-
+
+
+
ROC curve
Show similarity to selected datapoint
-
False positive rate
-
True positive rate
+
+ False positive rate
+
+
+ True positive rate
+
-
-
+
+
PR curve
Show similarity to selected datapoint
-
Recall
-
Precision
+
Recall
+
Precision
+
+
Confusion matrix
+
+
+
+
+
@@ -3120,12 +3088,12 @@
Show similarity to selected datapoint
-
-
-
+
+
+
Confusion matrix
Show similarity to selected datapoint
label="[[getConfusionMatrixLabel(index, numModels)]]"
background="[[getConfusionMatrixColor(index)]]"
all-items="[[allConfMatrixLabels]]"
- class="conf-matrix"
>
+
+
+
+
+
+ ROC curve for [[getLabel(labelInd)]]
+
+
+
+
+ ROC curve
+
+
+ A receiver operating characteristic
+ (ROC) curve plots the true positive rate
+ (TPR) against the false positive rate
+ (FPR) at various classification
+ thresholds.
+
+
+ For this multi-class classification
+ problem, we plot one ROC curve for each
+ class, at each time considering the
+ class in question as the positive one
+ and all the others as negatives (i.e.
+ binarized versions of the
+ problem).
+
+
+
+
+ False positive rate
+
+
+ True positive rate
+
+
+
+
+
+
+ PR curve for [[getLabel(labelInd)]]
+
+
+
+ PR curve
+
+ A precision-recall (PR) curve plots
+ precision against recall at various
+ classification thresholds.
+
+
+ For this multi-class classification
+ problem, we plot one PR curve for each
+ class, at each time considering the
+ class in question as the positive one
+ and all the others as negatives (i.e.
+ binarized versions of the
+ problem).
+
+
+
+
Recall
+
+ Precision
+
+
+
+
+
-
+
@@ -3197,25 +3267,118 @@
Show similarity to selected datapoint
>
-
-
Confusion matrix
-
-
+
+
Confusion matrix
+
-
-
+
+
+
+
+
+
+
+
+ ROC curve for [[getLabel(labelInd)]]
+
+
+
+ ROC curve
+
+ A receiver operating characteristic (ROC)
+ curve plots the true positive rate (TPR)
+ against the false positive rate (FPR) at
+ various classification thresholds.
+
+
+ For this multi-class classification
+ problem, we plot one ROC curve for each
+ class, at each time considering the class
+ in question as the positive one and all
+ the others as negatives (i.e.
+ binarized versions of the problem).
+
+
+
+
+ False positive rate
+
+
+ True positive rate
+
+
+
+
+
+
+ PR curve for [[getLabel(labelInd)]]
+
+
+
+ PR curve
+
+ A precision-recall (PR) curve plots
+ precision against recall at various
+ classification thresholds.
+
+
+ For this multi-class classification
+ problem, we plot one PR curve for each
+ class, at each time considering the class
+ in question as the positive one and all
+ the others as negatives (i.e.
+ binarized versions of the problem).
+
+
+
+
Recall
+
Precision
+
+
+
+
+
@@ -4397,7 +4560,7 @@
Show similarity to selected datapoint
return !(modelNames && modelNames.length > 1);
},
- shouldShowOverallPrChart_: function(
+ shouldShowOverallPerfCharts_: function(
selectedLabelFeature,
selectedBreakdownFeature,
inferences
@@ -4409,7 +4572,7 @@
Show similarity to selected datapoint
);
},
- shouldShowFeaturePrCharts_: function(
+ shouldShowFeaturePerfCharts_: function(
selectedLabelFeature,
selectedBreakdownFeature,
inferences
@@ -4794,6 +4957,55 @@
Show similarity to selected datapoint
}
}
this.allConfMatrixLabels = Array.from(allLabels.values());
+ // Compute confusion data for perf curves
+ const thresholds = {};
+ for (let i = 0; i < this.examplesAndInferences.length; i++) {
+ const item = this.examplesAndInferences[i];
+ const trueLabel = this.visdata[i][
+ this.selectedLabelFeature
+ ].toString();
+ const scores =
+ item.inferences[item.inferences.length - 1][modelInd];
+ const slice =
+ this.selectedBreakdownFeature != ''
+ ? this.getSliceKey_(
+ this.visdata[i],
+ this.selectedBreakdownFeature,
+ this.selectedSecondBreakdownFeature
+ )
+ : '';
+ if (!(slice in thresholds)) {
+ thresholds[slice] = {};
+ }
+ for (let k = 0; k < scores.length; k++) {
+ const label = scores[k].label;
+ if (!(label in thresholds[slice])) {
+ thresholds[slice][label] = [];
+ for (let thresh = 0; thresh <= 100; thresh++) {
+ thresholds[slice][label].push({
+ TP: 0,
+ FP: 0,
+ FN: 0,
+ TN: 0,
+ });
+ }
+ }
+ const score = scores[k].score * 100;
+ let result = '';
+ for (let thresh = 0; thresh <= 100; thresh++) {
+ if (label === trueLabel) {
+ result = score > thresh ? 'TP' : 'FN';
+ } else {
+ result = score > thresh ? 'FP' : 'TN';
+ }
+ thresholds[slice][label][thresh][result] += 1;
+ }
+ }
+ }
+ Object.values(thresholds).forEach((t) =>
+ Object.values(t).forEach(this.calcThresholdStats)
+ );
+ inferenceStats.allThresholds = thresholds;
} else {
// For regression models, calculate inference error.
inferenceStats.results = {errors: []};
@@ -4854,7 +5066,7 @@
Show similarity to selected datapoint
this.featureValueThresholds = [];
this.featureValueThresholds = this.sortFeatureValues(tempArray);
- // ROC curves should only exist for the binary case
+ // ROC and PR curves for the binary case
if (this.isBinaryClassification_(this.modelType, this.multiClass)) {
for (let i = 0; i < this.featureValueThresholds.length; i++) {
const plotStats = [];
@@ -4913,6 +5125,63 @@
Show similarity to selected datapoint
false
);
}
+ // ROC and PR curves for the multi-class case
+ else if (this.isMultiClass_(this.modelType, this.multiClass)) {
+ const isSliced = this.featureValueThresholds.length > 0;
+ const slices = isSliced
+ ? this.featureValueThresholds.map((slice) =>
+ this.createCombinedValueString_(slice.value, slice.value2)
+ )
+ : [''];
+ for (let sliceInd = 0; sliceInd < slices.length; sliceInd++) {
+ const slice = slices[sliceInd];
+ for (
+ let labelInd = 0;
+ labelInd < this.allConfMatrixLabels.length;
+ labelInd++
+ ) {
+ const label = this.allConfMatrixLabels[labelInd];
+ const plotStats = [];
+ const plotThresholds = [];
+ const thresholds = isSliced
+ ? this.featureValueThresholds[sliceInd].threshold
+ : this.overallThresholds;
+ for (
+ let modelInd = 0;
+ modelInd < this.inferenceStats_.length;
+ modelInd++
+ ) {
+ plotStats.push(
+ this.inferenceStats_[modelInd].allThresholds[slice][label]
+ );
+ plotThresholds.push(thresholds[modelInd].threshold);
+ }
+ this.plotChart(
+ this.$$(
+ '#' +
+ this.getRocChartLabelId(
+ labelInd,
+ isSliced ? sliceInd : ''
+ )
+ ),
+ plotStats,
+ plotThresholds,
+ regenInferenceStats,
+ true
+ );
+ this.plotChart(
+ this.$$(
+ '#' +
+ this.getPrChartLabelId(labelInd, isSliced ? sliceInd : '')
+ ),
+ plotStats,
+ plotThresholds,
+ regenInferenceStats,
+ false
+ );
+ }
+ }
+ }
this.updateCorrectness_();
},
@@ -5864,10 +6133,22 @@
Show similarity to selected datapoint
return 'rocchart' + index;
},
+ getRocChartLabelId: function(label, index) {
+ return this.getRocChartId(index) + '-' + label;
+ },
+
+ getLabel: function(index) {
+ return this.labelVocab[index] || index;
+ },
+
getPrChartId: function(index) {
return 'prchart' + index;
},
+ getPrChartLabelId: function(label, index) {
+ return this.getPrChartId(index) + '-' + label;
+ },
+
/**
* Returns a printable value for a breakdown item, meaning performance
* statistics broken down by feature values or feature crosses.