Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typing and build fixes to benchmarking #909

Merged
merged 7 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
frames) (https://github.com/gchq/coreax/issues/892)
- `benchmark` dependency group for benchmarking dependencies.
(https://github.com/gchq/coreax/pull/888)
- `example` dependency group for running example scripts.
(https://github.com/gchq/coreax/pull/909)
- Added a method `SquaredExponentialKernel.get_sqrt_kernel` which returns a square
root kernel for the squared exponential kernel. (https://github.com/gchq/coreax/pull/883)

Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ pip install jax[cuda12]
```

There are optional sets of additional dependencies:
* `coreax[test]` is required to run the tests and examples;
* `coreax[test]` is required to run the tests;
* `coreax[example]` contains all dependencies for the example scripts;
* `coreax[benchmark]` is required to run benchmarking;
* `coreax[doc]` is for compiling the Sphinx documentation;
* `coreax[dev]` includes all tools and packages a developer of Coreax might need.
Expand Down
10 changes: 5 additions & 5 deletions benchmark/david_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
import matplotlib.pyplot as plt
import numpy as np
from jax import random
from mnist_benchmark import get_solver_name, initialise_solvers

from benchmark.mnist_benchmark import get_solver_name, initialise_solvers
from coreax import Data
from examples.david_map_reduce_weighted import downsample_opencv

Expand Down Expand Up @@ -85,15 +85,15 @@ def benchmark_coreset_algorithms(

# Initialize each coreset solver
key = random.PRNGKey(0)
solvers = initialise_solvers(data, key)
solver_factories = initialise_solvers(data, key)

# Dictionary to store coresets generated by each method
coresets = {}
solver_times = {}

for get_solver in solvers:
solver = get_solver(coreset_size)
solver_name = get_solver_name(solver)
for solver_creator in solver_factories:
solver = solver_creator(coreset_size)
solver_name = get_solver_name(solver_creator)
start_time = time.perf_counter()
coreset, _ = eqx.filter_jit(solver.reduce)(data)
duration = time.perf_counter() - start_time
Expand Down
37 changes: 19 additions & 18 deletions benchmark/mnist_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
import json
import os
import time
from typing import Any, Callable, NamedTuple, Optional, Union
from collections.abc import Callable
from typing import Any, NamedTuple, Optional, Union

import equinox as eqx
import jax
Expand All @@ -62,6 +63,7 @@
MapReduce,
RandomSample,
RPCholesky,
Solver,
SteinThinning,
)

Expand Down Expand Up @@ -426,7 +428,7 @@ def prepare_datasets() -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarr

def initialise_solvers(
train_data_umap: Data, key: jax.random.PRNGKey
) -> list[Callable]:
) -> list[Callable[[int], Solver]]:
"""
Initialise and return a list of solvers for various coreset algorithms.

Expand All @@ -437,9 +439,8 @@ def initialise_solvers(
enabling easy integration in a loop for benchmarking.

:param train_data_umap: The UMAP-transformed training data used for
length scale estimation for ``SquareExponentialKernel``.
length scale estimation for ``SquareExponentialKernel``.
:param key: The random key for initialising random solvers.

:return: A list of solvers functions for different coreset algorithms.
"""
# Set up kernel using median heuristic
Expand Down Expand Up @@ -592,23 +593,23 @@ def save_results(results: dict) -> None:
print(f"Data has been saved to {file_name}")


def get_solver_name(_solver: Callable) -> str:
def get_solver_name(solver: Callable[[int], Solver]) -> str:
"""
Get the name of the solver.

This function extracts and returns the name of the solver class.
If the `_solver` is an instance of the `MapReduce` class, it retrieves the
name of the `base_solver` class instead.
If ``_solver`` is an instance of :class:`~coreax.solvers.MapReduce`, it retrieves
the name of the :class:`~coreax.solvers.MapReduce.base_solver` class instead.

:param _solver: An instance of a solver, such as `MapReduce` or `RandomSample`.
:param solver: An instance of a solver, such as `MapReduce` or `RandomSample`.
:return: The name of the solver class.
"""
solver_name = (
_solver.base_solver.__class__.__name__
if _solver.__class__.__name__ == "MapReduce"
else _solver.__class__.__name__
)
return solver_name
# Evaluate solver function to get an instance to interrogate
# Don't just inspect type annotations, as they may be incorrect - not robust
solver_instance = solver(1)
if isinstance(solver_instance, MapReduce):
return type(solver_instance.base_solver).__name__
return type(solver_instance).__name__


# pylint: disable=too-many-locals
Expand Down Expand Up @@ -646,11 +647,11 @@ def main() -> None:
for i in range(5):
print(f"Run {i + 1} of 5:")
key = jax.random.PRNGKey(i)
solvers = initialise_solvers(train_data_umap, key)
for getter in solvers:
solver_factories = initialise_solvers(train_data_umap, key)
for solver_creator in solver_factories:
for size in [25, 50, 100, 500, 1_000, 5_000]:
solver = getter(size)
solver_name = get_solver_name(solver)
solver = solver_creator(size)
solver_name = get_solver_name(solver_creator)
start_time = time.perf_counter()
# pylint: enable=duplicate-code
coreset, _ = eqx.filter_jit(solver.reduce)(train_data_umap)
Expand Down
17 changes: 7 additions & 10 deletions benchmark/mnist_benchmark_coresets_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@

import equinox as eqx
import jax
from mnist_benchmark import (

from benchmark.mnist_benchmark import (
density_preserving_umap,
get_solver_name,
initialise_solvers,
prepare_datasets,
)

from coreax import Data


Expand Down Expand Up @@ -100,15 +101,11 @@ def main() -> None:
# Run the experiment with 5 different random keys
for i in range(5):
key = jax.random.PRNGKey(i)
solvers = initialise_solvers(train_data_umap, key)
for getter in solvers:
solver_factories = initialise_solvers(train_data_umap, key)
for solver_creator in solver_factories:
for size in [25, 50, 100, 500, 1_000]:
solver = getter(size)
solver_name = (
solver.base_solver.__class__.__name__
if solver.__class__.__name__ == "MapReduce"
else solver.__class__.__name__
)
solver = solver_creator(size)
solver_name = get_solver_name(solver_creator)
start_time = time.perf_counter()
_, _ = eqx.filter_jit(solver.reduce)(train_data_umap)
time_taken = time.perf_counter() - start_time
Expand Down
21 changes: 11 additions & 10 deletions benchmark/pounce_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
import numpy as np
import umap
from jax import random
from mnist_benchmark import get_solver_name, initialise_solvers

from benchmark.mnist_benchmark import get_solver_name, initialise_solvers
from coreax.data import Data
from coreax.solvers import MapReduce


def benchmark_coreset_algorithms(
Expand Down Expand Up @@ -61,15 +62,15 @@ def benchmark_coreset_algorithms(
umap_model = umap.UMAP(densmap=True, n_components=25)
umap_data = umap_model.fit_transform(reshaped_data)

solvers = initialise_solvers(umap_data, random.PRNGKey(45))
# There is no need to use MapReduce as the data-size is small
solvers = [
solver.base_solver if solver.__class__.__name__ == "MapReduce" else solver
for solver in solvers
]
for get_solver in solvers:
solver = get_solver(coreset_size)
solver_name = get_solver_name(solver)
solver_factories = initialise_solvers(umap_data, random.PRNGKey(45))
for solver_creator in solver_factories:
solver = solver_creator(coreset_size)

# There is no need to use MapReduce as the data-size is small
if isinstance(solver, MapReduce):
solver = solver.base_solver

solver_name = get_solver_name(solver_creator)
data = Data(jnp.array(umap_data))

start_time = time.perf_counter()
Expand Down
4 changes: 2 additions & 2 deletions examples/david_map_reduce_weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@

def downsample_opencv(image_path: str, downsampling_factor: int) -> np.ndarray:
"""
Downsample an image using `func: cv2.resize` and convert it to grayscale.
Downsample an image using func:`~cv2.resize` and convert it to grayscale.

:param image_path: Path to the input image file.
:param downsampling_factor: Factor by which to downsample the image.
Expand Down Expand Up @@ -99,7 +99,7 @@ def main(
downsampling_factor: int = 1,
) -> tuple[float, float]:
"""
Run the 'david' example for image sampling.
Run the 'David' example for image sampling.

Take an image of the statue of David and then generate a coreset using
scalable Stein kernel herding.
Expand Down
18 changes: 11 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,16 @@ dependencies = [
]

[project.optional-dependencies]
# Benchmarking
# Example scripts
example = [
"imageio",
"matplotlib",
"numpy",
"opencv-python-headless", # WARNING: Incompatible with other versions of opencv
]
# Benchmarking - runs very similar code to examples with same dependencies plus more
benchmark = [
"coreax[example]",
"torch",
"torchvision",
"umap-learn>=0.5.7",
Expand All @@ -49,10 +57,6 @@ benchmark = [
test = [
"coreax[benchmark]",
"beartype",
"imageio",
"matplotlib",
"numpy",
"opencv-python-headless", # WARNING: Incompatible with other versions of opencv
"pytest-cov",
"pytest-rerunfailures",
"scipy",
Expand All @@ -70,8 +74,8 @@ doc = [
[dependency-groups]
# All tools for a developer including those for running pylint
dev = [
"coreax[benchmark, doc, test]",
"jupyter",
"coreax[example, benchmark, test, doc]",
"jupyter", # Include as developers may wish to write their own notebooks
"ruff",
"pre-commit>=3.7",
"pylint",
Expand Down
39 changes: 26 additions & 13 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading