Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API docs for all direct nnbench submodules #178

Merged
merged 4 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ repos:
types_or: [ python, pyi ]
args: [--ignore-missing-imports, --explicit-package-bases]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.0
rev: v0.8.1
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format
- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.5.4
rev: 0.5.5
hooks:
- id: uv-lock
name: Lock project dependencies
11 changes: 1 addition & 10 deletions src/nnbench/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,8 @@
"""A framework for organizing and running benchmark workloads on machine learning models."""

from .core import benchmark, parametrize, product
from .reporter import BenchmarkReporter, ConsoleReporter
from .reporter import BenchmarkReporter, ConsoleReporter, FileReporter
from .runner import BenchmarkRunner
from .types import Benchmark, BenchmarkRecord, Memo, Parameters

__version__ = "0.3.0"


# TODO: This isn't great, make it functional instead?
def default_runner() -> BenchmarkRunner:
return BenchmarkRunner()


def default_reporter() -> BenchmarkReporter:
return ConsoleReporter()
6 changes: 4 additions & 2 deletions src/nnbench/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""The ``nnbench`` command line interface."""

import argparse
import importlib
import logging
Expand Down Expand Up @@ -163,7 +165,7 @@ def construct_parser(config: nnbenchConfig) -> argparse.ArgumentParser:


def main() -> int:
"""The main nnbench CLI entry point."""
"""The main ``nnbench`` CLI entry point."""
config = parse_nnbench_config()
parser = construct_parser(config)
try:
Expand All @@ -183,7 +185,7 @@ def main() -> int:
builtin_providers[p.name] = klass(**p.arguments)
for val in args.context:
try:
k, v = val.split("=")
k, v = val.split("=", 1)
except ValueError:
raise ValueError("context values need to be of the form <key>=<value>")
if k == "provider":
Expand Down
42 changes: 42 additions & 0 deletions src/nnbench/compare.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Contains machinery to compare multiple benchmark records side by side."""

import copy
from collections.abc import Sequence

Expand All @@ -11,6 +13,30 @@


def get_value_by_name(record: BenchmarkRecord, name: str, missing: str) -> str:
"""
Get the value of a metric by name from a benchmark record, or a placeholder
if the metric name is not present in the record.

If the name is found, but the benchmark did not complete successfully
(i.e. the ``error_occurred`` value is set to ``True``), the returned value
will be set to the value of the ``error_message`` field.

Parameters
----------
record: BenchmarkRecord
The benchmark record to extract a metric value from.
name: str
The name of the target metric.
missing: str
A placeholder string to return in the event of a missing metric.

Returns
-------
str
A string containing the metric value (or error message) formatted
as rich text.

"""
metric_names = [b["name"] for b in record.benchmarks]
if name not in metric_names:
return missing
Expand All @@ -28,6 +54,22 @@ def compare(
contextvals: Sequence[str] | None = None,
missing: str = _MISSING,
) -> None:
"""
Compare a series of benchmark records, displaying their results in a table
side by side.

Parameters
----------
records: Sequence[BenchmarkRecord]
The benchmark records to compare.
parameters: Sequence[str] | None
Names of parameters to display as extra columns.
contextvals: Sequence[str] | None
Names of context values to display as extra columns. Supports nested access
via dotted syntax.
missing: str
A placeholder string to show in the event of a missing metric.
"""
t = Table()

rows: list[list[str]] = []
Expand Down
60 changes: 59 additions & 1 deletion src/nnbench/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Utilities for parsing an nnbench config block out of a pyproject.toml."""
"""Utilities for parsing an nnbench config block out of a pyproject.toml file."""

import logging
import os
Expand All @@ -20,29 +20,70 @@

@dataclass
class ContextProviderDef:
"""
A POD struct representing a custom context provider definition in a
pyproject.toml table.
"""

name: str
"""Name under which the provider should be registered by nnbench."""
classpath: str
"""Full path to the class or callable returning the context dict."""
arguments: dict[str, Any]
"""Arguments needed to instantiate the context provider class,
given as key-value pairs in the table."""


@dataclass(frozen=True)
class nnbenchConfig:
log_level: str
"""Log level to use for the ``nnbench`` module root logger."""
context: list[ContextProviderDef]
"""A list of context provider definitions found in pyproject.toml."""

@classmethod
def empty(cls) -> Self:
"""An empty default config, returned if no pyproject.toml is found."""
return cls(log_level="NOTSET", context=[])

@classmethod
def from_toml(cls, d: dict[str, Any]) -> Self:
"""
Returns an nnbench CLI config by parsing a [tool.nnbench] block from a
pyproject.toml file.

Parameters
----------
d: dict[str, Any]
Mapping containing the [tool.nnbench] block as obtained by
``tomllib.load``.

Returns
-------
Self
An nnbench config instance with the values from pyproject.toml,
with defaults for values that were not set explicitly.
"""
provider_map = d.get("context", {})
context = [ContextProviderDef(**cpd) for cpd in provider_map.values()]
log_level = d.get("log-level", "NOTSET")
return cls(log_level=log_level, context=context)


def locate_pyproject() -> os.PathLike[str]:
"""
Locate a pyproject.toml file by walking up from the current directory,
and checking for file existence, stopping at the current user home
directory.

If no pyproject.toml file can be found, a RuntimeError is raised.

Returns
-------
os.PathLike[str]
The path to pyproject.toml.

