Python patches

Hopefully, fixes the race conditions witnessed through the NetBSD vm tests.
 -----BEGIN PGP SIGNATURE-----
 
 iQIzBAABCAAdFiEE+ber27ys35W+dsvQfe+BBqr8OQ4FAmImg9IACgkQfe+BBqr8
 OQ6pMxAAgilUH8OIJzJfV2C/1qWM2Hzrl/jwTUEuYxmMYacdL9kJvR3NJ4CMv5Nn
 996TyJROK+QDQoVsUuoEjkdrezbI4UDoixM9ku7KWAUMEsxXmRR5kcclSkCWX4HX
 o+My1UR+6LxPgH894JMTcnKzH9gDHkU0Aww/nu5LumJoVB12Gu1iLif/2JneQKFB
 rWaQu+8DHGH7Jv9s0ShrmkDYwtwq5XXGtefR6DEdo5xGGCjzYrYr80Frg7R1OYVU
 xlGV0MbLjTmePM5F4ZxiQGohFSOY6QsraxDMiqVOc+gBjz2J8l+7i8AA3Zirwotz
 V9BYPDRZ9pZV3ERDPqh0L3homsmk2wepkXi6YAz9/DMn0pDHizmvntPCCdhzBXyH
 cA63+QayvCYADDoHkUbMT5jc7X6ayfauj7ZkJPzfr7YtzYKs6k0bDmtgJBMyNRj1
 pHILnv5oGnnVz4kO5W98oV2jijAdqi9or3+4B2woeUmaROoQJA0ObU35ke961KNE
 n66kTOibgMj/TQmDE1veBgNvCxY0cRE+ZB7SYL7ZaqvavEwfeYQRz851sDxTdiFF
 v5b/Ls8IDKPbU8qPLDzTQrAy19CWtOkJTD4b4/6WAv9K0SAxghQEyoCUCZbk+PLt
 xGeCyxImTC7XaqFlops9WzBTK3jz/7m9EvgfJNRKj8QZ49yxCBo=
 =0ieN
 -----END PGP SIGNATURE-----

Merge remote-tracking branch 'remotes/jsnow-gitlab/tags/python-pull-request' into staging

Python patches

Hopefully, fixes the race conditions witnessed through the NetBSD vm tests.

# gpg: Signature made Mon 07 Mar 2022 22:14:42 GMT
# gpg:                using RSA key F9B7ABDBBCACDF95BE76CBD07DEF8106AAFC390E
# gpg: Good signature from "John Snow (John Huston) <jsnow@redhat.com>" [full]
# Primary key fingerprint: FAEB 9711 A12C F475 812F  18F2 88A9 064D 1835 61EB
#      Subkey fingerprint: F9B7 ABDB BCAC DF95 BE76  CBD0 7DEF 8106 AAFC 390E

* remotes/jsnow-gitlab/tags/python-pull-request:
  scripts/qmp-shell-wrap: Fix import path
  python/aqmp: drop _bind_hack()
  python/aqmp: fix race condition in legacy.py
  python/aqmp: add start_server() and accept() methods
  python/aqmp: stop the server during disconnect()
  python/aqmp: refactor _do_accept() into two distinct steps
  python/aqmp: squelch pylint warning for too many lines
  python/aqmp: split _client_connected_cb() out as _incoming()
  python/aqmp: remove _new_session and _establish_connection
  python/aqmp: rename 'accept()' to 'start_server_and_accept()'
  python/aqmp: add _session_guard()

Signed-off-by: Peter Maydell <peter.maydell@linaro.org>
This commit is contained in:
Peter Maydell 2022-03-08 19:31:05 +00:00
commit 2ad7624900
4 changed files with 272 additions and 171 deletions

View File

@ -57,7 +57,7 @@ class QEMUMonitorProtocol(qemu.qmp.QEMUMonitorProtocol):
self._timeout: Optional[float] = None
if server:
self._aqmp._bind_hack(address) # pylint: disable=protected-access
self._sync(self._aqmp.start_server(self._address))
_T = TypeVar('_T')
@ -90,10 +90,7 @@ class QEMUMonitorProtocol(qemu.qmp.QEMUMonitorProtocol):
self._aqmp.await_greeting = True
self._aqmp.negotiate = True
self._sync(
self._aqmp.accept(self._address),
timeout
)
self._sync(self._aqmp.accept(), timeout)
ret = self._get_greeting()
assert ret is not None

