diff --git a/tensorboard/plugins/interactive_inference/interactive_inference_plugin.py b/tensorboard/plugins/interactive_inference/interactive_inference_plugin.py index acaea24de3..007c1cc194 100644 --- a/tensorboard/plugins/interactive_inference/interactive_inference_plugin.py +++ b/tensorboard/plugins/interactive_inference/interactive_inference_plugin.py @@ -436,24 +436,24 @@ def _infer_mutants_handler(self, request): (inference_addresses, model_names, model_versions, model_signatures) = self._parse_request_arguments(request) - # TODO(tolgab) Generalize this to multiple models - model_num = 0 - serving_bundle = inference_utils.ServingBundle( - inference_addresses[model_num], - model_names[model_num], - request.args.get('model_type'), - model_versions[model_num], - model_signatures[model_num], - request.args.get('use_predict') == 'true', - request.args.get('predict_input_tensor'), - request.args.get('predict_output_tensor')) + serving_bundles = [] + for model_num in xrange(len(inference_addresses)): + serving_bundles.append(inference_utils.ServingBundle( + inference_addresses[model_num], + model_names[model_num], + request.args.get('model_type'), + model_versions[model_num], + model_signatures[model_num], + request.args.get('use_predict') == 'true', + request.args.get('predict_input_tensor'), + request.args.get('predict_output_tensor'))) viz_params = inference_utils.VizParams( request.args.get('x_min'), request.args.get('x_max'), self.examples[0:NUM_EXAMPLES_TO_SCAN], NUM_MUTANTS, request.args.get('feature_index_pattern')) json_mapping = inference_utils.mutant_charts_for_feature( - examples, feature_name, serving_bundle, viz_params) + examples, feature_name, serving_bundles, viz_params) return http_util.Respond(request, json_mapping, 'application/json') except common_utils.InvalidUserInputError as e: return http_util.Respond(request, {'error': e.message}, diff --git a/tensorboard/plugins/interactive_inference/interactive_inference_plugin_test.py b/tensorboard/plugins/interactive_inference/interactive_inference_plugin_test.py index 5124b020d8..843ae9f124 100644 --- a/tensorboard/plugins/interactive_inference/interactive_inference_plugin_test.py +++ b/tensorboard/plugins/interactive_inference/interactive_inference_plugin_test.py @@ -194,15 +194,15 @@ def test_infer_mutants_handler(self, mock_mutant_charts_for_feature): # A no-op that just passes the example passed to mutant_charts_for_feature # back through. This tests that the URL parameters get processed properly # within infer_mutants_handler. - def pass_through(example, feature_name, serving_bundle, viz_params): + def pass_through(example, feature_name, serving_bundles, viz_params): return { 'example': str(example), 'feature_name': feature_name, - 'serving_bundle': { - 'inference_address': serving_bundle.inference_address, - 'model_name': serving_bundle.model_name, - 'model_type': serving_bundle.model_type, - }, + 'serving_bundles': [{ + 'inference_address': serving_bundles[0].inference_address, + 'model_name': serving_bundles[0].model_name, + 'model_type': serving_bundles[0].model_type, + }], 'viz_params': { 'x_min': viz_params.x_min, 'x_max': viz_params.x_max @@ -229,10 +229,10 @@ def pass_through(example, feature_name, serving_bundle, viz_params): self.assertEqual(str([example]), result['example']) self.assertEqual('single_int', result['feature_name']) self.assertEqual('ml-serving-temp.prediction', - result['serving_bundle']['inference_address']) + result['serving_bundles'][0]['inference_address']) self.assertEqual('/ml/cassandrax/iris_classification', - result['serving_bundle']['model_name']) - self.assertEqual('classification', result['serving_bundle']['model_type']) + result['serving_bundles'][0]['model_name']) + self.assertEqual('classification', result['serving_bundles'][0]['model_type']) self.assertAlmostEqual(-10, result['viz_params']['x_min']) self.assertAlmostEqual(10, result['viz_params']['x_max']) diff --git a/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/demo/tf-interactive-inference-age-demo.html b/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/demo/tf-interactive-inference-age-demo.html index 45bf73018d..f432a906d4 100644 --- a/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/demo/tf-interactive-inference-age-demo.html +++ b/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/demo/tf-interactive-inference-age-demo.html @@ -45,6 +45,7 @@ }, ready: async function() { this.$.dash.settingsClicked_(); + this.$.dash.updateNumberOfModels_(); this.means = { 'age': 38.64358543876172, 'education-num': 10.078088530363212, @@ -267,7 +268,7 @@ results.push({step: step, scalar: adjustedScore}); } this.$.dash.makeChartForFeature_(isNum ? 'numeric' : 'categorical', - e.detail.feature_name, [{'value': results}]); + e.detail.feature_name, [[{'value': results}]]); }; setTimeout(method, 50); }); diff --git a/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/demo/tf-interactive-inference-demo.html b/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/demo/tf-interactive-inference-demo.html index be923ad3b6..95b6b8d425 100644 --- a/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/demo/tf-interactive-inference-demo.html +++ b/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/demo/tf-interactive-inference-demo.html @@ -46,6 +46,7 @@ }, ready: async function() { this.$.dash.settingsClicked_(); + this.$.dash.updateNumberOfModels_(); this.means = { 'age': 38.64358543876172, 'education-num': 10.078088530363212, @@ -276,7 +277,7 @@ this.$.dash.makeChartForFeature_( isNum ? 'numeric' : 'categorical', - e.detail.feature_name, [{'1': results}]); + e.detail.feature_name, [[{'1': results}]]); }; setTimeout(method, 50); }); diff --git a/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/demo/tf-interactive-inference-image-demo.html b/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/demo/tf-interactive-inference-image-demo.html index f004eb2d9e..ada4cd533b 100644 --- a/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/demo/tf-interactive-inference-image-demo.html +++ b/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/demo/tf-interactive-inference-image-demo.html @@ -53,7 +53,7 @@ }, ready: async function() { this.$.dash.settingsClicked_(); - + this.$.dash.updateNumberOfModels_(); this.model = await tf.loadModel(tf.io.browserHTTPRequest( 'data/images/model.json', {credentials: 'include'})); const DATA_PATH = "data/images/smile_examples.json"; diff --git a/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/demo/tf-interactive-inference-iris-demo.html b/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/demo/tf-interactive-inference-iris-demo.html index 9768b066e9..1dbca0ea13 100644 --- a/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/demo/tf-interactive-inference-iris-demo.html +++ b/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/demo/tf-interactive-inference-iris-demo.html @@ -47,6 +47,7 @@ ready: async function() { this.$.dash.multiClass = true; this.$.dash.settingsClicked_(); + this.$.dash.updateNumberOfModels_(); this.model = await tf.loadModel(tf.io.browserHTTPRequest( 'data/iris/model.json', {credentials: 'include'})); const DATA_PATH = "data/iris/iris.json"; @@ -209,7 +210,7 @@ this.$.dash.makeChartForFeature_( isNum ? 'numeric' : 'categorical', e.detail.feature_name, - [{'0': results[0]}, {'1': results[1]},{'2': results[2]}]); + [[{'0': results[0], '1': results[1], '2': results[2]}]]); }; setTimeout(method, 50); }) diff --git a/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/demo/tf-interactive-inference-multi-demo.html b/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/demo/tf-interactive-inference-multi-demo.html index b1f100d594..ac9928f551 100644 --- a/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/demo/tf-interactive-inference-multi-demo.html +++ b/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/demo/tf-interactive-inference-multi-demo.html @@ -30,7 +30,7 @@ width: 100%; } - +