From 0db8bfad6898d4498ad416a2b7acab85cda7b0ec Mon Sep 17 00:00:00 2001 From: William Chargin Date: Mon, 15 Mar 2021 10:52:02 -0700 Subject: [PATCH 1/5] custom_scalar: add generic data support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: This patch replaces the multiplexer code in the custom scalars plugin with data provider code. It is not gated behind a flag. We treat the layout protos as tensors, since they are of scalar shape and are expected to be small. A case could be made for representing them as blob sequences with a datacompat shape conversion, but this is simpler and the plugin doesn’t get a ton of use. Test Plan: The dashboard still works with the standard demo data, both with and without `--load_fast`. wchargin-branch: custom-scalars-generic wchargin-source: 91001190a543e99ba73c407df3bc801d81eaa2cf --- tensorboard/BUILD | 1 + tensorboard/data/server/data_compat.rs | 5 ++- tensorboard/dataclass_compat.py | 11 +++++ tensorboard/plugins/custom_scalar/BUILD | 1 - .../custom_scalar/custom_scalars_plugin.py | 41 ++++++++++--------- .../custom_scalars_plugin_test.py | 12 ++++-- 6 files changed, 46 insertions(+), 25 deletions(-) diff --git a/tensorboard/BUILD b/tensorboard/BUILD index 4afbffe164..3a54e53681 100644 --- a/tensorboard/BUILD +++ b/tensorboard/BUILD @@ -531,6 +531,7 @@ py_library( deps = [ "//tensorboard/compat/proto:protos_all_py_pb2", "//tensorboard/plugins/audio:metadata", + "//tensorboard/plugins/custom_scalar:metadata", "//tensorboard/plugins/graph:metadata", "//tensorboard/plugins/histogram:metadata", "//tensorboard/plugins/hparams:metadata", diff --git a/tensorboard/data/server/data_compat.rs b/tensorboard/data/server/data_compat.rs index 790b5addb6..82d49aef52 100644 --- a/tensorboard/data/server/data_compat.rs +++ b/tensorboard/data/server/data_compat.rs @@ -43,6 +43,7 @@ pub(crate) mod plugin_names { pub const TEXT: &str = "text"; pub const PR_CURVES: &str = "pr_curves"; pub const HPARAMS: &str = "hparams"; + pub const CUSTOM_SCALARS: &str = "custom_scalars"; } /// The inner contents of a single value from an event. @@ -350,7 +351,8 @@ impl SummaryValue { Some(plugin_names::HISTOGRAMS) | Some(plugin_names::TEXT) | Some(plugin_names::HPARAMS) - | Some(plugin_names::PR_CURVES) => { + | Some(plugin_names::PR_CURVES) + | Some(plugin_names::CUSTOM_SCALARS) => { md.data_class = pb::DataClass::Tensor.into(); } Some(plugin_names::IMAGES) @@ -740,6 +742,7 @@ mod tests { plugin_names::TEXT, plugin_names::PR_CURVES, plugin_names::HPARAMS, + plugin_names::CUSTOM_SCALARS, ] { let md = blank_with_plugin_content( plugin_name, diff --git a/tensorboard/dataclass_compat.py b/tensorboard/dataclass_compat.py index db05e5a719..8b25682e7e 100644 --- a/tensorboard/dataclass_compat.py +++ b/tensorboard/dataclass_compat.py @@ -25,6 +25,9 @@ from tensorboard.compat.proto import event_pb2 from tensorboard.compat.proto import summary_pb2 from tensorboard.plugins.audio import metadata as audio_metadata +from tensorboard.plugins.custom_scalar import ( + metadata as custom_scalars_metadata, +) from tensorboard.plugins.graph import metadata as graphs_metadata from tensorboard.plugins.histogram import metadata as histograms_metadata from tensorboard.plugins.hparams import metadata as hparams_metadata @@ -136,6 +139,8 @@ def _migrate_value(value, initial_metadata): return _migrate_hparams_value(value) if plugin_name == pr_curves_metadata.PLUGIN_NAME: return _migrate_pr_curve_value(value) + if plugin_name == custom_scalars_metadata.PLUGIN_NAME: + return _migrate_custom_scalars_value(value) if plugin_name in [ graphs_metadata.PLUGIN_NAME_RUN_METADATA, graphs_metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH, @@ -196,6 +201,12 @@ def _migrate_pr_curve_value(value): return (value,) +def _migrate_custom_scalars_value(value): + if value.HasField("metadata"): + value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR + return (value,) + + def _migrate_graph_sub_plugin_value(value): if value.HasField("metadata"): value.metadata.data_class = summary_pb2.DATA_CLASS_BLOB_SEQUENCE diff --git a/tensorboard/plugins/custom_scalar/BUILD b/tensorboard/plugins/custom_scalar/BUILD index f83c9ee24f..8a8c4cd8b8 100644 --- a/tensorboard/plugins/custom_scalar/BUILD +++ b/tensorboard/plugins/custom_scalar/BUILD @@ -21,7 +21,6 @@ py_library( "//tensorboard/plugins:base_plugin", "//tensorboard/plugins/scalar:metadata", "//tensorboard/plugins/scalar:scalars_plugin", - "//tensorboard/util:tensor_util", "@org_pocoo_werkzeug", ], ) diff --git a/tensorboard/plugins/custom_scalar/custom_scalars_plugin.py b/tensorboard/plugins/custom_scalar/custom_scalars_plugin.py index 10dff8714a..697c36f969 100644 --- a/tensorboard/plugins/custom_scalar/custom_scalars_plugin.py +++ b/tensorboard/plugins/custom_scalar/custom_scalars_plugin.py @@ -31,12 +31,12 @@ from tensorboard import plugin_util from tensorboard.backend import http_util from tensorboard.compat import tf +from tensorboard.data import provider from tensorboard.plugins import base_plugin from tensorboard.plugins.custom_scalar import layout_pb2 from tensorboard.plugins.custom_scalar import metadata from tensorboard.plugins.scalar import metadata as scalars_metadata from tensorboard.plugins.scalar import scalars_plugin -from tensorboard.util import tensor_util # The name of the property in the response for whether the regex is valid. @@ -63,7 +63,7 @@ def __init__(self, context): context: A base_plugin.TBContext instance. """ self._logdir = context.logdir - self._multiplexer = context.multiplexer + self._data_provider = context.data_provider self._plugin_name_to_instance = context.plugin_name_to_instance def _get_scalars_plugin(self): @@ -214,8 +214,11 @@ def scalars_impl(self, ctx, run, tag_regex_string, experiment): } # Fetch the tags for the run. Filter for tags that match the regex. - run_to_data = self._multiplexer.PluginRunToTagToContent( - scalars_metadata.PLUGIN_NAME + run_to_data = self._data_provider.list_scalars( + ctx, + experiment_id=experiment, + plugin_name=scalars_metadata.PLUGIN_NAME, + run_tag_filter=provider.RunTagFilter(runs=[run]), ) tag_to_data = None @@ -264,29 +267,29 @@ def layout_route(self, request): The response is an empty object if no layout could be found. """ - body = self.layout_impl() + ctx = plugin_util.context(request.environ) + experiment = plugin_util.experiment_id(request.environ) + body = self.layout_impl(ctx, experiment) return http_util.Respond(request, body, "application/json") - def layout_impl(self): + def layout_impl(self, ctx, experiment): # Keep a mapping between and category so we do not create duplicate # categories. title_to_category = {} merged_layout = None - runs = list( - self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME) + data = self._data_provider.read_tensors( + ctx, + experiment_id=experiment, + plugin_name=metadata.PLUGIN_NAME, + run_tag_filter=provider.RunTagFilter( + tags=[metadata.CONFIG_SUMMARY_TAG] + ), + downsample=1, ) - runs.sort() - for run in runs: - tensor_events = self._multiplexer.Tensors( - run, metadata.CONFIG_SUMMARY_TAG - ) - - # This run has a layout. Merge it with the ones currently found. - string_array = tensor_util.make_ndarray( - tensor_events[0].tensor_proto - ) - content = string_array.item() + for run in sorted(data): + points = data[run][metadata.CONFIG_SUMMARY_TAG] + content = points[0].numpy.item() layout_proto = layout_pb2.Layout() layout_proto.ParseFromString(tf.compat.as_bytes(content)) diff --git a/tensorboard/plugins/custom_scalar/custom_scalars_plugin_test.py b/tensorboard/plugins/custom_scalar/custom_scalars_plugin_test.py index f4def5fdcd..f6ca9f10c0 100644 --- a/tensorboard/plugins/custom_scalar/custom_scalars_plugin_test.py +++ b/tensorboard/plugins/custom_scalar/custom_scalars_plugin_test.py @@ -165,7 +165,6 @@ def createPlugin(self, logdir): plugin_name_to_instance = {} context = base_plugin.TBContext( logdir=logdir, - multiplexer=multiplexer, data_provider=provider, plugin_name_to_instance=plugin_name_to_instance, ) @@ -232,8 +231,9 @@ def testScalars(self): np.testing.assert_allclose(step + 1, entry[2]) def testMergedLayout(self): + ctx = context.RequestContext() parsed_layout = layout_pb2.Layout() - json_format.Parse(self.plugin.layout_impl(), parsed_layout) + json_format.Parse(self.plugin.layout_impl(ctx, "exp_id"), parsed_layout) correct_layout = layout_pb2.Layout( category=[ # A category with this name is also present in a layout for a @@ -293,15 +293,19 @@ def testMergedLayout(self): def testLayoutFromSingleRun(self): # The foo directory contains 1 single layout. + ctx = context.RequestContext() local_plugin = self.createPlugin(os.path.join(self.logdir, "foo")) parsed_layout = layout_pb2.Layout() - json_format.Parse(local_plugin.layout_impl(), parsed_layout) + json_format.Parse( + local_plugin.layout_impl(ctx, "exp_id"), parsed_layout + ) self.assertProtoEquals(self.foo_layout, parsed_layout) def testNoLayoutFound(self): # The bar directory contains no layout. + ctx = context.RequestContext() local_plugin = self.createPlugin(os.path.join(self.logdir, "bar")) - self.assertDictEqual({}, local_plugin.layout_impl()) + self.assertDictEqual({}, local_plugin.layout_impl(ctx, "exp_id")) def testIsActive(self): self.assertFalse(self.plugin.is_active()) From 403d165cc93cecf9fffd2242bcd208dc9d1154b5 Mon Sep 17 00:00:00 2001 From: William Chargin Date: Mon, 15 Mar 2021 10:54:13 -0700 Subject: [PATCH 2/5] mesh: add generic data support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: This patch replaces the multiplexer code in the mesh plugin with data provider code. It is not gated behind a flag. We treat all mesh data as tensors, since that’s most expedient at this time and the local data providers (multiplexer and RustBoard) don’t actually have any size constraints. See also discussion in #4734. Test Plan: The dashboard still works with the standard demo data, both with and without `--load_fast`. wchargin-branch: mesh-generic wchargin-source: 2a995b31efea5c58dee9d05fc7e66bbdb833a32e --- tensorboard/BUILD | 1 + tensorboard/data/server/data_compat.rs | 5 +- tensorboard/dataclass_compat.py | 9 ++ tensorboard/plugins/mesh/BUILD | 3 +- tensorboard/plugins/mesh/mesh_plugin.py | 94 +++++++++++++------- tensorboard/plugins/mesh/mesh_plugin_test.py | 10 ++- 6 files changed, 86 insertions(+), 36 deletions(-) diff --git a/tensorboard/BUILD b/tensorboard/BUILD index 3a54e53681..01da360f55 100644 --- a/tensorboard/BUILD +++ b/tensorboard/BUILD @@ -536,6 +536,7 @@ py_library( "//tensorboard/plugins/histogram:metadata", "//tensorboard/plugins/hparams:metadata", "//tensorboard/plugins/image:metadata", + "//tensorboard/plugins/mesh:metadata", "//tensorboard/plugins/pr_curve:metadata", "//tensorboard/plugins/scalar:metadata", "//tensorboard/plugins/text:metadata", diff --git a/tensorboard/data/server/data_compat.rs b/tensorboard/data/server/data_compat.rs index 82d49aef52..15751a85c8 100644 --- a/tensorboard/data/server/data_compat.rs +++ b/tensorboard/data/server/data_compat.rs @@ -44,6 +44,7 @@ pub(crate) mod plugin_names { pub const PR_CURVES: &str = "pr_curves"; pub const HPARAMS: &str = "hparams"; pub const CUSTOM_SCALARS: &str = "custom_scalars"; + pub const MESH: &str = "mesh"; } /// The inner contents of a single value from an event. @@ -352,7 +353,8 @@ impl SummaryValue { | Some(plugin_names::TEXT) | Some(plugin_names::HPARAMS) | Some(plugin_names::PR_CURVES) - | Some(plugin_names::CUSTOM_SCALARS) => { + | Some(plugin_names::CUSTOM_SCALARS) + | Some(plugin_names::MESH) => { md.data_class = pb::DataClass::Tensor.into(); } Some(plugin_names::IMAGES) @@ -743,6 +745,7 @@ mod tests { plugin_names::PR_CURVES, plugin_names::HPARAMS, plugin_names::CUSTOM_SCALARS, + plugin_names::MESH, ] { let md = blank_with_plugin_content( plugin_name, diff --git a/tensorboard/dataclass_compat.py b/tensorboard/dataclass_compat.py index 8b25682e7e..16c7bba643 100644 --- a/tensorboard/dataclass_compat.py +++ b/tensorboard/dataclass_compat.py @@ -32,6 +32,7 @@ from tensorboard.plugins.histogram import metadata as histograms_metadata from tensorboard.plugins.hparams import metadata as hparams_metadata from tensorboard.plugins.image import metadata as images_metadata +from tensorboard.plugins.mesh import metadata as mesh_metadata from tensorboard.plugins.pr_curve import metadata as pr_curves_metadata from tensorboard.plugins.scalar import metadata as scalars_metadata from tensorboard.plugins.text import metadata as text_metadata @@ -141,6 +142,8 @@ def _migrate_value(value, initial_metadata): return _migrate_pr_curve_value(value) if plugin_name == custom_scalars_metadata.PLUGIN_NAME: return _migrate_custom_scalars_value(value) + if plugin_name == mesh_metadata.PLUGIN_NAME: + return _migrate_mesh_value(value) if plugin_name in [ graphs_metadata.PLUGIN_NAME_RUN_METADATA, graphs_metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH, @@ -207,6 +210,12 @@ def _migrate_custom_scalars_value(value): return (value,) +def _migrate_mesh_value(value): + if value.HasField("metadata"): + value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR + return (value,) + + def _migrate_graph_sub_plugin_value(value): if value.HasField("metadata"): value.metadata.data_class = summary_pb2.DATA_CLASS_BLOB_SEQUENCE diff --git a/tensorboard/plugins/mesh/BUILD b/tensorboard/plugins/mesh/BUILD index 014fecb975..bc5df8a387 100644 --- a/tensorboard/plugins/mesh/BUILD +++ b/tensorboard/plugins/mesh/BUILD @@ -46,8 +46,8 @@ py_library( "//tensorboard:expect_numpy_installed", "//tensorboard:plugin_util", "//tensorboard/backend:http_util", + "//tensorboard/data:provider", "//tensorboard/plugins:base_plugin", - "//tensorboard/util:tensor_util", "@org_pocoo_werkzeug", ], ) @@ -76,6 +76,7 @@ py_test( "//tensorboard:expect_numpy_installed", "//tensorboard:expect_tensorflow_installed", "//tensorboard/backend:application", + "//tensorboard/backend/event_processing:data_provider", "//tensorboard/backend/event_processing:event_multiplexer", "//tensorboard/plugins:base_plugin", "//tensorboard/util:test_util", diff --git a/tensorboard/plugins/mesh/mesh_plugin.py b/tensorboard/plugins/mesh/mesh_plugin.py index 34b4d0e609..a8f7ef9099 100644 --- a/tensorboard/plugins/mesh/mesh_plugin.py +++ b/tensorboard/plugins/mesh/mesh_plugin.py @@ -18,10 +18,13 @@ from werkzeug import wrappers from tensorboard.backend import http_util +from tensorboard.data import provider from tensorboard.plugins import base_plugin from tensorboard.plugins.mesh import metadata from tensorboard.plugins.mesh import plugin_data_pb2 -from tensorboard.util import tensor_util +from tensorboard import plugin_util + +_DEFAULT_DOWNSAMPLING = 100 # meshes per time series class MeshPlugin(base_plugin.TBPlugin): @@ -36,28 +39,42 @@ def __init__(self, context): context: A base_plugin.TBContext instance. A magic container that TensorBoard uses to make objects available to the plugin. """ - # Retrieve the multiplexer from the context and store a reference to it. - self._multiplexer = context.multiplexer + self._data_provider = context.data_provider + self._downsample_to = (context.sampling_hints or {}).get( + self.plugin_name, _DEFAULT_DOWNSAMPLING + ) - def _instance_tag_metadata(self, run, instance_tag): + def _instance_tag_metadata(self, ctx, experiment, run, instance_tag): """Gets the `MeshPluginData` proto for an instance tag.""" - summary_metadata = self._multiplexer.SummaryMetadata(run, instance_tag) - content = summary_metadata.plugin_data.content + results = self._data_provider.list_tensors( + ctx, + experiment_id=experiment, + plugin_name=metadata.PLUGIN_NAME, + run_tag_filter=provider.RunTagFilter( + runs=[run], tags=[instance_tag] + ), + ) + content = results[run][instance_tag].plugin_content return metadata.parse_plugin_metadata(content) - def _tag(self, run, instance_tag): + def _tag(self, ctx, experiment, run, instance_tag): """Gets the user-facing tag name for an instance tag.""" - return self._instance_tag_metadata(run, instance_tag).name + return self._instance_tag_metadata( + ctx, experiment, run, instance_tag + ).name - def _instance_tags(self, run, tag): + def _instance_tags(self, ctx, experiment, run, tag): """Gets the instance tag names for a user-facing tag.""" - index = self._multiplexer.GetAccumulator(run).PluginTagToContent( - metadata.PLUGIN_NAME + index = self._data_provider.list_tensors( + ctx, + experiment_id=experiment, + plugin_name=metadata.PLUGIN_NAME, + run_tag_filter=provider.RunTagFilter(runs=[run]), ) return [ instance_tag - for (instance_tag, content) in index.items() - if tag == metadata.parse_plugin_metadata(content).name + for (instance_tag, ts) in index.get(run, {}).items() + if tag == metadata.parse_plugin_metadata(ts.plugin_content).name ] @wrappers.Request.application @@ -72,10 +89,12 @@ def _serve_tags(self, request): are all the runs. Each run is mapped to a (potentially empty) list of all tags that are relevant to this plugin. """ - # This is a dictionary mapping from run to (tag to string content). - # To be clear, the values of the dictionary are dictionaries. - all_runs = self._multiplexer.PluginRunToTagToContent( - MeshPlugin.plugin_name + ctx = plugin_util.context(request.environ) + experiment = plugin_util.experiment_id(request.environ) + all_runs = self._data_provider.list_tensors( + ctx, + experiment_id=experiment, + plugin_name=metadata.PLUGIN_NAME, ) # tagToContent is itself a dictionary mapping tag name to string @@ -83,12 +102,14 @@ def _serve_tags(self, request): # to obtain a list of tags associated with each run. For each tag estimate # number of samples. response = dict() - for run, tag_to_content in all_runs.items(): + for run, tags in all_runs.items(): response[run] = dict() - for instance_tag, _ in tag_to_content.items(): + for instance_tag in tags: # Make sure we only operate on user-defined tags here. - tag = self._tag(run, instance_tag) - meta = self._instance_tag_metadata(run, instance_tag) + tag = self._tag(ctx, experiment, run, instance_tag) + meta = self._instance_tag_metadata( + ctx, experiment, run, instance_tag + ) # Batch size must be defined, otherwise we don't know how many # samples were there. response[run][tag] = {"samples": meta.shape[0]} @@ -117,20 +138,19 @@ def is_active(self): def frontend_metadata(self): return base_plugin.FrontendMetadata(element_name="mesh-dashboard") - def _get_sample(self, tensor_event, sample): + def _get_sample(self, tensor_datum, sample): """Returns a single sample from a batch of samples.""" - data = tensor_util.make_ndarray(tensor_event.tensor_proto) - return data[sample].tolist() + return tensor_datum.numpy[sample].tolist() def _get_tensor_metadata( self, event, content_type, components, data_shape, config ): - """Converts a TensorEvent into a JSON-compatible response. + """Converts a TensorDatum into a JSON-compatible response. Args: - event: TensorEvent object containing data in proto format. + event: TensorDatum object containing data in proto format. content_type: enum plugin_data_pb2.MeshPluginData.ContentType value, - representing content type in TensorEvent. + representing content type in TensorDatum. components: Bitmask representing all parts (vertices, colors, etc.) that belong to the summary. data_shape: list of dimensions sizes of the tensor. @@ -149,19 +169,31 @@ def _get_tensor_metadata( } def _get_tensor_data(self, event, sample): - """Convert a TensorEvent into a JSON-compatible response.""" + """Convert a TensorDatum into a JSON-compatible response.""" data = self._get_sample(event, sample) return data def _collect_tensor_events(self, request, step=None): """Collects list of tensor events based on request.""" + ctx = plugin_util.context(request.environ) + experiment = plugin_util.experiment_id(request.environ) run = request.args.get("run") tag = request.args.get("tag") tensor_events = [] # List of tuples (meta, tensor) that contain tag. - for instance_tag in self._instance_tags(run, tag): - tensors = self._multiplexer.Tensors(run, instance_tag) - meta = self._instance_tag_metadata(run, instance_tag) + for instance_tag in self._instance_tags(ctx, experiment, run, tag): + tensors = self._data_provider.read_tensors( + ctx, + experiment_id=experiment, + plugin_name=metadata.PLUGIN_NAME, + run_tag_filter=provider.RunTagFilter( + runs=[run], tags=[instance_tag] + ), + downsample=self._downsample_to, + )[run][instance_tag] + meta = self._instance_tag_metadata( + ctx, experiment, run, instance_tag + ) tensor_events += [(meta, tensor) for tensor in tensors] if step is not None: diff --git a/tensorboard/plugins/mesh/mesh_plugin_test.py b/tensorboard/plugins/mesh/mesh_plugin_test.py index 908c4f47cc..0a8a12dac8 100644 --- a/tensorboard/plugins/mesh/mesh_plugin_test.py +++ b/tensorboard/plugins/mesh/mesh_plugin_test.py @@ -25,6 +25,7 @@ from werkzeug import test as werkzeug_test from werkzeug import wrappers from tensorboard.backend import application +from tensorboard.backend.event_processing import data_provider from tensorboard.backend.event_processing import ( plugin_event_multiplexer as event_multiplexer, ) @@ -137,13 +138,16 @@ def setUp(self): ) # Start a server that will receive requests. - self.multiplexer = event_multiplexer.EventMultiplexer( + multiplexer = event_multiplexer.EventMultiplexer( { "bar": bar_directory, } ) + provider = data_provider.MultiplexerDataProvider( + multiplexer, self.log_dir + ) self.context = base_plugin.TBContext( - logdir=self.log_dir, multiplexer=self.multiplexer + logdir=self.log_dir, data_provider=provider ) self.plugin = mesh_plugin.MeshPlugin(self.context) # Wait until after plugin construction to reload the multiplexer because the @@ -152,7 +156,7 @@ def setUp(self): # TODO(https://github.com/tensorflow/tensorboard/issues/2579): Eliminate the # caching of data at construction time and move this Reload() up to just # after the multiplexer is created. - self.multiplexer.Reload() + multiplexer.Reload() wsgi_app = application.TensorBoardWSGI([self.plugin]) self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse) self.routes = self.plugin.get_plugin_apps() From 0636b38cf56a7103fcd2f64b798f66bdd5739415 Mon Sep 17 00:00:00 2001 From: William Chargin Date: Mon, 15 Mar 2021 19:11:18 -0700 Subject: [PATCH 3/5] [mesh-generic: rotate below custom scalars] wchargin-branch: mesh-generic wchargin-source: 75e006cc59401e05b0aa2b4ec49536ccc9f0aa9e --- tensorboard/BUILD | 1 - tensorboard/data/server/data_compat.rs | 3 -- tensorboard/dataclass_compat.py | 11 ----- tensorboard/plugins/custom_scalar/BUILD | 1 + .../custom_scalar/custom_scalars_plugin.py | 41 +++++++++---------- .../custom_scalars_plugin_test.py | 12 ++---- 6 files changed, 24 insertions(+), 45 deletions(-) diff --git a/tensorboard/BUILD b/tensorboard/BUILD index 01da360f55..6e7fa9e1f5 100644 --- a/tensorboard/BUILD +++ b/tensorboard/BUILD @@ -531,7 +531,6 @@ py_library( deps = [ "//tensorboard/compat/proto:protos_all_py_pb2", "//tensorboard/plugins/audio:metadata", - "//tensorboard/plugins/custom_scalar:metadata", "//tensorboard/plugins/graph:metadata", "//tensorboard/plugins/histogram:metadata", "//tensorboard/plugins/hparams:metadata", diff --git a/tensorboard/data/server/data_compat.rs b/tensorboard/data/server/data_compat.rs index 15751a85c8..4ba9b8cd82 100644 --- a/tensorboard/data/server/data_compat.rs +++ b/tensorboard/data/server/data_compat.rs @@ -43,7 +43,6 @@ pub(crate) mod plugin_names { pub const TEXT: &str = "text"; pub const PR_CURVES: &str = "pr_curves"; pub const HPARAMS: &str = "hparams"; - pub const CUSTOM_SCALARS: &str = "custom_scalars"; pub const MESH: &str = "mesh"; } @@ -353,7 +352,6 @@ impl SummaryValue { | Some(plugin_names::TEXT) | Some(plugin_names::HPARAMS) | Some(plugin_names::PR_CURVES) - | Some(plugin_names::CUSTOM_SCALARS) | Some(plugin_names::MESH) => { md.data_class = pb::DataClass::Tensor.into(); } @@ -744,7 +742,6 @@ mod tests { plugin_names::TEXT, plugin_names::PR_CURVES, plugin_names::HPARAMS, - plugin_names::CUSTOM_SCALARS, plugin_names::MESH, ] { let md = blank_with_plugin_content( diff --git a/tensorboard/dataclass_compat.py b/tensorboard/dataclass_compat.py index 16c7bba643..2fa25c94ce 100644 --- a/tensorboard/dataclass_compat.py +++ b/tensorboard/dataclass_compat.py @@ -25,9 +25,6 @@ from tensorboard.compat.proto import event_pb2 from tensorboard.compat.proto import summary_pb2 from tensorboard.plugins.audio import metadata as audio_metadata -from tensorboard.plugins.custom_scalar import ( - metadata as custom_scalars_metadata, -) from tensorboard.plugins.graph import metadata as graphs_metadata from tensorboard.plugins.histogram import metadata as histograms_metadata from tensorboard.plugins.hparams import metadata as hparams_metadata @@ -140,8 +137,6 @@ def _migrate_value(value, initial_metadata): return _migrate_hparams_value(value) if plugin_name == pr_curves_metadata.PLUGIN_NAME: return _migrate_pr_curve_value(value) - if plugin_name == custom_scalars_metadata.PLUGIN_NAME: - return _migrate_custom_scalars_value(value) if plugin_name == mesh_metadata.PLUGIN_NAME: return _migrate_mesh_value(value) if plugin_name in [ @@ -204,12 +199,6 @@ def _migrate_pr_curve_value(value): return (value,) -def _migrate_custom_scalars_value(value): - if value.HasField("metadata"): - value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR - return (value,) - - def _migrate_mesh_value(value): if value.HasField("metadata"): value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR diff --git a/tensorboard/plugins/custom_scalar/BUILD b/tensorboard/plugins/custom_scalar/BUILD index 8a8c4cd8b8..f83c9ee24f 100644 --- a/tensorboard/plugins/custom_scalar/BUILD +++ b/tensorboard/plugins/custom_scalar/BUILD @@ -21,6 +21,7 @@ py_library( "//tensorboard/plugins:base_plugin", "//tensorboard/plugins/scalar:metadata", "//tensorboard/plugins/scalar:scalars_plugin", + "//tensorboard/util:tensor_util", "@org_pocoo_werkzeug", ], ) diff --git a/tensorboard/plugins/custom_scalar/custom_scalars_plugin.py b/tensorboard/plugins/custom_scalar/custom_scalars_plugin.py index 697c36f969..10dff8714a 100644 --- a/tensorboard/plugins/custom_scalar/custom_scalars_plugin.py +++ b/tensorboard/plugins/custom_scalar/custom_scalars_plugin.py @@ -31,12 +31,12 @@ from tensorboard import plugin_util from tensorboard.backend import http_util from tensorboard.compat import tf -from tensorboard.data import provider from tensorboard.plugins import base_plugin from tensorboard.plugins.custom_scalar import layout_pb2 from tensorboard.plugins.custom_scalar import metadata from tensorboard.plugins.scalar import metadata as scalars_metadata from tensorboard.plugins.scalar import scalars_plugin +from tensorboard.util import tensor_util # The name of the property in the response for whether the regex is valid. @@ -63,7 +63,7 @@ def __init__(self, context): context: A base_plugin.TBContext instance. """ self._logdir = context.logdir - self._data_provider = context.data_provider + self._multiplexer = context.multiplexer self._plugin_name_to_instance = context.plugin_name_to_instance def _get_scalars_plugin(self): @@ -214,11 +214,8 @@ def scalars_impl(self, ctx, run, tag_regex_string, experiment): } # Fetch the tags for the run. Filter for tags that match the regex. - run_to_data = self._data_provider.list_scalars( - ctx, - experiment_id=experiment, - plugin_name=scalars_metadata.PLUGIN_NAME, - run_tag_filter=provider.RunTagFilter(runs=[run]), + run_to_data = self._multiplexer.PluginRunToTagToContent( + scalars_metadata.PLUGIN_NAME ) tag_to_data = None @@ -267,29 +264,29 @@ def layout_route(self, request): The response is an empty object if no layout could be found. """ - ctx = plugin_util.context(request.environ) - experiment = plugin_util.experiment_id(request.environ) - body = self.layout_impl(ctx, experiment) + body = self.layout_impl() return http_util.Respond(request, body, "application/json") - def layout_impl(self, ctx, experiment): + def layout_impl(self): # Keep a mapping between and category so we do not create duplicate # categories. title_to_category = {} merged_layout = None - data = self._data_provider.read_tensors( - ctx, - experiment_id=experiment, - plugin_name=metadata.PLUGIN_NAME, - run_tag_filter=provider.RunTagFilter( - tags=[metadata.CONFIG_SUMMARY_TAG] - ), - downsample=1, + runs = list( + self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME) ) - for run in sorted(data): - points = data[run][metadata.CONFIG_SUMMARY_TAG] - content = points[0].numpy.item() + runs.sort() + for run in runs: + tensor_events = self._multiplexer.Tensors( + run, metadata.CONFIG_SUMMARY_TAG + ) + + # This run has a layout. Merge it with the ones currently found. + string_array = tensor_util.make_ndarray( + tensor_events[0].tensor_proto + ) + content = string_array.item() layout_proto = layout_pb2.Layout() layout_proto.ParseFromString(tf.compat.as_bytes(content)) diff --git a/tensorboard/plugins/custom_scalar/custom_scalars_plugin_test.py b/tensorboard/plugins/custom_scalar/custom_scalars_plugin_test.py index f6ca9f10c0..f4def5fdcd 100644 --- a/tensorboard/plugins/custom_scalar/custom_scalars_plugin_test.py +++ b/tensorboard/plugins/custom_scalar/custom_scalars_plugin_test.py @@ -165,6 +165,7 @@ def createPlugin(self, logdir): plugin_name_to_instance = {} context = base_plugin.TBContext( logdir=logdir, + multiplexer=multiplexer, data_provider=provider, plugin_name_to_instance=plugin_name_to_instance, ) @@ -231,9 +232,8 @@ def testScalars(self): np.testing.assert_allclose(step + 1, entry[2]) def testMergedLayout(self): - ctx = context.RequestContext() parsed_layout = layout_pb2.Layout() - json_format.Parse(self.plugin.layout_impl(ctx, "exp_id"), parsed_layout) + json_format.Parse(self.plugin.layout_impl(), parsed_layout) correct_layout = layout_pb2.Layout( category=[ # A category with this name is also present in a layout for a @@ -293,19 +293,15 @@ def testMergedLayout(self): def testLayoutFromSingleRun(self): # The foo directory contains 1 single layout. - ctx = context.RequestContext() local_plugin = self.createPlugin(os.path.join(self.logdir, "foo")) parsed_layout = layout_pb2.Layout() - json_format.Parse( - local_plugin.layout_impl(ctx, "exp_id"), parsed_layout - ) + json_format.Parse(local_plugin.layout_impl(), parsed_layout) self.assertProtoEquals(self.foo_layout, parsed_layout) def testNoLayoutFound(self): # The bar directory contains no layout. - ctx = context.RequestContext() local_plugin = self.createPlugin(os.path.join(self.logdir, "bar")) - self.assertDictEqual({}, local_plugin.layout_impl(ctx, "exp_id")) + self.assertDictEqual({}, local_plugin.layout_impl()) def testIsActive(self): self.assertFalse(self.plugin.is_active()) From 84b18642e510d770fe85ff1aeec9b19d6479e75d Mon Sep 17 00:00:00 2001 From: William Chargin Date: Mon, 15 Mar 2021 19:11:50 -0700 Subject: [PATCH 4/5] [mesh-generic: no-op] [ci skip] wchargin-branch: mesh-generic wchargin-source: 75e006cc59401e05b0aa2b4ec49536ccc9f0aa9e From 8d6be6c9353c843c404b47eda805556a4072fc1b Mon Sep 17 00:00:00 2001 From: William Chargin Date: Mon, 15 Mar 2021 19:12:26 -0700 Subject: [PATCH 5/5] [mesh-generic: bump ci] wchargin-branch: mesh-generic wchargin-source: 75e006cc59401e05b0aa2b4ec49536ccc9f0aa9e