Skip to content

Commit

Permalink
Fix pre-commit errors
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 4, 2024
1 parent 2fed08a commit 469298a
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 47 deletions.
21 changes: 11 additions & 10 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

_LOGGER = logging.getLogger(__name__)

LABEL_TYPE = Union[None, enum.Enum, str]
EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, Optional['State']], None]


Expand Down Expand Up @@ -131,10 +130,12 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any:

@runtime_checkable
class State(Protocol):
LABEL: ClassVar[LABEL_TYPE]
ALLOWED: ClassVar[set[str]]
LABEL: ClassVar[Any]
ALLOWED: ClassVar[set[Any]]
is_terminal: ClassVar[bool]

def __init__(self, *args: Any, **kwargs: Any): ...

def enter(self) -> None: ...

def exit(self) -> None: ...
Expand All @@ -155,7 +156,7 @@ def execute(self) -> State | None:
...


def create_state(st: StateMachine, state_label: Hashable, *args, **kwargs: Any) -> State:
def create_state(st: StateMachine, state_label: Hashable, *args: Any, **kwargs: Any) -> State:
if state_label not in st.get_states_map():
raise ValueError(f'{state_label} is not a valid state')

Expand Down Expand Up @@ -211,13 +212,13 @@ def get_states(cls) -> Sequence[Type[State]]:
raise RuntimeError('States not defined')

@classmethod
def initial_state_label(cls) -> LABEL_TYPE:
def initial_state_label(cls) -> Any:
cls.__ensure_built()
assert cls.STATES is not None
return cls.STATES[0].LABEL

@classmethod
def get_state_class(cls, label: LABEL_TYPE) -> Type[State]:
def get_state_class(cls, label: Any) -> Type[State]:
cls.__ensure_built()
assert cls._STATES_MAP is not None
return cls._STATES_MAP[label]
Expand Down Expand Up @@ -261,11 +262,11 @@ def init(self) -> None:
def __str__(self) -> str:
return f'<{self.__class__.__name__}> ({self.state})'

def create_initial_state(self) -> State:
return self.get_state_class(self.initial_state_label())(self)
def create_initial_state(self, *args: Any, **kwargs: Any) -> State:
return self.get_state_class(self.initial_state_label())(self, *args, **kwargs)

@property
def state(self) -> Optional[LABEL_TYPE]:
def state(self) -> Any:
if self._state is None:
return None
return self._state.LABEL
Expand Down Expand Up @@ -373,7 +374,7 @@ def _exit_current_state(self, next_state: State) -> None:
return # Nothing to exit

if next_state.LABEL not in self._state.ALLOWED:
raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}')
raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.LABEL}')
self._fire_state_event(StateEventHook.EXITING_STATE, next_state)
self._state.exit()

Expand Down
65 changes: 41 additions & 24 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,20 @@
import traceback
from enum import Enum
from types import TracebackType
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast, final
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Optional,
Protocol,
Tuple,
Type,
Union,
cast,
final,
runtime_checkable,
)

import yaml
from yaml.loader import Loader
Expand All @@ -22,7 +35,7 @@
from . import exceptions, futures, persistence, utils
from .base import state_machine as st
from .lang import NULL
from .persistence import auto_persist
from .persistence import LoadSaveContext, auto_persist
from .utils import SAVED_STATE_TYPE

__all__ = [
Expand Down Expand Up @@ -147,14 +160,19 @@ class ProcessState(Enum):
KILLED = 'killed'


@runtime_checkable
class Savable(Protocol):
def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ...


@final
@auto_persist('args', 'kwargs')
class Created(persistence.Savable):
LABEL = ProcessState.CREATED
ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED}
LABEL: ClassVar = ProcessState.CREATED
ALLOWED: ClassVar = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED}

RUN_FN = 'run_fn'
is_terminal = False
is_terminal: ClassVar[bool] = False

