Skip to content

Commit 85519e7

Browse files
tolga-bjameswex
authored andcommitted
Add support for custom prediction function in colab and jupyter notebook modes. (#1842)
* Add custom_predict_fn support * predict_fn model name fix * Update README.md * minor fixes * minor fixes
1 parent 81d4d3d commit 85519e7

File tree

6 files changed

+153
-32
lines changed

6 files changed

+153
-32
lines changed

tensorboard/plugins/interactive_inference/README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ You can use the What-If Tool to analyze a classification or regression
5757
that takes TensorFlow Example or SequenceExample protos
5858
(data points) as inputs directly in a jupyter or colab notebook.
5959

60+
You can also use What-If-Tool with a custom prediction function that takes
61+
Tensorflow examples and produces predictions. In this mode, you can load any model
62+
(including non-TensorFlow models) as long as your custom function's input and output
63+
specifications are correct.
64+
6065
If you want to train an ML model from a dataset and explore the dataset and
6166
model, check out the [What_If_Tool_Notebook_Usage.ipynb notebook](https://colab.research.google.com/github/tensorflow/tensorboard/blob/master/tensorboard/plugins/interactive_inference/What_If_Tool_Notebook_Usage.ipynb) in colab, which starts from a CSV file,
6267
converts the data to tf.Example protos, trains a classifier, and then uses the
@@ -254,11 +259,13 @@ The WitConfigBuilder object takes a list of tf.Example or tf.SequenceExample
254259
protos as a constructor argument. These protos will be shown in the tool and
255260
inferred in the specified model.
256261

257-
The model to be used for inference by the tool can be specified one of two ways:
262+
The model to be used for inference by the tool can be specified one of three ways:
258263
- As a TensorFlow [Estimator](https://www.tensorflow.org/guide/estimators)
259264
object that is provided through the `set_estimator_and_feature_spec` method.
260265
In this case the inference will be done inside the notebook using the
261266
provided estimator.
267+
- As a custom prediction function provided through `set_custom_predict_fn` method.
268+
In this case WIT will directly call the function for inference.
262269
- As an endpoint for a model being served by [TensorFlow Serving](https://github.com/tensorflow/serving),
263270
through the `set_inference_address` and `set_model_name` methods. In this case
264271
the inference will be done on the model server specified. To query a model served

tensorboard/plugins/interactive_inference/utils/inference_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,16 @@ class ServingBundle(object):
181181
Predict API.
182182
estimator: An estimator to use instead of calling an external model.
183183
feature_spec: A feature spec for use with the estimator.
184+
custom_predict_fn: A custom prediction function.
184185
185186
Raises:
186187
ValueError: If ServingBundle fails init validation.
187188
"""
188189

189190
def __init__(self, inference_address, model_name, model_type, model_version,
190191
signature, use_predict, predict_input_tensor,
191-
predict_output_tensor, estimator=None, feature_spec=None):
192+
predict_output_tensor, estimator=None, feature_spec=None,
193+
custom_predict_fn=None):
192194
"""Inits ServingBundle."""
193195
if not isinstance(inference_address, string_types):
194196
raise ValueError('Invalid inference_address has type: {}'.format(
@@ -215,6 +217,7 @@ def __init__(self, inference_address, model_name, model_type, model_version,
215217
self.predict_output_tensor = predict_output_tensor
216218
self.estimator = estimator
217219
self.feature_spec = feature_spec
220+
self.custom_predict_fn = custom_predict_fn
218221

219222

220223
def proto_value_for_feature(example, feature_name):
@@ -750,5 +753,10 @@ def run_inference(examples, serving_bundle):
750753
for pred in preds:
751754
values.append(pred[preds_key])
752755
return common_utils.convert_prediction_values(values, serving_bundle)
756+
elif serving_bundle.custom_predict_fn:
757+
# If custom_predict_fn is provided, pass examples directly for local
758+
# inference.
759+
values = serving_bundle.custom_predict_fn(examples)
760+
return common_utils.convert_prediction_values(values, serving_bundle)
753761
else:
754762
return platform_utils.call_servo(examples, serving_bundle)

tensorboard/plugins/interactive_inference/witwidget/notebook/colab/wit.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,18 @@ def __init__(self, config_builder, height=1000):
195195
if 'compare_estimator_and_spec' in copied_config:
196196
del copied_config['compare_estimator_and_spec']
197197

198+
self.custom_predict_fn = (
199+
config.get('custom_predict_fn')
200+
if 'custom_predict_fn' in config else None)
201+
self.compare_custom_predict_fn = (
202+
config.get('compare_custom_predict_fn')
203+
if 'compare_custom_predict_fn' in config else None)
204+
if 'custom_predict_fn' in copied_config:
205+
del copied_config['custom_predict_fn']
206+
if 'compare_custom_predict_fn' in copied_config:
207+
del copied_config['compare_custom_predict_fn']
208+
209+
198210
self._set_examples(config['examples'])
199211
del copied_config['examples']
200212

@@ -248,11 +260,13 @@ def infer(self):
248260
self.config.get('predict_input_tensor'),
249261
self.config.get('predict_output_tensor'),
250262
self.estimator_and_spec.get('estimator'),
251-
self.estimator_and_spec.get('feature_spec'))
263+
self.estimator_and_spec.get('feature_spec'),
264+
self.custom_predict_fn)
252265
infer_objs.append(inference_utils.run_inference_for_inference_results(
253266
examples_to_infer, serving_bundle))
254267
if ('inference_address_2' in self.config or
255-
self.compare_estimator_and_spec.get('estimator')):
268+
self.compare_estimator_and_spec.get('estimator') or
269+
self.compare_custom_predict_fn):
256270
serving_bundle = inference_utils.ServingBundle(
257271
self.config.get('inference_address_2'),
258272
self.config.get('model_name_2'),
@@ -263,7 +277,8 @@ def infer(self):
263277
self.config.get('predict_input_tensor'),
264278
self.config.get('predict_output_tensor'),
265279
self.compare_estimator_and_spec.get('estimator'),
266-
self.compare_estimator_and_spec.get('feature_spec'))
280+
self.compare_estimator_and_spec.get('feature_spec'),
281+
self.compare_custom_predict_fn)
267282
infer_objs.append(inference_utils.run_inference_for_inference_results(
268283
examples_to_infer, serving_bundle))
269284
self.updated_example_indices = set()
@@ -314,9 +329,11 @@ def infer_mutants(self, info):
314329
self.config.get('predict_input_tensor'),
315330
self.config.get('predict_output_tensor'),
316331
self.estimator_and_spec.get('estimator'),
317-
self.estimator_and_spec.get('feature_spec')))
332+
self.estimator_and_spec.get('feature_spec'),
333+
self.custom_predict_fn))
318334
if ('inference_address_2' in self.config or
319-
self.compare_estimator_and_spec.get('estimator')):
335+
self.compare_estimator_and_spec.get('estimator') or
336+
self.compare_custom_predict_fn):
320337
serving_bundles.append(inference_utils.ServingBundle(
321338
self.config.get('inference_address_2'),
322339
self.config.get('model_name_2'),
@@ -327,7 +344,8 @@ def infer_mutants(self, info):
327344
self.config.get('predict_input_tensor'),
328345
self.config.get('predict_output_tensor'),
329346
self.compare_estimator_and_spec.get('estimator'),
330-
self.compare_estimator_and_spec.get('feature_spec')))
347+
self.compare_estimator_and_spec.get('feature_spec'),
348+
self.compare_custom_predict_fn))
331349
viz_params = inference_utils.VizParams(
332350
info['x_min'], info['x_max'],
333351
scan_examples, 10,

tensorboard/plugins/interactive_inference/witwidget/notebook/jupyter/wit.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,17 @@ def __init__(self, config_builder, height=1000):
7373
if 'compare_estimator_and_spec' in copied_config:
7474
del copied_config['compare_estimator_and_spec']
7575

76+
self.custom_predict_fn = (
77+
config.get('custom_predict_fn')
78+
if 'custom_predict_fn' in config else None)
79+
self.compare_custom_predict_fn = (
80+
config.get('compare_custom_predict_fn')
81+
if 'compare_custom_predict_fn' in config else None)
82+
if 'custom_predict_fn' in copied_config:
83+
del copied_config['custom_predict_fn']
84+
if 'compare_custom_predict_fn' in copied_config:
85+
del copied_config['compare_custom_predict_fn']
86+
7687
self._set_examples(config['examples'])
7788
del copied_config['examples']
7889

@@ -109,11 +120,13 @@ def _infer(self, change):
109120
self.config.get('predict_input_tensor'),
110121
self.config.get('predict_output_tensor'),
111122
self.estimator_and_spec.get('estimator'),
112-
self.estimator_and_spec.get('feature_spec'))
123+
self.estimator_and_spec.get('feature_spec'),
124+
self.custom_predict_fn)
113125
infer_objs.append(inference_utils.run_inference_for_inference_results(
114126
examples_to_infer, serving_bundle))
115127
if ('inference_address_2' in self.config or
116-
self.compare_estimator_and_spec.get('estimator')):
128+
self.compare_estimator_and_spec.get('estimator') or
129+
self.compare_custom_predict_fn):
117130
serving_bundle = inference_utils.ServingBundle(
118131
self.config.get('inference_address_2'),
119132
self.config.get('model_name_2'),
@@ -124,7 +137,8 @@ def _infer(self, change):
124137
self.config.get('predict_input_tensor'),
125138
self.config.get('predict_output_tensor'),
126139
self.compare_estimator_and_spec.get('estimator'),
127-
self.compare_estimator_and_spec.get('feature_spec'))
140+
self.compare_estimator_and_spec.get('feature_spec'),
141+
self.compare_custom_predict_fn)
128142
infer_objs.append(inference_utils.run_inference_for_inference_results(
129143
examples_to_infer, serving_bundle))
130144
self.updated_example_indices = set()
@@ -160,9 +174,11 @@ def _infer_mutants(self, change):
160174
self.config.get('predict_input_tensor'),
161175
self.config.get('predict_output_tensor'),
162176
self.estimator_and_spec.get('estimator'),
163-
self.estimator_and_spec.get('feature_spec')))
177+
self.estimator_and_spec.get('feature_spec'),
178+
self.custom_predict_fn))
164179
if ('inference_address_2' in self.config or
165-
self.compare_estimator_and_spec.get('estimator')):
180+
self.compare_estimator_and_spec.get('estimator') or
181+
self.compare_custom_predict_fn):
166182
serving_bundles.append(inference_utils.ServingBundle(
167183
self.config.get('inference_address_2'),
168184
self.config.get('model_name_2'),
@@ -173,7 +189,8 @@ def _infer_mutants(self, change):
173189
self.config.get('predict_input_tensor'),
174190
self.config.get('predict_output_tensor'),
175191
self.compare_estimator_and_spec.get('estimator'),
176-
self.compare_estimator_and_spec.get('feature_spec')))
192+
self.compare_estimator_and_spec.get('feature_spec'),
193+
self.compare_custom_predict_fn))
177194
viz_params = inference_utils.VizParams(
178195
info['x_min'], info['x_max'],
179196
scan_examples, 10,

0 commit comments

Comments
 (0)