View File

@ -10,12 +10,14 @@ In this package, it is used as the implementation for the `QMPClient`
class.
"""
# It's all the docstrings ... ! It's long for a good reason ^_^;
# pylint: disable=too-many-lines
import asyncio
from asyncio import StreamReader, StreamWriter
from enum import Enum
from functools import wraps
import logging
import socket
from ssl import SSLContext
from typing import (
Any,
@ -239,8 +241,9 @@ class AsyncProtocol(Generic[T]):
self._runstate = Runstate.IDLE
self._runstate_changed: Optional[asyncio.Event] = None
# Workaround for bind()
self._sock: Optional[socket.socket] = None
# Server state for start_server() and _incoming()
self._server: Optional[asyncio.AbstractServer] = None
self._accepted: Optional[asyncio.Event] = None
def __repr__(self) -> str:
cls_name = type(self).__name__
@ -265,21 +268,90 @@ class AsyncProtocol(Generic[T]):
@upper_half
@require(Runstate.IDLE)
async def accept(self, address: SocketAddrT,
ssl: Optional[SSLContext] = None) -> None:
async def start_server_and_accept(
self, address: SocketAddrT,
ssl: Optional[SSLContext] = None
) -> None:
"""
Accept a connection and begin processing message queues.
If this call fails, `runstate` is guaranteed to be set back to `IDLE`.
This method is precisely equivalent to calling `start_server()`
followed by `accept()`.
:param address:
Address to listen to; UNIX socket path or TCP address/port.
Address to listen on; UNIX socket path or TCP address/port.
:param ssl: SSL context to use, if any.
:raise StateError: When the `Runstate` is not `IDLE`.
:raise ConnectError: If a connection could not be accepted.
:raise ConnectError:
When a connection or session cannot be established.
This exception will wrap a more concrete one. In most cases,
the wrapped exception will be `OSError` or `EOFError`. If a
protocol-level failure occurs while establishing a new
session, the wrapped error may also be an `QMPError`.
"""
await self._new_session(address, ssl, accept=True)
await self.start_server(address, ssl)
await self.accept()
assert self.runstate == Runstate.RUNNING
@upper_half
@require(Runstate.IDLE)
async def start_server(self, address: SocketAddrT,
ssl: Optional[SSLContext] = None) -> None:
"""
Start listening for an incoming connection, but do not wait for a peer.
This method starts listening for an incoming connection, but
does not block waiting for a peer. This call will return
immediately after binding and listening on a socket. A later
call to `accept()` must be made in order to finalize the
incoming connection.
:param address:
Address to listen on; UNIX socket path or TCP address/port.
:param ssl: SSL context to use, if any.
:raise StateError: When the `Runstate` is not `IDLE`.
:raise ConnectError:
When the server could not start listening on this address.
This exception will wrap a more concrete one. In most cases,
the wrapped exception will be `OSError`.
"""
await self._session_guard(
self._do_start_server(address, ssl),
'Failed to establish connection')
assert self.runstate == Runstate.CONNECTING
@upper_half
@require(Runstate.CONNECTING)
async def accept(self) -> None:
"""
Accept an incoming connection and begin processing message queues.
If this call fails, `runstate` is guaranteed to be set back to `IDLE`.
:raise StateError: When the `Runstate` is not `CONNECTING`.
:raise QMPError: When `start_server()` was not called yet.
:raise ConnectError:
When a connection or session cannot be established.
This exception will wrap a more concrete one. In most cases,
the wrapped exception will be `OSError` or `EOFError`. If a
protocol-level failure occurs while establishing a new
session, the wrapped error may also be an `QMPError`.
"""
if self._accepted is None:
raise QMPError("Cannot call accept() before start_server().")
await self._session_guard(
self._do_accept(),
'Failed to establish connection')
await self._session_guard(
self._establish_session(),
'Failed to establish session')
assert self.runstate == Runstate.RUNNING
@upper_half
@require(Runstate.IDLE)
@ -295,9 +367,21 @@ class AsyncProtocol(Generic[T]):
:param ssl: SSL context to use, if any.
:raise StateError: When the `Runstate` is not `IDLE`.
:raise ConnectError: If a connection cannot be made to the server.
:raise ConnectError:
When a connection or session cannot be established.
This exception will wrap a more concrete one. In most cases,
the wrapped exception will be `OSError` or `EOFError`. If a
protocol-level failure occurs while establishing a new
session, the wrapped error may also be an `QMPError`.
"""
await self._new_session(address, ssl)
await self._session_guard(
self._do_connect(address, ssl),
'Failed to establish connection')
await self._session_guard(
self._establish_session(),
'Failed to establish session')
assert self.runstate == Runstate.RUNNING
@upper_half
async def disconnect(self) -> None:
@ -317,6 +401,62 @@ class AsyncProtocol(Generic[T]):
# Section: Session machinery
# --------------------------
async def _session_guard(self, coro: Awaitable[None], emsg: str) -> None:
"""
Async guard function used to roll back to `IDLE` on any error.
On any Exception, the state machine will be reset back to
`IDLE`. Most Exceptions will be wrapped with `ConnectError`, but
`BaseException` events will be left alone (This includes
asyncio.CancelledError, even prior to Python 3.8).
:param error_message:
Human-readable string describing what connection phase failed.
:raise BaseException:
When `BaseException` occurs in the guarded block.
:raise ConnectError:
When any other error is encountered in the guarded block.
"""
# Note: After Python 3.6 support is removed, this should be an
# @asynccontextmanager instead of accepting a callback.
try:
await coro
except BaseException as err:
self.logger.error("%s: %s", emsg, exception_summary(err))
self.logger.debug("%s:\n%s\n", emsg, pretty_traceback())
try:
# Reset the runstate back to IDLE.
await self.disconnect()
except:
# We don't expect any Exceptions from the disconnect function
# here, because we failed to connect in the first place.
# The disconnect() function is intended to perform
# only cannot-fail cleanup here, but you never know.
emsg = (
"Unexpected bottom half exception. "
"This is a bug in the QMP library. "
"Please report it to <qemu-devel@nongnu.org> and "
"CC: John Snow <jsnow@redhat.com>."
)
self.logger.critical("%s:\n%s\n", emsg, pretty_traceback())
raise
# CancelledError is an Exception with special semantic meaning;
# We do NOT want to wrap it up under ConnectError.
# NB: CancelledError is not a BaseException before Python 3.8
if isinstance(err, asyncio.CancelledError):
raise
# Any other kind of error can be treated as some kind of connection
# failure broadly. Inspect the 'exc' field to explore the root
# cause in greater detail.
if isinstance(err, Exception):
raise ConnectError(emsg, err) from err
# Raise BaseExceptions un-wrapped, they're more important.
raise
@property
def _runstate_event(self) -> asyncio.Event:
# asyncio.Event() objects should not be created prior to entrance into
@ -343,127 +483,64 @@ class AsyncProtocol(Generic[T]):
self._runstate_event.set()
self._runstate_event.clear()
@upper_half
async def _new_session(self,
address: SocketAddrT,
ssl: Optional[SSLContext] = None,
accept: bool = False) -> None:
@bottom_half
async def _stop_server(self) -> None:
"""
Establish a new connection and initialize the session.
Connect or accept a new connection, then begin the protocol
session machinery. If this call fails, `runstate` is guaranteed
to be set back to `IDLE`.
:param address:
Address to connect to/listen on;
UNIX socket path or TCP address/port.
:param ssl: SSL context to use, if any.
:param accept: Accept a connection instead of connecting when `True`.
:raise ConnectError:
When a connection or session cannot be established.
This exception will wrap a more concrete one. In most cases,
the wrapped exception will be `OSError` or `EOFError`. If a
protocol-level failure occurs while establishing a new
session, the wrapped error may also be an `QMPError`.
Stop listening for / accepting new incoming connections.
"""
assert self.runstate == Runstate.IDLE
if self._server is None:
return
try:
phase = "connection"
await self._establish_connection(address, ssl, accept)
self.logger.debug("Stopping server.")
self._server.close()
await self._server.wait_closed()
self.logger.debug("Server stopped.")
finally:
self._server = None
phase = "session"
await self._establish_session()
@bottom_half # However, it does not run from the R/W tasks.
async def _incoming(self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter) -> None:
"""
Accept an incoming connection and signal the upper_half.
except BaseException as err:
emsg = f"Failed to establish {phase}"
self.logger.error("%s: %s", emsg, exception_summary(err))
self.logger.debug("%s:\n%s\n", emsg, pretty_traceback())
try:
# Reset from CONNECTING back to IDLE.
await self.disconnect()
except:
emsg = "Unexpected bottom half exception"
self.logger.critical("%s:\n%s\n", emsg, pretty_traceback())
raise
This method does the minimum necessary to accept a single
incoming connection. It signals back to the upper_half ASAP so
that any errors during session initialization can occur
naturally in the caller's stack.
# NB: CancelledError is not a BaseException before Python 3.8
if isinstance(err, asyncio.CancelledError):
raise
:param reader: Incoming `asyncio.StreamReader`
:param writer: Incoming `asyncio.StreamWriter`
"""
peer = writer.get_extra_info('peername', 'Unknown peer')
self.logger.debug("Incoming connection from %s", peer)
if isinstance(err, Exception):
raise ConnectError(emsg, err) from err
if self._reader or self._writer:
# Sadly, we can have more than one pending connection
# because of https://bugs.python.org/issue46715
# Close any extra connections we don't actually want.
self.logger.warning("Extraneous connection inadvertently accepted")
writer.close()
return
# Raise BaseExceptions un-wrapped, they're more important.
raise
assert self.runstate == Runstate.RUNNING
# A connection has been accepted; stop listening for new ones.
assert self._accepted is not None
await self._stop_server()
self._reader, self._writer = (reader, writer)
self._accepted.set()
@upper_half
async def _establish_connection(
self,
address: SocketAddrT,
ssl: Optional[SSLContext] = None,
accept: bool = False
) -> None:
async def _do_start_server(self, address: SocketAddrT,
ssl: Optional[SSLContext] = None) -> None:
"""
Establish a new connection.
Start listening for an incoming connection, but do not wait for a peer.
:param address:
Address to connect to/listen on;
UNIX socket path or TCP address/port.
:param ssl: SSL context to use, if any.
:param accept: Accept a connection instead of connecting when `True`.
"""
assert self.runstate == Runstate.IDLE
self._set_state(Runstate.CONNECTING)
# Allow runstate watchers to witness 'CONNECTING' state; some
# failures in the streaming layer are synchronous and will not
# otherwise yield.
await asyncio.sleep(0)
if accept:
await self._do_accept(address, ssl)
else:
await self._do_connect(address, ssl)
def _bind_hack(self, address: Union[str, Tuple[str, int]]) -> None:
"""
Used to create a socket in advance of accept().
This is a workaround to ensure that we can guarantee timing of
precisely when a socket exists to avoid a connection attempt
bouncing off of nothing.
Python 3.7+ adds a feature to separate the server creation and
listening phases instead, and should be used instead of this
hack.
"""
if isinstance(address, tuple):
family = socket.AF_INET
else:
family = socket.AF_UNIX
sock = socket.socket(family, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
sock.bind(address)
except:
sock.close()
raise
self._sock = sock
@upper_half
async def _do_accept(self, address: SocketAddrT,
ssl: Optional[SSLContext] = None) -> None:
"""
Acting as the transport server, accept a single connection.
This method starts listening for an incoming connection, but does not
block waiting for a peer. This call will return immediately after
binding and listening to a socket. A later call to accept() must be
made in order to finalize the incoming connection.
:param address:
Address to listen on; UNIX socket path or TCP address/port.
@ -471,52 +548,54 @@ class AsyncProtocol(Generic[T]):
:raise OSError: For stream-related errors.
"""
assert self.runstate == Runstate.IDLE
self._set_state(Runstate.CONNECTING)
self.logger.debug("Awaiting connection on %s ...", address)
connected = asyncio.Event()
server: Optional[asyncio.AbstractServer] = None
async def _client_connected_cb(reader: asyncio.StreamReader,
writer: asyncio.StreamWriter) -> None:
"""Used to accept a single incoming connection, see below."""
nonlocal server
nonlocal connected
# A connection has been accepted; stop listening for new ones.
assert server is not None
server.close()
await server.wait_closed()
server = None
# Register this client as being connected
self._reader, self._writer = (reader, writer)
# Signal back: We've accepted a client!
connected.set()
self._accepted = asyncio.Event()
if isinstance(address, tuple):
coro = asyncio.start_server(
_client_connected_cb,
host=None if self._sock else address[0],
port=None if self._sock else address[1],
self._incoming,
host=address[0],
port=address[1],
ssl=ssl,
backlog=1,
limit=self._limit,
sock=self._sock,
)
else:
coro = asyncio.start_unix_server(
_client_connected_cb,
path=None if self._sock else address,
self._incoming,
path=address,
ssl=ssl,
backlog=1,
limit=self._limit,
sock=self._sock,
)
server = await coro # Starts listening
await connected.wait() # Waits for the callback to fire (and finish)
assert server is None
self._sock = None
# Allow runstate watchers to witness 'CONNECTING' state; some
# failures in the streaming layer are synchronous and will not
# otherwise yield.
await asyncio.sleep(0)
# This will start the server (bind(2), listen(2)). It will also
# call accept(2) if we yield, but we don't block on that here.
self._server = await coro
self.logger.debug("Server listening on %s", address)
@upper_half
async def _do_accept(self) -> None:
"""
Wait for and accept an incoming connection.
Requires that we have not yet accepted an incoming connection
from the upper_half, but it's OK if the server is no longer
running because the bottom_half has already accepted the
connection.
"""
assert self._accepted is not None
await self._accepted.wait()
assert self._server is None
self._accepted = None
self.logger.debug("Connection accepted.")
@ -532,6 +611,14 @@ class AsyncProtocol(Generic[T]):
:raise OSError: For stream-related errors.
"""
assert self.runstate == Runstate.IDLE
self._set_state(Runstate.CONNECTING)
# Allow runstate watchers to witness 'CONNECTING' state; some
# failures in the streaming layer are synchronous and will not
# otherwise yield.
await asyncio.sleep(0)
self.logger.debug("Connecting to %s ...", address)
if isinstance(address, tuple):
@ -644,6 +731,7 @@ class AsyncProtocol(Generic[T]):
self._reader = None
self._writer = None
self._accepted = None
# NB: _runstate_changed cannot be cleared because we still need it to
# send the final runstate changed event ...!
@ -667,6 +755,9 @@ class AsyncProtocol(Generic[T]):
def _done(task: Optional['asyncio.Future[Any]']) -> bool:
return task is not None and task.done()
# If the server is running, stop it.
await self._stop_server()
# Are we already in an error pathway? If either of the tasks are
# already done, or if we have no tasks but a reader/writer; we
# must be.

View File

@ -41,12 +41,25 @@ class NullProtocol(AsyncProtocol[None]):
self.trigger_input = asyncio.Event()
await super()._establish_session()
async def _do_accept(self, address, ssl=None):
if not self.fake_session:
await super()._do_accept(address, ssl)
async def _do_start_server(self, address, ssl=None):
if self.fake_session:
self._accepted = asyncio.Event()
self._set_state(Runstate.CONNECTING)
await asyncio.sleep(0)
else:
await super()._do_start_server(address, ssl)
async def _do_accept(self):
if self.fake_session:
self._accepted = None
else:
await super()._do_accept()
async def _do_connect(self, address, ssl=None):
if not self.fake_session:
if self.fake_session:
self._set_state(Runstate.CONNECTING)
await asyncio.sleep(0)
else:
await super()._do_connect(address, ssl)
async def _do_recv(self) -> None:
@ -413,14 +426,14 @@ class Accept(Connect):
assert family in ('INET', 'UNIX')
if family == 'INET':
await self.proto.accept(('example.com', 1))
await self.proto.start_server_and_accept(('example.com', 1))
elif family == 'UNIX':
await self.proto.accept('/dev/null')
await self.proto.start_server_and_accept('/dev/null')
async def _hanging_connection(self):
with TemporaryDirectory(suffix='.aqmp') as tmpdir:
sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
await self.proto.accept(sock)
await self.proto.start_server_and_accept(sock)
class FakeSession(TestBase):
@ -449,13 +462,13 @@ class FakeSession(TestBase):
@TestBase.async_test
async def testFakeAccept(self):
"""Test the full state lifecycle (via accept) with a no-op session."""
await self.proto.accept('/not/a/real/path')
await self.proto.start_server_and_accept('/not/a/real/path')
self.assertEqual(self.proto.runstate, Runstate.RUNNING)
@TestBase.async_test
async def testFakeRecv(self):
"""Test receiving a fake/null message."""
await self.proto.accept('/not/a/real/path')
await self.proto.start_server_and_accept('/not/a/real/path')
logname = self.proto.logger.name
with self.assertLogs(logname, level='DEBUG') as context:
@ -471,7 +484,7 @@ class FakeSession(TestBase):
@TestBase.async_test
async def testFakeSend(self):
"""Test sending a fake/null message."""
await self.proto.accept('/not/a/real/path')
await self.proto.start_server_and_accept('/not/a/real/path')
logname = self.proto.logger.name
with self.assertLogs(logname, level='DEBUG') as context:
@ -493,7 +506,7 @@ class FakeSession(TestBase):
):
with self.assertRaises(StateError) as context:
if accept:
await self.proto.accept('/not/a/real/path')
await self.proto.start_server_and_accept('/not/a/real/path')
else:
await self.proto.connect('/not/a/real/path')
@ -504,7 +517,7 @@ class FakeSession(TestBase):
@TestBase.async_test
async def testAcceptRequireRunning(self):
"""Test that accept() cannot be called when Runstate=RUNNING"""
await self.proto.accept('/not/a/real/path')
await self.proto.start_server_and_accept('/not/a/real/path')
await self._prod_session_api(
Runstate.RUNNING,
@ -515,7 +528,7 @@ class FakeSession(TestBase):
@TestBase.async_test
async def testConnectRequireRunning(self):
"""Test that connect() cannot be called when Runstate=RUNNING"""
await self.proto.accept('/not/a/real/path')
await self.proto.start_server_and_accept('/not/a/real/path')
await self._prod_session_api(
Runstate.RUNNING,
@ -526,7 +539,7 @@ class FakeSession(TestBase):
@TestBase.async_test
async def testAcceptRequireDisconnecting(self):
"""Test that accept() cannot be called when Runstate=DISCONNECTING"""
await self.proto.accept('/not/a/real/path')
await self.proto.start_server_and_accept('/not/a/real/path')
# Cheat: force a disconnect.
await self.proto.simulate_disconnect()
@ -541,7 +554,7 @@ class FakeSession(TestBase):
@TestBase.async_test
async def testConnectRequireDisconnecting(self):
"""Test that connect() cannot be called when Runstate=DISCONNECTING"""
await self.proto.accept('/not/a/real/path')
await self.proto.start_server_and_accept('/not/a/real/path')
# Cheat: force a disconnect.
await self.proto.simulate_disconnect()
@ -576,7 +589,7 @@ class SimpleSession(TestBase):
async def testSmoke(self):
with TemporaryDirectory(suffix='.aqmp') as tmpdir:
sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
server_task = create_task(self.server.accept(sock))
server_task = create_task(self.server.start_server_and_accept(sock))
# give the server a chance to start listening [...]
await asyncio.sleep(0)

View File

@ -4,7 +4,7 @@ import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'python'))
from qemu.qmp import qmp_shell
from qemu.aqmp import qmp_shell
if __name__ == '__main__':