Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop: Test specs #770

Merged
merged 3 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions pydra/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,9 +817,7 @@ def node_names(self) -> list[str]:
def execution_graph(self, submitter: "Submitter") -> DiGraph:
from pydra.engine.submitter import NodeExecution

exec_nodes = [
NodeExecution(n, submitter, workflow_inputs=self.inputs) for n in self.nodes
]
exec_nodes = [NodeExecution(n, submitter, workflow=self) for n in self.nodes]
graph = self._create_graph(exec_nodes)
# Set the graph attribute of the nodes so lazy fields can be resolved as tasks
# are created
Expand Down
52 changes: 41 additions & 11 deletions pydra/engine/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
from . import node

if ty.TYPE_CHECKING:
from .graph import DiGraph
from .submitter import NodeExecution
from .submitter import DiGraph, NodeExecution
from .core import Task, Workflow
from .specs import TaskDef, WorkflowDef
from .specs import TaskDef
from .state import StateIndex


Expand Down Expand Up @@ -46,6 +45,30 @@ def _apply_cast(self, value):
value = self._type(value)
return value

def _get_value(
self,
workflow: "Workflow",
graph: "DiGraph[NodeExecution]",
state_index: "StateIndex | None" = None,
) -> ty.Any:
"""Return the value of a lazy field.

Parameters
----------
workflow: Workflow
the workflow object
graph: DiGraph[NodeExecution]
the graph representing the execution state of the workflow
state_index : StateIndex, optional
the state index of the field to access

Returns
-------
value : Any
the resolved value of the lazy-field
"""
raise NotImplementedError("LazyField is an abstract class")


@attrs.define(kw_only=True)
class LazyInField(LazyField[T]):
Expand All @@ -70,23 +93,27 @@ def _source(self):

def _get_value(
self,
workflow_def: "WorkflowDef",
workflow: "Workflow",
graph: "DiGraph[NodeExecution]",
state_index: "StateIndex | None" = None,
) -> ty.Any:
"""Return the value of a lazy field.

Parameters
----------
wf : Workflow
the workflow the lazy field references
state_index : int, optional
workflow: Workflow
the workflow object
graph: DiGraph[NodeExecution]
the graph representing the execution state of the workflow
state_index : StateIndex, optional
the state index of the field to access

Returns
-------
value : Any
the resolved value of the lazy-field
"""
value = workflow_def[self._field]
value = workflow.inputs[self._field]
value = self._apply_cast(value)
return value

Expand All @@ -105,16 +132,19 @@ def __repr__(self):

def _get_value(
self,
workflow: "Workflow",
graph: "DiGraph[NodeExecution]",
state_index: "StateIndex | None" = None,
) -> ty.Any:
"""Return the value of a lazy field.

Parameters
----------
wf : Workflow
the workflow the lazy field references
state_index : int, optional
workflow: Workflow
the workflow object
graph: DiGraph[NodeExecution]
the graph representing the execution state of the workflow
state_index : StateIndex, optional
the state index of the field to access

Returns
Expand Down
42 changes: 23 additions & 19 deletions pydra/engine/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,28 @@ def lzout(self) -> OutputType:
type=field.type,
)
outputs = self.inputs.Outputs(**lazy_fields)
# Flag the output lazy fields as being not typed checked (i.e. assigned to another
# node's inputs) yet

outpt: lazy.LazyOutField
for outpt in attrs_values(outputs).values():
outpt._type_checked = False
# Assign the current node to the lazy fields so they can access the state
outpt._node = self
# If the node has a non-empty state, wrap the type of the lazy field in
# a combination of an optional list and a number of nested StateArrays
# types based on the number of states the node is split over and whether
# it has a combiner
if self._state:
type_, _ = TypeParser.strip_splits(outpt._type)
if self._state.combiner:
type_ = list[type_]
for _ in range(self._state.depth - int(bool(self._state.combiner))):
type_ = StateArray[type_]
outpt._type = type_
# Flag the output lazy fields as being not typed checked (i.e. assigned to
# another node's inputs) yet. This is used to prevent the user from changing
# the type of the output after it has been accessed by connecting it to an
# output of an upstream node with additional state variables.
outpt._type_checked = False
self._lzout = outputs
self._wrap_lzout_types_in_state_arrays()
return outputs

@property
Expand Down Expand Up @@ -217,20 +231,6 @@ def _check_if_outputs_have_been_used(self, msg):
+ msg
)

def _wrap_lzout_types_in_state_arrays(self) -> None:
"""Wraps a types of the lazy out fields in a number of nested StateArray types
based on the number of states the node is split over"""
# Unwrap StateArray types from the output types
if not self.state:
return
outpt_lf: lazy.LazyOutField
for outpt_lf in attrs_values(self.lzout).values():
assert not outpt_lf._type_checked
type_, _ = TypeParser.strip_splits(outpt_lf._type)
for _ in range(self._state.depth):
type_ = StateArray[type_]
outpt_lf._type = type_

def _set_state(self) -> None:
# Add node name to state's splitter, combiner and cont_dim loaded from the def
splitter = self._definition._splitter
Expand Down Expand Up @@ -269,7 +269,11 @@ def _get_upstream_states(self) -> dict[str, tuple["State", list[str]]]:
"""Get the states of the upstream nodes that are connected to this node"""
upstream_states = {}
for inpt_name, val in self.input_values:
if isinstance(val, lazy.LazyOutField) and val._node.state:
if (
isinstance(val, lazy.LazyOutField)
and val._node.state
and val._node.state.depth
):
node: Node = val._node
# variables that are part of inner splitters should be treated as a containers
if node.state and f"{node.name}.{inpt_name}" in node.state.splitter:
Expand Down
39 changes: 1 addition & 38 deletions pydra/engine/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,12 @@
from pydra.utils.typing import StateArray, MultiInputObj
from pydra.design.base import Field, Arg, Out, RequirementSet, NO_DEFAULT
from pydra.design import shell
from pydra.engine.lazy import LazyInField, LazyOutField

if ty.TYPE_CHECKING:
from pydra.engine.core import Task
from pydra.engine.graph import DiGraph
from pydra.engine.submitter import NodeExecution
from pydra.engine.core import Workflow
from pydra.engine.state import StateIndex
from pydra.engine.environments import Environment
from pydra.engine.workers import Worker

Expand Down Expand Up @@ -476,41 +474,6 @@ def _compute_hashes(self) -> ty.Tuple[bytes, ty.Dict[str, bytes]]:
}
return hash_function(sorted(field_hashes.items())), field_hashes

def _resolve_lazy_inputs(
self,
workflow_inputs: "WorkflowDef",
graph: "DiGraph[NodeExecution]",
state_index: "StateIndex | None" = None,
) -> Self:
"""Resolves lazy fields in the task definition by replacing them with their
actual values.

