diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index be5aaa3cc9..4c9f688f07 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -110,7 +110,6 @@ repos:
aiida/engine/processes/calcjobs/monitors.py|
aiida/engine/processes/calcjobs/tasks.py|
aiida/engine/processes/control.py|
- aiida/engine/processes/functions.py|
aiida/engine/processes/ports.py|
aiida/manage/configuration/__init__.py|
aiida/manage/configuration/config.py|
diff --git a/aiida/engine/processes/functions.py b/aiida/engine/processes/functions.py
index 051b5252b1..23d8a87522 100644
--- a/aiida/engine/processes/functions.py
+++ b/aiida/engine/processes/functions.py
@@ -15,15 +15,41 @@
import inspect
import logging
import signal
-from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Tuple, Type, TypeVar
+import types
+import typing as t
+from typing import TYPE_CHECKING
from aiida.common.lang import override
from aiida.manage import get_manager
-from aiida.orm import CalcFunctionNode, Data, ProcessNode, WorkFunctionNode, to_aiida_type
+from aiida.orm import (
+ Bool,
+ CalcFunctionNode,
+ Data,
+ Dict,
+ Float,
+ Int,
+ List,
+ ProcessNode,
+ Str,
+ WorkFunctionNode,
+ to_aiida_type,
+)
from aiida.orm.utils.mixins import FunctionCalculationMixin
from .process import Process
+try:
+ UnionType = types.UnionType # type: ignore[attr-defined]
+except AttributeError:
+ # This type is not available for Python 3.9 and older
+ UnionType = None # pylint: disable=invalid-name
+
+try:
+ get_annotations = inspect.get_annotations # type: ignore[attr-defined]
+except AttributeError:
+ # This is the backport for Python 3.9 and older
+ from get_annotations import get_annotations # type: ignore[no-redef]
+
if TYPE_CHECKING:
from .exit_code import ExitCode
@@ -31,7 +57,7 @@
LOGGER = logging.getLogger(__name__)
-FunctionType = TypeVar('FunctionType', bound=Callable[..., Any])
+FunctionType = t.TypeVar('FunctionType', bound=t.Callable[..., t.Any])
def calcfunction(function: FunctionType) -> FunctionType:
@@ -88,14 +114,14 @@ def workfunction(function: FunctionType) -> FunctionType:
return process_function(node_class=WorkFunctionNode)(function)
-def process_function(node_class: Type['ProcessNode']) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
+def process_function(node_class: t.Type['ProcessNode']) -> t.Callable[[FunctionType], FunctionType]:
"""
The base function decorator to create a FunctionProcess out of a normal python function.
:param node_class: the ORM class to be used as the Node record for the FunctionProcess
"""
- def decorator(function: Callable[..., Any]) -> Callable[..., Any]:
+ def decorator(function: FunctionType) -> FunctionType:
"""
Turn the decorated function into a FunctionProcess.
@@ -104,7 +130,7 @@ def decorator(function: Callable[..., Any]) -> Callable[..., Any]:
"""
process_class = FunctionProcess.build(function, node_class=node_class)
- def run_get_node(*args, **kwargs) -> Tuple[Optional[Dict[str, Any]], 'ProcessNode']:
+ def run_get_node(*args, **kwargs) -> tuple[dict[str, t.Any] | None, 'ProcessNode']:
"""
Run the FunctionProcess with the supplied inputs in a local runner.
@@ -159,7 +185,7 @@ def kill_process(_num, _frame):
return result, process.node
- def run_get_pk(*args, **kwargs) -> Tuple[Optional[Dict[str, Any]], int]:
+ def run_get_pk(*args, **kwargs) -> tuple[dict[str, t.Any] | None, int]:
"""Recreate the `run_get_pk` utility launcher.
:param args: input arguments to construct the FunctionProcess
@@ -185,15 +211,52 @@ def decorated_function(*args, **kwargs):
decorated_function.recreate_from = process_class.recreate_from # type: ignore[attr-defined]
decorated_function.spec = process_class.spec # type: ignore[attr-defined]
- return decorated_function
+ return decorated_function # type: ignore[return-value]
return decorator
+def infer_valid_type_from_type_annotation(annotation: t.Any) -> tuple[t.Any, ...]:
+ """Infer the value for the ``valid_type`` of an input port from the given function argument annotation.
+
+ :param annotation: The annotation of a function argument as returned by ``inspect.get_annotation``.
+ :returns: A tuple of valid types. If no valid types were defined or they could not be successfully parsed, an empty
+ tuple is returned.
+ """
+
+ def get_type_from_annotation(annotation):
+ valid_type_map = {
+ bool: Bool,
+ dict: Dict,
+ t.Dict: Dict,
+ float: Float,
+ int: Int,
+ list: List,
+ t.List: List,
+ str: Str,
+ }
+
+ if inspect.isclass(annotation) and issubclass(annotation, Data):
+ return annotation
+
+ return valid_type_map.get(annotation)
+
+ inferred_valid_type: tuple[t.Any, ...] = ()
+
+ if inspect.isclass(annotation):
+ inferred_valid_type = (get_type_from_annotation(annotation),)
+ elif t.get_origin(annotation) is t.Union or t.get_origin(annotation) is UnionType:
+ inferred_valid_type = tuple(get_type_from_annotation(valid_type) for valid_type in t.get_args(annotation))
+ elif t.get_origin(annotation) is t.Optional:
+ inferred_valid_type = (t.get_args(annotation),)
+
+ return tuple(valid_type for valid_type in inferred_valid_type if valid_type is not None)
+
+
class FunctionProcess(Process):
"""Function process class used for turning functions into a Process"""
- _func_args: Sequence[str] = ()
+ _func_args: t.Sequence[str] = ()
_varargs: str | None = None
@staticmethod
@@ -205,7 +268,7 @@ def _func(*_args, **_kwargs) -> dict:
return {}
@staticmethod
- def build(func: Callable[..., Any], node_class: Type['ProcessNode']) -> Type['FunctionProcess']:
+ def build(func: FunctionType, node_class: t.Type['ProcessNode']) -> t.Type['FunctionProcess']:
"""
Build a Process from the given function.
@@ -222,10 +285,30 @@ def build(func: Callable[..., Any], node_class: Type['ProcessNode']) -> Type['Fu
if not issubclass(node_class, ProcessNode) or not issubclass(node_class, FunctionCalculationMixin):
raise TypeError('the node_class should be a sub class of `ProcessNode` and `FunctionCalculationMixin`')
- args, varargs, keywords, defaults, _, _, _ = inspect.getfullargspec(func)
- nargs = len(args)
- ndefaults = len(defaults) if defaults else 0
- first_default_pos = nargs - ndefaults
+ signature = inspect.signature(func)
+
+ args: list[str] = []
+ varargs: str | None = None
+ keywords: str | None = None
+
+ try:
+ annotations = get_annotations(func, eval_str=True)
+ except Exception as exception: # pylint: disable=broad-except
+ # Since we are running with ``eval_str=True`` to unstringize the annotations, the call can except if the
+ # annotations are incorrect. In this case we simply want to log a warning and continue with type inference.
+ LOGGER.warning(f'function `{func.__name__}` has invalid type hints: {exception}')
+ annotations = {}
+
+ for key, parameter in signature.parameters.items():
+
+ if parameter.kind in [parameter.POSITIONAL_ONLY, parameter.POSITIONAL_OR_KEYWORD, parameter.KEYWORD_ONLY]:
+ args.append(key)
+
+ if parameter.kind is parameter.VAR_POSITIONAL:
+ varargs = key
+
+ if parameter.kind is parameter.VAR_KEYWORD:
+ varargs = key
def _define(cls, spec): # pylint: disable=unused-argument
"""Define the spec dynamically"""
@@ -233,37 +316,39 @@ def _define(cls, spec): # pylint: disable=unused-argument
super().define(spec)
- for i, arg in enumerate(args):
+ for parameter in signature.parameters.values():
+
+ if parameter.kind in [parameter.VAR_POSITIONAL, parameter.VAR_KEYWORD]:
+ continue
- default = UNSPECIFIED
+ annotation = annotations.get(parameter.name)
+ valid_type = infer_valid_type_from_type_annotation(annotation) or (Data,)
- if defaults and i >= first_default_pos:
- default = defaults[i - first_default_pos]
+ default = parameter.default if parameter.default is not parameter.empty else UNSPECIFIED
# If the keyword was already specified, simply override the default
- if spec.has_input(arg):
- spec.inputs[arg].default = default
+ if spec.has_input(parameter.name):
+ spec.inputs[parameter.name].default = default
+ continue
+
+ # If the default is ``None`` make sure that the port also accepts a ``NoneType``. Note that we cannot
+ # use ``None`` because the validation will call ``isinstance`` which does not work when passing ``None``
+ # but it does work with ``NoneType`` which is returned by calling ``type(None)``.
+ if default is None:
+ valid_type += (type(None),)
+
+ # If a default is defined and it is not a ``Data`` instance it should be serialized, but this should be
+ # done lazily using a lambda, just as any port defaults should not define node instances directly as is
+ # also checked by the ``spec.input`` call.
+ if (
+ default is not None and default != UNSPECIFIED and not isinstance(default, Data) and
+ not callable(default)
+ ):
+ indirect_default = lambda value=default: to_aiida_type(value)
else:
- # If the default is `None` make sure that the port also accepts a `NoneType`
- # Note that we cannot use `None` because the validation will call `isinstance` which does not work
- # when passing `None`, but it does work with `NoneType` which is returned by calling `type(None)`
- if default is None:
- valid_type = (Data, type(None))
- else:
- valid_type = (Data,)
-
- # If a default is defined and it is not a ``Data`` instance it should be serialized, but this should
- # be done lazily using a lambda, just as any port defaults should not define node instances directly
- # as is also checked by the ``spec.input`` call.
- if (
- default is not None and default != UNSPECIFIED and not isinstance(default, Data) and
- not callable(default)
- ):
- indirect_default = lambda value=default: to_aiida_type(value)
- else:
- indirect_default = default # type: ignore[assignment]
-
- spec.input(arg, valid_type=valid_type, default=indirect_default, serializer=to_aiida_type)
+ indirect_default = default
+
+ spec.input(parameter.name, valid_type=valid_type, default=indirect_default, serializer=to_aiida_type)
# Set defaults for label and description based on function name and docstring, if not explicitly defined
port_label = spec.inputs['metadata']['label']
@@ -293,7 +378,7 @@ def _define(cls, spec): # pylint: disable=unused-argument
)
@classmethod
- def validate_inputs(cls, *args: Any, **kwargs: Any) -> None: # pylint: disable=unused-argument
+ def validate_inputs(cls, *args: t.Any, **kwargs: t.Any) -> None: # pylint: disable=unused-argument
"""
Validate the positional and keyword arguments passed in the function call.
@@ -314,7 +399,7 @@ def validate_inputs(cls, *args: Any, **kwargs: Any) -> None: # pylint: disable=
raise TypeError(f'{name}() takes {nparameters} positional arguments but {nargs} were given')
@classmethod
- def create_inputs(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
+ def create_inputs(cls, *args: t.Any, **kwargs: t.Any) -> dict[str, t.Any]:
"""Create the input args for the FunctionProcess."""
cls.validate_inputs(*args, **kwargs)
@@ -326,7 +411,7 @@ def create_inputs(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
return ins
@classmethod
- def args_to_dict(cls, *args: Any) -> Dict[str, Any]:
+ def args_to_dict(cls, *args: t.Any) -> dict[str, t.Any]:
"""
Create an input dictionary (of form label -> value) from supplied args.
@@ -375,7 +460,7 @@ def __init__(self, *args, **kwargs) -> None:
super().__init__(enable_persistence=False, *args, **kwargs) # type: ignore
@property
- def process_class(self) -> Callable[..., Any]:
+ def process_class(self) -> t.Callable[..., t.Any]:
"""
Return the class that represents this Process, for the FunctionProcess this is the function itself.
@@ -388,7 +473,7 @@ class that really represents what was being executed.
"""
return self._func
- def execute(self) -> Optional[Dict[str, Any]]:
+ def execute(self) -> dict[str, t.Any] | None:
"""Execute the process."""
result = super().execute()
@@ -405,7 +490,7 @@ def _setup_db_record(self) -> None:
self.node.store_source_info(self._func)
@override
- def run(self) -> Optional['ExitCode']:
+ def run(self) -> 'ExitCode' | None:
"""Run the process."""
from .exit_code import ExitCode
@@ -414,7 +499,7 @@ def run(self) -> Optional['ExitCode']:
# been overridden by the engine to `Running` so we cannot check that, but if the `exit_status` is anything other
# than `None`, it should mean this node was taken from the cache, so the process should not be rerun.
if self.node.exit_status is not None:
- return self.node.exit_status
+ return ExitCode(self.node.exit_status, self.node.exit_message)
# Split the inputs into positional and keyword arguments
args = [None] * len(self._func_args)
diff --git a/docs/source/topics/calculations/concepts.rst b/docs/source/topics/calculations/concepts.rst
index c6a00a9fea..afb178e9ae 100644
--- a/docs/source/topics/calculations/concepts.rst
+++ b/docs/source/topics/calculations/concepts.rst
@@ -55,6 +55,8 @@ To solve this, one only has to wrap them in the :py:class:`~aiida.orm.nodes.data
The only difference with the previous snippet is that all inputs have been wrapped in the :py:class:`~aiida.orm.nodes.data.int.Int` class.
The result that is returned by the function, is now also an :py:class:`~aiida.orm.nodes.data.int.Int` node that can be stored in the provenance graph, and contains the result of the computation.
+.. _topics:calculations:concepts:calcfunctions:automatic-serialization:
+
.. versionadded:: 2.1
If a function argument is a Python base type (i.e. a value of type ``bool``, ``dict``, ``Enum``, ``float``, ``int``, ``list`` or ``str``), it can be passed straight away to the function, without first having to wrap it in the corresponding AiiDA data type.
diff --git a/docs/source/topics/processes/functions.rst b/docs/source/topics/processes/functions.rst
index 23b067f9e9..6a9d6ae8c0 100644
--- a/docs/source/topics/processes/functions.rst
+++ b/docs/source/topics/processes/functions.rst
@@ -116,6 +116,57 @@ The link labels for the example above will therefore be ``args_0``, ``args_1`` a
If any of these labels were to overlap with the label of a positional or keyword argument, a ``RuntimeError`` will be raised.
In this case, the conflicting argument name needs to be changed to something that does not overlap with the automatically generated labels for the variadic arguments.
+Type validation
+===============
+
+.. versionadded:: 2.3
+
+Type hints (introduced with `PEP 484 `_ in Python 3.5) can be used to add automatic type validation of process function arguments.
+For example, the following will raise a ``ValueError`` exception:
+
+.. include:: include/snippets/functions/typing_call_raise.py
+ :code: python
+
+When the process function is declared, the process specification (``ProcessSpec``) is built dynamically.
+For each function argument, if a correct type hint is provided, it is set as the ``valid_type`` attribute of the corresponding input port.
+In the example above, the ``x`` and ``y`` inputs have ``Int`` as type hint, which is why the call that passes a ``Float`` raises a ``ValueError``.
+
+.. note::
+
+ Type hints for return values are currently not parsed.
+
+If an argument accepts multiple types, the ``typing.Union`` class can be used as normal:
+
+.. include:: include/snippets/functions/typing_union.py
+ :code: python
+
+The call with an ``Int`` and a ``Float`` will now finish correctly.
+Similarly, optional arguments, with ``None`` as a default, can be declared using ``typing.Optional``:
+
+.. include:: include/snippets/functions/typing_none.py
+ :code: python
+
+The `postponed evaluation of annotations introduced by PEP 563 `_ is also supported.
+This means it is possible to use Python base types for the type hint instead of AiiDA's ``Data`` node equivalent:
+
+.. include:: include/snippets/functions/typing_pep_563.py
+ :code: python
+
+The type hints are automatically serialized just as the actual inputs are when the function is called, :ref:`as introduced in v2.1`.
+
+The alternative syntax for union types ``X | Y`` `as introduced by PEP 604 `_ is also supported:
+
+.. include:: include/snippets/functions/typing_pep_604.py
+ :code: python
+
+.. warning::
+
+ The usage of notation as defined by PEP 563 and PEP 604 are not supported for Python versions older than 3.10, even if the ``from __future__ import annotations`` statement is added.
+ The reason is that the type inference uses the `inspect.get_annotations `_ method, which was introduced in Python 3.10.
+ For older Python versions, the `get-annotations `_ backport is used, but that does not work with PEP 563 and PEP 604, so the constructs from the ``typing`` module have to be used instead.
+
+If a process function has invalid type hints, they will simply be ignored and a warning message is logged: ``function 'function_name' has invalid type hints``.
+This ensures backwards compatibility in the case existing process functions had invalid type hints.
Return values
=============
diff --git a/docs/source/topics/processes/include/snippets/functions/typing_call_raise.py b/docs/source/topics/processes/include/snippets/functions/typing_call_raise.py
new file mode 100644
index 0000000000..c62cf149e8
--- /dev/null
+++ b/docs/source/topics/processes/include/snippets/functions/typing_call_raise.py
@@ -0,0 +1,10 @@
+# -*- coding: utf-8 -*-
+from aiida.engine import calcfunction
+from aiida.orm import Float, Int
+
+
+@calcfunction
+def add(x: Int, y: Int):
+ return x + y
+
+add(Int(1), Float(1.0))
diff --git a/docs/source/topics/processes/include/snippets/functions/typing_none.py b/docs/source/topics/processes/include/snippets/functions/typing_none.py
new file mode 100644
index 0000000000..00405051f9
--- /dev/null
+++ b/docs/source/topics/processes/include/snippets/functions/typing_none.py
@@ -0,0 +1,16 @@
+# -*- coding: utf-8 -*-
+import typing as t
+
+from aiida.engine import calcfunction
+from aiida.orm import Int
+
+
+@calcfunction
+def add_multiply(x: Int, y: Int, z: typing.Optional[Int] = None):
+ if z is None:
+ z = Int(3)
+
+ return (x + y) * z
+
+result = add_multiply(Int(1), Int(2))
+result = add_multiply(Int(1), Int(2), Int(3))
diff --git a/docs/source/topics/processes/include/snippets/functions/typing_pep_563.py b/docs/source/topics/processes/include/snippets/functions/typing_pep_563.py
new file mode 100644
index 0000000000..cc52ca3ddf
--- /dev/null
+++ b/docs/source/topics/processes/include/snippets/functions/typing_pep_563.py
@@ -0,0 +1,11 @@
+# -*- coding: utf-8 -*-
+from __future__ import annotations
+
+from aiida.engine import calcfunction
+
+
+@calcfunction
+def add(x: int, y: int):
+ return x + y
+
+add(1, 2)
diff --git a/docs/source/topics/processes/include/snippets/functions/typing_pep_604.py b/docs/source/topics/processes/include/snippets/functions/typing_pep_604.py
new file mode 100644
index 0000000000..64f9c30290
--- /dev/null
+++ b/docs/source/topics/processes/include/snippets/functions/typing_pep_604.py
@@ -0,0 +1,16 @@
+# -*- coding: utf-8 -*-
+from __future__ import annotations
+
+from aiida.engine import calcfunction
+from aiida.orm import Int
+
+
+@calcfunction
+def add_multiply(x: int, y: int, z: int | None = None):
+ if z is None:
+ z = Int(3)
+
+ return (x + y) * z
+
+result = add_multiply(1, 2)
+result = add_multiply(1, 2, 3)
diff --git a/docs/source/topics/processes/include/snippets/functions/typing_union.py b/docs/source/topics/processes/include/snippets/functions/typing_union.py
new file mode 100644
index 0000000000..3811cad719
--- /dev/null
+++ b/docs/source/topics/processes/include/snippets/functions/typing_union.py
@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+import typing as t
+
+from aiida.engine import calcfunction
+from aiida.orm import Float, Int
+
+
+@calcfunction
+def add(x: t.Union[Int, Float], y: t.Union[Int, Float]):
+ return x + y
+
+add(Int(1), Float(1.0))
diff --git a/environment.yml b/environment.yml
index cffb828ef4..50a4acde10 100644
--- a/environment.yml
+++ b/environment.yml
@@ -14,6 +14,7 @@ dependencies:
- click-spinner~=0.1.8
- click~=8.1
- disk-objectstore~=0.6.0
+- get-annotations~=0.1
- python-graphviz~=0.13
- ipython<9,>=7
- jinja2~=3.0
diff --git a/pyproject.toml b/pyproject.toml
index e387fc424e..4e05598fe1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -33,6 +33,7 @@ dependencies = [
"click-spinner~=0.1.8",
"click~=8.1",
"disk-objectstore~=0.6.0",
+ "get-annotations~=0.1;python_version<'3.10'",
"graphviz~=0.13",
"ipython>=7,<9",
"jinja2~=3.0",
@@ -393,6 +394,7 @@ module = [
'docutils.*',
'flask_cors.*',
'flask_restful.*',
+ 'get_annotations.*',
'graphviz.*',
'importlib._bootstrap.*',
'IPython.*',
diff --git a/requirements/requirements-py-3.8.txt b/requirements/requirements-py-3.8.txt
index 2f71868b15..b3be3fadbc 100644
--- a/requirements/requirements-py-3.8.txt
+++ b/requirements/requirements-py-3.8.txt
@@ -38,6 +38,7 @@ Flask-Cors==3.0.10
Flask-RESTful==0.3.9
fonttools==4.28.2
future==0.18.3
+get-annotations==0.1.2
graphviz==0.19
greenlet==1.1.2
idna==3.3
diff --git a/requirements/requirements-py-3.9.txt b/requirements/requirements-py-3.9.txt
index 95162db0a7..f5d0477403 100644
--- a/requirements/requirements-py-3.9.txt
+++ b/requirements/requirements-py-3.9.txt
@@ -38,6 +38,7 @@ Flask-Cors==3.0.10
Flask-RESTful==0.3.9
fonttools==4.28.2
future==0.18.3
+get-annotations==0.1.2
graphviz==0.19
greenlet==1.1.2
idna==3.3
diff --git a/tests/engine/test_process_function.py b/tests/engine/test_process_function.py
index 9bf6d90fc3..11767c4783 100644
--- a/tests/engine/test_process_function.py
+++ b/tests/engine/test_process_function.py
@@ -16,8 +16,12 @@
fly, but then anytime inputs or outputs would be attached to it in the tests, the ``validate_link`` function would
complain as the dummy node class is not recognized as a valid process node.
"""
+from __future__ import annotations
+
import enum
import re
+import sys
+import typing as t
import pytest
@@ -616,3 +620,96 @@ def function(**kwargs):
}
with pytest.raises(ValueError):
function.run_get_node(**inputs)
+
+
+def test_type_hinting_spec_inference():
+ """Test the parsing of type hinting to define the valid types of the dynamically generated input ports."""
+
+ @calcfunction # type: ignore[misc]
+ def function(
+ a,
+ b: str,
+ c: bool,
+ d: orm.Str,
+ e: t.Union[orm.Str, orm.Int],
+ f: t.Union[str, int],
+ g: t.Optional[t.Dict] = None,
+ ):
+ # pylint: disable=invalid-name,unused-argument
+ pass
+
+ input_namespace = function.spec().inputs
+
+ expected = (
+ ('a', (orm.Data,)),
+ ('b', (orm.Str,)),
+ ('c', (orm.Bool,)),
+ ('d', (orm.Str,)),
+ ('e', (orm.Str, orm.Int)),
+ ('f', (orm.Str, orm.Int)),
+ ('g', (orm.Dict, type(None))),
+ )
+
+ for key, valid_types in expected:
+ assert key in input_namespace
+ assert input_namespace[key].valid_type == valid_types, key
+
+
+def test_type_hinting_spec_inference_pep_604(aiida_caplog):
+ """Test the parsing of type hinting that uses union typing of PEP 604 which is only available to Python 3.10 and up.
+
+ Even though adding ``from __future__ import annotations`` should backport this functionality to Python 3.9 and older
+ the ``get_annotations`` method (which was also added to the ``inspect`` package in Python 3.10) as provided by the
+ ``get-annotations`` backport package fails for this new syntax when called with ``eval_str=True``. Therefore type
+ inference using this syntax only works on Python 3.10 and up.
+
+ See https://peps.python.org/pep-0604
+ """
+
+ @calcfunction # type: ignore[misc]
+ def function(
+ a: str | int,
+ b: orm.Str | orm.Int,
+ c: dict | None = None,
+ ):
+ # pylint: disable=invalid-name,unused-argument
+ pass
+
+ input_namespace = function.spec().inputs
+
+ # Since the PEP 604 union syntax is only available starting from Python 3.10 the type inference will not be
+ # available for older versions, and so the valid type will be the default ``(orm.Data,)``.
+ if sys.version_info[:2] >= (3, 10):
+ expected = (
+ ('a', (orm.Str, orm.Int)),
+ ('b', (orm.Str, orm.Int)),
+ ('c', (orm.Dict, type(None))),
+ )
+ else:
+ assert 'function `function` has invalid type hints: unsupported operand type' in aiida_caplog.records[0].message
+ expected = (
+ ('a', (orm.Data,)),
+ ('b', (orm.Data,)),
+ ('c', (orm.Data, type(None))),
+ )
+
+ for key, valid_types in expected:
+ assert key in input_namespace
+ assert input_namespace[key].valid_type == valid_types, key
+
+
+def test_type_hinting_validation():
+ """Test that type hints are converted to automatic type checking through the process specification."""
+
+ @calcfunction # type: ignore[misc]
+ def function_type_hinting(a: t.Union[int, float]):
+ # pylint: disable=invalid-name
+ return a + 1
+
+ with pytest.raises(ValueError, match=r'.*value \'a\' is not of the right type.*'):
+ function_type_hinting('string')
+
+ assert function_type_hinting(1) == 2
+ assert function_type_hinting(orm.Int(1)) == 2
+ assert function_type_hinting(1.0) == 2.0
+ assert function_type_hinting(orm.Float(1)) == 2.0