1
1
#!/usr/bin/env python3
2
2
import warnings
3
3
from enum import Enum
4
- from typing import Any , Iterable , List , Tuple , Union
4
+ from typing import Any , Iterable , List , Optional , Tuple , Union
5
5
6
6
import numpy as np
7
- from matplotlib import pyplot as plt
7
+ from matplotlib import cm , colors , pyplot as plt
8
+ from matplotlib .collections import LineCollection
8
9
from matplotlib .colors import LinearSegmentedColormap
9
10
from matplotlib .figure import Figure
10
11
from matplotlib .pyplot import axis , figure
@@ -27,6 +28,12 @@ class ImageVisualizationMethod(Enum):
27
28
alpha_scaling = 5
28
29
29
30
31
+ class TimeseriesVisualizationMethod (Enum ):
32
+ overlay_individual = 1
33
+ overlay_combined = 2
34
+ colored_graph = 3
35
+
36
+
30
37
class VisualizeSign (Enum ):
31
38
positive = 1
32
39
absolute_value = 2
@@ -61,10 +68,16 @@ def _cumulative_sum_threshold(values: ndarray, percentile: Union[int, float]):
61
68
return sorted_vals [threshold_id ]
62
69
63
70
64
- def _normalize_image_attr (
65
- attr : ndarray , sign : str , outlier_perc : Union [int , float ] = 2
71
+ def _normalize_attr (
72
+ attr : ndarray ,
73
+ sign : str ,
74
+ outlier_perc : Union [int , float ] = 2 ,
75
+ reduction_axis : Optional [int ] = None ,
66
76
):
67
- attr_combined = np .sum (attr , axis = 2 )
77
+ attr_combined = attr
78
+ if reduction_axis is not None :
79
+ attr_combined = np .sum (attr , axis = reduction_axis )
80
+
68
81
# Choose appropriate signed values and rescale, removing given outlier percentage.
69
82
if VisualizeSign [sign ] == VisualizeSign .all :
70
83
threshold = _cumulative_sum_threshold (np .abs (attr_combined ), 100 - outlier_perc )
@@ -241,7 +254,7 @@ def visualize_image_attr(
241
254
plt_axis .imshow (original_image )
242
255
else :
243
256
# Choose appropriate signed attributions and normalize.
244
- norm_attr = _normalize_image_attr (attr , sign , outlier_perc )
257
+ norm_attr = _normalize_attr (attr , sign , outlier_perc , reduction_axis = 2 )
245
258
246
259
# Set default colormap and bounds based on sign.
247
260
if VisualizeSign [sign ] == VisualizeSign .all :
@@ -422,6 +435,311 @@ def visualize_image_attr_multiple(
422
435
return plt_fig , plt_axis
423
436
424
437
438
+ def visualize_timeseries_attr (
439
+ attr : ndarray ,
440
+ data : ndarray ,
441
+ x_values : Optional [ndarray ] = None ,
442
+ method : str = "individual_channels" ,
443
+ sign : str = "absolute_value" ,
444
+ channel_labels : Optional [List [str ]] = None ,
445
+ channels_last : bool = True ,
446
+ plt_fig_axis : Union [None , Tuple [figure , axis ]] = None ,
447
+ outlier_perc : Union [int , float ] = 2 ,
448
+ cmap : Union [None , str ] = None ,
449
+ alpha_overlay : float = 0.7 ,
450
+ show_colorbar : bool = False ,
451
+ title : Union [None , str ] = None ,
452
+ fig_size : Tuple [int , int ] = (6 , 6 ),
453
+ use_pyplot : bool = True ,
454
+ ** pyplot_kwargs ,
455
+ ):
456
+ r"""
457
+ Visualizes attribution for a given timeseries data by normalizing
458
+ attribution values of the desired sign (positive, negative, absolute value,
459
+ or all) and displaying them using the desired mode in a matplotlib figure.
460
+
461
+ Args:
462
+
463
+ attr (numpy.array): Numpy array corresponding to attributions to be
464
+ visualized. Shape must be in the form (N, C) with channels
465
+ as last dimension, unless `channels_last` is set to True.
466
+ Shape must also match that of the timeseries data.
467
+ data (numpy.array): Numpy array corresponding to the original,
468
+ equidistant timeseries data. Shape must be in the form
469
+ (N, C) with channels as last dimension, unless
470
+ `channels_last` is set to true.
471
+ x_values (numpy.array, optional): Numpy array corresponding to the
472
+ points on the x-axis. Shape must be in the form (N, ). If
473
+ not provided, integers from 0 to N-1 are used.
474
+ Default: None
475
+ method (string, optional): Chosen method for visualizing attributions
476
+ overlaid onto data. Supported options are:
477
+
478
+ 1. `overlay_individual` - Plot each channel individually in
479
+ a separate panel, and overlay the attributions for each
480
+ channel as a heat map. The `alpha_overlay` parameter
481
+ controls the alpha of the heat map.
482
+
483
+ 2. `overlay_combined` - Plot all channels in the same panel,
484
+ and overlay the average attributions as a heat map.
485
+
486
+ 3. `colored_graph` - Plot each channel in a separate panel,
487
+ and color the graphs according to the attribution
488
+ values. Works best with color maps that does not contain
489
+ white or very bright colors.
490
+ Default: `overlay_individual`
491
+ sign (string, optional): Chosen sign of attributions to visualize.
492
+ Supported options are:
493
+
494
+ 1. `positive` - Displays only positive pixel attributions.
495
+
496
+ 2. `absolute_value` - Displays absolute value of
497
+ attributions.
498
+
499
+ 3. `negative` - Displays only negative pixel attributions.
500
+
501
+ 4. `all` - Displays both positive and negative attribution
502
+ values.
503
+ Default: `absolute_value`
504
+ channel_labels (list of strings, optional): List of labels
505
+ corresponding to each channel in data.
506
+ Default: None
507
+ channels_last (bool, optional): If True, data is expected to have
508
+ channels as the last dimension, i.e. (N, C). If False, data
509
+ is expected to have channels first, i.e. (C, N).
510
+ Default: True
511
+ plt_fig_axis (tuple, optional): Tuple of matplotlib.pyplot.figure and axis
512
+ on which to visualize. If None is provided, then a new figure
513
+ and axis are created.
514
+ Default: None
515
+ outlier_perc (float or int, optional): Top attribution values which
516
+ correspond to a total of outlier_perc percentage of the
517
+ total attribution are set to 1 and scaling is performed
518
+ using the minimum of these values. For sign=`all`, outliers
519
+ and scale value are computed using absolute value of
520
+ attributions.
521
+ Default: 2
522
+ cmap (string, optional): String corresponding to desired colormap for
523
+ heatmap visualization. This defaults to "Reds" for negative
524
+ sign, "Blues" for absolute value, "Greens" for positive sign,
525
+ and a spectrum from red to green for all. Note that this
526
+ argument is only used for visualizations displaying heatmaps.
527
+ Default: None
528
+ alpha_overlay (float, optional): Alpha to set for heatmap when using
529
+ `blended_heat_map` visualization mode, which overlays the
530
+ heat map over the greyscaled original image.
531
+ Default: 0.7
532
+ show_colorbar (boolean): Displays colorbar for heat map below
533
+ the visualization.
534
+ title (string, optional): Title string for plot. If None, no title is
535
+ set.
536
+ Default: None
537
+ fig_size (tuple, optional): Size of figure created.
538
+ Default: (6,6)
539
+ use_pyplot (boolean): If true, uses pyplot to create and show
540
+ figure and displays the figure after creating. If False,
541
+ uses Matplotlib object oriented API and simply returns a
542
+ figure object without showing.
543
+ Default: True.
544
+ pyplot_kwargs: Keyword arguments forwarded to plt.plot, for example
545
+ `linewidth=3`, `color='black'`, etc
546
+
547
+ Returns:
548
+ 2-element tuple of **figure**, **axis**:
549
+ - **figure** (*matplotlib.pyplot.figure*):
550
+ Figure object on which visualization
551
+ is created. If plt_fig_axis argument is given, this is the
552
+ same figure provided.
553
+ - **axis** (*matplotlib.pyplot.axis*):
554
+ Axis object on which visualization
555
+ is created. If plt_fig_axis argument is given, this is the
556
+ same axis provided.
557
+
558
+ Examples::
559
+
560
+ >>> # Classifier takes input of shape (batch, length, channels)
561
+ >>> model = Classifier()
562
+ >>> dl = DeepLift(model)
563
+ >>> attribution = dl.attribute(data, target=0)
564
+ >>> # Pick the first sample and plot each channel in data in a separate
565
+ >>> # panel, with attributions overlaid
566
+ >>> visualize_timeseries_attr(attribution[0], data[0], "overlay_individual")
567
+ """
568
+
569
+ # Check input dimensions
570
+ assert len (attr .shape ) == 2 , "Expected attr of shape (N, C), got {}" .format (
571
+ attr .shape
572
+ )
573
+ assert len (data .shape ) == 2 , "Expected data of shape (N, C), got {}" .format (
574
+ attr .shape
575
+ )
576
+
577
+ # Convert to channels-first
578
+ if channels_last :
579
+ attr = np .transpose (attr )
580
+ data = np .transpose (data )
581
+
582
+ num_channels = attr .shape [0 ]
583
+ timeseries_length = attr .shape [1 ]
584
+
585
+ if num_channels > timeseries_length :
586
+ warnings .warn (
587
+ "Number of channels ({}) greater than time series length ({}), "
588
+ "please verify input format" .format (num_channels , timeseries_length )
589
+ )
590
+
591
+ num_subplots = num_channels
592
+ if (
593
+ TimeseriesVisualizationMethod [method ]
594
+ == TimeseriesVisualizationMethod .overlay_combined
595
+ ):
596
+ num_subplots = 1
597
+ attr = np .sum (attr , axis = 0 ) # Merge attributions across channels
598
+
599
+ if x_values is not None :
600
+ assert (
601
+ x_values .shape [0 ] == timeseries_length
602
+ ), "x_values must have same length as data"
603
+ else :
604
+ x_values = np .arange (timeseries_length )
605
+
606
+ # Create plot if figure, axis not provided
607
+ if plt_fig_axis is not None :
608
+ plt_fig , plt_axis = plt_fig_axis
609
+ else :
610
+ if use_pyplot :
611
+ plt_fig , plt_axis = plt .subplots (
612
+ figsize = fig_size , nrows = num_subplots , sharex = True
613
+ )
614
+ else :
615
+ plt_fig = Figure (figsize = fig_size )
616
+ plt_axis = plt_fig .subplots (nrows = num_subplots , sharex = True )
617
+
618
+ if not isinstance (plt_axis , ndarray ):
619
+ plt_axis = np .array ([plt_axis ])
620
+
621
+ norm_attr = _normalize_attr (attr , sign , outlier_perc , reduction_axis = None )
622
+
623
+ # Set default colormap and bounds based on sign.
624
+ if VisualizeSign [sign ] == VisualizeSign .all :
625
+ default_cmap = LinearSegmentedColormap .from_list (
626
+ "RdWhGn" , ["red" , "white" , "green" ]
627
+ )
628
+ vmin , vmax = - 1 , 1
629
+ elif VisualizeSign [sign ] == VisualizeSign .positive :
630
+ default_cmap = "Greens"
631
+ vmin , vmax = 0 , 1
632
+ elif VisualizeSign [sign ] == VisualizeSign .negative :
633
+ default_cmap = "Reds"
634
+ vmin , vmax = 0 , 1
635
+ elif VisualizeSign [sign ] == VisualizeSign .absolute_value :
636
+ default_cmap = "Blues"
637
+ vmin , vmax = 0 , 1
638
+ else :
639
+ raise AssertionError ("Visualize Sign type is not valid." )
640
+ cmap = cmap if cmap is not None else default_cmap
641
+ cmap = cm .get_cmap (cmap )
642
+ cm_norm = colors .Normalize (vmin , vmax )
643
+
644
+ def _plot_attrs_as_axvspan (attr_vals , x_vals , ax ):
645
+
646
+ half_col_width = (x_values [1 ] - x_values [0 ]) / 2.0
647
+ for icol , col_center in enumerate (x_vals ):
648
+ left = col_center - half_col_width
649
+ right = col_center + half_col_width
650
+ ax .axvspan (
651
+ xmin = left ,
652
+ xmax = right ,
653
+ facecolor = (cmap (cm_norm (attr_vals [icol ]))),
654
+ edgecolor = None ,
655
+ alpha = alpha_overlay ,
656
+ )
657
+
658
+ if (
659
+ TimeseriesVisualizationMethod [method ]
660
+ == TimeseriesVisualizationMethod .overlay_individual
661
+ ):
662
+
663
+ for chan in range (num_channels ):
664
+
665
+ plt_axis [chan ].plot (x_values , data [chan , :], ** pyplot_kwargs )
666
+ if channel_labels is not None :
667
+ plt_axis [chan ].set_ylabel (channel_labels [chan ])
668
+
669
+ _plot_attrs_as_axvspan (norm_attr [chan ], x_values , plt_axis [chan ])
670
+
671
+ plt .subplots_adjust (hspace = 0 )
672
+
673
+ elif (
674
+ TimeseriesVisualizationMethod [method ]
675
+ == TimeseriesVisualizationMethod .overlay_combined
676
+ ):
677
+
678
+ # Dark colors are better in this case
679
+ cycler = plt .cycler ("color" , cm .Dark2 .colors )
680
+ plt_axis [0 ].set_prop_cycle (cycler )
681
+
682
+ for chan in range (num_channels ):
683
+ if channel_labels is not None :
684
+ label = channel_labels [chan ]
685
+ else :
686
+ label = None
687
+ plt_axis [0 ].plot (x_values , data [chan , :], label = label , ** pyplot_kwargs )
688
+
689
+ _plot_attrs_as_axvspan (norm_attr , x_values , plt_axis [0 ])
690
+
691
+ plt_axis [0 ].legend (loc = "best" )
692
+
693
+ elif (
694
+ TimeseriesVisualizationMethod [method ]
695
+ == TimeseriesVisualizationMethod .colored_graph
696
+ ):
697
+
698
+ for chan in range (num_channels ):
699
+
700
+ points = np .array ([x_values , data [chan , :]]).T .reshape (- 1 , 1 , 2 )
701
+ segments = np .concatenate ([points [:- 1 ], points [1 :]], axis = 1 )
702
+
703
+ lc = LineCollection (segments , cmap = cmap , norm = cm_norm , ** pyplot_kwargs )
704
+ lc .set_array (norm_attr [chan , :])
705
+ plt_axis [chan ].add_collection (lc )
706
+ plt_axis [chan ].set_ylim (
707
+ 1.2 * np .min (data [chan , :]), 1.2 * np .max (data [chan , :])
708
+ )
709
+ if channel_labels is not None :
710
+ plt_axis [chan ].set_ylabel (channel_labels [chan ])
711
+
712
+ plt .subplots_adjust (hspace = 0 )
713
+
714
+ else :
715
+ raise AssertionError ("Invalid visualization method: {}" .format (method ))
716
+
717
+ plt .xlim ([x_values [0 ], x_values [- 1 ]])
718
+
719
+ if show_colorbar :
720
+ axis_separator = make_axes_locatable (plt_axis [- 1 ])
721
+ colorbar_axis = axis_separator .append_axes ("bottom" , size = "5%" , pad = 0.4 )
722
+ colorbar_alpha = alpha_overlay
723
+ if (
724
+ TimeseriesVisualizationMethod [method ]
725
+ == TimeseriesVisualizationMethod .colored_graph
726
+ ):
727
+ colorbar_alpha = 1.0
728
+ plt_fig .colorbar (
729
+ cm .ScalarMappable (cm_norm , cmap ),
730
+ orientation = "horizontal" ,
731
+ cax = colorbar_axis ,
732
+ alpha = colorbar_alpha ,
733
+ )
734
+ if title :
735
+ plt_axis [0 ].set_title (title )
736
+
737
+ if use_pyplot :
738
+ plt .show ()
739
+
740
+ return plt_fig , plt_axis
741
+
742
+
425
743
# These visualization methods are for text and are partially copied from
426
744
# experiments conducted by Davide Testuggine at Facebook.
427
745
0 commit comments