diff --git a/pydra/engine/core.py b/pydra/engine/core.py index 9b7af8500..9b0ac65d7 100644 --- a/pydra/engine/core.py +++ b/pydra/engine/core.py @@ -817,9 +817,7 @@ def node_names(self) -> list[str]: def execution_graph(self, submitter: "Submitter") -> DiGraph: from pydra.engine.submitter import NodeExecution - exec_nodes = [ - NodeExecution(n, submitter, workflow_inputs=self.inputs) for n in self.nodes - ] + exec_nodes = [NodeExecution(n, submitter, workflow=self) for n in self.nodes] graph = self._create_graph(exec_nodes) # Set the graph attribute of the nodes so lazy fields can be resolved as tasks # are created diff --git a/pydra/engine/lazy.py b/pydra/engine/lazy.py index f668acfc1..fd1f628a2 100644 --- a/pydra/engine/lazy.py +++ b/pydra/engine/lazy.py @@ -6,10 +6,9 @@ from . import node if ty.TYPE_CHECKING: - from .graph import DiGraph - from .submitter import NodeExecution + from .submitter import DiGraph, NodeExecution from .core import Task, Workflow - from .specs import TaskDef, WorkflowDef + from .specs import TaskDef from .state import StateIndex @@ -46,6 +45,30 @@ def _apply_cast(self, value): value = self._type(value) return value + def _get_value( + self, + workflow: "Workflow", + graph: "DiGraph[NodeExecution]", + state_index: "StateIndex | None" = None, + ) -> ty.Any: + """Return the value of a lazy field. + + Parameters + ---------- + workflow: Workflow + the workflow object + graph: DiGraph[NodeExecution] + the graph representing the execution state of the workflow + state_index : StateIndex, optional + the state index of the field to access + + Returns + ------- + value : Any + the resolved value of the lazy-field + """ + raise NotImplementedError("LazyField is an abstract class") + @attrs.define(kw_only=True) class LazyInField(LazyField[T]): @@ -70,15 +93,19 @@ def _source(self): def _get_value( self, - workflow_def: "WorkflowDef", + workflow: "Workflow", + graph: "DiGraph[NodeExecution]", + state_index: "StateIndex | None" = None, ) -> ty.Any: """Return the value of a lazy field. Parameters ---------- - wf : Workflow - the workflow the lazy field references - state_index : int, optional + workflow: Workflow + the workflow object + graph: DiGraph[NodeExecution] + the graph representing the execution state of the workflow + state_index : StateIndex, optional the state index of the field to access Returns @@ -86,7 +113,7 @@ def _get_value( value : Any the resolved value of the lazy-field """ - value = workflow_def[self._field] + value = workflow.inputs[self._field] value = self._apply_cast(value) return value @@ -105,6 +132,7 @@ def __repr__(self): def _get_value( self, + workflow: "Workflow", graph: "DiGraph[NodeExecution]", state_index: "StateIndex | None" = None, ) -> ty.Any: @@ -112,9 +140,11 @@ def _get_value( Parameters ---------- - wf : Workflow - the workflow the lazy field references - state_index : int, optional + workflow: Workflow + the workflow object + graph: DiGraph[NodeExecution] + the graph representing the execution state of the workflow + state_index : StateIndex, optional the state index of the field to access Returns diff --git a/pydra/engine/node.py b/pydra/engine/node.py index 2598a08fb..8fd3bf041 100644 --- a/pydra/engine/node.py +++ b/pydra/engine/node.py @@ -121,14 +121,28 @@ def lzout(self) -> OutputType: type=field.type, ) outputs = self.inputs.Outputs(**lazy_fields) - # Flag the output lazy fields as being not typed checked (i.e. assigned to another - # node's inputs) yet + outpt: lazy.LazyOutField for outpt in attrs_values(outputs).values(): - outpt._type_checked = False + # Assign the current node to the lazy fields so they can access the state outpt._node = self + # If the node has a non-empty state, wrap the type of the lazy field in + # a combination of an optional list and a number of nested StateArrays + # types based on the number of states the node is split over and whether + # it has a combiner + if self._state: + type_, _ = TypeParser.strip_splits(outpt._type) + if self._state.combiner: + type_ = list[type_] + for _ in range(self._state.depth - int(bool(self._state.combiner))): + type_ = StateArray[type_] + outpt._type = type_ + # Flag the output lazy fields as being not typed checked (i.e. assigned to + # another node's inputs) yet. This is used to prevent the user from changing + # the type of the output after it has been accessed by connecting it to an + # output of an upstream node with additional state variables. + outpt._type_checked = False self._lzout = outputs - self._wrap_lzout_types_in_state_arrays() return outputs @property @@ -217,20 +231,6 @@ def _check_if_outputs_have_been_used(self, msg): + msg ) - def _wrap_lzout_types_in_state_arrays(self) -> None: - """Wraps a types of the lazy out fields in a number of nested StateArray types - based on the number of states the node is split over""" - # Unwrap StateArray types from the output types - if not self.state: - return - outpt_lf: lazy.LazyOutField - for outpt_lf in attrs_values(self.lzout).values(): - assert not outpt_lf._type_checked - type_, _ = TypeParser.strip_splits(outpt_lf._type) - for _ in range(self._state.depth): - type_ = StateArray[type_] - outpt_lf._type = type_ - def _set_state(self) -> None: # Add node name to state's splitter, combiner and cont_dim loaded from the def splitter = self._definition._splitter @@ -269,7 +269,11 @@ def _get_upstream_states(self) -> dict[str, tuple["State", list[str]]]: """Get the states of the upstream nodes that are connected to this node""" upstream_states = {} for inpt_name, val in self.input_values: - if isinstance(val, lazy.LazyOutField) and val._node.state: + if ( + isinstance(val, lazy.LazyOutField) + and val._node.state + and val._node.state.depth + ): node: Node = val._node # variables that are part of inner splitters should be treated as a containers if node.state and f"{node.name}.{inpt_name}" in node.state.splitter: diff --git a/pydra/engine/specs.py b/pydra/engine/specs.py index cd7f43b4a..baea685f1 100644 --- a/pydra/engine/specs.py +++ b/pydra/engine/specs.py @@ -35,14 +35,12 @@ from pydra.utils.typing import StateArray, MultiInputObj from pydra.design.base import Field, Arg, Out, RequirementSet, NO_DEFAULT from pydra.design import shell -from pydra.engine.lazy import LazyInField, LazyOutField if ty.TYPE_CHECKING: from pydra.engine.core import Task from pydra.engine.graph import DiGraph from pydra.engine.submitter import NodeExecution from pydra.engine.core import Workflow - from pydra.engine.state import StateIndex from pydra.engine.environments import Environment from pydra.engine.workers import Worker @@ -476,41 +474,6 @@ def _compute_hashes(self) -> ty.Tuple[bytes, ty.Dict[str, bytes]]: } return hash_function(sorted(field_hashes.items())), field_hashes - def _resolve_lazy_inputs( - self, - workflow_inputs: "WorkflowDef", - graph: "DiGraph[NodeExecution]", - state_index: "StateIndex | None" = None, - ) -> Self: - """Resolves lazy fields in the task definition by replacing them with their - actual values. - - Parameters - ---------- - workflow : Workflow - The workflow the task is part of - graph : DiGraph[NodeExecution] - The execution graph of the workflow - state_index : StateIndex, optional - The state index for the workflow, by default None - - Returns - ------- - Self - The task definition with all lazy fields resolved - """ - from pydra.engine.state import StateIndex - - if state_index is None: - state_index = StateIndex() - resolved = {} - for name, value in attrs_values(self).items(): - if isinstance(value, LazyInField): - resolved[name] = value._get_value(workflow_inputs) - elif isinstance(value, LazyOutField): - resolved[name] = value._get_value(graph, state_index) - return attrs.evolve(self, **resolved) - def _check_rules(self): """Check if all rules are satisfied.""" @@ -773,7 +736,7 @@ def _from_task(cls, task: "Task[WorkflowDef]") -> Self: nodes_dict = {n.name: n for n in exec_graph.nodes} for name, lazy_field in attrs_values(workflow.outputs).items(): try: - val_out = lazy_field._get_value(exec_graph) + val_out = lazy_field._get_value(workflow=workflow, graph=exec_graph) output_wf[name] = val_out except (ValueError, AttributeError): output_wf[name] = None diff --git a/pydra/engine/state.py b/pydra/engine/state.py index 96c288118..11a4290bc 100644 --- a/pydra/engine/state.py +++ b/pydra/engine/state.py @@ -41,7 +41,13 @@ def __init__(self, indices: dict[str, int] | None = None): else: self.indices = OrderedDict(sorted(indices.items())) - def __repr__(self): + def __len__(self) -> int: + return len(self.indices) + + def __iter__(self) -> ty.Generator[str, None, None]: + return iter(self.indices) + + def __repr__(self) -> str: return ( "StateIndex(" + ", ".join(f"{n}={v}" for n, v in self.indices.items()) + ")" ) @@ -49,15 +55,49 @@ def __repr__(self): def __hash__(self): return hash(tuple(self.indices.items())) - def __eq__(self, other): + def __eq__(self, other) -> bool: return self.indices == other.indices - def __str__(self): + def __str__(self) -> str: return "__".join(f"{n}-{i}" for n, i in self.indices.items()) - def __bool__(self): + def __bool__(self) -> bool: return bool(self.indices) + def subset(self, state_names: ty.Iterable[str]) -> ty.Self: + """Create a new StateIndex with only the specified fields + + Parameters + ---------- + fields : list[str] + the fields to keep in the new StateIndex + + Returns + ------- + StateIndex + a new StateIndex with only the specified fields + """ + return type(self)({k: v for k, v in self.indices.items() if k in state_names}) + + def matches(self, other: "StateIndex") -> bool: + """Check if the indices that are present in the other StateIndex match + + Parameters + ---------- + other : StateIndex + the other StateIndex to compare against + + Returns + ------- + bool + True if all the indices in the other StateIndex match + """ + if not set(self.indices).issuperset(other.indices): + raise ValueError( + f"StateIndex {self} does not contain all the indices in {other}" + ) + return all(self.indices[k] == v for k, v in other.indices.items()) + class State: """ @@ -172,6 +212,9 @@ def __str__(self): def names(self): """Return the names of the states.""" # analysing states from connected tasks if inner_inputs + if not hasattr(self, "keys_final"): + self.prepare_states() + self.prepare_inputs() previous_states_keys = { f"_{v.name}": v.keys_final for v in self.inner_inputs.values() } @@ -190,13 +233,13 @@ def names(self): @property def depth(self) -> int: - """Return the number of uncombined splits of the state, i.e. the number nested + """Return the number of splits of the state, i.e. the number nested state arrays to wrap around the type of lazy out fields Returns ------- int - number of uncombined splits + number of uncombined independent splits (i.e. linked splits only add 1) """ depth = 0 stack = [] @@ -210,7 +253,8 @@ def depth(self) -> int: stack = [] else: stack.append(spl) - return depth + len(stack) + remaining_stack = [s for s in stack if s not in self.combiner] + return depth + len(remaining_stack) @property def splitter(self): diff --git a/pydra/engine/submitter.py b/pydra/engine/submitter.py index c9433e2d2..3236306a5 100644 --- a/pydra/engine/submitter.py +++ b/pydra/engine/submitter.py @@ -9,15 +9,18 @@ from copy import copy from datetime import datetime from collections import defaultdict +import attrs from .workers import Worker, WORKERS from .graph import DiGraph from .helpers import ( get_open_loop, list_fields, + attrs_values, ) from pydra.utils.hash import PersistentCache from .state import StateIndex from pydra.utils.typing import StateArray +from pydra.engine.lazy import LazyField from .audit import Audit from .core import Task from pydra.utils.messenger import AuditFlag, Messenger @@ -29,7 +32,8 @@ if ty.TYPE_CHECKING: from .node import Node - from .specs import TaskDef, TaskOutputs, WorkflowDef, TaskHooks, Result + from .specs import WorkflowDef, TaskDef, TaskOutputs, TaskHooks, Result + from .core import Workflow from .environments import Environment from .state import State @@ -498,7 +502,7 @@ class NodeExecution(ty.Generic[DefType]): _tasks: dict[StateIndex | None, "Task[DefType]"] | None - workflow_inputs: "WorkflowDef" + workflow: "Workflow" graph: DiGraph["NodeExecution"] | None @@ -506,7 +510,7 @@ def __init__( self, node: "Node", submitter: Submitter, - workflow_inputs: "WorkflowDef", + workflow: "Workflow", ): self.name = node.name self.node = node @@ -520,9 +524,17 @@ def __init__( self.running = {} # Not used in logic, but may be useful for progress tracking self.unrunnable = defaultdict(list) self.state_names = self.node.state.names if self.node.state else [] - self.workflow_inputs = workflow_inputs + self.workflow = workflow self.graph = None + def __repr__(self): + return ( + f"NodeExecution(name={self.name!r}, blocked={list(self.blocked)}, " + f"queued={list(self.queued)}, running={list(self.running)}, " + f"successful={list(self.successful)}, errored={list(self.errored)}, " + f"unrunnable={list(self.unrunnable)})" + ) + @property def inputs(self) -> "Node.Inputs": return self.node.inputs @@ -544,12 +556,16 @@ def tasks(self) -> ty.Iterable["Task[DefType]"]: def task(self, index: StateIndex = StateIndex()) -> "Task | list[Task[DefType]]": """Get a task object for a given state index.""" self.tasks # Ensure tasks are loaded - try: - return self._tasks[index] - except KeyError: - if not index: - return StateArray(self._tasks.values()) - raise + task_index = next(iter(self._tasks)) + if len(task_index) > len(index): + tasks = [] + for ind, task in self._tasks.items(): + if ind.matches(index): + tasks.append(task) + return StateArray(tasks) + elif len(index) > len(task_index): + index = index.subset(task_index) + return self._tasks[index] @property def started(self) -> bool: @@ -607,10 +623,7 @@ def all_failed(self) -> bool: def _generate_tasks(self) -> ty.Iterable["Task[DefType]"]: if not self.node.state: yield Task( - definition=self.node._definition._resolve_lazy_inputs( - workflow_inputs=self.workflow_inputs, - graph=self.graph, - ), + definition=self._resolve_lazy_inputs(task_def=self.node._definition), submitter=self.submitter, environment=self.node._environment, hooks=self.node._hooks, @@ -619,9 +632,8 @@ def _generate_tasks(self) -> ty.Iterable["Task[DefType]"]: else: for index, split_defn in self.node._split_definition().items(): yield Task( - definition=split_defn._resolve_lazy_inputs( - workflow_inputs=self.workflow_inputs, - graph=self.graph, + definition=self._resolve_lazy_inputs( + task_def=split_defn, state_index=index, ), submitter=self.submitter, @@ -631,6 +643,34 @@ def _generate_tasks(self) -> ty.Iterable["Task[DefType]"]: state_index=index, ) + def _resolve_lazy_inputs( + self, + task_def: "TaskDef", + state_index: "StateIndex | None" = None, + ) -> "TaskDef": + """Resolves lazy fields in the task definition by replacing them with their + actual values calculated by upstream jobs. + + Parameters + ---------- + task_def : TaskDef + The definition to resolve the lazy fields of + state_index : StateIndex, optional + The state index for the workflow, by default None + + Returns + ------- + TaskDef + The task definition with all lazy fields resolved + """ + resolved = {} + for name, value in attrs_values(task_def).items(): + if isinstance(value, LazyField): + resolved[name] = value._get_value( + workflow=self.workflow, graph=self.graph, state_index=state_index + ) + return attrs.evolve(task_def, **resolved) + def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]: """For a given node, check to see which tasks have been successfully run, are ready to run, can't be run due to upstream errors, or are blocked on other tasks to complete. @@ -651,19 +691,35 @@ def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]: runnable: list["Task[DefType]"] = [] self.tasks # Ensure tasks are loaded if not self.started: + assert self._tasks self.blocked = copy(self._tasks) # Check to see if any blocked tasks are now runnable/unrunnable for index, task in list(self.blocked.items()): pred: NodeExecution is_runnable = True for pred in graph.predecessors[self.node.name]: - if index not in pred.successful: + pred_jobs = pred.task(index) + if isinstance(pred_jobs, StateArray): + pred_inds = [j.state_index for j in pred_jobs] + else: + pred_inds = [pred_jobs.state_index] + if not all(i in pred.successful for i in pred_inds): is_runnable = False - if index in pred.errored: - self.unrunnable[index].append(self.blocked.pop(index)) - if index in pred.unrunnable: - self.unrunnable[index].extend(pred.unrunnable[index]) - self.blocked.pop(index) + blocked = True + if pred_errored := [i for i in pred_inds if i in pred.errored]: + self.unrunnable[index].extend( + [pred.errored[i] for i in pred_errored] + ) + blocked = False + if pred_unrunnable := [ + i for i in pred_inds if i in pred.unrunnable + ]: + self.unrunnable[index].extend( + [pred.unrunnable[i] for i in pred_unrunnable] + ) + blocked = False + if not blocked: + del self.blocked[index] break if is_runnable: runnable.append(self.blocked.pop(index)) diff --git a/pydra/engine/tests/test_specs.py b/pydra/engine/tests/test_specs.py index 950d68b8a..8b1407d23 100644 --- a/pydra/engine/tests/test_specs.py +++ b/pydra/engine/tests/test_specs.py @@ -1,32 +1,64 @@ from pathlib import Path import typing as ty -import os -import attrs -from unittest.mock import Mock - -# from copy import deepcopy import time +import pytest from fileformats.generic import File -from ..specs import ( +from pydra.engine.specs import ( Runtime, Result, + WorkflowDef, ) from pydra.engine.lazy import ( LazyInField, LazyOutField, ) - +from pydra.engine.core import Workflow +from pydra.engine.node import Node +from pydra.engine.submitter import Submitter, NodeExecution, DiGraph from pydra.utils.typing import StateArray - -# from ..helpers import make_klass -from .utils import Foo, BasicWorkflow from pydra.design import python, workflow -import pytest +from .utils import Foo, FunAddTwo, FunAddVar, ListSum + + +@workflow.define +def TestWorkflow(x: int, y: list[int]) -> int: + node_a = workflow.add(FunAddTwo(a=x), name="A") + node_b = workflow.add(FunAddVar(a=node_a.out).split(b=y).combine("b"), name="B") + node_c = workflow.add(ListSum(x=node_b.out), name="C") + return node_c.out + + +@pytest.fixture +def workflow_task(submitter: Submitter) -> WorkflowDef: + wf = TestWorkflow(x=1, y=[1, 2, 3]) + with submitter: + submitter(wf) + return wf + +@pytest.fixture +def wf(workflow_task: WorkflowDef) -> Workflow: + wf = Workflow.construct(workflow_task) + for n in wf.nodes: + if n._state: + n._state.prepare_states() + n._state.prepare_inputs() + return wf -# @python.define -# def Foo(a: str, b: int, c: float) -> str: -# return f"{a}{b}{c}" + +@pytest.fixture +def submitter(tmp_path) -> Submitter: + return Submitter(tmp_path) + + +@pytest.fixture +def graph(wf: Workflow, submitter: Submitter) -> DiGraph[NodeExecution]: + return wf.execution_graph(submitter=submitter) + + +@pytest.fixture +def node_a(wf) -> Node: + return wf["A"] # we can pick any node to retrieve the values to def test_runtime(): @@ -44,111 +76,34 @@ def test_result(tmp_path): assert getattr(result, "errored") is False -class NodeTesting: - @attrs.define() - class Input: - inp_a: str = "A" - inp_b: str = "B" - - def __init__(self): - class InpDef: - def __init__(self): - self.fields = [("inp_a", int), ("inp_b", int)] - - class Outputs: - def __init__(self): - self.fields = [("out_a", int)] - - self.name = "tn" - self.inputs = self.Input() - self.input_spec = InpDef() - self.output_spec = Outputs() - self.output_names = ["out_a"] - self.state = None - - def result(self, state_index=None): - class Output: - def __init__(self): - self.out_a = "OUT_A" - - class Result: - def __init__(self): - self.output = Output() - self.errored = False - - def get_output_field(self, field): - return getattr(self.output, field) - - return Result() - - -class WorkflowTesting: - def __init__(self): - class Input: - def __init__(self): - self.inp_a = "A" - self.inp_b = "B" - - self.inputs = Input() - self.tn = NodeTesting() - - -@pytest.fixture -def mock_node(): - node = Mock() - node.name = "tn" - node.definition = Foo(a="a", b=1, c=2.0) - return node - - -@pytest.fixture -def mock_workflow(): - mock_workflow = Mock() - mock_workflow.inputs = BasicWorkflow(x=1) - mock_workflow.outputs = BasicWorkflow.Outputs(out=attrs.NOTHING) - return mock_workflow - +def test_lazy_inp(wf: Workflow, graph: DiGraph[NodeExecution]): + lf = LazyInField(field="x", type=int, workflow=wf) + assert lf._get_value(workflow=wf, graph=graph) == 1 -def test_lazy_inp(mock_workflow): - lf = LazyInField(field="a", type=int, workflow=mock_workflow) - assert lf._get_value() == "a" + lf = LazyInField(field="y", type=str, workflow=wf) + assert lf._get_value(workflow=wf, graph=graph) == [1, 2, 3] - lf = LazyInField(field="b", type=str, workflow_def=mock_workflow) - assert lf._get_value() == 1 +def test_lazy_out(node_a, wf, graph): + lf = LazyOutField(field="out", type=int, node=node_a) + assert lf._get_value(wf, graph) == 3 -def test_lazy_out(): - tn = NodeTesting() - lzout = LazyOutField(task=tn) - lf = lzout.out_a - assert lf.get_value(wf=WorkflowTesting()) == "OUT_A" +def test_input_file_hash_1(tmp_path): -def test_lazy_getvale(): - tn = NodeTesting() - lf = LazyIn(task=tn) - with pytest.raises(Exception) as excinfo: - lf.inp_c - assert ( - str(excinfo.value) - == "Task 'tn' has no input attribute 'inp_c', available: 'inp_a', 'inp_b'" - ) + outfile = tmp_path / "test.file" + outfile.touch() + @python.define + def A(in_file: File) -> File: + return in_file -def test_input_file_hash_1(tmp_path): - os.chdir(tmp_path) - outfile = "test.file" - fields = [("in_file", ty.Any)] - input_spec = SpecInfo(name="Inputs", fields=fields, bases=(BaseDef,)) - inputs = make_klass(input_spec) - assert inputs(in_file=outfile).hash == "9a106eb2830850834d9b5bf098d5fa85" + assert A(in_file=outfile)._hash == "9644d3998748b339819c23ec6abec520" with open(outfile, "w") as fp: fp.write("test") - fields = [("in_file", File)] - input_spec = SpecInfo(name="Inputs", fields=fields, bases=(BaseDef,)) - inputs = make_klass(input_spec) - assert inputs(in_file=outfile).hash == "02fa5f6f1bbde7f25349f54335e1adaf" + + assert A(in_file=outfile)._hash == "9f7f9377ddef6d8c018f1bf8e89c208c" def test_input_file_hash_2(tmp_path): @@ -157,26 +112,26 @@ def test_input_file_hash_2(tmp_path): with open(file, "w") as f: f.write("hello") - input_spec = SpecInfo(name="Inputs", fields=[("in_file", File)], bases=(BaseDef,)) - inputs = make_klass(input_spec) + @python.define + def A(in_file: File) -> File: + return in_file # checking specific hash value - hash1 = inputs(in_file=file).hash - assert hash1 == "aaa50d60ed33d3a316d58edc882a34c3" + hash1 = A(in_file=file)._hash + assert hash1 == "179bd3cbdc747edc4957579376fe8c7d" # checking if different name doesn't affect the hash file_diffname = tmp_path / "in_file_2.txt" with open(file_diffname, "w") as f: f.write("hello") - hash2 = inputs(in_file=file_diffname).hash + hash2 = A(in_file=file_diffname)._hash assert hash1 == hash2 # checking if different content (the same name) affects the hash - time.sleep(2) # ensure mtime is different file_diffcontent = tmp_path / "in_file_1.txt" with open(file_diffcontent, "w") as f: f.write("hi") - hash3 = inputs(in_file=file_diffcontent).hash + hash3 = A(in_file=file_diffcontent)._hash assert hash1 != hash3 @@ -186,33 +141,31 @@ def test_input_file_hash_2a(tmp_path): with open(file, "w") as f: f.write("hello") - input_spec = SpecInfo( - name="Inputs", fields=[("in_file", ty.Union[File, int])], bases=(BaseDef,) - ) - inputs = make_klass(input_spec) + @python.define + def A(in_file: ty.Union[File, int]) -> File: + return in_file # checking specific hash value - hash1 = inputs(in_file=file).hash - assert hash1 == "aaa50d60ed33d3a316d58edc882a34c3" + hash1 = A(in_file=file)._hash + assert hash1 == "179bd3cbdc747edc4957579376fe8c7d" # checking if different name doesn't affect the hash file_diffname = tmp_path / "in_file_2.txt" with open(file_diffname, "w") as f: f.write("hello") - hash2 = inputs(in_file=file_diffname).hash + hash2 = A(in_file=file_diffname)._hash assert hash1 == hash2 + # checking if string is also accepted + hash3 = A(in_file=str(file))._hash + assert hash3 == hash1 + # checking if different content (the same name) affects the hash - time.sleep(2) # ensure mtime is different file_diffcontent = tmp_path / "in_file_1.txt" with open(file_diffcontent, "w") as f: f.write("hi") - hash3 = inputs(in_file=file_diffcontent).hash - assert hash1 != hash3 - - # checking if string is also accepted - hash4 = inputs(in_file=str(file)).hash - assert hash4 == "800af2b5b334c9e3e5c40c0e49b7ffb5" + hash4 = A(in_file=file_diffcontent)._hash + assert hash1 != hash4 def test_input_file_hash_3(tmp_path): @@ -221,22 +174,21 @@ def test_input_file_hash_3(tmp_path): with open(file, "w") as f: f.write("hello") - input_spec = SpecInfo( - name="Inputs", fields=[("in_file", File), ("in_int", int)], bases=(BaseDef,) - ) - inputs = make_klass(input_spec) + @python.define + def A(in_file: File, in_int: int) -> File: + return in_file, in_int - my_inp = inputs(in_file=file, in_int=3) + a = A(in_file=file, in_int=3) # original hash and files_hash (dictionary contains info about files) - hash1 = my_inp.hash + hash1 = a._hash # files_hash1 = deepcopy(my_inp.files_hash) # file name should be in files_hash1[in_file] filename = str(Path(file)) # assert filename in files_hash1["in_file"] # changing int input - my_inp.in_int = 5 - hash2 = my_inp.hash + a.in_int = 5 + hash2 = a._hash # files_hash2 = deepcopy(my_inp.files_hash) # hash should be different assert hash1 != hash2 @@ -249,7 +201,7 @@ def test_input_file_hash_3(tmp_path): with open(file, "w") as f: f.write("hello") - hash3 = my_inp.hash + hash3 = a._hash # files_hash3 = deepcopy(my_inp.files_hash) # hash should be the same, # but the entry for in_file in files_hash should be different (modification time) @@ -261,11 +213,11 @@ def test_input_file_hash_3(tmp_path): # assert files_hash3["in_file"][filename][1] == files_hash2["in_file"][filename][1] # setting the in_file again - my_inp.in_file = file + a.in_file = file # filename should be removed from files_hash # assert my_inp.files_hash["in_file"] == {} # will be saved again when hash is calculated - assert my_inp.hash == hash3 + assert a._hash == hash3 # assert filename in my_inp.files_hash["in_file"] @@ -277,26 +229,23 @@ def test_input_file_hash_4(tmp_path): with open(file, "w") as f: f.write("hello") - input_spec = SpecInfo( - name="Inputs", - fields=[("in_file", ty.List[ty.List[ty.Union[int, File]]])], - bases=(BaseDef,), - ) - inputs = make_klass(input_spec) + @python.define + def A(in_file: ty.List[ty.List[ty.Union[int, File]]]) -> File: + return in_file # checking specific hash value - hash1 = inputs(in_file=[[file, 3]]).hash - assert hash1 == "0693adbfac9f675af87e900065b1de00" + hash1 = A(in_file=[[file, 3]])._hash + assert hash1 == "ffd7afe0ca9d4585518809a509244b4b" # the same file, but int field changes - hash1a = inputs(in_file=[[file, 5]]).hash + hash1a = A(in_file=[[file, 5]])._hash assert hash1 != hash1a # checking if different name doesn't affect the hash file_diffname = tmp_path / "in_file_2.txt" with open(file_diffname, "w") as f: f.write("hello") - hash2 = inputs(in_file=[[file_diffname, 3]]).hash + hash2 = A(in_file=[[file_diffname, 3]])._hash assert hash1 == hash2 # checking if different content (the same name) affects the hash @@ -304,7 +253,7 @@ def test_input_file_hash_4(tmp_path): file_diffcontent = tmp_path / "in_file_1.txt" with open(file_diffcontent, "w") as f: f.write("hi") - hash3 = inputs(in_file=[[file_diffcontent, 3]]).hash + hash3 = A(in_file=[[file_diffcontent, 3]])._hash assert hash1 != hash3 @@ -314,26 +263,23 @@ def test_input_file_hash_5(tmp_path): with open(file, "w") as f: f.write("hello") - input_spec = SpecInfo( - name="Inputs", - fields=[("in_file", ty.List[ty.Dict[ty.Any, ty.Union[File, int]]])], - bases=(BaseDef,), - ) - inputs = make_klass(input_spec) + @python.define + def A(in_file: ty.List[ty.Dict[ty.Any, ty.Union[File, int]]]) -> File: + return in_file # checking specific hash value - hash1 = inputs(in_file=[{"file": file, "int": 3}]).hash - assert hash1 == "56e6e2c9f3bdf0cd5bd3060046dea480" + hash1 = A(in_file=[{"file": file, "int": 3}])._hash + assert hash1 == "ba884a74e33552854271f55b03e53947" # the same file, but int field changes - hash1a = inputs(in_file=[{"file": file, "int": 5}]).hash + hash1a = A(in_file=[{"file": file, "int": 5}])._hash assert hash1 != hash1a # checking if different name doesn't affect the hash file_diffname = tmp_path / "in_file_2.txt" with open(file_diffname, "w") as f: f.write("hello") - hash2 = inputs(in_file=[{"file": file_diffname, "int": 3}]).hash + hash2 = A(in_file=[{"file": file_diffname, "int": 3}])._hash assert hash1 == hash2 # checking if different content (the same name) affects the hash @@ -341,53 +287,15 @@ def test_input_file_hash_5(tmp_path): file_diffcontent = tmp_path / "in_file_1.txt" with open(file_diffcontent, "w") as f: f.write("hi") - hash3 = inputs(in_file=[{"file": file_diffcontent, "int": 3}]).hash + hash3 = A(in_file=[{"file": file_diffcontent, "int": 3}])._hash assert hash1 != hash3 -def test_lazy_field_cast(): - task = Foo(a="a", b=1, c=2.0, name="foo") - - assert task.lzout.y._type is int - assert workflow.cast(task.lzout.y, float)._type is float - - -def test_lazy_field_multi_same_split(): - @python.define - def f(x: ty.List[int]) -> ty.List[int]: - return x - - task = f(x=[1, 2, 3], name="foo") - - lf = task.lzout.out.split("foo.x") - - assert lf.type == StateArray[int] - assert lf.splits == set([(("foo.x",),)]) - - lf2 = lf.split("foo.x") - assert lf2.type == StateArray[int] - assert lf2.splits == set([(("foo.x",),)]) - - -def test_lazy_field_multi_diff_split(): - @python.define - def F(x: ty.Any, y: ty.Any) -> ty.Any: - return x - - task = F(x=[1, 2, 3], name="foo") - - lf = task.lzout.out.split("foo.x") - - assert lf.type == StateArray[ty.Any] - assert lf.splits == set([(("foo.x",),)]) - - lf2 = lf.split("foo.x") - assert lf2.type == StateArray[ty.Any] - assert lf2.splits == set([(("foo.x",),)]) +def test_lazy_field_cast(wf: Workflow): + lzout = wf.add(Foo(a="a", b=1, c=2.0), name="foo") - lf3 = lf.split("foo.y") - assert lf3.type == StateArray[StateArray[ty.Any]] - assert lf3.splits == set([(("foo.x",),), (("foo.y",),)]) + assert lzout.y._type is int + assert workflow.cast(lzout.y, float)._type is float def test_wf_lzin_split(tmp_path): diff --git a/pydra/engine/tests/utils.py b/pydra/engine/tests/utils.py index 9fc1d5f91..47ba21e4c 100644 --- a/pydra/engine/tests/utils.py +++ b/pydra/engine/tests/utils.py @@ -290,8 +290,8 @@ def FunFileList(filename_list: ty.List[File]): @workflow.define(outputs=["out"]) def BasicWorkflow(x): - task1 = workflow.add(FunAddTwo(a=x)) - task2 = workflow.add(FunAddVar(a=task1.out, b=2)) + task1 = workflow.add(FunAddTwo(a=x), name="A") + task2 = workflow.add(FunAddVar(a=task1.out, b=2), name="B") return task2.out diff --git a/pydra/utils/typing.py b/pydra/utils/typing.py index 7e39fc541..6c538efaa 100644 --- a/pydra/utils/typing.py +++ b/pydra/utils/typing.py @@ -7,7 +7,7 @@ import types import typing as ty import logging -import attr +import attrs from pydra.utils import add_exc_note from fileformats import field, core, generic @@ -217,8 +217,8 @@ def __call__(self, obj: ty.Any) -> T: from pydra.engine.helpers import is_lazy coerced: T - if obj is attr.NOTHING: - coerced = attr.NOTHING # type: ignore[assignment] + if obj is attrs.NOTHING: + coerced = attrs.NOTHING # type: ignore[assignment] elif is_lazy(obj): try: self.check_type(obj._type) @@ -279,8 +279,8 @@ def coerce(self, object_: ty.Any) -> T: def expand_and_coerce(obj, pattern: ty.Union[type, tuple]): """Attempt to expand the object along the lines of the coercion pattern""" - if obj is attr.NOTHING: - return attr.NOTHING + if obj is attrs.NOTHING: + return attrs.NOTHING if not isinstance(pattern, tuple): return coerce_basic(obj, pattern) origin, pattern_args = pattern