Skip to content

Commit c6fc8d1

Browse files
authored
Add sampling option to what-if tool example loading (#1504)
* Add flags to TBContext * Add flags to TBContext in application * Add flags to tbcontext * gitignore changes * jwexler updated rules_closure version * Add sampling to example loading * fix test
1 parent e55a65b commit c6fc8d1

File tree

5 files changed

+66
-25
lines changed

5 files changed

+66
-25
lines changed

tensorboard/plugins/interactive_inference/interactive_inference_plugin.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,16 @@ def _examples_from_path_handler(self, request):
124124
"""
125125
examples_count = int(request.args.get('max_examples'))
126126
examples_path = request.args.get('examples_path')
127+
sampling_odds = float(request.args.get('sampling_odds'))
127128
try:
128129
platform_utils.throw_if_file_access_not_allowed(examples_path,
129130
self._logdir,
130131
self._has_auth_group)
131132
example_strings = platform_utils.example_protos_from_path(
132-
examples_path, examples_count, parse_examples=False)
133+
examples_path, examples_count, parse_examples=False,
134+
sampling_odds=sampling_odds)
133135
self.examples = [
134-
tf.train.Example.FromString(ex) for ex in example_strings]
136+
tf.train.Example.FromString(ex) for ex in example_strings]
135137
self.generate_sprite(example_strings)
136138
json_examples = [
137139
json_format.MessageToJson(example) for example in self.examples
@@ -404,4 +406,4 @@ def _infer_mutants_handler(self, request):
404406
return http_util.Respond(request, json_mapping, 'application/json')
405407
except common_utils.InvalidUserInputError as e:
406408
return http_util.Respond(request, {'error': e.message},
407-
'application/json', code=400)
409+
'application/json', code=400)

tensorboard/plugins/interactive_inference/interactive_inference_plugin_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def test_examples_from_path(self):
7676
'/data/plugin/whatif/examples_from_path?' +
7777
urllib_parse.urlencode({
7878
'examples_path': examples_path,
79-
'max_examples': 2
79+
'max_examples': 2,
80+
'sampling_odds': 1,
8081
}))
8182
self.assertEqual(200, response.status_code)
8283
example_strings = json.loads(response.get_data().decode('utf-8'))['examples']
@@ -94,7 +95,8 @@ def test_examples_from_path_if_path_does_not_exist(self):
9495
'/data/plugin/whatif/examples_from_path?' +
9596
urllib_parse.urlencode({
9697
'examples_path': 'does_not_exist',
97-
'max_examples': 2
98+
'max_examples': 2,
99+
'sampling_odds': 1,
98100
}))
99101
error = json.loads(response.get_data().decode('utf-8'))['error']
100102
self.assertTrue(error)

tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/tf-inference-panel.html

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@
6767
.input-in-row {
6868
margin-right: 10px;
6969
}
70+
.flex-grow {
71+
flex-grow: 1;
72+
}
7073
.model-type-label {
7174
padding-top: 10px;
7275
}
@@ -105,11 +108,16 @@
105108
<paper-input always-float-label label="Path to examples"
106109
value="{{examplesPath}}">
107110
</paper-input>
108-
<paper-input always-float-label type="number"
109-
label="Maximum number of examples to load"
110-
placeholder="[[maxExamples]]" value="{{maxExamples}}">
111-
</paper-input>
112-
111+
<div class="flex-holder">
112+
<paper-input always-float-label type="number" class="input-in-row flex-grow"
113+
label="Maximum number of examples to load"
114+
placeholder="[[maxExamples]]" value="{{maxExamples}}">
115+
</paper-input>
116+
<paper-input always-float-label type="number" class="input-in-row flex-grow"
117+
label="Sampling ratio (0.2 = sample ~20% of examples)"
118+
placeholder="[[samplingOdds]]" value="{{samplingOdds}}">
119+
</paper-input>
120+
</div>
113121
<paper-input always-float-label label="Path to label dictionary (optional)"
114122
placeholder="[[labelVocabPath]]"
115123
value="{{labelVocabPath}}"
@@ -129,7 +137,7 @@
129137
<div class="flex-holder">
130138
<paper-input always-float-label type="number" label="Max classes to display"
131139
placeholder="[[maxClassesToDisplay]]" value="{{maxClassesToDisplay}}"
132-
class="input-in-row" disabled="[[shouldDisableClassificationControls_(modelType)]]">
140+
class="input-in-row" disabled="[[shouldDisableMultiClassControls_(multiClass)]]">
133141
</paper-input>
134142
<paper-checkbox disabled="[[shouldDisableClassificationControls_(modelType)]]"
135143
checked="{{multiClass}}"
@@ -148,6 +156,8 @@
148156
const defaultModelType = 'classification';
149157
const defaultMaxExamples = '1000';
150158
const defaultLabelVocabPath = '';
159+
const defaultMaxClassesToDisplay = '5';
160+
const defaultSamplingOdds = '1';
151161

152162
Polymer({
153163
is: "tf-inference-panel",
@@ -198,7 +208,7 @@
198208
maxExamples: {
199209
type: Number,
200210
value: tf_storage.getStringInitializer(
201-
'maxExamples', {defaultValue: String(defaultMaxExamples)}),
211+
'maxExamples', {defaultValue: defaultMaxExamples}),
202212
observer: 'maxExamplesChanged_',
203213
notify: true,
204214
},
@@ -216,7 +226,16 @@
216226
},
217227
maxClassesToDisplay: {
218228
type: Number,
219-
value: 5,
229+
value: tf_storage.getStringInitializer(
230+
'maxClassesToDisplay', {defaultValue: defaultMaxClassesToDisplay}),
231+
observer: 'maxClassesToDisplayChanged_',
232+
notify: true,
233+
},
234+
samplingOdds: {
235+
type: Number,
236+
value: tf_storage.getStringInitializer(
237+
'samplingOdds', {defaultValue: defaultSamplingOdds}),
238+
observer: 'samplingOddsChanged_',
220239
notify: true,
221240
},
222241
},
@@ -245,9 +264,19 @@
245264
labelVocabPathChanged_: tf_storage.getStringObserver(
246265
'labelVocabPath', {defaultValue: defaultLabelVocabPath}),
247266

267+
maxClassesToDisplayChanged_: tf_storage.getStringObserver(
268+
'maxClassesToDisplay', {defaultValue: defaultMaxClassesToDisplay}),
269+
270+
samplingOddsChanged_: tf_storage.getStringObserver(
271+
'samplingOdds', {defaultValue: defaultSamplingOdds}),
272+
248273
shouldDisableClassificationControls_: function(modelType) {
249274
return modelType == 'regression';
250275
},
276+
277+
shouldDisableMultiClassControls_: function(multiClass) {
278+
return !multiClass;
279+
}
251280
});
252281

253282
</script>

tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/tf-interactive-inference-dashboard.html

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,7 @@
831831
max-examples="{{maxExamples}}"
832832
label-vocab-path="{{labelVocabPath}}"
833833
multi-class="{{multiClass}}"
834+
sampling-odds="{{samplingOdds}}"
834835
max-classes-to-display="{{maxInferenceEntriesPerRun}}">
835836
</tf-inference-panel>
836837
<div class="accept-button-holder">
@@ -1669,7 +1670,10 @@ <h2>Create a distance feature</h2>
16691670
// If the classification model is a multi-class model.
16701671
multiClass: {
16711672
type: Boolean,
1672-
value: false,
1673+
},
1674+
// Sampling odds (1: load all examples, .2: sample 20% of examples)
1675+
samplingOdds: {
1676+
type: Number,
16731677
},
16741678
// Precision on charts for performance measuring.
16751679
axisPrecision: {
@@ -3226,7 +3230,8 @@ <h2>Create a distance feature</h2>
32263230
getExamples_: function(){
32273231
var url = this.makeUrl_('/data/plugin/whatif/examples_from_path',
32283232
{'examples_path': this.examplesPath,
3229-
'max_examples': this.maxExamples});
3233+
'max_examples': this.maxExamples,
3234+
'sampling_odds': this.samplingOdds});
32303235

32313236
const updateExampleContents = result => {
32323237
this.updateExampleContents_(

tensorboard/plugins/interactive_inference/utils/platform_utils.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from glob import glob
1818
from grpc.beta import implementations
19+
import random
1920
from six.moves.urllib.parse import urlparse
2021
import tensorflow as tf
2122

@@ -62,15 +63,17 @@ def throw_if_file_access_not_allowed(file_path, logdir, has_auth_group):
6263
def example_protos_from_path(cns_path,
6364
num_examples=10,
6465
start_index=0,
65-
parse_examples=True):
66+
parse_examples=True,
67+
sampling_odds=1):
6668
"""Returns a number of tf.train.Examples from the CNS path.
6769
6870
Args:
6971
cns_path: A string CNS path.
7072
num_examples: The maximum number of examples to return from the path.
71-
start_index: The index of the first example to return.
7273
parse_examples: If true then parses the serialized proto from the path into
7374
proto objects. Defaults to True.
75+
sampling_odds: Odds of loading an example, used for sampling. When >= 1
76+
(the default), then all examples are loaded.
7477
7578
Returns:
7679
A list of Example protos or serialized proto strings at the CNS path.
@@ -80,8 +83,8 @@ def example_protos_from_path(cns_path,
8083
"""
8184

8285
def append_examples_from_iterable(iterable, examples):
83-
for i, value in enumerate(iterable):
84-
if i >= start_index:
86+
for value in iterable:
87+
if sampling_odds >= 1 or random.random() < sampling_odds:
8588
examples.append(
8689
tf.train.Example.FromString(value) if parse_examples else value)
8790
if len(examples) >= num_examples:
@@ -90,19 +93,19 @@ def append_examples_from_iterable(iterable, examples):
9093
filenames = filepath_to_filepath_list(cns_path)
9194
examples = []
9295
compression_types = [
93-
tf.python_io.TFRecordCompressionType.NONE,
94-
tf.python_io.TFRecordCompressionType.GZIP,
95-
tf.python_io.TFRecordCompressionType.ZLIB,
96+
tf.python_io.TFRecordCompressionType.NONE,
97+
tf.python_io.TFRecordCompressionType.GZIP,
98+
tf.python_io.TFRecordCompressionType.ZLIB,
9699
]
97100
current_compression_idx = 0
98101
current_file_index = 0
99102
while (current_file_index < len(filenames) and
100103
current_compression_idx < len(compression_types)):
101104
try:
102105
record_iterator = tf.python_io.tf_record_iterator(
103-
path=filenames[current_file_index],
104-
options=tf.python_io.TFRecordOptions(
105-
compression_types[current_compression_idx]))
106+
path=filenames[current_file_index],
107+
options=tf.python_io.TFRecordOptions(
108+
compression_types[current_compression_idx]))
106109
append_examples_from_iterable(record_iterator, examples)
107110
current_file_index += 1
108111
if len(examples) >= num_examples:

0 commit comments

Comments
 (0)