From f7021e595dc17e4e9245564c734c09fa7aab7bd6 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Wed, 26 Feb 2025 10:18:36 +1100 Subject: [PATCH 1/3] fixed bug in state depth calculation where, single, combined states were being included --- pydra/engine/state.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pydra/engine/state.py b/pydra/engine/state.py index 96c2881189..7ef22aa848 100644 --- a/pydra/engine/state.py +++ b/pydra/engine/state.py @@ -190,13 +190,13 @@ def names(self): @property def depth(self) -> int: - """Return the number of uncombined splits of the state, i.e. the number nested + """Return the number of splits of the state, i.e. the number nested state arrays to wrap around the type of lazy out fields Returns ------- int - number of uncombined splits + number of uncombined independent splits (i.e. linked splits only add 1) """ depth = 0 stack = [] @@ -210,7 +210,8 @@ def depth(self) -> int: stack = [] else: stack.append(spl) - return depth + len(stack) + remaining_stack = [s for s in stack if s not in self.combiner] + return depth + len(remaining_stack) @property def splitter(self): From b2034d56f7d5681b2a3fbe6d47238344c3e05d79 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Wed, 26 Feb 2025 10:20:06 +1100 Subject: [PATCH 2/3] moved resolved lazy inputs into NodeExecution class from TaskDef --- pydra/engine/lazy.py | 48 ++++++++--- pydra/engine/node.py | 36 ++++---- pydra/engine/specs.py | 37 -------- pydra/engine/submitter.py | 39 +++++++-- pydra/engine/tests/test_specs.py | 144 ++++++++++++------------------- pydra/engine/tests/utils.py | 4 +- pydra/utils/typing.py | 10 +-- 7 files changed, 148 insertions(+), 170 deletions(-) diff --git a/pydra/engine/lazy.py b/pydra/engine/lazy.py index f668acfc1b..936d47ead5 100644 --- a/pydra/engine/lazy.py +++ b/pydra/engine/lazy.py @@ -6,10 +6,9 @@ from . import node if ty.TYPE_CHECKING: - from .graph import DiGraph from .submitter import NodeExecution from .core import Task, Workflow - from .specs import TaskDef, WorkflowDef + from .specs import TaskDef from .state import StateIndex @@ -46,6 +45,27 @@ def _apply_cast(self, value): value = self._type(value) return value + def _get_value( + self, + node_exec: "NodeExecution", + state_index: "StateIndex | None" = None, + ) -> ty.Any: + """Return the value of a lazy field. + + Parameters + ---------- + node_exec: NodeExecution + the object representing the execution state of the current node + state_index : StateIndex, optional + the state index of the field to access + + Returns + ------- + value : Any + the resolved value of the lazy-field + """ + raise NotImplementedError("LazyField is an abstract class") + @attrs.define(kw_only=True) class LazyInField(LazyField[T]): @@ -70,23 +90,25 @@ def _source(self): def _get_value( self, - workflow_def: "WorkflowDef", + node_exec: "NodeExecution", + state_index: "StateIndex | None" = None, ) -> ty.Any: """Return the value of a lazy field. Parameters ---------- - wf : Workflow - the workflow the lazy field references - state_index : int, optional - the state index of the field to access + node_exec: NodeExecution + the object representing the execution state of the current node + state_index : StateIndex, optional + the state index of the field to access (ignored, used for duck-typing with + LazyOutField) Returns ------- value : Any the resolved value of the lazy-field """ - value = workflow_def[self._field] + value = node_exec.workflow_inputs[self._field] value = self._apply_cast(value) return value @@ -105,16 +127,16 @@ def __repr__(self): def _get_value( self, - graph: "DiGraph[NodeExecution]", + node_exec: "NodeExecution", state_index: "StateIndex | None" = None, ) -> ty.Any: """Return the value of a lazy field. Parameters ---------- - wf : Workflow - the workflow the lazy field references - state_index : int, optional + node_exec: NodeExecution + the object representing the execution state of the current node + state_index : StateIndex, optional the state index of the field to access Returns @@ -130,7 +152,7 @@ def _get_value( if state_index is None: state_index = StateIndex() - task = graph.node(self._node.name).task(state_index) + task = node_exec.graph.node(self._node.name).task(state_index) _, split_depth = TypeParser.strip_splits(self._type) def get_nested(task: "Task[DefType]", depth: int): diff --git a/pydra/engine/node.py b/pydra/engine/node.py index 2598a08fb6..156478b7e1 100644 --- a/pydra/engine/node.py +++ b/pydra/engine/node.py @@ -121,14 +121,28 @@ def lzout(self) -> OutputType: type=field.type, ) outputs = self.inputs.Outputs(**lazy_fields) - # Flag the output lazy fields as being not typed checked (i.e. assigned to another - # node's inputs) yet + outpt: lazy.LazyOutField for outpt in attrs_values(outputs).values(): - outpt._type_checked = False + # Assign the current node to the lazy fields so they can access the state outpt._node = self + # If the node has a non-empty state, wrap the type of the lazy field in + # a combination of an optional list and a number of nested StateArrays + # types based on the number of states the node is split over and whether + # it has a combiner + if self._state: + type_, _ = TypeParser.strip_splits(outpt._type) + if self._state.combiner: + type_ = list[type_] + for _ in range(self._state.depth - int(bool(self._state.combiner))): + type_ = StateArray[type_] + outpt._type = type_ + # Flag the output lazy fields as being not typed checked (i.e. assigned to + # another node's inputs) yet. This is used to prevent the user from changing + # the type of the output after it has been accessed by connecting it to an + # output of an upstream node with additional state variables. + outpt._type_checked = False self._lzout = outputs - self._wrap_lzout_types_in_state_arrays() return outputs @property @@ -217,20 +231,6 @@ def _check_if_outputs_have_been_used(self, msg): + msg ) - def _wrap_lzout_types_in_state_arrays(self) -> None: - """Wraps a types of the lazy out fields in a number of nested StateArray types - based on the number of states the node is split over""" - # Unwrap StateArray types from the output types - if not self.state: - return - outpt_lf: lazy.LazyOutField - for outpt_lf in attrs_values(self.lzout).values(): - assert not outpt_lf._type_checked - type_, _ = TypeParser.strip_splits(outpt_lf._type) - for _ in range(self._state.depth): - type_ = StateArray[type_] - outpt_lf._type = type_ - def _set_state(self) -> None: # Add node name to state's splitter, combiner and cont_dim loaded from the def splitter = self._definition._splitter diff --git a/pydra/engine/specs.py b/pydra/engine/specs.py index cd7f43b4ae..d92ce66434 100644 --- a/pydra/engine/specs.py +++ b/pydra/engine/specs.py @@ -35,14 +35,12 @@ from pydra.utils.typing import StateArray, MultiInputObj from pydra.design.base import Field, Arg, Out, RequirementSet, NO_DEFAULT from pydra.design import shell -from pydra.engine.lazy import LazyInField, LazyOutField if ty.TYPE_CHECKING: from pydra.engine.core import Task from pydra.engine.graph import DiGraph from pydra.engine.submitter import NodeExecution from pydra.engine.core import Workflow - from pydra.engine.state import StateIndex from pydra.engine.environments import Environment from pydra.engine.workers import Worker @@ -476,41 +474,6 @@ def _compute_hashes(self) -> ty.Tuple[bytes, ty.Dict[str, bytes]]: } return hash_function(sorted(field_hashes.items())), field_hashes - def _resolve_lazy_inputs( - self, - workflow_inputs: "WorkflowDef", - graph: "DiGraph[NodeExecution]", - state_index: "StateIndex | None" = None, - ) -> Self: - """Resolves lazy fields in the task definition by replacing them with their - actual values. - - Parameters - ---------- - workflow : Workflow - The workflow the task is part of - graph : DiGraph[NodeExecution] - The execution graph of the workflow - state_index : StateIndex, optional - The state index for the workflow, by default None - - Returns - ------- - Self - The task definition with all lazy fields resolved - """ - from pydra.engine.state import StateIndex - - if state_index is None: - state_index = StateIndex() - resolved = {} - for name, value in attrs_values(self).items(): - if isinstance(value, LazyInField): - resolved[name] = value._get_value(workflow_inputs) - elif isinstance(value, LazyOutField): - resolved[name] = value._get_value(graph, state_index) - return attrs.evolve(self, **resolved) - def _check_rules(self): """Check if all rules are satisfied.""" diff --git a/pydra/engine/submitter.py b/pydra/engine/submitter.py index c9433e2d21..a869c10272 100644 --- a/pydra/engine/submitter.py +++ b/pydra/engine/submitter.py @@ -9,15 +9,18 @@ from copy import copy from datetime import datetime from collections import defaultdict +import attrs from .workers import Worker, WORKERS from .graph import DiGraph from .helpers import ( get_open_loop, list_fields, + attrs_values, ) from pydra.utils.hash import PersistentCache from .state import StateIndex from pydra.utils.typing import StateArray +from pydra.engine.lazy import LazyField from .audit import Audit from .core import Task from pydra.utils.messenger import AuditFlag, Messenger @@ -607,10 +610,7 @@ def all_failed(self) -> bool: def _generate_tasks(self) -> ty.Iterable["Task[DefType]"]: if not self.node.state: yield Task( - definition=self.node._definition._resolve_lazy_inputs( - workflow_inputs=self.workflow_inputs, - graph=self.graph, - ), + definition=self._resolve_lazy_inputs(task_def=self.node._definition), submitter=self.submitter, environment=self.node._environment, hooks=self.node._hooks, @@ -619,9 +619,8 @@ def _generate_tasks(self) -> ty.Iterable["Task[DefType]"]: else: for index, split_defn in self.node._split_definition().items(): yield Task( - definition=split_defn._resolve_lazy_inputs( - workflow_inputs=self.workflow_inputs, - graph=self.graph, + definition=self._resolve_lazy_inputs( + task_def=split_defn, state_index=index, ), submitter=self.submitter, @@ -631,6 +630,32 @@ def _generate_tasks(self) -> ty.Iterable["Task[DefType]"]: state_index=index, ) + def _resolve_lazy_inputs( + self, + task_def: "TaskDef", + state_index: "StateIndex | None" = None, + ) -> "TaskDef": + """Resolves lazy fields in the task definition by replacing them with their + actual values calculated by upstream jobs. + + Parameters + ---------- + task_def : TaskDef + The definition to resolve the lazy fields of + state_index : StateIndex, optional + The state index for the workflow, by default None + + Returns + ------- + TaskDef + The task definition with all lazy fields resolved + """ + resolved = {} + for name, value in attrs_values(self).items(): + if isinstance(value, LazyField): + resolved[name] = value._get_value(self, state_index) + return attrs.evolve(self, **resolved) + def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]: """For a given node, check to see which tasks have been successfully run, are ready to run, can't be run due to upstream errors, or are blocked on other tasks to complete. diff --git a/pydra/engine/tests/test_specs.py b/pydra/engine/tests/test_specs.py index 950d68b8a3..20b410026e 100644 --- a/pydra/engine/tests/test_specs.py +++ b/pydra/engine/tests/test_specs.py @@ -1,13 +1,11 @@ from pathlib import Path import typing as ty import os -import attrs from unittest.mock import Mock - -# from copy import deepcopy import time +import pytest from fileformats.generic import File -from ..specs import ( +from pydra.engine.specs import ( Runtime, Result, ) @@ -15,113 +13,68 @@ LazyInField, LazyOutField, ) - +from pydra.engine.core import Workflow +from pydra.engine.submitter import NodeExecution from pydra.utils.typing import StateArray - -# from ..helpers import make_klass -from .utils import Foo, BasicWorkflow from pydra.design import python, workflow -import pytest - - -# @python.define -# def Foo(a: str, b: int, c: float) -> str: -# return f"{a}{b}{c}" - - -def test_runtime(): - runtime = Runtime() - assert hasattr(runtime, "rss_peak_gb") - assert hasattr(runtime, "vms_peak_gb") - assert hasattr(runtime, "cpu_peak_percent") - - -def test_result(tmp_path): - result = Result(output_dir=tmp_path) - assert hasattr(result, "runtime") - assert hasattr(result, "outputs") - assert hasattr(result, "errored") - assert getattr(result, "errored") is False +from .utils import Foo, FunAddTwo, FunAddVar, ListSum -class NodeTesting: - @attrs.define() - class Input: - inp_a: str = "A" - inp_b: str = "B" +@workflow.define +def TestWorkflow(x: int, y: list[int]) -> int: + node_a = workflow.add(FunAddTwo(a=x), name="A") + node_b = workflow.add(FunAddVar(a=node_a.out).split(b=y).combine("b"), name="B") + node_c = workflow.add(ListSum(x=node_b.out), name="C") + return node_c.out - def __init__(self): - class InpDef: - def __init__(self): - self.fields = [("inp_a", int), ("inp_b", int)] - class Outputs: - def __init__(self): - self.fields = [("out_a", int)] - - self.name = "tn" - self.inputs = self.Input() - self.input_spec = InpDef() - self.output_spec = Outputs() - self.output_names = ["out_a"] - self.state = None - - def result(self, state_index=None): - class Output: - def __init__(self): - self.out_a = "OUT_A" +@pytest.fixture +def workflow_task(): + return TestWorkflow(x=1, y=[1, 2, 3]) - class Result: - def __init__(self): - self.output = Output() - self.errored = False - def get_output_field(self, field): - return getattr(self.output, field) +@pytest.fixture +def workflow_obj(workflow_task): + wf = Workflow.construct(workflow_task) + for n in wf.nodes: + if n._state: + n._state.prepare_states() + n._state.prepare_inputs() + return wf - return Result() +@pytest.fixture +def node_a(workflow_obj): + return workflow_obj["A"] -class WorkflowTesting: - def __init__(self): - class Input: - def __init__(self): - self.inp_a = "A" - self.inp_b = "B" - self.inputs = Input() - self.tn = NodeTesting() +@pytest.fixture +def node_b(workflow_obj): + return workflow_obj["B"] @pytest.fixture -def mock_node(): - node = Mock() - node.name = "tn" - node.definition = Foo(a="a", b=1, c=2.0) - return node +def node_c(workflow_obj): + return workflow_obj["C"] @pytest.fixture -def mock_workflow(): - mock_workflow = Mock() - mock_workflow.inputs = BasicWorkflow(x=1) - mock_workflow.outputs = BasicWorkflow.Outputs(out=attrs.NOTHING) - return mock_workflow +def node_exec(node_c, workflow_task): + # We only use this to resolve the upstream outputs from, can be any node + return NodeExecution(node=node_c, workflow_inputs=workflow_task, submitter=Mock()) -def test_lazy_inp(mock_workflow): - lf = LazyInField(field="a", type=int, workflow=mock_workflow) - assert lf._get_value() == "a" +def test_lazy_inp(workflow_obj, node_exec): + lf = LazyInField(field="x", type=int, workflow=workflow_obj) + assert lf._get_value(node_exec) == 1 - lf = LazyInField(field="b", type=str, workflow_def=mock_workflow) - assert lf._get_value() == 1 + lf = LazyInField(field="y", type=str, workflow=workflow_obj) + assert lf._get_value(node_exec) == [1, 2, 3] -def test_lazy_out(): - tn = NodeTesting() - lzout = LazyOutField(task=tn) - lf = lzout.out_a - assert lf.get_value(wf=WorkflowTesting()) == "OUT_A" +def test_lazy_out(node_a, node_exec): + lf = LazyOutField(field="a", type=int, node=node_a) + assert lf._get_value(node_exec) == 3 def test_lazy_getvale(): @@ -409,3 +362,18 @@ def Outer(xs): outputs = outer(cache_dir=tmp_path) assert outputs.out == [1, 2, 3] + + +def test_runtime(): + runtime = Runtime() + assert hasattr(runtime, "rss_peak_gb") + assert hasattr(runtime, "vms_peak_gb") + assert hasattr(runtime, "cpu_peak_percent") + + +def test_result(tmp_path): + result = Result(output_dir=tmp_path) + assert hasattr(result, "runtime") + assert hasattr(result, "outputs") + assert hasattr(result, "errored") + assert getattr(result, "errored") is False diff --git a/pydra/engine/tests/utils.py b/pydra/engine/tests/utils.py index 9fc1d5f91f..47ba21e4ce 100644 --- a/pydra/engine/tests/utils.py +++ b/pydra/engine/tests/utils.py @@ -290,8 +290,8 @@ def FunFileList(filename_list: ty.List[File]): @workflow.define(outputs=["out"]) def BasicWorkflow(x): - task1 = workflow.add(FunAddTwo(a=x)) - task2 = workflow.add(FunAddVar(a=task1.out, b=2)) + task1 = workflow.add(FunAddTwo(a=x), name="A") + task2 = workflow.add(FunAddVar(a=task1.out, b=2), name="B") return task2.out diff --git a/pydra/utils/typing.py b/pydra/utils/typing.py index 7e39fc541b..6c538efaa8 100644 --- a/pydra/utils/typing.py +++ b/pydra/utils/typing.py @@ -7,7 +7,7 @@ import types import typing as ty import logging -import attr +import attrs from pydra.utils import add_exc_note from fileformats import field, core, generic @@ -217,8 +217,8 @@ def __call__(self, obj: ty.Any) -> T: from pydra.engine.helpers import is_lazy coerced: T - if obj is attr.NOTHING: - coerced = attr.NOTHING # type: ignore[assignment] + if obj is attrs.NOTHING: + coerced = attrs.NOTHING # type: ignore[assignment] elif is_lazy(obj): try: self.check_type(obj._type) @@ -279,8 +279,8 @@ def coerce(self, object_: ty.Any) -> T: def expand_and_coerce(obj, pattern: ty.Union[type, tuple]): """Attempt to expand the object along the lines of the coercion pattern""" - if obj is attr.NOTHING: - return attr.NOTHING + if obj is attrs.NOTHING: + return attrs.NOTHING if not isinstance(pattern, tuple): return coerce_basic(obj, pattern) origin, pattern_args = pattern From db8f7992a929b26af94bd4e73f04d65ab4e4cae7 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Wed, 26 Feb 2025 14:26:06 +1100 Subject: [PATCH 3/3] finished debugging test_specs --- pydra/engine/core.py | 4 +- pydra/engine/lazy.py | 36 +++-- pydra/engine/node.py | 6 +- pydra/engine/specs.py | 2 +- pydra/engine/state.py | 51 ++++++- pydra/engine/submitter.py | 69 ++++++--- pydra/engine/tests/test_specs.py | 248 ++++++++++++------------------- 7 files changed, 220 insertions(+), 196 deletions(-) diff --git a/pydra/engine/core.py b/pydra/engine/core.py index 9b7af85008..9b0ac65d79 100644 --- a/pydra/engine/core.py +++ b/pydra/engine/core.py @@ -817,9 +817,7 @@ def node_names(self) -> list[str]: def execution_graph(self, submitter: "Submitter") -> DiGraph: from pydra.engine.submitter import NodeExecution - exec_nodes = [ - NodeExecution(n, submitter, workflow_inputs=self.inputs) for n in self.nodes - ] + exec_nodes = [NodeExecution(n, submitter, workflow=self) for n in self.nodes] graph = self._create_graph(exec_nodes) # Set the graph attribute of the nodes so lazy fields can be resolved as tasks # are created diff --git a/pydra/engine/lazy.py b/pydra/engine/lazy.py index 936d47ead5..fd1f628a24 100644 --- a/pydra/engine/lazy.py +++ b/pydra/engine/lazy.py @@ -6,7 +6,7 @@ from . import node if ty.TYPE_CHECKING: - from .submitter import NodeExecution + from .submitter import DiGraph, NodeExecution from .core import Task, Workflow from .specs import TaskDef from .state import StateIndex @@ -47,15 +47,18 @@ def _apply_cast(self, value): def _get_value( self, - node_exec: "NodeExecution", + workflow: "Workflow", + graph: "DiGraph[NodeExecution]", state_index: "StateIndex | None" = None, ) -> ty.Any: """Return the value of a lazy field. Parameters ---------- - node_exec: NodeExecution - the object representing the execution state of the current node + workflow: Workflow + the workflow object + graph: DiGraph[NodeExecution] + the graph representing the execution state of the workflow state_index : StateIndex, optional the state index of the field to access @@ -90,25 +93,27 @@ def _source(self): def _get_value( self, - node_exec: "NodeExecution", + workflow: "Workflow", + graph: "DiGraph[NodeExecution]", state_index: "StateIndex | None" = None, ) -> ty.Any: """Return the value of a lazy field. Parameters ---------- - node_exec: NodeExecution - the object representing the execution state of the current node + workflow: Workflow + the workflow object + graph: DiGraph[NodeExecution] + the graph representing the execution state of the workflow state_index : StateIndex, optional - the state index of the field to access (ignored, used for duck-typing with - LazyOutField) + the state index of the field to access Returns ------- value : Any the resolved value of the lazy-field """ - value = node_exec.workflow_inputs[self._field] + value = workflow.inputs[self._field] value = self._apply_cast(value) return value @@ -127,15 +132,18 @@ def __repr__(self): def _get_value( self, - node_exec: "NodeExecution", + workflow: "Workflow", + graph: "DiGraph[NodeExecution]", state_index: "StateIndex | None" = None, ) -> ty.Any: """Return the value of a lazy field. Parameters ---------- - node_exec: NodeExecution - the object representing the execution state of the current node + workflow: Workflow + the workflow object + graph: DiGraph[NodeExecution] + the graph representing the execution state of the workflow state_index : StateIndex, optional the state index of the field to access @@ -152,7 +160,7 @@ def _get_value( if state_index is None: state_index = StateIndex() - task = node_exec.graph.node(self._node.name).task(state_index) + task = graph.node(self._node.name).task(state_index) _, split_depth = TypeParser.strip_splits(self._type) def get_nested(task: "Task[DefType]", depth: int): diff --git a/pydra/engine/node.py b/pydra/engine/node.py index 156478b7e1..8fd3bf0415 100644 --- a/pydra/engine/node.py +++ b/pydra/engine/node.py @@ -269,7 +269,11 @@ def _get_upstream_states(self) -> dict[str, tuple["State", list[str]]]: """Get the states of the upstream nodes that are connected to this node""" upstream_states = {} for inpt_name, val in self.input_values: - if isinstance(val, lazy.LazyOutField) and val._node.state: + if ( + isinstance(val, lazy.LazyOutField) + and val._node.state + and val._node.state.depth + ): node: Node = val._node # variables that are part of inner splitters should be treated as a containers if node.state and f"{node.name}.{inpt_name}" in node.state.splitter: diff --git a/pydra/engine/specs.py b/pydra/engine/specs.py index d92ce66434..baea685f14 100644 --- a/pydra/engine/specs.py +++ b/pydra/engine/specs.py @@ -736,7 +736,7 @@ def _from_task(cls, task: "Task[WorkflowDef]") -> Self: nodes_dict = {n.name: n for n in exec_graph.nodes} for name, lazy_field in attrs_values(workflow.outputs).items(): try: - val_out = lazy_field._get_value(exec_graph) + val_out = lazy_field._get_value(workflow=workflow, graph=exec_graph) output_wf[name] = val_out except (ValueError, AttributeError): output_wf[name] = None diff --git a/pydra/engine/state.py b/pydra/engine/state.py index 7ef22aa848..11a4290bcf 100644 --- a/pydra/engine/state.py +++ b/pydra/engine/state.py @@ -41,7 +41,13 @@ def __init__(self, indices: dict[str, int] | None = None): else: self.indices = OrderedDict(sorted(indices.items())) - def __repr__(self): + def __len__(self) -> int: + return len(self.indices) + + def __iter__(self) -> ty.Generator[str, None, None]: + return iter(self.indices) + + def __repr__(self) -> str: return ( "StateIndex(" + ", ".join(f"{n}={v}" for n, v in self.indices.items()) + ")" ) @@ -49,15 +55,49 @@ def __repr__(self): def __hash__(self): return hash(tuple(self.indices.items())) - def __eq__(self, other): + def __eq__(self, other) -> bool: return self.indices == other.indices - def __str__(self): + def __str__(self) -> str: return "__".join(f"{n}-{i}" for n, i in self.indices.items()) - def __bool__(self): + def __bool__(self) -> bool: return bool(self.indices) + def subset(self, state_names: ty.Iterable[str]) -> ty.Self: + """Create a new StateIndex with only the specified fields + + Parameters + ---------- + fields : list[str] + the fields to keep in the new StateIndex + + Returns + ------- + StateIndex + a new StateIndex with only the specified fields + """ + return type(self)({k: v for k, v in self.indices.items() if k in state_names}) + + def matches(self, other: "StateIndex") -> bool: + """Check if the indices that are present in the other StateIndex match + + Parameters + ---------- + other : StateIndex + the other StateIndex to compare against + + Returns + ------- + bool + True if all the indices in the other StateIndex match + """ + if not set(self.indices).issuperset(other.indices): + raise ValueError( + f"StateIndex {self} does not contain all the indices in {other}" + ) + return all(self.indices[k] == v for k, v in other.indices.items()) + class State: """ @@ -172,6 +212,9 @@ def __str__(self): def names(self): """Return the names of the states.""" # analysing states from connected tasks if inner_inputs + if not hasattr(self, "keys_final"): + self.prepare_states() + self.prepare_inputs() previous_states_keys = { f"_{v.name}": v.keys_final for v in self.inner_inputs.values() } diff --git a/pydra/engine/submitter.py b/pydra/engine/submitter.py index a869c10272..3236306a57 100644 --- a/pydra/engine/submitter.py +++ b/pydra/engine/submitter.py @@ -32,7 +32,8 @@ if ty.TYPE_CHECKING: from .node import Node - from .specs import TaskDef, TaskOutputs, WorkflowDef, TaskHooks, Result + from .specs import WorkflowDef, TaskDef, TaskOutputs, TaskHooks, Result + from .core import Workflow from .environments import Environment from .state import State @@ -501,7 +502,7 @@ class NodeExecution(ty.Generic[DefType]): _tasks: dict[StateIndex | None, "Task[DefType]"] | None - workflow_inputs: "WorkflowDef" + workflow: "Workflow" graph: DiGraph["NodeExecution"] | None @@ -509,7 +510,7 @@ def __init__( self, node: "Node", submitter: Submitter, - workflow_inputs: "WorkflowDef", + workflow: "Workflow", ): self.name = node.name self.node = node @@ -523,9 +524,17 @@ def __init__( self.running = {} # Not used in logic, but may be useful for progress tracking self.unrunnable = defaultdict(list) self.state_names = self.node.state.names if self.node.state else [] - self.workflow_inputs = workflow_inputs + self.workflow = workflow self.graph = None + def __repr__(self): + return ( + f"NodeExecution(name={self.name!r}, blocked={list(self.blocked)}, " + f"queued={list(self.queued)}, running={list(self.running)}, " + f"successful={list(self.successful)}, errored={list(self.errored)}, " + f"unrunnable={list(self.unrunnable)})" + ) + @property def inputs(self) -> "Node.Inputs": return self.node.inputs @@ -547,12 +556,16 @@ def tasks(self) -> ty.Iterable["Task[DefType]"]: def task(self, index: StateIndex = StateIndex()) -> "Task | list[Task[DefType]]": """Get a task object for a given state index.""" self.tasks # Ensure tasks are loaded - try: - return self._tasks[index] - except KeyError: - if not index: - return StateArray(self._tasks.values()) - raise + task_index = next(iter(self._tasks)) + if len(task_index) > len(index): + tasks = [] + for ind, task in self._tasks.items(): + if ind.matches(index): + tasks.append(task) + return StateArray(tasks) + elif len(index) > len(task_index): + index = index.subset(task_index) + return self._tasks[index] @property def started(self) -> bool: @@ -651,10 +664,12 @@ def _resolve_lazy_inputs( The task definition with all lazy fields resolved """ resolved = {} - for name, value in attrs_values(self).items(): + for name, value in attrs_values(task_def).items(): if isinstance(value, LazyField): - resolved[name] = value._get_value(self, state_index) - return attrs.evolve(self, **resolved) + resolved[name] = value._get_value( + workflow=self.workflow, graph=self.graph, state_index=state_index + ) + return attrs.evolve(task_def, **resolved) def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]: """For a given node, check to see which tasks have been successfully run, are ready @@ -676,19 +691,35 @@ def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]: runnable: list["Task[DefType]"] = [] self.tasks # Ensure tasks are loaded if not self.started: + assert self._tasks self.blocked = copy(self._tasks) # Check to see if any blocked tasks are now runnable/unrunnable for index, task in list(self.blocked.items()): pred: NodeExecution is_runnable = True for pred in graph.predecessors[self.node.name]: - if index not in pred.successful: + pred_jobs = pred.task(index) + if isinstance(pred_jobs, StateArray): + pred_inds = [j.state_index for j in pred_jobs] + else: + pred_inds = [pred_jobs.state_index] + if not all(i in pred.successful for i in pred_inds): is_runnable = False - if index in pred.errored: - self.unrunnable[index].append(self.blocked.pop(index)) - if index in pred.unrunnable: - self.unrunnable[index].extend(pred.unrunnable[index]) - self.blocked.pop(index) + blocked = True + if pred_errored := [i for i in pred_inds if i in pred.errored]: + self.unrunnable[index].extend( + [pred.errored[i] for i in pred_errored] + ) + blocked = False + if pred_unrunnable := [ + i for i in pred_inds if i in pred.unrunnable + ]: + self.unrunnable[index].extend( + [pred.unrunnable[i] for i in pred_unrunnable] + ) + blocked = False + if not blocked: + del self.blocked[index] break if is_runnable: runnable.append(self.blocked.pop(index)) diff --git a/pydra/engine/tests/test_specs.py b/pydra/engine/tests/test_specs.py index 20b410026e..8b1407d231 100644 --- a/pydra/engine/tests/test_specs.py +++ b/pydra/engine/tests/test_specs.py @@ -1,20 +1,20 @@ from pathlib import Path import typing as ty -import os -from unittest.mock import Mock import time import pytest from fileformats.generic import File from pydra.engine.specs import ( Runtime, Result, + WorkflowDef, ) from pydra.engine.lazy import ( LazyInField, LazyOutField, ) from pydra.engine.core import Workflow -from pydra.engine.submitter import NodeExecution +from pydra.engine.node import Node +from pydra.engine.submitter import Submitter, NodeExecution, DiGraph from pydra.utils.typing import StateArray from pydra.design import python, workflow from .utils import Foo, FunAddTwo, FunAddVar, ListSum @@ -29,12 +29,15 @@ def TestWorkflow(x: int, y: list[int]) -> int: @pytest.fixture -def workflow_task(): - return TestWorkflow(x=1, y=[1, 2, 3]) +def workflow_task(submitter: Submitter) -> WorkflowDef: + wf = TestWorkflow(x=1, y=[1, 2, 3]) + with submitter: + submitter(wf) + return wf @pytest.fixture -def workflow_obj(workflow_task): +def wf(workflow_task: WorkflowDef) -> Workflow: wf = Workflow.construct(workflow_task) for n in wf.nodes: if n._state: @@ -44,64 +47,63 @@ def workflow_obj(workflow_task): @pytest.fixture -def node_a(workflow_obj): - return workflow_obj["A"] +def submitter(tmp_path) -> Submitter: + return Submitter(tmp_path) @pytest.fixture -def node_b(workflow_obj): - return workflow_obj["B"] +def graph(wf: Workflow, submitter: Submitter) -> DiGraph[NodeExecution]: + return wf.execution_graph(submitter=submitter) @pytest.fixture -def node_c(workflow_obj): - return workflow_obj["C"] +def node_a(wf) -> Node: + return wf["A"] # we can pick any node to retrieve the values to -@pytest.fixture -def node_exec(node_c, workflow_task): - # We only use this to resolve the upstream outputs from, can be any node - return NodeExecution(node=node_c, workflow_inputs=workflow_task, submitter=Mock()) +def test_runtime(): + runtime = Runtime() + assert hasattr(runtime, "rss_peak_gb") + assert hasattr(runtime, "vms_peak_gb") + assert hasattr(runtime, "cpu_peak_percent") -def test_lazy_inp(workflow_obj, node_exec): - lf = LazyInField(field="x", type=int, workflow=workflow_obj) - assert lf._get_value(node_exec) == 1 +def test_result(tmp_path): + result = Result(output_dir=tmp_path) + assert hasattr(result, "runtime") + assert hasattr(result, "outputs") + assert hasattr(result, "errored") + assert getattr(result, "errored") is False - lf = LazyInField(field="y", type=str, workflow=workflow_obj) - assert lf._get_value(node_exec) == [1, 2, 3] +def test_lazy_inp(wf: Workflow, graph: DiGraph[NodeExecution]): + lf = LazyInField(field="x", type=int, workflow=wf) + assert lf._get_value(workflow=wf, graph=graph) == 1 -def test_lazy_out(node_a, node_exec): - lf = LazyOutField(field="a", type=int, node=node_a) - assert lf._get_value(node_exec) == 3 + lf = LazyInField(field="y", type=str, workflow=wf) + assert lf._get_value(workflow=wf, graph=graph) == [1, 2, 3] -def test_lazy_getvale(): - tn = NodeTesting() - lf = LazyIn(task=tn) - with pytest.raises(Exception) as excinfo: - lf.inp_c - assert ( - str(excinfo.value) - == "Task 'tn' has no input attribute 'inp_c', available: 'inp_a', 'inp_b'" - ) +def test_lazy_out(node_a, wf, graph): + lf = LazyOutField(field="out", type=int, node=node_a) + assert lf._get_value(wf, graph) == 3 def test_input_file_hash_1(tmp_path): - os.chdir(tmp_path) - outfile = "test.file" - fields = [("in_file", ty.Any)] - input_spec = SpecInfo(name="Inputs", fields=fields, bases=(BaseDef,)) - inputs = make_klass(input_spec) - assert inputs(in_file=outfile).hash == "9a106eb2830850834d9b5bf098d5fa85" + + outfile = tmp_path / "test.file" + outfile.touch() + + @python.define + def A(in_file: File) -> File: + return in_file + + assert A(in_file=outfile)._hash == "9644d3998748b339819c23ec6abec520" with open(outfile, "w") as fp: fp.write("test") - fields = [("in_file", File)] - input_spec = SpecInfo(name="Inputs", fields=fields, bases=(BaseDef,)) - inputs = make_klass(input_spec) - assert inputs(in_file=outfile).hash == "02fa5f6f1bbde7f25349f54335e1adaf" + + assert A(in_file=outfile)._hash == "9f7f9377ddef6d8c018f1bf8e89c208c" def test_input_file_hash_2(tmp_path): @@ -110,26 +112,26 @@ def test_input_file_hash_2(tmp_path): with open(file, "w") as f: f.write("hello") - input_spec = SpecInfo(name="Inputs", fields=[("in_file", File)], bases=(BaseDef,)) - inputs = make_klass(input_spec) + @python.define + def A(in_file: File) -> File: + return in_file # checking specific hash value - hash1 = inputs(in_file=file).hash - assert hash1 == "aaa50d60ed33d3a316d58edc882a34c3" + hash1 = A(in_file=file)._hash + assert hash1 == "179bd3cbdc747edc4957579376fe8c7d" # checking if different name doesn't affect the hash file_diffname = tmp_path / "in_file_2.txt" with open(file_diffname, "w") as f: f.write("hello") - hash2 = inputs(in_file=file_diffname).hash + hash2 = A(in_file=file_diffname)._hash assert hash1 == hash2 # checking if different content (the same name) affects the hash - time.sleep(2) # ensure mtime is different file_diffcontent = tmp_path / "in_file_1.txt" with open(file_diffcontent, "w") as f: f.write("hi") - hash3 = inputs(in_file=file_diffcontent).hash + hash3 = A(in_file=file_diffcontent)._hash assert hash1 != hash3 @@ -139,33 +141,31 @@ def test_input_file_hash_2a(tmp_path): with open(file, "w") as f: f.write("hello") - input_spec = SpecInfo( - name="Inputs", fields=[("in_file", ty.Union[File, int])], bases=(BaseDef,) - ) - inputs = make_klass(input_spec) + @python.define + def A(in_file: ty.Union[File, int]) -> File: + return in_file # checking specific hash value - hash1 = inputs(in_file=file).hash - assert hash1 == "aaa50d60ed33d3a316d58edc882a34c3" + hash1 = A(in_file=file)._hash + assert hash1 == "179bd3cbdc747edc4957579376fe8c7d" # checking if different name doesn't affect the hash file_diffname = tmp_path / "in_file_2.txt" with open(file_diffname, "w") as f: f.write("hello") - hash2 = inputs(in_file=file_diffname).hash + hash2 = A(in_file=file_diffname)._hash assert hash1 == hash2 + # checking if string is also accepted + hash3 = A(in_file=str(file))._hash + assert hash3 == hash1 + # checking if different content (the same name) affects the hash - time.sleep(2) # ensure mtime is different file_diffcontent = tmp_path / "in_file_1.txt" with open(file_diffcontent, "w") as f: f.write("hi") - hash3 = inputs(in_file=file_diffcontent).hash - assert hash1 != hash3 - - # checking if string is also accepted - hash4 = inputs(in_file=str(file)).hash - assert hash4 == "800af2b5b334c9e3e5c40c0e49b7ffb5" + hash4 = A(in_file=file_diffcontent)._hash + assert hash1 != hash4 def test_input_file_hash_3(tmp_path): @@ -174,22 +174,21 @@ def test_input_file_hash_3(tmp_path): with open(file, "w") as f: f.write("hello") - input_spec = SpecInfo( - name="Inputs", fields=[("in_file", File), ("in_int", int)], bases=(BaseDef,) - ) - inputs = make_klass(input_spec) + @python.define + def A(in_file: File, in_int: int) -> File: + return in_file, in_int - my_inp = inputs(in_file=file, in_int=3) + a = A(in_file=file, in_int=3) # original hash and files_hash (dictionary contains info about files) - hash1 = my_inp.hash + hash1 = a._hash # files_hash1 = deepcopy(my_inp.files_hash) # file name should be in files_hash1[in_file] filename = str(Path(file)) # assert filename in files_hash1["in_file"] # changing int input - my_inp.in_int = 5 - hash2 = my_inp.hash + a.in_int = 5 + hash2 = a._hash # files_hash2 = deepcopy(my_inp.files_hash) # hash should be different assert hash1 != hash2 @@ -202,7 +201,7 @@ def test_input_file_hash_3(tmp_path): with open(file, "w") as f: f.write("hello") - hash3 = my_inp.hash + hash3 = a._hash # files_hash3 = deepcopy(my_inp.files_hash) # hash should be the same, # but the entry for in_file in files_hash should be different (modification time) @@ -214,11 +213,11 @@ def test_input_file_hash_3(tmp_path): # assert files_hash3["in_file"][filename][1] == files_hash2["in_file"][filename][1] # setting the in_file again - my_inp.in_file = file + a.in_file = file # filename should be removed from files_hash # assert my_inp.files_hash["in_file"] == {} # will be saved again when hash is calculated - assert my_inp.hash == hash3 + assert a._hash == hash3 # assert filename in my_inp.files_hash["in_file"] @@ -230,26 +229,23 @@ def test_input_file_hash_4(tmp_path): with open(file, "w") as f: f.write("hello") - input_spec = SpecInfo( - name="Inputs", - fields=[("in_file", ty.List[ty.List[ty.Union[int, File]]])], - bases=(BaseDef,), - ) - inputs = make_klass(input_spec) + @python.define + def A(in_file: ty.List[ty.List[ty.Union[int, File]]]) -> File: + return in_file # checking specific hash value - hash1 = inputs(in_file=[[file, 3]]).hash - assert hash1 == "0693adbfac9f675af87e900065b1de00" + hash1 = A(in_file=[[file, 3]])._hash + assert hash1 == "ffd7afe0ca9d4585518809a509244b4b" # the same file, but int field changes - hash1a = inputs(in_file=[[file, 5]]).hash + hash1a = A(in_file=[[file, 5]])._hash assert hash1 != hash1a # checking if different name doesn't affect the hash file_diffname = tmp_path / "in_file_2.txt" with open(file_diffname, "w") as f: f.write("hello") - hash2 = inputs(in_file=[[file_diffname, 3]]).hash + hash2 = A(in_file=[[file_diffname, 3]])._hash assert hash1 == hash2 # checking if different content (the same name) affects the hash @@ -257,7 +253,7 @@ def test_input_file_hash_4(tmp_path): file_diffcontent = tmp_path / "in_file_1.txt" with open(file_diffcontent, "w") as f: f.write("hi") - hash3 = inputs(in_file=[[file_diffcontent, 3]]).hash + hash3 = A(in_file=[[file_diffcontent, 3]])._hash assert hash1 != hash3 @@ -267,26 +263,23 @@ def test_input_file_hash_5(tmp_path): with open(file, "w") as f: f.write("hello") - input_spec = SpecInfo( - name="Inputs", - fields=[("in_file", ty.List[ty.Dict[ty.Any, ty.Union[File, int]]])], - bases=(BaseDef,), - ) - inputs = make_klass(input_spec) + @python.define + def A(in_file: ty.List[ty.Dict[ty.Any, ty.Union[File, int]]]) -> File: + return in_file # checking specific hash value - hash1 = inputs(in_file=[{"file": file, "int": 3}]).hash - assert hash1 == "56e6e2c9f3bdf0cd5bd3060046dea480" + hash1 = A(in_file=[{"file": file, "int": 3}])._hash + assert hash1 == "ba884a74e33552854271f55b03e53947" # the same file, but int field changes - hash1a = inputs(in_file=[{"file": file, "int": 5}]).hash + hash1a = A(in_file=[{"file": file, "int": 5}])._hash assert hash1 != hash1a # checking if different name doesn't affect the hash file_diffname = tmp_path / "in_file_2.txt" with open(file_diffname, "w") as f: f.write("hello") - hash2 = inputs(in_file=[{"file": file_diffname, "int": 3}]).hash + hash2 = A(in_file=[{"file": file_diffname, "int": 3}])._hash assert hash1 == hash2 # checking if different content (the same name) affects the hash @@ -294,53 +287,15 @@ def test_input_file_hash_5(tmp_path): file_diffcontent = tmp_path / "in_file_1.txt" with open(file_diffcontent, "w") as f: f.write("hi") - hash3 = inputs(in_file=[{"file": file_diffcontent, "int": 3}]).hash + hash3 = A(in_file=[{"file": file_diffcontent, "int": 3}])._hash assert hash1 != hash3 -def test_lazy_field_cast(): - task = Foo(a="a", b=1, c=2.0, name="foo") - - assert task.lzout.y._type is int - assert workflow.cast(task.lzout.y, float)._type is float +def test_lazy_field_cast(wf: Workflow): + lzout = wf.add(Foo(a="a", b=1, c=2.0), name="foo") - -def test_lazy_field_multi_same_split(): - @python.define - def f(x: ty.List[int]) -> ty.List[int]: - return x - - task = f(x=[1, 2, 3], name="foo") - - lf = task.lzout.out.split("foo.x") - - assert lf.type == StateArray[int] - assert lf.splits == set([(("foo.x",),)]) - - lf2 = lf.split("foo.x") - assert lf2.type == StateArray[int] - assert lf2.splits == set([(("foo.x",),)]) - - -def test_lazy_field_multi_diff_split(): - @python.define - def F(x: ty.Any, y: ty.Any) -> ty.Any: - return x - - task = F(x=[1, 2, 3], name="foo") - - lf = task.lzout.out.split("foo.x") - - assert lf.type == StateArray[ty.Any] - assert lf.splits == set([(("foo.x",),)]) - - lf2 = lf.split("foo.x") - assert lf2.type == StateArray[ty.Any] - assert lf2.splits == set([(("foo.x",),)]) - - lf3 = lf.split("foo.y") - assert lf3.type == StateArray[StateArray[ty.Any]] - assert lf3.splits == set([(("foo.x",),), (("foo.y",),)]) + assert lzout.y._type is int + assert workflow.cast(lzout.y, float)._type is float def test_wf_lzin_split(tmp_path): @@ -362,18 +317,3 @@ def Outer(xs): outputs = outer(cache_dir=tmp_path) assert outputs.out == [1, 2, 3] - - -def test_runtime(): - runtime = Runtime() - assert hasattr(runtime, "rss_peak_gb") - assert hasattr(runtime, "vms_peak_gb") - assert hasattr(runtime, "cpu_peak_percent") - - -def test_result(tmp_path): - result = Result(output_dir=tmp_path) - assert hasattr(result, "runtime") - assert hasattr(result, "outputs") - assert hasattr(result, "errored") - assert getattr(result, "errored") is False