Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Process functions: Infer argument valid_type from type hints #5900

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ repos:
aiida/engine/processes/calcjobs/monitors.py|
aiida/engine/processes/calcjobs/tasks.py|
aiida/engine/processes/control.py|
aiida/engine/processes/functions.py|
aiida/engine/processes/ports.py|
aiida/manage/configuration/__init__.py|
aiida/manage/configuration/config.py|
Expand Down
179 changes: 132 additions & 47 deletions aiida/engine/processes/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,49 @@
import inspect
import logging
import signal
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Tuple, Type, TypeVar
import types
import typing as t
from typing import TYPE_CHECKING

from aiida.common.lang import override
from aiida.manage import get_manager
from aiida.orm import CalcFunctionNode, Data, ProcessNode, WorkFunctionNode, to_aiida_type
from aiida.orm import (
Bool,
CalcFunctionNode,
Data,
Dict,
Float,
Int,
List,
ProcessNode,
Str,
WorkFunctionNode,
to_aiida_type,
)
from aiida.orm.utils.mixins import FunctionCalculationMixin

from .process import Process

try:
UnionType = types.UnionType # type: ignore[attr-defined]
except AttributeError:
# This type is not available for Python 3.9 and older
UnionType = None # pylint: disable=invalid-name

try:
get_annotations = inspect.get_annotations # type: ignore[attr-defined]
except AttributeError:
# This is the backport for Python 3.9 and older
from get_annotations import get_annotations # type: ignore[no-redef]

if TYPE_CHECKING:
from .exit_code import ExitCode

__all__ = ('calcfunction', 'workfunction', 'FunctionProcess')

LOGGER = logging.getLogger(__name__)

FunctionType = TypeVar('FunctionType', bound=Callable[..., Any])
FunctionType = t.TypeVar('FunctionType', bound=t.Callable[..., t.Any])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note in #5901 I have started naming these with a Tv suffix, like FunctionTv, because pylint no longer allows Type pylint-dev/pylint#6003

Also can you make sure this file is no longer in the exclude list in the pre-commit-config

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask the reason for choosing Tv as the suffix? When going over #5901 it already struck me a bit weird, since I read it automatically as "television" 😅 I wasn't going to mention it, since in the end it is an arbitrary choice, but if it is pylint that complains, why not use what they suggest as naming namely T as suffix. This also makes more sense to me. I would propose we just maintain that convention. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chrisjsewell I have removed the file from the black list. Since the naming of the type var is still an open discussion and the latest version of pylint is anyway not being used, I will leave that for another PR. Since you had no other comments, I take it this is good to go and will merge tomorrow.



def calcfunction(function: FunctionType) -> FunctionType:
Expand Down Expand Up @@ -88,14 +114,14 @@ def workfunction(function: FunctionType) -> FunctionType:
return process_function(node_class=WorkFunctionNode)(function)


def process_function(node_class: Type['ProcessNode']) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def process_function(node_class: t.Type['ProcessNode']) -> t.Callable[[FunctionType], FunctionType]:
"""
The base function decorator to create a FunctionProcess out of a normal python function.

:param node_class: the ORM class to be used as the Node record for the FunctionProcess
"""

def decorator(function: Callable[..., Any]) -> Callable[..., Any]:
def decorator(function: FunctionType) -> FunctionType:
"""
Turn the decorated function into a FunctionProcess.

Expand All @@ -104,7 +130,7 @@ def decorator(function: Callable[..., Any]) -> Callable[..., Any]:
"""
process_class = FunctionProcess.build(function, node_class=node_class)

def run_get_node(*args, **kwargs) -> Tuple[Optional[Dict[str, Any]], 'ProcessNode']:
def run_get_node(*args, **kwargs) -> tuple[dict[str, t.Any] | None, 'ProcessNode']:
"""
Run the FunctionProcess with the supplied inputs in a local runner.

Expand Down Expand Up @@ -159,7 +185,7 @@ def kill_process(_num, _frame):

return result, process.node

def run_get_pk(*args, **kwargs) -> Tuple[Optional[Dict[str, Any]], int]:
def run_get_pk(*args, **kwargs) -> tuple[dict[str, t.Any] | None, int]:
"""Recreate the `run_get_pk` utility launcher.

