From cc7b062026311287c91b226960ee0ad2e447a7ea Mon Sep 17 00:00:00 2001 From: Ilya Konstantinov Date: Wed, 12 Apr 2023 03:59:42 -0400 Subject: [PATCH] Fix attrs.evolve on bound TypeVar (#15022) Fixes the error on the last line of this example: ```python @attrs.define class A: x: int T = TypeVar('T', bound=A) def f(t: T) -> None: _ = attrs.evolve(t, x=42) # E: Argument 1 to "evolve" has incompatible type "T"; expected an attrs class ``` Since `T` is bounded by `A`, we know it can be treated as `A`. --- mypy/plugins/attrs.py | 24 ++++++++---- test-data/unit/check-attr.test | 69 ++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 8 deletions(-) diff --git a/mypy/plugins/attrs.py b/mypy/plugins/attrs.py index f59eb2e36e4c..4a43c2a16d52 100644 --- a/mypy/plugins/attrs.py +++ b/mypy/plugins/attrs.py @@ -62,6 +62,7 @@ LiteralType, NoneType, Overloaded, + ProperType, TupleType, Type, TypeOfAny, @@ -929,13 +930,10 @@ def add_method( add_method(self.ctx, method_name, args, ret_type, self_type, tvd) -def _get_attrs_init_type(typ: Type) -> CallableType | None: +def _get_attrs_init_type(typ: Instance) -> CallableType | None: """ If `typ` refers to an attrs class, gets the type of its initializer method. """ - typ = get_proper_type(typ) - if not isinstance(typ, Instance): - return None magic_attr = typ.type.get(MAGIC_ATTR_NAME) if magic_attr is None or not magic_attr.plugin_generated: return None @@ -945,6 +943,14 @@ def _get_attrs_init_type(typ: Type) -> CallableType | None: return init_method.type +def _get_attrs_cls_and_init(typ: ProperType) -> tuple[Instance | None, CallableType | None]: + if isinstance(typ, TypeVarType): + typ = get_proper_type(typ.upper_bound) + if not isinstance(typ, Instance): + return None, None + return typ, _get_attrs_init_type(typ) + + def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> CallableType: """ Generates a signature for the 'attr.evolve' function that's specific to the call site @@ -967,13 +973,15 @@ def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> Callabl inst_type = get_proper_type(inst_type) if isinstance(inst_type, AnyType): - return ctx.default_signature + return ctx.default_signature # evolve(Any, ....) -> Any inst_type_str = format_type_bare(inst_type) - attrs_init_type = _get_attrs_init_type(inst_type) - if not attrs_init_type: + attrs_type, attrs_init_type = _get_attrs_cls_and_init(inst_type) + if attrs_type is None or attrs_init_type is None: ctx.api.fail( - f'Argument 1 to "evolve" has incompatible type "{inst_type_str}"; expected an attrs class', + f'Argument 1 to "evolve" has a variable type "{inst_type_str}" not bound to an attrs class' + if isinstance(inst_type, TypeVarType) + else f'Argument 1 to "evolve" has incompatible type "{inst_type_str}"; expected an attrs class', ctx.context, ) return ctx.default_signature diff --git a/test-data/unit/check-attr.test b/test-data/unit/check-attr.test index 3ca804943010..45c673b269c5 100644 --- a/test-data/unit/check-attr.test +++ b/test-data/unit/check-attr.test @@ -1970,6 +1970,75 @@ reveal_type(ret) # N: Revealed type is "Any" [typing fixtures/typing-medium.pyi] +[case testEvolveTypeVarBound] +import attrs +from typing import TypeVar + +@attrs.define +class A: + x: int + +@attrs.define +class B(A): + pass + +TA = TypeVar('TA', bound=A) + +def f(t: TA) -> TA: + t2 = attrs.evolve(t, x=42) + reveal_type(t2) # N: Revealed type is "TA`-1" + t3 = attrs.evolve(t, x='42') # E: Argument "x" to "evolve" of "TA" has incompatible type "str"; expected "int" + return t2 + +f(A(x=42)) +f(B(x=42)) + +[builtins fixtures/attr.pyi] + +[case testEvolveTypeVarBoundNonAttrs] +import attrs +from typing import TypeVar + +TInt = TypeVar('TInt', bound=int) +TAny = TypeVar('TAny') +TNone = TypeVar('TNone', bound=None) + +def f(t: TInt) -> None: + _ = attrs.evolve(t, x=42) # E: Argument 1 to "evolve" has a variable type "TInt" not bound to an attrs class + +def g(t: TAny) -> None: + _ = attrs.evolve(t, x=42) # E: Argument 1 to "evolve" has a variable type "TAny" not bound to an attrs class + +def h(t: TNone) -> None: + _ = attrs.evolve(t, x=42) # E: Argument 1 to "evolve" has a variable type "TNone" not bound to an attrs class + +[builtins fixtures/attr.pyi] + +[case testEvolveTypeVarConstrained] +import attrs +from typing import TypeVar + +@attrs.define +class A: + x: int + +@attrs.define +class B: + x: str # conflicting with A.x + +T = TypeVar('T', A, B) + +def f(t: T) -> T: + t2 = attrs.evolve(t, x=42) # E: Argument "x" to "evolve" of "B" has incompatible type "int"; expected "str" + reveal_type(t2) # N: Revealed type is "__main__.A" # N: Revealed type is "__main__.B" + t2 = attrs.evolve(t, x='42') # E: Argument "x" to "evolve" of "A" has incompatible type "str"; expected "int" + return t2 + +f(A(x=42)) +f(B(x='42')) + +[builtins fixtures/attr.pyi] + [case testEvolveVariants] from typing import Any import attr