From dcfe16e1134e89d90bcc3e1fe43372a659d37e3a Mon Sep 17 00:00:00 2001 From: William Chargin Date: Thu, 20 Feb 2020 13:30:17 -0800 Subject: [PATCH 1/2] data: expose downsampling preferences to plugins MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: We add a `sampling_hints` attribute to the `TBContext` magic container, which is populated with the parsed form of the `--samples_per_plugin` flag. Existing plugins’ generic data modes are updated to read from this map instead of using hard-coded thresholds. Test Plan: This change is not actually observable as is, because the multiplexer data provider ignores its downsampling argument. But after patching in a change to make the data provider respect the downsampling argument, this change has the effect that increasing the `--samples_per_plugin` over the default (e.g., `images=20`) now properly increases the number of samples shown in generic data mode, whereas previously it had no effect. wchargin-branch: data-downsampling-flag wchargin-source: 50998be15abd790a0915458bac76091c79823f0f --- tensorboard/backend/application.py | 21 ++++++++++++------- tensorboard/plugins/base_plugin.py | 6 ++++++ .../plugins/histogram/histograms_plugin.py | 19 +++++++++++------ tensorboard/plugins/image/images_plugin.py | 10 ++++----- tensorboard/plugins/scalar/scalars_plugin.py | 12 ++++++----- tensorboard/plugins/text/text_plugin.py | 7 ++++++- 6 files changed, 50 insertions(+), 25 deletions(-) diff --git a/tensorboard/backend/application.py b/tensorboard/backend/application.py index 37950020c2..a9aa08c280 100644 --- a/tensorboard/backend/application.py +++ b/tensorboard/backend/application.py @@ -97,17 +97,20 @@ logger = tb_logging.get_logger() -def tensor_size_guidance_from_flags(flags): - """Apply user per-summary size guidance overrides.""" - - tensor_size_guidance = dict(DEFAULT_TENSOR_SIZE_GUIDANCE) +def _parse_samples_per_plugin(flags): + result = {} if not flags or not flags.samples_per_plugin: - return tensor_size_guidance - + return result for token in flags.samples_per_plugin.split(","): k, v = token.strip().split("=") - tensor_size_guidance[k] = int(v) + result[k] = int(v) + return result + +def _apply_tensor_size_guidance(sampling_hints): + """Apply user per-summary size guidance overrides.""" + tensor_size_guidance = dict(DEFAULT_TENSOR_SIZE_GUIDANCE) + tensor_size_guidance.update(sampling_hints) return tensor_size_guidance @@ -151,9 +154,10 @@ def standard_tensorboard_wsgi(flags, plugin_loaders, assets_zip_provider): multiplexer = _DbModeMultiplexer(flags.db, db_connection_provider) else: # Regular logdir loading mode. + sampling_hints = _parse_samples_per_plugin(flags) multiplexer = event_multiplexer.EventMultiplexer( size_guidance=DEFAULT_SIZE_GUIDANCE, - tensor_size_guidance=tensor_size_guidance_from_flags(flags), + tensor_size_guidance=_apply_tensor_size_guidance(sampling_hints), purge_orphaned_data=flags.purge_orphaned_data, max_reload_threads=flags.max_reload_threads, event_file_active_filter=_get_event_file_active_filter(flags), @@ -238,6 +242,7 @@ def TensorBoardWSGIApp( multiplexer=deprecated_multiplexer, assets_zip_provider=assets_zip_provider, plugin_name_to_instance=plugin_name_to_instance, + sampling_hints=_parse_samples_per_plugin(flags), window_title=flags.window_title, ) tbplugins = [] diff --git a/tensorboard/plugins/base_plugin.py b/tensorboard/plugins/base_plugin.py index 339ca9cb60..51134052ab 100644 --- a/tensorboard/plugins/base_plugin.py +++ b/tensorboard/plugins/base_plugin.py @@ -254,6 +254,7 @@ def __init__( logdir=None, multiplexer=None, plugin_name_to_instance=None, + sampling_hints=None, window_title=None, ): """Instantiates magic container. @@ -291,6 +292,10 @@ def __init__( plugin may be absent from this mapping until it is registered. Plugin logic should handle cases in which a plugin is absent from this mapping, lest a KeyError is raised. + sampling_hints: Map from plugin name to `int` or `NoneType`, where + the value represents the user-specified downsampling limit as + given to the `--samples_per_plugin` flag, or `None` if none was + explicitly given for this plugin. window_title: A string specifying the window title. """ self.assets_zip_provider = assets_zip_provider @@ -301,6 +306,7 @@ def __init__( self.logdir = logdir self.multiplexer = multiplexer self.plugin_name_to_instance = plugin_name_to_instance + self.sampling_hints = sampling_hints self.window_title = window_title diff --git a/tensorboard/plugins/histogram/histograms_plugin.py b/tensorboard/plugins/histogram/histograms_plugin.py index b4e48e1fca..0890084529 100644 --- a/tensorboard/plugins/histogram/histograms_plugin.py +++ b/tensorboard/plugins/histogram/histograms_plugin.py @@ -39,6 +39,9 @@ from tensorboard.util import tensor_util +_DEFAULT_DOWNSAMPLING = 500 # histograms per time series + + class HistogramsPlugin(base_plugin.TBPlugin): """Histograms Plugin for TensorBoard. @@ -62,6 +65,9 @@ def __init__(self, context): """ self._multiplexer = context.multiplexer self._db_connection_provider = context.db_connection_provider + self._downsample_to = (context.sampling_hints or {}).get( + self.plugin_name, _DEFAULT_DOWNSAMPLING + ) if context.flags and context.flags.generic_data == "true": self._data_provider = context.data_provider else: @@ -174,20 +180,21 @@ def histograms_impl(self, tag, run, experiment, downsample_to=None): """Result of the form `(body, mime_type)`. At most `downsample_to` events will be returned. If this value is - `None`, then no downsampling will be performed. + `None`, then default downsampling will be performed. Raises: tensorboard.errors.PublicError: On invalid request. """ if self._data_provider: - # Downsample reads to 500 histograms per time series, which is - # the default size guidance for histograms under the multiplexer - # loading logic. - SAMPLE_COUNT = downsample_to if downsample_to is not None else 500 + sample_count = ( + downsample_to + if downsample_to is not None + else self._downsample_to + ) all_histograms = self._data_provider.read_tensors( experiment_id=experiment, plugin_name=metadata.PLUGIN_NAME, - downsample=SAMPLE_COUNT, + downsample=sample_count, run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]), ) histograms = all_histograms.get(run, {}).get(tag, None) diff --git a/tensorboard/plugins/image/images_plugin.py b/tensorboard/plugins/image/images_plugin.py index 7458f4ded0..bc06ec4864 100644 --- a/tensorboard/plugins/image/images_plugin.py +++ b/tensorboard/plugins/image/images_plugin.py @@ -43,6 +43,7 @@ } _DEFAULT_IMAGE_MIMETYPE = "application/octet-stream" +_DEFAULT_DOWNSAMPLING = 10 # images per time series # Extend imghdr.tests to include svg. @@ -69,6 +70,9 @@ def __init__(self, context): """ self._multiplexer = context.multiplexer self._db_connection_provider = context.db_connection_provider + self._downsample_to = (context.sampling_hints or {}).get( + self.plugin_name, _DEFAULT_DOWNSAMPLING + ) if context.flags and context.flags.generic_data == "true": self._data_provider = context.data_provider else: @@ -239,14 +243,10 @@ def _image_response_for_run(self, experiment, run, tag, sample): parameters. """ if self._data_provider: - # Downsample reads to 10 images per time series, which is the - # default size guidance for images under the multiplexer loading - # logic. - SAMPLE_COUNT = 10 all_images = self._data_provider.read_blob_sequences( experiment_id=experiment, plugin_name=metadata.PLUGIN_NAME, - downsample=SAMPLE_COUNT, + downsample=self._downsample_to, run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]), ) images = all_images.get(run, {}).get(tag, None) diff --git a/tensorboard/plugins/scalar/scalars_plugin.py b/tensorboard/plugins/scalar/scalars_plugin.py index 8a1faf4839..59ab6d10cc 100644 --- a/tensorboard/plugins/scalar/scalars_plugin.py +++ b/tensorboard/plugins/scalar/scalars_plugin.py @@ -40,6 +40,9 @@ from tensorboard.util import tensor_util +_DEFAULT_DOWNSAMPLING = 1000 # scalars per time series + + class OutputFormat(object): """An enum used to list the valid output formats for API calls.""" @@ -60,6 +63,9 @@ def __init__(self, context): """ self._multiplexer = context.multiplexer self._db_connection_provider = context.db_connection_provider + self._downsample_to = (context.sampling_hints or {}).get( + self.plugin_name, _DEFAULT_DOWNSAMPLING + ) if context.flags and context.flags.generic_data != "false": self._data_provider = context.data_provider else: @@ -169,14 +175,10 @@ def index_impl(self, experiment=None): def scalars_impl(self, tag, run, experiment, output_format): """Result of the form `(body, mime_type)`.""" if self._data_provider: - # Downsample reads to 1000 scalars per time series, which is the - # default size guidance for scalars under the multiplexer loading - # logic. - SAMPLE_COUNT = 1000 all_scalars = self._data_provider.read_scalars( experiment_id=experiment, plugin_name=metadata.PLUGIN_NAME, - downsample=SAMPLE_COUNT, + downsample=self._downsample_to, run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]), ) scalars = all_scalars.get(run, {}).get(tag, None) diff --git a/tensorboard/plugins/text/text_plugin.py b/tensorboard/plugins/text/text_plugin.py index 248d76741e..6ff564973b 100644 --- a/tensorboard/plugins/text/text_plugin.py +++ b/tensorboard/plugins/text/text_plugin.py @@ -48,6 +48,8 @@ 2d tables are supported. Showing a 2d slice of the data instead.""" ) +_DEFAULT_DOWNSAMPLING = 100 # text tensors per time series + def make_table_row(contents, tag="td"): """Given an iterable of string contents, make a table row. @@ -212,6 +214,9 @@ def __init__(self, context): context: A base_plugin.TBContext instance. """ self._multiplexer = context.multiplexer + self._downsample_to = (context.sampling_hints or {}).get( + self.plugin_name, _DEFAULT_DOWNSAMPLING + ) if context.flags and context.flags.generic_data == "true": self._data_provider = context.data_provider else: @@ -261,7 +266,7 @@ def text_impl(self, run, tag, experiment): all_text = self._data_provider.read_tensors( experiment_id=experiment, plugin_name=metadata.PLUGIN_NAME, - downsample=100, + downsample=self._downsample_to, run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]), ) text = all_text.get(run, {}).get(tag, None) From 490440c422dd035b971e36e4657c7eafa5c611be Mon Sep 17 00:00:00 2001 From: William Chargin Date: Thu, 20 Feb 2020 13:31:53 -0800 Subject: [PATCH 2/2] data: perform downsampling in multiplexer provider MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: The `MultiplexerDataProvider` now respects its `downsample` parameter, even though the backing `PluginEventMultiplexer` already performs its own sampling. This serves two purposes: - It enforces that clients are always specifying the `downsample` argument, which is required. - It enables us to test plugins’ downsampling parameters to verify that they will behave correctly with other data providers. Test Plan: Unit tests included. Note that changing the `_DEFAULT_DOWNSAMPLING` constant in (e.g.) the scalars plugin to a small number (like `5`) now actually causes charts in the frontend to be downsampled. wchargin-branch: data-mux-downsample wchargin-source: 116ce4c206613e25e09fec31102b90ad80282496 --- .../backend/event_processing/data_provider.py | 69 +++++++++++++++---- .../event_processing/data_provider_test.py | 57 ++++++++++++++- tensorboard/plugins/graph/graphs_plugin.py | 1 + 3 files changed, 112 insertions(+), 15 deletions(-) diff --git a/tensorboard/backend/event_processing/data_provider.py b/tensorboard/backend/event_processing/data_provider.py index f3993a5df1..54c7f04038 100644 --- a/tensorboard/backend/event_processing/data_provider.py +++ b/tensorboard/backend/event_processing/data_provider.py @@ -21,6 +21,7 @@ import base64 import collections import json +import random import six @@ -57,6 +58,16 @@ def _validate_experiment_id(self, experiment_id): % (str, type(experiment_id), experiment_id) ) + def _validate_downsample(self, downsample): + if downsample is None: + raise TypeError("`downsample` required but not given") + if isinstance(downsample, int): + return # OK + raise TypeError( + "`downsample` must be an int, but got %r: %r" + % (type(downsample), downsample) + ) + def _test_run_tag(self, run_tag_filter, run, tag): runs = run_tag_filter.runs if runs is not None and run not in runs: @@ -109,14 +120,11 @@ def list_scalars(self, experiment_id, plugin_name, run_tag_filter=None): def read_scalars( self, experiment_id, plugin_name, downsample=None, run_tag_filter=None ): - # TODO(@wchargin): Downsampling not implemented, as the multiplexer - # is already downsampled. We could downsample on top of the existing - # sampling, which would be nice for testing. - del downsample # ignored for now + self._validate_downsample(downsample) index = self.list_scalars( experiment_id, plugin_name, run_tag_filter=run_tag_filter ) - return self._read(_convert_scalar_event, index) + return self._read(_convert_scalar_event, index, downsample) def list_tensors(self, experiment_id, plugin_name, run_tag_filter=None): self._validate_experiment_id(experiment_id) @@ -131,14 +139,11 @@ def list_tensors(self, experiment_id, plugin_name, run_tag_filter=None): def read_tensors( self, experiment_id, plugin_name, downsample=None, run_tag_filter=None ): - # TODO(@wchargin): Downsampling not implemented, as the multiplexer - # is already downsampled. We could downsample on top of the existing - # sampling, which would be nice for testing. - del downsample # ignored for now + self._validate_downsample(downsample) index = self.list_tensors( experiment_id, plugin_name, run_tag_filter=run_tag_filter ) - return self._read(_convert_tensor_event, index) + return self._read(_convert_tensor_event, index, downsample) def _list( self, @@ -191,13 +196,15 @@ def _list( ) return result - def _read(self, convert_event, index): + def _read(self, convert_event, index, downsample): """Helper to read scalar or tensor data from the multiplexer. Args: convert_event: Takes `plugin_event_accumulator.TensorEvent` to either `provider.ScalarDatum` or `provider.TensorDatum`. index: The result of `list_scalars` or `list_tensors`. + downsample: Non-negative `int`; how many samples to return per + time series. Returns: A dict of dicts of values returned by `convert_event` calls, @@ -209,7 +216,8 @@ def _read(self, convert_event, index): result[run] = result_for_run for (tag, metadata) in six.iteritems(tags_for_run): events = self._multiplexer.Tensors(run, tag) - result_for_run[tag] = [convert_event(e) for e in events] + data = [convert_event(e) for e in events] + result_for_run[tag] = _downsample(data, downsample) return result def list_blob_sequences( @@ -258,6 +266,7 @@ def read_blob_sequences( self, experiment_id, plugin_name, downsample=None, run_tag_filter=None ): self._validate_experiment_id(experiment_id) + self._validate_downsample(downsample) index = self.list_blob_sequences( experiment_id, plugin_name, run_tag_filter=run_tag_filter ) @@ -275,7 +284,7 @@ def read_blob_sequences( experiment_id, plugin_name, run, tag, event ) data = [datum for (step, datum) in sorted(data_by_step.items())] - result_for_run[tag] = data + result_for_run[tag] = _downsample(data, downsample) return result def read_blob(self, blob_key): @@ -411,3 +420,37 @@ def _tensor_size(tensor_proto): for dim in tensor_proto.tensor_shape.dim: result *= dim.size return result + + +def _downsample(xs, k): + """Downsample `xs` to at most `k` elements. + + If `k` is larger than `xs`, then the contents of `xs` itself will be + returned. If `k` is smaller than `xs`, the last element of `xs` will + always be included (unless `k` is `0`) and the preceding elements + will be selected uniformly at random. + + This differs from `random.sample` in that it returns a subsequence + (i.e., order is preserved) and that it permits `k > len(xs)`. + + The random number generator will always be `random.Random(0)`, so + this function is deterministic (within a Python process). + + Args: + xs: A sequence (`collections.abc.Sequence`). + k: A non-negative integer. + + Returns: + A new list whose elements are a subsequence of `xs` of length + `min(k, len(xs))` and that is guaranteed to include the last + element of `xs`, uniformly selected among such subsequences. + """ + + if k > len(xs): + return list(xs) + if k == 0: + return [] + indices = random.Random(0).sample(six.moves.xrange(len(xs) - 1), k - 1) + indices.sort() + indices += [len(xs) - 1] + return [xs[i] for i in indices] diff --git a/tensorboard/backend/event_processing/data_provider_test.py b/tensorboard/backend/event_processing/data_provider_test.py index 1497f7fb0f..46a10f8ffe 100644 --- a/tensorboard/backend/event_processing/data_provider_test.py +++ b/tensorboard/backend/event_processing/data_provider_test.py @@ -249,7 +249,7 @@ def test_read_scalars(self): experiment_id="unused", plugin_name=scalar_metadata.PLUGIN_NAME, run_tag_filter=run_tag_filter, - downsample=None, # not yet implemented + downsample=100, ) self.assertItemsEqual(result.keys(), ["polynomials", "waves"]) @@ -267,6 +267,18 @@ def test_read_scalars(self): tensor_util.make_ndarray(event.tensor_proto).item(), ) + def test_read_scalars_downsamples(self): + multiplexer = self.create_multiplexer() + provider = data_provider.MultiplexerDataProvider( + multiplexer, self.logdir + ) + result = provider.read_scalars( + experiment_id="unused", + plugin_name=scalar_metadata.PLUGIN_NAME, + downsample=3, + ) + self.assertLen(result["waves"]["sine"], 3) + def test_read_scalars_but_not_rank_0(self): provider = self.create_provider() run_tag_filter = base_provider.RunTagFilter(["waves"], ["bad"]) @@ -280,6 +292,7 @@ def test_read_scalars_but_not_rank_0(self): experiment_id="unused", plugin_name="greetings", run_tag_filter=run_tag_filter, + downsample=100, ) def test_list_tensors_all(self): @@ -329,7 +342,7 @@ def test_read_tensors(self): experiment_id="unused", plugin_name=histogram_metadata.PLUGIN_NAME, run_tag_filter=run_tag_filter, - downsample=None, # not yet implemented + downsample=100, ) self.assertItemsEqual(result.keys(), ["lebesgue"]) @@ -346,6 +359,46 @@ def test_read_tensors(self): tensor_util.make_ndarray(event.tensor_proto), ) + def test_read_tensors_downsamples(self): + multiplexer = self.create_multiplexer() + provider = data_provider.MultiplexerDataProvider( + multiplexer, self.logdir + ) + result = provider.read_tensors( + experiment_id="unused", + plugin_name=histogram_metadata.PLUGIN_NAME, + downsample=3, + ) + self.assertLen(result["lebesgue"]["uniform"], 3) + + +class DownsampleTest(tf.test.TestCase): + """Tests for the `_downsample` private helper function.""" + + def test_deterministic(self): + xs = "abcdefg" + expected = data_provider._downsample(xs, k=4) + for _ in range(100): + actual = data_provider._downsample(xs, k=4) + self.assertEqual(actual, expected) + + def test_underlong_ok(self): + xs = list("abcdefg") + actual = data_provider._downsample(xs, k=10) + expected = list("abcdefg") + self.assertIsNot(actual, xs) + self.assertEqual(actual, expected) + + def test_inorder(self): + xs = list(range(10000)) + actual = data_provider._downsample(xs, k=100) + self.assertEqual(actual, sorted(actual)) + + def test_zero(self): + xs = "abcdefg" + actual = data_provider._downsample(xs, k=0) + self.assertEqual(actual, []) + if __name__ == "__main__": tf.test.main() diff --git a/tensorboard/plugins/graph/graphs_plugin.py b/tensorboard/plugins/graph/graphs_plugin.py index c499ae1692..91913dff7a 100644 --- a/tensorboard/plugins/graph/graphs_plugin.py +++ b/tensorboard/plugins/graph/graphs_plugin.py @@ -209,6 +209,7 @@ def graph_impl( experiment_id=experiment, plugin_name=metadata.PLUGIN_NAME, run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]), + downsample=1, ) blob_datum_list = graph_blob_sequences.get(run, {}).get(tag, ()) try: