diff --git a/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/tf-interactive-inference-dashboard.html b/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/tf-interactive-inference-dashboard.html index e333ec04e6..7cab25d12d 100644 --- a/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/tf-interactive-inference-dashboard.html +++ b/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/tf-interactive-inference-dashboard.html @@ -246,13 +246,6 @@ --paper-slider-active-color: #fa7817; } - .pr-line-chart { - margin: 0; - height: 200px; - width: 280px; - display: inline-block; - } - paper-dialog.inference-settings { padding: 20px; width: 40%; @@ -378,16 +371,6 @@ display: block; } - .conf-matrix-holder { - margin-top: 20px; - margin-bottom: 18px; - margin-right: 24px; - } - - .conf-matrix { - margin-bottom: 18px; - } - .datapoint-controls-holder.datapoint-control-buttons-holder { padding-left: 2px; } @@ -569,57 +552,41 @@ font-weight: 500; } - .curves-holder { + .perfs-holder { display: flex; + justify-content: center; + width: 100%; flex-wrap: wrap; - margin-top: 20px; + margin: 4px; position: relative; } - .curve-holder { - width: 300px; - height: 235px; - margin-bottom: 20px; - margin-right: 20px; + .perf-holder { + margin: 8px; position: relative; } - .roc-x-label { - position: absolute; - bottom: 0; - left: 120px; - font-size: 12px; - color: #5f6368; - padding: 0px; - } - - .roc-y-label { - position: absolute; - left: -36px; - bottom: 110px; - transform: rotate(270deg); - font-size: 12px; - color: #5f6368; - padding: 0px; - } - - .pr-x-label { + .perf-curve-x-label { position: absolute; - bottom: 0; - left: 140px; + bottom: 2px; + left: 138px; font-size: 12px; color: #5f6368; padding: 0px; + width: 120px; + text-align: center; } - .pr-y-label { + .perf-curve-y-label { position: absolute; - left: -14px; - bottom: 110px; + left: 44px; + bottom: 54px; transform: rotate(270deg); + transform-origin: left bottom; font-size: 12px; color: #5f6368; - padding: 0px; + width: 120px; + text-align: center; } .flex { @@ -695,6 +662,7 @@ .main-bottom-bar { height: 52px; min-height: 52px; + margin-left: 44px; flex-grow: 0; display: flex; /* box-shadow: 0 2px 5px grey; @@ -932,14 +900,15 @@ cursor: pointer; } - .roc-text { + .perf-curve-text { color: #3c4043; font-size: 16px; margin-left: 44px; + margin-bottom: -10px; } .conf-text { - margin-bottom: 12px; + margin-bottom: 8px; color: #3c4043; font-size: 16px; } @@ -1099,6 +1068,7 @@ .perf-table-entry-expanded { display: flex; flex-wrap: wrap; + width: 100%; margin: 0 12px; border-left: 1px solid var(--wit-color-gray300); border-bottom: 1px solid var(--wit-color-gray300); @@ -2319,7 +2289,7 @@

Show similarity to selected datapoint

Demographic parity Show similarity to selected datapoint Equal opportunity Show similarity to selected datapoint Equal accuracy Show similarity to selected datapoint Group thresholds Show similarity to selected datapoint
@@ -2879,27 +2846,11 @@

Show similarity to selected datapoint

@@ -3120,12 +3088,12 @@

Show similarity to selected datapoint

-
- @@ -3197,25 +3267,118 @@

Show similarity to selected datapoint

>
@@ -4397,7 +4560,7 @@

Show similarity to selected datapoint

