diff --git a/tensorboard/backend/event_processing/data_provider.py b/tensorboard/backend/event_processing/data_provider.py index f3993a5df1..54c7f04038 100644 --- a/tensorboard/backend/event_processing/data_provider.py +++ b/tensorboard/backend/event_processing/data_provider.py @@ -21,6 +21,7 @@ import base64 import collections import json +import random import six @@ -57,6 +58,16 @@ def _validate_experiment_id(self, experiment_id): % (str, type(experiment_id), experiment_id) ) + def _validate_downsample(self, downsample): + if downsample is None: + raise TypeError("`downsample` required but not given") + if isinstance(downsample, int): + return # OK + raise TypeError( + "`downsample` must be an int, but got %r: %r" + % (type(downsample), downsample) + ) + def _test_run_tag(self, run_tag_filter, run, tag): runs = run_tag_filter.runs if runs is not None and run not in runs: @@ -109,14 +120,11 @@ def list_scalars(self, experiment_id, plugin_name, run_tag_filter=None): def read_scalars( self, experiment_id, plugin_name, downsample=None, run_tag_filter=None ): - # TODO(@wchargin): Downsampling not implemented, as the multiplexer - # is already downsampled. We could downsample on top of the existing - # sampling, which would be nice for testing. - del downsample # ignored for now + self._validate_downsample(downsample) index = self.list_scalars( experiment_id, plugin_name, run_tag_filter=run_tag_filter ) - return self._read(_convert_scalar_event, index) + return self._read(_convert_scalar_event, index, downsample) def list_tensors(self, experiment_id, plugin_name, run_tag_filter=None): self._validate_experiment_id(experiment_id) @@ -131,14 +139,11 @@ def list_tensors(self, experiment_id, plugin_name, run_tag_filter=None): def read_tensors( self, experiment_id, plugin_name, downsample=None, run_tag_filter=None ): - # TODO(@wchargin): Downsampling not implemented, as the multiplexer - # is already downsampled. We could downsample on top of the existing - # sampling, which would be nice for testing. - del downsample # ignored for now + self._validate_downsample(downsample) index = self.list_tensors( experiment_id, plugin_name, run_tag_filter=run_tag_filter ) - return self._read(_convert_tensor_event, index) + return self._read(_convert_tensor_event, index, downsample) def _list( self, @@ -191,13 +196,15 @@ def _list( ) return result - def _read(self, convert_event, index): + def _read(self, convert_event, index, downsample): """Helper to read scalar or tensor data from the multiplexer. Args: convert_event: Takes `plugin_event_accumulator.TensorEvent` to either `provider.ScalarDatum` or `provider.TensorDatum`. index: The result of `list_scalars` or `list_tensors`. + downsample: Non-negative `int`; how many samples to return per + time series. Returns: A dict of dicts of values returned by `convert_event` calls, @@ -209,7 +216,8 @@ def _read(self, convert_event, index): result[run] = result_for_run for (tag, metadata) in six.iteritems(tags_for_run): events = self._multiplexer.Tensors(run, tag) - result_for_run[tag] = [convert_event(e) for e in events] + data = [convert_event(e) for e in events] + result_for_run[tag] = _downsample(data, downsample) return result def list_blob_sequences( @@ -258,6 +266,7 @@ def read_blob_sequences( self, experiment_id, plugin_name, downsample=None, run_tag_filter=None ): self._validate_experiment_id(experiment_id) + self._validate_downsample(downsample) index = self.list_blob_sequences( experiment_id, plugin_name, run_tag_filter=run_tag_filter ) @@ -275,7 +284,7 @@ def read_blob_sequences( experiment_id, plugin_name, run, tag, event ) data = [datum for (step, datum) in sorted(data_by_step.items())] - result_for_run[tag] = data + result_for_run[tag] = _downsample(data, downsample) return result def read_blob(self, blob_key): @@ -411,3 +420,37 @@ def _tensor_size(tensor_proto): for dim in tensor_proto.tensor_shape.dim: result *= dim.size return result + + +def _downsample(xs, k): + """Downsample `xs` to at most `k` elements. + + If `k` is larger than `xs`, then the contents of `xs` itself will be + returned. If `k` is smaller than `xs`, the last element of `xs` will + always be included (unless `k` is `0`) and the preceding elements + will be selected uniformly at random. + + This differs from `random.sample` in that it returns a subsequence + (i.e., order is preserved) and that it permits `k > len(xs)`. + + The random number generator will always be `random.Random(0)`, so + this function is deterministic (within a Python process). + + Args: + xs: A sequence (`collections.abc.Sequence`). + k: A non-negative integer. + + Returns: + A new list whose elements are a subsequence of `xs` of length + `min(k, len(xs))` and that is guaranteed to include the last + element of `xs`, uniformly selected among such subsequences. + """ + + if k > len(xs): + return list(xs) + if k == 0: + return [] + indices = random.Random(0).sample(six.moves.xrange(len(xs) - 1), k - 1) + indices.sort() + indices += [len(xs) - 1] + return [xs[i] for i in indices] diff --git a/tensorboard/backend/event_processing/data_provider_test.py b/tensorboard/backend/event_processing/data_provider_test.py index 1497f7fb0f..46a10f8ffe 100644 --- a/tensorboard/backend/event_processing/data_provider_test.py +++ b/tensorboard/backend/event_processing/data_provider_test.py @@ -249,7 +249,7 @@ def test_read_scalars(self): experiment_id="unused", plugin_name=scalar_metadata.PLUGIN_NAME, run_tag_filter=run_tag_filter, - downsample=None, # not yet implemented + downsample=100, ) self.assertItemsEqual(result.keys(), ["polynomials", "waves"]) @@ -267,6 +267,18 @@ def test_read_scalars(self): tensor_util.make_ndarray(event.tensor_proto).item(), ) + def test_read_scalars_downsamples(self): + multiplexer = self.create_multiplexer() + provider = data_provider.MultiplexerDataProvider( + multiplexer, self.logdir + ) + result = provider.read_scalars( + experiment_id="unused", + plugin_name=scalar_metadata.PLUGIN_NAME, + downsample=3, + ) + self.assertLen(result["waves"]["sine"], 3) + def test_read_scalars_but_not_rank_0(self): provider = self.create_provider() run_tag_filter = base_provider.RunTagFilter(["waves"], ["bad"]) @@ -280,6 +292,7 @@ def test_read_scalars_but_not_rank_0(self): experiment_id="unused", plugin_name="greetings", run_tag_filter=run_tag_filter, + downsample=100, ) def test_list_tensors_all(self): @@ -329,7 +342,7 @@ def test_read_tensors(self): experiment_id="unused", plugin_name=histogram_metadata.PLUGIN_NAME, run_tag_filter=run_tag_filter, - downsample=None, # not yet implemented + downsample=100, ) self.assertItemsEqual(result.keys(), ["lebesgue"]) @@ -346,6 +359,46 @@ def test_read_tensors(self): tensor_util.make_ndarray(event.tensor_proto), ) + def test_read_tensors_downsamples(self): + multiplexer = self.create_multiplexer() + provider = data_provider.MultiplexerDataProvider( + multiplexer, self.logdir + ) + result = provider.read_tensors( + experiment_id="unused", + plugin_name=histogram_metadata.PLUGIN_NAME, + downsample=3, + ) + self.assertLen(result["lebesgue"]["uniform"], 3) + + +class DownsampleTest(tf.test.TestCase): + """Tests for the `_downsample` private helper function.""" + + def test_deterministic(self): + xs = "abcdefg" + expected = data_provider._downsample(xs, k=4) + for _ in range(100): + actual = data_provider._downsample(xs, k=4) + self.assertEqual(actual, expected) + + def test_underlong_ok(self): + xs = list("abcdefg") + actual = data_provider._downsample(xs, k=10) + expected = list("abcdefg") + self.assertIsNot(actual, xs) + self.assertEqual(actual, expected) + + def test_inorder(self): + xs = list(range(10000)) + actual = data_provider._downsample(xs, k=100) + self.assertEqual(actual, sorted(actual)) + + def test_zero(self): + xs = "abcdefg" + actual = data_provider._downsample(xs, k=0) + self.assertEqual(actual, []) + if __name__ == "__main__": tf.test.main() diff --git a/tensorboard/plugins/graph/graphs_plugin.py b/tensorboard/plugins/graph/graphs_plugin.py index c499ae1692..91913dff7a 100644 --- a/tensorboard/plugins/graph/graphs_plugin.py +++ b/tensorboard/plugins/graph/graphs_plugin.py @@ -209,6 +209,7 @@ def graph_impl( experiment_id=experiment, plugin_name=metadata.PLUGIN_NAME, run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]), + downsample=1, ) blob_datum_list = graph_blob_sequences.get(run, {}).get(tag, ()) try: