Skip to content

Commit

Permalink
Merge pull request #70 from aai-institute/fix-collection
Browse files Browse the repository at this point in the history
Fix runner collection by checking file path first
  • Loading branch information
nicholasjng authored Feb 8, 2024
2 parents ff0dd94 + 4056143 commit bf4dd5e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 19 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repos:
types_or: [ python, pyi ]
args: [--ignore-missing-imports, --scripts-are-modules]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.0
rev: v0.2.1
hooks:
- id: ruff
args: [ --fix, --exit-non-zero-on-fix ]
Expand All @@ -29,6 +29,6 @@ repos:
args: [-c, pyproject.toml]
additional_dependencies: ["bandit[toml]"]
- repo: https://github.com/jsh9/pydoclint
rev: 0.3.9
rev: 0.4.0
hooks:
- id: pydoclint
4 changes: 2 additions & 2 deletions docs/guides/benchmarks.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ from nnbench.reporter import ConsoleReporter
r = nnbench.BenchmarkRunner()
reporter = ConsoleReporter()

result = r.run("./benchmarks.py", params={"n_estimators": 100, "max_depth": 5, "random_state": 42})
result = r.run("benchmarks.py", params={"n_estimators": 100, "max_depth": 5, "random_state": 42})
reporter.report(result)
```

Expand Down Expand Up @@ -127,7 +127,7 @@ from nnbench.reporter import ConsoleReporter
r = nnbench.BenchmarkRunner()
reporter = ConsoleReporter()

result = r.run("./benchmarks.py", params={"random_state": 42})
result = r.run("benchmarks.py", params={"random_state": 42})
reporter.report(result)
```

Expand Down
29 changes: 14 additions & 15 deletions src/nnbench/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,23 +144,22 @@ def collect(self, path_or_module: str | os.PathLike[str], tags: tuple[str, ...]
ValueError
If the given path is not a Python file, directory, or module name.
"""
if ismodule(path_or_module):
ppath = Path(path_or_module)
if ppath.is_dir():
pythonpaths = (p for p in ppath.iterdir() if p.suffix == ".py")
for py in pythonpaths:
logger.debug(f"Collecting benchmarks from submodule {py.name!r}.")
self.collect(py, tags)
return
elif ppath.is_file():
module = import_file_as_module(path_or_module)
elif ismodule(path_or_module):
module = sys.modules[str(path_or_module)]
else:
ppath = Path(path_or_module)
if ppath.is_dir():
pythonpaths = (p for p in ppath.iterdir() if p.suffix == ".py")
for py in pythonpaths:
logger.debug(f"Collecting benchmarks from submodule {py.name!r}.")
self.collect(py, tags)
return
elif ppath.is_file():
module = import_file_as_module(path_or_module)
else:
raise ValueError(
f"expected a module name, Python file, or directory, "
f"got {str(path_or_module)!r}"
)
raise ValueError(
f"expected a module name, Python file, or directory, "
f"got {str(path_or_module)!r}"
)

# iterate through the module dict members to register
for k, v in module.__dict__.items():
Expand Down

0 comments on commit bf4dd5e

Please sign in to comment.