Skip to content

Commit

Permalink
refactor: Don't always try to find a module as a relative path
Browse files Browse the repository at this point in the history
  • Loading branch information
pawamoy committed Jan 2, 2022
1 parent a498db0 commit e6df277
Showing 1 changed file with 26 additions and 11 deletions.
37 changes: 26 additions & 11 deletions src/griffe/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,15 @@ def load_module(
module: str | Path,
submodules: bool = True,
search_paths: Sequence[str | Path] | None = None,
try_relative_path: bool = True,
) -> Module:
"""Load a module.
Parameters:
module: The module name or path.
submodules: Whether to recurse on the submodules.
search_paths: The paths to search into.
try_relative_path: Whether to try finding the module as a relative path.
Returns:
A module.
Expand All @@ -177,10 +179,10 @@ def load_module(
module_name = module
top_module = self._inspect_module(module) # type: ignore[arg-type]
else:
# TODO: maybe don't try each time to find a relative path,
# to improve recursion when following aliases / expanding wildcards
try:
module_name, top_module_name, top_module_path = _top_name_and_path(module, search_paths)
module_name, top_module_name, top_module_path = _top_name_and_path(
module, search_paths, try_relative_path
)
except ModuleNotFoundError:
logger.debug(f"Could not find {module}: trying inspection")
module_name = module
Expand Down Expand Up @@ -211,7 +213,7 @@ def follow_aliases(self, obj: Object, only_exported: bool = True) -> bool: # no
package = member.wildcard.split(".", 1)[0] # type: ignore[union-attr]
if obj.package.path != package and package not in self.modules_collection:
try:
self.load_module(package)
self.load_module(package, try_relative_path=False)
except ImportError as error:
logger.warning(f"Could not expand wildcard import {member.name} in {obj.path}: {error}")
else:
Expand All @@ -235,7 +237,7 @@ def follow_aliases(self, obj: Object, only_exported: bool = True) -> bool: # no
package = error.target_path.split(".", 1)[0]
if obj.package.path != package and package not in self.modules_collection:
try: # noqa: WPS505
self.load_module(package)
self.load_module(package, try_relative_path=False)
except ImportError as error: # noqa: WPS440
logger.warning(f"Could not follow alias {member.path}: {error}")
elif member.kind in {Kind.MODULE, Kind.CLASS}:
Expand Down Expand Up @@ -289,13 +291,15 @@ async def load_module(
module: str | Path,
submodules: bool = True,
search_paths: Sequence[str | Path] | None = None,
try_relative_path: bool = True,
) -> Module:
"""Load a module.
Parameters:
module: The module name or path.
submodules: Whether to recurse on the submodules.
search_paths: The paths to search into.
try_relative_path: Whether to try finding the module as a relative path.
Returns:
A module.
Expand All @@ -306,7 +310,9 @@ async def load_module(
top_module = self._inspect_module(module) # type: ignore[arg-type]
else:
try:
module_name, top_module_name, top_module_path = _top_name_and_path(module, search_paths)
module_name, top_module_name, top_module_path = _top_name_and_path(
module, search_paths, try_relative_path
)
except ModuleNotFoundError:
logger.debug(f"Could not find {module}: trying inspection")
module_name = module
Expand Down Expand Up @@ -337,7 +343,7 @@ async def follow_aliases(self, obj: Object, only_exported: bool = True) -> bool:
package = member.wildcard.split(".", 1)[0] # type: ignore[union-attr]
if obj.package.path != package and package not in self.modules_collection:
try:
await self.load_module(package)
await self.load_module(package, try_relative_path=False)
except ImportError as error:
logger.warning(f"Could not expand wildcard import {member.name} in {obj.path}: {error}")
else:
Expand All @@ -361,7 +367,7 @@ async def follow_aliases(self, obj: Object, only_exported: bool = True) -> bool:
package = error.target_path.split(".", 1)[0]
if obj.package.path != package and package not in self.modules_collection:
try: # noqa: WPS505
await self.load_module(package)
await self.load_module(package, try_relative_path=False)
except ImportError as error: # noqa: WPS440
logger.warning(f"Could not follow alias {member.path}: {error}")
elif member.kind in {Kind.MODULE, Kind.CLASS}:
Expand Down Expand Up @@ -410,8 +416,9 @@ async def _load_submodule(self, module: Module, subparts: NamePartsType, subpath
def _top_name_and_path(
module: str | Path,
search_paths: Sequence[str | Path] | None = None,
try_relative_path: bool = True,
) -> tuple[str, str, Path]:
module_name, module_path = find_module_or_path(module, search_paths)
module_name, module_path = find_module_or_path(module, search_paths, try_relative_path)
module_parts = module_name.split(".")
top_module_name = module_parts[0]
top_module_path = module_path
Expand All @@ -423,6 +430,7 @@ def _top_name_and_path(
def find_module_or_path(
module: str | Path,
search_paths: Sequence[str | Path] | None = None,
try_relative_path: bool = True,
) -> tuple[str, Path]:
"""Find the name and path of a module.
Expand All @@ -433,6 +441,8 @@ def find_module_or_path(
Parameters:
module: The module name or path.
search_paths: The paths to search into.
try_relative_path: Whether to try finding the module as a relative path,
when the given module is not already a path.
Raises:
FileNotFoundError: When a Path was passed and the module could not be found:
Expand All @@ -454,13 +464,16 @@ def find_module_or_path(
if isinstance(module, Path):
# programatically passed a Path, try only that
module_name, module_path = _module_name_path(module)
else:
elif try_relative_path:
# passed a string (from CLI or Python code), try both
try:
module_name, module_path = _module_name_path(Path(module))
except FileNotFoundError:
module_name = module
module_path = find_module(module_name, search_paths=search_paths)
else:
module_name = module
module_path = find_module(module_name, search_paths=search_paths)
return module_name, module_path


Expand Down Expand Up @@ -506,7 +519,7 @@ def find_module(module_name: str, search_paths: Sequence[str | Path] | None = No
if abs_top_pth.exists():
with suppress(UnhandledPthFileError):
location = _handle_pth_file(abs_top_pth)
if location.suffix == ".py":
if location.suffix:
location = location.parent
search = [location.parent]
# TODO: possible optimization
Expand All @@ -515,6 +528,7 @@ def find_module(module_name: str, search_paths: Sequence[str | Path] | None = No

# resume regular search
filepaths = [
# TODO: handle .py[cod] and .so files?
Path(*parts, "__init__.py"),
Path(*parts[:-1], f"{parts[-1]}.py"),
Path(*parts[:-1], f"{parts[-1]}.pth"),
Expand Down Expand Up @@ -542,6 +556,7 @@ def _handle_pth_file(path):
instructions = path.read_text().strip("\n").split(";")

filepaths = [
# TODO: handle .py[cod] and .so files?
Path(instructions[0], path.stem, "__init__.py"),
Path(instructions[0], path.stem), # namespace packages, try last
]
Expand Down

0 comments on commit e6df277

Please sign in to comment.