diff --git a/mypy/stubtest.py b/mypy/stubtest.py index 41b58cbbb636..949c73867f01 100644 --- a/mypy/stubtest.py +++ b/mypy/stubtest.py @@ -33,6 +33,9 @@ from typing_extensions import get_origin, is_typeddict import mypy.build +import mypy.checkexpr +import mypy.checkmember +import mypy.erasetype import mypy.modulefinder import mypy.nodes import mypy.state @@ -680,7 +683,11 @@ def _verify_arg_default_value( "has a default value but stub argument does not" ) else: - runtime_type = get_mypy_type_of_runtime_value(runtime_arg.default) + type_context = stub_arg.variable.type + runtime_type = get_mypy_type_of_runtime_value( + runtime_arg.default, type_context=type_context + ) + # Fallback to the type annotation type if var type is missing. The type annotation # is an UnboundType, but I don't know enough to know what the pros and cons here are. # UnboundTypes have ugly question marks following them, so default to var type. @@ -1115,7 +1122,7 @@ def verify_var( ): yield Error(object_path, "is read-only at runtime but not in the stub", stub, runtime) - runtime_type = get_mypy_type_of_runtime_value(runtime) + runtime_type = get_mypy_type_of_runtime_value(runtime, type_context=stub.type) if ( runtime_type is not None and stub.type is not None @@ -1612,7 +1619,18 @@ def is_subtype_helper(left: mypy.types.Type, right: mypy.types.Type) -> bool: return mypy.subtypes.is_subtype(left, right) -def get_mypy_type_of_runtime_value(runtime: Any) -> mypy.types.Type | None: +def get_mypy_node_for_name(module: str, type_name: str) -> mypy.nodes.SymbolNode | None: + stub = get_stub(module) + if stub is None: + return None + if type_name not in stub.names: + return None + return stub.names[type_name].node + + +def get_mypy_type_of_runtime_value( + runtime: Any, type_context: mypy.types.Type | None = None +) -> mypy.types.Type | None: """Returns a mypy type object representing the type of ``runtime``. Returns None if we can't find something that works. @@ -1673,14 +1691,45 @@ def anytype() -> mypy.types.AnyType: is_ellipsis_args=True, ) - # Try and look up a stub for the runtime object - stub = get_stub(type(runtime).__module__) - if stub is None: - return None - type_name = type(runtime).__name__ - if type_name not in stub.names: + skip_type_object_type = False + if type_context: + # Don't attempt to process the type object when context is generic + # This is related to issue #3737 + type_context = mypy.types.get_proper_type(type_context) + # Callable types with a generic return value + if isinstance(type_context, mypy.types.CallableType): + if isinstance(type_context.ret_type, mypy.types.TypeVarType): + skip_type_object_type = True + # Type[x] where x is generic + if isinstance(type_context, mypy.types.TypeType): + if isinstance(type_context.item, mypy.types.TypeVarType): + skip_type_object_type = True + + if isinstance(runtime, type) and not skip_type_object_type: + + def _named_type(name: str) -> mypy.types.Instance: + parts = name.rsplit(".", maxsplit=1) + node = get_mypy_node_for_name(parts[0], parts[1]) + assert isinstance(node, nodes.TypeInfo) + any_type = mypy.types.AnyType(mypy.types.TypeOfAny.special_form) + return mypy.types.Instance(node, [any_type] * len(node.defn.type_vars)) + + # Try and look up a stub for the runtime object itself + # The logic here is similar to ExpressionChecker.analyze_ref_expr + type_info = get_mypy_node_for_name(runtime.__module__, runtime.__name__) + if isinstance(type_info, nodes.TypeInfo): + result: mypy.types.Type | None = None + result = mypy.checkmember.type_object_type(type_info, _named_type) + if mypy.checkexpr.is_type_type_context(type_context): + # This is the type in a type[] expression, so substitute type + # variables with Any. + result = mypy.erasetype.erase_typevars(result) + return result + + # Try and look up a stub for the runtime object's type + type_info = get_mypy_node_for_name(type(runtime).__module__, type(runtime).__name__) + if type_info is None: return None - type_info = stub.names[type_name].node if isinstance(type_info, nodes.Var): return type_info.type if not isinstance(type_info, nodes.TypeInfo): diff --git a/mypy/test/teststubtest.py b/mypy/test/teststubtest.py index 101b6f65c45a..303c5fc099b6 100644 --- a/mypy/test/teststubtest.py +++ b/mypy/test/teststubtest.py @@ -2436,6 +2436,31 @@ def func2() -> None: ... error="func2", ) + @collect_cases + def test_type_default_protocol(self) -> Iterator[Case]: + yield Case( + stub=""" + from typing import Protocol + + class _FormatterClass(Protocol): + def __call__(self, *, prog: str) -> HelpFormatter: ... + + class ArgumentParser: + def __init__(self, formatter_class: _FormatterClass = ...) -> None: ... + + class HelpFormatter: + def __init__(self, prog: str, indent_increment: int = 2) -> None: ... + """, + runtime=""" + class HelpFormatter: + def __init__(self, prog, indent_increment=2) -> None: ... + + class ArgumentParser: + def __init__(self, formatter_class=HelpFormatter): ... + """, + error=None, + ) + def remove_color_code(s: str) -> str: return re.sub("\\x1b.*?m", "", s) # this works!