return !(modelNames && modelNames.length > 1); }, - shouldShowOverallPrChart_: function( + shouldShowOverallPerfCharts_: function( selectedLabelFeature, selectedBreakdownFeature, inferences @@ -4409,7 +4572,7 @@

Show similarity to selected datapoint

); }, - shouldShowFeaturePrCharts_: function( + shouldShowFeaturePerfCharts_: function( selectedLabelFeature, selectedBreakdownFeature, inferences @@ -4794,6 +4957,55 @@

Show similarity to selected datapoint

} } this.allConfMatrixLabels = Array.from(allLabels.values()); + // Compute confusion data for perf curves + const thresholds = {}; + for (let i = 0; i < this.examplesAndInferences.length; i++) { + const item = this.examplesAndInferences[i]; + const trueLabel = this.visdata[i][ + this.selectedLabelFeature + ].toString(); + const scores = + item.inferences[item.inferences.length - 1][modelInd]; + const slice = + this.selectedBreakdownFeature != '' + ? this.getSliceKey_( + this.visdata[i], + this.selectedBreakdownFeature, + this.selectedSecondBreakdownFeature + ) + : ''; + if (!(slice in thresholds)) { + thresholds[slice] = {}; + } + for (let k = 0; k < scores.length; k++) { + const label = scores[k].label; + if (!(label in thresholds[slice])) { + thresholds[slice][label] = []; + for (let thresh = 0; thresh <= 100; thresh++) { + thresholds[slice][label].push({ + TP: 0, + FP: 0, + FN: 0, + TN: 0, + }); + } + } + const score = scores[k].score * 100; + let result = ''; + for (let thresh = 0; thresh <= 100; thresh++) { + if (label === trueLabel) { + result = score > thresh ? 'TP' : 'FN'; + } else { + result = score > thresh ? 'FP' : 'TN'; + } + thresholds[slice][label][thresh][result] += 1; + } + } + } + Object.values(thresholds).forEach((t) => + Object.values(t).forEach(this.calcThresholdStats) + ); + inferenceStats.allThresholds = thresholds; } else { // For regression models, calculate inference error. inferenceStats.results = {errors: []}; @@ -4854,7 +5066,7 @@

Show similarity to selected datapoint

this.featureValueThresholds = []; this.featureValueThresholds = this.sortFeatureValues(tempArray); - // ROC curves should only exist for the binary case + // ROC and PR curves for the binary case if (this.isBinaryClassification_(this.modelType, this.multiClass)) { for (let i = 0; i < this.featureValueThresholds.length; i++) { const plotStats = []; @@ -4913,6 +5125,63 @@

Show similarity to selected datapoint

false ); } + // ROC and PR curves for the multi-class case + else if (this.isMultiClass_(this.modelType, this.multiClass)) { + const isSliced = this.featureValueThresholds.length > 0; + const slices = isSliced + ? this.featureValueThresholds.map((slice) => + this.createCombinedValueString_(slice.value, slice.value2) + ) + : ['']; + for (let sliceInd = 0; sliceInd < slices.length; sliceInd++) { + const slice = slices[sliceInd]; + for ( + let labelInd = 0; + labelInd < this.allConfMatrixLabels.length; + labelInd++ + ) { + const label = this.allConfMatrixLabels[labelInd]; + const plotStats = []; + const plotThresholds = []; + const thresholds = isSliced + ? this.featureValueThresholds[sliceInd].threshold + : this.overallThresholds; + for ( + let modelInd = 0; + modelInd < this.inferenceStats_.length; + modelInd++ + ) { + plotStats.push( + this.inferenceStats_[modelInd].allThresholds[slice][label] + ); + plotThresholds.push(thresholds[modelInd].threshold); + } + this.plotChart( + this.$$( + '#' + + this.getRocChartLabelId( + labelInd, + isSliced ? sliceInd : '' + ) + ), + plotStats, + plotThresholds, + regenInferenceStats, + true + ); + this.plotChart( + this.$$( + '#' + + this.getPrChartLabelId(labelInd, isSliced ? sliceInd : '') + ), + plotStats, + plotThresholds, + regenInferenceStats, + false + ); + } + } + } this.updateCorrectness_(); }, @@ -5864,10 +6133,22 @@

Show similarity to selected datapoint

return 'rocchart' + index; }, + getRocChartLabelId: function(label, index) { + return this.getRocChartId(index) + '-' + label; + }, + + getLabel: function(index) { + return this.labelVocab[index] || index; + }, + getPrChartId: function(index) { return 'prchart' + index; }, + getPrChartLabelId: function(label, index) { + return this.getPrChartId(index) + '-' + label; + }, + /** * Returns a printable value for a breakdown item, meaning performance * statistics broken down by feature values or feature crosses.