"""
cwd = Path.cwd()
for p in (cwd, *cwd.parents):
if (pyproject_cand := (p / "pyproject.toml")).exists():
Expand All @@ -53,6 +94,23 @@ def locate_pyproject() -> os.PathLike[str]:


def parse_nnbench_config(pyproject_path: str | os.PathLike[str] | None = None) -> nnbenchConfig:
"""
Load an nnbench config from a given pyproject.toml file.

If no path to the pyproject.toml file is given, an attempt at autodiscovery
will be made. If that is unsuccessful, an empty config is returned.

Parameters
----------
pyproject_path: str | os.PathLike[str] | None
Path to the current project's pyproject.toml file, optional.

Returns
-------
nnbenchConfig
The loaded config if found, or a default config.

"""
if pyproject_path is None:
try:
pyproject_path = locate_pyproject()
Expand Down
22 changes: 18 additions & 4 deletions src/nnbench/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ def __call__(self) -> dict[str, Any]:
result["buildno"] = buildno
result["buildtime"] = buildtime

dependencies: dict[str, str] = {}
packages: dict[str, str] = {}
for pkg in self.packages:
try:
dependencies[pkg] = version(pkg)
packages[pkg] = version(pkg)
except PackageNotFoundError:
dependencies[pkg] = ""
packages[pkg] = ""

result["packages"] = dependencies
result["packages"] = packages
return {self.key: result}


Expand Down Expand Up @@ -140,6 +140,20 @@ def git_subprocess(args: list[str]) -> subprocess.CompletedProcess:


class CPUInfo:
"""
A context helper providing information about the host machine's CPU
capabilities, operating system, and amount of memory.

Parameters
----------
memunit: Literal["kB", "MB", "GB"]
The unit to display memory size in (either "kB" for kilobytes,
"MB" for Megabytes, or "GB" for Gigabytes).
frequnit: Literal["kHz", "MHz", "GHz"]
The unit to display CPU clock speeds in (either "kHz" for kilohertz,
"MHz" for Megahertz, or "GHz" for Gigahertz).
"""

key = "cpu"

def __init__(
Expand Down
62 changes: 52 additions & 10 deletions src/nnbench/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""
Collect values ('fixtures') by name for benchmark runs from certain files,
similarly to pytest and its ``conftest.py``.
"""

import inspect
import os
from collections.abc import Callable, Iterable
Expand Down Expand Up @@ -39,19 +44,25 @@ class FixtureManager:
A lean class responsible for resolving parameter values (aka 'fixtures')
of benchmarks from provider functions.

To resolve a benchmark parameter (in resolve()), the class does
the following:
To resolve a benchmark parameter (in ``FixtureManager.resolve()``), the class
does the following:

1. Obtain the path to the file containing the benchmark, as
the __file__ attribute of the benchmark function's origin module.
the ``__file__`` attribute of the benchmark function's origin module.

2. Look for a `conf.py` file in the same directory.

3. Import the `conf.py` module, look for a function named the same as
the benchmark parameter.
the benchmark parameter.

4. If necessary, resolve any named inputs to the function **within**
the module scope.
the module scope.

5. If no function member is found, and the benchmark file is not in `root`,
fall back to the parent directory, repeat steps 2-5, until `root` is reached.
fall back to the parent directory, repeat steps 2-5, until `root` is reached.

6. If no `conf.py` contains any function matching the name, throw an
error (TODO: ImportError? custom?)
error.
"""

def __init__(self, root: str | os.PathLike[str]) -> None:
Expand All @@ -65,10 +76,22 @@ def __init__(self, root: str | os.PathLike[str]) -> None:

def collect(self, mod: ModuleType, names: Iterable[str]) -> dict[str, Any]:
"""
Given a fixture module and a list of fixture names required (for a
Given a module containing fixtures (contents of a ``conf.py`` file imported
as a module), and a list of required fixture names (for a
selected benchmark), collect values, computing transitive closures in the
meantime (i.e., inputs required to compute certain fixtures), and add
the resulting values to the cache.
process (i.e., all inputs required to compute the set of fixtures).

Parameters
----------
mod: ModuleType
The module to import fixture values from.
names: Iterable[str]
Names of fixture values to compute and use in the invoking benchmark.

Returns
-------
dict[str, Any]
The mapping of fixture names to their values.
"""
res: dict[str, Any] = {}
for name in names:
Expand Down Expand Up @@ -100,6 +123,25 @@ def collect(self, mod: ModuleType, names: Iterable[str]) -> dict[str, Any]:
return res

def resolve(self, bm: Benchmark) -> dict[str, Any]:
"""
Resolve fixture values for a benchmark.

Fixtures will be resolved only for benchmark inputs that do not have a
default value in place in the interface.

Fixtures need to be functions in a ``conf.py`` module in the benchmark
directory structure, and must *exactly* match input parameters by name.

Parameters
----------
bm: Benchmark
The benchmark to resolve fixtures for.

Returns
-------
dict[str, Any]
The mapping of fixture values to use for the given benchmark.
"""
fixturevals: dict[str, Any] = {}
# first, get the candidate fixture names, aka the benchmark param names.
# We limit ourselves to names that do not have a default value.
Expand Down
3 changes: 2 additions & 1 deletion src/nnbench/reporter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""
A lightweight interface for refining, displaying, and streaming benchmark results to various sinks.
An interface for displaying, writing, or streaming benchmark results to
files, databases, or web services.
"""

from .base import BenchmarkReporter
Expand Down
Loading