python/aqmp: add runstate state machine to AsyncProtocol

This serves a few purposes:

1. Protect interfaces when it's not safe to call them (via @require)

2. Add an interface by which an async client can determine if the state
has changed, for the purposes of connection management.

Signed-off-by: John Snow <jsnow@redhat.com>
Reviewed-by: Eric Blake <eblake@redhat.com>
Message-id: 20210915162955.333025-7-jsnow@redhat.com
Signed-off-by: John Snow <jsnow@redhat.com>
This commit is contained in:
John Snow 2021-09-15 12:29:34 -04:00
parent 4ccaab0377
commit c58b42e095
2 changed files with 160 additions and 5 deletions

View File

@ -22,12 +22,16 @@ managing QMP events.
# the COPYING file in the top-level directory. # the COPYING file in the top-level directory.
from .error import AQMPError from .error import AQMPError
from .protocol import ConnectError from .protocol import ConnectError, Runstate, StateError
# The order of these fields impact the Sphinx documentation order. # The order of these fields impact the Sphinx documentation order.
__all__ = ( __all__ = (
# Classes
'Runstate',
# Exceptions, most generic to most explicit # Exceptions, most generic to most explicit
'AQMPError', 'AQMPError',
'StateError',
'ConnectError', 'ConnectError',
) )

View File

