Skip to content

Commit

Permalink
Merge pull request #71 from aai-institute/extend-reporter-interface
Browse files Browse the repository at this point in the history
Change `nnbench.reporter` file to submodule, add base reporter
  • Loading branch information
nicholasjng authored Feb 8, 2024
2 parents bf4dd5e + 703ac37 commit d818d62
Show file tree
Hide file tree
Showing 8 changed files with 258 additions and 187 deletions.
2 changes: 1 addition & 1 deletion docs/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ reporter = ConsoleReporter()

# To collect in the current file, pass "__main__" as module name.
result = r.run("__main__", params={"model": model, "X_test": X_test, "y_test": y_test})
reporter.report(result)
reporter.write(result)
```

The resulting output might look like this:
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def mnist_jax():
reporter = ConsoleReporter()
params = MNISTTestParameters(params=state.params, data=data)
result = runner.run(HERE, params=params)
reporter.report(result)
reporter.write(result)


if __name__ == "__main__":
Expand Down
1 change: 0 additions & 1 deletion src/nnbench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
# package is not installed
pass

# TODO: This naming is unfortunate
from .core import benchmark, parametrize, product
from .reporter import BenchmarkReporter, register_reporter
from .runner import BenchmarkRunner
Expand Down
184 changes: 0 additions & 184 deletions src/nnbench/reporter.py

This file was deleted.

45 changes: 45 additions & 0 deletions src/nnbench/reporter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
A lightweight interface for refining, displaying, and streaming benchmark results to various sinks.
"""
from __future__ import annotations

import importlib
import types

from .base import BenchmarkReporter
from .console import ConsoleReporter

# internal, mutable
_reporter_registry: dict[str, type[BenchmarkReporter]] = {
"console": ConsoleReporter,
}

# external, immutable
reporter_registry: types.MappingProxyType[str, type[BenchmarkReporter]] = types.MappingProxyType(
_reporter_registry
)


def register_reporter(key: str, cls_or_name: str | type[BenchmarkReporter]) -> None:
"""
Register a reporter class by its fully qualified module path.
Parameters
----------
key: str
The key to register the reporter under. Subsequently, this key can be used in place
of reporter classes in code.
cls_or_name: str | type[BenchmarkReporter]
Name of or full module path to the reporter class. For example, when registering a class
``MyReporter`` located in ``my_module``, ``name`` should be ``my_module.MyReporter``.
"""

if isinstance(cls_or_name, str):
name = cls_or_name
modname, clsname = name.rsplit(".", 1)
mod = importlib.import_module(modname)
cls = getattr(mod, clsname)
_reporter_registry[key] = cls
else:
# name = cls_or_name.__module__ + "." + cls_or_name.__qualname__
_reporter_registry[key] = cls_or_name
92 changes: 92 additions & 0 deletions src/nnbench/reporter/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import Sequence

from nnbench.types import BenchmarkRecord


def default_merge(records: Sequence[BenchmarkRecord]) -> BenchmarkRecord:
"""
Merges a number of benchmark records into one.
The resulting record has an empty top-level context, since the context
values might be different in all respective records.
TODO: Think about merging contexts here to preserve the record model,
padding missing values with a placeholder if not present.
-> Might be easier with an OOP Context class.
Parameters
----------
records: Sequence[BenchmarkRecord]
The records to merge.
Returns
-------
BenchmarkRecord
The merged record, with all benchmark contexts inlined into their
respective benchmarks.
"""
merged = BenchmarkRecord(context=dict(), benchmarks=[])
for record in records:
ctx, benchmarks = record["context"], record["benchmarks"]
for bm in benchmarks:
bm["context"] = ctx
merged["benchmarks"].extend(benchmarks)
return merged


# TODO: Add IO mixins for database, file, and HTTP IO
class BenchmarkReporter:
"""
The base interface for a benchmark reporter class.
A benchmark reporter consumes benchmark results from a previous run, and subsequently
reports them in the way specified by the respective implementation's ``report_result()``
method.
For example, to write benchmark results to a database, you could save the credentials
for authentication on the class, and then stream the results directly to
the database in ``report_result()``, with preprocessing if necessary.
"""

merge: bool = False
"""Whether to merge multiple BenchmarkRecords before reporting."""

def initialize(self):
"""
Initialize the reporter's state.
This is the place where to create a result directory, a database connection,
or a HTTP client.
"""
pass

def finalize(self):
"""
Finalize the reporter's state.
This is the place to destroy / release resources that were previously
acquired in ``initialize()``.
"""
pass

merge_records = staticmethod(default_merge)

def read(self) -> BenchmarkRecord:
raise NotImplementedError

def read_batched(self) -> list[BenchmarkRecord]:
raise NotImplementedError

def write(self, record: BenchmarkRecord) -> None:
raise NotImplementedError

def write_batched(self, records: Sequence[BenchmarkRecord]) -> None:
# by default, merge first and then write.
if self.merge:
merged = self.merge_records(records)
self.write(merged)
else:
# write everything in a loop.
for record in records:
self.write(record)
Loading

0 comments on commit d818d62

Please sign in to comment.