diff --git a/docs/quickstart.md b/docs/quickstart.md index 93eef77e..b50ef1cd 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -51,7 +51,7 @@ reporter = ConsoleReporter() # To collect in the current file, pass "__main__" as module name. result = r.run("__main__", params={"model": model, "X_test": X_test, "y_test": y_test}) -reporter.report(result) +reporter.write(result) ``` The resulting output might look like this: diff --git a/examples/mnist/mnist.py b/examples/mnist/mnist.py index d3b7bc57..b53dfea8 100644 --- a/examples/mnist/mnist.py +++ b/examples/mnist/mnist.py @@ -221,7 +221,7 @@ def mnist_jax(): reporter = ConsoleReporter() params = MNISTTestParameters(params=state.params, data=data) result = runner.run(HERE, params=params) - reporter.report(result) + reporter.write(result) if __name__ == "__main__": diff --git a/src/nnbench/__init__.py b/src/nnbench/__init__.py index 5b8942c1..26615356 100644 --- a/src/nnbench/__init__.py +++ b/src/nnbench/__init__.py @@ -8,7 +8,6 @@ # package is not installed pass -# TODO: This naming is unfortunate from .core import benchmark, parametrize, product from .reporter import BenchmarkReporter, register_reporter from .runner import BenchmarkRunner diff --git a/src/nnbench/reporter.py b/src/nnbench/reporter.py deleted file mode 100644 index d88080d9..00000000 --- a/src/nnbench/reporter.py +++ /dev/null @@ -1,184 +0,0 @@ -""" -A lightweight interface for refining, displaying, and streaming benchmark results to various sinks. -""" -from __future__ import annotations - -import collections -import importlib -import re -import sys -import types -from typing import Any, Callable - -from nnbench.types import BenchmarkRecord - - -def nullcols(_benchmarks: list[dict[str, Any]]) -> tuple[str, ...]: - """ - Extracts columns that only contain false-ish data from a list of benchmarks. - - Since this data is most often not interesting, the result of this - can be used to filter out these columns from the benchmark dictionaries. - - Parameters - ---------- - _benchmarks: list[dict[str, Any]] - The benchmarks to filter. - - Returns - ------- - tuple[str, ...] - Tuple of the columns (key names) that only contain false-ish values - across all benchmarks. - """ - nulls: dict[str, bool] = collections.defaultdict(bool) - for bm in _benchmarks: - for k, v in bm.items(): - nulls[k] = nulls[k] or bool(v) - return tuple(k for k, v in nulls.items() if not v) - - -def flatten(d: dict[str, Any], prefix: str = "", sep: str = ".") -> dict[str, Any]: - """ - Turn a nested dictionary into a flattened dictionary. - - Parameters - ---------- - d: dict[str, Any] - (Possibly) nested dictionary to flatten. - prefix: str - Key prefix to apply at the top-level (nesting level 0). - sep: str - Separator on which to join keys, "." by default. - - Returns - ------- - dict[str, Any] - The flattened dictionary. - """ - - items: list[tuple[str, Any]] = [] - for key, value in d.items(): - new_key = prefix + sep + key if prefix else key - if isinstance(value, dict): - items.extend(flatten(value, new_key, sep).items()) - else: - items.append((new_key, value)) - return dict(items) - - -# TODO: Add IO mixins for database, file, and HTTP IO -class BenchmarkReporter: - """ - The base interface for a benchmark reporter class. - - A benchmark reporter consumes benchmark results from a previous run, and subsequently - reports them in the way specified by the respective implementation's `report_result()` - method. - - For example, to write benchmark results to a database, you could save the credentials - for authentication on the class, and then stream the results directly to - the database in `report_result()`, with preprocessing if necessary. - """ - - merge: bool = False - """Whether to merge multiple BenchmarkRecords before reporting.""" - - def report_result(self, record: BenchmarkRecord) -> None: - raise NotImplementedError - - def report(self, *records: BenchmarkRecord) -> None: - if self.merge: - raise NotImplementedError - for record in records: - self.report_result(record) - - -class ConsoleReporter(BenchmarkReporter): - def __init__( - self, - tablefmt: str = "simple", - custom_formatters: dict[str, Callable[[Any], Any]] | None = None, - ): - self.tablefmt = tablefmt - self.custom_formatters: dict[str, Callable[[Any], Any]] = custom_formatters or {} - - def report_result( - self, - record: BenchmarkRecord, - benchmark_filter: str | None = None, - include_context: tuple[str, ...] = (), - exclude_empty: bool = True, - ) -> None: - try: - from tabulate import tabulate - except ModuleNotFoundError: - raise ValueError( - f"class {self.__class__.__name__}() requires `tabulate` to be installed. " - f"To install, run `{sys.executable} -m pip install --upgrade tabulate`." - ) - - ctx, benchmarks = record["context"], record["benchmarks"] - - nulls = set() if not exclude_empty else nullcols(benchmarks) - - if benchmark_filter is not None: - regex = re.compile(benchmark_filter, flags=re.IGNORECASE) - else: - regex = None - - filtered = [] - for bm in benchmarks: - if regex is not None and regex.search(bm["name"]) is None: - continue - filteredctx = { - k: v - for k, v in flatten(ctx).items() - if any(k.startswith(i) for i in include_context) - } - filteredbm = {k: v for k, v in bm.items() if k not in nulls} - filteredbm.update(filteredctx) - # only apply custom formatters after context merge - # to allow custom formatting of context values. - filteredbm = { - k: self.custom_formatters.get(k, lambda x: x)(v) for k, v in filteredbm.items() - } - filtered.append(filteredbm) - - print(tabulate(filtered, headers="keys", tablefmt=self.tablefmt)) - - -# internal, mutable -_reporter_registry: dict[str, type[BenchmarkReporter]] = { - "console": ConsoleReporter, -} - -# external, immutable -reporter_registry: types.MappingProxyType[str, type[BenchmarkReporter]] = types.MappingProxyType( - _reporter_registry -) - - -def register_reporter(key: str, cls_or_name: str | type[BenchmarkReporter]) -> None: - """ - Register a reporter class by its fully qualified module path. - - Parameters - ---------- - key: str - The key to register the reporter under. Subsequently, this key can be used in place - of reporter classes in code. - cls_or_name: str | type[BenchmarkReporter] - Name of or full module path to the reporter class. For example, when registering a class - ``MyReporter`` located in ``my_module``, ``name`` should be ``my_module.MyReporter``. - """ - - if isinstance(cls_or_name, str): - name = cls_or_name - modname, clsname = name.rsplit(".", 1) - mod = importlib.import_module(modname) - cls = getattr(mod, clsname) - _reporter_registry[key] = cls - else: - # name = cls_or_name.__module__ + "." + cls_or_name.__qualname__ - _reporter_registry[key] = cls_or_name diff --git a/src/nnbench/reporter/__init__.py b/src/nnbench/reporter/__init__.py new file mode 100644 index 00000000..79196fc4 --- /dev/null +++ b/src/nnbench/reporter/__init__.py @@ -0,0 +1,45 @@ +""" +A lightweight interface for refining, displaying, and streaming benchmark results to various sinks. +""" +from __future__ import annotations + +import importlib +import types + +from .base import BenchmarkReporter +from .console import ConsoleReporter + +# internal, mutable +_reporter_registry: dict[str, type[BenchmarkReporter]] = { + "console": ConsoleReporter, +} + +# external, immutable +reporter_registry: types.MappingProxyType[str, type[BenchmarkReporter]] = types.MappingProxyType( + _reporter_registry +) + + +def register_reporter(key: str, cls_or_name: str | type[BenchmarkReporter]) -> None: + """ + Register a reporter class by its fully qualified module path. + + Parameters + ---------- + key: str + The key to register the reporter under. Subsequently, this key can be used in place + of reporter classes in code. + cls_or_name: str | type[BenchmarkReporter] + Name of or full module path to the reporter class. For example, when registering a class + ``MyReporter`` located in ``my_module``, ``name`` should be ``my_module.MyReporter``. + """ + + if isinstance(cls_or_name, str): + name = cls_or_name + modname, clsname = name.rsplit(".", 1) + mod = importlib.import_module(modname) + cls = getattr(mod, clsname) + _reporter_registry[key] = cls + else: + # name = cls_or_name.__module__ + "." + cls_or_name.__qualname__ + _reporter_registry[key] = cls_or_name diff --git a/src/nnbench/reporter/base.py b/src/nnbench/reporter/base.py new file mode 100644 index 00000000..81c1fa64 --- /dev/null +++ b/src/nnbench/reporter/base.py @@ -0,0 +1,92 @@ +from typing import Sequence + +from nnbench.types import BenchmarkRecord + + +def default_merge(records: Sequence[BenchmarkRecord]) -> BenchmarkRecord: + """ + Merges a number of benchmark records into one. + + The resulting record has an empty top-level context, since the context + values might be different in all respective records. + + TODO: Think about merging contexts here to preserve the record model, + padding missing values with a placeholder if not present. + -> Might be easier with an OOP Context class. + + Parameters + ---------- + records: Sequence[BenchmarkRecord] + The records to merge. + + Returns + ------- + BenchmarkRecord + The merged record, with all benchmark contexts inlined into their + respective benchmarks. + + """ + merged = BenchmarkRecord(context=dict(), benchmarks=[]) + for record in records: + ctx, benchmarks = record["context"], record["benchmarks"] + for bm in benchmarks: + bm["context"] = ctx + merged["benchmarks"].extend(benchmarks) + return merged + + +# TODO: Add IO mixins for database, file, and HTTP IO +class BenchmarkReporter: + """ + The base interface for a benchmark reporter class. + + A benchmark reporter consumes benchmark results from a previous run, and subsequently + reports them in the way specified by the respective implementation's ``report_result()`` + method. + + For example, to write benchmark results to a database, you could save the credentials + for authentication on the class, and then stream the results directly to + the database in ``report_result()``, with preprocessing if necessary. + """ + + merge: bool = False + """Whether to merge multiple BenchmarkRecords before reporting.""" + + def initialize(self): + """ + Initialize the reporter's state. + + This is the place where to create a result directory, a database connection, + or a HTTP client. + """ + pass + + def finalize(self): + """ + Finalize the reporter's state. + + This is the place to destroy / release resources that were previously + acquired in ``initialize()``. + """ + pass + + merge_records = staticmethod(default_merge) + + def read(self) -> BenchmarkRecord: + raise NotImplementedError + + def read_batched(self) -> list[BenchmarkRecord]: + raise NotImplementedError + + def write(self, record: BenchmarkRecord) -> None: + raise NotImplementedError + + def write_batched(self, records: Sequence[BenchmarkRecord]) -> None: + # by default, merge first and then write. + if self.merge: + merged = self.merge_records(records) + self.write(merged) + else: + # write everything in a loop. + for record in records: + self.write(record) diff --git a/src/nnbench/reporter/console.py b/src/nnbench/reporter/console.py new file mode 100644 index 00000000..74263418 --- /dev/null +++ b/src/nnbench/reporter/console.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import re +import sys +from typing import Any, Callable + +from nnbench.reporter.base import BenchmarkReporter +from nnbench.reporter.util import flatten, nullcols +from nnbench.types import BenchmarkRecord + + +class ConsoleReporter(BenchmarkReporter): + def __init__( + self, + tablefmt: str = "simple", + custom_formatters: dict[str, Callable[[Any], Any]] | None = None, + ): + self.tablefmt = tablefmt + self.custom_formatters: dict[str, Callable[[Any], Any]] = custom_formatters or {} + + def write( + self, + record: BenchmarkRecord, + benchmark_filter: str | None = None, + include_context: tuple[str, ...] = (), + exclude_empty: bool = True, + ) -> None: + try: + from tabulate import tabulate + except ModuleNotFoundError: + raise ValueError( + f"class {self.__class__.__name__}() requires `tabulate` to be installed. " + f"To install, run `{sys.executable} -m pip install --upgrade tabulate`." + ) + + ctx, benchmarks = record["context"], record["benchmarks"] + + nulls = set() if not exclude_empty else nullcols(benchmarks) + + if benchmark_filter is not None: + regex = re.compile(benchmark_filter, flags=re.IGNORECASE) + else: + regex = None + + filtered = [] + for bm in benchmarks: + if regex is not None and regex.search(bm["name"]) is None: + continue + filteredctx = { + k: v + for k, v in flatten(ctx).items() + if any(k.startswith(i) for i in include_context) + } + filteredbm = {k: v for k, v in bm.items() if k not in nulls} + filteredbm.update(filteredctx) + # only apply custom formatters after context merge + # to allow custom formatting of context values. + filteredbm = { + k: self.custom_formatters.get(k, lambda x: x)(v) for k, v in filteredbm.items() + } + filtered.append(filteredbm) + + print(tabulate(filtered, headers="keys", tablefmt=self.tablefmt)) diff --git a/src/nnbench/reporter/util.py b/src/nnbench/reporter/util.py new file mode 100644 index 00000000..1c4abf2b --- /dev/null +++ b/src/nnbench/reporter/util.py @@ -0,0 +1,56 @@ +import collections +from typing import Any + + +def nullcols(_benchmarks: list[dict[str, Any]]) -> tuple[str, ...]: + """ + Extracts columns that only contain false-ish data from a list of benchmarks. + + Since this data is most often not interesting, the result of this + can be used to filter out these columns from the benchmark dictionaries. + + Parameters + ---------- + _benchmarks: list[dict[str, Any]] + The benchmarks to filter. + + Returns + ------- + tuple[str, ...] + Tuple of the columns (key names) that only contain false-ish values + across all benchmarks. + """ + nulls: dict[str, bool] = collections.defaultdict(bool) + for bm in _benchmarks: + for k, v in bm.items(): + nulls[k] = nulls[k] or bool(v) + return tuple(k for k, v in nulls.items() if not v) + + +def flatten(d: dict[str, Any], prefix: str = "", sep: str = ".") -> dict[str, Any]: + """ + Turn a nested dictionary into a flattened dictionary. + + Parameters + ---------- + d: dict[str, Any] + (Possibly) nested dictionary to flatten. + prefix: str + Key prefix to apply at the top-level (nesting level 0). + sep: str + Separator on which to join keys, "." by default. + + Returns + ------- + dict[str, Any] + The flattened dictionary. + """ + + items: list[tuple[str, Any]] = [] + for key, value in d.items(): + new_key = prefix + sep + key if prefix else key + if isinstance(value, dict): + items.extend(flatten(value, new_key, sep).items()) + else: + items.append((new_key, value)) + return dict(items)