Skip to content

Commit

Permalink
Furthur simplipy _create_state_instant only create state from class
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Nov 30, 2024
1 parent 989f995 commit 19cc257
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 18 deletions.
23 changes: 10 additions & 13 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
Set,
Type,
Union,
cast,
)

from plumpy.futures import Future
Expand All @@ -44,7 +43,7 @@ class StateEntryFailed(Exception):
Failed to enter a state, can provide the next state to go to via this exception
"""

def __init__(self, state: Hashable = None, *args: Any, **kwargs: Any) -> None: # pylint: disable=keyword-arg-before-vararg
def __init__(self, state: type["State"] = None, *args: Any, **kwargs: Any) -> None: # pylint: disable=keyword-arg-before-vararg
super().__init__("failed to enter state")
self.state = state
self.args = args
Expand Down Expand Up @@ -330,12 +329,12 @@ def on_terminated(self) -> None:
"""Called when a terminal state is entered"""

def transition_to(
self, new_state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any
self, new_state: Union[State, Type[State]], *args: Any, **kwargs: Any
) -> None:
"""Transite to the new state.
The new target state will be create lazily when the state
is not yet instantiated, which will happened for states not in the expect path such as
is not yet instantiated, which will happened for states not in the expect path such as
pause and kill.
"""
assert (
Expand Down Expand Up @@ -403,6 +402,10 @@ def set_debug(self, enabled: bool) -> None:
self._debug: bool = enabled

def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State:
# XXX: this method create state from label, which is duplicate as _create_state_instance and less generic
# because the label is defined after the state and required to be know before calling this function.
# This method should be replaced by `_create_state_instance`.
# aiida-core using this method for its Waiting state override.
try:
return self.get_states_map()[state_label](self, *args, **kwargs) # pylint: disable=unsubscriptable-object
except KeyError:
Expand Down Expand Up @@ -436,15 +439,9 @@ def _enter_next_state(self, next_state: State) -> None:
self._fire_state_event(StateEventHook.ENTERED_STATE, last_state)

def _create_state_instance(
self, state: Union[Hashable, Type[State]], *args: Any, **kwargs: Any
self, state_cls: type[State], *args: Any, **kwargs: Any
) -> State:
# build from state class
if inspect.isclass(state) and issubclass(state, State):
state_cls = state
else:
try:
state_cls = self.get_states_map()[cast(Hashable, state)] # pylint: disable=unsubscriptable-object
except KeyError:
raise ValueError(f"{state} is not a valid state")
if state_cls.LABEL not in self.get_states_map():
raise ValueError(f"{state_cls.LABEL} is not a valid state")

return state_cls(self, *args, **kwargs)
10 changes: 5 additions & 5 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ def on_finish(self, result: Any, successful: bool) -> None:
if successful:
validation_error = self.spec().outputs.validate(self.outputs)
if validation_error:
raise StateEntryFailed(process_states.ProcessState.FINISHED, result, False)
raise StateEntryFailed(process_states.Finished, result, False)

self.future().set_result(self.outputs)

Expand Down Expand Up @@ -1017,7 +1017,7 @@ def transition_failed(
if final_state == process_states.ProcessState.CREATED:
raise exception.with_traceback(trace)

self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace)
self.transition_to(process_states.Excepted, exception, trace)

def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]:
"""Pause the process.
Expand Down Expand Up @@ -1082,7 +1082,7 @@ def do_kill(_next_state: process_states.State) -> Any:
try:
# Ignore the next state
# __import__('ipdb').set_trace()
self.transition_to(process_states.ProcessState.KILLED, exception)
self.transition_to(process_states.Killed, exception)
return True
finally:
self._killing = None
Expand Down Expand Up @@ -1134,7 +1134,7 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac
:param exception: The exception that caused the failure
:param trace_back: Optional exception traceback
"""
self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace_back)
self.transition_to(process_states.Excepted, exception, trace_back)

def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]:
"""
Expand Down Expand Up @@ -1162,7 +1162,7 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]
self._state.interrupt(interrupt_exception)
return cast(futures.CancellableAction, self._interrupt_action)

self.transition_to(process_states.ProcessState.KILLED, msg)
self.transition_to(process_states.Killed, msg)
return True

@property
Expand Down

0 comments on commit 19cc257

Please sign in to comment.