Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Confused logit callback #118

Merged
merged 24 commits into from
Jul 31, 2020
Binary file added docs/source/_images/vision/confused_logit.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ PyTorch-Lightning-Bolts documentation
callbacks
info_callbacks
variational_callbacks
vision_callbacks

.. toctree::
:maxdepth: 2
Expand Down
6 changes: 6 additions & 0 deletions docs/source/variational_callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,11 @@ Latent Dim Interpolator
-----------------------
Interpolates latent dims.

Example output:

.. image:: _images/gans/basic_gan_interpolate.jpg
:width: 400
:alt: Example latent space interpolation

.. autoclass:: pl_bolts.callbacks.variational.LatentDimInterpolator
:noindex:
22 changes: 22 additions & 0 deletions docs/source/vision_callbacks.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
.. role:: hidden
:class: hidden-section

Vision Callbacks
================
Useful callbacks for vision models

---------------

Confused Logit
--------------
Shows how the input would have to change to move the prediction from one logit to the other


Example outputs:

.. image:: _images/vision/confused_logit.png
:width: 400
:alt: Example of prediction confused between 5 and 8

.. autoclass:: pl_bolts.callbacks.vision.confused_logit.ConfusedLogitCallback
:noindex:
1 change: 1 addition & 0 deletions pl_bolts/callbacks/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from pl_bolts.callbacks.vision.confused_logit import ConfusedLogitCallback
129 changes: 129 additions & 0 deletions pl_bolts/callbacks/vision/confused_logit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import torch
from torch import nn
from pytorch_lightning import Callback


class ConfusedLogitCallback(Callback):

def __init__(self, top_k, projection_factor=3):
"""
Takes the logit predictions of a model and when the probabilities of two classes are very close, the model
doesn't have high certainty that it should pick one vs the other class.

This callback shows how the input would have to change to swing the model from one label prediction
to the other.

In this case, the network predicts a 5... but gives almost equal probability to an 8.
The images show what about the original 5 would have to change to make it more like a 5 or more like an 8.

For each confused logit $l_i$ the confused images are generated by $\partial l_i/ \partial x$,
where $i \in \{0, 1\}$ means the top two confused logits.

Example::

from pl_bolts.callbacks.vision import ConfusedLogitCallback
trainer = Trainer(callbacks=[ConfusedLogitCallback()])


.. note:: whenever called, this model will look for self.last_batch and self.last_logits in the LightningModule

.. note:: this callback supports tensorboard only right now

Args:
top_k: How many "offending" images we should plot
projection_factor: How much to multiply the input image to make it look more like this logit label

Authored by:

- Alfredo Canziani

"""
super().__init__()
self.top_k = top_k
self.projection_factor = projection_factor

def on_batch_end(self, trainer, pl_module):

# show images only every 20 batches
if (trainer.batch_idx + 1) % 20 != 0:
williamFalcon marked this conversation as resolved.
Show resolved Hide resolved
return

# pick the last batch and logits
# TODO: use context instead
x, y = pl_module.last_batch
l = pl_module.last_logits

# only check when it has opinions (ie: the logit > 5)
if l.max() > 5.0:
williamFalcon marked this conversation as resolved.
Show resolved Hide resolved
# pick the top two confused probs
(values, idxs) = torch.topk(l, k=2, dim=1)

# care about only the ones that are at most eps close to each other
eps = 0.1
mask = (values[:, 0] - values[:, 1]).abs() < eps

if mask.sum() > 0:
# pull out the ones we care about
confusing_x = x[mask, ...]
confusing_y = y[mask]
confusing_l = l[mask]

mask_idxs = idxs[mask]

self._plot(confusing_x, confusing_y, trainer, pl_module, mask_idxs)

def _plot(self, confusing_x, confusing_y, trainer, model, mask_idxs):
from matplotlib import pyplot as plt

batch_size = confusing_x.size(0)

confusing_x = confusing_x[:self.top_k]
confusing_y = confusing_y[:self.top_k]

model.eval()
x_param_a = nn.Parameter(confusing_x)
x_param_b = nn.Parameter(confusing_x)

for logit_i, x_param in enumerate((x_param_a, x_param_b)):
l = model(x_param.view(batch_size, -1))
l[:, mask_idxs[:, logit_i]].sum().backward()

# reshape grads
grad_a = x_param_a.grad.view(batch_size, 28, 28)
grad_b = x_param_b.grad.view(batch_size, 28, 28)

for img_i in range(len(confusing_x)):
x = confusing_x[img_i].squeeze(0)
y = confusing_y[img_i]
ga = grad_a[img_i]
gb = grad_b[img_i]

mask_idx = mask_idxs[img_i]

fig = plt.figure(figsize=(15, 10))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use rather fig, axarr = plt.subplots(ncol=3, nrow=2)
and alter you just axarr[1, 2].imshow(...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't have a lot of time to mess with this. i don't know if this will work, but i know the current version works.
If you feel strongly about it, mind making the changes and posting a colab that shows this works?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I ll check I because the call to plt. takes the last pointer and if you plot two in parallel you will write everything into one

plt.subplot(231)
plt.imshow(x)
plt.colorbar()
plt.title(f'True: {y}', fontsize=20)

plt.subplot(232)
plt.imshow(ga)
plt.colorbar()
plt.title(f'd{mask_idx[0]}-logit/dx', fontsize=20)

plt.subplot(233)
plt.imshow(gb)
plt.colorbar()
plt.title(f'd{mask_idx[1]}-logit/dx', fontsize=20)

plt.subplot(235)
plt.imshow(ga * 2 + x)
plt.colorbar()
plt.title(f'd{mask_idx[0]}-logit/dx', fontsize=20)

plt.subplot(236)
plt.imshow(gb * 2 + x)
plt.colorbar()
plt.title(f'd{mask_idx[1]}-logit/dx', fontsize=20)

trainer.logger.experiment.add_figure('confusing_imgs', fig, global_step=trainer.global_step)