Skip to content

Commit

Permalink
Confused logit callback (#118)
Browse files Browse the repository at this point in the history
* r

* r

* r

* r

* r

* r

* r

* r

* r

* r

* r

* r

* r

* r

* r

* r

* r

* r

* r

* r

* flake8

* plotting

* added more dm tests

* added more dm tests

Co-authored-by: Jirka <jirka@pytorchlightning.ai>
  • Loading branch information
williamFalcon and Borda authored Jul 31, 2020
1 parent 07dcc0f commit 637a532
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 4 deletions.
Binary file added docs/source/_images/vision/confused_logit.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ PyTorch-Lightning-Bolts documentation
callbacks
info_callbacks
variational_callbacks
vision_callbacks

.. toctree::
:maxdepth: 2
Expand Down
6 changes: 6 additions & 0 deletions docs/source/variational_callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,11 @@ Latent Dim Interpolator
-----------------------
Interpolates latent dims.

Example output:

.. image:: _images/gans/basic_gan_interpolate.jpg
:width: 400
:alt: Example latent space interpolation

.. autoclass:: pl_bolts.callbacks.variational.LatentDimInterpolator
:noindex:
22 changes: 22 additions & 0 deletions docs/source/vision_callbacks.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
.. role:: hidden
:class: hidden-section

Vision Callbacks
================
Useful callbacks for vision models

---------------

Confused Logit
--------------
Shows how the input would have to change to move the prediction from one logit to the other


Example outputs:

.. image:: _images/vision/confused_logit.png
:width: 400
:alt: Example of prediction confused between 5 and 8

.. autoclass:: pl_bolts.callbacks.vision.confused_logit.ConfusedLogitCallback
:noindex:
1 change: 1 addition & 0 deletions pl_bolts/callbacks/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from pl_bolts.callbacks.vision.confused_logit import ConfusedLogitCallback
127 changes: 127 additions & 0 deletions pl_bolts/callbacks/vision/confused_logit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import torch
from torch import nn
from pytorch_lightning import Callback


class ConfusedLogitCallback(Callback): # pragma: no-cover

def __init__(
self,
top_k,
projection_factor=3,
min_logit_value=5.0,
logging_batch_interval=20,
max_logit_difference=0.1
):
"""
Takes the logit predictions of a model and when the probabilities of two classes are very close, the model
doesn't have high certainty that it should pick one vs the other class.
This callback shows how the input would have to change to swing the model from one label prediction
to the other.
In this case, the network predicts a 5... but gives almost equal probability to an 8.
The images show what about the original 5 would have to change to make it more like a 5 or more like an 8.
For each confused logit the confused images are generated by taking the gradient from a logit wrt an input
for the top two closest logits.
Example::
from pl_bolts.callbacks.vision import ConfusedLogitCallback
trainer = Trainer(callbacks=[ConfusedLogitCallback()])
.. note:: whenever called, this model will look for self.last_batch and self.last_logits in the LightningModule
.. note:: this callback supports tensorboard only right now
Args:
top_k: How many "offending" images we should plot
projection_factor: How much to multiply the input image to make it look more like this logit label
min_logit_value: Only consider logit values above this threshold
logging_batch_interval: how frequently to inspect/potentially plot something
max_logit_difference: when the top 2 logits are within this threshold we consider them confused
Authored by:
- Alfredo Canziani
"""
super().__init__()
self.top_k = top_k
self.projection_factor = projection_factor
self.max_logit_difference = max_logit_difference
self.logging_batch_interval = logging_batch_interval
self.min_logit_value = min_logit_value

def on_batch_end(self, trainer, pl_module):

# show images only every 20 batches
if (trainer.batch_idx + 1) % self.logging_batch_interval != 0:
return

# pick the last batch and logits
x, y = pl_module.last_batch
logits = pl_module.last_logits

# only check when it has opinions (ie: the logit > 5)
if logits.max() > self.min_logit_value:
# pick the top two confused probs
(values, idxs) = torch.topk(logits, k=2, dim=1)

# care about only the ones that are at most eps close to each other
eps = self.max_logit_difference
mask = (values[:, 0] - values[:, 1]).abs() < eps

if mask.sum() > 0:
# pull out the ones we care about
confusing_x = x[mask, ...]
confusing_y = y[mask]

mask_idxs = idxs[mask]

self._plot(confusing_x, confusing_y, trainer, pl_module, mask_idxs)

def _plot(self, confusing_x, confusing_y, trainer, model, mask_idxs):
from matplotlib import pyplot as plt

batch_size, c, w, h = confusing_x.size()

confusing_x = confusing_x[:self.top_k]
confusing_y = confusing_y[:self.top_k]

model.eval()
x_param_a = nn.Parameter(confusing_x)
x_param_b = nn.Parameter(confusing_x)

for logit_i, x_param in enumerate((x_param_a, x_param_b)):
logits = model(x_param.view(batch_size, -1))
logits[:, mask_idxs[:, logit_i]].sum().backward()

# reshape grads
grad_a = x_param_a.grad.view(batch_size, w, h)
grad_b = x_param_b.grad.view(batch_size, w, h)

for img_i in range(len(confusing_x)):
x = confusing_x[img_i].squeeze(0)
y = confusing_y[img_i]
ga = grad_a[img_i]
gb = grad_b[img_i]

mask_idx = mask_idxs[img_i]

fig, axarr = plt.subplots(nrows=2, ncols=3, figsize=(15, 10))
self.__draw_sample(fig, axarr, 0, 0, x, f'True: {y}')
self.__draw_sample(fig, axarr, 0, 1, ga, f'd{mask_idx[0]}-logit/dx')
self.__draw_sample(fig, axarr, 0, 2, gb, f'd{mask_idx[1]}-logit/dx')
self.__draw_sample(fig, axarr, 1, 1, ga * 2 + x, f'd{mask_idx[0]}-logit/dx')
self.__draw_sample(fig, axarr, 1, 2, gb * 2 + x, f'd{mask_idx[1]}-logit/dx')

trainer.logger.experiment.add_figure('confusing_imgs', fig, global_step=trainer.global_step)

@staticmethod
def __draw_sample(fig, axarr, row_idx, col_idx, img, title):
im = axarr[row_idx, col_idx].imshow(img)
fig.colorbar(im, ax=axarr[row_idx, col_idx])
axarr[row_idx, col_idx].set_title(title, fontsize=20)
11 changes: 7 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ exclude_lines =
add_model_specific_args

[coverage:run]
# TODO, remove this ignores in future
omit =
pl_bolts/datamodules/stl10_datamodule.py
pl_bolts/datamodules/ssl_imagenet_datamodule.py
Expand All @@ -36,11 +37,13 @@ doctests = True
verbose = 2
# https://pep8.readthedocs.io/en/latest/intro.html#error-codes
format = pylint
# see: https://www.flake8rules.com/
ignore =
E731
W504
F401
F841
E731 # Do not assign a lambda expression, use a def
W504 # Line break occurred after a binary operator
F401 # Module imported but unused
F841 # Local variable name is assigned to but never used
W605 # Invalid escape sequence 'x'

# setup.cfg or tox.ini
[check-manifest]
Expand Down

0 comments on commit 637a532

Please sign in to comment.