:param args: input arguments to construct the FunctionProcess
Expand All @@ -185,15 +211,52 @@ def decorated_function(*args, **kwargs):
decorated_function.recreate_from = process_class.recreate_from # type: ignore[attr-defined]
decorated_function.spec = process_class.spec # type: ignore[attr-defined]

return decorated_function
return decorated_function # type: ignore[return-value]

return decorator


def infer_valid_type_from_type_annotation(annotation: t.Any) -> tuple[t.Any, ...]:
"""Infer the value for the ``valid_type`` of an input port from the given function argument annotation.

:param annotation: The annotation of a function argument as returned by ``inspect.get_annotation``.
:returns: A tuple of valid types. If no valid types were defined or they could not be successfully parsed, an empty
tuple is returned.
"""

def get_type_from_annotation(annotation):
valid_type_map = {
bool: Bool,
dict: Dict,
t.Dict: Dict,
float: Float,
int: Int,
list: List,
t.List: List,
str: Str,
}

if inspect.isclass(annotation) and issubclass(annotation, Data):
return annotation

return valid_type_map.get(annotation)

inferred_valid_type: tuple[t.Any, ...] = ()

if inspect.isclass(annotation):
inferred_valid_type = (get_type_from_annotation(annotation),)
elif t.get_origin(annotation) is t.Union or t.get_origin(annotation) is UnionType:
inferred_valid_type = tuple(get_type_from_annotation(valid_type) for valid_type in t.get_args(annotation))
elif t.get_origin(annotation) is t.Optional:
inferred_valid_type = (t.get_args(annotation),)

return tuple(valid_type for valid_type in inferred_valid_type if valid_type is not None)


class FunctionProcess(Process):
"""Function process class used for turning functions into a Process"""

_func_args: Sequence[str] = ()
_func_args: t.Sequence[str] = ()
_varargs: str | None = None

@staticmethod
Expand All @@ -205,7 +268,7 @@ def _func(*_args, **_kwargs) -> dict:
return {}

