Skip to content

Commit

Permalink
core: Discover benchmark families as well
Browse files Browse the repository at this point in the history
Requires some tweaking of designated benchmark types in `nnbench.collect()` and
`nnbench.run()`. Now, we pick up every module member that is either a `Benchmark` or
`BenchmarkFamily`. Finally, in `nnbench.run()`, we create a composite iterator over our findings
(currently, most often a list, but single benchmarks / families are also supported),
which allows to keep the main loop intact no matter what the user chooses to pass.
  • Loading branch information
nicholasjng committed Jan 20, 2025
1 parent 1461984 commit 3931d32
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/nnbench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
from .core import benchmark, parametrize, product
from .reporter import BenchmarkReporter, ConsoleReporter, FileReporter
from .runner import collect, run
from .types import Benchmark, BenchmarkRecord, Parameters
from .types import Benchmark, BenchmarkFamily, BenchmarkRecord, Parameters

__version__ = "0.4.0"
38 changes: 24 additions & 14 deletions src/nnbench/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

from nnbench.context import ContextProvider
from nnbench.fixtures import FixtureManager
from nnbench.types import Benchmark, BenchmarkRecord, Parameters, State
from nnbench.types import Benchmark, BenchmarkFamily, BenchmarkRecord, Parameters, State
from nnbench.util import import_file_as_module, ismodule

Benchmarkable = Benchmark | BenchmarkFamily

logger = logging.getLogger("nnbench.runner")


Expand Down Expand Up @@ -92,7 +94,9 @@ def _jsonify(val):
return json_params


def collect(path_or_module: str | os.PathLike[str], tags: tuple[str, ...] = ()) -> list[Benchmark]:
def collect(
path_or_module: str | os.PathLike[str], tags: tuple[str, ...] = ()
) -> list[Benchmark | BenchmarkFamily]:
# TODO: functools.cache this guy
"""
Discover benchmarks in a module or source file.
Expand Down Expand Up @@ -134,20 +138,19 @@ def collect(path_or_module: str | os.PathLike[str], tags: tuple[str, ...] = ())
if k.startswith("__") and k.endswith("__"):
# dunder names are ignored.
continue
elif isinstance(v, Benchmark):
elif isinstance(v, Benchmarkable):
if not tags or set(tags) & set(v.tags):
benchmarks.append(v)
elif isinstance(v, list | tuple | set | frozenset):
for bm in v:
if isinstance(bm, Benchmark):
if isinstance(bm, Benchmarkable):
if not tags or set(tags) & set(bm.tags):
benchmarks.append(bm)

return benchmarks


def run(
benchmarks: Benchmark | Iterable[Benchmark],
benchmarks: Benchmark | BenchmarkFamily | Iterable[Benchmark | BenchmarkFamily],
name: str | None = None,
params: dict[str, Any] | Parameters | None = None,
context: Sequence[ContextProvider] = (),
Expand All @@ -158,8 +161,8 @@ def run(
Parameters
----------
benchmarks: Sequence[Benchmark]
The list of discovered benchmarks to run.
benchmarks: Benchmark | BenchmarkFamily | Iterable[Benchmark | BenchmarkFamily]
A benchmark, family of benchmarks, or collection of discovered benchmarks to run.
name: str | None
A name for the currently started run. If None, a name will be automatically generated.
params: dict[str, Any] | Parameters | None
Expand All @@ -181,6 +184,16 @@ def run(
"name" giving the benchmark run name, "context" holding the context information,
and "benchmarks", holding an array with the benchmark results.
"""

def benchmark_iterator(
_bm: Iterable[Benchmark | BenchmarkFamily],
) -> Generator[Benchmark, None, None]:
for _b in _bm:
if isinstance(_b, Benchmark):
yield _b
else:
yield from _b

_run = name or "nnbench-" + platform.node() + "-" + uuid.uuid1().hex[:8]

family_sizes: dict[str, Any] = collections.defaultdict(int)
Expand All @@ -195,15 +208,12 @@ def run(
raise ValueError(f"got multiple values for context key {dupe!r}")
ctx.update(val)

if isinstance(benchmarks, Benchmark):
benchmarks = [benchmarks]

# if we didn't find any benchmarks, return an empty record.
if not benchmarks:
return BenchmarkRecord(run=_run, context=ctx, benchmarks=[])

# for bm in benchmarks:
# family_sizes[bm.interface.funcname] += 1
if isinstance(benchmarks, Benchmarkable):
benchmarks = [benchmarks]

if isinstance(params, Parameters):
dparams = asdict(params)
Expand All @@ -212,7 +222,7 @@ def run(

results: list[dict[str, Any]] = []

for benchmark in benchmarks:
for benchmark in benchmark_iterator(benchmarks):
bm_family = benchmark.interface.funcname
state = State(
name=benchmark.name,
Expand Down

0 comments on commit 3931d32

Please sign in to comment.