Skip to content
This repository has been archived by the owner on Jan 19, 2025. It is now read-only.

Commit

Permalink
fix(package-parser): missing public reexports
Browse files Browse the repository at this point in the history
  • Loading branch information
lars-reimann committed Jun 14, 2022
1 parent e660e08 commit 45d0573
Showing 1 changed file with 36 additions and 12 deletions.
48 changes: 36 additions & 12 deletions package-parser/package_parser/processing/api/_ast_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@

class _AstVisitor:
def __init__(self, api: API) -> None:
self.reexported: dict[str, list[str]] = {}
# Key of dict is ID of declaration. Value is list of modules that re-export this declaration. For each module
# we store its ID and its qualified name.
self.reexported: dict[str, list[tuple[str, str, str]]] = {}
self.api: API = api
self.__declaration_stack: list[Union[Module, Class, Function]] = []

Expand Down Expand Up @@ -91,14 +93,15 @@ def enter_module(self, module_node: astroid.Module):
from_imports.append(FromImport(base_import_path, name, alias))

# Find re-exported declarations in __init__.py files
if _is_init_file(module_node.file) and is_public_module(module_node.qname()):
if _is_init_file(module_node.file):
for declaration, _ in global_node.names:
reexported_name = f"{base_import_path}.{declaration}"
reexported_id = f"{base_import_path}.{declaration}"

# if reexported_name.startswith(module_node.name):
if reexported_name not in self.reexported:
self.reexported[reexported_name] = []
self.reexported[reexported_name] += [id_]
if reexported_id not in self.reexported:
self.reexported[reexported_id] = []
self.reexported[reexported_id] += [(id_, module_node.qname())]


# Remember module, so we can later add classes and global functions
module = Module(
Expand Down Expand Up @@ -134,7 +137,7 @@ def enter_classdef(self, class_node: astroid.ClassDef) -> None:
decorator_names,
class_node.basenames,
self.is_public(class_node.name, qname),
self.reexported.get(qname, []),
self.reexported.get(qname, []), # TODO
_AstVisitor.__description(numpydoc),
class_node.doc,
)
Expand Down Expand Up @@ -164,6 +167,7 @@ def enter_functiondef(self, function_node: astroid.FunctionDef) -> None:

numpydoc = NumpyDocString(inspect.cleandoc(function_node.doc or ""))
is_public = self.is_public(function_node.name, qname)
reexports = self._transitive_hull_for_public_reexports(qname)

function = Function(
self.__get_function_id(function_node.name, decorator_names),
Expand All @@ -174,7 +178,7 @@ def enter_functiondef(self, function_node: astroid.FunctionDef) -> None:
),
[], # TODO: results
is_public,
self.reexported.get(qname, []),
,
_AstVisitor.__description(numpydoc),
function_node.doc,
)
Expand Down Expand Up @@ -316,16 +320,36 @@ def is_public(self, name: str, qualified_name: str) -> bool:
if name.startswith("_") and not name.endswith("__"):
return False

if qualified_name in self.reexported:
if self._is_publicly_reexported(qualified_name):
return True

# Containing class is re-exported (always false if the current API element is not a method)
if isinstance(self.__declaration_stack[-1], Class) and parent_qualified_name(qualified_name) in self.reexported:
# Containing class is reexported
if isinstance(self.__declaration_stack[-1], Class) and self._is_publicly_reexported(
parent_qualified_name(qualified_name)
):
return True

# The slicing is necessary so __init__ functions are not excluded (already handled in the first condition).
return all(not it.startswith("_") for it in qualified_name.split(".")[:-1])

def _is_publicly_reexported(self, id_: str) -> bool:
return len(self._transitive_hull_for_public_reexports(id_)) > 0

def _transitive_hull_for_public_reexports(self, id_: str) -> list[tuple[str, str]]:
return [
(id_, qualified_name)
for id_, qualified_name
in self._transitive_hull_for_reexports(id_)
if _is_public_module(qualified_name)
]

def _transitive_hull_for_reexports(self, id_: str) -> list[tuple[str, str]]:
result = []
for id_, id_ in self.reexported.get(id_, []):
result.append((id_, id_))
result.extend(self._transitive_hull_for_reexports(id_))
return result


def is_public_module(module_name: str) -> bool:
def _is_public_module(module_name: str) -> bool:
return all(not it.startswith("_") for it in module_name.split("."))

0 comments on commit 45d0573

Please sign in to comment.