8193b9d148
Tests a real connect, a real accept, and really sending and receiving a message over a UNIX socket. Brings coverage of protocol.py up to ~93%. Signed-off-by: John Snow <jsnow@redhat.com> Message-id: 20210915162955.333025-27-jsnow@redhat.com Signed-off-by: John Snow <jsnow@redhat.com>
584 lines
18 KiB
Python
584 lines
18 KiB
Python
import asyncio
|
|
from contextlib import contextmanager
|
|
import os
|
|
import socket
|
|
from tempfile import TemporaryDirectory
|
|
|
|
import avocado
|
|
|
|
from qemu.aqmp import ConnectError, Runstate
|
|
from qemu.aqmp.protocol import AsyncProtocol, StateError
|
|
from qemu.aqmp.util import asyncio_run, create_task
|
|
|
|
|
|
class NullProtocol(AsyncProtocol[None]):
|
|
"""
|
|
NullProtocol is a test mockup of an AsyncProtocol implementation.
|
|
|
|
It adds a fake_session instance variable that enables a code path
|
|
that bypasses the actual connection logic, but still allows the
|
|
reader/writers to start.
|
|
|
|
Because the message type is defined as None, an asyncio.Event named
|
|
'trigger_input' is created that prohibits the reader from
|
|
incessantly being able to yield None; this event can be poked to
|
|
simulate an incoming message.
|
|
|
|
For testing symmetry with do_recv, an interface is added to "send" a
|
|
Null message.
|
|
|
|
For testing purposes, a "simulate_disconnection" method is also
|
|
added which allows us to trigger a bottom half disconnect without
|
|
injecting any real errors into the reader/writer loops; in essence
|
|
it performs exactly half of what disconnect() normally does.
|
|
"""
|
|
def __init__(self, name=None):
|
|
self.fake_session = False
|
|
self.trigger_input: asyncio.Event
|
|
super().__init__(name)
|
|
|
|
async def _establish_session(self):
|
|
self.trigger_input = asyncio.Event()
|
|
await super()._establish_session()
|
|
|
|
async def _do_accept(self, address, ssl=None):
|
|
if not self.fake_session:
|
|
await super()._do_accept(address, ssl)
|
|
|
|
async def _do_connect(self, address, ssl=None):
|
|
if not self.fake_session:
|
|
await super()._do_connect(address, ssl)
|
|
|
|
async def _do_recv(self) -> None:
|
|
await self.trigger_input.wait()
|
|
self.trigger_input.clear()
|
|
|
|
def _do_send(self, msg: None) -> None:
|
|
pass
|
|
|
|
async def send_msg(self) -> None:
|
|
await self._outgoing.put(None)
|
|
|
|
async def simulate_disconnect(self) -> None:
|
|
"""
|
|
Simulates a bottom-half disconnect.
|
|
|
|
This method schedules a disconnection but does not wait for it
|
|
to complete. This is used to put the loop into the DISCONNECTING
|
|
state without fully quiescing it back to IDLE. This is normally
|
|
something you cannot coax AsyncProtocol to do on purpose, but it
|
|
will be similar to what happens with an unhandled Exception in
|
|
the reader/writer.
|
|
|
|
Under normal circumstances, the library design requires you to
|
|
await on disconnect(), which awaits the disconnect task and
|
|
returns bottom half errors as a pre-condition to allowing the
|
|
loop to return back to IDLE.
|
|
"""
|
|
self._schedule_disconnect()
|
|
|
|
|
|
class LineProtocol(AsyncProtocol[str]):
|
|
def __init__(self, name=None):
|
|
super().__init__(name)
|
|
self.rx_history = []
|
|
|
|
async def _do_recv(self) -> str:
|
|
raw = await self._readline()
|
|
msg = raw.decode()
|
|
self.rx_history.append(msg)
|
|
return msg
|
|
|
|
def _do_send(self, msg: str) -> None:
|
|
assert self._writer is not None
|
|
self._writer.write(msg.encode() + b'\n')
|
|
|
|
async def send_msg(self, msg: str) -> None:
|
|
await self._outgoing.put(msg)
|
|
|
|
|
|
def run_as_task(coro, allow_cancellation=False):
|
|
"""
|
|
Run a given coroutine as a task.
|
|
|
|
Optionally, wrap it in a try..except block that allows this
|
|
coroutine to be canceled gracefully.
|
|
"""
|
|
async def _runner():
|
|
try:
|
|
await coro
|
|
except asyncio.CancelledError:
|
|
if allow_cancellation:
|
|
return
|
|
raise
|
|
return create_task(_runner())
|
|
|
|
|
|
@contextmanager
|
|
def jammed_socket():
|
|
"""
|
|
Opens up a random unused TCP port on localhost, then jams it.
|
|
"""
|
|
socks = []
|
|
|
|
try:
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
sock.bind(('127.0.0.1', 0))
|
|
sock.listen(1)
|
|
address = sock.getsockname()
|
|
|
|
socks.append(sock)
|
|
|
|
# I don't *fully* understand why, but it takes *two* un-accepted
|
|
# connections to start jamming the socket.
|
|
for _ in range(2):
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
sock.connect(address)
|
|
socks.append(sock)
|
|
|
|
yield address
|
|
|
|
finally:
|
|
for sock in socks:
|
|
sock.close()
|
|
|
|
|
|
class Smoke(avocado.Test):
|
|
|
|
def setUp(self):
|
|
self.proto = NullProtocol()
|
|
|
|
def test__repr__(self):
|
|
self.assertEqual(
|
|
repr(self.proto),
|
|
"<NullProtocol runstate=IDLE>"
|
|
)
|
|
|
|
def testRunstate(self):
|
|
self.assertEqual(
|
|
self.proto.runstate,
|
|
Runstate.IDLE
|
|
)
|
|
|
|
def testDefaultName(self):
|
|
self.assertEqual(
|
|
self.proto.name,
|
|
None
|
|
)
|
|
|
|
def testLogger(self):
|
|
self.assertEqual(
|
|
self.proto.logger.name,
|
|
'qemu.aqmp.protocol'
|
|
)
|
|
|
|
def testName(self):
|
|
self.proto = NullProtocol('Steve')
|
|
|
|
self.assertEqual(
|
|
self.proto.name,
|
|
'Steve'
|
|
)
|
|
|
|
self.assertEqual(
|
|
self.proto.logger.name,
|
|
'qemu.aqmp.protocol.Steve'
|
|
)
|
|
|
|
self.assertEqual(
|
|
repr(self.proto),
|
|
"<NullProtocol name='Steve' runstate=IDLE>"
|
|
)
|
|
|
|
|
|
class TestBase(avocado.Test):
|
|
|
|
def setUp(self):
|
|
self.proto = NullProtocol(type(self).__name__)
|
|
self.assertEqual(self.proto.runstate, Runstate.IDLE)
|
|
self.runstate_watcher = None
|
|
|
|
def tearDown(self):
|
|
self.assertEqual(self.proto.runstate, Runstate.IDLE)
|
|
|
|
async def _asyncSetUp(self):
|
|
pass
|
|
|
|
async def _asyncTearDown(self):
|
|
if self.runstate_watcher:
|
|
await self.runstate_watcher
|
|
|
|
@staticmethod
|
|
def async_test(async_test_method):
|
|
"""
|
|
Decorator; adds SetUp and TearDown to async tests.
|
|
"""
|
|
async def _wrapper(self, *args, **kwargs):
|
|
loop = asyncio.get_event_loop()
|
|
loop.set_debug(True)
|
|
|
|
await self._asyncSetUp()
|
|
await async_test_method(self, *args, **kwargs)
|
|
await self._asyncTearDown()
|
|
|
|
return _wrapper
|
|
|
|
# Definitions
|
|
|
|
# The states we expect a "bad" connect/accept attempt to transition through
|
|
BAD_CONNECTION_STATES = (
|
|
Runstate.CONNECTING,
|
|
Runstate.DISCONNECTING,
|
|
Runstate.IDLE,
|
|
)
|
|
|
|
# The states we expect a "good" session to transition through
|
|
GOOD_CONNECTION_STATES = (
|
|
Runstate.CONNECTING,
|
|
Runstate.RUNNING,
|
|
Runstate.DISCONNECTING,
|
|
Runstate.IDLE,
|
|
)
|
|
|
|
# Helpers
|
|
|
|
async def _watch_runstates(self, *states):
|
|
"""
|
|
This launches a task alongside (most) tests below to confirm that
|
|
the sequence of runstate changes that occur is exactly as
|
|
anticipated.
|
|
"""
|
|
async def _watcher():
|
|
for state in states:
|
|
new_state = await self.proto.runstate_changed()
|
|
self.assertEqual(
|
|
new_state,
|
|
state,
|
|
msg=f"Expected state '{state.name}'",
|
|
)
|
|
|
|
self.runstate_watcher = create_task(_watcher())
|
|
# Kick the loop and force the task to block on the event.
|
|
await asyncio.sleep(0)
|
|
|
|
|
|
class State(TestBase):
|
|
|
|
@TestBase.async_test
|
|
async def testSuperfluousDisconnect(self):
|
|
"""
|
|
Test calling disconnect() while already disconnected.
|
|
"""
|
|
await self._watch_runstates(
|
|
Runstate.DISCONNECTING,
|
|
Runstate.IDLE,
|
|
)
|
|
await self.proto.disconnect()
|
|
|
|
|
|
class Connect(TestBase):
|
|
"""
|
|
Tests primarily related to calling Connect().
|
|
"""
|
|
async def _bad_connection(self, family: str):
|
|
assert family in ('INET', 'UNIX')
|
|
|
|
if family == 'INET':
|
|
await self.proto.connect(('127.0.0.1', 0))
|
|
elif family == 'UNIX':
|
|
await self.proto.connect('/dev/null')
|
|
|
|
async def _hanging_connection(self):
|
|
with jammed_socket() as addr:
|
|
await self.proto.connect(addr)
|
|
|
|
async def _bad_connection_test(self, family: str):
|
|
await self._watch_runstates(*self.BAD_CONNECTION_STATES)
|
|
|
|
with self.assertRaises(ConnectError) as context:
|
|
await self._bad_connection(family)
|
|
|
|
self.assertIsInstance(context.exception.exc, OSError)
|
|
self.assertEqual(
|
|
context.exception.error_message,
|
|
"Failed to establish connection"
|
|
)
|
|
|
|
@TestBase.async_test
|
|
async def testBadINET(self):
|
|
"""
|
|
Test an immediately rejected call to an IP target.
|
|
"""
|
|
await self._bad_connection_test('INET')
|
|
|
|
@TestBase.async_test
|
|
async def testBadUNIX(self):
|
|
"""
|
|
Test an immediately rejected call to a UNIX socket target.
|
|
"""
|
|
await self._bad_connection_test('UNIX')
|
|
|
|
@TestBase.async_test
|
|
async def testCancellation(self):
|
|
"""
|
|
Test what happens when a connection attempt is aborted.
|
|
"""
|
|
# Note that accept() cannot be cancelled outright, as it isn't a task.
|
|
# However, we can wrap it in a task and cancel *that*.
|
|
await self._watch_runstates(*self.BAD_CONNECTION_STATES)
|
|
task = run_as_task(self._hanging_connection(), allow_cancellation=True)
|
|
|
|
state = await self.proto.runstate_changed()
|
|
self.assertEqual(state, Runstate.CONNECTING)
|
|
|
|
# This is insider baseball, but the connection attempt has
|
|
# yielded *just* before the actual connection attempt, so kick
|
|
# the loop to make sure it's truly wedged.
|
|
await asyncio.sleep(0)
|
|
|
|
task.cancel()
|
|
await task
|
|
|
|
@TestBase.async_test
|
|
async def testTimeout(self):
|
|
"""
|
|
Test what happens when a connection attempt times out.
|
|
"""
|
|
await self._watch_runstates(*self.BAD_CONNECTION_STATES)
|
|
task = run_as_task(self._hanging_connection())
|
|
|
|
# More insider baseball: to improve the speed of this test while
|
|
# guaranteeing that the connection even gets a chance to start,
|
|
# verify that the connection hangs *first*, then await the
|
|
# result of the task with a nearly-zero timeout.
|
|
|
|
state = await self.proto.runstate_changed()
|
|
self.assertEqual(state, Runstate.CONNECTING)
|
|
await asyncio.sleep(0)
|
|
|
|
with self.assertRaises(asyncio.TimeoutError):
|
|
await asyncio.wait_for(task, timeout=0)
|
|
|
|
@TestBase.async_test
|
|
async def testRequire(self):
|
|
"""
|
|
Test what happens when a connection attempt is made while CONNECTING.
|
|
"""
|
|
await self._watch_runstates(*self.BAD_CONNECTION_STATES)
|
|
task = run_as_task(self._hanging_connection(), allow_cancellation=True)
|
|
|
|
state = await self.proto.runstate_changed()
|
|
self.assertEqual(state, Runstate.CONNECTING)
|
|
|
|
with self.assertRaises(StateError) as context:
|
|
await self._bad_connection('UNIX')
|
|
|
|
self.assertEqual(
|
|
context.exception.error_message,
|
|
"NullProtocol is currently connecting."
|
|
)
|
|
self.assertEqual(context.exception.state, Runstate.CONNECTING)
|
|
self.assertEqual(context.exception.required, Runstate.IDLE)
|
|
|
|
task.cancel()
|
|
await task
|
|
|
|
@TestBase.async_test
|
|
async def testImplicitRunstateInit(self):
|
|
"""
|
|
Test what happens if we do not wait on the runstate event until
|
|
AFTER a connection is made, i.e., connect()/accept() themselves
|
|
initialize the runstate event. All of the above tests force the
|
|
initialization by waiting on the runstate *first*.
|
|
"""
|
|
task = run_as_task(self._hanging_connection(), allow_cancellation=True)
|
|
|
|
# Kick the loop to coerce the state change
|
|
await asyncio.sleep(0)
|
|
assert self.proto.runstate == Runstate.CONNECTING
|
|
|
|
# We already missed the transition to CONNECTING
|
|
await self._watch_runstates(Runstate.DISCONNECTING, Runstate.IDLE)
|
|
|
|
task.cancel()
|
|
await task
|
|
|
|
|
|
class Accept(Connect):
|
|
"""
|
|
All of the same tests as Connect, but using the accept() interface.
|
|
"""
|
|
async def _bad_connection(self, family: str):
|
|
assert family in ('INET', 'UNIX')
|
|
|
|
if family == 'INET':
|
|
await self.proto.accept(('example.com', 1))
|
|
elif family == 'UNIX':
|
|
await self.proto.accept('/dev/null')
|
|
|
|
async def _hanging_connection(self):
|
|
with TemporaryDirectory(suffix='.aqmp') as tmpdir:
|
|
sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
|
|
await self.proto.accept(sock)
|
|
|
|
|
|
class FakeSession(TestBase):
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.proto.fake_session = True
|
|
|
|
async def _asyncSetUp(self):
|
|
await super()._asyncSetUp()
|
|
await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
|
|
|
|
async def _asyncTearDown(self):
|
|
await self.proto.disconnect()
|
|
await super()._asyncTearDown()
|
|
|
|
####
|
|
|
|
@TestBase.async_test
|
|
async def testFakeConnect(self):
|
|
|
|
"""Test the full state lifecycle (via connect) with a no-op session."""
|
|
await self.proto.connect('/not/a/real/path')
|
|
self.assertEqual(self.proto.runstate, Runstate.RUNNING)
|
|
|
|
@TestBase.async_test
|
|
async def testFakeAccept(self):
|
|
"""Test the full state lifecycle (via accept) with a no-op session."""
|
|
await self.proto.accept('/not/a/real/path')
|
|
self.assertEqual(self.proto.runstate, Runstate.RUNNING)
|
|
|
|
@TestBase.async_test
|
|
async def testFakeRecv(self):
|
|
"""Test receiving a fake/null message."""
|
|
await self.proto.accept('/not/a/real/path')
|
|
|
|
logname = self.proto.logger.name
|
|
with self.assertLogs(logname, level='DEBUG') as context:
|
|
self.proto.trigger_input.set()
|
|
self.proto.trigger_input.clear()
|
|
await asyncio.sleep(0) # Kick reader.
|
|
|
|
self.assertEqual(
|
|
context.output,
|
|
[f"DEBUG:{logname}:<-- None"],
|
|
)
|
|
|
|
@TestBase.async_test
|
|
async def testFakeSend(self):
|
|
"""Test sending a fake/null message."""
|
|
await self.proto.accept('/not/a/real/path')
|
|
|
|
logname = self.proto.logger.name
|
|
with self.assertLogs(logname, level='DEBUG') as context:
|
|
# Cheat: Send a Null message to nobody.
|
|
await self.proto.send_msg()
|
|
# Kick writer; awaiting on a queue.put isn't sufficient to yield.
|
|
await asyncio.sleep(0)
|
|
|
|
self.assertEqual(
|
|
context.output,
|
|
[f"DEBUG:{logname}:--> None"],
|
|
)
|
|
|
|
async def _prod_session_api(
|
|
self,
|
|
current_state: Runstate,
|
|
error_message: str,
|
|
accept: bool = True
|
|
):
|
|
with self.assertRaises(StateError) as context:
|
|
if accept:
|
|
await self.proto.accept('/not/a/real/path')
|
|
else:
|
|
await self.proto.connect('/not/a/real/path')
|
|
|
|
self.assertEqual(context.exception.error_message, error_message)
|
|
self.assertEqual(context.exception.state, current_state)
|
|
self.assertEqual(context.exception.required, Runstate.IDLE)
|
|
|
|
@TestBase.async_test
|
|
async def testAcceptRequireRunning(self):
|
|
"""Test that accept() cannot be called when Runstate=RUNNING"""
|
|
await self.proto.accept('/not/a/real/path')
|
|
|
|
await self._prod_session_api(
|
|
Runstate.RUNNING,
|
|
"NullProtocol is already connected and running.",
|
|
accept=True,
|
|
)
|
|
|
|
@TestBase.async_test
|
|
async def testConnectRequireRunning(self):
|
|
"""Test that connect() cannot be called when Runstate=RUNNING"""
|
|
await self.proto.accept('/not/a/real/path')
|
|
|
|
await self._prod_session_api(
|
|
Runstate.RUNNING,
|
|
"NullProtocol is already connected and running.",
|
|
accept=False,
|
|
)
|
|
|
|
@TestBase.async_test
|
|
async def testAcceptRequireDisconnecting(self):
|
|
"""Test that accept() cannot be called when Runstate=DISCONNECTING"""
|
|
await self.proto.accept('/not/a/real/path')
|
|
|
|
# Cheat: force a disconnect.
|
|
await self.proto.simulate_disconnect()
|
|
|
|
await self._prod_session_api(
|
|
Runstate.DISCONNECTING,
|
|
("NullProtocol is disconnecting."
|
|
" Call disconnect() to return to IDLE state."),
|
|
accept=True,
|
|
)
|
|
|
|
@TestBase.async_test
|
|
async def testConnectRequireDisconnecting(self):
|
|
"""Test that connect() cannot be called when Runstate=DISCONNECTING"""
|
|
await self.proto.accept('/not/a/real/path')
|
|
|
|
# Cheat: force a disconnect.
|
|
await self.proto.simulate_disconnect()
|
|
|
|
await self._prod_session_api(
|
|
Runstate.DISCONNECTING,
|
|
("NullProtocol is disconnecting."
|
|
" Call disconnect() to return to IDLE state."),
|
|
accept=False,
|
|
)
|
|
|
|
|
|
class SimpleSession(TestBase):
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.server = LineProtocol(type(self).__name__ + '-server')
|
|
|
|
async def _asyncSetUp(self):
|
|
await super()._asyncSetUp()
|
|
await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
|
|
|
|
async def _asyncTearDown(self):
|
|
await self.proto.disconnect()
|
|
try:
|
|
await self.server.disconnect()
|
|
except EOFError:
|
|
pass
|
|
await super()._asyncTearDown()
|
|
|
|
@TestBase.async_test
|
|
async def testSmoke(self):
|
|
with TemporaryDirectory(suffix='.aqmp') as tmpdir:
|
|
sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
|
|
server_task = create_task(self.server.accept(sock))
|
|
|
|
# give the server a chance to start listening [...]
|
|
await asyncio.sleep(0)
|
|
await self.proto.connect(sock)
|