Skip to content

Commit

Permalink
added waterfall plots for change plotters
Browse files Browse the repository at this point in the history
  • Loading branch information
Maximilian authored and Maximilian committed Dec 30, 2022
1 parent ec7d1e1 commit 7ff51c7
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 0 deletions.
97 changes: 97 additions & 0 deletions ixai/visualization/change_plotter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import copy
import typing

from matplotlib import pyplot as plt

from ixai.visualization.line_plots import plot_multi_line_graph
from ixai.visualization.plotting import BasePlotter
from ixai.visualization.waterfall_plots import plot_water_fall_graph


class ChangePlotter(BasePlotter):

def __init__(self):
super().__init__()
self.y_data = {}
self.x_data = {}
self.stored_feature_names = set()

@property
def n_features_stored(self):
return len(self.stored_feature_names)

def _store_new_feature(self, feature_name: str):
self.stored_feature_names.add(feature_name)
self.y_data[feature_name] = []
self.x_data[feature_name] = []

def update(
self,
importance_values: dict[str, typing.Union[int, float]],
):
self.seen_timesteps += 1
for feature_name, feature_value in importance_values.items():
if feature_name not in self.stored_feature_names:
self._store_new_feature(feature_name)
self.y_data[feature_name].append(feature_value)
self.x_data[feature_name].append(self.seen_timesteps)

def plot(
self,
figsize: typing.Optional[tuple[int, int]] = None,
save_name: typing.Optional[str] = None,
model_performances: typing.Optional[dict[str, typing.Sequence]] = None,
performance_kw: typing.Optional[dict] = None,
**plot_kw
) -> None:

n_features = self.n_features_stored
line_names = list(self.stored_feature_names)
if 'line_names' in plot_kw:
line_names = plot_kw['line_names']
n_features = len(line_names)

if model_performances is not None:

fig = plt.figure(figsize=figsize)
gs = fig.add_gridspec(2, n_features, height_ratios=[1, 4])
performance_axis = fig.add_subplot(gs[0, :])
fi_axis = [fig.add_subplot(gs[1, i]) for i in range(0, n_features)]
for i in range(0, len(fi_axis) - 1):
fi_axis[i].sharey(fi_axis[i + 1])

title = None
if 'title' in plot_kw:
title = copy.copy(plot_kw['title'])
del plot_kw['title']
performance_kw = {} if performance_kw is None else performance_kw
performance_axis = plot_multi_line_graph(
axis=performance_axis,
y_data=model_performances,
title=title,
**performance_kw
)
else:
fig, fi_axis = plt.subplots(1, n_features, sharey='all')

if figsize is not None:
fig.set_figheight(figsize[0])
fig.set_figwidth(figsize[1])

fi_axis = plot_water_fall_graph(
axes=fi_axis,
y_data=self.y_data,
x_data=self.x_data,
**plot_kw
)

plt.tight_layout()

if model_performances is not None:
plt.subplots_adjust(wspace=0.000, hspace=0.3)
else:
plt.subplots_adjust(wspace=0.000)

if save_name is not None:
plt.savefig(save_name, dpi=200)
plt.show()
72 changes: 72 additions & 0 deletions ixai/visualization/waterfall_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import copy
import typing

from matplotlib import pyplot as plt

from ixai.visualization.color import BACKGROUND_COLOR

WATERFALL_COLORS = {False: 'red', True: 'green'}


def plot_water_fall_graph(
axes: typing.List[plt.axis],
y_data: typing.Dict[str, typing.List[float]],
*,
x_data: typing.Dict[str, typing.List[int]] = None,
show_last_n: typing.Optional[int] = None,
**kwargs
):
if 'h_lines' in kwargs:
for h_line_props in kwargs['h_lines']:
for axis in axes:
axis.axhline(**h_line_props)

color_dict = WATERFALL_COLORS
if 'color' in kwargs:
color_dict = kwargs['color']

line_names = list(y_data.keys())

min_line_value = min([min(values) for values in x_data.values()])
max_line_value = max([max(values) for values in x_data.values()])

x_values = {i: i - min_line_value for i in range(min_line_value, max_line_value + 1)} # {5:0, 6:1, 7:2}

first_axis = axes[0]
last_axis = axes[-1]

for line_name, axis in zip(line_names, axes):
line_values = [0., *y_data[line_name]]
diffs = [line_values[i + 1] - line_values[i] for i in range(len(line_values) - 1)]
bottoms = line_values[0:-1]
colors = [color_dict[diffs[i] > 0] for i in range(len(diffs))]
axis.bar(x_data[line_name], diffs, bottom=bottoms, color=colors)
axis.set_xlabel(line_name)
axis.grid(True, linestyle='dotted')
axis.set_facecolor(BACKGROUND_COLOR)

if 'ylabel' in kwargs:
first_axis.set_ylabel(kwargs['ylabel'])

if 'y_ticks' in kwargs:
first_axis.set_yticks(kwargs['y_ticks'])
first_axis.grid(True)
for axis in axes[1:]:
axis.set_yticks(kwargs['y_ticks'])
axis.set_yticklabels([])

if 'x_ticks' in kwargs:
for axis in axes:
axis.set_xticks(kwargs['x_ticks'])

if 'y_min' in kwargs:
plt.ylim(bottom=kwargs['y_min'])

if 'y_max' in kwargs:
plt.ylim(top=kwargs['y_max'])

if 'title' in kwargs:
plt.suptitle(kwargs['title'])

plt.subplots_adjust(wspace=0, hspace=0)
return axes

0 comments on commit 7ff51c7

Please sign in to comment.