Skip to content

Commit 35c1594

Browse files
authored
Merge pull request #772 from nipype/test_task
Updated test_shelltask, test_shelltask_inputspec, test_workflow to new syntax
2 parents a5f6e45 + bddb1f1 commit 35c1594

18 files changed

+6269
-8530
lines changed

pydra/design/base.py

+32-20
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ def convert_default_value(value: ty.Any, self_: "Field") -> ty.Any:
6262
return TypeParser[self_.type](self_.type, label=self_.name)(value)
6363

6464

65+
def allowed_values_converter(value: ty.Iterable[str] | None) -> list[str] | None:
66+
"""Ensure the allowed_values field is a list of strings or None"""
67+
if value is None:
68+
return None
69+
return list(value)
70+
71+
6572
@attrs.define
6673
class Requirement:
6774
"""Define a requirement for a task input field
@@ -76,14 +83,19 @@ class Requirement:
7683
"""
7784

7885
name: str
79-
allowed_values: list[str] = attrs.field(factory=list, converter=list)
86+
allowed_values: list[str] | None = attrs.field(
87+
default=None, converter=allowed_values_converter
88+
)
8089

8190
def satisfied(self, inputs: "TaskDef") -> bool:
8291
"""Check if the requirement is satisfied by the inputs"""
8392
value = getattr(inputs, self.name)
84-
if value is attrs.NOTHING:
93+
field = {f.name: f for f in list_fields(inputs)}[self.name]
94+
if value is attrs.NOTHING or field.type is bool and value is False:
8595
return False
86-
return not self.allowed_values or value in self.allowed_values
96+
if self.allowed_values is None:
97+
return True
98+
return value in self.allowed_values
8799

88100
@classmethod
89101
def parse(cls, value: ty.Any) -> Self:
@@ -350,8 +362,8 @@ def get_fields(klass, field_type, auto_attribs, helps) -> dict[str, Field]:
350362

351363
if not issubclass(klass, spec_type):
352364
raise ValueError(
353-
f"The canonical form of {spec_type.__module__.split('.')[-1]} task definitions, "
354-
f"{klass}, must inherit from {spec_type}"
365+
f"When using the canonical form for {spec_type.__module__.split('.')[-1]} "
366+
f"tasks, {klass} must inherit from {spec_type}"
355367
)
356368

357369
inputs = get_fields(klass, arg_type, auto_attribs, input_helps)
@@ -364,8 +376,8 @@ def get_fields(klass, field_type, auto_attribs, helps) -> dict[str, Field]:
364376
) from None
365377
if not issubclass(outputs_klass, outputs_type):
366378
raise ValueError(
367-
f"The canonical form of {spec_type.__module__.split('.')[-1]} task definitions, "
368-
f"{klass}, must inherit from {spec_type}"
379+
f"When using the canonical form for {outputs_type.__module__.split('.')[-1]} "
380+
f"task outputs {outputs_klass}, you must inherit from {outputs_type}"
369381
)
370382

