-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #71 from aai-institute/extend-reporter-interface
Change `nnbench.reporter` file to submodule, add base reporter
- Loading branch information
Showing
8 changed files
with
258 additions
and
187 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
""" | ||
A lightweight interface for refining, displaying, and streaming benchmark results to various sinks. | ||
""" | ||
from __future__ import annotations | ||
|
||
import importlib | ||
import types | ||
|
||
from .base import BenchmarkReporter | ||
from .console import ConsoleReporter | ||
|
||
# internal, mutable | ||
_reporter_registry: dict[str, type[BenchmarkReporter]] = { | ||
"console": ConsoleReporter, | ||
} | ||
|
||
# external, immutable | ||
reporter_registry: types.MappingProxyType[str, type[BenchmarkReporter]] = types.MappingProxyType( | ||
_reporter_registry | ||
) | ||
|
||
|
||
def register_reporter(key: str, cls_or_name: str | type[BenchmarkReporter]) -> None: | ||
""" | ||
Register a reporter class by its fully qualified module path. | ||
Parameters | ||
---------- | ||
key: str | ||
The key to register the reporter under. Subsequently, this key can be used in place | ||
of reporter classes in code. | ||
cls_or_name: str | type[BenchmarkReporter] | ||
Name of or full module path to the reporter class. For example, when registering a class | ||
``MyReporter`` located in ``my_module``, ``name`` should be ``my_module.MyReporter``. | ||
""" | ||
|
||
if isinstance(cls_or_name, str): | ||
name = cls_or_name | ||
modname, clsname = name.rsplit(".", 1) | ||
mod = importlib.import_module(modname) | ||
cls = getattr(mod, clsname) | ||
_reporter_registry[key] = cls | ||
else: | ||
# name = cls_or_name.__module__ + "." + cls_or_name.__qualname__ | ||
_reporter_registry[key] = cls_or_name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from typing import Sequence | ||
|
||
from nnbench.types import BenchmarkRecord | ||
|
||
|
||
def default_merge(records: Sequence[BenchmarkRecord]) -> BenchmarkRecord: | ||
""" | ||
Merges a number of benchmark records into one. | ||
The resulting record has an empty top-level context, since the context | ||
values might be different in all respective records. | ||
TODO: Think about merging contexts here to preserve the record model, | ||
padding missing values with a placeholder if not present. | ||
-> Might be easier with an OOP Context class. | ||
Parameters | ||
---------- | ||
records: Sequence[BenchmarkRecord] | ||
The records to merge. | ||
Returns | ||
------- | ||
BenchmarkRecord | ||
The merged record, with all benchmark contexts inlined into their | ||
respective benchmarks. | ||
""" | ||
merged = BenchmarkRecord(context=dict(), benchmarks=[]) | ||
for record in records: | ||
ctx, benchmarks = record["context"], record["benchmarks"] | ||
for bm in benchmarks: | ||
bm["context"] = ctx | ||
merged["benchmarks"].extend(benchmarks) | ||
return merged | ||
|
||
|
||
# TODO: Add IO mixins for database, file, and HTTP IO | ||
class BenchmarkReporter: | ||
""" | ||
The base interface for a benchmark reporter class. | ||
A benchmark reporter consumes benchmark results from a previous run, and subsequently | ||
reports them in the way specified by the respective implementation's ``report_result()`` | ||
method. | ||
For example, to write benchmark results to a database, you could save the credentials | ||
for authentication on the class, and then stream the results directly to | ||
the database in ``report_result()``, with preprocessing if necessary. | ||
""" | ||
|
||
merge: bool = False | ||
"""Whether to merge multiple BenchmarkRecords before reporting.""" | ||
|
||
def initialize(self): | ||
""" | ||
Initialize the reporter's state. | ||
This is the place where to create a result directory, a database connection, | ||
or a HTTP client. | ||
""" | ||
pass | ||
|
||
def finalize(self): | ||
""" | ||
Finalize the reporter's state. | ||
This is the place to destroy / release resources that were previously | ||
acquired in ``initialize()``. | ||
""" | ||
pass | ||
|
||
merge_records = staticmethod(default_merge) | ||
|
||
def read(self) -> BenchmarkRecord: | ||
raise NotImplementedError | ||
|
||
def read_batched(self) -> list[BenchmarkRecord]: | ||
raise NotImplementedError | ||
|
||
def write(self, record: BenchmarkRecord) -> None: | ||
raise NotImplementedError | ||
|
||
def write_batched(self, records: Sequence[BenchmarkRecord]) -> None: | ||
# by default, merge first and then write. | ||
if self.merge: | ||
merged = self.merge_records(records) | ||
self.write(merged) | ||
else: | ||
# write everything in a loop. | ||
for record in records: | ||
self.write(record) |
Oops, something went wrong.