Skip to content

Commit

Permalink
stubtest: adjust symtable logic (#16823)
Browse files Browse the repository at this point in the history
  • Loading branch information
hauntsaninja authored Jan 27, 2024
1 parent 3838bff commit 717a263
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 29 deletions.
59 changes: 30 additions & 29 deletions mypy/stubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import collections.abc
import copy
import enum
import functools
import importlib
import importlib.machinery
import inspect
Expand Down Expand Up @@ -310,35 +311,23 @@ def _verify_exported_names(
)


def _get_imported_symbol_names(runtime: types.ModuleType) -> frozenset[str] | None:
"""Retrieve the names in the global namespace which are known to be imported.
@functools.lru_cache
def _module_symbol_table(runtime: types.ModuleType) -> symtable.SymbolTable | None:
"""Retrieve the symbol table for the module (or None on failure).
1). Use inspect to retrieve the source code of the module
2). Use symtable to parse the source and retrieve names that are known to be imported
from other modules.
If either of the above steps fails, return `None`.
Note that if a set of names is returned,
it won't include names imported via `from foo import *` imports.
1) Use inspect to retrieve the source code of the module
2) Use symtable to parse the source (and use what symtable knows for its purposes)
"""
try:
source = inspect.getsource(runtime)
except (OSError, TypeError, SyntaxError):
return None

if not source.strip():
# The source code for the module was an empty file,
# no point in parsing it with symtable
return frozenset()

try:
module_symtable = symtable.symtable(source, runtime.__name__, "exec")
return symtable.symtable(source, runtime.__name__, "exec")
except SyntaxError:
return None

return frozenset(sym.get_name() for sym in module_symtable.get_symbols() if sym.is_imported())


@verify.register(nodes.MypyFile)
def verify_mypyfile(
Expand Down Expand Up @@ -369,25 +358,37 @@ def verify_mypyfile(
if not o.module_hidden and (not is_probably_private(m) or hasattr(runtime, m))
}

imported_symbols = _get_imported_symbol_names(runtime)

def _belongs_to_runtime(r: types.ModuleType, attr: str) -> bool:
"""Heuristics to determine whether a name originates from another module."""
obj = getattr(r, attr)
if isinstance(obj, types.ModuleType):
return False
if callable(obj):
# It's highly likely to be a class or a function if it's callable,
# so the __module__ attribute will give a good indication of which module it comes from

symbol_table = _module_symbol_table(r)
if symbol_table is not None:
try:
obj_mod = obj.__module__
except Exception:
symbol = symbol_table.lookup(attr)
except KeyError:
pass
else:
if isinstance(obj_mod, str):
return bool(obj_mod == r.__name__)
if imported_symbols is not None:
return attr not in imported_symbols
if symbol.is_imported():
# symtable says we got this from another module
return False
# But we can't just return True here, because symtable doesn't know about symbols
# that come from `from module import *`
if symbol.is_assigned():
# symtable knows we assigned this symbol in the module
return True

# The __module__ attribute is unreliable for anything except functions and classes,
# but it's our best guess at this point
try:
obj_mod = obj.__module__
except Exception:
pass
else:
if isinstance(obj_mod, str):
return bool(obj_mod == r.__name__)
return True

runtime_public_contents = (
Expand Down
18 changes: 18 additions & 0 deletions mypy/test/teststubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,24 @@ def test_missing_no_runtime_all(self) -> Iterator[Case]:
yield Case(stub="", runtime="from json.scanner import NUMBER_RE", error=None)
yield Case(stub="", runtime="from string import ascii_letters", error=None)

@collect_cases
def test_missing_no_runtime_all_terrible(self) -> Iterator[Case]:
yield Case(
stub="",
runtime="""
import sys
import types
import __future__
_m = types.SimpleNamespace()
_m.annotations = __future__.annotations
sys.modules["_terrible_stubtest_test_module"] = _m
from _terrible_stubtest_test_module import *
assert annotations
""",
error=None,
)

@collect_cases
def test_non_public_1(self) -> Iterator[Case]:
yield Case(
Expand Down

0 comments on commit 717a263

Please sign in to comment.