diff --git a/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/tf-interactive-inference-dashboard.html b/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/tf-interactive-inference-dashboard.html index 814ff272e3..4ecd1ee720 100644 --- a/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/tf-interactive-inference-dashboard.html +++ b/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/tf-interactive-inference-dashboard.html @@ -3365,6 +3365,21 @@

Show similarity to selected datapoint

observer: 'newInferences_', value: () => ({}), }, + // Extra outputs from inference. A dict with two fields: 'indices' and + // 'extra'. Indices contains a list of example indices that + // these new outputs apply to. Extra contains a list of extra output + // objects, one for each model being inferred. The object for each + // model is a dict of output data names to lists of the output values + // for that data, one entry for each example that was inferred upon. + // 'attributions' is one of these output data which is parsed into the + // 'attributions' object defined below as a special case. Any other extra + // data provided are displayed by WIT with each example. + // @type {indices: Array, + // extra: Array<{!Object>}>} + extraOutputs: { + type: Object, + observer: 'newExtraOutputs_', + }, // Attributions from inference. A dict with two fields: 'indices' and // 'attributions'. Indices contains a list of example indices that // these new attributions apply to. Attributions contains a list of @@ -3944,12 +3959,16 @@

Show similarity to selected datapoint

} else { this.comparedIndices = []; this.counterfactualExampleAndInference = null; - const temp = this.selectedExampleAndInference; - this.selectedExampleAndInference = null; - this.selectedExampleAndInference = temp; + this.refreshSelectedDatapoint_(); } }, + refreshSelectedDatapoint_: function() { + const temp = this.selectedExampleAndInference; + this.selectedExampleAndInference = null; + this.selectedExampleAndInference = temp; + }, + isSameInferenceClass_: function(val1, val2) { return this.isRegression_(this.modelType) ? Math.abs(val1 - val2) < this.minCounterfactualValueDist @@ -6108,6 +6127,77 @@

Show similarity to selected datapoint

this.updatedExample = false; }, + newExtraOutputs_: function(extraOutputs) { + // Set attributions from the extra outputs, if available. + const attributions = []; + for ( + let modelNum = 0; + modelNum < extraOutputs.extra.length; + modelNum++ + ) { + if ('attributions' in extraOutputs.extra[modelNum]) { + attributions.push(extraOutputs.extra[modelNum].attributions); + } + } + if (attributions.length > 0) { + this.attributions = { + indices: extraOutputs.indices, + attributions: attributions, + }; + } + + // Add extra output information to datapoints + for (let i = 0; i < extraOutputs.indices.length; i++) { + const idx = extraOutputs.indices[i]; + const datapoint = Object.assign({}, this.visdata[idx]); + for ( + let modelNum = 0; + modelNum < extraOutputs.extra.length; + modelNum++ + ) { + const keys = Object.keys(extraOutputs.extra[modelNum]); + for (let j = 0; j < keys.length; j++) { + const key = keys[j]; + // Skip attributions as they are handled separately above. + if (key == 'attributions') { + continue; + } + let val = extraOutputs.extra[modelNum][key][i]; + const datapointKey = this.strWithModelName_(key, modelNum); + + // Update the datapoint with the extra info for use in + // Facets Dive. + datapoint[datapointKey] = val; + + // Convert the extra output into an array if necessary, for + // insertion into tf.Example as a value list, for update of + // examplesAndInferences for the example viewer. + if (!Array.isArray(val)) { + val = [val]; + } + const isString = + val.length > 0 && + (typeof val[0] == 'string' || val[0] instanceof String); + const exampleJsonString = JSON.stringify( + this.examplesAndInferences[idx].example + ); + const copiedExample = JSON.parse(exampleJsonString); + copiedExample.features.feature[datapointKey] = isString + ? {bytesList: {value: val}} + : {floatList: {value: val}}; + this.examplesAndInferences[idx].example = copiedExample; + } + } + this.set(`visdata.${idx}`, datapoint); + } + this.refreshDive_(); + + // Update selected datapoint so that if a datapoint is being viewed, + // the display is updated with the appropriate extra output. + this.computeSelectedExampleAndInference(); + this.refreshSelectedDatapoint_(); + }, + newAttributions_: function(attributions) { if (Object.keys(attributions).length == 0) { return; diff --git a/tensorboard/plugins/interactive_inference/utils/inference_utils.py b/tensorboard/plugins/interactive_inference/utils/inference_utils.py index a0af0f8e67..3ff6ac7988 100644 --- a/tensorboard/plugins/interactive_inference/utils/inference_utils.py +++ b/tensorboard/plugins/interactive_inference/utils/inference_utils.py @@ -615,12 +615,12 @@ def get_example_features(example): def run_inference_for_inference_results(examples, serving_bundle): """Calls servo and wraps the inference results.""" - (inference_result_proto, attributions) = run_inference( + (inference_result_proto, extra_results) = run_inference( examples, serving_bundle) inferences = wrap_inference_results(inference_result_proto) infer_json = json_format.MessageToJson( inferences, including_default_value_fields=True) - return json.loads(infer_json), attributions + return json.loads(infer_json), extra_results def get_eligible_features(examples, num_mutants): """Returns a list of JSON objects for each feature in the examples. @@ -795,8 +795,8 @@ def run_inference(examples, serving_bundle): Returns: A tuple with the first entry being the ClassificationResponse or - RegressionResponse proto and the second entry being a list of the - attributions for each example, or None if no attributions exist. + RegressionResponse proto and the second entry being a dictionary of extra + data for each example, such as attributions, or None if no data exists. """ batch_size = 64 if serving_bundle.estimator and serving_bundle.feature_spec: @@ -822,14 +822,16 @@ def run_inference(examples, serving_bundle): # If custom_predict_fn is provided, pass examples directly for local # inference. values = serving_bundle.custom_predict_fn(examples) - attributions = None + extra_results = None # If the custom prediction function returned a dict, then parse out the - # prediction scores and the attributions. If it is just a list, then the - # results are the prediction results without attributions. + # prediction scores. If it is just a list, then the results are the + # prediction results without attributions or other data. if isinstance(values, dict): - attributions = values['attributions'] - values = values['predictions'] - return (common_utils.convert_prediction_values(values, serving_bundle), - attributions) + preds = values.pop('predictions') + extra_results = values + else: + preds = values + return (common_utils.convert_prediction_values(preds, serving_bundle), + extra_results) else: return (platform_utils.call_servo(examples, serving_bundle), None) diff --git a/tensorboard/plugins/interactive_inference/witwidget/notebook/base.py b/tensorboard/plugins/interactive_inference/witwidget/notebook/base.py index 297aa11172..72632da730 100644 --- a/tensorboard/plugins/interactive_inference/witwidget/notebook/base.py +++ b/tensorboard/plugins/interactive_inference/witwidget/notebook/base.py @@ -123,7 +123,7 @@ def infer_impl(self): examples_to_infer = [ self.json_to_proto(self.examples[index]) for index in indices_to_infer] infer_objs = [] - attribution_objs = [] + extra_output_objs = [] serving_bundle = inference_utils.ServingBundle( self.config.get('inference_address'), self.config.get('model_name'), @@ -136,11 +136,11 @@ def infer_impl(self): self.estimator_and_spec.get('estimator'), self.estimator_and_spec.get('feature_spec'), self.custom_predict_fn) - (predictions, attributions) = ( + (predictions, extra_output) = ( inference_utils.run_inference_for_inference_results( examples_to_infer, serving_bundle)) infer_objs.append(predictions) - attribution_objs.append(attributions) + extra_output_objs.append(extra_output) if ('inference_address_2' in self.config or self.compare_estimator_and_spec.get('estimator') or self.compare_custom_predict_fn): @@ -156,16 +156,16 @@ def infer_impl(self): self.compare_estimator_and_spec.get('estimator'), self.compare_estimator_and_spec.get('feature_spec'), self.compare_custom_predict_fn) - (predictions, attributions) = ( + (predictions, extra_output) = ( inference_utils.run_inference_for_inference_results( examples_to_infer, serving_bundle)) infer_objs.append(predictions) - attribution_objs.append(attributions) + extra_output_objs.append(extra_output) self.updated_example_indices = set() return { 'inferences': {'indices': indices_to_infer, 'results': infer_objs}, 'label_vocab': self.config.get('label_vocab'), - 'attributions': attribution_objs} + 'extra_outputs': extra_output_objs} def infer_mutants_impl(self, info): """Performs mutant inference on specified examples.""" diff --git a/tensorboard/plugins/interactive_inference/witwidget/notebook/colab/wit.py b/tensorboard/plugins/interactive_inference/witwidget/notebook/colab/wit.py index b50e08f7de..6ca2955547 100644 --- a/tensorboard/plugins/interactive_inference/witwidget/notebook/colab/wit.py +++ b/tensorboard/plugins/interactive_inference/witwidget/notebook/colab/wit.py @@ -120,8 +120,8 @@ def compute_custom_distance(wit_id, index, callback_name, params): window.inferenceCallback = inferences => {{ wit.labelVocab = inferences.label_vocab; wit.inferences = inferences.inferences; - wit.attributions = {{indices: wit.inferences.indices, - attributions: inferences.attributions}}; + wit.extraOutputs = {{indices: wit.inferences.indices, + extra: inferences.extra_outputs}}; }}; window.distanceCallback = callbackDict => {{ diff --git a/tensorboard/plugins/interactive_inference/witwidget/notebook/jupyter/js/lib/wit.js b/tensorboard/plugins/interactive_inference/witwidget/notebook/jupyter/js/lib/wit.js index f30d48d0aa..a4ffdc499a 100644 --- a/tensorboard/plugins/interactive_inference/witwidget/notebook/jupyter/js/lib/wit.js +++ b/tensorboard/plugins/interactive_inference/witwidget/notebook/jupyter/js/lib/wit.js @@ -178,9 +178,9 @@ var WITView = widgets.DOMWidgetView.extend({ const inferences = this.model.get('inferences'); this.view_.labelVocab = inferences['label_vocab']; this.view_.inferences = inferences['inferences']; - this.view_.attributions = { + this.view_.extraOutputs = { indices: this.view_.inferences.indices, - attributions: inferences['attributions'], + extra: inferences['extra_outputs'], }; }, eligibleFeaturesChanged: function() { diff --git a/tensorboard/plugins/interactive_inference/witwidget/notebook/visualization.py b/tensorboard/plugins/interactive_inference/witwidget/notebook/visualization.py index bbd7e6c270..48b1ac417a 100644 --- a/tensorboard/plugins/interactive_inference/witwidget/notebook/visualization.py +++ b/tensorboard/plugins/interactive_inference/witwidget/notebook/visualization.py @@ -414,11 +414,11 @@ def set_custom_predict_fn(self, predict_fn): - For regression: A 1D list of numbers, with a regression score for each example being predicted. - Optionally, if attributions can be returned by the model with each - prediction, then this method can return a dict with the key 'predictions' - containing the predictions result list described above, and with the key - 'attributions' containing a list of attributions for each example that was - predicted. + Optionally, if attributions or other prediction-time information + can be returned by the model with each prediction, then this method + can return a dict with the key 'predictions' containing the predictions + result list described above, and with the key 'attributions' containing + a list of attributions for each example that was predicted. For each example, the attributions list should contain a dict mapping input feature names to attribution values for that feature on that example. @@ -432,6 +432,12 @@ def set_custom_predict_fn(self, predict_fn): a list of attribution values for the corresponding feature values in the first list. + This dict can contain any other keys, with their values being a list of + prediction-time strings or numbers for each example being predicted. These + values will be displayed in WIT as extra information for each example, + usable in the same ways by WIT as normal input features (such as for + creating plots and slicing performance data). + Args: predict_fn: The custom python function which will be used for model inference. @@ -464,11 +470,11 @@ def set_compare_custom_predict_fn(self, predict_fn): - For regression: A 1D list of numbers, with a regression score for each example being predicted. - Optionally, if attributions can be returned by the model with each - prediction, then this method can return a dict with the key 'predictions' - containing the predictions result list described above, and with the key - 'attributions' containing a list of attributions for each example that was - predicted. + Optionally, if attributions or other prediction-time information + can be returned by the model with each prediction, then this method + can return a dict with the key 'predictions' containing the predictions + result list described above, and with the key 'attributions' containing + a list of attributions for each example that was predicted. For each example, the attributions list should contain a dict mapping input feature names to attribution values for that feature on that example. @@ -482,6 +488,12 @@ def set_compare_custom_predict_fn(self, predict_fn): a list of attribution values for the corresponding feature values in the first list. + This dict can contain any other keys, with their values being a list of + prediction-time strings or numbers for each example being predicted. These + values will be displayed in WIT as extra information for each example, + usable in the same ways by WIT as normal input features (such as for + creating plots and slicing performance data). + Args: predict_fn: The custom python function which will be used for model inference.