diff --git a/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/tf-interactive-inference-dashboard.html b/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/tf-interactive-inference-dashboard.html
index 4ecd1ee720..994e1c7b11 100644
--- a/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/tf-interactive-inference-dashboard.html
+++ b/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/tf-interactive-inference-dashboard.html
@@ -497,6 +497,11 @@
max-width: 150px;
}
+ .num-buckets-input {
+ width: 45px;
+ margin-right: 8px;
+ }
+
.control-button {
background-color: white;
border: 1px solid var(--wit-color-gray300);
@@ -2122,6 +2127,18 @@
What does slicing do?
@@ -2158,6 +2175,18 @@
Show similarity to selected datapoint
+
+
+
@@ -3312,6 +3341,10 @@
Show similarity to selected datapoint
}
}
+ function deepClone(obj) {
+ return JSON.parse(JSON.stringify(obj));
+ }
+
(function() {
const PLUGIN_NAME = 'whatif';
@@ -3358,6 +3391,17 @@
Show similarity to selected datapoint
type: String,
},
maxInferenceEntriesPerRun: Number,
+ // Number of buckets when aggregating numeric features.
+ numPrimaryBuckets: {
+ type: Number,
+ value: 2,
+ observer: 'breakdownFeatureSelected_',
+ },
+ numSecondaryBuckets: {
+ type: Number,
+ value: 2,
+ observer: 'breakdownFeatureSelected_',
+ },
// Inferences from servo.
inferences: {
@@ -3769,6 +3813,14 @@
Show similarity to selected datapoint
type: Boolean,
value: false,
},
+ // This object maps each feature to a list of buckets delimiters,
+ // used for aggregating numeric features.
+ // For instance: `{'age': [10, 30, 60, 90]}` would mean values should
+ // be bucketed into either `[10, 30)`, `[30, 60)` or `[60, 90]`.
+ featureBucketEdges_: {
+ type: Object,
+ value: () => ({}),
+ },
},
observers: [
@@ -3977,10 +4029,8 @@
Show similarity to selected datapoint
adjustMaxCounterfactualValueDist_: function(selected, valueName) {
this.maxCounterfactualValueDist = Math.max(
- this.distanceStats_[valueName].max -
- this.visdata[selected][valueName],
- this.visdata[selected][valueName] -
- this.distanceStats_[valueName].min
+ this.stats[valueName].numberMax - this.visdata[selected][valueName],
+ this.visdata[selected][valueName] - this.stats[valueName].numberMin
);
},
@@ -4132,7 +4182,7 @@
Show similarity to selected datapoint
const maxLength = Math.max(aVals.length, bVals.length);
let featureTotalDist = 0;
for (let i = 0; i < maxLength; i++) {
- if (this.distanceStats_[feat].stdDev != null) {
+ if (this.distanceStats_[feat].isNumeric) {
featureTotalDist += this.getNumericDist(
aVals[i],
bVals[i],
@@ -4405,6 +4455,35 @@
Show similarity to selected datapoint
);
},
+ /**
+ * Get the key to which this example belongs when the dataset is sliced by the given feature(s).
+ */
+ getSliceKey_: function(example, feature1, feature2) {
+ const bucketEdges = this.featureBucketEdges_;
+ function maybeAggregate(feature) {
+ if (feature && example[feature] === undefined) {
+ return '?';
+ }
+ const edges = bucketEdges[feature];
+ if (edges) {
+ for (let i = 1; i < edges.length; i++) {
+ if (
+ example[feature] < edges[i] ||
+ (example[feature] === edges[i] && i === edges.length - 1)
+ ) {
+ const right = i < edges.length - 1 ? ')' : ']';
+ return '[' + edges[i - 1] + ', ' + edges[i] + right;
+ }
+ }
+ }
+ return example[feature];
+ }
+ return this.createCombinedValueString_(
+ maybeAggregate(feature1),
+ maybeAggregate(feature2)
+ );
+ },
+
/**
* Creates a list of all feature values of the selected breakdown feature
* or feature crosses if two breakdown features are selected, and gets
@@ -4414,80 +4493,39 @@
Show similarity to selected datapoint
// When features to slice by change, set optimization strategy
// back to custom (default).
this.resetOptimizationSelected_();
-
- const feature1 = this.selectedBreakdownFeature;
- if (feature1 == '') {
+ if (this.selectedBreakdownFeature == '') {
this.selectedSecondBreakdownFeature = '';
}
- const feature2 = this.selectedSecondBreakdownFeature;
+ this.calculateBucketEdges_(
+ this.selectedBreakdownFeature,
+ this.numPrimaryBuckets
+ );
+ this.calculateBucketEdges_(
+ this.selectedSecondBreakdownFeature,
+ this.numSecondaryBuckets
+ );
+
const thresholds = [];
- const thresholdsMap = {};
- if (feature1.length !== 0) {
- let feature1Values = this.stats[feature1].valueHash;
- // Only breakdown performance by features that don't contain fully-
- // unique values per example.
- if (
- this.stats[feature1].totalCount !=
- this.examplesAndInferences.length
- ) {
- feature1Values = Object.assign({}, feature1Values, {
- undefined: '',
+ const thresholdsIndexMap = {};
+ this.visdata.forEach((item) => {
+ const key1 = this.getSliceKey_(item, this.selectedBreakdownFeature);
+ const key2 = this.getSliceKey_(
+ item,
+ this.selectedSecondBreakdownFeature
+ );
+ const key = this.createCombinedValueString_(key1, key2);
+ if (!(key in thresholdsIndexMap)) {
+ thresholds.push({
+ value: key1,
+ value2: key2,
+ threshold: deepClone(this.overallThresholds),
+ opened: false,
});
+ thresholdsIndexMap[key] = thresholds.length - 1;
}
- let feature2Values = {undefined: ''};
- if (feature2.length != 0) {
- feature2Values = this.stats[feature2].valueHash;
- if (
- this.stats[feature2].totalCount !=
- this.examplesAndInferences.length
- ) {
- feature2Values = Object.assign({}, feature2Values, {
- undefined: '',
- });
- }
- }
-
- // For the selected feature, set up a dict of each feature value in
- // the dataset to the threshold. Add this to a list (for display
- // purposes) and create a map of feature value to entry in that list.
- for (var key1 in feature1Values) {
- if (feature1Values.hasOwnProperty(key1)) {
- for (var key2 in feature2Values) {
- if (feature2Values.hasOwnProperty(key2)) {
- const feature1Value =
- key1 == 'undefined'
- ? undefined
- : this.stats[feature1].valueHash[key1].value;
- const feature2Value =
- key2 == 'undefined'
- ? undefined
- : this.stats[feature2].valueHash[key2].value;
- // Deep copy thresholds for each facet
- const modelThresholds = [];
- for (let i = 0; i < this.overallThresholds.length; i++) {
- modelThresholds.push({
- threshold: this.overallThresholds[i].threshold,
- });
- }
- const thresh = {
- value: feature1Value,
- value2: feature2Value,
- threshold: modelThresholds,
- opened: false,
- };
- thresholds.push(thresh);
- const mapKey = this.createCombinedValueString_(
- feature1Value,
- feature2Value
- );
- thresholdsMap[mapKey] = thresholds.length - 1;
- }
- }
- }
- }
- }
+ });
this.set('featureValueThresholds', thresholds);
- this.set('featureValueThresholdsIndexMap', thresholdsMap);
+ this.set('featureValueThresholdsIndexMap', thresholdsIndexMap);
this.refreshInferences_(false);
},
@@ -4560,9 +4598,10 @@
Show similarity to selected datapoint
const item = this.visdata[i];
let facetedStats = null;
if (this.selectedBreakdownFeature != '') {
- const facetKey = this.createCombinedValueString_(
- item[this.selectedBreakdownFeature],
- item[this.selectedSecondBreakdownFeature]
+ const facetKey = this.getSliceKey_(
+ item,
+ this.selectedBreakdownFeature,
+ this.selectedSecondBreakdownFeature
);
facetedStats = inferenceStats.faceted[facetKey];
if (!facetedStats) {
@@ -4639,9 +4678,10 @@
Show similarity to selected datapoint
const item = this.visdata[i];
let facetedStats = null;
if (this.selectedBreakdownFeature != '') {
- const facetKey = this.createCombinedValueString_(
- item[this.selectedBreakdownFeature],
- item[this.selectedSecondBreakdownFeature]
+ const facetKey = this.getSliceKey_(
+ item,
+ this.selectedBreakdownFeature,
+ this.selectedSecondBreakdownFeature
);
facetedStats = inferenceStats.faceted[facetKey];
if (!facetedStats) {
@@ -4695,9 +4735,10 @@
Show similarity to selected datapoint
const item = this.visdata[i];
let facetedStats = null;
if (this.selectedBreakdownFeature != '') {
- const facetKey = this.createCombinedValueString_(
- item[this.selectedBreakdownFeature],
- item[this.selectedSecondBreakdownFeature]
+ const facetKey = this.getSliceKey_(
+ item,
+ this.selectedBreakdownFeature,
+ this.selectedSecondBreakdownFeature
);
facetedStats = inferenceStats.faceted[facetKey];
if (!facetedStats) {
@@ -4844,9 +4885,16 @@
Show similarity to selected datapoint
);
}
} else if (this.selectedFeatureSort == 'Alphabetical') {
- return this.getPrintableValue_(a).localeCompare(
- this.getPrintableValue_(b)
- );
+ const aValue = this.getPrintableValue_(a);
+ const bValue = this.getPrintableValue_(b);
+ // Handle numeric intervals
+ if (aValue[0] === '[' && bValue[0] === '[') {
+ return (
+ Number.parseFloat(aValue.substring(1)) -
+ Number.parseFloat(bValue.substring(1))
+ );
+ }
+ return aValue.localeCompare(bValue);
} else if (this.selectedFeatureSort == 'Accuracy') {
if (
this.isBinaryClassification_(this.modelType, this.multiClass)
@@ -4898,6 +4946,13 @@
Show similarity to selected datapoint
if (this.selectedFeatureSort == 'Count') {
return b.count - a.count;
} else if (this.selectedFeatureSort == 'Alphabetical') {
+ // Handle numeric intervals
+ if (a.name[0] === '[' && b.name[0] === '[') {
+ return (
+ Number.parseFloat(a.name.substring(1)) -
+ Number.parseFloat(b.name.substring(1))
+ );
+ }
return a.name.localeCompare(b.name);
} else if (this.selectedFeatureSort == 'Mean error') {
return b.meanError - a.meanError;
@@ -6264,9 +6319,10 @@
Show similarity to selected datapoint
// case), then get the appropriate threshold for this item's value for
// that feature. Otherwise the overall threshold will be used.
if (feature1.length !== 0) {
- let key = this.createCombinedValueString_(
- item[feature1],
- item[feature2]
+ let key = this.getSliceKey_(
+ item,
+ this.selectedBreakdownFeature,
+ this.selectedSecondBreakdownFeature
);
thresholds = this.featureValueThresholds[
this.featureValueThresholdsIndexMap[key]
@@ -6428,14 +6484,13 @@
Show similarity to selected datapoint
) {
const featureStats = statsProto.datasetsList[0].featuresList[i];
const feature = featureStats.name;
- this.distanceStats_[feature] = {};
- if (featureStats.numStats) {
+ this.distanceStats_[feature] = {
+ isNumeric: featureStats.numStats != null,
+ };
+ if (this.distanceStats_[feature].isNumeric) {
// Numeric features:
- this.distanceStats_[feature] = {
- stdDev: featureStats.numStats.stdDev,
- min: featureStats.numStats.min,
- max: featureStats.numStats.max,
- };
+ this.distanceStats_[feature].stdDev =
+ featureStats.numStats.stdDev;
} else {
// Categorical features: calculate and store the probability
// that any two feature values across all examples are the same.
@@ -6453,6 +6508,51 @@
Show similarity to selected datapoint
}
},
+ /**
+ * Whether the feature is numeric (as opposed to categorical).
+ */
+ isNumericFeature_: function(feature) {
+ return (
+ feature &&
+ this.distanceStats_ &&
+ this.distanceStats_[feature] &&
+ this.distanceStats_[feature].isNumeric
+ );
+ },
+
+ /**
+ * Calculate edges between buckets for aggregating numeric features.
+ * We do this beforehand to round numbers and avoid ugly interval labels.
+ */
+ calculateBucketEdges_: function(feature, numBuckets) {
+ if (
+ !this.isNumericFeature_(feature) ||
+ // No point in aggregating if less unique values than buckets.
+ this.stats[feature].uniqueCount < numBuckets ||
+ // Already done.
+ (this.featureBucketEdges_[feature] &&
+ this.featureBucketEdges_[feature].length == numBuckets + 1)
+ ) {
+ return;
+ }
+ const min = this.stats[feature].numberMin;
+ const max = this.stats[feature].numberMax;
+ const len = (max - min) / numBuckets;
+ const stdDev = this.distanceStats_[feature].stdDev;
+ function round(val) {
+ // Round to slightly more precise than the magnitude of standard deviation.
+ const precision = -Math.floor(Math.log10(stdDev)) + 1;
+ return Math.round(val * 10 ** precision) / 10 ** precision;
+ }
+ const bucketEdges = [];
+ bucketEdges.push(min);
+ for (let i = 1; i < numBuckets; i++) {
+ bucketEdges.push(round(min + i * len));
+ }
+ bucketEdges.push(max);
+ this.featureBucketEdges_[feature] = bucketEdges;
+ },
+
/**
* Calls the backend to update a changed example.
*/