diff --git a/notebooks/lava_cifar10.ipynb b/notebooks/lava_cifar10.ipynb new file mode 100644 index 000000000..0fe120a20 --- /dev/null +++ b/notebooks/lava_cifar10.ipynb @@ -0,0 +1,535 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "# LAVA\n", + "\n", + "This notebook explores the use of LAVA for data valuation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "If you are reading this in the documentation, some boilerplate has been omitted for convenience.\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports and setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "nbsphinx": "hidden", + "slideshow": { + "slide_type": "" + }, + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "%load_ext autoreload" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "nbsphinx": "hidden", + "slideshow": { + "slide_type": "" + }, + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "%autoreload\n", + "%matplotlib inline\n", + "\n", + "import logging\n", + "import os\n", + "from typing import Tuple\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay\n", + "from torch import nn\n", + "from torch.utils.data import DataLoader, Subset\n", + "from torchvision.datasets import CIFAR10\n", + "import torchvision.transforms as transforms\n", + "from torchvision.models.resnet import ResNet, BasicBlock, resnet18, ResNet18_Weights\n", + "from torchvision.utils import make_grid\n", + "from tqdm.auto import tqdm\n", + "\n", + "from support.common import (\n", + " plot_sample_images,\n", + " plot_losses,\n", + ")\n", + "from support.torch import (\n", + " TrainingManager,\n", + " MODEL_PATH,\n", + " new_resnet_model,\n", + ")\n", + "from support.types import Losses\n", + "\n", + "logging.basicConfig(level=logging.DEBUG)\n", + "\n", + "plt.rcParams[\"figure.figsize\"] = (7, 7)\n", + "plt.rcParams[\"font.size\"] = 12\n", + "plt.rcParams[\"xtick.labelsize\"] = 12\n", + "plt.rcParams[\"ytick.labelsize\"] = 10\n", + "plt.rcParams[\"axes.facecolor\"] = (1, 1, 1, 0)\n", + "plt.rcParams[\"figure.facecolor\"] = (1, 1, 1, 0)\n", + "\n", + "random_state = 42\n", + "np.random.seed(random_state)\n", + "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%autoreload\n", + "from pydvl.ot.lava import LAVA\n", + "from pydvl.utils.dataset import Dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading and preprocessing the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "transform = transforms.Compose(\n", + " [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n", + ")\n", + "trainset = CIFAR10(root=\"/tmp/cifar10\", train=True, download=True, transform=transform)\n", + "valset = CIFAR10(root=\"/tmp/cifar10\", train=False, download=True, transform=transform)\n", + "classes = trainset.classes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainset = Subset(trainset, np.random.randint(low=0, high=len(trainset), size=100))\n", + "valset = Subset(valset, np.random.randint(low=0, high=len(valset), size=100))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "" + }, + "tags": [ + "hide-output" + ] + }, + "outputs": [], + "source": [ + "trainloader = DataLoader(trainset, batch_size=4, shuffle=True, num_workers=0)\n", + "valloader = DataLoader(valset, batch_size=4, shuffle=True, num_workers=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's take a closer look at a few image samples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "" + }, + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "def imshow(img):\n", + " img = img / 2 + 0.5 # unnormalize\n", + " npimg = img.numpy()\n", + " plt.imshow(np.transpose(npimg, (1, 2, 0)))\n", + " plt.show()\n", + "\n", + "\n", + "# get some random training images\n", + "dataiter = iter(trainloader)\n", + "images, labels = next(dataiter)\n", + "\n", + "# show images\n", + "imshow(make_grid(images))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model definition and training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now train a model on the validation data (This is the same as in the paper) in order to use it as a feature extractor." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = ResNet(BasicBlock, [1, 1, 1, 1], num_classes=10)\n", + "num_params = sum(p.numel() for p in model.parameters())\n", + "print(f\"Model has {num_params} parameters\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "" + }, + "tags": [ + "hide-output" + ] + }, + "outputs": [], + "source": [ + "mgr = TrainingManager(\n", + " \"model_lava_cifar10\",\n", + " model,\n", + " nn.CrossEntropyLoss(),\n", + " valloader,\n", + " trainloader,\n", + " MODEL_PATH,\n", + " device=DEVICE,\n", + ")\n", + "# Set use_cache=False to retrain the model\n", + "train_loss, val_loss = mgr.train(n_epochs=10, use_cache=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "" + }, + "tags": [ + "invertible-output" + ] + }, + "outputs": [], + "source": [ + "plot_losses(Losses(train_loss, val_loss))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The confusion matrix and $F_1$ score look good, especially considering the low resolution of the images and their complexity (they contain different objects)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "" + }, + "tags": [ + "invertible-output" + ] + }, + "outputs": [], + "source": [ + "y_test = []\n", + "y_pred = []\n", + "\n", + "for inputs, targets in tqdm(valloader, total=len(valloader)):\n", + " y_test.append(targets.cpu().numpy().ravel())\n", + " inputs = inputs.to(DEVICE)\n", + " pred = np.argmax(model(inputs).cpu().detach().numpy(), axis=1).ravel()\n", + " y_pred.append(pred)\n", + "\n", + "\n", + "y_test = np.concatenate(y_test)\n", + "y_pred = np.concatenate(y_pred)\n", + "cm = confusion_matrix(y_test, y_pred)\n", + "disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)\n", + "disp.plot();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Feature Extraction" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now remove the last layer in order to use the model as a feature extractor" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.fc = torch.nn.Identity()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x_train = []\n", + "y_train = []\n", + "\n", + "for inputs, targets in tqdm(trainloader, total=len(trainloader)):\n", + " y_train.append(targets.cpu().numpy().ravel())\n", + " inputs = inputs.to(DEVICE)\n", + " pred = model(inputs).cpu().detach().numpy()\n", + " x_train.append(pred)\n", + "\n", + "x_train = np.concatenate(x_train)\n", + "y_train = np.concatenate(y_train)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x_test = []\n", + "y_test = []\n", + "\n", + "for inputs, targets in tqdm(valloader, total=len(valloader)):\n", + " y_test.append(targets.cpu().numpy().ravel())\n", + " inputs = inputs.to(DEVICE)\n", + " pred = model(inputs).cpu().detach().numpy()\n", + " x_test.append(pred)\n", + "\n", + "x_test = np.concatenate(x_test)\n", + "y_test = np.concatenate(y_test)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Computing Values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = Dataset(x_train, y_train, x_test, y_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "all_values = []\n", + "\n", + "for regularization in [1, 0.5, 0.1, 0.01]:\n", + " for lambda_ in [3.0, 1.0, 0.1, 0.01, 0]:\n", + " print(f\"{regularization=}, {lambda_=}\")\n", + " lava = LAVA(\n", + " dataset,\n", + " inner_ot_method=\"exact\",\n", + " regularization=regularization,\n", + " lambda_=lambda_,\n", + " )\n", + " values = lava.compute_values()\n", + " all_values.append(values)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "values_df = pd.DataFrame(np.stack(all_values).T)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "values_df.plot.boxplot();" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "regularization = 1.0\n", + "lambda_ = 1.0\n", + "lava = LAVA(\n", + " dataset, inner_ot_method=\"gaussian\", regularization=regularization, lambda_=lambda_\n", + ")\n", + "values = lava.compute_values()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "feature_cost = lava._compute_feature_cost()\n", + "plt.boxplot(feature_cost.ravel());" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lava._compute_gaussian_label_distances()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lava._compute_exact_label_distances()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.hist(values)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pre-Trained Model\n", + "\n", + "What if we use a pre-trained model instead?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)\n", + "model.fc = torch.nn.Identity()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + }, + "vscode": { + "interpreter": { + "hash": "b3369ace3ad477f5e763d9fa7767e0177027059e92a8b1ded9e92b707c0b1513" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/setup.py b/setup.py index cac6ca5d6..c043acbfa 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,7 @@ "zarr>=2.16.1", ], "ray": ["ray>=0.8"], + "lava": ["pot>=0.9"], }, author="appliedAI Institute gGmbH", long_description=long_description, diff --git a/src/pydvl/ot/__init__.py b/src/pydvl/ot/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/pydvl/ot/lava.py b/src/pydvl/ot/lava.py new file mode 100644 index 000000000..dbc4585d3 --- /dev/null +++ b/src/pydvl/ot/lava.py @@ -0,0 +1,219 @@ +""" +!!! tip "New in version 0.9.0" + +## References + +[^1]: Just et al. +[LAVA: Data Valuation without Pre-Specified Learning Algorithms](https://arxiv.org/abs/2305.00054). +In: Published at ICRL 2023 +""" + +import itertools +from typing import Callable, Literal, Tuple + +import numpy as np +import ot +from numpy.typing import NDArray +from ot.bregman import empirical_sinkhorn2 +from ot.gaussian import empirical_bures_wasserstein_distance + +from pydvl.utils.dataset import Dataset + +__all__ = ["LAVA"] + + +class LAVA: + """Computes Data values using LAVA. + + This implements the method described in + (Just et al., 2023)1. + + Args: + dataset: The dataset containing training and test samples. + regularization: Regularization parameter for sinkhorn iterations. + lambda_: Weight parameter for label distances. + inner_ot_method: Name of method used to compute the inner (instance-wise) + OT problem. Must be one of 'gaussian' or 'exact'. + If set to 'gaussian', the label distributions + are approximated as Gaussians, and thus their distance is computed as + the Bures-Wasserstein distance. If set to 'exact', no approximation is + used, and their distance is computed as an exact Wasserstein problem. + + Examples: + >>> from pydvl.ot.lava import LAVA + >>> from pydvl.utils.dataset import Dataset + >>> from sklearn.datasets import load_iris + >>> dataset = Dataset.from_sklearn(load_iris()) + >>> lava = LAVA(dataset) + >>> values = lava.compute_values() + >>> assert values.shape == (len(dataset.x_train),) + """ + + def __init__( + self, + dataset: Dataset, + *, + regularization: float = 0.1, + lambda_: float = 1.0, + inner_ot_method: Literal["exact", "gaussian"] = "exact", + ) -> None: + self.dataset = dataset + self.regularization = regularization + self.lambda_ = lambda_ + self.inner_ot_method = inner_ot_method + + def compute_values(self) -> NDArray: + """Compute calibrated gradients using Optimal Transport. + + Returns: + Array of dimensions `(n_train)` that contains the calibrated gradients. + """ + dual_solution = self._compute_ot_dual() + calibrated_gradients = self._compute_calibrated_gradients(dual_solution) + return calibrated_gradients + + def _compute_calibrated_gradients( + self, dual_solution: Tuple[NDArray, NDArray] + ) -> NDArray: + r"""Compute calibrated gradients using the dual solution of Optimal Transport problem. + + $$ + \frac{\partial OT(\mu_t, \mu_v )}{\partial \mu_t(z_i)} + = f_i^∗ - \sum\limits_{j\in\{1,\dots,N\} \setminus i} \frac{f_j^∗}{N-1} + $$ + + Args: + dual_solution: Dual solution of Optimal Transport problem. + + Returns: + Array of dimensions `(n_train)` + """ + f1k = np.array(dual_solution[0]) + training_size = len(self.dataset.x_train) + calibrated_gradients = (1 + 1 / (training_size - 1)) * f1k - f1k.sum() / ( + training_size - 1 + ) + return calibrated_gradients + + def _compute_ot_dual(self) -> Tuple[NDArray, NDArray]: + """Compute the dual solution of the Optimal Transport problem.""" + ground_cost = self._compute_ground_cost() + a = ot.unif(len(self.dataset.x_train)) + b = ot.unif(len(self.dataset.x_test)) + gamma, log = ot.sinkhorn( + a, + b, + ground_cost, + self.regularization, + log=True, + verbose=False, + numItermax=10000, + ) + u, v = log["u"], log["v"] + return u, v + + def _compute_ground_cost(self) -> NDArray: + label_cost = self._compute_label_cost() + feature_cost = self._compute_feature_cost() + ground_cost = feature_cost + self.lambda_ * label_cost + return ground_cost + + def _compute_feature_cost( + self, metric: Literal["euclidean", "sqeuclidean"] = "sqeuclidean" + ) -> NDArray: + """Compute distance between the features of the training and test sets. + + The first has dimensions `(n1, d1)` and the second has dimensions `(n2, d2)`. + + Args: + p: p-norm + + Returns: + Array with dimensions `(n1, n2)` + """ + distance = ot.dist(self.dataset.x_train, self.dataset.x_test, metric=metric) + return distance + + def _compute_label_cost(self) -> NDArray: + """Compute distances between classes in the training and test sets. + + The number of classes in the first set is `n_classes1` and the second is `n_classes2`. + + Returns: + An array with dimensions `(n_classes1, n_classes2)` + """ + if self.inner_ot_method == "exact": + ot_method = empirical_sinkhorn2 + reg = 0.1 + else: + ot_method = empirical_bures_wasserstein_distance + reg = 0.1 + + ( + D_train_train, + D_train_test, + D_test_test, + ) = self._compute_label_distances(ot_method, reg=reg) + + label_distances = np.concatenate( + [ + np.concatenate([D_train_train, D_train_test], axis=1), + np.concatenate([D_train_test, D_test_test], axis=1), + ], + axis=0, + ) + + """ + M = ( + label_distances.shape[1] * self.dataset.y_train[..., np.newaxis] + + self.dataset.y_test[np.newaxis, ...] + ) + label_cost = label_distances.ravel()[M.ravel()].reshape( + len(self.dataset.y_train), len(self.dataset.y_test) + ) + """ + indexing_array = ( + D_train_test.shape[1] * self.dataset.y_train[..., np.newaxis] + + self.dataset.y_test[np.newaxis, ...] + ) + label_cost = D_train_test.ravel()[indexing_array.ravel()].reshape( + len(self.dataset.y_train), len(self.dataset.y_test) + ) + + return label_cost + + def _compute_label_distances( + self, ot_method: Callable, reg: float = 0.1 + ) -> Tuple[NDArray, NDArray, NDArray]: + c_train = np.sort(np.unique(self.dataset.y_train)) + c_test = np.sort(np.unique(self.dataset.y_test)) + n_train, n_test = len(c_train), len(c_test) + + D_train_test = np.zeros((n_train, n_test)) + for i, j in itertools.product(range(n_train), range(n_test)): + distance = ot_method( + self.dataset.x_train[self.dataset.y_train == c_train[i]], + self.dataset.x_test[self.dataset.y_test == c_test[j]], + reg=reg, + ) + D_train_test[i, j] = distance + + D_train_train = np.zeros((n_train, n_train)) + for i, j in itertools.combinations(range(n_train), 2): + distance = empirical_sinkhorn2( + self.dataset.x_train[self.dataset.y_train == c_train[i]], + self.dataset.x_train[self.dataset.y_train == c_train[j]], + reg=reg, + ) + D_train_train[i, j] = D_train_train[j, i] = distance + + D_test_test = np.zeros((n_test, n_test)) + for i, j in itertools.combinations(range(n_train), 2): + distance = empirical_sinkhorn2( + self.dataset.x_test[self.dataset.y_test == c_test[i]], + self.dataset.x_test[self.dataset.y_test == c_test[j]], + reg=reg, + ) + D_test_test[i, j] = D_test_test[j, i] = distance + + return D_train_train, D_train_test, D_test_test diff --git a/tests/ot/__init__.py b/tests/ot/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/ot/test_lava.py b/tests/ot/test_lava.py new file mode 100644 index 000000000..2a234b286 --- /dev/null +++ b/tests/ot/test_lava.py @@ -0,0 +1,74 @@ +import numpy as np +import pytest +from numpy.testing import assert_array_equal +from sklearn.datasets import make_classification + +from pydvl.ot.lava import LAVA +from pydvl.utils.dataset import Dataset + + +@pytest.fixture +def synthetic_dataset(): + x, y = make_classification( + n_samples=8, + n_features=3, + n_informative=2, + n_repeated=0, + n_redundant=1, + n_classes=2, + random_state=16, + flip_y=0, + ) + dataset = Dataset.from_arrays( + x, y, train_size=0.5, stratify_by_target=True, random_state=16 + ) + return dataset + + +@pytest.fixture +def synthetic_dataset_same_train_test(synthetic_dataset): + synthetic_dataset.x_test = synthetic_dataset.x_train + synthetic_dataset.y_test = synthetic_dataset.y_train + return synthetic_dataset + + +@pytest.fixture +def synthetic_dataset_flipped_labels(synthetic_dataset): + rng = np.random.default_rng(16) + flip_mask = rng.uniform(size=len(synthetic_dataset.y_train)) < 0.1 + assert flip_mask.sum() > 0 + synthetic_dataset.y_train[flip_mask] = np.invert( + synthetic_dataset.y_train[flip_mask].astype(bool) + ).astype(int) + return synthetic_dataset, flip_mask + + +def test_lava_exact_and_gaussian(synthetic_dataset): + lava = LAVA(synthetic_dataset, inner_ot_method="gaussian") + gaussian_values = lava.compute_values() + lava = LAVA(synthetic_dataset, inner_ot_method="exact") + exact_values = lava.compute_values() + # We make sure that values are not all the same + assert_array_equal(gaussian_values, exact_values) + + +def test_lava_not_all_same_values(synthetic_dataset): + lava = LAVA(synthetic_dataset, inner_ot_method="gaussian") + values = lava.compute_values() + # We make sure that values are not all the same + assert np.any(~np.isclose(values, values[0])) + + +def test_lava_same_train_and_test(synthetic_dataset_same_train_test): + lava = LAVA(synthetic_dataset_same_train_test, inner_ot_method="gaussian") + values = lava.compute_values() + # We make sure that all values are zero + assert_array_equal(values, np.zeros_like(values)) + + +def test_lava_flipped_labels(synthetic_dataset_flipped_labels): + dataset, flip_mask = synthetic_dataset_flipped_labels + lava = LAVA(dataset, inner_ot_method="gaussian") + values = lava.compute_values() + # We make sure that values are not all the same + assert np.any(~np.isclose(values, values[0])) diff --git a/tox.ini b/tox.ini index d62cfe481..0a29d050f 100644 --- a/tox.ini +++ b/tox.ini @@ -10,6 +10,7 @@ extras = ray influence memcached + lava setenv = COVERAGE_FILE = {env:COVERAGE_FILE:{toxinidir}/.coverage.{envname}} passenv =