@staticmethod
def build(func: Callable[..., Any], node_class: Type['ProcessNode']) -> Type['FunctionProcess']:
def build(func: FunctionType, node_class: t.Type['ProcessNode']) -> t.Type['FunctionProcess']:
"""
Build a Process from the given function.

Expand All @@ -222,48 +285,70 @@ def build(func: Callable[..., Any], node_class: Type['ProcessNode']) -> Type['Fu
if not issubclass(node_class, ProcessNode) or not issubclass(node_class, FunctionCalculationMixin):
raise TypeError('the node_class should be a sub class of `ProcessNode` and `FunctionCalculationMixin`')

args, varargs, keywords, defaults, _, _, _ = inspect.getfullargspec(func)
nargs = len(args)
ndefaults = len(defaults) if defaults else 0
first_default_pos = nargs - ndefaults
signature = inspect.signature(func)

args: list[str] = []
varargs: str | None = None
keywords: str | None = None

try:
annotations = get_annotations(func, eval_str=True)
except Exception as exception: # pylint: disable=broad-except
# Since we are running with ``eval_str=True`` to unstringize the annotations, the call can except if the
# annotations are incorrect. In this case we simply want to log a warning and continue with type inference.
LOGGER.warning(f'function `{func.__name__}` has invalid type hints: {exception}')
annotations = {}

for key, parameter in signature.parameters.items():

if parameter.kind in [parameter.POSITIONAL_ONLY, parameter.POSITIONAL_OR_KEYWORD, parameter.KEYWORD_ONLY]:
args.append(key)

if parameter.kind is parameter.VAR_POSITIONAL:
varargs = key

if parameter.kind is parameter.VAR_KEYWORD:
varargs = key

def _define(cls, spec): # pylint: disable=unused-argument
"""Define the spec dynamically"""
from plumpy.ports import UNSPECIFIED

super().define(spec)

for i, arg in enumerate(args):
for parameter in signature.parameters.values():

if parameter.kind in [parameter.VAR_POSITIONAL, parameter.VAR_KEYWORD]:
continue

default = UNSPECIFIED
annotation = annotations.get(parameter.name)
valid_type = infer_valid_type_from_type_annotation(annotation) or (Data,)

if defaults and i >= first_default_pos:
default = defaults[i - first_default_pos]
default = parameter.default if parameter.default is not parameter.empty else UNSPECIFIED

# If the keyword was already specified, simply override the default
if spec.has_input(arg):
spec.inputs[arg].default = default
if spec.has_input(parameter.name):
spec.inputs[parameter.name].default = default
continue

# If the default is ``None`` make sure that the port also accepts a ``NoneType``. Note that we cannot
# use ``None`` because the validation will call ``isinstance`` which does not work when passing ``None``
# but it does work with ``NoneType`` which is returned by calling ``type(None)``.
if default is None:
valid_type += (type(None),)

# If a default is defined and it is not a ``Data`` instance it should be serialized, but this should be
# done lazily using a lambda, just as any port defaults should not define node instances directly as is
# also checked by the ``spec.input`` call.
if (
default is not None and default != UNSPECIFIED and not isinstance(default, Data) and
not callable(default)
):
indirect_default = lambda value=default: to_aiida_type(value)
else:
# If the default is `None` make sure that the port also accepts a `NoneType`
# Note that we cannot use `None` because the validation will call `isinstance` which does not work
# when passing `None`, but it does work with `NoneType` which is returned by calling `type(None)`
if default is None:
valid_type = (Data, type(None))
else:
valid_type = (Data,)

# If a default is defined and it is not a ``Data`` instance it should be serialized, but this should
# be done lazily using a lambda, just as any port defaults should not define node instances directly
# as is also checked by the ``spec.input`` call.
if (
default is not None and default != UNSPECIFIED and not isinstance(default, Data) and
not callable(default)
):
indirect_default = lambda value=default: to_aiida_type(value)
else:
indirect_default = default # type: ignore[assignment]

spec.input(arg, valid_type=valid_type, default=indirect_default, serializer=to_aiida_type)
indirect_default = default

spec.input(parameter.name, valid_type=valid_type, default=indirect_default, serializer=to_aiida_type)

# Set defaults for label and description based on function name and docstring, if not explicitly defined
port_label = spec.inputs['metadata']['label']
Expand Down Expand Up @@ -293,7 +378,7 @@ def _define(cls, spec): # pylint: disable=unused-argument
)

@classmethod
def validate_inputs(cls, *args: Any, **kwargs: Any) -> None: # pylint: disable=unused-argument
def validate_inputs(cls, *args: t.Any, **kwargs: t.Any) -> None: # pylint: disable=unused-argument
"""
Validate the positional and keyword arguments passed in the function call.

Expand All @@ -314,7 +399,7 @@ def validate_inputs(cls, *args: Any, **kwargs: Any) -> None: # pylint: disable=
raise TypeError(f'{name}() takes {nparameters} positional arguments but {nargs} were given')

