Skip to content

Commit

Permalink
rali
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 11, 2024
1 parent e82824b commit 80b3458
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 14 deletions.
12 changes: 8 additions & 4 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,8 @@ def transition_to(self, new_state: State | None, **kwargs: Any) -> None:
assert not self._transitioning, 'Cannot call transition_to when already transitioning state'

if new_state is None:
# early return if the new state is `None`
# it can happened when transit from terminal state
return None

initial_state_label = self._state.LABEL if self._state is not None else None
Expand Down Expand Up @@ -411,8 +413,10 @@ def _enter_next_state(self, next_state: State) -> None:
self._state = next_state
self._fire_state_event(StateEventHook.ENTERED_STATE, last_state)

def _create_state_instance(self, state_cls: type[State], **kwargs: Any) -> State:
if state_cls.LABEL not in self.get_states_map():
raise ValueError(f'{state_cls.LABEL} is not a valid state')
def _create_state_instance(self, state_cls: Hashable, **kwargs: Any) -> State:
if state_cls not in self.get_states_map():
raise ValueError(f'{state_cls} is not a valid state')

return state_cls(self, **kwargs)
cls = self.get_states_map()[state_cls]

return cls(self, **kwargs)
12 changes: 10 additions & 2 deletions src/plumpy/process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class Intent:


class PlayMessage:
"""The play message send over communicator."""

@classmethod
def build(cls, message: str | None = None) -> MessageType:
return {
Expand All @@ -57,6 +59,8 @@ def build(cls, message: str | None = None) -> MessageType:


class PauseMessage:
"""The pause message send over communicator."""

@classmethod
def build(cls, message: str | None = None) -> MessageType:
return {
Expand All @@ -66,16 +70,20 @@ def build(cls, message: str | None = None) -> MessageType:


class KillMessage:
"""The kill message send over communicator."""

@classmethod
def build(cls, message: str | None = None, force: bool = False) -> MessageType:
def build(cls, message: str | None = None, force_kill: bool = False) -> MessageType:
return {
INTENT_KEY: Intent.KILL,
MESSAGE_KEY: message,
FORCE_KILL_KEY: force,
FORCE_KILL_KEY: force_kill,
}


class StatusMessage:
"""The status message send over communicator."""

@classmethod
def build(cls, message: str | None = None) -> MessageType:
return {
Expand Down
16 changes: 8 additions & 8 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,8 +1066,9 @@ def transition_failed(
if final_state == process_states.ProcessState.CREATED:
raise exception.with_traceback(trace)

state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED]
new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace)
new_state = self._create_state_instance(
process_states.ProcessState.EXCEPTED, exception=exception, trace_back=trace
)
self.transition_to(new_state)

def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]:
Expand Down Expand Up @@ -1131,8 +1132,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu

def do_kill(_next_state: process_states.State) -> Any:
try:
state_class = self.get_states_map()[process_states.ProcessState.KILLED]
new_state = self._create_state_instance(state_class, msg=exception.msg)
new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=exception.msg)
self.transition_to(new_state)
return True
finally:
Expand Down Expand Up @@ -1185,8 +1185,9 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac
:param exception: The exception that caused the failure
:param trace_back: Optional exception traceback
"""
state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED]
new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace_back)
new_state = self._create_state_instance(
process_states.ProcessState.EXCEPTED, exception=exception, trace_back=trace_back
)
self.transition_to(new_state)

def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]:
Expand Down Expand Up @@ -1215,8 +1216,7 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]
self._state.interrupt(interrupt_exception)
return cast(futures.CancellableAction, self._interrupt_action)

state_class = self.get_states_map()[process_states.ProcessState.KILLED]
new_state = self._create_state_instance(state_class, msg=msg)
new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=msg)
self.transition_to(new_state)
return True

Expand Down

0 comments on commit 80b3458

Please sign in to comment.