From a7915af510ff2f37019e85d42ee46db92a56f47d Mon Sep 17 00:00:00 2001 From: hauntsaninja Date: Sun, 30 Oct 2022 18:24:31 -0700 Subject: [PATCH] Better story for import redefinitions This changes our importing logic to be more consistent and to treat import statements more like assignments. Fixes #13803 The primary motivation for this is when typing modules as protocols, as in the linked issue. But it turns out we already allowed redefinition with from imports, so this just seems like a nice consistency win. We move shared logic from visit_import_all and visit_import_from (via process_imported_symbol) into add_imported_symbol. We then reuse it in visit_import. To simplify stuff, we inline the code from add_module_symbol into visit_import. Then we copy over logic from add_symbol, because MypyFile is not a SymbolTableNode, but this isn't the worst thing ever. Finally, we now need to check non-from import statements like assignments, which was a thing we weren't doing earlier. --- mypy/checker.py | 4 +- mypy/semanal.py | 84 +++++++++++++-------------- test-data/unit/check-classes.test | 3 +- test-data/unit/check-incremental.test | 5 +- test-data/unit/check-protocols.test | 56 ++++++++++++++++++ test-data/unit/check-redefine.test | 2 +- 6 files changed, 100 insertions(+), 54 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index d66bf764b517..3d698bc6bba5 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -2528,8 +2528,8 @@ def visit_import_from(self, node: ImportFrom) -> None: def visit_import_all(self, node: ImportAll) -> None: self.check_import(node) - def visit_import(self, s: Import) -> None: - pass + def visit_import(self, node: Import) -> None: + self.check_import(node) def check_import(self, node: ImportBase) -> None: for assign in node.assignments: diff --git a/mypy/semanal.py b/mypy/semanal.py index b8f708b22a92..af923ede1c07 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -2235,13 +2235,33 @@ def visit_import(self, i: Import) -> None: base_id = id.split(".")[0] imported_id = base_id module_public = use_implicit_reexport - self.add_module_symbol( - base_id, - imported_id, - context=i, - module_public=module_public, - module_hidden=not module_public, - ) + + if base_id in self.modules: + node = self.modules[base_id] + if self.is_func_scope(): + kind = LDEF + elif self.type is not None: + kind = MDEF + else: + kind = GDEF + symbol = SymbolTableNode( + kind, node, module_public=module_public, module_hidden=not module_public + ) + self.add_imported_symbol( + imported_id, + symbol, + context=i, + module_public=module_public, + module_hidden=not module_public, + ) + else: + self.add_unknown_imported_symbol( + imported_id, + context=i, + target_name=base_id, + module_public=module_public, + module_hidden=not module_public, + ) def visit_import_from(self, imp: ImportFrom) -> None: self.statement = imp @@ -2367,19 +2387,6 @@ def process_imported_symbol( module_hidden=module_hidden, becomes_typeinfo=True, ) - existing_symbol = self.globals.get(imported_id) - if ( - existing_symbol - and not isinstance(existing_symbol.node, PlaceholderNode) - and not isinstance(node.node, PlaceholderNode) - ): - # Import can redefine a variable. They get special treatment. - if self.process_import_over_existing_name(imported_id, existing_symbol, node, context): - return - if existing_symbol and isinstance(node.node, PlaceholderNode): - # Imports are special, some redefinitions are allowed, so wait until - # we know what is the new symbol node. - return # NOTE: we take the original node even for final `Var`s. This is to support # a common pattern when constants are re-exported (same applies to import *). self.add_imported_symbol( @@ -2495,14 +2502,9 @@ def visit_import_all(self, i: ImportAll) -> None: if isinstance(node.node, MypyFile): # Star import of submodule from a package, add it as a dependency. self.imports.add(node.node.fullname) - existing_symbol = self.lookup_current_scope(name) - if existing_symbol and not isinstance(node.node, PlaceholderNode): - # Import can redefine a variable. They get special treatment. - if self.process_import_over_existing_name(name, existing_symbol, node, i): - continue # `from x import *` always reexports symbols self.add_imported_symbol( - name, node, i, module_public=True, module_hidden=False + name, node, context=i, module_public=True, module_hidden=False ) else: @@ -5577,24 +5579,6 @@ def add_local(self, node: Var | FuncDef | OverloadedFuncDef, context: Context) - node._fullname = name self.add_symbol(name, node, context) - def add_module_symbol( - self, id: str, as_id: str, context: Context, module_public: bool, module_hidden: bool - ) -> None: - """Add symbol that is a reference to a module object.""" - if id in self.modules: - node = self.modules[id] - self.add_symbol( - as_id, node, context, module_public=module_public, module_hidden=module_hidden - ) - else: - self.add_unknown_imported_symbol( - as_id, - context, - target_name=id, - module_public=module_public, - module_hidden=module_hidden, - ) - def _get_node_for_class_scoped_import( self, name: str, symbol_node: SymbolNode | None, context: Context ) -> SymbolNode | None: @@ -5641,13 +5625,23 @@ def add_imported_symbol( self, name: str, node: SymbolTableNode, - context: Context, + context: ImportBase, module_public: bool, module_hidden: bool, ) -> None: """Add an alias to an existing symbol through import.""" assert not module_hidden or not module_public + existing_symbol = self.lookup_current_scope(name) + if ( + existing_symbol + and not isinstance(existing_symbol.node, PlaceholderNode) + and not isinstance(node.node, PlaceholderNode) + ): + # Import can redefine a variable. They get special treatment. + if self.process_import_over_existing_name(name, existing_symbol, node, context): + return + symbol_node: SymbolNode | None = node.node if self.is_class_scope(): diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index b16387f194d4..4ed44c90f275 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -7414,8 +7414,7 @@ class Foo: def meth1(self, a: str) -> str: ... # E: Name "meth1" already defined on line 5 def meth2(self, a: str) -> str: ... - from mod1 import meth2 # E: Unsupported class scoped import \ - # E: Name "meth2" already defined on line 8 + from mod1 import meth2 # E: Incompatible import of "meth2" (imported name has type "Callable[[int], int]", local name has type "Callable[[Foo, str], str]") class Bar: from mod1 import foo # E: Unsupported class scoped import diff --git a/test-data/unit/check-incremental.test b/test-data/unit/check-incremental.test index d4e6779403b4..3ec0ed2c63f5 100644 --- a/test-data/unit/check-incremental.test +++ b/test-data/unit/check-incremental.test @@ -1025,10 +1025,7 @@ import a.b [file a/b.py] -[rechecked b] -[stale] -[out2] -tmp/b.py:4: error: Name "a" already defined on line 3 +[stale b] [case testIncrementalSilentImportsAndImportsInClass] # flags: --ignore-missing-imports diff --git a/test-data/unit/check-protocols.test b/test-data/unit/check-protocols.test index 8cdfd2a3e0d9..113b2000fc22 100644 --- a/test-data/unit/check-protocols.test +++ b/test-data/unit/check-protocols.test @@ -3787,3 +3787,59 @@ from typing_extensions import Final a: Final = 1 [builtins fixtures/module.pyi] + + +[case testModuleAsProtocolRedefinitionTopLevel] +from typing import Protocol + +class P(Protocol): + def f(self) -> str: ... + +cond: bool +t: P +if cond: + import mod1 as t +else: + import mod2 as t + +import badmod as t # E: Incompatible import of "t" (imported name has type Module, local name has type "P") + +[file mod1.py] +def f() -> str: ... + +[file mod2.py] +def f() -> str: ... + +[file badmod.py] +def nothing() -> int: ... +[builtins fixtures/module.pyi] + +[case testModuleAsProtocolRedefinitionImportFrom] +from typing import Protocol + +class P(Protocol): + def f(self) -> str: ... + +cond: bool +t: P +if cond: + from package import mod1 as t +else: + from package import mod2 as t + +from package import badmod as t # E: Incompatible import of "t" (imported name has type Module, local name has type "P") + +package: int = 10 + +import package.mod1 as t +import package.mod1 # E: Incompatible import of "package" (imported name has type Module, local name has type "int") + +[file package/mod1.py] +def f() -> str: ... + +[file package/mod2.py] +def f() -> str: ... + +[file package/badmod.py] +def nothing() -> int: ... +[builtins fixtures/module.pyi] diff --git a/test-data/unit/check-redefine.test b/test-data/unit/check-redefine.test index e73f715c9ec0..e3f1b976d4e9 100644 --- a/test-data/unit/check-redefine.test +++ b/test-data/unit/check-redefine.test @@ -285,7 +285,7 @@ def f() -> None: import typing as m m = 1 # E: Incompatible types in assignment (expression has type "int", variable has type Module) n = 1 - import typing as n # E: Name "n" already defined on line 5 + import typing as n # E: Incompatible import of "n" (imported name has type Module, local name has type "int") [builtins fixtures/module.pyi] [case testRedefineLocalWithTypeAnnotation]