Skip to content

Commit 4e331e5

Browse files
committed
debugging states and workflows
1 parent 54bbd47 commit 4e331e5

File tree

10 files changed

+1492
-1895
lines changed

10 files changed

+1492
-1895
lines changed

pydra/conftest.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def pytest_generate_tests(metafunc):
2020
if bool(shutil.which("sbatch")):
2121
Plugins = ["slurm"]
2222
else:
23-
Plugins = ["cf"]
23+
Plugins = ["debug"] # ["debug", "cf"]
2424
try:
2525
if metafunc.config.getoption("dask"):
2626
Plugins.append("dask")
@@ -50,7 +50,7 @@ def pytest_generate_tests(metafunc):
5050
elif bool(shutil.which("sbatch")):
5151
Plugins = ["slurm"]
5252
else:
53-
Plugins = ["cf"]
53+
Plugins = ["debug"] # ["debug", "cf"]
5454
try:
5555
if metafunc.config.getoption("psij"):
5656
Plugins.append("psij-" + metafunc.config.getoption("psij"))

pydra/design/tests/test_python.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
sort_key = attrgetter("name")
1212

1313

14-
def test_interface_wrap_function():
14+
def test_interface_wrap_function(tmp_path):
1515
def func(a: int) -> float:
1616
"""Sample function with inputs and outputs"""
1717
return a * 2
@@ -27,7 +27,7 @@ def func(a: int) -> float:
2727
]
2828
assert outputs == [python.out(name="out", type=float)]
2929
definition = SampleDef(a=1)
30-
outputs = definition()
30+
outputs = definition(cache_dir=tmp_path)
3131
assert outputs.out == 2.0
3232
with pytest.raises(TypeError):
3333
SampleDef(a=1.5)

pydra/engine/audit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self, audit_flags, messengers, messenger_args, develop=None):
3030
Base configuration of auditing.
3131
messengers : :obj:`pydra.util.messenger.Messenger`
3232
or list of :class:`pydra.util.messenger.Messenger`, optional
33-
Defify types of messenger used by Audit to send a message.
33+
Taskify types of messenger used by Audit to send a message.
3434
Could be `PrintMessenger`, `FileMessenger`, or `RemoteRESTMessenger`.
3535
messenger_args : :obj:`dict`, optional
3636
Optional arguments for the `Messenger.send` method.

pydra/engine/core.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -792,9 +792,8 @@ def execution_graph(self, submitter: "Submitter") -> DiGraph:
792792
node.graph = graph
793793
return graph
794794

795-
@property
796-
def graph(self) -> DiGraph:
797-
return self._create_graph(self.nodes, detailed=True)
795+
def graph(self, detailed: bool = False) -> DiGraph:
796+
return self._create_graph(self.nodes, detailed=detailed)
798797