@classmethod
def create_inputs(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
def create_inputs(cls, *args: t.Any, **kwargs: t.Any) -> dict[str, t.Any]:
"""Create the input args for the FunctionProcess."""
cls.validate_inputs(*args, **kwargs)

Expand All @@ -326,7 +411,7 @@ def create_inputs(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
return ins

@classmethod
def args_to_dict(cls, *args: Any) -> Dict[str, Any]:
def args_to_dict(cls, *args: t.Any) -> dict[str, t.Any]:
"""
Create an input dictionary (of form label -> value) from supplied args.

Expand Down Expand Up @@ -375,7 +460,7 @@ def __init__(self, *args, **kwargs) -> None:
super().__init__(enable_persistence=False, *args, **kwargs) # type: ignore

@property
def process_class(self) -> Callable[..., Any]:
def process_class(self) -> t.Callable[..., t.Any]:
"""
Return the class that represents this Process, for the FunctionProcess this is the function itself.

Expand All @@ -388,7 +473,7 @@ class that really represents what was being executed.
"""
return self._func

def execute(self) -> Optional[Dict[str, Any]]:
def execute(self) -> dict[str, t.Any] | None:
"""Execute the process."""
result = super().execute()

Expand All @@ -405,7 +490,7 @@ def _setup_db_record(self) -> None:
self.node.store_source_info(self._func)

@override
def run(self) -> Optional['ExitCode']:
def run(self) -> 'ExitCode' | None:
"""Run the process."""
from .exit_code import ExitCode

Expand All @@ -414,7 +499,7 @@ def run(self) -> Optional['ExitCode']:
# been overridden by the engine to `Running` so we cannot check that, but if the `exit_status` is anything other
# than `None`, it should mean this node was taken from the cache, so the process should not be rerun.
if self.node.exit_status is not None:
return self.node.exit_status
return ExitCode(self.node.exit_status, self.node.exit_message)

# Split the inputs into positional and keyword arguments
args = [None] * len(self._func_args)
Expand Down
2 changes: 2 additions & 0 deletions docs/source/topics/calculations/concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ To solve this, one only has to wrap them in the :py:class:`~aiida.orm.nodes.data
The only difference with the previous snippet is that all inputs have been wrapped in the :py:class:`~aiida.orm.nodes.data.int.Int` class.
The result that is returned by the function, is now also an :py:class:`~aiida.orm.nodes.data.int.Int` node that can be stored in the provenance graph, and contains the result of the computation.

.. _topics:calculations:concepts:calcfunctions:automatic-serialization:

.. versionadded:: 2.1

If a function argument is a Python base type (i.e. a value of type ``bool``, ``dict``, ``Enum``, ``float``, ``int``, ``list`` or ``str``), it can be passed straight away to the function, without first having to wrap it in the corresponding AiiDA data type.
Expand Down
51 changes: 51 additions & 0 deletions docs/source/topics/processes/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,57 @@ The link labels for the example above will therefore be ``args_0``, ``args_1`` a
If any of these labels were to overlap with the label of a positional or keyword argument, a ``RuntimeError`` will be raised.
In this case, the conflicting argument name needs to be changed to something that does not overlap with the automatically generated labels for the variadic arguments.

Type validation
===============

.. versionadded:: 2.3

Type hints (introduced with `PEP 484 <https://peps.python.org/pep-0484/>`_ in Python 3.5) can be used to add automatic type validation of process function arguments.
For example, the following will raise a ``ValueError`` exception:

.. include:: include/snippets/functions/typing_call_raise.py
:code: python

When the process function is declared, the process specification (``ProcessSpec``) is built dynamically.
For each function argument, if a correct type hint is provided, it is set as the ``valid_type`` attribute of the corresponding input port.
In the example above, the ``x`` and ``y`` inputs have ``Int`` as type hint, which is why the call that passes a ``Float`` raises a ``ValueError``.

.. note::

Type hints for return values are currently not parsed.

If an argument accepts multiple types, the ``typing.Union`` class can be used as normal:

.. include:: include/snippets/functions/typing_union.py
:code: python

The call with an ``Int`` and a ``Float`` will now finish correctly.
Similarly, optional arguments, with ``None`` as a default, can be declared using ``typing.Optional``:

.. include:: include/snippets/functions/typing_none.py
:code: python

The `postponed evaluation of annotations introduced by PEP 563 <https://peps.python.org/pep-0563/>`_ is also supported.
This means it is possible to use Python base types for the type hint instead of AiiDA's ``Data`` node equivalent:

.. include:: include/snippets/functions/typing_pep_563.py
:code: python

The type hints are automatically serialized just as the actual inputs are when the function is called, :ref:`as introduced in v2.1<topics:calculations:concepts:calcfunctions:automatic-serialization>`.

The alternative syntax for union types ``X | Y`` `as introduced by PEP 604 <https://peps.python.org/pep-0604/>`_ is also supported:

.. include:: include/snippets/functions/typing_pep_604.py
:code: python

.. warning::

The usage of notation as defined by PEP 563 and PEP 604 are not supported for Python versions older than 3.10, even if the ``from __future__ import annotations`` statement is added.
The reason is that the type inference uses the `inspect.get_annotations <https://docs.python.org/3/library/inspect.html#inspect.get_annotations>`_ method, which was introduced in Python 3.10.
For older Python versions, the `get-annotations <https://pypi.org/project/get-annotations>`_ backport is used, but that does not work with PEP 563 and PEP 604, so the constructs from the ``typing`` module have to be used instead.

If a process function has invalid type hints, they will simply be ignored and a warning message is logged: ``function 'function_name' has invalid type hints``.
This ensures backwards compatibility in the case existing process functions had invalid type hints.

Return values
=============
Expand Down
Loading