diff --git a/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/tf-interactive-inference-dashboard.html b/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/tf-interactive-inference-dashboard.html
index 814ff272e3..4ecd1ee720 100644
--- a/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/tf-interactive-inference-dashboard.html
+++ b/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/tf-interactive-inference-dashboard.html
@@ -3365,6 +3365,21 @@
Show similarity to selected datapoint
observer: 'newInferences_',
value: () => ({}),
},
+ // Extra outputs from inference. A dict with two fields: 'indices' and
+ // 'extra'. Indices contains a list of example indices that
+ // these new outputs apply to. Extra contains a list of extra output
+ // objects, one for each model being inferred. The object for each
+ // model is a dict of output data names to lists of the output values
+ // for that data, one entry for each example that was inferred upon.
+ // 'attributions' is one of these output data which is parsed into the
+ // 'attributions' object defined below as a special case. Any other extra
+ // data provided are displayed by WIT with each example.
+ // @type {indices: Array,
+ // extra: Array<{!Object>}>}
+ extraOutputs: {
+ type: Object,
+ observer: 'newExtraOutputs_',
+ },
// Attributions from inference. A dict with two fields: 'indices' and
// 'attributions'. Indices contains a list of example indices that
// these new attributions apply to. Attributions contains a list of
@@ -3944,12 +3959,16 @@ Show similarity to selected datapoint
} else {
this.comparedIndices = [];
this.counterfactualExampleAndInference = null;
- const temp = this.selectedExampleAndInference;
- this.selectedExampleAndInference = null;
- this.selectedExampleAndInference = temp;
+ this.refreshSelectedDatapoint_();
}
},
+ refreshSelectedDatapoint_: function() {
+ const temp = this.selectedExampleAndInference;
+ this.selectedExampleAndInference = null;
+ this.selectedExampleAndInference = temp;
+ },
+
isSameInferenceClass_: function(val1, val2) {
return this.isRegression_(this.modelType)
? Math.abs(val1 - val2) < this.minCounterfactualValueDist
@@ -6108,6 +6127,77 @@ Show similarity to selected datapoint
this.updatedExample = false;
},
+ newExtraOutputs_: function(extraOutputs) {
+ // Set attributions from the extra outputs, if available.
+ const attributions = [];
+ for (
+ let modelNum = 0;
+ modelNum < extraOutputs.extra.length;
+ modelNum++
+ ) {
+ if ('attributions' in extraOutputs.extra[modelNum]) {
+ attributions.push(extraOutputs.extra[modelNum].attributions);
+ }
+ }
+ if (attributions.length > 0) {
+ this.attributions = {
+ indices: extraOutputs.indices,
+ attributions: attributions,
+ };
+ }
+
+ // Add extra output information to datapoints
+ for (let i = 0; i < extraOutputs.indices.length; i++) {
+ const idx = extraOutputs.indices[i];
+ const datapoint = Object.assign({}, this.visdata[idx]);
+ for (
+ let modelNum = 0;
+ modelNum < extraOutputs.extra.length;
+ modelNum++
+ ) {
+ const keys = Object.keys(extraOutputs.extra[modelNum]);
+ for (let j = 0; j < keys.length; j++) {
+ const key = keys[j];
+ // Skip attributions as they are handled separately above.
+ if (key == 'attributions') {
+ continue;
+ }
+ let val = extraOutputs.extra[modelNum][key][i];
+ const datapointKey = this.strWithModelName_(key, modelNum);
+
+ // Update the datapoint with the extra info for use in
+ // Facets Dive.
+ datapoint[datapointKey] = val;
+
+ // Convert the extra output into an array if necessary, for
+ // insertion into tf.Example as a value list, for update of
+ // examplesAndInferences for the example viewer.
+ if (!Array.isArray(val)) {
+ val = [val];
+ }
+ const isString =
+ val.length > 0 &&
+ (typeof val[0] == 'string' || val[0] instanceof String);
+ const exampleJsonString = JSON.stringify(
+ this.examplesAndInferences[idx].example
+ );
+ const copiedExample = JSON.parse(exampleJsonString);
+ copiedExample.features.feature[datapointKey] = isString
+ ? {bytesList: {value: val}}
+ : {floatList: {value: val}};
+ this.examplesAndInferences[idx].example = copiedExample;
+ }
+ }
+ this.set(`visdata.${idx}`, datapoint);
+ }
+ this.refreshDive_();
+
+ // Update selected datapoint so that if a datapoint is being viewed,
+ // the display is updated with the appropriate extra output.
+ this.computeSelectedExampleAndInference();
+ this.refreshSelectedDatapoint_();
+ },
+
newAttributions_: function(attributions) {
if (Object.keys(attributions).length == 0) {
return;
diff --git a/tensorboard/plugins/interactive_inference/utils/inference_utils.py b/tensorboard/plugins/interactive_inference/utils/inference_utils.py
index a0af0f8e67..3ff6ac7988 100644
--- a/tensorboard/plugins/interactive_inference/utils/inference_utils.py
+++ b/tensorboard/plugins/interactive_inference/utils/inference_utils.py
@@ -615,12 +615,12 @@ def get_example_features(example):
def run_inference_for_inference_results(examples, serving_bundle):
"""Calls servo and wraps the inference results."""
- (inference_result_proto, attributions) = run_inference(
+ (inference_result_proto, extra_results) = run_inference(
examples, serving_bundle)
inferences = wrap_inference_results(inference_result_proto)
infer_json = json_format.MessageToJson(
inferences, including_default_value_fields=True)
- return json.loads(infer_json), attributions
+ return json.loads(infer_json), extra_results
def get_eligible_features(examples, num_mutants):
"""Returns a list of JSON objects for each feature in the examples.
@@ -795,8 +795,8 @@ def run_inference(examples, serving_bundle):
Returns:
A tuple with the first entry being the ClassificationResponse or
- RegressionResponse proto and the second entry being a list of the
- attributions for each example, or None if no attributions exist.
+ RegressionResponse proto and the second entry being a dictionary of extra
+ data for each example, such as attributions, or None if no data exists.
"""
batch_size = 64
if serving_bundle.estimator and serving_bundle.feature_spec:
@@ -822,14 +822,16 @@ def run_inference(examples, serving_bundle):
# If custom_predict_fn is provided, pass examples directly for local
# inference.
values = serving_bundle.custom_predict_fn(examples)
- attributions = None
+ extra_results = None
# If the custom prediction function returned a dict, then parse out the
- # prediction scores and the attributions. If it is just a list, then the
- # results are the prediction results without attributions.
+ # prediction scores. If it is just a list, then the results are the
+ # prediction results without attributions or other data.
if isinstance(values, dict):
- attributions = values['attributions']
- values = values['predictions']
- return (common_utils.convert_prediction_values(values, serving_bundle),
- attributions)
+ preds = values.pop('predictions')
+ extra_results = values
+ else:
+ preds = values
+ return (common_utils.convert_prediction_values(preds, serving_bundle),
+ extra_results)
else:
return (platform_utils.call_servo(examples, serving_bundle), None)
diff --git a/tensorboard/plugins/interactive_inference/witwidget/notebook/base.py b/tensorboard/plugins/interactive_inference/witwidget/notebook/base.py
index 297aa11172..72632da730 100644
--- a/tensorboard/plugins/interactive_inference/witwidget/notebook/base.py
+++ b/tensorboard/plugins/interactive_inference/witwidget/notebook/base.py
@@ -123,7 +123,7 @@ def infer_impl(self):
examples_to_infer = [
self.json_to_proto(self.examples[index]) for index in indices_to_infer]
infer_objs = []
- attribution_objs = []
+ extra_output_objs = []
serving_bundle = inference_utils.ServingBundle(
self.config.get('inference_address'),
self.config.get('model_name'),
@@ -136,11 +136,11 @@ def infer_impl(self):
self.estimator_and_spec.get('estimator'),
self.estimator_and_spec.get('feature_spec'),
self.custom_predict_fn)
- (predictions, attributions) = (
+ (predictions, extra_output) = (
inference_utils.run_inference_for_inference_results(
examples_to_infer, serving_bundle))
infer_objs.append(predictions)
- attribution_objs.append(attributions)
+ extra_output_objs.append(extra_output)
if ('inference_address_2' in self.config or
self.compare_estimator_and_spec.get('estimator') or
self.compare_custom_predict_fn):
@@ -156,16 +156,16 @@ def infer_impl(self):
self.compare_estimator_and_spec.get('estimator'),
self.compare_estimator_and_spec.get('feature_spec'),
self.compare_custom_predict_fn)
- (predictions, attributions) = (
+ (predictions, extra_output) = (
inference_utils.run_inference_for_inference_results(
examples_to_infer, serving_bundle))
infer_objs.append(predictions)
- attribution_objs.append(attributions)
+ extra_output_objs.append(extra_output)
self.updated_example_indices = set()
return {
'inferences': {'indices': indices_to_infer, 'results': infer_objs},
'label_vocab': self.config.get('label_vocab'),
- 'attributions': attribution_objs}
+ 'extra_outputs': extra_output_objs}
def infer_mutants_impl(self, info):
"""Performs mutant inference on specified examples."""
diff --git a/tensorboard/plugins/interactive_inference/witwidget/notebook/colab/wit.py b/tensorboard/plugins/interactive_inference/witwidget/notebook/colab/wit.py
index b50e08f7de..6ca2955547 100644
--- a/tensorboard/plugins/interactive_inference/witwidget/notebook/colab/wit.py
+++ b/tensorboard/plugins/interactive_inference/witwidget/notebook/colab/wit.py
@@ -120,8 +120,8 @@ def compute_custom_distance(wit_id, index, callback_name, params):
window.inferenceCallback = inferences => {{
wit.labelVocab = inferences.label_vocab;
wit.inferences = inferences.inferences;
- wit.attributions = {{indices: wit.inferences.indices,
- attributions: inferences.attributions}};
+ wit.extraOutputs = {{indices: wit.inferences.indices,
+ extra: inferences.extra_outputs}};
}};
window.distanceCallback = callbackDict => {{
diff --git a/tensorboard/plugins/interactive_inference/witwidget/notebook/jupyter/js/lib/wit.js b/tensorboard/plugins/interactive_inference/witwidget/notebook/jupyter/js/lib/wit.js
index f30d48d0aa..a4ffdc499a 100644
--- a/tensorboard/plugins/interactive_inference/witwidget/notebook/jupyter/js/lib/wit.js
+++ b/tensorboard/plugins/interactive_inference/witwidget/notebook/jupyter/js/lib/wit.js
@@ -178,9 +178,9 @@ var WITView = widgets.DOMWidgetView.extend({
const inferences = this.model.get('inferences');
this.view_.labelVocab = inferences['label_vocab'];
this.view_.inferences = inferences['inferences'];
- this.view_.attributions = {
+ this.view_.extraOutputs = {
indices: this.view_.inferences.indices,
- attributions: inferences['attributions'],
+ extra: inferences['extra_outputs'],
};
},
eligibleFeaturesChanged: function() {
diff --git a/tensorboard/plugins/interactive_inference/witwidget/notebook/visualization.py b/tensorboard/plugins/interactive_inference/witwidget/notebook/visualization.py
index bbd7e6c270..48b1ac417a 100644
--- a/tensorboard/plugins/interactive_inference/witwidget/notebook/visualization.py
+++ b/tensorboard/plugins/interactive_inference/witwidget/notebook/visualization.py
@@ -414,11 +414,11 @@ def set_custom_predict_fn(self, predict_fn):
- For regression: A 1D list of numbers, with a regression score for each
example being predicted.
- Optionally, if attributions can be returned by the model with each
- prediction, then this method can return a dict with the key 'predictions'
- containing the predictions result list described above, and with the key
- 'attributions' containing a list of attributions for each example that was
- predicted.
+ Optionally, if attributions or other prediction-time information
+ can be returned by the model with each prediction, then this method
+ can return a dict with the key 'predictions' containing the predictions
+ result list described above, and with the key 'attributions' containing
+ a list of attributions for each example that was predicted.
For each example, the attributions list should contain a dict mapping
input feature names to attribution values for that feature on that example.
@@ -432,6 +432,12 @@ def set_custom_predict_fn(self, predict_fn):
a list of attribution values for the corresponding feature values in
the first list.
+ This dict can contain any other keys, with their values being a list of
+ prediction-time strings or numbers for each example being predicted. These
+ values will be displayed in WIT as extra information for each example,
+ usable in the same ways by WIT as normal input features (such as for
+ creating plots and slicing performance data).
+
Args:
predict_fn: The custom python function which will be used for model
inference.
@@ -464,11 +470,11 @@ def set_compare_custom_predict_fn(self, predict_fn):
- For regression: A 1D list of numbers, with a regression score for each
example being predicted.
- Optionally, if attributions can be returned by the model with each
- prediction, then this method can return a dict with the key 'predictions'
- containing the predictions result list described above, and with the key
- 'attributions' containing a list of attributions for each example that was
- predicted.
+ Optionally, if attributions or other prediction-time information
+ can be returned by the model with each prediction, then this method
+ can return a dict with the key 'predictions' containing the predictions
+ result list described above, and with the key 'attributions' containing
+ a list of attributions for each example that was predicted.
For each example, the attributions list should contain a dict mapping
input feature names to attribution values for that feature on that example.
@@ -482,6 +488,12 @@ def set_compare_custom_predict_fn(self, predict_fn):
a list of attribution values for the corresponding feature values in
the first list.
+ This dict can contain any other keys, with their values being a list of
+ prediction-time strings or numbers for each example being predicted. These
+ values will be displayed in WIT as extra information for each example,
+ usable in the same ways by WIT as normal input features (such as for
+ creating plots and slicing performance data).
+
Args:
predict_fn: The custom python function which will be used for model
inference.