From 84cabf69f5eb4cc872b496b2f1b04584a7e2b2bd Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 2 Dec 2020 10:17:24 +0100 Subject: [PATCH] Runtime signals (#161) * wip implementation * full implementation Process runtime methods may also return signals. * fix/update some tests * black * unrelated doc change * add api docs * fix handling CONTINUE signal * add tests * black * update release notes * update doc (general concepts) and docstrings * doc: add user guide subsection * typo --- doc/api.rst | 1 + doc/framework.rst | 32 +++++---- doc/monitor.rst | 47 ++++++++++-- doc/run_model.rst | 9 +-- doc/whats_new.rst | 2 + xsimlab/__init__.py | 1 + xsimlab/drivers.py | 20 ++++-- xsimlab/hook.py | 7 +- xsimlab/model.py | 130 +++++++++++++++++++++++++++------- xsimlab/monitoring.py | 2 +- xsimlab/process.py | 88 ++++++++++++++++++++--- xsimlab/tests/test_process.py | 15 ++-- xsimlab/tests/test_signals.py | 111 +++++++++++++++++++++++++++++ 13 files changed, 395 insertions(+), 70 deletions(-) create mode 100644 xsimlab/tests/test_signals.py diff --git a/doc/api.rst b/doc/api.rst index ed486811..6b286a80 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -189,3 +189,4 @@ Model runtime monitoring monitoring.ProgressBar runtime_hook RuntimeHook + RuntimeSignal diff --git a/doc/framework.rst b/doc/framework.rst index cbcf509c..2c1a880b 100644 --- a/doc/framework.rst +++ b/doc/framework.rst @@ -224,18 +224,26 @@ A model run is divided into four successive stages: 3. finalize step 4. finalization -During a simulation, stages 1 and 4 are run only once while stages 2 -and 3 are repeated for a given number of (time) steps. Stage 4 is run even if -an exception is raised during stage 1, 2 or 3. - -Each process-ified class may provide its own computation instructions -by implementing specific methods named ``.initialize()``, -``.run_step()``, ``.finalize_step()`` and ``.finalize()`` for each -stage above, respectively. Note that this is entirely optional. For -example, time-independent processes (e.g., for setting model grids) -usually implement stage 1 only. In a few cases, the role of a process -may even consist of just declaring some variables that are used -elsewhere. +During a simulation, stages 1 and 4 are run only once while stages 2 and 3 are +repeated for a given number of (time) steps. Stage 4 is always run even when an +exception is raised during stage 1, 2 or 3. + +Each :func:`~xsimlab.process` decorated class may provide its own computation +instructions by implementing specific "runtime" methods named ``.initialize()``, +``.run_step()``, ``.finalize_step()`` and ``.finalize()`` for each stage above, +respectively. Note that this is entirely optional. For example, time-independent +processes (e.g., for setting model grids) usually implement stage 1 only. In a +few cases, the role of a process may even consist of just declaring some +variables that are used elsewhere. + +Runtime methods may be decorated by :func:`~xsimlab.runtime`. This is useful if +one needs to access the value of some runtime-specific variables like the +current step, time step duration, etc. from within those methods. Runtime +methods may also return a :func:`~xsimlab.RuntimeSignal` to control the +workflow, e.g., break the execution of the current stage. + +It is also possible to monitor and/or control simulations independently of any +model, using runtime hooks. See Section :ref:`monitor`. Get / set variable values inside a process ------------------------------------------ diff --git a/doc/monitor.rst b/doc/monitor.rst index f0a9764b..68b2a96e 100644 --- a/doc/monitor.rst +++ b/doc/monitor.rst @@ -16,25 +16,24 @@ it exemplifies how to create your own custom monitoring. sys.path.append('scripts') from advection_model import advect_model, advect_model_src -The following imports are necessary for the examples below. +Let's use the following setup for the examples below. It is based on the +``advect_model`` created in Section :ref:`create_model`. .. ipython:: python import xsimlab as xs -.. ipython:: python - :suppress: - in_ds = xs.create_setup( model=advect_model, clocks={ - 'time': np.linspace(0., 1., 5), + 'time': np.linspace(0., 1., 6), }, input_vars={ 'grid': {'length': 1.5, 'spacing': 0.01}, 'init': {'loc': 0.3, 'scale': 0.1}, 'advect__v': 1. }, + output_vars={'profile__u': 'time'} ) @@ -172,3 +171,41 @@ methods that may share some state: with PrintStepTime(): in_ds.xsimlab.run(model=advect_model) + + +Control simulation runtime +-------------------------- + +Runtime hook functions may return a :class:`~xsimlab.RuntimeSignal` so that you +can control the simulation workflow (e.g., skip the current stage or process, +break the simulation time steps) based on some condition or some computed value. + +In the example below, the simulation stops as soon as the gaussian pulse (peak +value) has been advected past ``x = 0.4``. + +.. ipython:: + + In [2]: @xs.runtime_hook("run_step", "model", "post") + ...: def maybe_stop(model, context, state): + ...: peak_idx = np.argmax(state[('profile', 'u')]) + ...: peak_x = state[('grid', 'x')][peak_idx] + ...: + ...: if peak_x > 0.4: + ...: print("Peak crossed x=0.4, stop simulation!") + ...: return xs.RuntimeSignal.BREAK + ...: + + In [3]: out_ds = in_ds.xsimlab.run( + ...: model=advect_model, + ...: hooks=[print_step_start, maybe_stop] + ...: ) + +Even when a simulation stops early like in the example above, the resulting +xarray Dataset still contains all time steps defined in the input Dataset. +Output variables have fill (masked) values for the time steps that were not run, +as shown below with the ``nan`` values for ``profile__u`` (fill values are not +stored physically in the Zarr output store). + +.. ipython:: python + + out_ds diff --git a/doc/run_model.rst b/doc/run_model.rst index 3b2a4ade..a055f03a 100644 --- a/doc/run_model.rst +++ b/doc/run_model.rst @@ -106,20 +106,21 @@ IPython (Jupyter) magic commands Writing a new setup from scratch may be tedious, especially for big models with a lot of input variables. If you are using IPython (Jupyter), xarray-simlab -provides convenient commands that can be activated with: +provides helper commands that are available after loading the +``xsimlab.ipython`` extension, i.e., .. ipython:: python %load_ext xsimlab.ipython The ``%create_setup`` magic command auto-generates the -:func:`~xsimlab.create_setup` code cell above from a given model: +:func:`~xsimlab.create_setup` code cell above from a given model, e.g., .. ipython:: python - %create_setup advect_model --default --comment + %create_setup advect_model --default --verbose -The ``--default`` and ``--comment`` options respectively add default values found +The ``--default`` and ``--verbose`` options respectively add default values found for input variables in the model and input variable description as line comments. Full command help: diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 95460fd1..3a7101fe 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -30,6 +30,8 @@ Enhancements - Added :func:`~xsimlab.group_dict` variable (:issue:`159`). - Added :func:`~xsimlab.global_ref` variable for model-wise implicit linking of variables in separate processes, based on global names (:issue:`160`). +- Added :class:`~xsimlab.RuntimeSignal` for controlling simulation workflow from + process runtime methods and/or runtime hook functions (:issue:`161`). Bug fixes ~~~~~~~~~ diff --git a/xsimlab/__init__.py b/xsimlab/__init__.py index 94e81bbe..f0399169 100644 --- a/xsimlab/__init__.py +++ b/xsimlab/__init__.py @@ -11,6 +11,7 @@ process, process_info, runtime, + RuntimeSignal, variable_info, ) from .variable import ( diff --git a/xsimlab/drivers.py b/xsimlab/drivers.py index 4fb47f61..0d8a5d47 100644 --- a/xsimlab/drivers.py +++ b/xsimlab/drivers.py @@ -5,6 +5,7 @@ import pandas as pd from .hook import flatten_hooks, group_hooks, RuntimeHook +from .process import RuntimeSignal from .stores import ZarrSimulationStore from .utils import get_batch_size @@ -345,16 +346,27 @@ def _run( in_vars = _get_input_vars(ds_step, model) model.update_state(in_vars, validate=validate_inputs, ignore_static=False) - model.execute("run_step", rt_context, **execute_kwargs) + signal = model.execute("run_step", rt_context, **execute_kwargs) + + if signal == RuntimeSignal.BREAK: + break store.write_output_vars(batch, step, model=model) - model.execute("finalize_step", rt_context, **execute_kwargs) + # after writing output variables so that index positions + # are properly updated in store. + if signal == RuntimeSignal.CONTINUE: + continue + + signal = model.execute("finalize_step", rt_context, **execute_kwargs) + + if signal == RuntimeSignal.BREAK: + break store.write_output_vars(batch, -1, model=model) store.write_index_vars(model=model) - except Exception as error: - raise error + except Exception: + raise finally: model.execute("finalize", rt_context, **execute_kwargs) diff --git a/xsimlab/hook.py b/xsimlab/hook.py index d447e001..b3998d3e 100644 --- a/xsimlab/hook.py +++ b/xsimlab/hook.py @@ -1,18 +1,17 @@ import inspect +from enum import Enum from typing import Callable, Dict, Iterable, List, Union from .process import SimulationStage -__all__ = ("runtime_hook", "RuntimeHook") - - def runtime_hook(stage, level="model", trigger="post"): """Decorator that allows to call a function or a method at one or more specific times during a simulation. The decorated function / method must have the following signature: - ``func(model, context, state)`` or ``meth(self, model, context, state)``. + ``func(model, context, state)`` or ``meth(self, model, context, state)``. It + may return a :class:`RuntimeSignal` (optional). Parameters ---------- diff --git a/xsimlab/model.py b/xsimlab/model.py index 83b7a92c..135aaf61 100644 --- a/xsimlab/model.py +++ b/xsimlab/model.py @@ -11,6 +11,7 @@ filter_variables, get_process_cls, get_target_variable, + RuntimeSignal, SimulationStage, ) from .utils import AttrMapping, Frozen, variables_dict @@ -795,25 +796,61 @@ def _call_hooks(self, hooks, runtime_context, stage, level, trigger): try: event_hooks = hooks[stage][level][trigger] except KeyError: - return + return RuntimeSignal.NONE + + signals = [] for h in event_hooks: - h(self, Frozen(runtime_context), Frozen(self.state)) + s = h(self, Frozen(runtime_context), Frozen(self.state)) + + if s is None: + s = RuntimeSignal(0) + else: + s = RuntimeSignal(s) + + signals.append(s) + + # Signal with highest value has highest priority + return RuntimeSignal(max([s.value for s in signals])) def _execute_process( self, p_obj, stage, runtime_context, hooks, validate, state=None ): + """Internal process execution method, which calls the process object's + executor. + + A state may be passed to the executor instead of using the executor's + state (this is to avoid stateful objects when calling the executor + during execution of a Dask graph). + + The process executor returns a partial state (only the variables that + have been updated by the executor, which will be needed for executing + further tasks in the Dask graph). + + This method returns this updated state as well as any runtime signal returned + by the hook functions and/or the executor (the one with highest priority). + + """ executor = p_obj.__xsimlab_executor__ p_name = p_obj.__xsimlab_name__ - self._call_hooks(hooks, runtime_context, stage, "process", "pre") - out_state = executor.execute(p_obj, stage, runtime_context, state=state) - self._call_hooks(hooks, runtime_context, stage, "process", "post") + signal_pre = self._call_hooks(hooks, runtime_context, stage, "process", "pre") + + if signal_pre.value > 0: + return p_name, ({}, signal_pre) + + state_out, signal_out = executor.execute( + p_obj, stage, runtime_context, state=state + ) + + signal_post = self._call_hooks(hooks, runtime_context, stage, "process", "post") + if signal_post.value > signal_out.value: + signal_out = signal_post if validate: self.validate(self._processes_to_validate[p_name]) - return p_name, out_state + return p_name, (state_out, signal_out) def _build_dask_graph(self, execute_args): """Build a custom, 'stateless' graph of tasks (process execution) that will @@ -821,34 +858,56 @@ def _build_dask_graph(self, execute_args): """ - def exec_process(p_obj, model_state, out_states): - # update model state with output state from all dependent processes + def exec_process(p_obj, model_state, exec_outputs): + # update model state with output states from all dependent processes + # gather signals returned by all dependent processes and sort them by highest priority state = {} + signal = RuntimeSignal.NONE + state.update(model_state) - for _, s in out_states: - state.update(s) - return self._execute_process(p_obj, *execute_args, state=state) + for _, (state_out, signal_out) in exec_outputs: + state.update(state_out) + + if signal_out.value > signal.value: + signal = signal_out + + if signal == RuntimeSignal.BREAK: + # received a BREAK signal from the execution of a dependent process + # -> skip execution of current process as well as all downstream processes + # in the graph (by forwarding the signal). + return p_obj.__xsimlab_name__, ({}, signal) + else: + return self._execute_process(p_obj, *execute_args, state=state) dsk = {} for p_name, p_deps in self._dep_processes.items(): dsk[p_name] = (exec_process, self._processes[p_name], self._state, p_deps) - # add a node to gather output state from all executed processes - dsk["_gather"] = (lambda out_states: dict(out_states), list(self._processes)) + # add a node to gather output signals and state from all executed processes + dsk["_gather"] = ( + lambda exec_outputs: dict(exec_outputs), + list(self._processes), + ) return dsk - def _merge_and_update_state(self, out_states): - """Collect, merge together and update model state from the output - states returned by all executed processes (dask graph). + def _merge_exec_outputs(self, exec_outputs) -> RuntimeSignal: + """Collect and merge process execution outputs (from dask graph). + + - combine all output states and update model's state. + - sort all output runtime signals and return the signal with highest priority. """ new_state = {} + signal = RuntimeSignal.NONE - # process order matters! + # process order matters for properly updating state! for p_name in self._processes: - new_state.update(out_states[p_name]) + state_out, signal_out = exec_outputs[p_name] + new_state.update(state_out) + if signal_out.value > signal.value: + signal = signal_out self._state.update(new_state) @@ -857,6 +916,8 @@ def _merge_and_update_state(self, out_states): for p_obj in self._processes.values(): p_obj.__xsimlab_state__ = self._state + return signal + def _clear_od_cache(self): """Clear cached values of on-demand variables.""" @@ -896,6 +957,13 @@ def execute( Dask's scheduler used to run the stage in parallel (Dask's threads scheduler is used as failback). + Returns + ------- + signal : :class:`RuntimeSignal` + Signal with hightest priority among all signals returned by hook + functions and/or process runtime methods, if any. Otherwise, + returns ``RuntimeSignal.NONE``. + Notes ----- Even when run in parallel, xarray-simlab ensures that processes will @@ -925,33 +993,41 @@ def execute( if hooks is None: hooks = {} - dsk_get = dask.base.get_scheduler(scheduler=scheduler) - if dsk_get is None: - dsk_get = dask.threaded.get - stage = SimulationStage(stage) execute_args = (stage, runtime_context, hooks, validate) self._clear_od_cache() - self._call_hooks(hooks, runtime_context, stage, "model", "pre") + signal_pre = self._call_hooks(hooks, runtime_context, stage, "model", "pre") + + if signal_pre.value > 0: + return signal_pre if parallel: + dsk_get = dask.base.get_scheduler(scheduler=scheduler) + if dsk_get is None: + dsk_get = dask.threaded.get + dsk = self._build_dask_graph(execute_args) - out_states = dsk_get(dsk, "_gather", scheduler=scheduler) + exec_outputs = dsk_get(dsk, "_gather", scheduler=scheduler) # TODO: without this -> flaky tests (don't know why) # state is not properly updated -> error when writing output vars in store if isinstance(scheduler, Client): time.sleep(0.001) - self._merge_and_update_state(out_states) + signal_process = self._merge_exec_outputs(exec_outputs) else: for p_obj in self._processes.values(): - self._execute_process(p_obj, *execute_args) + _, (_, signal_process) = self._execute_process(p_obj, *execute_args) + + if signal_process == RuntimeSignal.BREAK: + break + + signal_post = self._call_hooks(hooks, runtime_context, stage, "model", "post") - self._call_hooks(hooks, runtime_context, stage, "model", "post") + return signal_post def clone(self): """Clone the Model. diff --git a/xsimlab/monitoring.py b/xsimlab/monitoring.py index b6fa2d26..17f49dd5 100644 --- a/xsimlab/monitoring.py +++ b/xsimlab/monitoring.py @@ -78,7 +78,7 @@ def init_bar(self, model, context, state): self.pbar_model = self.tqdm(**self.tqdm_kwargs) @runtime_hook("initialize", trigger="post") - def update_init(self, mode, context, state): + def update_init(self, model, context, state): self.pbar_model.update(1) @runtime_hook("run_step", trigger="post") diff --git a/xsimlab/process.py b/xsimlab/process.py index 8427dfe2..f0f884c5 100644 --- a/xsimlab/process.py +++ b/xsimlab/process.py @@ -308,10 +308,67 @@ def getter_state_or_on_demand(self): return property(fget=getter_state_or_on_demand, doc=var_details(var)) +class RuntimeSignal(Enum): + """Signal controlling simulation runtime. + + Such signal may be returned either by runtime methods of :func:`process` + decorated classes or by :func:`runtime_hook` decorated functions or methods. + + All signals listed below are ordered from the lowest to the highest + priority. If multiple signals are emitted at the same time during a + simulation, the one with the highest priority prevails. + + One signal may affect the simulation workflow in different ways, depending on + whether it is returned during the execution of a simulation stage (process-level) + or just before/after a stage (model-level). + + Attributes + ---------- + NONE : int + Do nothing (blank signal). Used by default when a process + runtime method or a hook function doesn't explicitly return a value. + Priority = 0. + SKIP : int + If returned by a pre-hook function, skip the current process + (process-level) or the current simulation stage (model-level). + Otherwise do nothing. + Priority = 1. + CONTINUE : int + Skip the current process or current simulation stage (pre-hook). + When returned by a model-level hook function in a looped + simulation stage (e.g., ``run_step``), also skip all remaining + simulation stages (e.g., ``finalize_step``) at the current step and + continues the simulation at the next step. + Priority = 2. + BREAK : int + Skip the current process or current simulation stage (pre-hook). + Also skip all remaining processes in a simulation stage (process-level) + or all remaining steps in looped simulation stages (model-level). + Priority = 3. + + """ + + NONE = 0 + """Do nothing (blank signal).""" + SKIP = 1 + """Skip the current process or simulation stage.""" + CONTINUE = 2 + """Skip all the remaining stages of the current step.""" + BREAK = 3 + """Skip all remaining processes or steps.""" + + class _RuntimeMethodExecutor: """Used to execute a process 'runtime' method in the context of a simulation. + This is a thin wrapper around a process class runtime method, which: + + - maps method argument(s) to their corresponding simulation runtime variable. + - optionally overrides the process object's state with an input state (i.e., when + executed as part of a Dask graph). + - returns a valid runtime signal. + """ def __init__(self, meth, args=None): @@ -328,13 +385,18 @@ def __init__(self, meth, args=None): self.args = tuple(args) - def execute(self, obj, runtime_context, state=None): + def execute(self, p_obj, runtime_context, state=None): if state is not None: - obj.__xsimlab_state__ = state + p_obj.__xsimlab_state__ = state args = [runtime_context[k] for k in self.args] - self.meth(obj, *args) + signal = self.meth(p_obj, *args) + + if signal is None: + return RuntimeSignal.NONE + else: + return RuntimeSignal(signal) def runtime(meth=None, args=None): @@ -453,17 +515,25 @@ def __init__(self, cls): def stages(self): return [k.value for k in self.runtime_executors] - def execute(self, obj, stage, runtime_context, state=None): + def execute(self, p_obj, stage, runtime_context, state=None): + """Maybe execute the given simulation stage (if implemented). + + Returns a state dictionary with only the 'out'/'inout' variables and + a runtime signal. + + """ executor = self.runtime_executors.get(stage) if executor is None: - return {} + return {}, RuntimeSignal.NONE else: - executor.execute(obj, runtime_context, state=state) + signal_out = executor.execute(p_obj, runtime_context, state=state) + + skeys = [p_obj.__xsimlab_state_keys__[k] for k in self.out_vars] + sobj = p_obj.__xsimlab_state__ + state_out = {k: sobj[k] for k in skeys if k in sobj} - skeys = [obj.__xsimlab_state_keys__[k] for k in self.out_vars] - sobj = obj.__xsimlab_state__ - return {k: sobj[k] for k in skeys if k in sobj} + return state_out, signal_out def _process_cls_init(obj): diff --git a/xsimlab/tests/test_process.py b/xsimlab/tests/test_process.py index 807a19e9..992f33c5 100644 --- a/xsimlab/tests/test_process.py +++ b/xsimlab/tests/test_process.py @@ -293,6 +293,7 @@ class P: def run_step(self): self.out_var = self.in_var * 2 + return xs.RuntimeSignal.BREAK @od_var.compute def _dummy(self): @@ -302,11 +303,17 @@ def _dummy(self): executor = m.p.__xsimlab_executor__ state = {("p", "in_var"): 1} - expected = {("p", "out_var"): 2} - actual = executor.execute(m.p, SimulationStage.RUN_STEP, {}, state=state) - assert actual == expected + state_out, signal_out = executor.execute( + m.p, SimulationStage.RUN_STEP, {}, state=state + ) + assert state_out == {("p", "out_var"): 2} + assert signal_out == xs.RuntimeSignal.BREAK - assert executor.execute(m.p, SimulationStage.INITIALIZE, {}, state=state) == {} + state_out, signal_out = executor.execute( + m.p, SimulationStage.INITIALIZE, {}, state=state + ) + assert state_out == {} + assert signal_out == xs.RuntimeSignal.NONE def test_process_executor_raise(): diff --git a/xsimlab/tests/test_signals.py b/xsimlab/tests/test_signals.py new file mode 100644 index 00000000..7fa042fa --- /dev/null +++ b/xsimlab/tests/test_signals.py @@ -0,0 +1,111 @@ +import numpy as np +import pytest + +import xsimlab as xs +from xsimlab.process import SimulationStage + + +@pytest.mark.parametrize( + "trigger,signal,expected", + [ + ("pre", xs.RuntimeSignal.SKIP, [1.0, 2.0, 4.0, 5.0]), + ("post", xs.RuntimeSignal.SKIP, [1.0, 3.0, 5.0, 6.0]), + ("pre", xs.RuntimeSignal.CONTINUE, [1.0, 2.0, 3.0, 4.0]), + ("post", xs.RuntimeSignal.CONTINUE, [1.0, 3.0, 4.0, 5.0]), + ("pre", xs.RuntimeSignal.BREAK, [1.0, 2.0, np.nan, np.nan]), + ("post", xs.RuntimeSignal.BREAK, [1.0, 3.0, np.nan, np.nan]), + ], +) +def test_signal_model_level(trigger, signal, expected): + @xs.process + class Foo: + v = xs.variable(intent="out") + vv = xs.variable(intent="out") + + def initialize(self): + self.v = 0.0 + self.vv = 10.0 + + def run_step(self): + self.v += 1.0 + + def finalize_step(self): + self.v += 1.0 + + @xs.runtime_hook("run_step", level="model", trigger=trigger) + def hook_func(model, context, state): + if context["step"] == 1: + return signal + + model = xs.Model({"foo": Foo}) + ds_in = xs.create_setup( + model=model, + clocks={"clock": range(4)}, + output_vars={"foo__v": "clock", "foo__vv": None}, + ) + ds_out = ds_in.xsimlab.run(model=model, hooks=[hook_func]) + + np.testing.assert_equal(ds_out.foo__v.values, expected) + + # ensure that clock-independent output variables are properly + # saved even when the simulation stops early + assert ds_out.foo__vv == 10.0 + + +@pytest.fixture(params=[True, False]) +def parallel(request): + return request.param + + +@pytest.mark.parametrize( + "step,trigger,signal,break_bar,expected_v1,expected_v2", + [ + # Both Foo.run_step and Bar.run_step are run + (0, "pre", xs.RuntimeSignal.SKIP, False, 1, 2), + # None of Foo.run_step and Bar.run_step are run + (1, "pre", xs.RuntimeSignal.SKIP, False, 0, 0), + # BREAK signal returned in Bar.run_step prevails + (1, "post", xs.RuntimeSignal.SKIP, True, 0, 2), + # BREAK signal returned by Bar.run_step only + (0, "post", xs.RuntimeSignal.BREAK, True, 0, 2), + ], +) +def test_signal_process_level( + step, trigger, signal, break_bar, expected_v1, expected_v2, parallel +): + @xs.process + class Foo: + v1 = xs.variable(intent="out") + v2 = xs.variable() + + def initialize(self): + self.v1 = 0 + + def run_step(self): + self.v1 = 1 + + @xs.process + class Bar: + v2 = xs.foreign(Foo, "v2", intent="out") + + def initialize(self): + self.v2 = 0 + + def run_step(self): + self.v2 = 2 + + if break_bar: + return xs.RuntimeSignal.BREAK + + def hook_func(model, context, state): + if context["step"] == 1: + return signal + + hook_dict = {SimulationStage.RUN_STEP: {"process": {trigger: [hook_func]}}} + + model = xs.Model({"foo": Foo, "bar": Bar}) + model.execute("initialize", {}) + model.execute("run_step", {"step": step}, hooks=hook_dict, parallel=parallel) + + assert model.state[("foo", "v1")] == expected_v1 + assert model.state[("foo", "v2")] == expected_v2