diff --git a/mypy/stubgen.py b/mypy/stubgen.py index bd1dbeb54a08b..dca7cbe22668b 100644 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -347,9 +347,7 @@ def get_base_types(self, cdef: ClassDef) -> List[str]: if base.name != 'object': base_types.append(base.name) elif isinstance(base, MemberExpr): - modname = get_qualified_name(base.expr) - base_types.append('%s.%s' % (modname, base.name)) - self.add_import_line('import %s\n' % modname) + base_types.append(get_qualified_name(base)) return base_types def visit_assignment_stmt(self, o: AssignmentStmt) -> None: @@ -437,8 +435,9 @@ def visit_import_from(self, o: ImportFrom) -> None: exported_names.update(sub_names) self.import_and_export_names(o.id, o.relative, sub_names) # Import names used as base classes. + base_class_imports = [base_class.split('.')[0] for base_class in self._base_classes] base_names = [(name, alias) for name, alias in o.names - if alias or name in self._base_classes and name not in exported_names] + if alias or name in base_class_imports and name not in exported_names] if base_names: imp_names = [] # type: List[str] for name, alias in base_names: @@ -468,6 +467,12 @@ def visit_import(self, o: Import) -> None: '.' not in id): self.add_import_line('import %s as %s\n' % (id, target_name)) self.record_name(target_name) + base_class_imports = [base_class.split('.')[0] for base_class in self._base_classes] + if target_name in base_class_imports: + if as_id: + self.add_import_line('import %s as %s\n' % (id, as_id)) + else: + self.add_import_line('import %s\n' % (id, )) def get_init(self, lvalue: str, rvalue: Expression) -> Optional[str]: """Return initializer for a variable. diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 2795bb34538a7..c70e4057f30ce 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -530,6 +530,14 @@ import x.y class D(x.y.C): ... +[case testArbitraryBaseClassWithAlias] +import x as y +class D(y.C): ... +[out] +import x as y + +class D(y.C): ... + [case testUnqualifiedArbitraryBaseClassWithNoDef] class A(int): ... [out] @@ -628,5 +636,23 @@ class A: x = ... # type: Any def __init__(self, a: Optional[Any] = ...) -> None: ... +[case testImportAddedForQualifiedBaseClass] +from foo import bar + +class A(bar.fuzz.Baz): ... +[out] +from foo import bar + +class A(bar.fuzz.Baz): ... + +[case testImportAddedForQualifiedBaseClassWithAlias] +from foo import bar as baz + +class A(baz.Baz): ... +[out] +from foo import bar as baz + +class A(baz.Baz): ... + -- More features/fixes: -- do not export deleted names