Skip to content

Commit

Permalink
fix: Fix logic for skipping already encountered modules when scanning…
Browse files Browse the repository at this point in the history
… namespace packages

Issue mkdocstrings#646: mkdocstrings/mkdocstrings#646
  • Loading branch information
pawamoy committed Jan 18, 2024
1 parent 3d66d0e commit 21a48d0
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 13 deletions.
29 changes: 18 additions & 11 deletions src/griffe/finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from collections import defaultdict
from contextlib import suppress
from dataclasses import dataclass
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, ClassVar, Iterator, Sequence, Tuple

Expand Down Expand Up @@ -235,8 +236,6 @@ def iter_submodules(
self,
path: Path | list[Path],
seen: set | None = None,
*,
additional: bool = True,
) -> Iterator[NamePartsAndPathType]:
"""Iterate on a module's submodules, if any.
Expand All @@ -255,11 +254,11 @@ def iter_submodules(
filepath (Path): A submodule filepath.
"""
if isinstance(path, list):
seen = seen if seen is not None else set()
# We never enter this condition again in recursive calls,
# so we just have to set `seen` once regardless of its value.
seen = set()
for path_elem in path:
if path_elem not in seen:
seen.add(path_elem)
yield from self.iter_submodules(path_elem, seen, additional=additional)
yield from self.iter_submodules(path_elem, seen)
return

if path.stem == "__init__":
Expand All @@ -269,7 +268,12 @@ def iter_submodules(
elif path.suffix in self.extensions_set:
return

skip = set(seen) if seen else set()
# `seen` is only set when we scan a list of paths (namespace package).
# `skip` is used to prevent yielding modules
# of a regular subpackage that we already yielded
# from another part of the namespace.
skip = set(seen or ())

for subpath in self._filter_py_modules(path):
rel_subpath = subpath.relative_to(path)
if rel_subpath.parent in skip:
Expand All @@ -294,9 +298,6 @@ def iter_submodules(
else:
yield rel_subpath.with_name(stem).parts, subpath

if additional:
yield from self.iter_submodules(self._always_scan_for[path.stem], seen=seen, additional=False)

def submodules(self, module: Module) -> list[NamePartsAndPathType]:
"""Return the list of a module's submodules.
Expand All @@ -306,7 +307,13 @@ def submodules(self, module: Module) -> list[NamePartsAndPathType]:
Returns:
A list of tuples containing the parts of the submodule name and its path.
"""
return sorted(self.iter_submodules(module.filepath), key=_module_depth)
return sorted(
chain(
self.iter_submodules(module.filepath),
self.iter_submodules(self._always_scan_for[module.name]),
),
key=_module_depth,
)

def _module_name_path(self, path: Path) -> tuple[str, Path]:
if path.is_dir():
Expand Down
7 changes: 5 additions & 2 deletions src/griffe/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def temporary_pypackage(
modules: Sequence[str] | Mapping[str, str] | None = None,
*,
init: bool = True,
inits: bool = True,
) -> Iterator[TmpPackage]:
"""Create a package containing the given modules in a temporary directory.
Expand All @@ -72,7 +73,8 @@ def temporary_pypackage(
If a list, simply touch the files: `["b.py", "c/d.py", "e/f"]`.
If a dict, keys are the file names and values their contents:
`{"b.py": "b = 1", "c/d.py": "print('hey from c')"}`.
init: Whether to create an `__init__` module in the leaf package.
init: Whether to create an `__init__` module in the top package.
inits: Whether to create `__init__` modules in subpackages.
Yields:
A temporary package.
Expand All @@ -96,7 +98,8 @@ def temporary_pypackage(
else:
current_path /= part
current_path.mkdir(**mkdir_kwargs)
current_path.joinpath("__init__.py").touch()
if inits:
current_path.joinpath("__init__.py").touch()
yield TmpPackage(tmpdirpath, package_name, package_path)


Expand Down
32 changes: 32 additions & 0 deletions tests/test_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from __future__ import annotations

import os
from pathlib import Path
from textwrap import dedent

import pytest

from griffe.dataclasses import Module
from griffe.finder import ModuleFinder, NamespacePackage, Package, _handle_editable_module, _handle_pth_file
from griffe.tests import temporary_pypackage

Expand Down Expand Up @@ -268,3 +270,33 @@ def test_finding_stubs_packages(
assert result.path.suffix == ".pyi"
assert result.path.parent.name.endswith("-stubs")
assert result.stubs is None


@pytest.mark.parametrize("namespace_package", [False, True])
def test_scanning_package_and_module_with_same_names(namespace_package: bool) -> None:
"""The finder correctly scans package and module having same the name.
Parameters:
namespace_package: Whether the temporary package is a namespace one.
"""
init = not namespace_package
with temporary_pypackage("pkg", ["pkg/mod.py", "mod/mod.py"], init=init, inits=init) as tmp_package:
# Here we must make sure that all paths are relative
# to correctly assert the finder's behavior,
# so we pass `.` and actually enter the temporary directory.
path = Path(tmp_package.name)
filepath: Path | list[Path] = [path] if namespace_package else path
old = os.getcwd()
os.chdir(tmp_package.path.parent)
try:
finder = ModuleFinder(search_paths=[])
found = [path for _, path in finder.submodules(Module("pkg", filepath=filepath))]
finally:
os.chdir(old)
check = (
path / "pkg/mod.py",
path / "mod/mod.py",
)
for mod in check:
assert mod in found

0 comments on commit 21a48d0

Please sign in to comment.