diff --git a/scripts/qmp/qmp-shell b/scripts/qmp/qmp-shell index f14fe211cc..ec028d662e 100755 --- a/scripts/qmp/qmp-shell +++ b/scripts/qmp/qmp-shell @@ -66,7 +66,6 @@ # sent to QEMU, which is useful for debugging and documentation generation. import argparse import ast -import atexit import json import os import re @@ -142,6 +141,11 @@ class QMPShell(qmp.QEMUMonitorProtocol): self.pretty = pretty self.verbose = verbose + def close(self) -> None: + # Hook into context manager of parent to save shell history. + self._save_history() + super().close() + def _fill_completion(self) -> None: cmds = self.cmd('query-commands') if 'error' in cmds: @@ -164,9 +168,8 @@ class QMPShell(qmp.QEMUMonitorProtocol): pass except IOError as err: print(f"Failed to read history '{self._histfile}': {err!s}") - atexit.register(self.__save_history) - def __save_history(self) -> None: + def _save_history(self) -> None: try: readline.write_history_file(self._histfile) except IOError as err: @@ -448,25 +451,25 @@ def main() -> None: parser.error("QMP socket or TCP address must be specified") shell_class = HMPShell if args.hmp else QMPShell + try: address = shell_class.parse_address(args.qmp_server) except qmp.QMPBadPortError: parser.error(f"Bad port number: {args.qmp_server}") return # pycharm doesn't know error() is noreturn - qemu = shell_class(address, args.pretty, args.verbose) + with shell_class(address, args.pretty, args.verbose) as qemu: + try: + qemu.connect(negotiate=not args.skip_negotiation) + except qmp.QMPConnectError: + die("Didn't get QMP greeting message") + except qmp.QMPCapabilitiesError: + die("Couldn't negotiate capabilities") + except OSError as err: + die(f"Couldn't connect to {args.qmp_server}: {err!s}") - try: - qemu.connect(negotiate=not args.skip_negotiation) - except qmp.QMPConnectError: - die("Didn't get QMP greeting message") - except qmp.QMPCapabilitiesError: - die("Couldn't negotiate capabilities") - except OSError as err: - die(f"Couldn't connect to {args.qmp_server}: {err!s}") - - for _ in qemu.repl(): - pass + for _ in qemu.repl(): + pass if __name__ == '__main__':