diff --git a/docs/input_spec.rst b/docs/input_spec.rst index 48d66fd814..2940c17820 100644 --- a/docs/input_spec.rst +++ b/docs/input_spec.rst @@ -174,8 +174,9 @@ In the example we used multiple keys in the metadata dictionary including `help_ (a specific input field will be sent). -Validators ----------- -Pydra allows for using simple validator for types and `allowev_values`. -The validators are disabled by default, but can be enabled by calling -`pydra.set_input_validator(flag=True)`. +`shell_arg` Function +-------------------- + +For convenience, there is a function in `pydra.mark` called `shell_arg()`, which will +takes the above metadata values as arguments and inserts them into the metadata passed +to `attrs.field`. This can be especially useful when using an IDE with code-completion. diff --git a/pydra/mark/__init__.py b/pydra/mark/__init__.py index 31e4cf832e..f2434e5a1c 100644 --- a/pydra/mark/__init__.py +++ b/pydra/mark/__init__.py @@ -1,3 +1,4 @@ from .functions import annotate, task +from .shell import shell_task, shell_arg, shell_out -__all__ = ("annotate", "task") +__all__ = ("annotate", "task", "shell_task", "shell_arg", "shell_out") diff --git a/pydra/mark/shell.py b/pydra/mark/shell.py new file mode 100644 index 0000000000..9abdcf61fe --- /dev/null +++ b/pydra/mark/shell.py @@ -0,0 +1,408 @@ +"""Decorators and helper functions to create ShellCommandTasks used in Pydra workflows""" +from __future__ import annotations +import typing as ty +import attrs + +# import os +import pydra.engine.specs + + +def shell_task( + klass_or_name: ty.Union[type, str], + executable: ty.Optional[str] = None, + input_fields: ty.Optional[dict[str, dict]] = None, + output_fields: ty.Optional[dict[str, dict]] = None, + bases: ty.Optional[list[type]] = None, + inputs_bases: ty.Optional[list[type]] = None, + outputs_bases: ty.Optional[list[type]] = None, +) -> type: + """ + Construct an analysis class and validate all the components fit together + + Parameters + ---------- + klass_or_name : type or str + Either the class decorated by the @shell_task decorator or the name for a + dynamically generated class + executable : str, optional + If dynamically constructing a class (instead of decorating an existing one) the + name of the executable to run is provided + input_fields : dict[str, dict], optional + If dynamically constructing a class (instead of decorating an existing one) the + input fields can be provided as a dictionary of dictionaries, where the keys + are the name of the fields and the dictionary contents are passed as keyword + args to cmd_arg, with the exception of "type", which is used as the type annotation + of the field. + output_fields : dict[str, dict], optional + If dynamically constructing a class (instead of decorating an existing one) the + output fields can be provided as a dictionary of dictionaries, where the keys + are the name of the fields and the dictionary contents are passed as keyword + args to cmd_out, with the exception of "type", which is used as the type annotation + of the field. + bases : list[type] + Base classes for dynamically constructed shell command classes + inputs_bases : list[type] + Base classes for the input spec of dynamically constructed shell command classes + outputs_bases : list[type] + Base classes for the input spec of dynamically constructed shell command classes + + Returns + ------- + type + the shell command task class + """ + + annotations = { + "executable": str, + "Inputs": type, + "Outputs": type, + } + dct = {"__annotations__": annotations} + + if isinstance(klass_or_name, str): + # Dynamically created classes using shell_task as a function + name = klass_or_name + + if executable is not None: + dct["executable"] = executable + if input_fields is None: + input_fields = {} + if output_fields is None: + output_fields = {} + bases = list(bases) if bases is not None else [] + inputs_bases = list(inputs_bases) if inputs_bases is not None else [] + outputs_bases = list(outputs_bases) if outputs_bases is not None else [] + + # Ensure base classes included somewhere in MRO + def ensure_base_included(base_class: type, bases_list: list[type]): + if not any(issubclass(b, base_class) for b in bases_list): + bases_list.append(base_class) + + # Get inputs and outputs bases from base class if not explicitly provided + for base in bases: + if not inputs_bases: + try: + inputs_bases = [base.Inputs] + except AttributeError: + pass + if not outputs_bases: + try: + outputs_bases = [base.Outputs] + except AttributeError: + pass + + # Ensure bases are lists and can be modified + ensure_base_included(pydra.engine.task.ShellCommandTask, bases) + ensure_base_included(pydra.engine.specs.ShellSpec, inputs_bases) + ensure_base_included(pydra.engine.specs.ShellOutSpec, outputs_bases) + + def convert_to_attrs(fields: dict[str, dict[str, ty.Any]], attrs_func): + annotations = {} + attrs_dict = {"__annotations__": annotations} + for name, dct in fields.items(): + kwargs = dict(dct) # copy to avoid modifying input to outer function + annotations[name] = kwargs.pop("type") + attrs_dict[name] = attrs_func(**kwargs) + return attrs_dict + + Inputs = attrs.define(kw_only=True, slots=False)( + type( + "Inputs", + tuple(inputs_bases), + convert_to_attrs(input_fields, shell_arg), + ) + ) + + Outputs = attrs.define(kw_only=True, slots=False)( + type( + "Outputs", + tuple(outputs_bases), + convert_to_attrs(output_fields, shell_out), + ) + ) + + else: + # Statically defined classes using shell_task as decorator + if ( + executable, + input_fields, + output_fields, + bases, + inputs_bases, + outputs_bases, + ) != (None, None, None, None, None, None): + raise RuntimeError( + "When used as a decorator on a class, `shell_task` should not be " + "provided any other arguments" + ) + klass = klass_or_name + name = klass.__name__ + + bases = [klass] + if not issubclass(klass, pydra.engine.task.ShellCommandTask): + bases.append(pydra.engine.task.ShellCommandTask) + + try: + executable = klass.executable + except AttributeError: + raise RuntimeError( + "Classes decorated by `shell_task` should contain an `executable` " + "attribute specifying the shell tool to run" + ) + try: + Inputs = klass.Inputs + except AttributeError: + raise RuntimeError( + "Classes decorated by `shell_task` should contain an `Inputs` class " + "attribute specifying the inputs to the shell tool" + ) + + try: + Outputs = klass.Outputs + except AttributeError: + Outputs = type("Outputs", (pydra.engine.specs.ShellOutSpec,), {}) + + # Pass Inputs and Outputs in attrs.define if they are present in klass (i.e. + # not in a base class) + if "Inputs" in klass.__dict__: + Inputs = attrs.define(kw_only=True, slots=False)(Inputs) + if "Outputs" in klass.__dict__: + Outputs = attrs.define(kw_only=True, slots=False)(Outputs) + + if not issubclass(Inputs, pydra.engine.specs.ShellSpec): + Inputs = attrs.define(kw_only=True, slots=False)( + type("Inputs", (Inputs, pydra.engine.specs.ShellSpec), {}) + ) + + template_fields = _gen_output_template_fields(Inputs, Outputs) + + if not issubclass(Outputs, pydra.engine.specs.ShellOutSpec): + outputs_bases = (Outputs, pydra.engine.specs.ShellOutSpec) + add_base_class = True + else: + outputs_bases = (Outputs,) + add_base_class = False + + if add_base_class or template_fields: + Outputs = attrs.define(kw_only=True, slots=False)( + type("Outputs", outputs_bases, template_fields) + ) + + dct["Inputs"] = Inputs + dct["Outputs"] = Outputs + + task_klass = type(name, tuple(bases), dct) + + if not hasattr(task_klass, "executable"): + raise RuntimeError( + "Classes generated by `shell_task` should contain an `executable` " + "attribute specifying the shell tool to run" + ) + + task_klass.input_spec = pydra.engine.specs.SpecInfo( + name=f"{name}Inputs", fields=[], bases=(task_klass.Inputs,) + ) + task_klass.output_spec = pydra.engine.specs.SpecInfo( + name=f"{name}Outputs", fields=[], bases=(task_klass.Outputs,) + ) + + return task_klass + + +def shell_arg( + help_string: str, + default: ty.Any = attrs.NOTHING, + argstr: str = None, + position: int = None, + mandatory: bool = False, + sep: str = None, + allowed_values: list = None, + requires: list = None, + xor: list = None, + copyfile: bool = None, + container_path: bool = False, + output_file_template: str = None, + output_field_name: str = None, + keep_extension: bool = True, + readonly: bool = False, + formatter: ty.Callable = None, + **kwargs, +): + """ + Returns an attrs field with appropriate metadata for it to be added as an argument in + a Pydra shell command task definition + + Parameters + ------------ + help_string: str + A short description of the input field. + default : Any, optional + the default value for the argument + argstr: str, optional + A flag or string that is used in the command before the value, e.g. -v or + -v {inp_field}, but it could be and empty string, “”. If … are used, e.g. -v…, + the flag is used before every element if a list is provided as a value. If no + argstr is used the field is not part of the command. + position: int, optional + Position of the field in the command, could be nonnegative or negative integer. + If nothing is provided the field will be inserted between all fields with + nonnegative positions and fields with negative positions. + mandatory: bool, optional + If True user has to provide a value for the field, by default it is False + sep: str, optional + A separator if a list is provided as a value. + allowed_values: list, optional + List of allowed values for the field. + requires: list, optional + List of field names that are required together with the field. + xor: list, optional + List of field names that are mutually exclusive with the field. + copyfile: bool, optional + If True, a hard link is created for the input file in the output directory. If + hard link not possible, the file is copied to the output directory, by default + it is False + container_path: bool, optional + If True a path will be consider as a path inside the container (and not as a + local path, by default it is False + output_file_template: str, optional + If provided, the field is treated also as an output field and it is added to + the output spec. The template can use other fields, e.g. {file1}. Used in order + to create an output specification. + output_field_name: str, optional + If provided the field is added to the output spec with changed name. Used in + order to create an output specification. Used together with output_file_template + keep_extension: bool, optional + A flag that specifies if the file extension should be removed from the field value. + Used in order to create an output specification, by default it is True + readonly: bool, optional + If True the input field can’t be provided by the user but it aggregates other + input fields (for example the fields with argstr: -o {fldA} {fldB}), by default + it is False + formatter: function, optional + If provided the argstr of the field is created using the function. This function + can for example be used to combine several inputs into one command argument. The + function can take field (this input field will be passed to the function), + inputs (entire inputs will be passed) or any input field name (a specific input + field will be sent). + **kwargs + remaining keyword arguments are passed onto the underlying attrs.field function + """ + + metadata = { + "help_string": help_string, + "argstr": argstr, + "position": position, + "mandatory": mandatory, + "sep": sep, + "allowed_values": allowed_values, + "requires": requires, + "xor": xor, + "copyfile": copyfile, + "container_path": container_path, + "output_file_template": output_file_template, + "output_field_name": output_field_name, + "keep_extension": keep_extension, + "readonly": readonly, + "formatter": formatter, + } + + return attrs.field( + default=default, + metadata={k: v for k, v in metadata.items() if v is not None}, + **kwargs, + ) + + +def shell_out( + help_string: str, + mandatory: bool = False, + output_file_template: str = None, + output_field_name: str = None, + keep_extension: bool = True, + requires: list = None, + callable: ty.Callable = None, + **kwargs, +): + """Returns an attrs field with appropriate metadata for it to be added as an output of + a Pydra shell command task definition + + Parameters + ---------- + help_string: str + A short description of the input field. The same as in input_spec. + mandatory: bool, default: False + If True the output file has to exist, otherwise an error will be raised. + output_file_template: str, optional + If provided the output file name (or list of file names) is created using the + template. The template can use other fields, e.g. {file1}. The same as in + input_spec. + output_field_name: str, optional + If provided the field is added to the output spec with changed name. The same as + in input_spec. Used together with output_file_template + keep_extension: bool, default: True + A flag that specifies if the file extension should be removed from the field + value. The same as in input_spec. + requires: list + List of field names that are required to create a specific output. The fields + do not have to be a part of the output_file_template and if any field from the + list is not provided in the input, a NOTHING is returned for the specific output. + This has a different meaning than the requires form the input_spec. + callable: Callable + If provided the output file name (or list of file names) is created using the + function. The function can take field (the specific output field will be passed + to the function), output_dir (task output_dir will be used), stdout, stderr + (stdout and stderr of the task will be sent) inputs (entire inputs will be + passed) or any input field name (a specific input field will be sent). + **kwargs + remaining keyword arguments are passed onto the underlying attrs.field function + """ + metadata = { + "help_string": help_string, + "mandatory": mandatory, + "output_file_template": output_file_template, + "output_field_name": output_field_name, + "keep_extension": keep_extension, + "requires": requires, + "callable": callable, + } + + return attrs.field( + metadata={k: v for k, v in metadata.items() if v is not None}, **kwargs + ) + + +def _gen_output_template_fields(Inputs: type, Outputs: type) -> dict: + """Auto-generates output fields for inputs that specify an 'output_file_template' + + Parameters + ---------- + Inputs : type + Inputs specification class + Outputs : type + Outputs specification class + + Returns + ------- + template_fields: dict[str, attrs._make_CountingAttribute] + the template fields to add to the output spec + """ + annotations = {} + template_fields = {"__annotations__": annotations} + output_field_names = [f.name for f in attrs.fields(Outputs)] + for fld in attrs.fields(Inputs): + if "output_file_template" in fld.metadata: + if "output_field_name" in fld.metadata: + field_name = fld.metadata["output_field_name"] + else: + field_name = fld.name + # skip adding if the field already in the output_spec + exists_already = field_name in output_field_names + if not exists_already: + metadata = { + "help_string": fld.metadata["help_string"], + "mandatory": fld.metadata["mandatory"], + "keep_extension": fld.metadata["keep_extension"], + } + template_fields[field_name] = attrs.field(metadata=metadata) + annotations[field_name] = str + return template_fields diff --git a/pydra/mark/tests/test_shell.py b/pydra/mark/tests/test_shell.py new file mode 100644 index 0000000000..6fee7259b1 --- /dev/null +++ b/pydra/mark/tests/test_shell.py @@ -0,0 +1,467 @@ +import os +import tempfile +import attrs +from pathlib import Path +import pytest +import cloudpickle as cp +from pydra.mark import shell_task, shell_arg, shell_out + + +def list_entries(stdout): + return stdout.split("\n")[:-1] + + +@pytest.fixture +def tmpdir(): + return Path(tempfile.mkdtemp()) + + +@pytest.fixture(params=["static", "dynamic"]) +def Ls(request): + if request.param == "static": + + @shell_task + class Ls: + executable = "ls" + + class Inputs: + directory: os.PathLike = shell_arg( + help_string="the directory to list the contents of", + argstr="", + mandatory=True, + position=-1, + ) + hidden: bool = shell_arg( + help_string=("display hidden FS objects"), + argstr="-a", + default=False, + ) + long_format: bool = shell_arg( + help_string=( + "display properties of FS object, such as permissions, size and " + "timestamps " + ), + default=False, + argstr="-l", + ) + human_readable: bool = shell_arg( + help_string="display file sizes in human readable form", + argstr="-h", + default=False, + requires=["long_format"], + ) + complete_date: bool = shell_arg( + help_string="Show complete date in long format", + argstr="-T", + default=False, + requires=["long_format"], + xor=["date_format_str"], + ) + date_format_str: str = shell_arg( + help_string="format string for ", + argstr="-D", + default=attrs.NOTHING, + requires=["long_format"], + xor=["complete_date"], + ) + + class Outputs: + entries: list = shell_out( + help_string="list of entries returned by ls command", + callable=list_entries, + ) + + elif request.param == "dynamic": + Ls = shell_task( + "Ls", + executable="ls", + input_fields={ + "directory": { + "type": os.PathLike, + "help_string": "the directory to list the contents of", + "argstr": "", + "mandatory": True, + "position": -1, + }, + "hidden": { + "type": bool, + "help_string": "display hidden FS objects", + "argstr": "-a", + }, + "long_format": { + "type": bool, + "help_string": ( + "display properties of FS object, such as permissions, size and " + "timestamps " + ), + "argstr": "-l", + }, + "human_readable": { + "type": bool, + "help_string": "display file sizes in human readable form", + "argstr": "-h", + "requires": ["long_format"], + }, + "complete_date": { + "type": bool, + "help_string": "Show complete date in long format", + "argstr": "-T", + "requires": ["long_format"], + "xor": ["date_format_str"], + }, + "date_format_str": { + "type": str, + "help_string": "format string for ", + "argstr": "-D", + "requires": ["long_format"], + "xor": ["complete_date"], + }, + }, + output_fields={ + "entries": { + "type": list, + "help_string": "list of entries returned by ls command", + "callable": list_entries, + } + }, + ) + + else: + assert False + + return Ls + + +def test_shell_fields(Ls): + assert [a.name for a in attrs.fields(Ls.Inputs)] == [ + "executable", + "args", + "directory", + "hidden", + "long_format", + "human_readable", + "complete_date", + "date_format_str", + ] + + assert [a.name for a in attrs.fields(Ls.Outputs)] == [ + "return_code", + "stdout", + "stderr", + "entries", + ] + + +def test_shell_pickle_roundtrip(Ls, tmpdir): + pkl_file = tmpdir / "ls.pkl" + with open(pkl_file, "wb") as f: + cp.dump(Ls, f) + + with open(pkl_file, "rb") as f: + RereadLs = cp.load(f) + + assert RereadLs is Ls + + +def test_shell_run(Ls, tmpdir): + Path.touch(tmpdir / "a") + Path.touch(tmpdir / "b") + Path.touch(tmpdir / "c") + + ls = Ls(directory=tmpdir, long_format=True) + + # Test cmdline + assert ls.inputs.directory == tmpdir + assert not ls.inputs.hidden + assert ls.inputs.long_format + assert ls.cmdline == f"ls -l {tmpdir}" + + # Drop Long format flag to make output simpler + ls = Ls(directory=tmpdir) + result = ls() + + assert result.output.entries == ["a", "b", "c"] + + +@pytest.fixture(params=["static", "dynamic"]) +def A(request): + if request.param == "static": + + @shell_task + class A: + executable = "cp" + + class Inputs: + x: os.PathLike = shell_arg( + help_string="an input file", argstr="", position=0 + ) + y: str = shell_arg( + help_string="path of output file", + output_file_template="{x}_out", + argstr="", + ) + + elif request.param == "dynamic": + A = shell_task( + "A", + executable="cp", + input_fields={ + "x": { + "type": os.PathLike, + "help_string": "an input file", + "argstr": "", + "position": 0, + }, + "y": { + "type": str, + "help_string": "path of output file", + "argstr": "", + "output_file_template": "{x}_out", + }, + }, + ) + else: + assert False + + return A + + +def test_shell_output_file_template(A): + assert "y" in [a.name for a in attrs.fields(A.Outputs)] + + +def test_shell_output_field_name_static(): + @shell_task + class A: + executable = "cp" + + class Inputs: + x: os.PathLike = shell_arg( + help_string="an input file", argstr="", position=0 + ) + y: str = shell_arg( + help_string="path of output file", + output_file_template="{x}_out", + output_field_name="y_out", + argstr="", + ) + + assert "y_out" in [a.name for a in attrs.fields(A.Outputs)] + + +def test_shell_output_field_name_dynamic(): + A = shell_task( + "A", + executable="cp", + input_fields={ + "x": { + "type": os.PathLike, + "help_string": "an input file", + "argstr": "", + "position": 0, + }, + "y": { + "type": str, + "help_string": "path of output file", + "argstr": "", + "output_field_name": "y_out", + "output_file_template": "{x}_out", + }, + }, + ) + + assert "y_out" in [a.name for a in attrs.fields(A.Outputs)] + + +def get_file_size(y: Path): + result = os.stat(y) + return result.st_size + + +def test_shell_bases_dynamic(A, tmpdir): + B = shell_task( + "B", + output_fields={ + "out_file_size": { + "type": int, + "help_string": "size of the output directory", + "callable": get_file_size, + } + }, + bases=[A], + ) + + xpath = tmpdir / "x.txt" + ypath = tmpdir / "y.txt" + Path.touch(xpath) + + b = B(x=xpath, y=str(ypath)) + + result = b() + + assert b.inputs.x == xpath + assert result.output.y == str(ypath) + + +def test_shell_bases_static(A, tmpdir): + @shell_task + class B(A): + class Outputs: + out_file_size: int = shell_out( + help_string="size of the output directory", callable=get_file_size + ) + + xpath = tmpdir / "x.txt" + ypath = tmpdir / "y.txt" + Path.touch(xpath) + + b = B(x=xpath, y=str(ypath)) + + result = b() + + assert b.inputs.x == xpath + assert result.output.y == str(ypath) + + +def test_shell_inputs_outputs_bases_dynamic(tmpdir): + A = shell_task( + "A", + "ls", + input_fields={ + "directory": { + "type": os.PathLike, + "help_string": "input directory", + "argstr": "", + "position": -1, + } + }, + output_fields={ + "entries": { + "type": list, + "help_string": "list of entries returned by ls command", + "callable": list_entries, + } + }, + ) + B = shell_task( + "B", + "ls", + input_fields={ + "hidden": { + "type": bool, + "argstr": "-a", + "help_string": "show hidden files", + "default": False, + } + }, + bases=[A], + inputs_bases=[A.Inputs], + ) + + Path.touch(tmpdir / ".hidden") + + b = B(directory=tmpdir, hidden=True) + + assert b.inputs.directory == tmpdir + assert b.inputs.hidden + assert b.cmdline == f"ls -a {tmpdir}" + + result = b() + assert result.output.entries == [".", "..", ".hidden"] + + +def test_shell_inputs_outputs_bases_static(tmpdir): + @shell_task + class A: + executable = "ls" + + class Inputs: + directory: os.PathLike = shell_arg( + help_string="input directory", argstr="", position=-1 + ) + + class Outputs: + entries: list = shell_out( + help_string="list of entries returned by ls command", + callable=list_entries, + ) + + @shell_task + class B(A): + class Inputs(A.Inputs): + hidden: bool = shell_arg( + help_string="show hidden files", + argstr="-a", + default=False, + ) + + Path.touch(tmpdir / ".hidden") + + b = B(directory=tmpdir, hidden=True) + + assert b.inputs.directory == tmpdir + assert b.inputs.hidden + + result = b() + assert result.output.entries == [".", "..", ".hidden"] + + +def test_shell_missing_executable_static(): + with pytest.raises(RuntimeError, match="should contain an `executable`"): + + @shell_task + class A: + class Inputs: + directory: os.PathLike = shell_arg( + help_string="input directory", argstr="", position=-1 + ) + + class Outputs: + entries: list = shell_out( + help_string="list of entries returned by ls command", + callable=list_entries, + ) + + +def test_shell_missing_executable_dynamic(): + with pytest.raises(RuntimeError, match="should contain an `executable`"): + A = shell_task( + "A", + executable=None, + input_fields={ + "directory": { + "type": os.PathLike, + "help_string": "input directory", + "argstr": "", + "position": -1, + } + }, + output_fields={ + "entries": { + "type": list, + "help_string": "list of entries returned by ls command", + "callable": list_entries, + } + }, + ) + + +def test_shell_missing_inputs_static(): + with pytest.raises(RuntimeError, match="should contain an `Inputs`"): + + @shell_task + class A: + executable = "ls" + + class Outputs: + entries: list = shell_out( + help_string="list of entries returned by ls command", + callable=list_entries, + ) + + +def test_shell_decorator_misuse(A): + with pytest.raises( + RuntimeError, match=("`shell_task` should not be provided any other arguments") + ): + shell_task(A, executable="cp")