371383
output_helps, _ = parse_doc_string(outputs_klass.__doc__)
@@ -416,10 +428,12 @@ def make_task_def(
416428

417429
spec_type._check_arg_refs(inputs, outputs)
418430

431+
# Check that the field attributes are valid after all fields have been set
432+
# (especially the type)
419433
for inpt in inputs.values():
420-
set_none_default_if_optional(inpt)
421-
for outpt in inputs.values():
422-
set_none_default_if_optional(outpt)
434+
attrs.validate(inpt)
435+
for outpt in outputs.values():
436+
attrs.validate(outpt)
423437

424438
if name is None and klass is not None:
425439
name = klass.__name__
@@ -459,10 +473,10 @@ def make_task_def(
459473
if getattr(arg, "path_template", False):
460474
if is_optional(arg.type):
461475
field_type = Path | bool | None
462-
# Will default to None and not be inserted into the command
476+
attrs_kwargs = {"default": None}
463477
else:
464478
field_type = Path | bool
465-
attrs_kwargs = {"default": True}
479+
attrs_kwargs = {"default": True} # use the template by default
466480
elif is_optional(arg.type):
467481
field_type = Path | None
468482
else:
@@ -988,12 +1002,10 @@ def check_explicit_fields_are_none(klass, inputs, outputs):
9881002

9891003
def _get_attrs_kwargs(field: Field) -> dict[str, ty.Any]:
9901004
kwargs = {}
991-
if not hasattr(field, "default"):
992-
kwargs["factory"] = nothing_factory
993-
elif field.default is not NO_DEFAULT:
1005+
if field.default is not NO_DEFAULT:
9941006
kwargs["default"] = field.default
995-
elif is_optional(field.type):
996-
kwargs["default"] = None
1007+
# elif is_optional(field.type):
1008+
# kwargs["default"] = None
9971009
else:
9981010
kwargs["factory"] = nothing_factory
9991011
if field.hash_eq:
@@ -1005,9 +1017,9 @@ def nothing_factory():
10051017
return attrs.NOTHING
10061018

10071019

1008-
def set_none_default_if_optional(field: Field) -> None:
1009-
if is_optional(field.type) and field.default is NO_DEFAULT:
1010-
field.default = None
1020+
# def set_none_default_if_optional(field: Field) -> None:
1021+
# if is_optional(field.type) and field.default is NO_DEFAULT:
1022+
# field.default = None
10111023

10121024

10131025
white_space_re = re.compile(r"\s+")

pydra/design/shell.py

+25-13
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,16 @@ def _sep_default(self):
110110

111111
@sep.validator
112112
def _validate_sep(self, _, sep):
113-
if self.type is ty.Any:
114-
return
115-
if ty.get_origin(self.type) is MultiInputObj:
113+
if self.type is MultiInputObj:
114+
tp = ty.Any
115+
elif ty.get_origin(self.type) is MultiInputObj:
116116
tp = ty.get_args(self.type)[0]
117117
else:
118118
tp = self.type
119119
if is_optional(tp):
120120
tp = optional_type(tp)
121+
if tp is ty.Any:
122+
return
121123
origin = ty.get_origin(tp) or tp
122124

123125
if (
@@ -238,16 +240,21 @@ class outarg(arg, Out):
238240

239241
@path_template.validator
240242
def _validate_path_template(self, attribute, value):
241-
if value and self.default not in (NO_DEFAULT, True, None):
242-
raise ValueError(
243-
f"path_template ({value!r}) can only be provided when no default "
244-
f"({self.default!r}) is provided"
245-
)
246-
if value and not (is_fileset_or_union(self.type) or self.type is ty.Any):
247-
raise ValueError(
248-
f"path_template ({value!r}) can only be provided when type is a FileSet, "
249-
f"or union thereof, not {self.type!r}"
250-
)
243+
if value:
244+
if self.default not in (NO_DEFAULT, True, None):
245+
raise ValueError(
246+
f"path_template ({value!r}) can only be provided when no default "
247+
f"({self.default!r}) is provided"
248+
)
249+
if not (is_fileset_or_union(self.type) or self.type is ty.Any):
250+
raise ValueError(
251+
f"path_template ({value!r}) can only be provided when type is a FileSet, "
252+
f"or union thereof, not {self.type!r}"
253+
)
254+
if self.argstr is None:
255+
raise ValueError(
256+
f"path_template ({value!r}) can only be provided when argstr is not None"
257+
)
251258

252259
@keep_extension.validator
253260
def _validate_keep_extension(self, attribute, value):
@@ -386,6 +393,7 @@ def make(
386393
input_helps=input_helps,
387394
output_helps=output_helps,
388395
)
396+
389397
if name:
390398
class_name = name
391399
else:
@@ -679,6 +687,10 @@ def from_type_str(type_str) -> type:
679687
if ext_type.ext is not None:
680688
path_template = name + ext_type.ext
681689
kwds["path_template"] = path_template
690+
# Set the default value to None if the field is optional and no default is
691+
# provided
692+
if is_optional(type_) and "default" not in kwds:
693+
kwds["default"] = None
682694
if option is None:
683695
add_arg(name, field_type, kwds)
684696
else:

pydra/design/tests/test_shell.py

-1
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,6 @@ def test_interface_template_with_type_overrides():
434434
name="int_arg",
435435
argstr="--int-arg",
436436
type=int | None,
437-
default=None,
438437
position=5,
439438
),
440439
shell.arg(

pydra/engine/core.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def run(self, rerun: bool = False):
371371
self.audit.audit_task(task=self)
372372
try:
373373
self.audit.monitor()
374-
self.definition._run(self)
374+
self.definition._run(self, rerun)
375375
result.outputs = self.definition.Outputs._from_task(self)
376376
except Exception:
377377
etype, eval, etr = sys.exc_info()
@@ -425,7 +425,7 @@ async def run_async(self, rerun: bool = False) -> Result:
425425
self.audit.start_audit(odir=self.output_dir)
426426
try:
427427
self.audit.monitor()
428-
await self.definition._run_async(self)
428+
await self.definition._run_async(self, rerun)
429429
result.outputs = self.definition.Outputs._from_task(self)
430430
except Exception:
431431
etype, eval, etr = sys.exc_info()
@@ -628,8 +628,7 @@ def clear_cache(
628628

629629
@classmethod
630630
def construct(
631-
cls,
632-
definition: WorkflowDef[WorkflowOutputsType],
631+
cls, definition: WorkflowDef[WorkflowOutputsType], dont_cache: bool = False
633632
) -> Self:
634633
"""Construct a workflow from a definition, caching the constructed worklow"""
635634

@@ -710,7 +709,7 @@ def construct(
710709
f"{len(output_lazy_fields)} ({output_lazy_fields})"
711710
)
712711
for outpt, outpt_lf in zip(output_fields, output_lazy_fields):
713-
# Automatically combine any uncombined state arrays into lists
712+
# Automatically combine any uncombined state arrays into a single lists
714713
if TypeParser.get_origin(outpt_lf._type) is StateArray:
715714
outpt_lf._type = list[TypeParser.strip_splits(outpt_lf._type)[0]]
716715
setattr(outputs, outpt.name, outpt_lf)
@@ -722,8 +721,8 @@ def construct(
722721
f"Expected outputs {unset_outputs} to be set by the "
723722
f"constructor of {workflow!r}"
724723
)
725-
726-
cls._constructed_cache[defn_hash][non_lazy_keys][non_lazy_hash] = workflow
724+
if not dont_cache:
725+
cls._constructed_cache[defn_hash][non_lazy_keys][non_lazy_hash] = workflow
727726

728727
return workflow
729728

@@ -735,8 +734,7 @@ def under_construction(cls) -> "Workflow[ty.Any]":
735734
# Find the frame where the construct method was called
736735
if (
737736
frame.f_code.co_name == "construct"
738-
and "cls" in frame.f_locals
739-
and frame.f_locals["cls"] is cls
737+
and frame.f_locals.get("cls") is cls
740738
and "workflow" in frame.f_locals
741739
):
742740
return frame.f_locals["workflow"] # local var "workflow" in construct

pydra/engine/helpers_file.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,7 @@ def template_update(
135135
if isinstance(field, shell.outarg)
136136
and field.path_template
137137
and getattr(definition, field.name)
138-
and all(
139-
getattr(definition, required_field) for required_field in field.requires
140-
)
138+
and all(req.satisfied(definition) for req in field.requires)
141139
]
142140

143141
dict_mod = {}

pydra/engine/node.py

+2-22
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def lzout(self) -> OutputType:
134134
type_, _ = TypeParser.strip_splits(outpt._type)
135135
if self._state.combiner:
136136
type_ = list[type_]
137-
for _ in range(self._state.depth - int(bool(self._state.combiner))):
137+
for _ in range(self._state.depth()):
138138
type_ = StateArray[type_]
139139
outpt._type = type_
140140
# Flag the output lazy fields as being not typed checked (i.e. assigned to
@@ -272,7 +272,7 @@ def _get_upstream_states(self) -> dict[str, tuple["State", list[str]]]:
272272
if (
273273
isinstance(val, lazy.LazyOutField)
274274
and val._node.state
275-
and val._node.state.depth
275+
and val._node.state.depth()
276276
):
277277
node: Node = val._node
278278
# variables that are part of inner splitters should be treated as a containers
@@ -305,26 +305,6 @@ def _extract_input_el(self, inputs, inp_nm, ind):
305305
else:
306306
return getattr(inputs, inp_nm)[ind]
307307

308-
def _split_definition(self) -> dict[StateIndex, "TaskDef[OutputType]"]:
309-
"""Split the definition into the different states it will be run over"""
310-
# TODO: doesn't work properly for more cmplicated wf (check if still an issue)
311-
if not self.state:
312-
return {None: self._definition}
313-
split_defs = {}
314-
for input_ind in self.state.inputs_ind:
315-
inputs_dict = {}
316-
for inp in set(self.input_names):
317-
if f"{self.name}.{inp}" in input_ind:
318-
inputs_dict[inp] = self._extract_input_el(
319-
inputs=self._definition,
320-
inp_nm=inp,
321-
ind=input_ind[f"{self.name}.{inp}"],
322-
)
323-
split_defs[StateIndex(input_ind)] = attrs.evolve(
324-
self._definition, **inputs_dict
325-
)
326-
return split_defs
327-
328308
# else:
329309
# # todo it never gets here
330310
# breakpoint()

0 commit comments

Comments
 (0)