diff --git a/docs/image_summaries.ipynb b/docs/image_summaries.ipynb index e527860e5a..266729b42c 100644 --- a/docs/image_summaries.ipynb +++ b/docs/image_summaries.ipynb @@ -498,14 +498,14 @@ " plt.xticks(tick_marks, class_names, rotation=45)\n", " plt.yticks(tick_marks, class_names)\n", "\n", - " # Normalize the confusion matrix.\n", - " cm = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2)\n", + " # Compute the labels from the normalized confusion matrix.\n", + " labels = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2)\n", "\n", " # Use white text if squares are dark; otherwise black.\n", " threshold = cm.max() / 2.\n", " for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n", " color = \"white\" if cm[i, j] > threshold else \"black\"\n", - " plt.text(j, i, cm[i, j], horizontalalignment=\"center\", color=color)\n", + " plt.text(j, i, labels[i, j], horizontalalignment=\"center\", color=color)\n", "\n", " plt.tight_layout()\n", " plt.ylabel('True label')\n",