Skip to content

Commit 7b64080

Browse files
committed
implemented state depth() implementation by handling the RPN representation properly
1 parent 5611651 commit 7b64080

File tree

2 files changed

+27
-28
lines changed

2 files changed

+27
-28
lines changed

pydra/engine/state.py

+25-22
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from copy import deepcopy
44
import itertools
55
from collections import OrderedDict
6-
from operator import itemgetter
76
from functools import reduce
87
import typing as ty
98
from . import helpers_state as hlpst
@@ -41,7 +40,7 @@ def __init__(self, indices: dict[str, int] | None = None):
4140
if indices is None:
4241
self.indices = OrderedDict()
4342
else:
44-
self.indices = OrderedDict(sorted(indices.items()))
43+
self.indices = OrderedDict(indices.items())
4544

4645
def __len__(self) -> int:
4746
return len(self.indices)
@@ -53,13 +52,12 @@ def __getitem__(self, key: str) -> int:
5352
return self.indices[key]
5453

5554
def __lt__(self, other: "StateIndex") -> bool:
56-
if set(self.indices) != set(other.indices):
55+
if list(self.indices) != list(other.indices):
5756
raise ValueError(
58-
f"StateIndex {self} does not contain the same indices as {other}"
57+
f"StateIndex {self} does not contain the same indices in the same order "
58+
f"as {other}: {list(self.indices)} != {list(other.indices)}"
5959
)
60-
return sorted(self.indices.items(), key=itemgetter(0)) < sorted(
61-
other.indices.items(), key=itemgetter(0)
62-
)
60+
return tuple(self.indices.items()) < tuple(other.indices.items())
6361

6462
def __repr__(self) -> str:
6563
return (
@@ -273,24 +271,29 @@ def depth(self, after_combine: bool = True) -> int:
273271
int
274272
number of splits in the state (i.e. linked splits only add 1)
275273
"""
276-
depth = 0
277-
stack = []
278274

279-
def included(s):
280-
return s not in self.combiner if after_combine else True
275+
# replace field names with 1 or 0 (1 if the field is included in the state)
276+
include_rpn = [
277+
(
278+
s
279+
if s in [".", "*"]
280+
else (int(s not in self.combiner) if after_combine else 1)
281+
)
282+
for s in self.splitter_rpn
283+
]
281284

282-
for spl in self.splitter_rpn:
283-
if spl in [".", "*"]:
284-
if spl == ".":
285-
depth += int(all(included(s) for s in stack))
286-
else:
287-
assert spl == "*"
288-
depth += len([s for s in stack if included(s)])
289-
stack = []
285+
stack = []
286+
for opr in include_rpn:
287+
if opr == ".":
288+
assert len(stack) >= 2
289+
stack.append(stack.pop() and stack.pop())
290+
elif opr == "*":
291+
assert len(stack) >= 2
292+
stack.append(stack.pop() + stack.pop())
290293
else:
291-
stack.append(spl)
292-
remaining_stack = [s for s in stack if included(s)]
293-
return depth + len(remaining_stack)
294+
stack.append(opr)
295+
assert len(stack) == 1
296+
return stack[0]
294297

295298
def nest_output_type(self, type_: type) -> type:
296299
"""Nests a type of an output field in a combination of lists and state-arrays

pydra/engine/tests/test_node_task.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -995,9 +995,7 @@ def test_task_state_6(plugin, tmp_path):
995995
assert not results.errored, "\n".join(results.errors["error message"])
996996

997997
# checking the results
998-
999-
for i, expected in enumerate([3, 2, 33, 12]):
1000-
assert results.outputs.out[i] == expected
998+
assert results.outputs.out == [3.0, 2.0, 33.0, 12.0]
1001999

10021000

10031001
def test_task_state_6a(plugin, tmp_path):
@@ -1014,9 +1012,7 @@ def test_task_state_6a(plugin, tmp_path):
10141012
assert not results.errored, "\n".join(results.errors["error message"])
10151013

10161014
# checking the results
1017-
1018-
for i, expected in enumerate([3, 2, 33, 12]):
1019-
assert results.outputs.out[i] == expected
1015+
assert results.outputs.out == [3.0, 2.0, 33.0, 12.0]
10201016

10211017

10221018
@pytest.mark.flaky(reruns=2) # when dask

0 commit comments

Comments
 (0)