-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Explaining Deep Cluster Assignments with Neuralized K-Means on Image Data - I tried to adhere to guidelines - That means: random data, random weights - Code for real data and real weights in comments - Runs on colab, did not test blender - also adds the reference to docs/source/tutorial/index.rst
- Loading branch information
1 parent
f3b6ba2
commit 7047178
Showing
2 changed files
with
313 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,312 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "a0661ff4-9f41-405c-8453-f009c31e6a0e", | ||
"metadata": {}, | ||
"source": [ | ||
"## Explaining Deep Cluster Assignments with Neuralized K-Means on Image Data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b3aef718-d2a0-4f30-9b91-b53f5b288299", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# for colab folks\n", | ||
"# %pip install zennit" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "aa6d0ce7-ea3d-46e5-a8d7-e9a8b31d9239", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Basic boilerplate code\n", | ||
"from torchvision import datasets, transforms\n", | ||
"from torchvision.models import vgg16\n", | ||
"import torch\n", | ||
"import numpy as np\n", | ||
"\n", | ||
"transform_img = transforms.Compose([transforms.Resize(224), transforms.CenterCrop(224)])\n", | ||
"transform_norm = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n", | ||
"\n", | ||
"transform = transforms.Compose([\n", | ||
" transform_img,\n", | ||
" transforms.ToTensor(),\n", | ||
" transform_norm\n", | ||
"])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "d73397bd-14a2-48ee-8c42-46d6b5104115", | ||
"metadata": {}, | ||
"source": [ | ||
"### Real data and weights" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d5b258b8-c670-473f-858e-2f8464863e29", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# uncomment this cell for an example on real data and real weights\n", | ||
"### Data loading\n", | ||
"# from torch.utils.data import SubsetRandomSampler, DataLoader\n", | ||
"\n", | ||
"# # Attention: the next row downloads a dataset into the current folder!\n", | ||
"# dataset = datasets.Caltech101(root='.', transform=transform, download=True)\n", | ||
"\n", | ||
"# categories = ['cougar_body', 'Leopards', 'wild_cat']\n", | ||
"\n", | ||
"# all_indices = []\n", | ||
"# for category in categories:\n", | ||
"# category_idx = dataset.categories.index(category)\n", | ||
"# category_indices = [i for i, label in enumerate(dataset.y) if label == category_idx]\n", | ||
"\n", | ||
"# num_samples = min(7, len(category_indices))\n", | ||
"\n", | ||
"# selected_indices = np.random.choice(category_indices, num_samples, replace=False)\n", | ||
"# all_indices.extend(selected_indices)\n", | ||
"\n", | ||
"# sampler = SubsetRandomSampler(all_indices)\n", | ||
"# loader = DataLoader(dataset, batch_size=21, sampler=sampler)\n", | ||
"\n", | ||
"# # If this line throws a shape error, just run this cell again (some images in Caltech101 are grayscale)\n", | ||
"# images, labels = next(iter(loader))\n", | ||
"\n", | ||
"### Feature extractor\n", | ||
"# features = vgg16(weights='IMAGENET1K_V1').eval()._modules['features']" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "be3a0e0d-afa0-4af1-b8c2-a3f6525dcb03", | ||
"metadata": {}, | ||
"source": [ | ||
"### Random data and weights for online preview" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "736252a7-4445-408c-a946-163ea44903da", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# for zennit contribution guidelines\n", | ||
"# some random data and weights\n", | ||
"images, labels = transform_norm(torch.randn(3, 3, 224, 224).clamp(min=0, max=1)), torch.tensor([0,1,2])\n", | ||
"features = vgg16(weights=None).eval()._modules['features']" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "e7f02b4d-1da8-44ea-a887-6413d150b355", | ||
"metadata": {}, | ||
"source": [ | ||
"### The fun begins here\n", | ||
"\n", | ||
"We construct a feature map $\\phi$ from image space to feature space.\n", | ||
"Here, we sum over spatial locations in feature space to get more or less translation invariance in pixel space." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "eef79eae-f9c7-4b77-8d7c-5edff8e84aeb", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from zennit.layer import Sum\n", | ||
"\n", | ||
"phi = torch.nn.Sequential(\n", | ||
" features,\n", | ||
" Sum((2,3))\n", | ||
")\n", | ||
"\n", | ||
"Z = phi(images).detach()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "97b43d41-322a-483c-8506-93e3fa0a852d", | ||
"metadata": {}, | ||
"source": [ | ||
"Use simple `scikit-learn.KMeans` on the features:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "87c058d4-a3e4-4d29-af50-a7f2235e78c3", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# initialize on class means\n", | ||
"# because we have very few data points here\n", | ||
"centroids = np.stack([Z[labels == y].mean(0) for y in labels.unique()])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "309e1158-de08-4493-af07-32a592622a94", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"### uncomment for real fun\n", | ||
"# from sklearn.cluster import KMeans\n", | ||
"# standard_kmeans = KMeans(n_clusters=3, n_init='auto', init=centroids).fit(Z)\n", | ||
"# centroids = standard_kmeans.cluster_centers_" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "5d65f068-b651-4f87-81d4-54508b71c841", | ||
"metadata": {}, | ||
"source": [ | ||
"Now build a deep clustering model that takes images as input and predicts the k-means assignments\n", | ||
"\n", | ||
"We also apply a little scaling trick that makes heatmaps nicer, but usually does not change the cluster assignments." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "ce2dbb2a-8a97-488d-9f88-25426881ee10", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from zennit.layer import Distance\n", | ||
"\n", | ||
"# it's not necessary, just looks a bit nicer\n", | ||
"s = ((centroids**2).sum(-1, keepdims=True)**.5)\n", | ||
"s = s / s.mean()\n", | ||
"\n", | ||
"model = torch.nn.Sequential(\n", | ||
" phi,\n", | ||
" Distance(torch.from_numpy(centroids / s).float())\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "f177bbce-fe8f-46b8-b7a9-b9bfb9048145", | ||
"metadata": {}, | ||
"source": [ | ||
"### Enter zennit." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "06892de9-0add-448d-8b76-0f6ea3a0ccd7", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# import zennit\n", | ||
"from zennit.attribution import Gradient\n", | ||
"from zennit.composites import EpsilonGammaBox\n", | ||
"from zennit.image import imgify\n", | ||
"from zennit.torchvision import VGGCanonizer\n", | ||
"from zennit.canonizers import KMeansCanonizer\n", | ||
"from zennit.composites import LayerMapComposite, MixedComposite\n", | ||
"from zennit.layer import NeuralizedKMeans\n", | ||
"from zennit.rules import ZPlus, Gamma\n", | ||
"\n", | ||
"def data2img(x):\n", | ||
" return (x.squeeze().permute(1,2,0) * torch.tensor([0.229, 0.224, 0.225])) + torch.tensor([0.485, 0.456, 0.406])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "aac5b8af-61cc-400b-a0fc-b036148104ad", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# compute cluster assignments and check if they are equal\n", | ||
"# without the scaling trick above, the are definitely equal (trust me)\n", | ||
"ypred = model(images).argmin(1)\n", | ||
"assert (ypred.numpy() == standard_kmeans.predict(Z)).all()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "47e38917-b4ee-499f-ba9e-55cce7cb8163", | ||
"metadata": {}, | ||
"source": [ | ||
"### Everything is ready.\n", | ||
"\n", | ||
"You can play around with the `beta` parameter in `KMeansCanonizer` and the `gamma` parameter in `Gamma`.\n", | ||
"\n", | ||
"`beta` is a contrast parameter. Keep `beta < 0`.\n", | ||
"Small negative `beta` can be seen as *one-vs-all* explanation whereas large negative `beta` is more like *one-vs-nearest-competitor*.\n", | ||
"\n", | ||
"The `gamma` parameter controls the contribution of negative weights. Keep `gamma >= 0`.\n", | ||
"In practice, small (positive) `gamma` can result in entirely negative heatmaps. Think of thousand negative weights and a single positive weight. The positive weight could be enough to win the k-means assignment in feature space, but it's lost after a few layers because the graph is flooded with negative contributions.\n", | ||
"\n", | ||
"If you are trying to explain contribution to another cluster (say, $x$ is assigned to cluster $1$, but you want to see if there is some evidence for cluster $2$ in the image), then definitely cramp up `gamma` or even use `ZPlus` instead of `Gamma`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "aa0f7ca6-3e73-4254-ba31-26a6de28e690", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"canonizer = KMeansCanonizer(beta=-1e-12)\n", | ||
"\n", | ||
"low, high = transform_norm(torch.tensor([[[[[0.]]] * 3], [[[[1.]]] * 3]]))\n", | ||
"\n", | ||
"composite = MixedComposite([\n", | ||
" EpsilonGammaBox(low=low, high=high, canonizers=[canonizer]),\n", | ||
" LayerMapComposite([(NeuralizedKMeans, Gamma(gamma=1.))])\n", | ||
"])\n", | ||
"\n", | ||
"with Gradient(model=model, composite=composite) as attributor:\n", | ||
" for c in range(len(centroids)):\n", | ||
" print(\"Cluster %d\"%c)\n", | ||
" cluster_members = (ypred == c).nonzero()[:,0]\n", | ||
" for i in cluster_members:\n", | ||
" img = images[i].unsqueeze(0)\n", | ||
" target = torch.eye(len(centroids))[[c]]\n", | ||
" output, attribution = attributor(img, target)\n", | ||
" relevance = attribution[0].sum(0)\n", | ||
"\n", | ||
" heatmap = np.array(imgify(relevance, symmetric=True, cmap='seismic').convert('RGB'))\n", | ||
" display(imgify(np.stack([data2img(img).numpy(), heatmap]), grid=(1,2)))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters