Skip to content

Commit 4670696

Browse files
authored
data: perform downsampling in multiplexer provider (#3272)
Summary: The `MultiplexerDataProvider` now respects its `downsample` parameter, even though the backing `PluginEventMultiplexer` already performs its own sampling. This serves two purposes: - It enforces that clients are always specifying the `downsample` argument, which is required. - It enables us to test plugins’ downsampling parameters to verify that they will behave correctly with other data providers. Test Plan: Unit tests included. Note that changing the `_DEFAULT_DOWNSAMPLING` constant in (e.g.) the scalars plugin to a small number (like `5`) now actually causes charts in the frontend to be downsampled. wchargin-branch: data-mux-downsample
1 parent 5b7f9ad commit 4670696

File tree

3 files changed

+112
-15
lines changed

3 files changed

+112
-15
lines changed

tensorboard/backend/event_processing/data_provider.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import base64
2222
import collections
2323
import json
24+
import random
2425

2526
import six
2627

@@ -57,6 +58,16 @@ def _validate_experiment_id(self, experiment_id):
5758
% (str, type(experiment_id), experiment_id)
5859
)
5960

61+
def _validate_downsample(self, downsample):
62+
if downsample is None:
63+
raise TypeError("`downsample` required but not given")
64+
if isinstance(downsample, int):
65+
return # OK
66+
raise TypeError(
67+
"`downsample` must be an int, but got %r: %r"
68+
% (type(downsample), downsample)
69+
)
70+
6071
def _test_run_tag(self, run_tag_filter, run, tag):
6172
runs = run_tag_filter.runs
6273
if runs is not None and run not in runs:
@@ -109,14 +120,11 @@ def list_scalars(self, experiment_id, plugin_name, run_tag_filter=None):
109120
def read_scalars(
110121
self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
111122
):
112-
# TODO(@wchargin): Downsampling not implemented, as the multiplexer
113-
# is already downsampled. We could downsample on top of the existing
114-
# sampling, which would be nice for testing.
115-
del downsample # ignored for now
123+
self._validate_downsample(downsample)
116124
index = self.list_scalars(
117125
experiment_id, plugin_name, run_tag_filter=run_tag_filter
118126
)
119-
return self._read(_convert_scalar_event, index)
127+
return self._read(_convert_scalar_event, index, downsample)
120128

121129
def list_tensors(self, experiment_id, plugin_name, run_tag_filter=None):
122130
self._validate_experiment_id(experiment_id)
@@ -131,14 +139,11 @@ def list_tensors(self, experiment_id, plugin_name, run_tag_filter=None):
131139
def read_tensors(
132140
self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
133141
):
134-
# TODO(@wchargin): Downsampling not implemented, as the multiplexer
135-
# is already downsampled. We could downsample on top of the existing
136-
# sampling, which would be nice for testing.
137-
del downsample # ignored for now
142+
self._validate_downsample(downsample)
138143
index = self.list_tensors(
139144
experiment_id, plugin_name, run_tag_filter=run_tag_filter
140145
)
141-
return self._read(_convert_tensor_event, index)
146+
return self._read(_convert_tensor_event, index, downsample)
142147

