diff --git a/notebooks/lava_cifar10.ipynb b/notebooks/lava_cifar10.ipynb
new file mode 100644
index 000000000..0fe120a20
--- /dev/null
+++ b/notebooks/lava_cifar10.ipynb
@@ -0,0 +1,535 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": []
+ },
+ "source": [
+ "# LAVA\n",
+ "\n",
+ "This notebook explores the use of LAVA for data valuation."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "
\n",
+ "\n",
+ "If you are reading this in the documentation, some boilerplate has been omitted for convenience.\n",
+ "\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Imports and setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "nbsphinx": "hidden",
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": [
+ "hide"
+ ]
+ },
+ "outputs": [],
+ "source": [
+ "%load_ext autoreload"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "nbsphinx": "hidden",
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": [
+ "hide"
+ ]
+ },
+ "outputs": [],
+ "source": [
+ "%autoreload\n",
+ "%matplotlib inline\n",
+ "\n",
+ "import logging\n",
+ "import os\n",
+ "from typing import Tuple\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import torch\n",
+ "from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay\n",
+ "from torch import nn\n",
+ "from torch.utils.data import DataLoader, Subset\n",
+ "from torchvision.datasets import CIFAR10\n",
+ "import torchvision.transforms as transforms\n",
+ "from torchvision.models.resnet import ResNet, BasicBlock, resnet18, ResNet18_Weights\n",
+ "from torchvision.utils import make_grid\n",
+ "from tqdm.auto import tqdm\n",
+ "\n",
+ "from support.common import (\n",
+ " plot_sample_images,\n",
+ " plot_losses,\n",
+ ")\n",
+ "from support.torch import (\n",
+ " TrainingManager,\n",
+ " MODEL_PATH,\n",
+ " new_resnet_model,\n",
+ ")\n",
+ "from support.types import Losses\n",
+ "\n",
+ "logging.basicConfig(level=logging.DEBUG)\n",
+ "\n",
+ "plt.rcParams[\"figure.figsize\"] = (7, 7)\n",
+ "plt.rcParams[\"font.size\"] = 12\n",
+ "plt.rcParams[\"xtick.labelsize\"] = 12\n",
+ "plt.rcParams[\"ytick.labelsize\"] = 10\n",
+ "plt.rcParams[\"axes.facecolor\"] = (1, 1, 1, 0)\n",
+ "plt.rcParams[\"figure.facecolor\"] = (1, 1, 1, 0)\n",
+ "\n",
+ "random_state = 42\n",
+ "np.random.seed(random_state)\n",
+ "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%autoreload\n",
+ "from pydvl.ot.lava import LAVA\n",
+ "from pydvl.utils.dataset import Dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Loading and preprocessing the dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "transform = transforms.Compose(\n",
+ " [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n",
+ ")\n",
+ "trainset = CIFAR10(root=\"/tmp/cifar10\", train=True, download=True, transform=transform)\n",
+ "valset = CIFAR10(root=\"/tmp/cifar10\", train=False, download=True, transform=transform)\n",
+ "classes = trainset.classes"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trainset = Subset(trainset, np.random.randint(low=0, high=len(trainset), size=100))\n",
+ "valset = Subset(valset, np.random.randint(low=0, high=len(valset), size=100))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": [
+ "hide-output"
+ ]
+ },
+ "outputs": [],
+ "source": [
+ "trainloader = DataLoader(trainset, batch_size=4, shuffle=True, num_workers=0)\n",
+ "valloader = DataLoader(valset, batch_size=4, shuffle=True, num_workers=0)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's take a closer look at a few image samples"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": [
+ "hide-input"
+ ]
+ },
+ "outputs": [],
+ "source": [
+ "def imshow(img):\n",
+ " img = img / 2 + 0.5 # unnormalize\n",
+ " npimg = img.numpy()\n",
+ " plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
+ " plt.show()\n",
+ "\n",
+ "\n",
+ "# get some random training images\n",
+ "dataiter = iter(trainloader)\n",
+ "images, labels = next(dataiter)\n",
+ "\n",
+ "# show images\n",
+ "imshow(make_grid(images))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Model definition and training"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We now train a model on the validation data (This is the same as in the paper) in order to use it as a feature extractor."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = ResNet(BasicBlock, [1, 1, 1, 1], num_classes=10)\n",
+ "num_params = sum(p.numel() for p in model.parameters())\n",
+ "print(f\"Model has {num_params} parameters\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": [
+ "hide-output"
+ ]
+ },
+ "outputs": [],
+ "source": [
+ "mgr = TrainingManager(\n",
+ " \"model_lava_cifar10\",\n",
+ " model,\n",
+ " nn.CrossEntropyLoss(),\n",
+ " valloader,\n",
+ " trainloader,\n",
+ " MODEL_PATH,\n",
+ " device=DEVICE,\n",
+ ")\n",
+ "# Set use_cache=False to retrain the model\n",
+ "train_loss, val_loss = mgr.train(n_epochs=10, use_cache=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": [
+ "invertible-output"
+ ]
+ },
+ "outputs": [],
+ "source": [
+ "plot_losses(Losses(train_loss, val_loss))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The confusion matrix and $F_1$ score look good, especially considering the low resolution of the images and their complexity (they contain different objects)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": [
+ "invertible-output"
+ ]
+ },
+ "outputs": [],
+ "source": [
+ "y_test = []\n",
+ "y_pred = []\n",
+ "\n",
+ "for inputs, targets in tqdm(valloader, total=len(valloader)):\n",
+ " y_test.append(targets.cpu().numpy().ravel())\n",
+ " inputs = inputs.to(DEVICE)\n",
+ " pred = np.argmax(model(inputs).cpu().detach().numpy(), axis=1).ravel()\n",
+ " y_pred.append(pred)\n",
+ "\n",
+ "\n",
+ "y_test = np.concatenate(y_test)\n",
+ "y_pred = np.concatenate(y_pred)\n",
+ "cm = confusion_matrix(y_test, y_pred)\n",
+ "disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)\n",
+ "disp.plot();"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Feature Extraction"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We now remove the last layer in order to use the model as a feature extractor"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model.fc = torch.nn.Identity()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x_train = []\n",
+ "y_train = []\n",
+ "\n",
+ "for inputs, targets in tqdm(trainloader, total=len(trainloader)):\n",
+ " y_train.append(targets.cpu().numpy().ravel())\n",
+ " inputs = inputs.to(DEVICE)\n",
+ " pred = model(inputs).cpu().detach().numpy()\n",
+ " x_train.append(pred)\n",
+ "\n",
+ "x_train = np.concatenate(x_train)\n",
+ "y_train = np.concatenate(y_train)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x_test = []\n",
+ "y_test = []\n",
+ "\n",
+ "for inputs, targets in tqdm(valloader, total=len(valloader)):\n",
+ " y_test.append(targets.cpu().numpy().ravel())\n",
+ " inputs = inputs.to(DEVICE)\n",
+ " pred = model(inputs).cpu().detach().numpy()\n",
+ " x_test.append(pred)\n",
+ "\n",
+ "x_test = np.concatenate(x_test)\n",
+ "y_test = np.concatenate(y_test)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Computing Values"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset = Dataset(x_train, y_train, x_test, y_test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "all_values = []\n",
+ "\n",
+ "for regularization in [1, 0.5, 0.1, 0.01]:\n",
+ " for lambda_ in [3.0, 1.0, 0.1, 0.01, 0]:\n",
+ " print(f\"{regularization=}, {lambda_=}\")\n",
+ " lava = LAVA(\n",
+ " dataset,\n",
+ " inner_ot_method=\"exact\",\n",
+ " regularization=regularization,\n",
+ " lambda_=lambda_,\n",
+ " )\n",
+ " values = lava.compute_values()\n",
+ " all_values.append(values)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "values_df = pd.DataFrame(np.stack(all_values).T)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "values_df.plot.boxplot();"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "regularization = 1.0\n",
+ "lambda_ = 1.0\n",
+ "lava = LAVA(\n",
+ " dataset, inner_ot_method=\"gaussian\", regularization=regularization, lambda_=lambda_\n",
+ ")\n",
+ "values = lava.compute_values()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "feature_cost = lava._compute_feature_cost()\n",
+ "plt.boxplot(feature_cost.ravel());"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "lava._compute_gaussian_label_distances()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "lava._compute_exact_label_distances()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plt.hist(values)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Pre-Trained Model\n",
+ "\n",
+ "What if we use a pre-trained model instead?"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)\n",
+ "model.fc = torch.nn.Identity()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.16"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": true,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": false,
+ "toc_position": {},
+ "toc_section_display": true,
+ "toc_window_display": false
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "b3369ace3ad477f5e763d9fa7767e0177027059e92a8b1ded9e92b707c0b1513"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/setup.py b/setup.py
index cac6ca5d6..c043acbfa 100644
--- a/setup.py
+++ b/setup.py
@@ -31,6 +31,7 @@
"zarr>=2.16.1",
],
"ray": ["ray>=0.8"],
+ "lava": ["pot>=0.9"],
},
author="appliedAI Institute gGmbH",
long_description=long_description,
diff --git a/src/pydvl/ot/__init__.py b/src/pydvl/ot/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/pydvl/ot/lava.py b/src/pydvl/ot/lava.py
new file mode 100644
index 000000000..dbc4585d3
--- /dev/null
+++ b/src/pydvl/ot/lava.py
@@ -0,0 +1,219 @@
+"""
+!!! tip "New in version 0.9.0"
+
+## References
+
+[^1]: Just et al.
+[LAVA: Data Valuation without Pre-Specified Learning Algorithms](https://arxiv.org/abs/2305.00054).
+In: Published at ICRL 2023
+"""
+
+import itertools
+from typing import Callable, Literal, Tuple
+
+import numpy as np
+import ot
+from numpy.typing import NDArray
+from ot.bregman import empirical_sinkhorn2
+from ot.gaussian import empirical_bures_wasserstein_distance
+
+from pydvl.utils.dataset import Dataset
+
+__all__ = ["LAVA"]
+
+
+class LAVA:
+ """Computes Data values using LAVA.
+
+ This implements the method described in
+ (Just et al., 2023)1.
+
+ Args:
+ dataset: The dataset containing training and test samples.
+ regularization: Regularization parameter for sinkhorn iterations.
+ lambda_: Weight parameter for label distances.
+ inner_ot_method: Name of method used to compute the inner (instance-wise)
+ OT problem. Must be one of 'gaussian' or 'exact'.
+ If set to 'gaussian', the label distributions
+ are approximated as Gaussians, and thus their distance is computed as
+ the Bures-Wasserstein distance. If set to 'exact', no approximation is
+ used, and their distance is computed as an exact Wasserstein problem.
+
+ Examples:
+ >>> from pydvl.ot.lava import LAVA
+ >>> from pydvl.utils.dataset import Dataset
+ >>> from sklearn.datasets import load_iris
+ >>> dataset = Dataset.from_sklearn(load_iris())
+ >>> lava = LAVA(dataset)
+ >>> values = lava.compute_values()
+ >>> assert values.shape == (len(dataset.x_train),)
+ """
+
+ def __init__(
+ self,
+ dataset: Dataset,
+ *,
+ regularization: float = 0.1,
+ lambda_: float = 1.0,
+ inner_ot_method: Literal["exact", "gaussian"] = "exact",
+ ) -> None:
+ self.dataset = dataset
+ self.regularization = regularization
+ self.lambda_ = lambda_
+ self.inner_ot_method = inner_ot_method
+
+ def compute_values(self) -> NDArray:
+ """Compute calibrated gradients using Optimal Transport.
+
+ Returns:
+ Array of dimensions `(n_train)` that contains the calibrated gradients.
+ """
+ dual_solution = self._compute_ot_dual()
+ calibrated_gradients = self._compute_calibrated_gradients(dual_solution)
+ return calibrated_gradients
+
+ def _compute_calibrated_gradients(
+ self, dual_solution: Tuple[NDArray, NDArray]
+ ) -> NDArray:
+ r"""Compute calibrated gradients using the dual solution of Optimal Transport problem.
+
+ $$
+ \frac{\partial OT(\mu_t, \mu_v )}{\partial \mu_t(z_i)}
+ = f_i^∗ - \sum\limits_{j\in\{1,\dots,N\} \setminus i} \frac{f_j^∗}{N-1}
+ $$
+
+ Args:
+ dual_solution: Dual solution of Optimal Transport problem.
+
+ Returns:
+ Array of dimensions `(n_train)`
+ """
+ f1k = np.array(dual_solution[0])
+ training_size = len(self.dataset.x_train)
+ calibrated_gradients = (1 + 1 / (training_size - 1)) * f1k - f1k.sum() / (
+ training_size - 1
+ )
+ return calibrated_gradients
+
+ def _compute_ot_dual(self) -> Tuple[NDArray, NDArray]:
+ """Compute the dual solution of the Optimal Transport problem."""
+ ground_cost = self._compute_ground_cost()
+ a = ot.unif(len(self.dataset.x_train))
+ b = ot.unif(len(self.dataset.x_test))
+ gamma, log = ot.sinkhorn(
+ a,
+ b,
+ ground_cost,
+ self.regularization,
+ log=True,
+ verbose=False,
+ numItermax=10000,
+ )
+ u, v = log["u"], log["v"]
+ return u, v
+
+ def _compute_ground_cost(self) -> NDArray:
+ label_cost = self._compute_label_cost()
+ feature_cost = self._compute_feature_cost()
+ ground_cost = feature_cost + self.lambda_ * label_cost
+ return ground_cost
+
+ def _compute_feature_cost(
+ self, metric: Literal["euclidean", "sqeuclidean"] = "sqeuclidean"
+ ) -> NDArray:
+ """Compute distance between the features of the training and test sets.
+
+ The first has dimensions `(n1, d1)` and the second has dimensions `(n2, d2)`.
+
+ Args:
+ p: p-norm
+
+ Returns:
+ Array with dimensions `(n1, n2)`
+ """
+ distance = ot.dist(self.dataset.x_train, self.dataset.x_test, metric=metric)
+ return distance
+
+ def _compute_label_cost(self) -> NDArray:
+ """Compute distances between classes in the training and test sets.
+
+ The number of classes in the first set is `n_classes1` and the second is `n_classes2`.
+
+ Returns:
+ An array with dimensions `(n_classes1, n_classes2)`
+ """
+ if self.inner_ot_method == "exact":
+ ot_method = empirical_sinkhorn2
+ reg = 0.1
+ else:
+ ot_method = empirical_bures_wasserstein_distance
+ reg = 0.1
+
+ (
+ D_train_train,
+ D_train_test,
+ D_test_test,
+ ) = self._compute_label_distances(ot_method, reg=reg)
+
+ label_distances = np.concatenate(
+ [
+ np.concatenate([D_train_train, D_train_test], axis=1),
+ np.concatenate([D_train_test, D_test_test], axis=1),
+ ],
+ axis=0,
+ )
+
+ """
+ M = (
+ label_distances.shape[1] * self.dataset.y_train[..., np.newaxis]
+ + self.dataset.y_test[np.newaxis, ...]
+ )
+ label_cost = label_distances.ravel()[M.ravel()].reshape(
+ len(self.dataset.y_train), len(self.dataset.y_test)
+ )
+ """
+ indexing_array = (
+ D_train_test.shape[1] * self.dataset.y_train[..., np.newaxis]
+ + self.dataset.y_test[np.newaxis, ...]
+ )
+ label_cost = D_train_test.ravel()[indexing_array.ravel()].reshape(
+ len(self.dataset.y_train), len(self.dataset.y_test)
+ )
+
+ return label_cost
+
+ def _compute_label_distances(
+ self, ot_method: Callable, reg: float = 0.1
+ ) -> Tuple[NDArray, NDArray, NDArray]:
+ c_train = np.sort(np.unique(self.dataset.y_train))
+ c_test = np.sort(np.unique(self.dataset.y_test))
+ n_train, n_test = len(c_train), len(c_test)
+
+ D_train_test = np.zeros((n_train, n_test))
+ for i, j in itertools.product(range(n_train), range(n_test)):
+ distance = ot_method(
+ self.dataset.x_train[self.dataset.y_train == c_train[i]],
+ self.dataset.x_test[self.dataset.y_test == c_test[j]],
+ reg=reg,
+ )
+ D_train_test[i, j] = distance
+
+ D_train_train = np.zeros((n_train, n_train))
+ for i, j in itertools.combinations(range(n_train), 2):
+ distance = empirical_sinkhorn2(
+ self.dataset.x_train[self.dataset.y_train == c_train[i]],
+ self.dataset.x_train[self.dataset.y_train == c_train[j]],
+ reg=reg,
+ )
+ D_train_train[i, j] = D_train_train[j, i] = distance
+
+ D_test_test = np.zeros((n_test, n_test))
+ for i, j in itertools.combinations(range(n_train), 2):
+ distance = empirical_sinkhorn2(
+ self.dataset.x_test[self.dataset.y_test == c_test[i]],
+ self.dataset.x_test[self.dataset.y_test == c_test[j]],
+ reg=reg,
+ )
+ D_test_test[i, j] = D_test_test[j, i] = distance
+
+ return D_train_train, D_train_test, D_test_test
diff --git a/tests/ot/__init__.py b/tests/ot/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/ot/test_lava.py b/tests/ot/test_lava.py
new file mode 100644
index 000000000..2a234b286
--- /dev/null
+++ b/tests/ot/test_lava.py
@@ -0,0 +1,74 @@
+import numpy as np
+import pytest
+from numpy.testing import assert_array_equal
+from sklearn.datasets import make_classification
+
+from pydvl.ot.lava import LAVA
+from pydvl.utils.dataset import Dataset
+
+
+@pytest.fixture
+def synthetic_dataset():
+ x, y = make_classification(
+ n_samples=8,
+ n_features=3,
+ n_informative=2,
+ n_repeated=0,
+ n_redundant=1,
+ n_classes=2,
+ random_state=16,
+ flip_y=0,
+ )
+ dataset = Dataset.from_arrays(
+ x, y, train_size=0.5, stratify_by_target=True, random_state=16
+ )
+ return dataset
+
+
+@pytest.fixture
+def synthetic_dataset_same_train_test(synthetic_dataset):
+ synthetic_dataset.x_test = synthetic_dataset.x_train
+ synthetic_dataset.y_test = synthetic_dataset.y_train
+ return synthetic_dataset
+
+
+@pytest.fixture
+def synthetic_dataset_flipped_labels(synthetic_dataset):
+ rng = np.random.default_rng(16)
+ flip_mask = rng.uniform(size=len(synthetic_dataset.y_train)) < 0.1
+ assert flip_mask.sum() > 0
+ synthetic_dataset.y_train[flip_mask] = np.invert(
+ synthetic_dataset.y_train[flip_mask].astype(bool)
+ ).astype(int)
+ return synthetic_dataset, flip_mask
+
+
+def test_lava_exact_and_gaussian(synthetic_dataset):
+ lava = LAVA(synthetic_dataset, inner_ot_method="gaussian")
+ gaussian_values = lava.compute_values()
+ lava = LAVA(synthetic_dataset, inner_ot_method="exact")
+ exact_values = lava.compute_values()
+ # We make sure that values are not all the same
+ assert_array_equal(gaussian_values, exact_values)
+
+
+def test_lava_not_all_same_values(synthetic_dataset):
+ lava = LAVA(synthetic_dataset, inner_ot_method="gaussian")
+ values = lava.compute_values()
+ # We make sure that values are not all the same
+ assert np.any(~np.isclose(values, values[0]))
+
+
+def test_lava_same_train_and_test(synthetic_dataset_same_train_test):
+ lava = LAVA(synthetic_dataset_same_train_test, inner_ot_method="gaussian")
+ values = lava.compute_values()
+ # We make sure that all values are zero
+ assert_array_equal(values, np.zeros_like(values))
+
+
+def test_lava_flipped_labels(synthetic_dataset_flipped_labels):
+ dataset, flip_mask = synthetic_dataset_flipped_labels
+ lava = LAVA(dataset, inner_ot_method="gaussian")
+ values = lava.compute_values()
+ # We make sure that values are not all the same
+ assert np.any(~np.isclose(values, values[0]))
diff --git a/tox.ini b/tox.ini
index d62cfe481..0a29d050f 100644
--- a/tox.ini
+++ b/tox.ini
@@ -10,6 +10,7 @@ extras =
ray
influence
memcached
+ lava
setenv =
COVERAGE_FILE = {env:COVERAGE_FILE:{toxinidir}/.coverage.{envname}}
passenv =