From f3b6ba20acefc15ff9af397da4ebe9edcba1f2a2 Mon Sep 17 00:00:00 2001 From: Jacob Kauffmann Date: Tue, 15 Aug 2023 16:07:00 +0200 Subject: [PATCH] Neuralized K-Means - documentation in numpydoc format - pylint + flake8 stuff - KMeansCanonizer - NeuralizedKMeans layer - LogMeanExpPool layer - Distance layer - Distance type --- src/zennit/canonizers.py | 101 +++++++++++++++++++++++++++++- src/zennit/layer.py | 129 +++++++++++++++++++++++++++++++++++++++ src/zennit/types.py | 9 +++ 3 files changed, 238 insertions(+), 1 deletion(-) diff --git a/src/zennit/canonizers.py b/src/zennit/canonizers.py index fc4a4f2..fce76cc 100644 --- a/src/zennit/canonizers.py +++ b/src/zennit/canonizers.py @@ -18,10 +18,12 @@ '''Functions to produce a canonical form of models fit for LRP''' from abc import ABCMeta, abstractmethod +import copy import torch from .core import collect_leaves -from .types import Linear, BatchNorm, ConvolutionTranspose +from .types import Linear, BatchNorm, ConvolutionTranspose, Distance +from .layer import NeuralizedKMeans, LogMeanExpPool class Canonizer(metaclass=ABCMeta): @@ -329,3 +331,100 @@ def register(self): def remove(self): '''Remove this Canonizer. Nothing to do for a CompositeCanonizer.''' + + +class KMeansCanonizer(Canonizer): + '''Canonizer for k-means. + + This canonizer replaces a :py:obj:`Distance` layer with power 2 with a :py:obj:`NeuralizedKMeans` layer followed by + a :py:obj:`LogMeanExpPool` + + Parameters + ---------- + beta : float + stiffness of the :py:obj:`LogMeanExpPool` layer. Should be smaller than 0 in order to approximate the min + function. Default is -1. + + Examples + -------- + >>> from sklearn.cluster import KMeans + >>> centroids = KMeans(n_clusters=10).fit(X).cluster_centers_ + >>> model = torch.nn.Sequential(Distance(torch.from_numpy(centroids).float(), power=2)) + >>> cluster_assignment = model(x).argmin() + >>> canonizer = KMeansCanonizer(beta=-1.) + >>> with Gradient(model, canonizer=[canonizer]) as attributor: + >>> output, attribution = attributor(x, torch.eye(len(centroids))[[cluster_assignment]]) + ''' + def __init__(self, beta=-1.): + self.distance = None + self.distance_unchanged = None + self.beta = beta + self.parent_module = None + self.child_name = None + + def apply(self, root_module): + '''Apply this canonizer recursively on all applicable modules. + + Iterates over all modules of the root module and applies this canonizer to all :py:obj:`Distance` layers with + power 2. + + Parameters + ---------- + root_module : :py:obj:`torch.nn.Module` + Root module containing a :py:obj:`Distance` layer with power 2 as a submodule. + ''' + instances = [] + + for full_name, module in root_module.named_modules(): + if isinstance(module, Distance) and module.power == 2: + instance = self.copy() + if '.' in full_name: + parent_name, child_name = full_name.rsplit('.', 1) + parent_module = getattr(root_module, parent_name) + else: + parent_module = root_module + child_name = full_name + + instance.parent_module = parent_module + instance.child_name = child_name + + instance.register(module) + instances.append(instance) + + return instances + + def register(self, distance_module): + '''Register the :py:obj:`Distance` layer and replace it with a :py:obj:`NeuralizedKMeans` layer followed by a + :py:obj:`LogMeanExpPool` layer. + + compute :math:`w_{ck} = 2(\\mathbf{\\mu}_c - \\mathbf{\\mu}_k)` and :math:`b_{ck} = \\|\\mathbf{\\mu}_k\\|^2 - + \\|\\mathbf{\\mu}_c\\|^2`. Weights are stored in a tensor :math:`W \\in \\mathbb{R}^{K \\times (K - 1) + \\times D}` and biases in a vector :math:`b \\in \\mathbb{R}^{K \\times (K - 1)}`. + + A :py:obj:`NeuralizedKMeans` layer is created with these weights and biases. The :py:obj:`LogMeanExpPool` layer + is created with the beta value supplied to the constructor. + + Parameters + ---------- + distance_module : list of :py:obj:`Distance` + Distance layers to replace. + ''' + self.distance = distance_module + self.distance_unchanged = copy.deepcopy(self.distance) + + n_clusters, n_dims = self.distance.centroids.shape + mask = ~torch.eye(n_clusters, dtype=bool) + weight = 2 * (self.distance.centroids[:, None, :] - self.distance.centroids[None, :, :]) + weight = weight[mask].reshape(n_clusters, n_clusters - 1, n_dims) + norms = torch.norm(self.distance.centroids, dim=-1) + bias = (norms[None, :]**2 - norms[:, None]**2)[mask].reshape(n_clusters, n_clusters - 1) + setattr(self.parent_module, self.child_name, + torch.nn.Sequential(NeuralizedKMeans(weight, bias), + LogMeanExpPool(self.beta))) + + def remove(self): + """Revert the changes introduced by this canonizer.""" + setattr(self.parent_module, self.child_name, self.distance_unchanged) + + def copy(self): + return KMeansCanonizer(self.beta) diff --git a/src/zennit/layer.py b/src/zennit/layer.py index bd93d90..fd29d82 100644 --- a/src/zennit/layer.py +++ b/src/zennit/layer.py @@ -34,3 +34,132 @@ def __init__(self, dim=-1): def forward(self, input): '''Computes the sum along a dimension.''' return torch.sum(input, dim=self.dim) + + +class Distance(torch.nn.Module): + '''Compute pairwise distances between two sets of points. + + Initialized with a set of centroids, this layer computes the pairwise distance between the input and the centroids. + + Parameters + ---------- + centroids : :py:obj:`torch.Tensor` + shape (K, D) tensor of centroids + power : float + power to raise the distance to + + Examples + -------- + >>> centroids = torch.randn(10, 2) + >>> distance = Distance(centroids) + >>> x = torch.randn(100, 2) + >>> distance(x) + + ''' + def __init__(self, centroids, power=2): + super().__init__() + self.centroids = torch.nn.Parameter(centroids) + self.power = power + + def forward(self, input): + '''Computes the pairwise distance between `input` and `self.centroids` and raises to the power `self.power`. + + Parameters + ---------- + input : :py:obj:`torch.Tensor` + shape (N, D) tensor of points + + Returns + ------- + :py:obj:`torch.Tensor` + shape (N, K) tensor of distances + ''' + distance = torch.cdist(input, self.centroids)**self.power + return distance + + +class NeuralizedKMeans(torch.nn.Module): + '''Compute the k-means discriminants for a set of points. + + Technically, this is a tensor-matrix product with a bias. + + Parameters + ---------- + weight : :py:obj:`torch.Tensor` + shape (K, K-1, D) tensor of weights + bias : :py:obj:`torch.Tensor` + shape (K, K-1) tensor of biases + + Examples + -------- + >>> weight = torch.randn(10, 9, 2) + >>> bias = torch.randn(10, 9) + >>> neuralized_kmeans = NeuralizedKMeans(weight, bias) + + ''' + def __init__(self, weight, bias): + super().__init__() + self.weight = torch.nn.Parameter(weight) + self.bias = torch.nn.Parameter(bias) + + def forward(self, x): + '''Computes the tensor-matrix product of `x` and `self.weight` and adds `self.bias`. + + Parameters + ---------- + x : :py:obj:`torch.Tensor` + shape (N, D) tensor of points + + Returns + ------- + :py:obj:`torch.Tensor` + shape (N, K, K-1) tensor of k-means discriminants + ''' + x = torch.einsum('nd,kjd->nkj', x, self.weight) + self.bias + return x + + +class LogMeanExpPool(torch.nn.Module): + '''Computes a log-mean-exp pool along an axis. + + LogMeanExpPool computes :math:`\\frac{1}{\\beta} \\log \\frac{1}{N} \\sum_{i=1}^N \\exp(\\beta x_i)` + + Parameters + ---------- + beta : float + stiffness of the pool. Positive values make the pool more like a max pool, negative values make the pool + more like a min pool. Default value is -1. + dim : int + dimension over which to pool + + Examples + -------- + >>> x = torch.randn(10, 2) + >>> pool = LogMeanExpPool() + >>> pool(x) + + ''' + def __init__(self, beta=1., dim=-1): + super().__init__() + self.dim = dim + self.beta = beta + + def forward(self, input): + '''Computes the LogMeanExpPool of `input`. + + If the input has shape (N1, N2, ..., Nk) and `self.dim` is `j`, then the output has shape + (N1, N2, ..., Nj-1, Nj+1, ..., Nk). + + Parameters + ---------- + input : :py:obj:`torch.Tensor` + the input tensor + + Returns + ------- + :py:obj:`torch.Tensor` + the LogMeanExpPool of `input` + ''' + n_dims = input.shape[self.dim] + return (torch.logsumexp(self.beta * input, dim=self.dim) + - torch.log(torch.tensor(n_dims, dtype=input.dtype))) / self.beta diff --git a/src/zennit/types.py b/src/zennit/types.py index 76cf78e..9641572 100644 --- a/src/zennit/types.py +++ b/src/zennit/types.py @@ -18,6 +18,8 @@ '''Type definitions for convenience.''' import torch +from .layer import Distance as DistanceLayer + class SubclassMeta(type): '''Meta class to bundle multiple subclasses.''' @@ -124,3 +126,10 @@ class Activation(metaclass=SubclassMeta): torch.nn.modules.activation.Tanhshrink, torch.nn.modules.activation.Threshold, ) + + +class Distance(metaclass=SubclassMeta): + '''Abstract base class that describes distance modules.''' + __subclass__ = ( + DistanceLayer, + )