diff --git a/modelvshuman/plotting/plot.py b/modelvshuman/plotting/plot.py index f6d68df..7fbe07e 100644 --- a/modelvshuman/plotting/plot.py +++ b/modelvshuman/plotting/plot.py @@ -417,6 +417,9 @@ def confusion_matrix_helper(data, output_filename, plt.savefig(output_filename, bbox_inches='tight', dpi=300) plt.close() + sns.reset_defaults() + sns.reset_orig() + plt.style.use('default') def plot_shape_bias_matrixplot(datasets, @@ -632,21 +635,25 @@ def plot_matrix(datasets, analysis, f, ax = plt.subplots(figsize=(22, 18)) cmap = sns.diverging_palette(230, 20, as_cmap=True) - heatmap = sns.heatmap(res["matrix"], mask=None, cmap=cmap, vmax=1.0, center=0, + sns.heatmap(res["matrix"], ax=ax, mask=None, cmap=cmap, vmax=1.0, center=0, square=True, linewidths=2.0, cbar_kws={"shrink": .5}, xticklabels=True, yticklabels=True) - for i, tick_label in enumerate(heatmap.axes.get_yticklabels()): + for i, tick_label in enumerate(ax.axes.get_yticklabels()): tick_label.set_color(colors[i]) - for i, tick_label in enumerate(heatmap.axes.get_xticklabels()): + for i, tick_label in enumerate(ax.axes.get_xticklabels()): tick_label.set_color(colors[i]) figure_path = pjoin(result_dir, f"{dataset.name}_{analysis.plotting_name.replace(' ', '-')}_matrix{by_mean_str}.pdf") - heatmap.figure.savefig(figure_path, bbox_inches='tight', pad_inches=0) - plt.cla() - plt.clf() - plt.close('all') + f.savefig(figure_path, bbox_inches='tight', pad_inches=0) + f.clear() + plt.cla() + plt.clf() + plt.close() + sns.reset_defaults() + sns.reset_orig() + plt.style.use('default') def sort_matrix_by_models_mean(result_dict): diff --git a/setup.cfg b/setup.cfg index c45ea58..ccea069 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,7 +35,7 @@ setup_requires = # setuptools >=38.3.0 # version with most `setup.cfg` bugfixes install_requires = torch==1.7.1 - torchvision + torchvision==0.8.2 requests gdown scikit-image @@ -44,7 +44,7 @@ install_requires = PySocks tensorflow_hub tensorflow-gpu - tensorflow==2.0 + tensorflow==2.5.0 matplotlib>=3.3.2 pandas seaborn @@ -52,8 +52,8 @@ install_requires = regex tqdm CLIP @ git+https://github.com/openai/CLIP#egg=CLIP - figshare @ git+https://github.com/cognoma/figshare#egg=figshare pytorch_pretrained_vit + tensorflow-estimator==2.1.* tests_require = pytest dependency_links =