diff --git a/mypy/stubtest.py b/mypy/stubtest.py index a7a72235fed1..5946324d4619 100644 --- a/mypy/stubtest.py +++ b/mypy/stubtest.py @@ -15,6 +15,7 @@ import os import pkgutil import re +import symtable import sys import traceback import types @@ -283,6 +284,36 @@ def _verify_exported_names( ) +def _get_imported_symbol_names(runtime: types.ModuleType) -> frozenset[str] | None: + """Retrieve the names in the global namespace which are known to be imported. + + 1). Use inspect to retrieve the source code of the module + 2). Use symtable to parse the source and retrieve names that are known to be imported + from other modules. + + If either of the above steps fails, return `None`. + + Note that if a set of names is returned, + it won't include names imported via `from foo import *` imports. + """ + try: + source = inspect.getsource(runtime) + except (OSError, TypeError, SyntaxError): + return None + + if not source.strip(): + # The source code for the module was an empty file, + # no point in parsing it with symtable + return frozenset() + + try: + module_symtable = symtable.symtable(source, runtime.__name__, "exec") + except SyntaxError: + return None + + return frozenset(sym.get_name() for sym in module_symtable.get_symbols() if sym.is_imported()) + + @verify.register(nodes.MypyFile) def verify_mypyfile( stub: nodes.MypyFile, runtime: MaybeMissing[types.ModuleType], object_path: list[str] @@ -312,15 +343,22 @@ def verify_mypyfile( if not o.module_hidden and (not is_probably_private(m) or hasattr(runtime, m)) } + imported_symbols = _get_imported_symbol_names(runtime) + def _belongs_to_runtime(r: types.ModuleType, attr: str) -> bool: obj = getattr(r, attr) - try: - obj_mod = getattr(obj, "__module__", None) - except Exception: + if isinstance(obj, types.ModuleType): return False - if obj_mod is not None: - return bool(obj_mod == r.__name__) - return not isinstance(obj, types.ModuleType) + if callable(obj): + try: + obj_mod = getattr(obj, "__module__", None) + except Exception: + return False + if obj_mod is not None: + return bool(obj_mod == r.__name__) + if imported_symbols is not None: + return attr not in imported_symbols + return True runtime_public_contents = ( runtime_all_as_set diff --git a/mypy/test/teststubtest.py b/mypy/test/teststubtest.py index 812333e3feb4..5e59d8efec63 100644 --- a/mypy/test/teststubtest.py +++ b/mypy/test/teststubtest.py @@ -1082,6 +1082,9 @@ def test_missing_no_runtime_all(self) -> Iterator[Case]: yield Case(stub="", runtime="import sys", error=None) yield Case(stub="", runtime="def g(): ...", error="g") yield Case(stub="", runtime="CONSTANT = 0", error="CONSTANT") + yield Case(stub="", runtime="import re; constant = re.compile('foo')", error="constant") + yield Case(stub="", runtime="from json.scanner import NUMBER_RE", error=None) + yield Case(stub="", runtime="from string import ascii_letters", error=None) @collect_cases def test_non_public_1(self) -> Iterator[Case]: