Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated test_shelltask, test_shelltask_inputspec, test_workflow to new syntax #772

Merged
merged 20 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 32 additions & 20 deletions pydra/design/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ def convert_default_value(value: ty.Any, self_: "Field") -> ty.Any:
return TypeParser[self_.type](self_.type, label=self_.name)(value)


def allowed_values_converter(value: ty.Iterable[str] | None) -> list[str] | None:
"""Ensure the allowed_values field is a list of strings or None"""
if value is None:
return None
return list(value)


@attrs.define
class Requirement:
"""Define a requirement for a task input field
Expand All @@ -76,14 +83,19 @@ class Requirement:
"""

name: str
allowed_values: list[str] = attrs.field(factory=list, converter=list)
allowed_values: list[str] | None = attrs.field(
default=None, converter=allowed_values_converter
)

def satisfied(self, inputs: "TaskDef") -> bool:
"""Check if the requirement is satisfied by the inputs"""
value = getattr(inputs, self.name)
if value is attrs.NOTHING:
field = {f.name: f for f in list_fields(inputs)}[self.name]
if value is attrs.NOTHING or field.type is bool and value is False:
return False
return not self.allowed_values or value in self.allowed_values
if self.allowed_values is None:
return True
return value in self.allowed_values

@classmethod
def parse(cls, value: ty.Any) -> Self:
Expand Down Expand Up @@ -350,8 +362,8 @@ def get_fields(klass, field_type, auto_attribs, helps) -> dict[str, Field]:

if not issubclass(klass, spec_type):
raise ValueError(
f"The canonical form of {spec_type.__module__.split('.')[-1]} task definitions, "
f"{klass}, must inherit from {spec_type}"
f"When using the canonical form for {spec_type.__module__.split('.')[-1]} "
f"tasks, {klass} must inherit from {spec_type}"
)

inputs = get_fields(klass, arg_type, auto_attribs, input_helps)
Expand All @@ -364,8 +376,8 @@ def get_fields(klass, field_type, auto_attribs, helps) -> dict[str, Field]:
) from None
if not issubclass(outputs_klass, outputs_type):
raise ValueError(
f"The canonical form of {spec_type.__module__.split('.')[-1]} task definitions, "
f"{klass}, must inherit from {spec_type}"
f"When using the canonical form for {outputs_type.__module__.split('.')[-1]} "
f"task outputs {outputs_klass}, you must inherit from {outputs_type}"
)

