diff --git a/mypy/argmap.py b/mypy/argmap.py index 8305a371e0a6..26548dd00228 100644 --- a/mypy/argmap.py +++ b/mypy/argmap.py @@ -193,3 +193,21 @@ def expand_actual_type(self, else: # No translation for other kinds -- 1:1 mapping. return actual_type + + +def expand_formal_type(formal_type: Type, actual_name: Optional[str], tuple_idx: int) -> Type: + if isinstance(formal_type, Instance) and formal_type.type.fullname() == 'mypy_extensions.Expand': + + formal_type_inner = formal_type.args[0] + if isinstance(formal_type_inner, Instance) and formal_type_inner.type.fullname() == 'builtins.dict': + return formal_type_inner.args[1] + elif isinstance(formal_type_inner, TypedDictType): + if actual_name is None or actual_name not in formal_type_inner.items: + return AnyType(TypeOfAny.from_error) + return formal_type_inner.items[actual_name] + elif isinstance(formal_type_inner, TupleType): + if len(formal_type_inner.items) <= tuple_idx: + return AnyType(TypeOfAny.from_error) + return formal_type_inner.items[tuple_idx] + + return formal_type diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 82d40138668e..385d332d9a97 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -50,7 +50,7 @@ from mypy import applytype from mypy import erasetype from mypy.checkmember import analyze_member_access, type_object_type -from mypy.argmap import ArgTypeExpander, map_actuals_to_formals, map_formals_to_actuals +from mypy.argmap import ArgTypeExpander, expand_formal_type, map_actuals_to_formals, map_formals_to_actuals from mypy.checkstrformat import StringFormatterChecker from mypy.expandtype import expand_type, expand_type_by_instance, freshen_function_type_vars from mypy.util import split_module_names @@ -775,7 +775,7 @@ def check_callable_call(self, self.check_argument_count(callee, arg_types, arg_kinds, arg_names, formal_to_actual, context, self.msg) - self.check_argument_types(arg_types, arg_kinds, callee, formal_to_actual, context, + self.check_argument_types(arg_types, arg_kinds, arg_names, callee, formal_to_actual, context, messages=arg_messages) if (callee.is_type_obj() and (len(arg_types) == 1) @@ -1203,6 +1203,7 @@ def check_for_extra_actual_arguments(self, def check_argument_types(self, arg_types: List[Type], arg_kinds: List[int], + arg_names: Optional[Sequence[Optional[str]]], callee: CallableType, formal_to_actual: List[List[int]], context: Context, @@ -1233,8 +1234,13 @@ def check_argument_types(self, expanded_actual = mapper.expand_actual_type( actual_type, actual_kind, callee.arg_names[i], callee.arg_kinds[i]) + expanded_formal = expand_formal_type( + callee.arg_types[i], + arg_names[actual] if arg_names is not None else None, + actual - i, + ) check_arg(expanded_actual, actual_type, arg_kinds[actual], - callee.arg_types[i], + expanded_formal, actual + 1, i + 1, callee, context, messages) def check_arg(self, caller_type: Type, original_caller_type: Type, @@ -1713,7 +1719,7 @@ def check_arg(caller_type: Type, original_caller_type: Type, caller_kind: int, raise Finished try: - self.check_argument_types(arg_types, arg_kinds, callee, + self.check_argument_types(arg_types, arg_kinds, None, callee, formal_to_actual, context=context, check_arg=check_arg) return True except Finished: diff --git a/test-data/unit/lib-stub/mypy_extensions.pyi b/test-data/unit/lib-stub/mypy_extensions.pyi index 8fb5942cd983..8e546ee1bfb0 100644 --- a/test-data/unit/lib-stub/mypy_extensions.pyi +++ b/test-data/unit/lib-stub/mypy_extensions.pyi @@ -43,3 +43,5 @@ def TypedDict(typename: str, fields: Dict[str, Type[_T]], *, total: Any = ...) - def trait(cls: Any) -> Any: ... class FlexibleAlias(Generic[_T, _U]): ... + +def Expand(type: _T) -> _T: ...