diff --git a/src/duty/validation.py b/src/duty/validation.py index 4858916..6976cf0 100644 --- a/src/duty/validation.py +++ b/src/duty/validation.py @@ -8,6 +8,7 @@ from __future__ import annotations import sys +import textwrap from functools import cached_property from inspect import Parameter, Signature, signature from typing import Any, Callable, Sequence @@ -48,16 +49,14 @@ def cast_arg(arg: Any, annotation: Any) -> Any: class ParamsCaster: """A helper class to cast parameters based on a function's signature annotations.""" - def __init__(self, function: Callable) -> None: + def __init__(self, signature: Signature) -> None: """Initialize the object. Parameters: - function: The function to use to cast arguments. + signature: The signature to use to cast arguments. """ - self.function = function - self.signature = signature(function) - self.params_dict = self.signature.parameters - self.params_list = list(self.params_dict.values())[1:] + self.params_dict = signature.parameters + self.params_list = list(self.params_dict.values()) @cached_property def var_positional_position(self) -> int: @@ -168,6 +167,52 @@ def cast(self, *args: Any, **kwargs: Any) -> tuple[Sequence, dict[str, Any]]: return positional, keyword +def _get_params_caster(func: Callable, *args: Any, **kwargs: Any) -> ParamsCaster: + duties_module = sys.modules[func.__module__] + exec_globals = dict(duties_module.__dict__) + eval_str = False + for name in list(exec_globals.keys()): + if exec_globals[name] is annotations: + eval_str = True + del exec_globals[name] + exec_globals["__context_above"] = {} + + # Don't keep first parameter: context. + params = list(signature(func).parameters.values())[1:] + params_no_types = [Parameter(param.name, param.kind, default=param.default) for param in params] + code_sig = Signature(parameters=params_no_types) + if eval_str: + params_types = [ + Parameter( + param.name, + param.kind, + default=param.default, + annotation=eval( # noqa: PGH001 + param.annotation, + exec_globals, + ), + ) + for param in params + ] + else: + params_types = params + cast_sig = Signature(parameters=params_types) + + code = f""" + import inspect + def {func.__name__}{code_sig}: ... + __context_above['func'] = {func.__name__} + """ + + exec(textwrap.dedent(code), exec_globals) # noqa: S102 + func = exec_globals["__context_above"]["func"] + + # Trigger TypeError early. + func(*args, **kwargs) + + return ParamsCaster(cast_sig) + + def validate( func: Callable, *args: Any, @@ -190,14 +235,4 @@ def validate( Returns: The casted arguments. """ - name = func.__name__ - - # don't keep first parameter: context - params_list = list(signature(func).parameters.values())[1:] - params = [Parameter(param.name, param.kind, default=param.default) for param in params_list] - sig = Signature(parameters=params) - trixx: dict = {} - exec(f"def {name}{sig}: ...\ntrixx[0] = {name}") # noqa: S102 - trixx[0](*args, **kwargs) - caster = ParamsCaster(func) - return caster.cast(*args, **kwargs) + return _get_params_caster(func, *args, **kwargs).cast(*args, **kwargs) diff --git a/tests/test_validation.py b/tests/test_validation.py index e7fb3e2..fddc4ca 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -7,7 +7,7 @@ import pytest -from duty.validation import ParamsCaster, cast_arg, to_bool +from duty.validation import _get_params_caster, cast_arg, to_bool from tests.fixtures import validation as valfix @@ -130,7 +130,7 @@ def test_params_caster(func: Callable, args: tuple, kwargs: dict, expected_args: expected_args: The expected positional arguments result. expected_kwargs: The expected keyword arguments result. """ - caster = ParamsCaster(func) + caster = _get_params_caster(func, *args, **kwargs) new_args, new_kwargs = caster.cast(*args, **kwargs) assert new_args == expected_args assert new_kwargs == expected_kwargs