diff --git a/utils/metrics.py b/utils/metrics.py index e8897bbc33d5..21a71df51f60 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -1,5 +1,6 @@ # Model validation metrics +import warnings from pathlib import Path import matplotlib.pyplot as plt @@ -167,9 +168,11 @@ def plot(self, save_dir='', names=()): fig = plt.figure(figsize=(12, 9), tight_layout=True) sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels - sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True, - xticklabels=names + ['background FP'] if labels else "auto", - yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1)) + with warnings.catch_warnings(): + warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered + sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True, + xticklabels=names + ['background FP'] if labels else "auto", + yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1)) fig.axes[0].set_xlabel('True') fig.axes[0].set_ylabel('Predicted') fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250) diff --git a/utils/plots.py b/utils/plots.py index 973b9ae19b54..66a30918190e 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -11,7 +11,7 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd -import seaborn as sns +import seaborn as sn import torch import yaml from PIL import Image, ImageDraw, ImageFont @@ -291,7 +291,7 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None): x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height']) # seaborn correlogram - sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9)) + sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9)) plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200) plt.close() @@ -306,8 +306,8 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None): ax[0].set_xticklabels(names, rotation=90, fontsize=10) else: ax[0].set_xlabel('classes') - sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9) - sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9) + sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9) + sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9) # rectangles labels[:, 1:3] = 0.5 # center