Parameters
----------
workflow : Workflow
The workflow the task is part of
graph : DiGraph[NodeExecution]
The execution graph of the workflow
state_index : StateIndex, optional
The state index for the workflow, by default None

Returns
-------
Self
The task definition with all lazy fields resolved
"""
from pydra.engine.state import StateIndex

if state_index is None:
state_index = StateIndex()
resolved = {}
for name, value in attrs_values(self).items():
if isinstance(value, LazyInField):
resolved[name] = value._get_value(workflow_inputs)
elif isinstance(value, LazyOutField):
resolved[name] = value._get_value(graph, state_index)
return attrs.evolve(self, **resolved)

def _check_rules(self):
"""Check if all rules are satisfied."""

Expand Down Expand Up @@ -773,7 +736,7 @@ def _from_task(cls, task: "Task[WorkflowDef]") -> Self:
nodes_dict = {n.name: n for n in exec_graph.nodes}
for name, lazy_field in attrs_values(workflow.outputs).items():
try:
val_out = lazy_field._get_value(exec_graph)
val_out = lazy_field._get_value(workflow=workflow, graph=exec_graph)
output_wf[name] = val_out
except (ValueError, AttributeError):
output_wf[name] = None
Expand Down
58 changes: 51 additions & 7 deletions pydra/engine/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,63 @@ def __init__(self, indices: dict[str, int] | None = None):
else:
self.indices = OrderedDict(sorted(indices.items()))

def __repr__(self):
def __len__(self) -> int:
return len(self.indices)

def __iter__(self) -> ty.Generator[str, None, None]:
return iter(self.indices)

def __repr__(self) -> str:
return (
"StateIndex(" + ", ".join(f"{n}={v}" for n, v in self.indices.items()) + ")"
)

def __hash__(self):
return hash(tuple(self.indices.items()))

def __eq__(self, other):
def __eq__(self, other) -> bool:
return self.indices == other.indices

def __str__(self):
def __str__(self) -> str:
return "__".join(f"{n}-{i}" for n, i in self.indices.items())

def __bool__(self):
def __bool__(self) -> bool:
return bool(self.indices)

def subset(self, state_names: ty.Iterable[str]) -> ty.Self:
"""Create a new StateIndex with only the specified fields

Parameters
----------
fields : list[str]
the fields to keep in the new StateIndex

Returns
-------
StateIndex
a new StateIndex with only the specified fields
"""
return type(self)({k: v for k, v in self.indices.items() if k in state_names})

def matches(self, other: "StateIndex") -> bool:
"""Check if the indices that are present in the other StateIndex match

Parameters
----------
other : StateIndex
the other StateIndex to compare against

Returns
-------
bool
True if all the indices in the other StateIndex match
"""
if not set(self.indices).issuperset(other.indices):
raise ValueError(
f"StateIndex {self} does not contain all the indices in {other}"
)
return all(self.indices[k] == v for k, v in other.indices.items())


class State:
"""
Expand Down Expand Up @@ -172,6 +212,9 @@ def __str__(self):
def names(self):
"""Return the names of the states."""
# analysing states from connected tasks if inner_inputs
if not hasattr(self, "keys_final"):
self.prepare_states()
self.prepare_inputs()
previous_states_keys = {
f"_{v.name}": v.keys_final for v in self.inner_inputs.values()
}
Expand All @@ -190,13 +233,13 @@ def names(self):

@property
def depth(self) -> int:
"""Return the number of uncombined splits of the state, i.e. the number nested
"""Return the number of splits of the state, i.e. the number nested
state arrays to wrap around the type of lazy out fields

Returns
-------
int
number of uncombined splits
number of uncombined independent splits (i.e. linked splits only add 1)
"""
depth = 0
stack = []
Expand All @@ -210,7 +253,8 @@ def depth(self) -> int:
stack = []
else:
stack.append(spl)
return depth + len(stack)
remaining_stack = [s for s in stack if s not in self.combiner]
return depth + len(remaining_stack)

@property
def splitter(self):
Expand Down
Loading
Loading