From f760b4aaf6a46bbfc13bab88e36271aab122a641 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Fri, 13 Dec 2024 10:10:08 +0100 Subject: [PATCH] Message builder for constructing message with carrying more information (#291) The messages to passing over rabbitmq for process control is build dynamically and able to carry more information. In the old implementation, the messages ace global dictionary variables and when the message need to change by copy which is error-prone. This commit introduce the `MessageBuilder` with class methods for creating kill/pause/status/play messages. For "kill" message, I also add support for passing the `force_kill` option. --- docs/source/tutorial.ipynb | 4 +- src/plumpy/base/state_machine.py | 70 +++++++++++------ src/plumpy/process_comms.py | 95 ++++++++++++++-------- src/plumpy/process_states.py | 58 ++++++++++++-- src/plumpy/processes.py | 121 ++++++++++++++++++++++------- tests/base/test_statemachine.py | 9 ++- tests/persistence/test_inmemory.py | 4 +- tests/persistence/test_pickle.py | 4 +- tests/rmq/test_process_comms.py | 4 +- tests/test_expose.py | 4 +- tests/test_process_comms.py | 2 +- tests/test_processes.py | 11 +-- tests/utils.py | 6 +- 13 files changed, 273 insertions(+), 119 deletions(-) diff --git a/docs/source/tutorial.ipynb b/docs/source/tutorial.ipynb index c1fdb3b2..b544d38b 100644 --- a/docs/source/tutorial.ipynb +++ b/docs/source/tutorial.ipynb @@ -281,7 +281,7 @@ " def continue_fn(self):\n", " print('continuing')\n", " # message is stored in the process status\n", - " return plumpy.Kill('I was killed')\n", + " return plumpy.Kill(plumpy.MessageBuilder.kill('I was killed'))\n", "\n", "\n", "process = ContinueProcess()\n", @@ -1118,7 +1118,7 @@ "\n", "process = SimpleProcess(communicator=communicator)\n", "\n", - "pprint(communicator.rpc_send(str(process.pid), plumpy.STATUS_MSG).result())" + "pprint(communicator.rpc_send(str(process.pid), plumpy.MessageBuilder.status()).result())" ] }, { diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index d99d0705..681858f0 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- """The state machine for processes""" +from __future__ import annotations + import enum import functools import inspect @@ -8,7 +10,19 @@ import os import sys from types import TracebackType -from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Sequence, Set, Type, Union, cast +from typing import ( + Any, + Callable, + Dict, + Hashable, + Iterable, + List, + Optional, + Sequence, + Set, + Type, + Union, +) from plumpy.futures import Future @@ -31,7 +45,7 @@ class StateEntryFailed(Exception): # noqa: N818 Failed to enter a state, can provide the next state to go to via this exception """ - def __init__(self, state: Hashable = None, *args: Any, **kwargs: Any) -> None: + def __init__(self, state: State, *args: Any, **kwargs: Any) -> None: super().__init__('failed to enter state') self.state = state self.args = args @@ -187,7 +201,7 @@ def __call__(cls, *args: Any, **kwargs: Any) -> 'StateMachine': :param kwargs: Any keyword arguments to be passed to the constructor :return: An instance of the state machine """ - inst = super().__call__(*args, **kwargs) + inst: StateMachine = super().__call__(*args, **kwargs) inst.transition_to(inst.create_initial_state()) call_with_super_check(inst.init) return inst @@ -300,16 +314,25 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None: def on_terminated(self) -> None: """Called when a terminal state is entered""" - def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> None: + def transition_to(self, new_state: State | None, **kwargs: Any) -> None: + """Transite to the new state. + + The new target state will be create lazily when the state is not yet instantiated, + which will happened for states not in the expect path such as pause and kill. + The arguments are passed to the state class to create state instance. + (process arg does not need to pass since it will always call with 'self' as process) + """ assert not self._transitioning, 'Cannot call transition_to when already transitioning state' + if new_state is None: + # early return if the new state is `None` + # it can happened when transit from terminal state + return None + initial_state_label = self._state.LABEL if self._state is not None else None label = None try: self._transitioning = True - - # Make sure we have a state instance - new_state = self._create_state_instance(new_state, *args, **kwargs) label = new_state.LABEL # If the previous transition failed, do not try to exit it but go straight to next state @@ -319,8 +342,7 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A try: self._enter_next_state(new_state) except StateEntryFailed as exception: - # Make sure we have a state instance - new_state = self._create_state_instance(exception.state, *exception.args, **exception.kwargs) + new_state = exception.state label = new_state.LABEL self._exit_current_state(new_state) self._enter_next_state(new_state) @@ -338,7 +360,11 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A self._transitioning = False def transition_failed( - self, initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType + self, + initial_state: Hashable, + final_state: Hashable, + exception: Exception, + trace: TracebackType, ) -> None: """Called when a state transitions fails. @@ -355,6 +381,10 @@ def set_debug(self, enabled: bool) -> None: self._debug: bool = enabled def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State: + # XXX: this method create state from label, which is duplicate as _create_state_instance and less generic + # because the label is defined after the state and required to be know before calling this function. + # This method should be replaced by `_create_state_instance`. + # aiida-core using this method for its Waiting state override. try: return self.get_states_map()[state_label](self, *args, **kwargs) except KeyError: @@ -383,20 +413,10 @@ def _enter_next_state(self, next_state: State) -> None: self._state = next_state self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) - def _create_state_instance(self, state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> State: - if isinstance(state, State): - # It's already a state instance - return state - - # OK, have to create it - state_cls = self._ensure_state_class(state) - return state_cls(self, *args, **kwargs) + def _create_state_instance(self, state_cls: Hashable, **kwargs: Any) -> State: + if state_cls not in self.get_states_map(): + raise ValueError(f'{state_cls} is not a valid state') - def _ensure_state_class(self, state: Union[Hashable, Type[State]]) -> Type[State]: - if inspect.isclass(state) and issubclass(state, State): - return state + cls = self.get_states_map()[state_cls] - try: - return self.get_states_map()[cast(Hashable, state)] - except KeyError: - raise ValueError(f'{state} is not a valid state') + return cls(self, **kwargs) diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index 293c680b..e615ee4a 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- """Module for process level communication functions and classes""" +from __future__ import annotations + import asyncio -import copy import logging from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast @@ -12,10 +13,7 @@ from .utils import PID_TYPE __all__ = [ - 'KILL_MSG', - 'PAUSE_MSG', - 'PLAY_MSG', - 'STATUS_MSG', + 'MessageBuilder', 'ProcessLauncher', 'RemoteProcessController', 'RemoteProcessThreadController', @@ -31,6 +29,7 @@ INTENT_KEY = 'intent' MESSAGE_KEY = 'message' +FORCE_KILL_KEY = 'force_kill' class Intent: @@ -42,10 +41,45 @@ class Intent: STATUS: str = 'status' -PAUSE_MSG = {INTENT_KEY: Intent.PAUSE} -PLAY_MSG = {INTENT_KEY: Intent.PLAY} -KILL_MSG = {INTENT_KEY: Intent.KILL} -STATUS_MSG = {INTENT_KEY: Intent.STATUS} +MessageType = Dict[str, Any] + + +class MessageBuilder: + """MessageBuilder will construct different messages that can passing over communicator.""" + + @classmethod + def play(cls, text: str | None = None) -> MessageType: + """The play message send over communicator.""" + return { + INTENT_KEY: Intent.PLAY, + MESSAGE_KEY: text, + } + + @classmethod + def pause(cls, text: str | None = None) -> MessageType: + """The pause message send over communicator.""" + return { + INTENT_KEY: Intent.PAUSE, + MESSAGE_KEY: text, + } + + @classmethod + def kill(cls, text: str | None = None, force_kill: bool = False) -> MessageType: + """The kill message send over communicator.""" + return { + INTENT_KEY: Intent.KILL, + MESSAGE_KEY: text, + FORCE_KILL_KEY: force_kill, + } + + @classmethod + def status(cls, text: str | None = None) -> MessageType: + """The status message send over communicator.""" + return { + INTENT_KEY: Intent.STATUS, + MESSAGE_KEY: text, + } + TASK_KEY = 'task' TASK_ARGS = 'args' @@ -162,7 +196,7 @@ async def get_status(self, pid: 'PID_TYPE') -> 'ProcessStatus': :param pid: the process id :return: the status response from the process """ - future = self._communicator.rpc_send(pid, STATUS_MSG) + future = self._communicator.rpc_send(pid, MessageBuilder.status()) result = await asyncio.wrap_future(future) return result @@ -174,11 +208,9 @@ async def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'Pr :param msg: optional pause message :return: True if paused, False otherwise """ - message = copy.copy(PAUSE_MSG) - if msg is not None: - message[MESSAGE_KEY] = msg + msg = MessageBuilder.pause(text=msg) - pause_future = self._communicator.rpc_send(pid, message) + pause_future = self._communicator.rpc_send(pid, msg) # rpc_send return a thread future from communicator future = await asyncio.wrap_future(pause_future) # future is just returned from rpc call which return a kiwipy future @@ -192,12 +224,12 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult': :param pid: the pid of the process to play :return: True if played, False otherwise """ - play_future = self._communicator.rpc_send(pid, PLAY_MSG) + play_future = self._communicator.rpc_send(pid, MessageBuilder.play()) future = await asyncio.wrap_future(play_future) result = await asyncio.wrap_future(future) return result - async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'ProcessResult': + async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> 'ProcessResult': """ Kill the process @@ -205,12 +237,11 @@ async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'Pro :param msg: optional kill message :return: True if killed, False otherwise """ - message = copy.copy(KILL_MSG) - if msg is not None: - message[MESSAGE_KEY] = msg + if msg is None: + msg = MessageBuilder.kill() # Wait for the communication to go through - kill_future = self._communicator.rpc_send(pid, message) + kill_future = self._communicator.rpc_send(pid, msg) future = await asyncio.wrap_future(kill_future) # Now wait for the kill to be enacted result = await asyncio.wrap_future(future) @@ -331,7 +362,7 @@ def get_status(self, pid: 'PID_TYPE') -> kiwipy.Future: :param pid: the process id :return: the status response from the process """ - return self._communicator.rpc_send(pid, STATUS_MSG) + return self._communicator.rpc_send(pid, MessageBuilder.status()) def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future: """ @@ -342,11 +373,9 @@ def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fu :return: a response future from the process to be paused """ - message = copy.copy(PAUSE_MSG) - if msg is not None: - message[MESSAGE_KEY] = msg + msg = MessageBuilder.pause(text=msg) - return self._communicator.rpc_send(pid, message) + return self._communicator.rpc_send(pid, msg) def pause_all(self, msg: Any) -> None: """ @@ -364,7 +393,7 @@ def play_process(self, pid: 'PID_TYPE') -> kiwipy.Future: :return: a response future from the process to be played """ - return self._communicator.rpc_send(pid, PLAY_MSG) + return self._communicator.rpc_send(pid, MessageBuilder.play()) def play_all(self) -> None: """ @@ -372,7 +401,7 @@ def play_all(self) -> None: """ self._communicator.broadcast_send(None, subject=Intent.PLAY) - def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future: + def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> kiwipy.Future: """ Kill the process @@ -381,18 +410,20 @@ def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fut :return: a response future from the process to be killed """ - message = copy.copy(KILL_MSG) - if msg is not None: - message[MESSAGE_KEY] = msg + if msg is None: + msg = MessageBuilder.kill() - return self._communicator.rpc_send(pid, message) + return self._communicator.rpc_send(pid, msg) - def kill_all(self, msg: Optional[Any]) -> None: + def kill_all(self, msg: Optional[MessageType]) -> None: """ Kill all processes that are subscribed to the same communicator :param msg: an optional pause message """ + if msg is None: + msg = MessageBuilder.kill() + self._communicator.broadcast_send(msg, subject=Intent.KILL) def continue_process( diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index cf29973a..d369a1e9 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + import sys import traceback from enum import Enum @@ -8,6 +10,8 @@ import yaml from yaml.loader import Loader +from plumpy.process_comms import MessageBuilder, MessageType + try: import tblib @@ -48,7 +52,12 @@ class Interruption(Exception): # noqa: N818 class KillInterruption(Interruption): - pass + def __init__(self, msg: MessageType | None): + super().__init__() + if msg is None: + msg = MessageBuilder.kill() + + self.msg: MessageType = msg class PauseInterruption(Interruption): @@ -64,7 +73,7 @@ class Command(persistence.Savable): @auto_persist('msg') class Kill(Command): - def __init__(self, msg: Optional[Any] = None): + def __init__(self, msg: Optional[MessageType] = None): super().__init__() self.msg = msg @@ -76,7 +85,10 @@ class Pause(Command): @auto_persist('msg', 'data') class Wait(Command): def __init__( - self, continue_fn: Optional[Callable[..., Any]] = None, msg: Optional[Any] = None, data: Optional[Any] = None + self, + continue_fn: Optional[Callable[..., Any]] = None, + msg: Optional[Any] = None, + data: Optional[Any] = None, ): super().__init__() self.continue_fn = continue_fn @@ -349,13 +361,23 @@ def resume(self, value: Any = NULL) -> None: class Excepted(State): + """ + Excepted state, can optionally provide exception and trace_back + + :param exception: The exception instance + :param trace_back: An optional exception traceback + """ + LABEL = ProcessState.EXCEPTED EXC_VALUE = 'ex_value' TRACEBACK = 'traceback' def __init__( - self, process: 'Process', exception: Optional[BaseException], trace_back: Optional[TracebackType] = None + self, + process: 'Process', + exception: Optional[BaseException], + trace_back: Optional[TracebackType] = None, ): """ :param process: The associated process @@ -387,15 +409,27 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi else: self.traceback = None - def get_exc_info(self) -> Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]]: + def get_exc_info( + self, + ) -> Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]]: """ Recreate the exc_info tuple and return it """ - return type(self.exception) if self.exception else None, self.exception, self.traceback + return ( + type(self.exception) if self.exception else None, + self.exception, + self.traceback, + ) @auto_persist('result', 'successful') class Finished(State): + """State for process is finished. + + :param result: The result of process + :param successful: Boolean for the exit code is ``0`` the process is successful. + """ + LABEL = ProcessState.FINISHED def __init__(self, process: 'Process', result: Any, successful: bool) -> None: @@ -406,13 +440,21 @@ def __init__(self, process: 'Process', result: Any, successful: bool) -> None: @auto_persist('msg') class Killed(State): + """ + Represents a state where a process has been killed. + + This state is used to indicate that a process has been terminated and can optionally + include a message providing details about the termination. + + :param msg: An optional message explaining the reason for the process termination. + """ + LABEL = ProcessState.KILLED - def __init__(self, process: 'Process', msg: Optional[str]): + def __init__(self, process: 'Process', msg: Optional[MessageType]): """ :param process: The associated process :param msg: Optional kill message - """ super().__init__(process) self.msg = msg diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index ffddf7b5..0866ee41 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -26,6 +26,7 @@ Sequence, Tuple, Type, + TypeVar, Union, cast, ) @@ -39,15 +40,27 @@ import yaml from aio_pika.exceptions import ChannelInvalidStateError, ConnectionClosed -from . import events, exceptions, futures, persistence, ports, process_comms, process_states, utils +from . import ( + events, + exceptions, + futures, + persistence, + ports, + process_comms, + process_states, + utils, +) from .base import state_machine from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event from .base.utils import call_with_super_check, super_check from .event_helper import EventHelper +from .process_comms import MESSAGE_KEY, MessageBuilder, MessageType from .process_listener import ProcessListener from .process_spec import ProcessSpec from .utils import PID_TYPE, SAVED_STATE_TYPE, protected +T = TypeVar('T') + __all__ = ['BundleKeys', 'Process', 'ProcessSpec', 'TransitionFailed'] _LOGGER = logging.getLogger(__name__) @@ -91,7 +104,13 @@ def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: @persistence.auto_persist( - '_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status', '_event_helper' + '_pid', + '_creation_time', + '_future', + '_paused', + '_status', + '_pre_paused_status', + '_event_helper', ) class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta): """ @@ -231,7 +250,9 @@ def get_description(cls) -> Dict[str, Any]: @classmethod def recreate_from( - cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] = None + cls, + saved_state: SAVED_STATE_TYPE, + load_context: Optional[persistence.LoadSaveContext] = None, ) -> 'Process': """ Recreate a process from a saved state, passing any positional and @@ -314,14 +335,21 @@ def init(self) -> None: identifier = self._communicator.add_broadcast_subscriber(subscriber, identifier=str(self.pid)) self.add_cleanup(functools.partial(self._communicator.remove_broadcast_subscriber, identifier)) except kiwipy.TimeoutError: - self.logger.exception('Process<%s>: failed to register as a broadcast subscriber', self.pid) + self.logger.exception( + 'Process<%s>: failed to register as a broadcast subscriber', + self.pid, + ) if not self._future.done(): def try_killing(future: futures.Future) -> None: if future.cancelled(): - if not self.kill('Killed by future being cancelled'): - self.logger.warning('Process<%s>: Failed to kill process on future cancel', self.pid) + msg = MessageBuilder.kill(text='Killed by future being cancelled') + if not self.kill(msg): + self.logger.warning( + 'Process<%s>: Failed to kill process on future cancel', + self.pid, + ) self._future.add_done_callback(try_killing) @@ -425,7 +453,13 @@ def launch( The process is started asynchronously, without blocking other task in the event loop. """ - process = process_class(inputs=inputs, pid=pid, logger=logger, loop=self.loop, communicator=self._communicator) + process = process_class( + inputs=inputs, + pid=pid, + logger=logger, + loop=self.loop, + communicator=self._communicator, + ) self.loop.create_task(process.step_until_terminated()) return process @@ -477,7 +511,7 @@ def killed(self) -> bool: """Return whether the process is killed.""" return self.state == process_states.ProcessState.KILLED - def killed_msg(self) -> Optional[str]: + def killed_msg(self) -> Optional[MessageType]: """Return the killed message.""" if isinstance(self._state, process_states.Killed): return self._state.msg @@ -529,7 +563,10 @@ def call_soon(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> return handle def callback_excepted( - self, _callback: Callable[..., Any], exception: Optional[BaseException], trace: Optional[TracebackType] + self, + _callback: Callable[..., Any], + exception: Optional[BaseException], + trace: Optional[TracebackType], ) -> None: if self.state != process_states.ProcessState.EXCEPTED: self.fail(exception, trace) @@ -555,7 +592,7 @@ def _process_scope(self) -> Generator[None, None, None]: stack_copy.pop() PROCESS_STACK.set(stack_copy) - async def _run_task(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + async def _run_task(self, callback: Callable[..., T], *args: Any, **kwargs: Any) -> T: """ This method should be used to run all Process related functions and coroutines. If there is an exception the process will enter the EXCEPTED state. @@ -576,7 +613,9 @@ async def _run_task(self, callback: Callable[..., Any], *args: Any, **kwargs: An # region Persistence def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: Optional[persistence.LoadSaveContext] + self, + out_state: SAVED_STATE_TYPE, + save_context: Optional[persistence.LoadSaveContext], ) -> None: """ Ask the process to save its current instance state. @@ -828,7 +867,9 @@ def on_finish(self, result: Any, successful: bool) -> None: if successful: validation_error = self.spec().outputs.validate(self.outputs) if validation_error: - raise StateEntryFailed(process_states.ProcessState.FINISHED, result, False) + state_cls = self.get_states_map()[process_states.ProcessState.FINISHED] + finished_state = state_cls(self, result=result, successful=False) + raise StateEntryFailed(finished_state) self.future().set_result(self.outputs) @@ -857,10 +898,15 @@ def on_excepted(self) -> None: self._fire_event(ProcessListener.on_process_excepted, str(self.future().exception())) @super_check - def on_kill(self, msg: Optional[str]) -> None: + def on_kill(self, msg: Optional[MessageType]) -> None: """Entering the KILLED state.""" - self.set_status(msg) - self.future().set_exception(exceptions.KilledError(msg)) + if msg is None: + msg_txt = '' + else: + msg_txt = msg[MESSAGE_KEY] or '' + + self.set_status(msg_txt) + self.future().set_exception(exceptions.KilledError(msg_txt)) @super_check def on_killed(self) -> None: @@ -906,7 +952,12 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An :param msg: the message :return: the outcome of processing the message, the return value will be sent back as a response to the sender """ - self.logger.debug("Process<%s>: received RPC message with communicator '%s': %r", self.pid, _comm, msg) + self.logger.debug( + "Process<%s>: received RPC message with communicator '%s': %r", + self.pid, + _comm, + msg, + ) intent = msg[process_comms.INTENT_KEY] @@ -915,7 +966,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An if intent == process_comms.Intent.PAUSE: return self._schedule_rpc(self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None)) if intent == process_comms.Intent.KILL: - return self._schedule_rpc(self.kill, msg=msg.get(process_comms.MESSAGE_KEY, None)) + return self._schedule_rpc(self.kill, msg=msg) if intent == process_comms.Intent.STATUS: status_info: Dict[str, Any] = {} self.get_status_info(status_info) @@ -935,7 +986,11 @@ def broadcast_receive( """ self.logger.debug( - "Process<%s>: received broadcast message '%s' with communicator '%s': %r", self.pid, subject, _comm, body + "Process<%s>: received broadcast message '%s' with communicator '%s': %r", + self.pid, + subject, + _comm, + body, ) # If we get a message we recognise then action it, otherwise ignore @@ -1001,13 +1056,20 @@ def close(self) -> None: # region State related methods def transition_failed( - self, initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType + self, + initial_state: Hashable, + final_state: Hashable, + exception: Exception, + trace: TracebackType, ) -> None: # If we are creating, then reraise instead of failing. if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace) + new_state = self._create_state_instance( + process_states.ProcessState.EXCEPTED, exception=exception, trace_back=trace + ) + self.transition_to(new_state) def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: """Pause the process. @@ -1070,8 +1132,8 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu def do_kill(_next_state: process_states.State) -> Any: try: - # Ignore the next state - self.transition_to(process_states.ProcessState.KILLED, str(exception)) + new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=exception.msg) + self.transition_to(new_state) return True finally: self._killing = None @@ -1123,9 +1185,12 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac :param exception: The exception that caused the failure :param trace_back: Optional exception traceback """ - self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace_back) + new_state = self._create_state_instance( + process_states.ProcessState.EXCEPTED, exception=exception, trace_back=trace_back + ) + self.transition_to(new_state) - def kill(self, msg: Union[str, None] = None) -> Union[bool, asyncio.Future]: + def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: """ Kill the process :param msg: An optional kill message @@ -1151,7 +1216,8 @@ def kill(self, msg: Union[str, None] = None) -> Union[bool, asyncio.Future]: self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) - self.transition_to(process_states.ProcessState.KILLED, msg) + new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=msg) + self.transition_to(new_state) return True @property @@ -1168,7 +1234,10 @@ def create_initial_state(self) -> process_states.State: :return: A Created state """ - return cast(process_states.State, self.get_state_class(process_states.ProcessState.CREATED)(self, self.run)) + return cast( + process_states.State, + self.get_state_class(process_states.ProcessState.CREATED)(self, self.run), + ) def recreate_state(self, saved_state: persistence.Bundle) -> process_states.State: """ diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index 5b4b73d8..3a1621a2 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -57,6 +57,7 @@ class Paused(state_machine.State): def __init__(self, player, playing_state): assert isinstance(playing_state, Playing), 'Must provide the playing state to pause' super().__init__(player) + self._player = player self.playing_state = playing_state def __str__(self): @@ -64,7 +65,7 @@ def __str__(self): def play(self, track=None): if track is not None: - self.state_machine.transition_to(Playing, track=track) + self.state_machine.transition_to(Playing(player=self.state_machine, track=track)) else: self.state_machine.transition_to(self.playing_state) @@ -80,7 +81,7 @@ def __str__(self): return '[]' def play(self, track): - self.state_machine.transition_to(Playing, track=track) + self.state_machine.transition_to(Playing(self.state_machine, track=track)) class CdPlayer(state_machine.StateMachine): @@ -107,12 +108,12 @@ def play(self, track=None): @state_machine.event(from_states=Playing, to_states=Paused) def pause(self): - self.transition_to(Paused, playing_state=self._state) + self.transition_to(Paused(self, playing_state=self._state)) return True @state_machine.event(from_states=(Playing, Paused), to_states=Stopped) def stop(self): - self.transition_to(Stopped) + self.transition_to(Stopped(self)) class TestStateMachine(unittest.TestCase): diff --git a/tests/persistence/test_inmemory.py b/tests/persistence/test_inmemory.py index b0db46e7..9e3141de 100644 --- a/tests/persistence/test_inmemory.py +++ b/tests/persistence/test_inmemory.py @@ -1,11 +1,9 @@ # -*- coding: utf-8 -*- import unittest -from ..utils import ProcessWithCheckpoint - import plumpy -import plumpy +from ..utils import ProcessWithCheckpoint class TestInMemoryPersister(unittest.TestCase): diff --git a/tests/persistence/test_pickle.py b/tests/persistence/test_pickle.py index dd68b4fd..da4ede51 100644 --- a/tests/persistence/test_pickle.py +++ b/tests/persistence/test_pickle.py @@ -5,10 +5,10 @@ if getattr(tempfile, 'TemporaryDirectory', None) is None: from backports import tempfile -from ..utils import ProcessWithCheckpoint - import plumpy +from ..utils import ProcessWithCheckpoint + class TestPicklePersister(unittest.TestCase): def test_save_load_roundtrip(self): diff --git a/tests/rmq/test_process_comms.py b/tests/rmq/test_process_comms.py index 7223b888..a6249d10 100644 --- a/tests/rmq/test_process_comms.py +++ b/tests/rmq/test_process_comms.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- import asyncio -import copy import kiwipy import pytest @@ -196,8 +195,7 @@ async def test_kill_all(self, thread_communicator, sync_controller): for _ in range(10): procs.append(utils.WaitForSignalProcess(communicator=thread_communicator)) - msg = copy.copy(process_comms.KILL_MSG) - msg[process_comms.MESSAGE_KEY] = 'bang bang, I shot you down' + msg = process_comms.MessageBuilder.kill(text='bang bang, I shot you down') sync_controller.kill_all(msg) await utils.wait_util(lambda: all([proc.killed() for proc in procs])) diff --git a/tests/test_expose.py b/tests/test_expose.py index 0f6f8087..c5e6014c 100644 --- a/tests/test_expose.py +++ b/tests/test_expose.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- import unittest -from .utils import NewLoopProcess - from plumpy.ports import PortNamespace from plumpy.process_spec import ProcessSpec from plumpy.processes import Process +from .utils import NewLoopProcess + def validator_function(input, port): pass diff --git a/tests/test_process_comms.py b/tests/test_process_comms.py index c59737ac..44947230 100644 --- a/tests/test_process_comms.py +++ b/tests/test_process_comms.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- import pytest -from tests import utils import plumpy from plumpy import process_comms +from tests import utils class Process(plumpy.Process): diff --git a/tests/test_processes.py b/tests/test_processes.py index faea9eae..7b21c463 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -2,18 +2,17 @@ """Process tests""" import asyncio -import copy import enum import unittest import kiwipy import pytest -from tests import utils import plumpy from plumpy import BundleKeys, Process, ProcessState -from plumpy.process_comms import KILL_MSG, MESSAGE_KEY +from plumpy.process_comms import MessageBuilder from plumpy.utils import AttributesFrozendict +from tests import utils class ForgetToCallParent(plumpy.Process): @@ -323,8 +322,7 @@ def run(self, **kwargs): def test_kill(self): proc: Process = utils.DummyProcess() - msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = 'Farewell!' + msg = MessageBuilder.kill(text='Farewell!') proc.kill(msg) self.assertTrue(proc.killed()) self.assertEqual(proc.killed_msg(), msg) @@ -430,8 +428,7 @@ class KillProcess(Process): after_kill = False def run(self, **kwargs): - msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = 'killed' + msg = MessageBuilder.kill(text='killed') self.kill(msg) # The following line should be executed because kill will not # interrupt execution of a method call in the RUNNING state diff --git a/tests/utils.py b/tests/utils.py index f2a58dfc..13abc38c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,13 +3,12 @@ import asyncio import collections -import copy import unittest from collections.abc import Mapping import plumpy from plumpy import persistence, process_states, processes, utils -from plumpy.process_comms import KILL_MSG, MESSAGE_KEY +from plumpy.process_comms import MessageBuilder Snapshot = collections.namedtuple('Snapshot', ['state', 'bundle', 'outputs']) @@ -86,8 +85,7 @@ def last_step(self): class KillProcess(processes.Process): @utils.override def run(self): - msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = 'killed' + msg = MessageBuilder.kill(text='killed') return process_states.Kill(msg=msg)