Skip to content

Commit

Permalink
feat: Support wildcard imports
Browse files Browse the repository at this point in the history
  • Loading branch information
pawamoy committed Dec 31, 2021
1 parent 1446343 commit 77a3cb7
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/griffe/agents/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def handle_function(self, node: ObjectNode, labels: set | None = None): # noqa:
"""
try:
signature = getsignature(node.obj)
except (ValueError, TokenError):
except (ValueError, TokenError, TypeError):
parameters = None
returns = None
else:
Expand Down
27 changes: 27 additions & 0 deletions src/griffe/agents/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,33 @@ def _join(sequence, item):
return new_sequence


if sys.version_info < (3, 8):

def parse__all__(node: NodeAssign) -> set[str]: # noqa: WPS116,WPS120
"""Get the values declared in `__all__`.
Parameters:
node: The assignment node.
Returns:
A set of names.
"""
return {elt.s for elt in node.value.elts} # type: ignore[attr-defined]

else:

def parse__all__(node: NodeAssign) -> set[str]: # noqa: WPS116,WPS120,WPS440
"""Get the values declared in `__all__`.
Parameters:
node: The assignment node.
Returns:
A set of names.
"""
return {elt.value for elt in node.value.elts} # type: ignore[attr-defined]


# ==========================================================
# annotations
def _get_attribute_annotation(node: NodeAttribute, parent: Module | Class) -> Expression:
Expand Down
14 changes: 10 additions & 4 deletions src/griffe/agents/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_names,
get_parameter_default,
get_value,
parse__all__,
)
from griffe.collections import LinesCollection, ModulesCollection
from griffe.dataclasses import (
Expand Down Expand Up @@ -385,13 +386,18 @@ def visit_importfrom(self, node: ast.ImportFrom) -> None:
Parameters:
node: The node to visit.
"""
# TODO: does this handle relative imports?
for name in node.names:
alias_name = name.asname or name.name
alias_path = f"{node.module}.{name.name}"
self.current.imports[name.asname or name.name] = alias_path
if alias_name == "*":
alias_name = node.module.replace(".", "/") + "/*" # type: ignore[union-attr]
alias_path = node.module
else:
alias_path = f"{node.module}.{name.name}"
self.current.imports[alias_name] = alias_path
self.current[alias_name] = Alias(
alias_name,
alias_path,
alias_path, # type: ignore[arg-type]
lineno=node.lineno,
endlineno=node.end_lineno, # type: ignore[attr-defined]
)
Expand Down Expand Up @@ -468,7 +474,7 @@ def handle_attribute( # noqa: WPS231

if name == "__all__":
with suppress(AttributeError):
parent.exports = {elt.value for elt in node.value.elts} # type: ignore[union-attr]
parent.exports = parse__all__(node) # type: ignore[arg-type]

def visit_assign(self, node: ast.Assign) -> None:
"""Visit an assignment node.
Expand Down
6 changes: 6 additions & 0 deletions src/griffe/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ async def _load_packages_async(
logger.error(f"Tried but could not import package {package}")
else:
loaded[module.name] = module
for obj in loaded.values():
if not await loader.follow_aliases(obj):
logger.info("Not all aliases were resolved")
return loaded


Expand All @@ -89,6 +92,9 @@ def _load_packages(
logger.error(f"Tried but could not import package {package}")
else:
loaded[module.name] = module
for obj in loaded.values():
if not loader.follow_aliases(obj):
logger.info("Not all aliases were resolved")
return loaded


Expand Down
11 changes: 11 additions & 0 deletions src/griffe/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,17 @@ def resolved(self) -> bool:
"""
return self._target is not None

@cached_property
def wildcard(self) -> str | None:
"""Return the module on which the wildcard import is performed (if any).
Returns:
The wildcard imported module, or None.
"""
if self.name.endswith("/*"):
return self._target_path
return None

def as_dict(self, full: bool = False, **kwargs: Any) -> dict[str, Any]:
"""Return this alias' data as a dictionary.
Expand Down
87 changes: 74 additions & 13 deletions src/griffe/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from griffe.agents.inspector import inspect
from griffe.agents.visitor import patch_ast, visit
from griffe.collections import LinesCollection, ModulesCollection
from griffe.dataclasses import Module, Object
from griffe.dataclasses import Alias, Kind, Module, Object
from griffe.docstrings.parsers import Parser
from griffe.exceptions import AliasResolutionError, UnhandledPthFileError, UnimportableModuleError
from griffe.logger import get_logger
Expand Down Expand Up @@ -143,6 +143,15 @@ def _member_parent(self, module: Module, subparts: NamePartsType, subpath: Path)
return member_parent
raise UnimportableModuleError(f"{subpath} is not importable")

def _expand_wildcard(self, wildcard_obj: Alias) -> dict[str, Object | Alias]:
module = self.modules_collection[wildcard_obj.wildcard] # type: ignore[index] # we know it's a wildcard
explicitely = "__all__" in module.members
return {
name: imported_member
for name, imported_member in module.members.items()
if imported_member.is_exported(explicitely=explicitely)
}


class GriffeLoader(_BaseGriffeLoader):
"""The Griffe loader, allowing to load data from modules."""
Expand All @@ -168,6 +177,8 @@ def load_module(
module_name = module
top_module = self._inspect_module(module) # type: ignore[arg-type]
else:
# TODO: maybe don't try each time to find a relative path,
# to improve recursion when following aliases / expanding wildcards
try:
module_name, top_module_name, top_module_path = _top_name_and_path(module, search_paths)
except ModuleNotFoundError:
Expand All @@ -191,20 +202,45 @@ def follow_aliases(self, obj: Object, only_exported: bool = True) -> bool: # no
True if everything was resolved, False otherwise.
"""
success = True
expanded = {}
to_remove = []

# iterate a first time to expand wildcards
for member in obj.members.values():
if member.is_alias:
if only_exported and not obj.member_is_exported(member, explicitely=True):
if member.is_alias and member.wildcard: # type: ignore[union-attr] # we know it's an alias
package = member.wildcard.split(".", 1)[0] # type: ignore[union-attr]
if obj.package.path != package and package not in self.modules_collection:
try:
self.load_module(package)
except ImportError as error:
logger.warning(f"Could not expand wildcard import {member.name} in {obj.path}: {error}")
else:
expanded.update(self._expand_wildcard(member)) # type: ignore[arg-type]
to_remove.append(member.name)

for name in to_remove:
del obj[name] # noqa: WPS420
for new_member in expanded.values():
obj[new_member.name] = Alias(new_member.name, new_member)

# iterate a second time to resolve aliases and recurse
for member in obj.members.values(): # noqa: WPS440
if member.is_alias and not member.wildcard: # type: ignore[union-attr]
if only_exported and not member.is_explicitely_exported:
continue
try:
member.resolve_target() # type: ignore[union-attr] # we know it's an alias
except AliasResolutionError as error:
member.resolve_target() # type: ignore[union-attr]
except AliasResolutionError as error: # noqa: WPS440
success = False
package = error.target_path.split(".", 1)[0]
if obj.package.path != package and package not in self.modules_collection:
with suppress(ModuleNotFoundError):
try: # noqa: WPS505
self.load_module(package)
else:
except ImportError as error: # noqa: WPS440
logger.warning(f"Could not follow alias {member.path}: {error}")
elif member.kind in {Kind.MODULE, Kind.CLASS}:
success &= self.follow_aliases(member) # type: ignore[arg-type] # we know it's an object

return success

def _load_module_path(
Expand Down Expand Up @@ -292,20 +328,45 @@ async def follow_aliases(self, obj: Object, only_exported: bool = True) -> bool:
True if everything was resolved, False otherwise.
"""
success = True
expanded = {}
to_remove = []

# iterate a first time to expand wildcards
for member in obj.members.values():
if member.is_alias:
if only_exported and not obj.member_is_exported(member, explicitely=True):
if member.is_alias and member.wildcard: # type: ignore[union-attr] # we know it's an alias
package = member.wildcard.split(".", 1)[0] # type: ignore[union-attr]
if obj.package.path != package and package not in self.modules_collection:
try:
await self.load_module(package)
except ImportError as error:
logger.warning(f"Could not expand wildcard import {member.name} in {obj.path}: {error}")
else:
expanded.update(self._expand_wildcard(member)) # type: ignore[arg-type]
to_remove.append(member.name)

for name in to_remove:
del obj[name] # noqa: WPS420
for new_member in expanded.values():
obj[new_member.name] = Alias(new_member.name, new_member)

# iterate a second time to resolve aliases and recurse
for member in obj.members.values(): # noqa: WPS440
if member.is_alias and not member.wildcard: # type: ignore[union-attr]
if only_exported and not member.is_explicitely_exported:
continue
try:
member.resolve_target() # type: ignore[union-attr] # we know it's an alias
except AliasResolutionError as error:
member.resolve_target() # type: ignore[union-attr]
except AliasResolutionError as error: # noqa: WPS440
success = False
package = error.target_path.split(".", 1)[0]
if obj.package.path != package and package not in self.modules_collection:
with suppress(ModuleNotFoundError):
try: # noqa: WPS505
await self.load_module(package)
else:
except ImportError as error: # noqa: WPS440
logger.warning(f"Could not follow alias {member.path}: {error}")
elif member.kind in {Kind.MODULE, Kind.CLASS}:
success &= await self.follow_aliases(member) # type: ignore[arg-type] # we know it's an object

return success

async def _load_module_path(
Expand Down
18 changes: 15 additions & 3 deletions src/griffe/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,27 @@ def _get_parts(key: str | Sequence[str]) -> Sequence[str]:
return parts


class SetMembersMixin:
class DelMembersMixin:
"""This mixin adds a `__delitem__` method to a class."""

def __delitem__(self, key: str | Sequence[str]) -> None: # noqa: WPS603
parts = _get_parts(key)
if len(parts) == 1:
name = parts[0]
del self.members[name] # type: ignore[attr-defined] # noqa: WPS420
else:
del self.members[parts[0]][parts[1]] # type: ignore[attr-defined] # noqa: WPS420


class SetMembersMixin(DelMembersMixin):
"""This mixin adds a `__setitem__` method to a class.
It makes it easier to set members of an object.
The method expects a `members` attribute/property to be available on the instance.
Each time a member is set, its `parent` attribute is set as well.
"""

def __setitem__(self, key: str | Sequence[str], value):
def __setitem__(self, key: str | Sequence[str], value) -> None:
parts = _get_parts(key)
if len(parts) == 1:
name = parts[0]
Expand All @@ -61,7 +73,7 @@ def __setitem__(self, key: str | Sequence[str], value):
self.members[parts[0]][parts[1]] = value # type: ignore[attr-defined]


class SetCollectionMembersMixin:
class SetCollectionMembersMixin(DelMembersMixin):
"""This mixin adds a `__setitem__` method to a class.
It makes it easier to set members of an object.
Expand Down

0 comments on commit 77a3cb7

Please sign in to comment.