-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added waterfall plots for change plotters
- Loading branch information
Maximilian
authored and
Maximilian
committed
Dec 30, 2022
1 parent
ec7d1e1
commit 7ff51c7
Showing
2 changed files
with
169 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import copy | ||
import typing | ||
|
||
from matplotlib import pyplot as plt | ||
|
||
from ixai.visualization.line_plots import plot_multi_line_graph | ||
from ixai.visualization.plotting import BasePlotter | ||
from ixai.visualization.waterfall_plots import plot_water_fall_graph | ||
|
||
|
||
class ChangePlotter(BasePlotter): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.y_data = {} | ||
self.x_data = {} | ||
self.stored_feature_names = set() | ||
|
||
@property | ||
def n_features_stored(self): | ||
return len(self.stored_feature_names) | ||
|
||
def _store_new_feature(self, feature_name: str): | ||
self.stored_feature_names.add(feature_name) | ||
self.y_data[feature_name] = [] | ||
self.x_data[feature_name] = [] | ||
|
||
def update( | ||
self, | ||
importance_values: dict[str, typing.Union[int, float]], | ||
): | ||
self.seen_timesteps += 1 | ||
for feature_name, feature_value in importance_values.items(): | ||
if feature_name not in self.stored_feature_names: | ||
self._store_new_feature(feature_name) | ||
self.y_data[feature_name].append(feature_value) | ||
self.x_data[feature_name].append(self.seen_timesteps) | ||
|
||
def plot( | ||
self, | ||
figsize: typing.Optional[tuple[int, int]] = None, | ||
save_name: typing.Optional[str] = None, | ||
model_performances: typing.Optional[dict[str, typing.Sequence]] = None, | ||
performance_kw: typing.Optional[dict] = None, | ||
**plot_kw | ||
) -> None: | ||
|
||
n_features = self.n_features_stored | ||
line_names = list(self.stored_feature_names) | ||
if 'line_names' in plot_kw: | ||
line_names = plot_kw['line_names'] | ||
n_features = len(line_names) | ||
|
||
if model_performances is not None: | ||
|
||
fig = plt.figure(figsize=figsize) | ||
gs = fig.add_gridspec(2, n_features, height_ratios=[1, 4]) | ||
performance_axis = fig.add_subplot(gs[0, :]) | ||
fi_axis = [fig.add_subplot(gs[1, i]) for i in range(0, n_features)] | ||
for i in range(0, len(fi_axis) - 1): | ||
fi_axis[i].sharey(fi_axis[i + 1]) | ||
|
||
title = None | ||
if 'title' in plot_kw: | ||
title = copy.copy(plot_kw['title']) | ||
del plot_kw['title'] | ||
performance_kw = {} if performance_kw is None else performance_kw | ||
performance_axis = plot_multi_line_graph( | ||
axis=performance_axis, | ||
y_data=model_performances, | ||
title=title, | ||
**performance_kw | ||
) | ||
else: | ||
fig, fi_axis = plt.subplots(1, n_features, sharey='all') | ||
|
||
if figsize is not None: | ||
fig.set_figheight(figsize[0]) | ||
fig.set_figwidth(figsize[1]) | ||
|
||
fi_axis = plot_water_fall_graph( | ||
axes=fi_axis, | ||
y_data=self.y_data, | ||
x_data=self.x_data, | ||
**plot_kw | ||
) | ||
|
||
plt.tight_layout() | ||
|
||
if model_performances is not None: | ||
plt.subplots_adjust(wspace=0.000, hspace=0.3) | ||
else: | ||
plt.subplots_adjust(wspace=0.000) | ||
|
||
if save_name is not None: | ||
plt.savefig(save_name, dpi=200) | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import copy | ||
import typing | ||
|
||
from matplotlib import pyplot as plt | ||
|
||
from ixai.visualization.color import BACKGROUND_COLOR | ||
|
||
WATERFALL_COLORS = {False: 'red', True: 'green'} | ||
|
||
|
||
def plot_water_fall_graph( | ||
axes: typing.List[plt.axis], | ||
y_data: typing.Dict[str, typing.List[float]], | ||
*, | ||
x_data: typing.Dict[str, typing.List[int]] = None, | ||
show_last_n: typing.Optional[int] = None, | ||
**kwargs | ||
): | ||
if 'h_lines' in kwargs: | ||
for h_line_props in kwargs['h_lines']: | ||
for axis in axes: | ||
axis.axhline(**h_line_props) | ||
|
||
color_dict = WATERFALL_COLORS | ||
if 'color' in kwargs: | ||
color_dict = kwargs['color'] | ||
|
||
line_names = list(y_data.keys()) | ||
|
||
min_line_value = min([min(values) for values in x_data.values()]) | ||
max_line_value = max([max(values) for values in x_data.values()]) | ||
|
||
x_values = {i: i - min_line_value for i in range(min_line_value, max_line_value + 1)} # {5:0, 6:1, 7:2} | ||
|
||
first_axis = axes[0] | ||
last_axis = axes[-1] | ||
|
||
for line_name, axis in zip(line_names, axes): | ||
line_values = [0., *y_data[line_name]] | ||
diffs = [line_values[i + 1] - line_values[i] for i in range(len(line_values) - 1)] | ||
bottoms = line_values[0:-1] | ||
colors = [color_dict[diffs[i] > 0] for i in range(len(diffs))] | ||
axis.bar(x_data[line_name], diffs, bottom=bottoms, color=colors) | ||
axis.set_xlabel(line_name) | ||
axis.grid(True, linestyle='dotted') | ||
axis.set_facecolor(BACKGROUND_COLOR) | ||
|
||
if 'ylabel' in kwargs: | ||
first_axis.set_ylabel(kwargs['ylabel']) | ||
|
||
if 'y_ticks' in kwargs: | ||
first_axis.set_yticks(kwargs['y_ticks']) | ||
first_axis.grid(True) | ||
for axis in axes[1:]: | ||
axis.set_yticks(kwargs['y_ticks']) | ||
axis.set_yticklabels([]) | ||
|
||
if 'x_ticks' in kwargs: | ||
for axis in axes: | ||
axis.set_xticks(kwargs['x_ticks']) | ||
|
||
if 'y_min' in kwargs: | ||
plt.ylim(bottom=kwargs['y_min']) | ||
|
||
if 'y_max' in kwargs: | ||
plt.ylim(top=kwargs['y_max']) | ||
|
||
if 'title' in kwargs: | ||
plt.suptitle(kwargs['title']) | ||
|
||
plt.subplots_adjust(wspace=0, hspace=0) | ||
return axes |