diff --git a/docs/source/howto.rst b/docs/source/howto.rst index d70f3af..7583083 100644 --- a/docs/source/howto.rst +++ b/docs/source/howto.rst @@ -568,9 +568,11 @@ The ``parser`` keyword allows to define a "custom" parser, which is a function w .. code-block:: python - def parser(self, dirpath: pathlib.Path) -> dict[str, Data]: + def custom_parser(dirpath: pathlib.Path) -> dict[str, Data]: """Parse any output file generated by the shell command and return it as any ``Data`` node.""" +The ``dirpath`` argument receives the filepath to a directory that contains the retrieved output files that can then be read and parsed. +The parsed results should be returned as a dictionary of ``Data`` nodes, such that they can be attached to the job's node as outputs in the provenance graph. The following example shows how a custom parser can be implemented: @@ -578,14 +580,14 @@ The following example shows how a custom parser can be implemented: from aiida_shell import launch_shell_job - def parser(self, dirpath): + def custom_parser(dirpath): from aiida.orm import Str return {'string': Str((dirpath / 'stdout').read_text().strip())} results, node = launch_shell_job( 'echo', arguments='some output', - parser=parser + parser=custom_parser ) print(results['string'].value) @@ -623,6 +625,46 @@ which prints ``some output``. which prints ``{'a': 1}``. +Optionally, the parsing function can also define the ``parser`` argument. +If defined, in addition to the ``dirpath``, the parser receives an instance of the ``Parser`` class. +This instance can be useful for a number of things, such as: + +* Access the logger in order to log messages +* Access the node that represents the ``ShellJob``, from which, e.g., its input nodes can be accessed + +Below is an example of how the ``parser`` argument can be put to use: + +.. code-block:: python + + from pathlib import Path + from aiida_shell import launch_shell_job + from aiida.parsers import Parser + + def custom_parser(dirpath: Path, parser: Parser): + from aiida.orm import Bool, Str + + inputs = parser.node.inputs # Retrieve inputs of the job + + if inputs.arguments[0] == 'return-bool': + parser.logger.warning('Arguments set to `return-bool`, returning a bool') + return {'output': Bool(True)} + else: + return {'output': Str((dirpath / 'stdout').read_text().strip())} + + results, node = launch_shell_job( + 'echo', + arguments='return-bool', + parser=custom_parser + ) + print(results['output'].value) + +which should print + +.. code-block:: console + + 07/18/2024 03:49:32 PM <13555> aiida.parser.ShellParser: [WARNING] Arguments set to `return-bool`, returning a bool + True + .. tip:: If you find yourself reusing the same parser often, you can also register it with an entry point and use that for the ``parser`` input. diff --git a/src/aiida_shell/calculations/shell.py b/src/aiida_shell/calculations/shell.py index 498adca..d9348b3 100644 --- a/src/aiida_shell/calculations/shell.py +++ b/src/aiida_shell/calculations/shell.py @@ -11,11 +11,16 @@ from aiida.common.folders import Folder from aiida.engine import CalcJob, CalcJobProcessSpec from aiida.orm import Data, Dict, FolderData, List, RemoteData, SinglefileData, to_aiida_type +from aiida.parsers import Parser from aiida_shell.data import EntryPointData, PickledData __all__ = ('ShellJob',) +ParserFunctionType = t.Union[ + t.Callable[[pathlib.Path, Parser], dict[str, Data]], t.Callable[[pathlib.Path], dict[str, Data]] +] + class ShellJob(CalcJob): """Implementation of :class:`aiida.engine.CalcJob` to run a simple shell command.""" @@ -191,8 +196,11 @@ def validate_parser(cls, value: t.Any, _: t.Any) -> str | None: parameters = list(signature.parameters.keys()) - if any(required_parameter not in parameters for required_parameter in ('self', 'dirpath')): - correct_signature = '(self, dirpath: pathlib.Path) -> dict[str, Data]:' + if sorted(parameters) not in (['dirpath'], ['dirpath', 'parser']): + correct_signature = ( + '(dirpath: pathlib.Path) -> dict[str, Data]: or ' + '(dirpath: pathlib.Path, parser: Parser) -> dict[str, Data]:' + ) return f'The `parser` has an invalid function signature, it should be: {correct_signature}' return None diff --git a/src/aiida_shell/launch.py b/src/aiida_shell/launch.py index 4b7810f..99f91ce 100644 --- a/src/aiida_shell/launch.py +++ b/src/aiida_shell/launch.py @@ -10,10 +10,11 @@ from aiida.common import exceptions, lang from aiida.engine import Process, WorkChain, launch from aiida.orm import AbstractCode, Computer, Data, ProcessNode, SinglefileData, load_code, load_computer -from aiida.parsers import Parser from aiida_shell import ShellCode, ShellJob +from .calculations.shell import ParserFunctionType + __all__ = ('launch_shell_job',) LOGGER = logging.getLogger('aiida_shell') @@ -25,7 +26,7 @@ def launch_shell_job( # noqa: PLR0913 nodes: t.Mapping[str, str | pathlib.Path | Data] | None = None, filenames: dict[str, str] | None = None, outputs: list[str] | None = None, - parser: t.Callable[[Parser, pathlib.Path], dict[str, Data]] | str | None = None, + parser: ParserFunctionType | str | None = None, metadata: dict[str, t.Any] | None = None, submit: bool = False, resolve_command: bool = True, diff --git a/src/aiida_shell/parsers/shell.py b/src/aiida_shell/parsers/shell.py index ef64640..4b701e7 100644 --- a/src/aiida_shell/parsers/shell.py +++ b/src/aiida_shell/parsers/shell.py @@ -138,8 +138,15 @@ def parse_custom_outputs(self, dirpath: pathlib.Path) -> list[str]: def call_parser_hook(self, dirpath: pathlib.Path) -> None: """Execute the ``parser`` custom parser hook that was passed as input to the ``ShellJob``.""" + from inspect import signature + unpickled_parser = self.node.inputs.parser.load() - results = unpickled_parser(self, dirpath) or {} + parser_signature = signature(unpickled_parser) + + if 'parser' in parser_signature.parameters: + results = unpickled_parser(dirpath, self) or {} + else: + results = unpickled_parser(dirpath) or {} if not isinstance(results, dict) or any(not isinstance(value, Data) for value in results.values()): raise TypeError(f'{unpickled_parser} did not return a dictionary of `Data` nodes but: {results}') diff --git a/tests/calculations/test_shell.py b/tests/calculations/test_shell.py index 09d45c6..a65aa1c 100644 --- a/tests/calculations/test_shell.py +++ b/tests/calculations/test_shell.py @@ -10,10 +10,14 @@ from aiida_shell.data import EntryPointData, PickledData -def custom_parser(self, dirpath): +def custom_parser(dirpath): """Implement a custom parser to test the ``parser`` input for a ``ShellJob``.""" +def custom_parser_with_parser_argument(dirpath, parser): + """Implement a custom parser that defines the optional ``parser`` argument.""" + + def test_code(generate_calc_job, generate_code): """Test the ``code`` input.""" code = generate_code() @@ -347,6 +351,16 @@ def test_parser(generate_calc_job, generate_code): assert isinstance(process.inputs.parser, PickledData) +def test_parser_with_parser_argument(generate_calc_job, generate_code): + """Test the ``parser`` input for valid input.""" + process = generate_calc_job( + 'core.shell', + inputs={'code': generate_code(), 'parser': custom_parser_with_parser_argument}, + return_process=True, + ) + assert isinstance(process.inputs.parser, PickledData) + + def test_parser_entry_point(generate_calc_job, generate_code, entry_points): """Test the ``parser`` serialization and validation when input is an entry point.""" entry_point_name = 'aiida.parsers:shell.parser' @@ -374,7 +388,7 @@ def test_parser_over_daemon(generate_code, submit_and_await): """Test submitting a ``ShellJob`` with a custom parser over the daemon.""" value = 'testing' - def parser(self, dirpath): + def parser(dirpath): from aiida.orm import Str return {'string': Str((dirpath / 'stdout').read_text().strip())} diff --git a/tests/test_launch.py b/tests/test_launch.py index c830fd9..c9da67b 100644 --- a/tests/test_launch.py +++ b/tests/test_launch.py @@ -256,7 +256,7 @@ def test_resolve_command(aiida_profile, resolve_command, executable): def test_parser(): """Test the ``parser`` argument.""" - def parser(self, dirpath): + def parser(dirpath): from aiida.orm import Str return {'string': Str((dirpath / 'stdout').read_text().strip())} @@ -276,7 +276,7 @@ def test_parser_non_stdout(): """ filename = 'results.json' - def parser(self, dirpath): + def parser(dirpath): import json from aiida.orm import Dict @@ -296,6 +296,22 @@ def parser(self, dirpath): assert results['json'] == dictionary +def test_parser_with_parser_argument(): + """Test the ``parser`` argument for callable that specifies optional ``parser`` argument.""" + + def parser(dirpath, parser): + from aiida.orm import Str + + return {'arguments': Str(parser.node.inputs.arguments[0])} + + value = 'test_string' + arguments = [value] + results, node = launch_shell_job('echo', arguments=arguments, parser=parser) + + assert node.is_finished_ok + assert results['arguments'] == value + + @pytest.mark.parametrize('scheduler_type', ('core.direct', 'core.sge')) def test_preexisting_localhost_no_default_mpiprocs_per_machine( aiida_profile, generate_computer, scheduler_type, caplog