diff --git a/mypy/plugins/attrs.py b/mypy/plugins/attrs.py index 12675042aa57..adc9a27c23ed 100644 --- a/mypy/plugins/attrs.py +++ b/mypy/plugins/attrs.py @@ -6,6 +6,8 @@ from typing_extensions import Final import mypy.plugin # To avoid circular imports. +from mypy.constraints import infer_constraints, SUBTYPE_OF +from mypy.expandtype import expand_type from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError from mypy.fixup import lookup_qualified_stnode from mypy.nodes import ( @@ -18,6 +20,7 @@ from mypy.plugins.common import ( _get_argument, _get_bool_argument, _get_decorator_bool_argument, add_method ) +from mypy.solve import solve_constraints from mypy.types import ( Type, AnyType, TypeOfAny, CallableType, NoneType, TypeVarDef, TypeVarType, Overloaded, UnionType, FunctionLike, get_proper_type @@ -57,6 +60,43 @@ def __init__(self, self.is_attr_converters_optional = is_attr_converters_optional +def expand_arg_type( + callable_type: CallableType, + target_type: Optional[Type], +) -> Type: + # The result is based on the type of the first argument of the callable + arg_type = get_proper_type(callable_type.arg_types[0]) + ret_type = get_proper_type(callable_type.ret_type) + target_type = get_proper_type(target_type) + if target_type is None: + target_type = AnyType(TypeOfAny.unannotated) + + if ret_type == target_type or isinstance(ret_type, AnyType): + # If the callable has the exact same return type as the target + # we can directly return the type of the first argument. + # This is also the case if the return type is `Any`. + return arg_type + + # Find the constraints on the type vars given that the result must be a + # subtype of the target type. + constraints = infer_constraints(ret_type, target_type, SUBTYPE_OF) + # When this code gets run, callable_type.variables has not yet been + # properly initialized, so we instead simply construct a set of all unique + # type var ids in the constraints + type_var_ids = list({const.type_var for const in constraints}) + # Now we get the best solutions for these constraints + solutions = solve_constraints(type_var_ids, constraints) + type_map = { + tid: sol + for tid, sol in zip(type_var_ids, solutions) + if sol is not None + } + + # Now we can use these solutions to expand the generic arg type into a + # concrete type + return expand_type(arg_type, type_map) + + class Attribute: """The value of an attr.ib() call.""" @@ -94,10 +134,11 @@ def argument(self, ctx: 'mypy.plugin.ClassDefContext') -> Argument: elif converter and converter.type: converter_type = converter.type + orig_type = init_type init_type = None converter_type = get_proper_type(converter_type) if isinstance(converter_type, CallableType) and converter_type.arg_types: - init_type = ctx.api.anal_type(converter_type.arg_types[0]) + init_type = ctx.api.anal_type(expand_arg_type(converter_type, orig_type)) elif isinstance(converter_type, Overloaded): types = [] # type: List[Type] for item in converter_type.items(): @@ -107,7 +148,7 @@ def argument(self, ctx: 'mypy.plugin.ClassDefContext') -> Argument: continue if num_arg_types > 1 and any(kind == ARG_POS for kind in item.arg_kinds[1:]): continue - types.append(item.arg_types[0]) + types.append(expand_arg_type(item, orig_type)) # Make a union of all the valid types. if types: args = make_simplified_union(types) diff --git a/test-data/unit/check-attr.test b/test-data/unit/check-attr.test index 28613454d2ff..e680b1fb2201 100644 --- a/test-data/unit/check-attr.test +++ b/test-data/unit/check-attr.test @@ -714,6 +714,43 @@ reveal_type(A) # N: Revealed type is 'def (x: builtins.int) -> __main__.A' reveal_type(A(15).x) # N: Revealed type is 'builtins.str' [builtins fixtures/list.pyi] +[case testAttrsUsingTupleConverter] +from typing import Tuple +import attr + +@attr.s +class C: + t: Tuple[int, ...] = attr.ib(converter=tuple) + +o = C([1, 2, 3]) +o = C(['a']) # E: List item 0 has incompatible type "str"; expected "int" +[builtins fixtures/attr.pyi] + +[case testAttrsUsingListConverter] +from typing import List +import attr + +@attr.s +class C: + t: List[int] = attr.ib(converter=list) + +o = C([1, 2, 3]) +o = C(['a']) # E: List item 0 has incompatible type "str"; expected "int" +[builtins fixtures/list.pyi] + +[case testAttrsUsingDictConverter] +from typing import Dict +import attr + +@attr.s +class C(object): + values = attr.ib(type=Dict[str, int], converter=dict) + + +C(values=[('a', 1), ('b', 2)]) +C(values=[(1, 'a')]) # E: List item 0 has incompatible type "Tuple[int, str]"; expected "Tuple[str, int]" +[builtins fixtures/dict.pyi] + [case testAttrsUsingConverterWithTypes] from typing import overload import attr diff --git a/test-data/unit/fixtures/attr.pyi b/test-data/unit/fixtures/attr.pyi index deb1906d931e..00d871e0c4e0 100644 --- a/test-data/unit/fixtures/attr.pyi +++ b/test-data/unit/fixtures/attr.pyi @@ -1,5 +1,7 @@ # Builtins stub used to support @attr.s tests. -from typing import Union, overload +from typing import Union, overload, Sequence, Generic, TypeVar, Iterable, \ + Tuple, Iterator + class object: def __init__(self) -> None: pass @@ -22,6 +24,24 @@ class complex: @overload def __init__(self, real: str = ...) -> None: ... +Tco = TypeVar('Tco', covariant=True) + +class tuple(Sequence[Tco], Generic[Tco]): + @overload + def __init__(self) -> None: pass + @overload + def __init__(self, x: Iterable[Tco]) -> None: pass + def __iter__(self) -> Iterator[Tco]: pass + def __contains__(self, item: object) -> bool: pass + def __getitem__(self, x: int) -> Tco: pass + def __rmul__(self, n: int) -> Tuple[Tco, ...]: pass + def __add__(self, x: Tuple[Tco, ...]) -> Tuple[Tco, ...]: pass + def count(self, obj: object) -> int: pass + +T = TypeVar('T') + +class list(Sequence[T], Generic[T]): pass + class str: pass class unicode: pass class ellipsis: pass