diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..a6e9fd087d --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: minor + +Support aliases (TypeVar passthrough) in `get_specialized_type_var_map`. diff --git a/strawberry/utils/inspect.py b/strawberry/utils/inspect.py index 650c53b498..024b004b15 100644 --- a/strawberry/utils/inspect.py +++ b/strawberry/utils/inspect.py @@ -1,17 +1,20 @@ import asyncio import inspect +from collections import OrderedDict from functools import lru_cache +from itertools import zip_longest from typing import ( Any, Callable, + Generic, Optional, + Protocol, TypeVar, + Union, get_origin, ) from typing_extensions import get_args -from strawberry.utils.typing import is_generic_alias - def in_async_context() -> bool: # Based on the way django checks if there's an event loop in the current thread @@ -67,13 +70,13 @@ class IntBarFoo(IntBar, Foo[str]): ... # {} get_specialized_type_var_map(Bar) - # {~T: ~T} + # {} get_specialized_type_var_map(IntBar) - # {~T: int} + # {~T: int, ~K: int} get_specialized_type_var_map(IntBarSubclass) - # {~T: int} + # {~T: int, ~K: int} get_specialized_type_var_map(IntBarFoo) # {~T: int, ~K: str} @@ -81,43 +84,46 @@ class IntBarFoo(IntBar, Foo[str]): ... """ from strawberry.types.base import has_object_definition - orig_bases = getattr(cls, "__orig_bases__", None) - if orig_bases is None: - # Specialized generic aliases will not have __orig_bases__ - if get_origin(cls) is not None and is_generic_alias(cls): - orig_bases = (cls,) - else: - # Not a specialized type - return None - - type_var_map = {} - - # only get type vars for base generics (ie. Generic[T]) and for strawberry types + param_args = OrderedDict[TypeVar, Union[None, TypeVar, type]]() - orig_bases = [b for b in orig_bases if has_object_definition(b)] + types: list[type] = [cls] + while types: + tp = types.pop(0) + if (origin := get_origin(tp)) is None or origin in (Generic, Protocol): + origin = tp - for base in orig_bases: - # Recursively get type var map from base classes - if base is not cls: - base_type_var_map = get_specialized_type_var_map(base) - if base_type_var_map is not None: - type_var_map.update(base_type_var_map) - - args = get_args(base) - origin = getattr(base, "__origin__", None) - - params = origin and getattr(origin, "__parameters__", None) - if params is None: - params = getattr(base, "__parameters__", None) - - if not params: + # only get type vars for base generics (i.e. Generic[T]) and for strawberry types + if not has_object_definition(origin): continue - type_var_map.update( - {p.__name__: a for p, a in zip(params, args) if not isinstance(a, TypeVar)} - ) - - return type_var_map + if (type_params := getattr(origin, "__parameters__", None)) is not None: + args = get_args(tp) + for type_param, arg in zip_longest(type_params, args): + if type_param not in param_args: + param_args[type_param] = arg + + if orig_bases := getattr(origin, "__orig_bases__", None): + types.extend(orig_bases) + if not param_args: + return None + + resolve = True + while resolve: + resolve = False + for type_param, arg in list(param_args.items()): + if arg is None or not isinstance(arg, TypeVar): + continue + resolved_arg = param_args.get(arg, None) if arg is not type_param else None + param_args[type_param] = resolved_arg + + if resolved_arg: + resolve = True + + return { + k.__name__: v + for k, v in reversed(param_args.items()) + if v is not None and not isinstance(v, TypeVar) + } __all__ = ["get_func_args", "get_specialized_type_var_map", "in_async_context"]