143148
def _list(
144149
self,
@@ -191,13 +196,15 @@ def _list(
191196
)
192197
return result
193198

194-
def _read(self, convert_event, index):
199+
def _read(self, convert_event, index, downsample):
195200
"""Helper to read scalar or tensor data from the multiplexer.
196201
197202
Args:
198203
convert_event: Takes `plugin_event_accumulator.TensorEvent` to
199204
either `provider.ScalarDatum` or `provider.TensorDatum`.
200205
index: The result of `list_scalars` or `list_tensors`.
206+
downsample: Non-negative `int`; how many samples to return per
207+
time series.
201208
202209
Returns:
203210
A dict of dicts of values returned by `convert_event` calls,
@@ -209,7 +216,8 @@ def _read(self, convert_event, index):
209216
result[run] = result_for_run
210217
for (tag, metadata) in six.iteritems(tags_for_run):
211218
events = self._multiplexer.Tensors(run, tag)
212-
result_for_run[tag] = [convert_event(e) for e in events]
219+
data = [convert_event(e) for e in events]
220+
result_for_run[tag] = _downsample(data, downsample)
213221
return result
214222

215223
def list_blob_sequences(
@@ -258,6 +266,7 @@ def read_blob_sequences(
258266
self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
259267
):
260268
self._validate_experiment_id(experiment_id)
269+
self._validate_downsample(downsample)
261270
index = self.list_blob_sequences(
262271
experiment_id, plugin_name, run_tag_filter=run_tag_filter
263272
)
@@ -275,7 +284,7 @@ def read_blob_sequences(
275284
experiment_id, plugin_name, run, tag, event
276285
)
277286
data = [datum for (step, datum) in sorted(data_by_step.items())]
278-
result_for_run[tag] = data
287+
result_for_run[tag] = _downsample(data, downsample)
279288
return result
280289

281290
def read_blob(self, blob_key):
@@ -411,3 +420,37 @@ def _tensor_size(tensor_proto):
411420
for dim in tensor_proto.tensor_shape.dim:
412421
result *= dim.size
413422
return result
423+
424+
425+
def _downsample(xs, k):
426+
"""Downsample `xs` to at most `k` elements.
427+
428+
If `k` is larger than `xs`, then the contents of `xs` itself will be
429+
returned. If `k` is smaller than `xs`, the last element of `xs` will
430+
always be included (unless `k` is `0`) and the preceding elements
431+
will be selected uniformly at random.
432+
433+
This differs from `random.sample` in that it returns a subsequence
434+
(i.e., order is preserved) and that it permits `k > len(xs)`.
435+
436+
The random number generator will always be `random.Random(0)`, so
437+
this function is deterministic (within a Python process).
438+
439+
Args:
440+
xs: A sequence (`collections.abc.Sequence`).
441+
k: A non-negative integer.
442+
443+
Returns:
444+
A new list whose elements are a subsequence of `xs` of length
445+
`min(k, len(xs))` and that is guaranteed to include the last
446+
element of `xs`, uniformly selected among such subsequences.
447+
"""
448+
449+
if k > len(xs):
450+
return list(xs)
451+
if k == 0:
452+
return []
453+
indices = random.Random(0).sample(six.moves.xrange(len(xs) - 1), k - 1)
454+
indices.sort()
455+
indices += [len(xs) - 1]
456+
return [xs[i] for i in indices]

tensorboard/backend/event_processing/data_provider_test.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def test_read_scalars(self):
249249
experiment_id="unused",
250250
plugin_name=scalar_metadata.PLUGIN_NAME,
251251
run_tag_filter=run_tag_filter,
252-
downsample=None, # not yet implemented
252+
downsample=100,
253253
)
254254

255255
self.assertItemsEqual(result.keys(), ["polynomials", "waves"])
@@ -267,6 +267,18 @@ def test_read_scalars(self):
267267
tensor_util.make_ndarray(event.tensor_proto).item(),
268268
)
269269

270+
def test_read_scalars_downsamples(self):
271+
multiplexer = self.create_multiplexer()
272+
provider = data_provider.MultiplexerDataProvider(
273+
multiplexer, self.logdir
274+
)
275+
result = provider.read_scalars(
276+
experiment_id="unused",
277+
plugin_name=scalar_metadata.PLUGIN_NAME,
278+
downsample=3,
279+
)
280+
self.assertLen(result["waves"]["sine"], 3)
281+
270282
def test_read_scalars_but_not_rank_0(self):
271283
provider = self.create_provider()
272284
run_tag_filter = base_provider.RunTagFilter(["waves"], ["bad"])
@@ -280,6 +292,7 @@ def test_read_scalars_but_not_rank_0(self):
280292
experiment_id="unused",
281293
plugin_name="greetings",
282294
run_tag_filter=run_tag_filter,
295+
downsample=100,
283296
)
284297

285298
def test_list_tensors_all(self):
@@ -329,7 +342,7 @@ def test_read_tensors(self):
329342
experiment_id="unused",
330343
plugin_name=histogram_metadata.PLUGIN_NAME,
331344
run_tag_filter=run_tag_filter,
332-
downsample=None, # not yet implemented
345+
downsample=100,
333346
)
334347

335348
self.assertItemsEqual(result.keys(), ["lebesgue"])
@@ -346,6 +359,46 @@ def test_read_tensors(self):
346359
tensor_util.make_ndarray(event.tensor_proto),
347360
)
348361

362+
def test_read_tensors_downsamples(self):
363+
multiplexer = self.create_multiplexer()
364+
provider = data_provider.MultiplexerDataProvider(
365+
multiplexer, self.logdir
366+
)
367+
result = provider.read_tensors(
368+
experiment_id="unused",
369+
plugin_name=histogram_metadata.PLUGIN_NAME,
370+
downsample=3,
371+
)
372+
self.assertLen(result["lebesgue"]["uniform"], 3)
373+
374+
375+
class DownsampleTest(tf.test.TestCase):
376+
"""Tests for the `_downsample` private helper function."""
377+
378+
def test_deterministic(self):
379+
xs = "abcdefg"
380+
expected = data_provider._downsample(xs, k=4)
381+
for _ in range(100):
382+
actual = data_provider._downsample(xs, k=4)
383+
self.assertEqual(actual, expected)
384+
385+
def test_underlong_ok(self):
386+
xs = list("abcdefg")
387+
actual = data_provider._downsample(xs, k=10)
388+
expected = list("abcdefg")
389+
self.assertIsNot(actual, xs)
390+
self.assertEqual(actual, expected)
391+
392+
def test_inorder(self):
393+
xs = list(range(10000))
394+
actual = data_provider._downsample(xs, k=100)
395+
self.assertEqual(actual, sorted(actual))
396+
397+
def test_zero(self):
398+
xs = "abcdefg"
399+
actual = data_provider._downsample(xs, k=0)
400+
self.assertEqual(actual, [])
401+
349402

350403
if __name__ == "__main__":
351404
tf.test.main()

tensorboard/plugins/graph/graphs_plugin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def graph_impl(
209209
experiment_id=experiment,
210210
plugin_name=metadata.PLUGIN_NAME,
211211
run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]),
212+
downsample=1,
212213
)
213214
blob_datum_list = graph_blob_sequences.get(run, {}).get(tag, ())
214215
try:

0 commit comments

Comments
 (0)