Skip to content

Commit b84980a

Browse files
smaelandfacebook-github-bot
authored andcommitted
Add time series visualization function (pytorch#980)
Summary: Add a convenience function to plot time series data with attributions overlaid (`captum.attr.visualization.visualize_timeseries_attr`). This addresses pytorch#958 . Comes with three styles, shown here for some dummy data: 1) Plot each channel in a separate panel, with separate heatmaps overlaid ![overlaid_individual](https://user-images.githubusercontent.com/30171842/174852816-f3c7d67f-d03f-4d04-91b4-6766052a640d.png) 2) Plot all channels in a single panel, with average heatmap overlaid ![overlaid_combined](https://user-images.githubusercontent.com/30171842/174852821-1ab089b2-9e30-4233-9726-dd3e3d9f03f5.png) 3) Plot each channel in a separate panel and color the graphs by attribution values at each time step ![colored_graph](https://user-images.githubusercontent.com/30171842/174852820-f0be8148-d432-43f3-a301-e783b98dece0.png) The function accepts matplotlib keyword arguments for additional styling. Pull Request resolved: pytorch#980 Reviewed By: vivekmig Differential Revision: D37495470 Pulled By: i-jones fbshipit-source-id: d218dc035d7158af39480a4df63a0bb9500f495c
1 parent 7d77c72 commit b84980a

File tree

1 file changed

+324
-6
lines changed

1 file changed

+324
-6
lines changed

captum/attr/_utils/visualization.py

+324-6
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
#!/usr/bin/env python3
22
import warnings
33
from enum import Enum
4-
from typing import Any, Iterable, List, Tuple, Union
4+
from typing import Any, Iterable, List, Optional, Tuple, Union
55

66
import numpy as np
7-
from matplotlib import pyplot as plt
7+
from matplotlib import cm, colors, pyplot as plt
8+
from matplotlib.collections import LineCollection
89
from matplotlib.colors import LinearSegmentedColormap
910
from matplotlib.figure import Figure
1011
from matplotlib.pyplot import axis, figure
@@ -27,6 +28,12 @@ class ImageVisualizationMethod(Enum):
2728
alpha_scaling = 5
2829

2930

31+
class TimeseriesVisualizationMethod(Enum):
32+
overlay_individual = 1
33+
overlay_combined = 2
34+
colored_graph = 3
35+
36+
3037
class VisualizeSign(Enum):
3138
positive = 1
3239
absolute_value = 2
@@ -61,10 +68,16 @@ def _cumulative_sum_threshold(values: ndarray, percentile: Union[int, float]):
6168
return sorted_vals[threshold_id]
6269

6370

64-
def _normalize_image_attr(
65-
attr: ndarray, sign: str, outlier_perc: Union[int, float] = 2
71+
def _normalize_attr(
72+
attr: ndarray,
73+
sign: str,
74+
outlier_perc: Union[int, float] = 2,
75+
reduction_axis: Optional[int] = None,
6676
):
67-
attr_combined = np.sum(attr, axis=2)
77+
attr_combined = attr
78+
if reduction_axis is not None:
79+
attr_combined = np.sum(attr, axis=reduction_axis)
80+
6881
# Choose appropriate signed values and rescale, removing given outlier percentage.
6982
if VisualizeSign[sign] == VisualizeSign.all:
7083
threshold = _cumulative_sum_threshold(np.abs(attr_combined), 100 - outlier_perc)
@@ -241,7 +254,7 @@ def visualize_image_attr(
241254
plt_axis.imshow(original_image)
242255
else:
243256
# Choose appropriate signed attributions and normalize.
244-
norm_attr = _normalize_image_attr(attr, sign, outlier_perc)
257+
norm_attr = _normalize_attr(attr, sign, outlier_perc, reduction_axis=2)
245258

246259
# Set default colormap and bounds based on sign.
247260
if VisualizeSign[sign] == VisualizeSign.all:
@@ -422,6 +435,311 @@ def visualize_image_attr_multiple(
422435
return plt_fig, plt_axis
423436

424437

438+
def visualize_timeseries_attr(
439+
attr: ndarray,
440+
data: ndarray,
441+
x_values: Optional[ndarray] = None,
442+
method: str = "individual_channels",
443+
sign: str = "absolute_value",
444+
channel_labels: Optional[List[str]] = None,
445+
channels_last: bool = True,
446+
plt_fig_axis: Union[None, Tuple[figure, axis]] = None,
447+
outlier_perc: Union[int, float] = 2,
448+
cmap: Union[None, str] = None,
449+
alpha_overlay: float = 0.7,
450+
show_colorbar: bool = False,
451+
title: Union[None, str] = None,
452+
fig_size: Tuple[int, int] = (6, 6),
453+
use_pyplot: bool = True,
454+
**pyplot_kwargs,
455+
):
456+
r"""
457+
Visualizes attribution for a given timeseries data by normalizing
458+
attribution values of the desired sign (positive, negative, absolute value,
459+
or all) and displaying them using the desired mode in a matplotlib figure.
460+
461+
Args:
462+
463+
attr (numpy.array): Numpy array corresponding to attributions to be
464+
visualized. Shape must be in the form (N, C) with channels
465+
as last dimension, unless `channels_last` is set to True.
466+
Shape must also match that of the timeseries data.
467+
data (numpy.array): Numpy array corresponding to the original,
468+
equidistant timeseries data. Shape must be in the form
469+
(N, C) with channels as last dimension, unless
470+
`channels_last` is set to true.
471+
x_values (numpy.array, optional): Numpy array corresponding to the
472+
points on the x-axis. Shape must be in the form (N, ). If
473+
not provided, integers from 0 to N-1 are used.
474+
Default: None
475+
method (string, optional): Chosen method for visualizing attributions
476+
overlaid onto data. Supported options are:
477+
478+
1. `overlay_individual` - Plot each channel individually in
479+
a separate panel, and overlay the attributions for each
480+
channel as a heat map. The `alpha_overlay` parameter
481+
controls the alpha of the heat map.
482+
483+
2. `overlay_combined` - Plot all channels in the same panel,
484+
and overlay the average attributions as a heat map.
485+
486+
3. `colored_graph` - Plot each channel in a separate panel,
487+
and color the graphs according to the attribution
488+
values. Works best with color maps that does not contain
489+
white or very bright colors.
490+
Default: `overlay_individual`
491+
sign (string, optional): Chosen sign of attributions to visualize.
492+
Supported options are:
493+
494+
1. `positive` - Displays only positive pixel attributions.
495+
496+
2. `absolute_value` - Displays absolute value of
497+
attributions.
498+
499+
3. `negative` - Displays only negative pixel attributions.
500+
501+
4. `all` - Displays both positive and negative attribution
502+
values.
503+
Default: `absolute_value`
504+
channel_labels (list of strings, optional): List of labels
505+
corresponding to each channel in data.
506+
Default: None
507+
channels_last (bool, optional): If True, data is expected to have
508+
channels as the last dimension, i.e. (N, C). If False, data
509+
is expected to have channels first, i.e. (C, N).
510+
Default: True
511+
plt_fig_axis (tuple, optional): Tuple of matplotlib.pyplot.figure and axis
512+
on which to visualize. If None is provided, then a new figure
513+
and axis are created.
514+
Default: None
515+
outlier_perc (float or int, optional): Top attribution values which
516+
correspond to a total of outlier_perc percentage of the
517+
total attribution are set to 1 and scaling is performed
518+
using the minimum of these values. For sign=`all`, outliers
519+
and scale value are computed using absolute value of
520+
attributions.
521+
Default: 2
522+
cmap (string, optional): String corresponding to desired colormap for
523+
heatmap visualization. This defaults to "Reds" for negative
524+
sign, "Blues" for absolute value, "Greens" for positive sign,
525+
and a spectrum from red to green for all. Note that this
526+
argument is only used for visualizations displaying heatmaps.
527+
Default: None
528+
alpha_overlay (float, optional): Alpha to set for heatmap when using
529+
`blended_heat_map` visualization mode, which overlays the
530+
heat map over the greyscaled original image.
531+
Default: 0.7
532+
show_colorbar (boolean): Displays colorbar for heat map below
533+
the visualization.
534+
title (string, optional): Title string for plot. If None, no title is
535+
set.
536+
Default: None
537+
fig_size (tuple, optional): Size of figure created.
538+
Default: (6,6)
539+
use_pyplot (boolean): If true, uses pyplot to create and show
540+
figure and displays the figure after creating. If False,
541+
uses Matplotlib object oriented API and simply returns a
542+
figure object without showing.
543+
Default: True.
544+
pyplot_kwargs: Keyword arguments forwarded to plt.plot, for example
545+
`linewidth=3`, `color='black'`, etc
546+
547+
Returns:
548+
2-element tuple of **figure**, **axis**:
549+
- **figure** (*matplotlib.pyplot.figure*):
550+
Figure object on which visualization
551+
is created. If plt_fig_axis argument is given, this is the
552+
same figure provided.
553+
- **axis** (*matplotlib.pyplot.axis*):
554+
Axis object on which visualization
555+
is created. If plt_fig_axis argument is given, this is the
556+
same axis provided.
557+
558+
Examples::
559+
560+
>>> # Classifier takes input of shape (batch, length, channels)
561+
>>> model = Classifier()
562+
>>> dl = DeepLift(model)
563+
>>> attribution = dl.attribute(data, target=0)
564+
>>> # Pick the first sample and plot each channel in data in a separate
565+
>>> # panel, with attributions overlaid
566+
>>> visualize_timeseries_attr(attribution[0], data[0], "overlay_individual")
567+
"""
568+
569+
# Check input dimensions
570+
assert len(attr.shape) == 2, "Expected attr of shape (N, C), got {}".format(
571+
attr.shape
572+
)
573+
assert len(data.shape) == 2, "Expected data of shape (N, C), got {}".format(
574+
attr.shape
575+
)
576+
577+
# Convert to channels-first
578+
if channels_last:
579+
attr = np.transpose(attr)
580+
data = np.transpose(data)
581+
582+
num_channels = attr.shape[0]
583+
timeseries_length = attr.shape[1]
584+
585+
if num_channels > timeseries_length:
586+
warnings.warn(
587+
"Number of channels ({}) greater than time series length ({}), "
588+
"please verify input format".format(num_channels, timeseries_length)
589+
)
590+
591+
num_subplots = num_channels
592+
if (
593+
TimeseriesVisualizationMethod[method]
594+
== TimeseriesVisualizationMethod.overlay_combined
595+
):
596+
num_subplots = 1
597+
attr = np.sum(attr, axis=0) # Merge attributions across channels
598+
599+
if x_values is not None:
600+
assert (
601+
x_values.shape[0] == timeseries_length
602+
), "x_values must have same length as data"
603+
else:
604+
x_values = np.arange(timeseries_length)
605+
606+
# Create plot if figure, axis not provided
607+
if plt_fig_axis is not None:
608+
plt_fig, plt_axis = plt_fig_axis
609+
else:
610+
if use_pyplot:
611+
plt_fig, plt_axis = plt.subplots(
612+
figsize=fig_size, nrows=num_subplots, sharex=True
613+
)
614+
else:
615+
plt_fig = Figure(figsize=fig_size)
616+
plt_axis = plt_fig.subplots(nrows=num_subplots, sharex=True)
617+
618+
if not isinstance(plt_axis, ndarray):
619+
plt_axis = np.array([plt_axis])
620+
621+
norm_attr = _normalize_attr(attr, sign, outlier_perc, reduction_axis=None)
622+
623+
# Set default colormap and bounds based on sign.
624+
if VisualizeSign[sign] == VisualizeSign.all:
625+
default_cmap = LinearSegmentedColormap.from_list(
626+
"RdWhGn", ["red", "white", "green"]
627+
)
628+
vmin, vmax = -1, 1
629+
elif VisualizeSign[sign] == VisualizeSign.positive:
630+
default_cmap = "Greens"
631+
vmin, vmax = 0, 1
632+
elif VisualizeSign[sign] == VisualizeSign.negative:
633+
default_cmap = "Reds"
634+
vmin, vmax = 0, 1
635+
elif VisualizeSign[sign] == VisualizeSign.absolute_value:
636+
default_cmap = "Blues"
637+
vmin, vmax = 0, 1
638+
else:
639+
raise AssertionError("Visualize Sign type is not valid.")
640+
cmap = cmap if cmap is not None else default_cmap
641+
cmap = cm.get_cmap(cmap)
642+
cm_norm = colors.Normalize(vmin, vmax)
643+
644+
def _plot_attrs_as_axvspan(attr_vals, x_vals, ax):
645+
646+
half_col_width = (x_values[1] - x_values[0]) / 2.0
647+
for icol, col_center in enumerate(x_vals):
648+
left = col_center - half_col_width
649+
right = col_center + half_col_width
650+
ax.axvspan(
651+
xmin=left,
652+
xmax=right,
653+
facecolor=(cmap(cm_norm(attr_vals[icol]))),
654+
edgecolor=None,
655+
alpha=alpha_overlay,
656+
)
657+
658+
if (
659+
TimeseriesVisualizationMethod[method]
660+
== TimeseriesVisualizationMethod.overlay_individual
661+
):
662+
663+
for chan in range(num_channels):
664+
665+
plt_axis[chan].plot(x_values, data[chan, :], **pyplot_kwargs)
666+
if channel_labels is not None:
667+
plt_axis[chan].set_ylabel(channel_labels[chan])
668+
669+
_plot_attrs_as_axvspan(norm_attr[chan], x_values, plt_axis[chan])
670+
671+
plt.subplots_adjust(hspace=0)
672+
673+
elif (
674+
TimeseriesVisualizationMethod[method]
675+
== TimeseriesVisualizationMethod.overlay_combined
676+
):
677+
678+
# Dark colors are better in this case
679+
cycler = plt.cycler("color", cm.Dark2.colors)
680+
plt_axis[0].set_prop_cycle(cycler)
681+
682+
for chan in range(num_channels):
683+
if channel_labels is not None:
684+
label = channel_labels[chan]
685+
else:
686+
label = None
687+
plt_axis[0].plot(x_values, data[chan, :], label=label, **pyplot_kwargs)
688+
689+
_plot_attrs_as_axvspan(norm_attr, x_values, plt_axis[0])
690+
691+
plt_axis[0].legend(loc="best")
692+
693+
elif (
694+
TimeseriesVisualizationMethod[method]
695+
== TimeseriesVisualizationMethod.colored_graph
696+
):
697+
698+
for chan in range(num_channels):
699+
700+
points = np.array([x_values, data[chan, :]]).T.reshape(-1, 1, 2)
701+
segments = np.concatenate([points[:-1], points[1:]], axis=1)
702+
703+
lc = LineCollection(segments, cmap=cmap, norm=cm_norm, **pyplot_kwargs)
704+
lc.set_array(norm_attr[chan, :])
705+
plt_axis[chan].add_collection(lc)
706+
plt_axis[chan].set_ylim(
707+
1.2 * np.min(data[chan, :]), 1.2 * np.max(data[chan, :])
708+
)
709+
if channel_labels is not None:
710+
plt_axis[chan].set_ylabel(channel_labels[chan])
711+
712+
plt.subplots_adjust(hspace=0)
713+
714+
else:
715+
raise AssertionError("Invalid visualization method: {}".format(method))
716+
717+
plt.xlim([x_values[0], x_values[-1]])
718+
719+
if show_colorbar:
720+
axis_separator = make_axes_locatable(plt_axis[-1])
721+
colorbar_axis = axis_separator.append_axes("bottom", size="5%", pad=0.4)
722+
colorbar_alpha = alpha_overlay
723+
if (
724+
TimeseriesVisualizationMethod[method]
725+
== TimeseriesVisualizationMethod.colored_graph
726+
):
727+
colorbar_alpha = 1.0
728+
plt_fig.colorbar(
729+
cm.ScalarMappable(cm_norm, cmap),
730+
orientation="horizontal",
731+
cax=colorbar_axis,
732+
alpha=colorbar_alpha,
733+
)
734+
if title:
735+
plt_axis[0].set_title(title)
736+
737+
if use_pyplot:
738+
plt.show()
739+
740+
return plt_fig, plt_axis
741+
742+
425743
# These visualization methods are for text and are partially copied from
426744
# experiments conducted by Davide Testuggine at Facebook.
427745

0 commit comments

Comments
 (0)