799798
def _create_graph(
800799
self, nodes: "list[Node | NodeExecution]", detailed: bool = False

pydra/engine/graph.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def nodes(self) -> list[NodeType]:
8181
def nodes(self, nodes: ty.Iterable[NodeType]) -> None:
8282
if nodes:
8383
nodes = ensure_list(nodes)
84-
if len(set(nodes)) != len(nodes):
85-
raise Exception("nodes have repeated elements")
84+
# if len(set(nodes)) != len(nodes):
85+
# raise Exception("nodes have repeated elements")
8686
self._nodes = nodes
8787

8888
def node(self, name: str) -> NodeType:

pydra/engine/helpers.py

+5-14
Original file line numberDiff line numberDiff line change
@@ -47,26 +47,17 @@ def plot_workflow(
4747

4848
# Construct the workflow object
4949
wf = Workflow.construct(workflow_task)
50-
graph = wf.graph
50+
5151
if not name:
52-
name = f"graph_{wf._node.name}"
52+
name = f"graph_{type(workflow_task).__name__}"
5353
if type == "simple":
54-
for task in graph.nodes:
55-
wf.create_connections(task)
54+
graph = wf.graph()
5655
dotfile = graph.create_dotfile_simple(outdir=out_dir, name=name)
5756
elif type == "nested":
58-
for task in graph.nodes:
59-
wf.create_connections(task)
57+
graph = wf.graph()
6058
dotfile = graph.create_dotfile_nested(outdir=out_dir, name=name)
6159
elif type == "detailed":
62-
# create connections with detailed=True
63-
for task in graph.nodes:
64-
wf.create_connections(task, detailed=True)
65-
# adding wf outputs
66-
for wf_out, lf in wf._connections:
67-
graph.add_edges_description(
68-
(wf._node.name, wf_out, lf._node.name, lf.field)
69-
)
60+
graph = wf.graph(detailed=True)
7061
dotfile = graph.create_dotfile_detailed(outdir=out_dir, name=name)
7162
else:
7263
raise Exception(

pydra/engine/lazy.py

+19-17
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def _get_value(
151151
value : Any
152152
the resolved value of the lazy-field
153153
"""
154-
154+
state = self._node.state
155155
jobs = graph.node(self._node.name).get_jobs(state_index)
156156

157157
def retrieve_from_job(job: "Task[DefType]") -> ty.Any:
@@ -184,25 +184,27 @@ def retrieve_from_job(job: "Task[DefType]") -> ty.Any:
184184
val = self._apply_cast(val)
185185
return val
186186

187-
if not isinstance(jobs, StateArray):
187+
if not isinstance(jobs, StateArray): # single job
188188
return retrieve_from_job(jobs)
189-
elif not self._node.state or not self._node.state.depth(before_combine=True):
189+
elif not state or not state.depth(before_combine=True):
190190
assert len(jobs) == 1
191191
return retrieve_from_job(jobs[0])
192-
elif not self._node.state.keys_final: # all states are combined over
193-
return [retrieve_from_job(j) for j in jobs]
194-
elif self._node.state.combiner:
195-
sorted_values = {
196-
frozenset(i.items()): [] for i in self._node.state.states_ind_final
197-
}
198-
assert len(jobs) == len(self._node.state.inputs_ind)
199-
for ind, job in zip(self._node.state.inputs_ind, jobs):
200-
sorted_values[
201-
frozenset((key, ind[key]) for key in self._node.state.keys_final)
202-
].append(retrieve_from_job(job))
203-
return StateArray(sorted_values.values())
204-
else:
205-
return StateArray(retrieve_from_job(j) for j in jobs)
192+
# elif state.combiner and state.keys_final:
193+
# # We initialise it here rather than using a defaultdict to ensure the order
194+
# # of the keys matches how it is defined in the state so we can return the
195+
# # values in the correct order
196+
# sorted_values = {frozenset(i.items()): [] for i in state.states_ind_final}
197+
# # Iterate through the jobs and append the values to the correct final state
198+
# # key
199+
# for job in jobs:
200+
# state_key = frozenset(
201+
# (key, state.states_ind[job.state_index][key])
202+
# for key in state.keys_final
203+
# )
204+
# sorted_values[state_key].append(retrieve_from_job(job))
205+
# return StateArray(sorted_values.values())
206+
# else:
207+
return [retrieve_from_job(j) for j in jobs]
206208

207209
@property
208210
def _source(self):

pydra/engine/node.py

-6
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,6 @@ class Node(ty.Generic[OutputType]):
4545
init=False, default=None, eq=False, hash=False, repr=False
4646
)
4747
_state: State | None = attrs.field(init=False, default=NOT_SET)
48-
# _cont_dim: dict[str, int] | None = attrs.field(
49-
# init=False, default=None
50-
# ) # QUESTION: should this be included in the state?
51-
# _inner_cont_dim: dict[str, int] = attrs.field(
52-
# init=False, factory=dict
53-
# ) # QUESTION: should this be included in the state?
5448

5549
def __attrs_post_init__(self):
5650
self._set_state()

pydra/engine/submitter.py

+39-25
Original file line numberDiff line numberDiff line change
@@ -569,39 +569,39 @@ def tasks(self) -> ty.Generator["Task[DefType]", None, None]:
569569
self._tasks = {t.state_index: t for t in self._generate_tasks()}
570570
return self._tasks.values()
571571

572-
def get_jobs(
573-
self, index: int | None = None, as_array: bool = False
574-
) -> "Task | StateArray[Task]":
572+
def get_jobs(self, final_index: int | None = None) -> "Task | StateArray[Task]":
575573
"""Get the jobs that match a given state index.
576574
577575
Parameters
578576
----------
579-
index : int, optional
580-
The index of the state of the task to get, by default None
581-
as_array : bool, optional
582-
Whether to return the tasks in a state-array object, by default if the index
583-
matches
577+
final_index : int, optional
578+
The index of the output state array (i.e. after any combinations) of the
579+
job to get, by default None
584580
585581
Returns
586582
-------
587583
matching : Task | StateArray[Task]
588584
The task or tasks that match the given index
589585
"""
590-
matching = StateArray()
591-
if self.tasks:
592-
try:
593-
task = self._tasks[index]
594-
except KeyError:
595-
if index is None:
596-
return StateArray(self._tasks.values())
597-
# Select matching tasks and return them in nested state-array objects
598-
for ind, task in self._tasks.items():
599-
matching.append(task)
600-
else:
601-
if not as_array:
602-
return task
603-
matching.append(task)
604-
return matching
586+
if not self.tasks: # No jobs, return empty state array
587+
return StateArray()
588+
if not self.node.state: # Return the singular job
589+
assert final_index is None
590+
task = self._tasks[None]
591+
return task
592+
if final_index is None: # return all jobs in a state array
593+
return StateArray(self._tasks.values())
594+
if not self.node.state.combiner: # Select the job that matches the index
595+
task = self._tasks[final_index]
596+
return task
597+
# Get a slice of the tasks that match the given index of the state array of the
598+
# combined values
599+
final_index = set(self.node.state.states_ind_final[final_index].items())
600+
return StateArray(
601+
self._tasks[i]
602+
for i, ind in enumerate(self.node.state.states_ind)
603+
if set(ind.items()).issuperset(final_index)
604+
)
605605

606606
@property
607607
def started(self) -> bool:
@@ -762,9 +762,23 @@ def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]:
762762
for index, task in list(self.blocked.items()):
763763
pred: NodeExecution
764764
is_runnable = True
765+
states_ind = (
766+
list(self.node.state.states_ind[index].items())
767+
if self.node.state
768+
else []
769+
)
765770
for pred in graph.predecessors[self.node.name]:
766-
pred_jobs: StateArray[Task] = pred.get_jobs(index, as_array=True)
767-
pred_inds = [j.state_index for j in pred_jobs]
771+
if pred.node.state:
772+
pred_states_ind = {
773+
(k, i) for k, i in states_ind if k.startswith(pred.name + ".")
774+
}
775+
pred_inds = [
776+
i
777+
for i, ind in enumerate(pred.node.state.states_ind)
778+
if set(ind.items()).issuperset(pred_states_ind)
779+
]
780+
else:
781+
pred_inds = [None]
768782
if not all(i in pred.successful for i in pred_inds):
769783
is_runnable = False
770784
blocked = True

0 commit comments

Comments
 (0)