output_helps, _ = parse_doc_string(outputs_klass.__doc__)
Expand Down Expand Up @@ -416,10 +428,12 @@ def make_task_def(

spec_type._check_arg_refs(inputs, outputs)

# Check that the field attributes are valid after all fields have been set
# (especially the type)
for inpt in inputs.values():
set_none_default_if_optional(inpt)
for outpt in inputs.values():
set_none_default_if_optional(outpt)
attrs.validate(inpt)
for outpt in outputs.values():
attrs.validate(outpt)

if name is None and klass is not None:
name = klass.__name__
Expand Down Expand Up @@ -459,10 +473,10 @@ def make_task_def(
if getattr(arg, "path_template", False):
if is_optional(arg.type):
field_type = Path | bool | None
# Will default to None and not be inserted into the command
attrs_kwargs = {"default": None}
else:
field_type = Path | bool
attrs_kwargs = {"default": True}
attrs_kwargs = {"default": True} # use the template by default
elif is_optional(arg.type):
field_type = Path | None
else:
Expand Down Expand Up @@ -988,12 +1002,10 @@ def check_explicit_fields_are_none(klass, inputs, outputs):

def _get_attrs_kwargs(field: Field) -> dict[str, ty.Any]:
kwargs = {}
if not hasattr(field, "default"):
kwargs["factory"] = nothing_factory
elif field.default is not NO_DEFAULT:
if field.default is not NO_DEFAULT:
kwargs["default"] = field.default
elif is_optional(field.type):
kwargs["default"] = None
# elif is_optional(field.type):
# kwargs["default"] = None
else:
kwargs["factory"] = nothing_factory
if field.hash_eq:
Expand All @@ -1005,9 +1017,9 @@ def nothing_factory():
return attrs.NOTHING


def set_none_default_if_optional(field: Field) -> None:
if is_optional(field.type) and field.default is NO_DEFAULT:
field.default = None
# def set_none_default_if_optional(field: Field) -> None:
# if is_optional(field.type) and field.default is NO_DEFAULT:
# field.default = None


white_space_re = re.compile(r"\s+")
38 changes: 25 additions & 13 deletions pydra/design/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,16 @@ def _sep_default(self):

@sep.validator
def _validate_sep(self, _, sep):
if self.type is ty.Any:
return
if ty.get_origin(self.type) is MultiInputObj:
if self.type is MultiInputObj:
tp = ty.Any
elif ty.get_origin(self.type) is MultiInputObj:
tp = ty.get_args(self.type)[0]
else:
tp = self.type
if is_optional(tp):
tp = optional_type(tp)
if tp is ty.Any:
return
origin = ty.get_origin(tp) or tp

if (
Expand Down Expand Up @@ -238,16 +240,21 @@ class outarg(arg, Out):

@path_template.validator
def _validate_path_template(self, attribute, value):
if value and self.default not in (NO_DEFAULT, True, None):
raise ValueError(
f"path_template ({value!r}) can only be provided when no default "
f"({self.default!r}) is provided"
)
if value and not (is_fileset_or_union(self.type) or self.type is ty.Any):
raise ValueError(
f"path_template ({value!r}) can only be provided when type is a FileSet, "
f"or union thereof, not {self.type!r}"
)
if value:
if self.default not in (NO_DEFAULT, True, None):
raise ValueError(
f"path_template ({value!r}) can only be provided when no default "
f"({self.default!r}) is provided"
)
if not (is_fileset_or_union(self.type) or self.type is ty.Any):
raise ValueError(
f"path_template ({value!r}) can only be provided when type is a FileSet, "
f"or union thereof, not {self.type!r}"
)
if self.argstr is None:
raise ValueError(
f"path_template ({value!r}) can only be provided when argstr is not None"
)

@keep_extension.validator
def _validate_keep_extension(self, attribute, value):
Expand Down Expand Up @@ -386,6 +393,7 @@ def make(
input_helps=input_helps,
output_helps=output_helps,
)

if name:
class_name = name
else:
Expand Down Expand Up @@ -679,6 +687,10 @@ def from_type_str(type_str) -> type:
if ext_type.ext is not None:
path_template = name + ext_type.ext
kwds["path_template"] = path_template
# Set the default value to None if the field is optional and no default is
# provided
if is_optional(type_) and "default" not in kwds:
kwds["default"] = None
if option is None:
add_arg(name, field_type, kwds)
else:
Expand Down
1 change: 0 additions & 1 deletion pydra/design/tests/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,6 @@ def test_interface_template_with_type_overrides():
name="int_arg",
argstr="--int-arg",
type=int | None,
default=None,
position=5,
),
shell.arg(
Expand Down
16 changes: 7 additions & 9 deletions pydra/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def run(self, rerun: bool = False):
self.audit.audit_task(task=self)
try:
self.audit.monitor()
self.definition._run(self)
self.definition._run(self, rerun)
result.outputs = self.definition.Outputs._from_task(self)
except Exception:
etype, eval, etr = sys.exc_info()
Expand Down Expand Up @@ -425,7 +425,7 @@ async def run_async(self, rerun: bool = False) -> Result:
self.audit.start_audit(odir=self.output_dir)
try:
self.audit.monitor()
await self.definition._run_async(self)
await self.definition._run_async(self, rerun)
result.outputs = self.definition.Outputs._from_task(self)
except Exception:
etype, eval, etr = sys.exc_info()
Expand Down Expand Up @@ -628,8 +628,7 @@ def clear_cache(

@classmethod
def construct(
cls,
definition: WorkflowDef[WorkflowOutputsType],
cls, definition: WorkflowDef[WorkflowOutputsType], dont_cache: bool = False
) -> Self:
"""Construct a workflow from a definition, caching the constructed worklow"""

Expand Down Expand Up @@ -710,7 +709,7 @@ def construct(
f"{len(output_lazy_fields)} ({output_lazy_fields})"
)
for outpt, outpt_lf in zip(output_fields, output_lazy_fields):
# Automatically combine any uncombined state arrays into lists
# Automatically combine any uncombined state arrays into a single lists
if TypeParser.get_origin(outpt_lf._type) is StateArray:
outpt_lf._type = list[TypeParser.strip_splits(outpt_lf._type)[0]]
setattr(outputs, outpt.name, outpt_lf)
Expand All @@ -722,8 +721,8 @@ def construct(
f"Expected outputs {unset_outputs} to be set by the "
f"constructor of {workflow!r}"
)

cls._constructed_cache[defn_hash][non_lazy_keys][non_lazy_hash] = workflow
if not dont_cache:
cls._constructed_cache[defn_hash][non_lazy_keys][non_lazy_hash] = workflow

return workflow

Expand All @@ -735,8 +734,7 @@ def under_construction(cls) -> "Workflow[ty.Any]":
# Find the frame where the construct method was called
if (
frame.f_code.co_name == "construct"
and "cls" in frame.f_locals
and frame.f_locals["cls"] is cls
and frame.f_locals.get("cls") is cls
and "workflow" in frame.f_locals
):
return frame.f_locals["workflow"] # local var "workflow" in construct
Expand Down
4 changes: 1 addition & 3 deletions pydra/engine/helpers_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,7 @@ def template_update(
if isinstance(field, shell.outarg)
and field.path_template
and getattr(definition, field.name)
and all(
getattr(definition, required_field) for required_field in field.requires
)
and all(req.satisfied(definition) for req in field.requires)
]

dict_mod = {}
Expand Down
24 changes: 2 additions & 22 deletions pydra/engine/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def lzout(self) -> OutputType:
type_, _ = TypeParser.strip_splits(outpt._type)
if self._state.combiner:
type_ = list[type_]
for _ in range(self._state.depth - int(bool(self._state.combiner))):
for _ in range(self._state.depth()):
type_ = StateArray[type_]
outpt._type = type_
# Flag the output lazy fields as being not typed checked (i.e. assigned to
Expand Down Expand Up @@ -272,7 +272,7 @@ def _get_upstream_states(self) -> dict[str, tuple["State", list[str]]]:
if (
isinstance(val, lazy.LazyOutField)
and val._node.state
and val._node.state.depth
and val._node.state.depth()
):
node: Node = val._node
# variables that are part of inner splitters should be treated as a containers
Expand Down Expand Up @@ -305,26 +305,6 @@ def _extract_input_el(self, inputs, inp_nm, ind):
else:
return getattr(inputs, inp_nm)[ind]

def _split_definition(self) -> dict[StateIndex, "TaskDef[OutputType]"]:
"""Split the definition into the different states it will be run over"""
# TODO: doesn't work properly for more cmplicated wf (check if still an issue)
if not self.state:
return {None: self._definition}
split_defs = {}
for input_ind in self.state.inputs_ind:
inputs_dict = {}
for inp in set(self.input_names):
if f"{self.name}.{inp}" in input_ind:
inputs_dict[inp] = self._extract_input_el(
inputs=self._definition,
inp_nm=inp,
ind=input_ind[f"{self.name}.{inp}"],
)
split_defs[StateIndex(input_ind)] = attrs.evolve(
self._definition, **inputs_dict
)
return split_defs

# else:
# # todo it never gets here
# breakpoint()
Expand Down
Loading
Loading