Skip to content

Commit

Permalink
Refactoring create_state as static function initialize state from label
Browse files Browse the repository at this point in the history
create_state refact

Hashable initialized + parameters passed to Hashable

Fix pre-commit errors
  • Loading branch information
unkcpz committed Dec 4, 2024
1 parent 080d036 commit 6bfb87d
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 174 deletions.
45 changes: 19 additions & 26 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

_LOGGER = logging.getLogger(__name__)

LABEL_TYPE = Union[None, enum.Enum, str]
EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, Optional['State']], None]


Expand Down Expand Up @@ -131,9 +130,12 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any:

@runtime_checkable
class State(Protocol):
LABEL: ClassVar[LABEL_TYPE]
LABEL: ClassVar[Any]
ALLOWED: ClassVar[set[Any]]
is_terminal: ClassVar[bool]

def __init__(self, *args: Any, **kwargs: Any): ...

def enter(self) -> None: ...

def exit(self) -> None: ...
Expand All @@ -146,7 +148,6 @@ def interrupt(self, reason: Exception) -> None: ...

@runtime_checkable
class Proceedable(Protocol):

def execute(self) -> State | None:
"""
Execute the state, performing the actions that this state is responsible for.
Expand All @@ -155,6 +156,14 @@ def execute(self) -> State | None:
...


def create_state(st: StateMachine, state_label: Hashable, *args: Any, **kwargs: Any) -> State:
if state_label not in st.get_states_map():
raise ValueError(f'{state_label} is not a valid state')

state_cls = st.get_states_map()[state_label]
return state_cls(*args, **kwargs)


class StateEventHook(enum.Enum):
"""
Hooks that can be used to register callback at various points in the state transition
Expand Down Expand Up @@ -203,13 +212,13 @@ def get_states(cls) -> Sequence[Type[State]]:
raise RuntimeError('States not defined')

@classmethod
def initial_state_label(cls) -> LABEL_TYPE:
def initial_state_label(cls) -> Any:
cls.__ensure_built()
assert cls.STATES is not None
return cls.STATES[0].LABEL

@classmethod
def get_state_class(cls, label: LABEL_TYPE) -> Type[State]:
def get_state_class(cls, label: Any) -> Type[State]:
cls.__ensure_built()
assert cls._STATES_MAP is not None
return cls._STATES_MAP[label]
Expand Down Expand Up @@ -253,11 +262,11 @@ def init(self) -> None:
def __str__(self) -> str:
return f'<{self.__class__.__name__}> ({self.state})'

def create_initial_state(self) -> State:
return self.get_state_class(self.initial_state_label())(self)
def create_initial_state(self, *args: Any, **kwargs: Any) -> State:
return self.get_state_class(self.initial_state_label())(self, *args, **kwargs)

@property
def state(self) -> Optional[LABEL_TYPE]:
def state(self) -> Any:
if self._state is None:
return None
return self._state.LABEL
Expand Down Expand Up @@ -297,6 +306,7 @@ def transition_to(self, new_state: State | None, **kwargs: Any) -> None:
The arguments are passed to the state class to create state instance.
(process arg does not need to pass since it will always call with 'self' as process)
"""
print(f'try: {self._state} -> {new_state}')
assert not self._transitioning, 'Cannot call transition_to when already transitioning state'

if new_state is None:
Expand Down Expand Up @@ -353,17 +363,6 @@ def get_debug(self) -> bool:
def set_debug(self, enabled: bool) -> None:
self._debug: bool = enabled

def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State:
# XXX: this method create state from label, which is duplicate as _create_state_instance and less generic
# because the label is defined after the state and required to be know before calling this function.
# This method should be replaced by `_create_state_instance`.
# aiida-core using this method for its Waiting state override.
try:
state_cls = self.get_states_map()[state_label]
return state_cls(self, *args, **kwargs)
except KeyError:
raise ValueError(f'{state_label} is not a valid state')

def _exit_current_state(self, next_state: State) -> None:
"""Exit the given state"""

Expand All @@ -375,7 +374,7 @@ def _exit_current_state(self, next_state: State) -> None:
return # Nothing to exit

if next_state.LABEL not in self._state.ALLOWED:
raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}')
raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.LABEL}')
self._fire_state_event(StateEventHook.EXITING_STATE, next_state)
self._state.exit()

Expand All @@ -386,9 +385,3 @@ def _enter_next_state(self, next_state: State) -> None:
next_state.enter()
self._state = next_state
self._fire_state_event(StateEventHook.ENTERED_STATE, last_state)

def _create_state_instance(self, state_cls: type[State], **kwargs: Any) -> State:
if state_cls.LABEL not in self.get_states_map():
raise ValueError(f'{state_cls.LABEL} is not a valid state')

return state_cls(self, **kwargs)
Loading

0 comments on commit 6bfb87d

Please sign in to comment.