def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
assert run_fn is not None
Expand Down Expand Up @@ -186,8 +204,8 @@ def exit(self) -> None: ...
@final
@auto_persist('args', 'kwargs')
class Running(persistence.Savable):
LABEL = ProcessState.RUNNING
ALLOWED = {
LABEL: ClassVar = ProcessState.RUNNING
ALLOWED: ClassVar = {
ProcessState.RUNNING,
ProcessState.WAITING,
ProcessState.FINISHED,
Expand All @@ -203,7 +221,7 @@ class Running(persistence.Savable):
_running: bool = False
_run_handle = None

is_terminal = False
is_terminal: ClassVar[bool] = False

def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
assert run_fn is not None
Expand Down Expand Up @@ -291,18 +309,17 @@ def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> st.Stat
else:
raise ValueError('Unrecognised command')

return cast(st.State, state) # casting from base.State to process.State
return state

def enter(self) -> None: ...

def exit(self) -> None: ...


@final
@auto_persist('msg', 'data')
class Waiting(persistence.Savable):
LABEL = ProcessState.WAITING
ALLOWED = {
LABEL: ClassVar = ProcessState.WAITING
ALLOWED: ClassVar = {
ProcessState.RUNNING,
ProcessState.WAITING,
ProcessState.KILLED,
Expand All @@ -314,7 +331,7 @@ class Waiting(persistence.Savable):

_interruption = None

is_terminal = False
is_terminal: ClassVar[bool] = False

def __str__(self) -> str:
state_info = super().__str__()
Expand Down Expand Up @@ -355,7 +372,7 @@ def interrupt(self, reason: Exception) -> None:
# This will cause the future in execute() to raise the exception
self._waiting_future.set_exception(reason)

async def execute(self) -> st.State: # type: ignore
async def execute(self) -> st.State:
try:
result = await self._waiting_future
except Interruption:
Expand All @@ -374,7 +391,7 @@ async def execute(self) -> st.State: # type: ignore
self.process, ProcessState.RUNNING, process=self.process, done_callback=self.done_callback, *result
)

return cast(st.State, next_state) # casting from base.State to process.State
return next_state

def resume(self, value: Any = NULL) -> None:
assert self._waiting_future is not None, 'Not yet waiting'
Expand All @@ -398,13 +415,13 @@ class Excepted(persistence.Savable):
:param traceback: An optional exception traceback
"""

LABEL = ProcessState.EXCEPTED
ALLOWED: set[str] = set()
LABEL: ClassVar = ProcessState.EXCEPTED
ALLOWED: ClassVar[set[str]] = set()

EXC_VALUE = 'ex_value'
TRACEBACK = 'traceback'

is_terminal = True
is_terminal: ClassVar = True

def __init__(
self,
Expand Down Expand Up @@ -466,10 +483,10 @@ class Finished(persistence.Savable):
:param successful: Boolean for the exit code is ``0`` the process is successful.
"""

LABEL = ProcessState.FINISHED
ALLOWED: set[str] = set()
LABEL: ClassVar = ProcessState.FINISHED
ALLOWED: ClassVar[set[str]] = set()

is_terminal = True
is_terminal: ClassVar[bool] = True

def __init__(self, result: Any, successful: bool) -> None:
self.result = result
Expand All @@ -495,10 +512,10 @@ class Killed(persistence.Savable):
:param msg: An optional message explaining the reason for the process termination.
"""

LABEL = ProcessState.KILLED
ALLOWED: set[str] = set()
LABEL: ClassVar = ProcessState.KILLED
ALLOWED: ClassVar[set[str]] = set()

is_terminal = True
is_terminal: ClassVar[bool] = True

def __init__(self, msg: Optional[MessageType]):
"""
Expand Down
22 changes: 11 additions & 11 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-
"""The main Process module"""

from __future__ import annotations

import abc
import asyncio
import contextlib
Expand Down Expand Up @@ -58,8 +60,8 @@
StateMachine,
StateMachineError,
TransitionFailed,
event,
create_state,
event,
)
from .base.utils import call_with_super_check, super_check
from .event_helper import EventHelper
Expand Down Expand Up @@ -195,7 +197,7 @@ def get_states(cls) -> Sequence[Type[state_machine.State]]:
)

@classmethod
def get_state_classes(cls) -> Dict[Hashable, Type[state_machine.State]]:
def get_state_classes(cls) -> dict[process_states.ProcessState, Type[state_machine.State]]:
# A mapping of the State constants to the corresponding state class
return {
process_states.ProcessState.CREATED: process_states.Created,
Expand Down Expand Up @@ -634,7 +636,9 @@ def save_instance_state(
"""
super().save_instance_state(out_state, save_context)

out_state['_state'] = self._state.save()
# FIXME: the combined ProcessState protocol should cover the case
if isinstance(self._state, process_states.Savable):
out_state['_state'] = self._state.save()

# Inputs/outputs
if self.raw_inputs is not None:
Expand Down Expand Up @@ -1205,7 +1209,7 @@ def fail(self, exception: Optional[BaseException], traceback: Optional[Traceback
:param traceback: Optional exception traceback
"""
# state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED]
new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=trace_back)
new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=traceback)
self.transition_to(new_state)

def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]:
Expand All @@ -1225,7 +1229,7 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]
# Already killing
return self._killing

if self._stepping:
if self._stepping and isinstance(self._state, Interruptable):
# Ask the step function to pause by setting this flag and giving the
# caller back a future
interrupt_exception = process_states.KillInterruption(msg)
Expand All @@ -1252,10 +1256,7 @@ def create_initial_state(self) -> state_machine.State:
:return: A Created state
"""
return cast(
state_machine.State,
self.get_state_class(process_states.ProcessState.CREATED)(self, self.run),
)
return self.get_state_class(process_states.ProcessState.CREATED)(self, self.run)

def recreate_state(self, saved_state: persistence.Bundle) -> state_machine.State:
"""
Expand Down Expand Up @@ -1326,9 +1327,8 @@ async def step(self) -> None:
raise
except Exception:
# Overwrite the next state to go to excepted directly
_, exception, traceback = sys.exc_info()
next_state = create_state(
self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=traceback
self, process_states.ProcessState.EXCEPTED, exception=sys.exc_info()[1], traceback=sys.exc_info()[2]
)
self._set_interrupt_action(None)

Expand Down
4 changes: 2 additions & 2 deletions src/plumpy/workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
Any,
Callable,
Dict,
Hashable,
List,
Mapping,
MutableSequence,
Expand Down Expand Up @@ -71,6 +70,7 @@ def get_outline(self) -> Union['_Instruction', '_FunctionCall']:
return self._outline


# FIXME: better use composition here
@persistence.auto_persist('_awaiting')
class Waiting(process_states.Waiting):
"""Overwrite the waiting state"""
Expand Down Expand Up @@ -124,7 +124,7 @@ class WorkChain(mixins.ContextMixin, processes.Process):
_CONTEXT = 'CONTEXT'

@classmethod
def get_state_classes(cls) -> Dict[Hashable, Type[state_machine.State]]:
def get_state_classes(cls) -> Dict[process_states.ProcessState, Type[state_machine.State]]:
states_map = super().get_state_classes()
states_map[process_states.ProcessState.WAITING] = Waiting
return states_map
Expand Down

0 comments on commit 469298a

Please sign in to comment.