diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index be5aaa3cc9..4c9f688f07 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -110,7 +110,6 @@ repos: aiida/engine/processes/calcjobs/monitors.py| aiida/engine/processes/calcjobs/tasks.py| aiida/engine/processes/control.py| - aiida/engine/processes/functions.py| aiida/engine/processes/ports.py| aiida/manage/configuration/__init__.py| aiida/manage/configuration/config.py| diff --git a/aiida/engine/processes/functions.py b/aiida/engine/processes/functions.py index 051b5252b1..23d8a87522 100644 --- a/aiida/engine/processes/functions.py +++ b/aiida/engine/processes/functions.py @@ -15,15 +15,41 @@ import inspect import logging import signal -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Tuple, Type, TypeVar +import types +import typing as t +from typing import TYPE_CHECKING from aiida.common.lang import override from aiida.manage import get_manager -from aiida.orm import CalcFunctionNode, Data, ProcessNode, WorkFunctionNode, to_aiida_type +from aiida.orm import ( + Bool, + CalcFunctionNode, + Data, + Dict, + Float, + Int, + List, + ProcessNode, + Str, + WorkFunctionNode, + to_aiida_type, +) from aiida.orm.utils.mixins import FunctionCalculationMixin from .process import Process +try: + UnionType = types.UnionType # type: ignore[attr-defined] +except AttributeError: + # This type is not available for Python 3.9 and older + UnionType = None # pylint: disable=invalid-name + +try: + get_annotations = inspect.get_annotations # type: ignore[attr-defined] +except AttributeError: + # This is the backport for Python 3.9 and older + from get_annotations import get_annotations # type: ignore[no-redef] + if TYPE_CHECKING: from .exit_code import ExitCode @@ -31,7 +57,7 @@ LOGGER = logging.getLogger(__name__) -FunctionType = TypeVar('FunctionType', bound=Callable[..., Any]) +FunctionType = t.TypeVar('FunctionType', bound=t.Callable[..., t.Any]) def calcfunction(function: FunctionType) -> FunctionType: @@ -88,14 +114,14 @@ def workfunction(function: FunctionType) -> FunctionType: return process_function(node_class=WorkFunctionNode)(function) -def process_function(node_class: Type['ProcessNode']) -> Callable[[Callable[..., Any]], Callable[..., Any]]: +def process_function(node_class: t.Type['ProcessNode']) -> t.Callable[[FunctionType], FunctionType]: """ The base function decorator to create a FunctionProcess out of a normal python function. :param node_class: the ORM class to be used as the Node record for the FunctionProcess """ - def decorator(function: Callable[..., Any]) -> Callable[..., Any]: + def decorator(function: FunctionType) -> FunctionType: """ Turn the decorated function into a FunctionProcess. @@ -104,7 +130,7 @@ def decorator(function: Callable[..., Any]) -> Callable[..., Any]: """ process_class = FunctionProcess.build(function, node_class=node_class) - def run_get_node(*args, **kwargs) -> Tuple[Optional[Dict[str, Any]], 'ProcessNode']: + def run_get_node(*args, **kwargs) -> tuple[dict[str, t.Any] | None, 'ProcessNode']: """ Run the FunctionProcess with the supplied inputs in a local runner. @@ -159,7 +185,7 @@ def kill_process(_num, _frame): return result, process.node - def run_get_pk(*args, **kwargs) -> Tuple[Optional[Dict[str, Any]], int]: + def run_get_pk(*args, **kwargs) -> tuple[dict[str, t.Any] | None, int]: """Recreate the `run_get_pk` utility launcher. :param args: input arguments to construct the FunctionProcess @@ -185,15 +211,52 @@ def decorated_function(*args, **kwargs): decorated_function.recreate_from = process_class.recreate_from # type: ignore[attr-defined] decorated_function.spec = process_class.spec # type: ignore[attr-defined] - return decorated_function + return decorated_function # type: ignore[return-value] return decorator +def infer_valid_type_from_type_annotation(annotation: t.Any) -> tuple[t.Any, ...]: + """Infer the value for the ``valid_type`` of an input port from the given function argument annotation. + + :param annotation: The annotation of a function argument as returned by ``inspect.get_annotation``. + :returns: A tuple of valid types. If no valid types were defined or they could not be successfully parsed, an empty + tuple is returned. + """ + + def get_type_from_annotation(annotation): + valid_type_map = { + bool: Bool, + dict: Dict, + t.Dict: Dict, + float: Float, + int: Int, + list: List, + t.List: List, + str: Str, + } + + if inspect.isclass(annotation) and issubclass(annotation, Data): + return annotation + + return valid_type_map.get(annotation) + + inferred_valid_type: tuple[t.Any, ...] = () + + if inspect.isclass(annotation): + inferred_valid_type = (get_type_from_annotation(annotation),) + elif t.get_origin(annotation) is t.Union or t.get_origin(annotation) is UnionType: + inferred_valid_type = tuple(get_type_from_annotation(valid_type) for valid_type in t.get_args(annotation)) + elif t.get_origin(annotation) is t.Optional: + inferred_valid_type = (t.get_args(annotation),) + + return tuple(valid_type for valid_type in inferred_valid_type if valid_type is not None) + + class FunctionProcess(Process): """Function process class used for turning functions into a Process""" - _func_args: Sequence[str] = () + _func_args: t.Sequence[str] = () _varargs: str | None = None @staticmethod @@ -205,7 +268,7 @@ def _func(*_args, **_kwargs) -> dict: return {} @staticmethod - def build(func: Callable[..., Any], node_class: Type['ProcessNode']) -> Type['FunctionProcess']: + def build(func: FunctionType, node_class: t.Type['ProcessNode']) -> t.Type['FunctionProcess']: """ Build a Process from the given function. @@ -222,10 +285,30 @@ def build(func: Callable[..., Any], node_class: Type['ProcessNode']) -> Type['Fu if not issubclass(node_class, ProcessNode) or not issubclass(node_class, FunctionCalculationMixin): raise TypeError('the node_class should be a sub class of `ProcessNode` and `FunctionCalculationMixin`') - args, varargs, keywords, defaults, _, _, _ = inspect.getfullargspec(func) - nargs = len(args) - ndefaults = len(defaults) if defaults else 0 - first_default_pos = nargs - ndefaults + signature = inspect.signature(func) + + args: list[str] = [] + varargs: str | None = None + keywords: str | None = None + + try: + annotations = get_annotations(func, eval_str=True) + except Exception as exception: # pylint: disable=broad-except + # Since we are running with ``eval_str=True`` to unstringize the annotations, the call can except if the + # annotations are incorrect. In this case we simply want to log a warning and continue with type inference. + LOGGER.warning(f'function `{func.__name__}` has invalid type hints: {exception}') + annotations = {} + + for key, parameter in signature.parameters.items(): + + if parameter.kind in [parameter.POSITIONAL_ONLY, parameter.POSITIONAL_OR_KEYWORD, parameter.KEYWORD_ONLY]: + args.append(key) + + if parameter.kind is parameter.VAR_POSITIONAL: + varargs = key + + if parameter.kind is parameter.VAR_KEYWORD: + varargs = key def _define(cls, spec): # pylint: disable=unused-argument """Define the spec dynamically""" @@ -233,37 +316,39 @@ def _define(cls, spec): # pylint: disable=unused-argument super().define(spec) - for i, arg in enumerate(args): + for parameter in signature.parameters.values(): + + if parameter.kind in [parameter.VAR_POSITIONAL, parameter.VAR_KEYWORD]: + continue - default = UNSPECIFIED + annotation = annotations.get(parameter.name) + valid_type = infer_valid_type_from_type_annotation(annotation) or (Data,) - if defaults and i >= first_default_pos: - default = defaults[i - first_default_pos] + default = parameter.default if parameter.default is not parameter.empty else UNSPECIFIED # If the keyword was already specified, simply override the default - if spec.has_input(arg): - spec.inputs[arg].default = default + if spec.has_input(parameter.name): + spec.inputs[parameter.name].default = default + continue + + # If the default is ``None`` make sure that the port also accepts a ``NoneType``. Note that we cannot + # use ``None`` because the validation will call ``isinstance`` which does not work when passing ``None`` + # but it does work with ``NoneType`` which is returned by calling ``type(None)``. + if default is None: + valid_type += (type(None),) + + # If a default is defined and it is not a ``Data`` instance it should be serialized, but this should be + # done lazily using a lambda, just as any port defaults should not define node instances directly as is + # also checked by the ``spec.input`` call. + if ( + default is not None and default != UNSPECIFIED and not isinstance(default, Data) and + not callable(default) + ): + indirect_default = lambda value=default: to_aiida_type(value) else: - # If the default is `None` make sure that the port also accepts a `NoneType` - # Note that we cannot use `None` because the validation will call `isinstance` which does not work - # when passing `None`, but it does work with `NoneType` which is returned by calling `type(None)` - if default is None: - valid_type = (Data, type(None)) - else: - valid_type = (Data,) - - # If a default is defined and it is not a ``Data`` instance it should be serialized, but this should - # be done lazily using a lambda, just as any port defaults should not define node instances directly - # as is also checked by the ``spec.input`` call. - if ( - default is not None and default != UNSPECIFIED and not isinstance(default, Data) and - not callable(default) - ): - indirect_default = lambda value=default: to_aiida_type(value) - else: - indirect_default = default # type: ignore[assignment] - - spec.input(arg, valid_type=valid_type, default=indirect_default, serializer=to_aiida_type) + indirect_default = default + + spec.input(parameter.name, valid_type=valid_type, default=indirect_default, serializer=to_aiida_type) # Set defaults for label and description based on function name and docstring, if not explicitly defined port_label = spec.inputs['metadata']['label'] @@ -293,7 +378,7 @@ def _define(cls, spec): # pylint: disable=unused-argument ) @classmethod - def validate_inputs(cls, *args: Any, **kwargs: Any) -> None: # pylint: disable=unused-argument + def validate_inputs(cls, *args: t.Any, **kwargs: t.Any) -> None: # pylint: disable=unused-argument """ Validate the positional and keyword arguments passed in the function call. @@ -314,7 +399,7 @@ def validate_inputs(cls, *args: Any, **kwargs: Any) -> None: # pylint: disable= raise TypeError(f'{name}() takes {nparameters} positional arguments but {nargs} were given') @classmethod - def create_inputs(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: + def create_inputs(cls, *args: t.Any, **kwargs: t.Any) -> dict[str, t.Any]: """Create the input args for the FunctionProcess.""" cls.validate_inputs(*args, **kwargs) @@ -326,7 +411,7 @@ def create_inputs(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: return ins @classmethod - def args_to_dict(cls, *args: Any) -> Dict[str, Any]: + def args_to_dict(cls, *args: t.Any) -> dict[str, t.Any]: """ Create an input dictionary (of form label -> value) from supplied args. @@ -375,7 +460,7 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(enable_persistence=False, *args, **kwargs) # type: ignore @property - def process_class(self) -> Callable[..., Any]: + def process_class(self) -> t.Callable[..., t.Any]: """ Return the class that represents this Process, for the FunctionProcess this is the function itself. @@ -388,7 +473,7 @@ class that really represents what was being executed. """ return self._func - def execute(self) -> Optional[Dict[str, Any]]: + def execute(self) -> dict[str, t.Any] | None: """Execute the process.""" result = super().execute() @@ -405,7 +490,7 @@ def _setup_db_record(self) -> None: self.node.store_source_info(self._func) @override - def run(self) -> Optional['ExitCode']: + def run(self) -> 'ExitCode' | None: """Run the process.""" from .exit_code import ExitCode @@ -414,7 +499,7 @@ def run(self) -> Optional['ExitCode']: # been overridden by the engine to `Running` so we cannot check that, but if the `exit_status` is anything other # than `None`, it should mean this node was taken from the cache, so the process should not be rerun. if self.node.exit_status is not None: - return self.node.exit_status + return ExitCode(self.node.exit_status, self.node.exit_message) # Split the inputs into positional and keyword arguments args = [None] * len(self._func_args) diff --git a/docs/source/topics/calculations/concepts.rst b/docs/source/topics/calculations/concepts.rst index c6a00a9fea..afb178e9ae 100644 --- a/docs/source/topics/calculations/concepts.rst +++ b/docs/source/topics/calculations/concepts.rst @@ -55,6 +55,8 @@ To solve this, one only has to wrap them in the :py:class:`~aiida.orm.nodes.data The only difference with the previous snippet is that all inputs have been wrapped in the :py:class:`~aiida.orm.nodes.data.int.Int` class. The result that is returned by the function, is now also an :py:class:`~aiida.orm.nodes.data.int.Int` node that can be stored in the provenance graph, and contains the result of the computation. +.. _topics:calculations:concepts:calcfunctions:automatic-serialization: + .. versionadded:: 2.1 If a function argument is a Python base type (i.e. a value of type ``bool``, ``dict``, ``Enum``, ``float``, ``int``, ``list`` or ``str``), it can be passed straight away to the function, without first having to wrap it in the corresponding AiiDA data type. diff --git a/docs/source/topics/processes/functions.rst b/docs/source/topics/processes/functions.rst index 23b067f9e9..6a9d6ae8c0 100644 --- a/docs/source/topics/processes/functions.rst +++ b/docs/source/topics/processes/functions.rst @@ -116,6 +116,57 @@ The link labels for the example above will therefore be ``args_0``, ``args_1`` a If any of these labels were to overlap with the label of a positional or keyword argument, a ``RuntimeError`` will be raised. In this case, the conflicting argument name needs to be changed to something that does not overlap with the automatically generated labels for the variadic arguments. +Type validation +=============== + +.. versionadded:: 2.3 + +Type hints (introduced with `PEP 484 `_ in Python 3.5) can be used to add automatic type validation of process function arguments. +For example, the following will raise a ``ValueError`` exception: + +.. include:: include/snippets/functions/typing_call_raise.py + :code: python + +When the process function is declared, the process specification (``ProcessSpec``) is built dynamically. +For each function argument, if a correct type hint is provided, it is set as the ``valid_type`` attribute of the corresponding input port. +In the example above, the ``x`` and ``y`` inputs have ``Int`` as type hint, which is why the call that passes a ``Float`` raises a ``ValueError``. + +.. note:: + + Type hints for return values are currently not parsed. + +If an argument accepts multiple types, the ``typing.Union`` class can be used as normal: + +.. include:: include/snippets/functions/typing_union.py + :code: python + +The call with an ``Int`` and a ``Float`` will now finish correctly. +Similarly, optional arguments, with ``None`` as a default, can be declared using ``typing.Optional``: + +.. include:: include/snippets/functions/typing_none.py + :code: python + +The `postponed evaluation of annotations introduced by PEP 563 `_ is also supported. +This means it is possible to use Python base types for the type hint instead of AiiDA's ``Data`` node equivalent: + +.. include:: include/snippets/functions/typing_pep_563.py + :code: python + +The type hints are automatically serialized just as the actual inputs are when the function is called, :ref:`as introduced in v2.1`. + +The alternative syntax for union types ``X | Y`` `as introduced by PEP 604 `_ is also supported: + +.. include:: include/snippets/functions/typing_pep_604.py + :code: python + +.. warning:: + + The usage of notation as defined by PEP 563 and PEP 604 are not supported for Python versions older than 3.10, even if the ``from __future__ import annotations`` statement is added. + The reason is that the type inference uses the `inspect.get_annotations `_ method, which was introduced in Python 3.10. + For older Python versions, the `get-annotations `_ backport is used, but that does not work with PEP 563 and PEP 604, so the constructs from the ``typing`` module have to be used instead. + +If a process function has invalid type hints, they will simply be ignored and a warning message is logged: ``function 'function_name' has invalid type hints``. +This ensures backwards compatibility in the case existing process functions had invalid type hints. Return values ============= diff --git a/docs/source/topics/processes/include/snippets/functions/typing_call_raise.py b/docs/source/topics/processes/include/snippets/functions/typing_call_raise.py new file mode 100644 index 0000000000..c62cf149e8 --- /dev/null +++ b/docs/source/topics/processes/include/snippets/functions/typing_call_raise.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +from aiida.engine import calcfunction +from aiida.orm import Float, Int + + +@calcfunction +def add(x: Int, y: Int): + return x + y + +add(Int(1), Float(1.0)) diff --git a/docs/source/topics/processes/include/snippets/functions/typing_none.py b/docs/source/topics/processes/include/snippets/functions/typing_none.py new file mode 100644 index 0000000000..00405051f9 --- /dev/null +++ b/docs/source/topics/processes/include/snippets/functions/typing_none.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +import typing as t + +from aiida.engine import calcfunction +from aiida.orm import Int + + +@calcfunction +def add_multiply(x: Int, y: Int, z: typing.Optional[Int] = None): + if z is None: + z = Int(3) + + return (x + y) * z + +result = add_multiply(Int(1), Int(2)) +result = add_multiply(Int(1), Int(2), Int(3)) diff --git a/docs/source/topics/processes/include/snippets/functions/typing_pep_563.py b/docs/source/topics/processes/include/snippets/functions/typing_pep_563.py new file mode 100644 index 0000000000..cc52ca3ddf --- /dev/null +++ b/docs/source/topics/processes/include/snippets/functions/typing_pep_563.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +from aiida.engine import calcfunction + + +@calcfunction +def add(x: int, y: int): + return x + y + +add(1, 2) diff --git a/docs/source/topics/processes/include/snippets/functions/typing_pep_604.py b/docs/source/topics/processes/include/snippets/functions/typing_pep_604.py new file mode 100644 index 0000000000..64f9c30290 --- /dev/null +++ b/docs/source/topics/processes/include/snippets/functions/typing_pep_604.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +from aiida.engine import calcfunction +from aiida.orm import Int + + +@calcfunction +def add_multiply(x: int, y: int, z: int | None = None): + if z is None: + z = Int(3) + + return (x + y) * z + +result = add_multiply(1, 2) +result = add_multiply(1, 2, 3) diff --git a/docs/source/topics/processes/include/snippets/functions/typing_union.py b/docs/source/topics/processes/include/snippets/functions/typing_union.py new file mode 100644 index 0000000000..3811cad719 --- /dev/null +++ b/docs/source/topics/processes/include/snippets/functions/typing_union.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +import typing as t + +from aiida.engine import calcfunction +from aiida.orm import Float, Int + + +@calcfunction +def add(x: t.Union[Int, Float], y: t.Union[Int, Float]): + return x + y + +add(Int(1), Float(1.0)) diff --git a/environment.yml b/environment.yml index cffb828ef4..50a4acde10 100644 --- a/environment.yml +++ b/environment.yml @@ -14,6 +14,7 @@ dependencies: - click-spinner~=0.1.8 - click~=8.1 - disk-objectstore~=0.6.0 +- get-annotations~=0.1 - python-graphviz~=0.13 - ipython<9,>=7 - jinja2~=3.0 diff --git a/pyproject.toml b/pyproject.toml index e387fc424e..4e05598fe1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "click-spinner~=0.1.8", "click~=8.1", "disk-objectstore~=0.6.0", + "get-annotations~=0.1;python_version<'3.10'", "graphviz~=0.13", "ipython>=7,<9", "jinja2~=3.0", @@ -393,6 +394,7 @@ module = [ 'docutils.*', 'flask_cors.*', 'flask_restful.*', + 'get_annotations.*', 'graphviz.*', 'importlib._bootstrap.*', 'IPython.*', diff --git a/requirements/requirements-py-3.8.txt b/requirements/requirements-py-3.8.txt index 2f71868b15..b3be3fadbc 100644 --- a/requirements/requirements-py-3.8.txt +++ b/requirements/requirements-py-3.8.txt @@ -38,6 +38,7 @@ Flask-Cors==3.0.10 Flask-RESTful==0.3.9 fonttools==4.28.2 future==0.18.3 +get-annotations==0.1.2 graphviz==0.19 greenlet==1.1.2 idna==3.3 diff --git a/requirements/requirements-py-3.9.txt b/requirements/requirements-py-3.9.txt index 95162db0a7..f5d0477403 100644 --- a/requirements/requirements-py-3.9.txt +++ b/requirements/requirements-py-3.9.txt @@ -38,6 +38,7 @@ Flask-Cors==3.0.10 Flask-RESTful==0.3.9 fonttools==4.28.2 future==0.18.3 +get-annotations==0.1.2 graphviz==0.19 greenlet==1.1.2 idna==3.3 diff --git a/tests/engine/test_process_function.py b/tests/engine/test_process_function.py index 9bf6d90fc3..11767c4783 100644 --- a/tests/engine/test_process_function.py +++ b/tests/engine/test_process_function.py @@ -16,8 +16,12 @@ fly, but then anytime inputs or outputs would be attached to it in the tests, the ``validate_link`` function would complain as the dummy node class is not recognized as a valid process node. """ +from __future__ import annotations + import enum import re +import sys +import typing as t import pytest @@ -616,3 +620,96 @@ def function(**kwargs): } with pytest.raises(ValueError): function.run_get_node(**inputs) + + +def test_type_hinting_spec_inference(): + """Test the parsing of type hinting to define the valid types of the dynamically generated input ports.""" + + @calcfunction # type: ignore[misc] + def function( + a, + b: str, + c: bool, + d: orm.Str, + e: t.Union[orm.Str, orm.Int], + f: t.Union[str, int], + g: t.Optional[t.Dict] = None, + ): + # pylint: disable=invalid-name,unused-argument + pass + + input_namespace = function.spec().inputs + + expected = ( + ('a', (orm.Data,)), + ('b', (orm.Str,)), + ('c', (orm.Bool,)), + ('d', (orm.Str,)), + ('e', (orm.Str, orm.Int)), + ('f', (orm.Str, orm.Int)), + ('g', (orm.Dict, type(None))), + ) + + for key, valid_types in expected: + assert key in input_namespace + assert input_namespace[key].valid_type == valid_types, key + + +def test_type_hinting_spec_inference_pep_604(aiida_caplog): + """Test the parsing of type hinting that uses union typing of PEP 604 which is only available to Python 3.10 and up. + + Even though adding ``from __future__ import annotations`` should backport this functionality to Python 3.9 and older + the ``get_annotations`` method (which was also added to the ``inspect`` package in Python 3.10) as provided by the + ``get-annotations`` backport package fails for this new syntax when called with ``eval_str=True``. Therefore type + inference using this syntax only works on Python 3.10 and up. + + See https://peps.python.org/pep-0604 + """ + + @calcfunction # type: ignore[misc] + def function( + a: str | int, + b: orm.Str | orm.Int, + c: dict | None = None, + ): + # pylint: disable=invalid-name,unused-argument + pass + + input_namespace = function.spec().inputs + + # Since the PEP 604 union syntax is only available starting from Python 3.10 the type inference will not be + # available for older versions, and so the valid type will be the default ``(orm.Data,)``. + if sys.version_info[:2] >= (3, 10): + expected = ( + ('a', (orm.Str, orm.Int)), + ('b', (orm.Str, orm.Int)), + ('c', (orm.Dict, type(None))), + ) + else: + assert 'function `function` has invalid type hints: unsupported operand type' in aiida_caplog.records[0].message + expected = ( + ('a', (orm.Data,)), + ('b', (orm.Data,)), + ('c', (orm.Data, type(None))), + ) + + for key, valid_types in expected: + assert key in input_namespace + assert input_namespace[key].valid_type == valid_types, key + + +def test_type_hinting_validation(): + """Test that type hints are converted to automatic type checking through the process specification.""" + + @calcfunction # type: ignore[misc] + def function_type_hinting(a: t.Union[int, float]): + # pylint: disable=invalid-name + return a + 1 + + with pytest.raises(ValueError, match=r'.*value \'a\' is not of the right type.*'): + function_type_hinting('string') + + assert function_type_hinting(1) == 2 + assert function_type_hinting(orm.Int(1)) == 2 + assert function_type_hinting(1.0) == 2.0 + assert function_type_hinting(orm.Float(1)) == 2.0