Skip to content

Commit ded6760

Browse files
authored
Fix What-If Tool PD plots in py3 (#2669)
1 parent 09298ed commit ded6760

File tree

4 files changed

+23
-11
lines changed

4 files changed

+23
-11
lines changed

tensorboard/plugins/interactive_inference/utils/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ py_library(
3131
":platform_utils",
3232
"//tensorboard:expect_absl_logging_installed",
3333
"//tensorboard:expect_tensorflow_installed",
34+
"@org_pythonhosted_six",
3435
"@org_tensorflow_serving_api",
3536
],
3637
)
@@ -55,6 +56,7 @@ py_test(
5556
"//tensorboard:expect_numpy_installed",
5657
"//tensorboard:expect_tensorflow_installed",
5758
"@org_pythonhosted_mock",
59+
"@org_pythonhosted_six",
5860
"@org_tensorflow_serving_api",
5961
],
6062
)

tensorboard/plugins/interactive_inference/utils/inference_utils.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
import numpy as np
2626
import tensorflow as tf
2727
from google.protobuf import json_format
28+
from six import binary_type, string_types, integer_types
2829
from six import iteritems
29-
from six import string_types, integer_types
3030
from six.moves import zip # pylint: disable=redefined-builtin
3131

3232
from tensorboard.plugins.interactive_inference.utils import common_utils
@@ -125,7 +125,8 @@ class OriginalFeatureList(object):
125125
def __init__(self, feature_name, original_value, feature_type):
126126
"""Inits OriginalFeatureList."""
127127
self.feature_name = feature_name
128-
self.original_value = original_value
128+
self.original_value = [
129+
ensure_not_binary(value) for value in original_value]
129130
self.feature_type = feature_type
130131

131132
# Derived attributes.
@@ -164,7 +165,8 @@ def __init__(self, original_feature, index, mutant_value):
164165
'index should be None or int, but had unexpected type: {}'.format(
165166
type(index)))
166167
self.index = index
167-
self.mutant_value = mutant_value
168+
self.mutant_value = (mutant_value.encode()
169+
if isinstance(mutant_value, string_types) else mutant_value)
168170

169171

170172
class ServingBundle(object):
@@ -226,6 +228,11 @@ def __init__(self, inference_address, model_name, model_type, model_version,
226228
self.custom_predict_fn = custom_predict_fn
227229

228230

231+
def ensure_not_binary(value):
232+
"""Return non-binary version of value."""
233+
return value.decode() if isinstance(value, binary_type) else value
234+
235+
229236
def proto_value_for_feature(example, feature_name):
230237
"""Get the value of a feature from Example regardless of feature type."""
231238
feature = get_example_features(example)[feature_name]
@@ -563,9 +570,10 @@ def make_json_formatted_for_single_chart(mutant_features,
563570
key += ' (index %d)' % index_to_mutate
564571
if not key in series:
565572
series[key] = {}
566-
if not mutant_feature.mutant_value in series[key]:
567-
series[key][mutant_feature.mutant_value] = []
568-
series[key][mutant_feature.mutant_value].append(
573+
mutant_val = ensure_not_binary(mutant_feature.mutant_value)
574+
if not mutant_val in series[key]:
575+
series[key][mutant_val] = []
576+
series[key][mutant_val].append(
569577
classification_class.score)
570578

571579
# Post-process points to have separate list for each class
@@ -589,9 +597,10 @@ def make_json_formatted_for_single_chart(mutant_features,
589597
# results. So, modding by len(mutant_features) allows us to correctly
590598
# lookup the mutant value for each inference.
591599
mutant_feature = mutant_features[idx % len(mutant_features)]
592-
if not mutant_feature.mutant_value in points:
593-
points[mutant_feature.mutant_value] = []
594-
points[mutant_feature.mutant_value].append(regression.value)
600+
mutant_val = ensure_not_binary(mutant_feature.mutant_value)
601+
if not mutant_val in points:
602+
points[mutant_val] = []
603+
points[mutant_val].append(regression.value)
595604
key = 'value'
596605
if (index_to_mutate != 0):
597606
key += ' (index %d)' % index_to_mutate

tensorboard/plugins/interactive_inference/utils/inference_utils_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def test_get_categorical_features_to_sampling(self):
177177
examples[0: 3], top_k=1)
178178
self.assertDictEqual({
179179
'non_numeric': {
180-
'samples': [b'cat']
180+
'samples': ['cat']
181181
}
182182
}, data)
183183

@@ -186,7 +186,7 @@ def test_get_categorical_features_to_sampling(self):
186186
examples[0: 20], top_k=2)
187187
self.assertDictEqual({
188188
'non_numeric': {
189-
'samples': [b'pony', b'cow']
189+
'samples': ['pony', 'cow']
190190
}
191191
}, data)
192192

tensorboard/plugins/interactive_inference/witwidget/pip_package/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
'google-api-python-client>=1.7.8',
4848
'ipywidgets>=7.0.0',
4949
'jupyter>=1.0,<2',
50+
'six>=1.12.0',
5051
] + _TF_REQ
5152

5253
def get_readme():

0 commit comments

Comments
 (0)