diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 048915f5..fc926008 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -34,7 +34,6 @@ _LOGGER = logging.getLogger(__name__) -LABEL_TYPE = Union[None, enum.Enum, str] EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, Optional['State']], None] @@ -131,10 +130,12 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: @runtime_checkable class State(Protocol): - LABEL: ClassVar[LABEL_TYPE] - ALLOWED: ClassVar[set[str]] + LABEL: ClassVar[Any] + ALLOWED: ClassVar[set[Any]] is_terminal: ClassVar[bool] + def __init__(self, *args: Any, **kwargs: Any): ... + def enter(self) -> None: ... def exit(self) -> None: ... @@ -155,7 +156,7 @@ def execute(self) -> State | None: ... -def create_state(st: StateMachine, state_label: Hashable, *args, **kwargs: Any) -> State: +def create_state(st: StateMachine, state_label: Hashable, *args: Any, **kwargs: Any) -> State: if state_label not in st.get_states_map(): raise ValueError(f'{state_label} is not a valid state') @@ -211,13 +212,13 @@ def get_states(cls) -> Sequence[Type[State]]: raise RuntimeError('States not defined') @classmethod - def initial_state_label(cls) -> LABEL_TYPE: + def initial_state_label(cls) -> Any: cls.__ensure_built() assert cls.STATES is not None return cls.STATES[0].LABEL @classmethod - def get_state_class(cls, label: LABEL_TYPE) -> Type[State]: + def get_state_class(cls, label: Any) -> Type[State]: cls.__ensure_built() assert cls._STATES_MAP is not None return cls._STATES_MAP[label] @@ -261,11 +262,11 @@ def init(self) -> None: def __str__(self) -> str: return f'<{self.__class__.__name__}> ({self.state})' - def create_initial_state(self) -> State: - return self.get_state_class(self.initial_state_label())(self) + def create_initial_state(self, *args: Any, **kwargs: Any) -> State: + return self.get_state_class(self.initial_state_label())(self, *args, **kwargs) @property - def state(self) -> Optional[LABEL_TYPE]: + def state(self) -> Any: if self._state is None: return None return self._state.LABEL @@ -373,7 +374,7 @@ def _exit_current_state(self, next_state: State) -> None: return # Nothing to exit if next_state.LABEL not in self._state.ALLOWED: - raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}') + raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.LABEL}') self._fire_state_event(StateEventHook.EXITING_STATE, next_state) self._state.exit() diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 80971154..5f3e8237 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -5,7 +5,20 @@ import traceback from enum import Enum from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast, final +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Optional, + Protocol, + Tuple, + Type, + Union, + cast, + final, + runtime_checkable, +) import yaml from yaml.loader import Loader @@ -22,7 +35,7 @@ from . import exceptions, futures, persistence, utils from .base import state_machine as st from .lang import NULL -from .persistence import auto_persist +from .persistence import LoadSaveContext, auto_persist from .utils import SAVED_STATE_TYPE __all__ = [ @@ -147,14 +160,19 @@ class ProcessState(Enum): KILLED = 'killed' +@runtime_checkable +class Savable(Protocol): + def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... + + @final @auto_persist('args', 'kwargs') class Created(persistence.Savable): - LABEL = ProcessState.CREATED - ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} + LABEL: ClassVar = ProcessState.CREATED + ALLOWED: ClassVar = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} RUN_FN = 'run_fn' - is_terminal = False + is_terminal: ClassVar[bool] = False def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: assert run_fn is not None @@ -186,8 +204,8 @@ def exit(self) -> None: ... @final @auto_persist('args', 'kwargs') class Running(persistence.Savable): - LABEL = ProcessState.RUNNING - ALLOWED = { + LABEL: ClassVar = ProcessState.RUNNING + ALLOWED: ClassVar = { ProcessState.RUNNING, ProcessState.WAITING, ProcessState.FINISHED, @@ -203,7 +221,7 @@ class Running(persistence.Savable): _running: bool = False _run_handle = None - is_terminal = False + is_terminal: ClassVar[bool] = False def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: assert run_fn is not None @@ -291,18 +309,17 @@ def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> st.Stat else: raise ValueError('Unrecognised command') - return cast(st.State, state) # casting from base.State to process.State + return state def enter(self) -> None: ... def exit(self) -> None: ... -@final @auto_persist('msg', 'data') class Waiting(persistence.Savable): - LABEL = ProcessState.WAITING - ALLOWED = { + LABEL: ClassVar = ProcessState.WAITING + ALLOWED: ClassVar = { ProcessState.RUNNING, ProcessState.WAITING, ProcessState.KILLED, @@ -314,7 +331,7 @@ class Waiting(persistence.Savable): _interruption = None - is_terminal = False + is_terminal: ClassVar[bool] = False def __str__(self) -> str: state_info = super().__str__() @@ -355,7 +372,7 @@ def interrupt(self, reason: Exception) -> None: # This will cause the future in execute() to raise the exception self._waiting_future.set_exception(reason) - async def execute(self) -> st.State: # type: ignore + async def execute(self) -> st.State: try: result = await self._waiting_future except Interruption: @@ -374,7 +391,7 @@ async def execute(self) -> st.State: # type: ignore self.process, ProcessState.RUNNING, process=self.process, done_callback=self.done_callback, *result ) - return cast(st.State, next_state) # casting from base.State to process.State + return next_state def resume(self, value: Any = NULL) -> None: assert self._waiting_future is not None, 'Not yet waiting' @@ -398,13 +415,13 @@ class Excepted(persistence.Savable): :param traceback: An optional exception traceback """ - LABEL = ProcessState.EXCEPTED - ALLOWED: set[str] = set() + LABEL: ClassVar = ProcessState.EXCEPTED + ALLOWED: ClassVar[set[str]] = set() EXC_VALUE = 'ex_value' TRACEBACK = 'traceback' - is_terminal = True + is_terminal: ClassVar = True def __init__( self, @@ -466,10 +483,10 @@ class Finished(persistence.Savable): :param successful: Boolean for the exit code is ``0`` the process is successful. """ - LABEL = ProcessState.FINISHED - ALLOWED: set[str] = set() + LABEL: ClassVar = ProcessState.FINISHED + ALLOWED: ClassVar[set[str]] = set() - is_terminal = True + is_terminal: ClassVar[bool] = True def __init__(self, result: Any, successful: bool) -> None: self.result = result @@ -495,10 +512,10 @@ class Killed(persistence.Savable): :param msg: An optional message explaining the reason for the process termination. """ - LABEL = ProcessState.KILLED - ALLOWED: set[str] = set() + LABEL: ClassVar = ProcessState.KILLED + ALLOWED: ClassVar[set[str]] = set() - is_terminal = True + is_terminal: ClassVar[bool] = True def __init__(self, msg: Optional[MessageType]): """ diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 8e101900..bae08dd4 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- """The main Process module""" +from __future__ import annotations + import abc import asyncio import contextlib @@ -58,8 +60,8 @@ StateMachine, StateMachineError, TransitionFailed, - event, create_state, + event, ) from .base.utils import call_with_super_check, super_check from .event_helper import EventHelper @@ -195,7 +197,7 @@ def get_states(cls) -> Sequence[Type[state_machine.State]]: ) @classmethod - def get_state_classes(cls) -> Dict[Hashable, Type[state_machine.State]]: + def get_state_classes(cls) -> dict[process_states.ProcessState, Type[state_machine.State]]: # A mapping of the State constants to the corresponding state class return { process_states.ProcessState.CREATED: process_states.Created, @@ -634,7 +636,9 @@ def save_instance_state( """ super().save_instance_state(out_state, save_context) - out_state['_state'] = self._state.save() + # FIXME: the combined ProcessState protocol should cover the case + if isinstance(self._state, process_states.Savable): + out_state['_state'] = self._state.save() # Inputs/outputs if self.raw_inputs is not None: @@ -1205,7 +1209,7 @@ def fail(self, exception: Optional[BaseException], traceback: Optional[Traceback :param traceback: Optional exception traceback """ # state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] - new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=trace_back) + new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=traceback) self.transition_to(new_state) def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: @@ -1225,7 +1229,7 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] # Already killing return self._killing - if self._stepping: + if self._stepping and isinstance(self._state, Interruptable): # Ask the step function to pause by setting this flag and giving the # caller back a future interrupt_exception = process_states.KillInterruption(msg) @@ -1252,10 +1256,7 @@ def create_initial_state(self) -> state_machine.State: :return: A Created state """ - return cast( - state_machine.State, - self.get_state_class(process_states.ProcessState.CREATED)(self, self.run), - ) + return self.get_state_class(process_states.ProcessState.CREATED)(self, self.run) def recreate_state(self, saved_state: persistence.Bundle) -> state_machine.State: """ @@ -1326,9 +1327,8 @@ async def step(self) -> None: raise except Exception: # Overwrite the next state to go to excepted directly - _, exception, traceback = sys.exc_info() next_state = create_state( - self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=traceback + self, process_states.ProcessState.EXCEPTED, exception=sys.exc_info()[1], traceback=sys.exc_info()[2] ) self._set_interrupt_action(None) diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 2389942b..865a5b61 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -11,7 +11,6 @@ Any, Callable, Dict, - Hashable, List, Mapping, MutableSequence, @@ -71,6 +70,7 @@ def get_outline(self) -> Union['_Instruction', '_FunctionCall']: return self._outline +# FIXME: better use composition here @persistence.auto_persist('_awaiting') class Waiting(process_states.Waiting): """Overwrite the waiting state""" @@ -124,7 +124,7 @@ class WorkChain(mixins.ContextMixin, processes.Process): _CONTEXT = 'CONTEXT' @classmethod - def get_state_classes(cls) -> Dict[Hashable, Type[state_machine.State]]: + def get_state_classes(cls) -> Dict[process_states.ProcessState, Type[state_machine.State]]: states_map = super().get_state_classes() states_map[process_states.ProcessState.WAITING] = Waiting return states_map