@ -12,11 +12,10 @@ class.
import asyncio import asyncio
from asyncio import StreamReader, StreamWriter from asyncio import StreamReader, StreamWriter
from enum import Enum
from functools import wraps
from ssl import SSLContext from ssl import SSLContext
# import exceptions will be removed in a forthcoming commit. from typing import (
# The problem stems from pylint/flake8 believing that 'Any'
# is unused because of its only use in a string-quoted type.
from typing import ( # pylint: disable=unused-import # noqa
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
@ -26,6 +25,7 @@ from typing import ( # pylint: disable=unused-import # noqa
Tuple, Tuple,
TypeVar, TypeVar,
Union, Union,
cast,
) )
from .error import AQMPError from .error import AQMPError
@ -44,6 +44,20 @@ _TaskFN = Callable[[], Awaitable[None]] # aka ``async def func() -> None``
_FutureT = TypeVar('_FutureT', bound=Optional['asyncio.Future[Any]']) _FutureT = TypeVar('_FutureT', bound=Optional['asyncio.Future[Any]'])
class Runstate(Enum):
"""Protocol session runstate."""
#: Fully quiesced and disconnected.
IDLE = 0
#: In the process of connecting or establishing a session.
CONNECTING = 1
#: Fully connected and active session.
RUNNING = 2
#: In the process of disconnecting.
#: Runstate may be returned to `IDLE` by calling `disconnect()`.
DISCONNECTING = 3
class ConnectError(AQMPError): class ConnectError(AQMPError):
""" """
Raised when the initial connection process has failed. Raised when the initial connection process has failed.
@ -65,6 +79,76 @@ class ConnectError(AQMPError):
return f"{self.error_message}: {self.exc!s}" return f"{self.error_message}: {self.exc!s}"
class StateError(AQMPError):
"""
An API command (connect, execute, etc) was issued at an inappropriate time.
This error is raised when a command like
:py:meth:`~AsyncProtocol.connect()` is issued at an inappropriate
time.
:param error_message: Human-readable string describing the state violation.
:param state: The actual `Runstate` seen at the time of the violation.
:param required: The `Runstate` required to process this command.
"""
def __init__(self, error_message: str,
state: Runstate, required: Runstate):
super().__init__(error_message)
self.error_message = error_message
self.state = state
self.required = required
F = TypeVar('F', bound=Callable[..., Any]) # pylint: disable=invalid-name
# Don't Panic.
def require(required_state: Runstate) -> Callable[[F], F]:
"""
Decorator: protect a method so it can only be run in a certain `Runstate`.
:param required_state: The `Runstate` required to invoke this method.
:raise StateError: When the required `Runstate` is not met.
"""
def _decorator(func: F) -> F:
# _decorator is the decorator that is built by calling the
# require() decorator factory; e.g.:
#
# @require(Runstate.IDLE) def foo(): ...
# will replace 'foo' with the result of '_decorator(foo)'.
@wraps(func)
def _wrapper(proto: 'AsyncProtocol[Any]',
*args: Any, **kwargs: Any) -> Any:
# _wrapper is the function that gets executed prior to the
# decorated method.
name = type(proto).__name__
if proto.runstate != required_state:
if proto.runstate == Runstate.CONNECTING:
emsg = f"{name} is currently connecting."
elif proto.runstate == Runstate.DISCONNECTING:
emsg = (f"{name} is disconnecting."
" Call disconnect() to return to IDLE state.")
elif proto.runstate == Runstate.RUNNING:
emsg = f"{name} is already connected and running."
elif proto.runstate == Runstate.IDLE:
emsg = f"{name} is disconnected and idle."
else:
assert False
raise StateError(emsg, proto.runstate, required_state)
# No StateError, so call the wrapped method.
return func(proto, *args, **kwargs)
# Return the decorated method;
# Transforming Func to Decorated[Func].
return cast(F, _wrapper)
# Return the decorator instance from the decorator factory. Phew!
return _decorator
class AsyncProtocol(Generic[T]): class AsyncProtocol(Generic[T]):
""" """
AsyncProtocol implements a generic async message-based protocol. AsyncProtocol implements a generic async message-based protocol.
@ -118,7 +202,24 @@ class AsyncProtocol(Generic[T]):
#: exit. #: exit.
self._dc_task: Optional[asyncio.Future[None]] = None self._dc_task: Optional[asyncio.Future[None]] = None
self._runstate = Runstate.IDLE
self._runstate_changed: Optional[asyncio.Event] = None
@property # @upper_half
def runstate(self) -> Runstate:
"""The current `Runstate` of the connection."""
return self._runstate
@upper_half @upper_half
async def runstate_changed(self) -> Runstate:
"""
Wait for the `runstate` to change, then return that runstate.
"""
await self._runstate_event.wait()
return self.runstate
@upper_half
@require(Runstate.IDLE)
async def connect(self, address: Union[str, Tuple[str, int]], async def connect(self, address: Union[str, Tuple[str, int]],
ssl: Optional[SSLContext] = None) -> None: ssl: Optional[SSLContext] = None) -> None:
""" """
@ -152,6 +253,30 @@ class AsyncProtocol(Generic[T]):
# Section: Session machinery # Section: Session machinery
# -------------------------- # --------------------------
@property
def _runstate_event(self) -> asyncio.Event:
# asyncio.Event() objects should not be created prior to entrance into
# an event loop, so we can ensure we create it in the correct context.
# Create it on-demand *only* at the behest of an 'async def' method.
if not self._runstate_changed:
self._runstate_changed = asyncio.Event()
return self._runstate_changed
@upper_half
@bottom_half
def _set_state(self, state: Runstate) -> None:
"""
Change the `Runstate` of the protocol connection.
Signals the `runstate_changed` event.
"""
if state == self._runstate:
return
self._runstate = state
self._runstate_event.set()
self._runstate_event.clear()
@upper_half @upper_half
async def _new_session(self, async def _new_session(self,
address: Union[str, Tuple[str, int]], address: Union[str, Tuple[str, int]],
@ -176,6 +301,8 @@ class AsyncProtocol(Generic[T]):
protocol-level failure occurs while establishing a new protocol-level failure occurs while establishing a new
session, the wrapped error may also be an `AQMPError`. session, the wrapped error may also be an `AQMPError`.
""" """
assert self.runstate == Runstate.IDLE
try: try:
phase = "connection" phase = "connection"
await self._establish_connection(address, ssl) await self._establish_connection(address, ssl)
@ -185,6 +312,7 @@ class AsyncProtocol(Generic[T]):
except BaseException as err: except BaseException as err:
emsg = f"Failed to establish {phase}" emsg = f"Failed to establish {phase}"
# Reset from CONNECTING back to IDLE.
await self.disconnect() await self.disconnect()
# NB: CancelledError is not a BaseException before Python 3.8 # NB: CancelledError is not a BaseException before Python 3.8
@ -197,6 +325,8 @@ class AsyncProtocol(Generic[T]):
# Raise BaseExceptions un-wrapped, they're more important. # Raise BaseExceptions un-wrapped, they're more important.
raise raise
assert self.runstate == Runstate.RUNNING
@upper_half @upper_half
async def _establish_connection( async def _establish_connection(
self, self,
@ -211,6 +341,14 @@ class AsyncProtocol(Generic[T]):
UNIX socket path or TCP address/port. UNIX socket path or TCP address/port.
:param ssl: SSL context to use, if any. :param ssl: SSL context to use, if any.
""" """
assert self.runstate == Runstate.IDLE
self._set_state(Runstate.CONNECTING)
# Allow runstate watchers to witness 'CONNECTING' state; some
# failures in the streaming layer are synchronous and will not
# otherwise yield.
await asyncio.sleep(0)
await self._do_connect(address, ssl) await self._do_connect(address, ssl)
@upper_half @upper_half
@ -240,6 +378,8 @@ class AsyncProtocol(Generic[T]):
own negotiations here. The Runstate will be RUNNING upon own negotiations here. The Runstate will be RUNNING upon
successful conclusion. successful conclusion.
""" """
assert self.runstate == Runstate.CONNECTING
self._outgoing = asyncio.Queue() self._outgoing = asyncio.Queue()
reader_coro = self._bh_loop_forever(self._bh_recv_message) reader_coro = self._bh_loop_forever(self._bh_recv_message)
@ -253,6 +393,9 @@ class AsyncProtocol(Generic[T]):
self._writer_task, self._writer_task,
) )
self._set_state(Runstate.RUNNING)
await asyncio.sleep(0) # Allow runstate_event to process
@upper_half @upper_half
@bottom_half @bottom_half
def _schedule_disconnect(self) -> None: def _schedule_disconnect(self) -> None:
@ -266,6 +409,7 @@ class AsyncProtocol(Generic[T]):
It can be invoked no matter what the `runstate` is. It can be invoked no matter what the `runstate` is.
""" """
if not self._dc_task: if not self._dc_task:
self._set_state(Runstate.DISCONNECTING)
self._dc_task = create_task(self._bh_disconnect()) self._dc_task = create_task(self._bh_disconnect())
@upper_half @upper_half
@ -281,6 +425,7 @@ class AsyncProtocol(Generic[T]):
:raise Exception: :raise Exception:
Arbitrary exception re-raised on behalf of the reader/writer. Arbitrary exception re-raised on behalf of the reader/writer.
""" """
assert self.runstate == Runstate.DISCONNECTING
assert self._dc_task assert self._dc_task
aws: List[Awaitable[object]] = [self._dc_task] aws: List[Awaitable[object]] = [self._dc_task]
@ -295,6 +440,7 @@ class AsyncProtocol(Generic[T]):
await all_defined_tasks # Raise Exceptions from the bottom half. await all_defined_tasks # Raise Exceptions from the bottom half.
finally: finally:
self._cleanup() self._cleanup()
self._set_state(Runstate.IDLE)
@upper_half @upper_half
def _cleanup(self) -> None: def _cleanup(self) -> None:
@ -306,6 +452,7 @@ class AsyncProtocol(Generic[T]):
assert (task is None) or task.done() assert (task is None) or task.done()
return None if (task and task.done()) else task return None if (task and task.done()) else task
assert self.runstate == Runstate.DISCONNECTING
self._dc_task = _paranoid_task_erase(self._dc_task) self._dc_task = _paranoid_task_erase(self._dc_task)
self._reader_task = _paranoid_task_erase(self._reader_task) self._reader_task = _paranoid_task_erase(self._reader_task)
self._writer_task = _paranoid_task_erase(self._writer_task) self._writer_task = _paranoid_task_erase(self._writer_task)
@ -314,6 +461,9 @@ class AsyncProtocol(Generic[T]):
self._reader = None self._reader = None
self._writer = None self._writer = None
# NB: _runstate_changed cannot be cleared because we still need it to
# send the final runstate changed event ...!
# ---------------------------- # ----------------------------
# Section: Bottom Half methods # Section: Bottom Half methods
# ---------------------------- # ----------------------------
@ -328,6 +478,7 @@ class AsyncProtocol(Generic[T]):
it is free to wait on any pending actions that may still need to it is free to wait on any pending actions that may still need to
occur in either the reader or writer tasks. occur in either the reader or writer tasks.
""" """
assert self.runstate == Runstate.DISCONNECTING
def _done(task: Optional['asyncio.Future[Any]']) -> bool: def _done(task: Optional['asyncio.Future[Any]']) -> bool:
return task is not None and task.done() return task is not None and task.done()