diff --git a/tensorboard/BUILD b/tensorboard/BUILD index 4afbffe164..6e7fa9e1f5 100644 --- a/tensorboard/BUILD +++ b/tensorboard/BUILD @@ -535,6 +535,7 @@ py_library( "//tensorboard/plugins/histogram:metadata", "//tensorboard/plugins/hparams:metadata", "//tensorboard/plugins/image:metadata", + "//tensorboard/plugins/mesh:metadata", "//tensorboard/plugins/pr_curve:metadata", "//tensorboard/plugins/scalar:metadata", "//tensorboard/plugins/text:metadata", diff --git a/tensorboard/data/server/data_compat.rs b/tensorboard/data/server/data_compat.rs index 790b5addb6..4ba9b8cd82 100644 --- a/tensorboard/data/server/data_compat.rs +++ b/tensorboard/data/server/data_compat.rs @@ -43,6 +43,7 @@ pub(crate) mod plugin_names { pub const TEXT: &str = "text"; pub const PR_CURVES: &str = "pr_curves"; pub const HPARAMS: &str = "hparams"; + pub const MESH: &str = "mesh"; } /// The inner contents of a single value from an event. @@ -350,7 +351,8 @@ impl SummaryValue { Some(plugin_names::HISTOGRAMS) | Some(plugin_names::TEXT) | Some(plugin_names::HPARAMS) - | Some(plugin_names::PR_CURVES) => { + | Some(plugin_names::PR_CURVES) + | Some(plugin_names::MESH) => { md.data_class = pb::DataClass::Tensor.into(); } Some(plugin_names::IMAGES) @@ -740,6 +742,7 @@ mod tests { plugin_names::TEXT, plugin_names::PR_CURVES, plugin_names::HPARAMS, + plugin_names::MESH, ] { let md = blank_with_plugin_content( plugin_name, diff --git a/tensorboard/dataclass_compat.py b/tensorboard/dataclass_compat.py index db05e5a719..2fa25c94ce 100644 --- a/tensorboard/dataclass_compat.py +++ b/tensorboard/dataclass_compat.py @@ -29,6 +29,7 @@ from tensorboard.plugins.histogram import metadata as histograms_metadata from tensorboard.plugins.hparams import metadata as hparams_metadata from tensorboard.plugins.image import metadata as images_metadata +from tensorboard.plugins.mesh import metadata as mesh_metadata from tensorboard.plugins.pr_curve import metadata as pr_curves_metadata from tensorboard.plugins.scalar import metadata as scalars_metadata from tensorboard.plugins.text import metadata as text_metadata @@ -136,6 +137,8 @@ def _migrate_value(value, initial_metadata): return _migrate_hparams_value(value) if plugin_name == pr_curves_metadata.PLUGIN_NAME: return _migrate_pr_curve_value(value) + if plugin_name == mesh_metadata.PLUGIN_NAME: + return _migrate_mesh_value(value) if plugin_name in [ graphs_metadata.PLUGIN_NAME_RUN_METADATA, graphs_metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH, @@ -196,6 +199,12 @@ def _migrate_pr_curve_value(value): return (value,) +def _migrate_mesh_value(value): + if value.HasField("metadata"): + value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR + return (value,) + + def _migrate_graph_sub_plugin_value(value): if value.HasField("metadata"): value.metadata.data_class = summary_pb2.DATA_CLASS_BLOB_SEQUENCE diff --git a/tensorboard/plugins/mesh/BUILD b/tensorboard/plugins/mesh/BUILD index 014fecb975..bc5df8a387 100644 --- a/tensorboard/plugins/mesh/BUILD +++ b/tensorboard/plugins/mesh/BUILD @@ -46,8 +46,8 @@ py_library( "//tensorboard:expect_numpy_installed", "//tensorboard:plugin_util", "//tensorboard/backend:http_util", + "//tensorboard/data:provider", "//tensorboard/plugins:base_plugin", - "//tensorboard/util:tensor_util", "@org_pocoo_werkzeug", ], ) @@ -76,6 +76,7 @@ py_test( "//tensorboard:expect_numpy_installed", "//tensorboard:expect_tensorflow_installed", "//tensorboard/backend:application", + "//tensorboard/backend/event_processing:data_provider", "//tensorboard/backend/event_processing:event_multiplexer", "//tensorboard/plugins:base_plugin", "//tensorboard/util:test_util", diff --git a/tensorboard/plugins/mesh/mesh_plugin.py b/tensorboard/plugins/mesh/mesh_plugin.py index 34b4d0e609..a8f7ef9099 100644 --- a/tensorboard/plugins/mesh/mesh_plugin.py +++ b/tensorboard/plugins/mesh/mesh_plugin.py @@ -18,10 +18,13 @@ from werkzeug import wrappers from tensorboard.backend import http_util +from tensorboard.data import provider from tensorboard.plugins import base_plugin from tensorboard.plugins.mesh import metadata from tensorboard.plugins.mesh import plugin_data_pb2 -from tensorboard.util import tensor_util +from tensorboard import plugin_util + +_DEFAULT_DOWNSAMPLING = 100 # meshes per time series class MeshPlugin(base_plugin.TBPlugin): @@ -36,28 +39,42 @@ def __init__(self, context): context: A base_plugin.TBContext instance. A magic container that TensorBoard uses to make objects available to the plugin. """ - # Retrieve the multiplexer from the context and store a reference to it. - self._multiplexer = context.multiplexer + self._data_provider = context.data_provider + self._downsample_to = (context.sampling_hints or {}).get( + self.plugin_name, _DEFAULT_DOWNSAMPLING + ) - def _instance_tag_metadata(self, run, instance_tag): + def _instance_tag_metadata(self, ctx, experiment, run, instance_tag): """Gets the `MeshPluginData` proto for an instance tag.""" - summary_metadata = self._multiplexer.SummaryMetadata(run, instance_tag) - content = summary_metadata.plugin_data.content + results = self._data_provider.list_tensors( + ctx, + experiment_id=experiment, + plugin_name=metadata.PLUGIN_NAME, + run_tag_filter=provider.RunTagFilter( + runs=[run], tags=[instance_tag] + ), + ) + content = results[run][instance_tag].plugin_content return metadata.parse_plugin_metadata(content) - def _tag(self, run, instance_tag): + def _tag(self, ctx, experiment, run, instance_tag): """Gets the user-facing tag name for an instance tag.""" - return self._instance_tag_metadata(run, instance_tag).name + return self._instance_tag_metadata( + ctx, experiment, run, instance_tag + ).name - def _instance_tags(self, run, tag): + def _instance_tags(self, ctx, experiment, run, tag): """Gets the instance tag names for a user-facing tag.""" - index = self._multiplexer.GetAccumulator(run).PluginTagToContent( - metadata.PLUGIN_NAME + index = self._data_provider.list_tensors( + ctx, + experiment_id=experiment, + plugin_name=metadata.PLUGIN_NAME, + run_tag_filter=provider.RunTagFilter(runs=[run]), ) return [ instance_tag - for (instance_tag, content) in index.items() - if tag == metadata.parse_plugin_metadata(content).name + for (instance_tag, ts) in index.get(run, {}).items() + if tag == metadata.parse_plugin_metadata(ts.plugin_content).name ] @wrappers.Request.application @@ -72,10 +89,12 @@ def _serve_tags(self, request): are all the runs. Each run is mapped to a (potentially empty) list of all tags that are relevant to this plugin. """ - # This is a dictionary mapping from run to (tag to string content). - # To be clear, the values of the dictionary are dictionaries. - all_runs = self._multiplexer.PluginRunToTagToContent( - MeshPlugin.plugin_name + ctx = plugin_util.context(request.environ) + experiment = plugin_util.experiment_id(request.environ) + all_runs = self._data_provider.list_tensors( + ctx, + experiment_id=experiment, + plugin_name=metadata.PLUGIN_NAME, ) # tagToContent is itself a dictionary mapping tag name to string @@ -83,12 +102,14 @@ def _serve_tags(self, request): # to obtain a list of tags associated with each run. For each tag estimate # number of samples. response = dict() - for run, tag_to_content in all_runs.items(): + for run, tags in all_runs.items(): response[run] = dict() - for instance_tag, _ in tag_to_content.items(): + for instance_tag in tags: # Make sure we only operate on user-defined tags here. - tag = self._tag(run, instance_tag) - meta = self._instance_tag_metadata(run, instance_tag) + tag = self._tag(ctx, experiment, run, instance_tag) + meta = self._instance_tag_metadata( + ctx, experiment, run, instance_tag + ) # Batch size must be defined, otherwise we don't know how many # samples were there. response[run][tag] = {"samples": meta.shape[0]} @@ -117,20 +138,19 @@ def is_active(self): def frontend_metadata(self): return base_plugin.FrontendMetadata(element_name="mesh-dashboard") - def _get_sample(self, tensor_event, sample): + def _get_sample(self, tensor_datum, sample): """Returns a single sample from a batch of samples.""" - data = tensor_util.make_ndarray(tensor_event.tensor_proto) - return data[sample].tolist() + return tensor_datum.numpy[sample].tolist() def _get_tensor_metadata( self, event, content_type, components, data_shape, config ): - """Converts a TensorEvent into a JSON-compatible response. + """Converts a TensorDatum into a JSON-compatible response. Args: - event: TensorEvent object containing data in proto format. + event: TensorDatum object containing data in proto format. content_type: enum plugin_data_pb2.MeshPluginData.ContentType value, - representing content type in TensorEvent. + representing content type in TensorDatum. components: Bitmask representing all parts (vertices, colors, etc.) that belong to the summary. data_shape: list of dimensions sizes of the tensor. @@ -149,19 +169,31 @@ def _get_tensor_metadata( } def _get_tensor_data(self, event, sample): - """Convert a TensorEvent into a JSON-compatible response.""" + """Convert a TensorDatum into a JSON-compatible response.""" data = self._get_sample(event, sample) return data def _collect_tensor_events(self, request, step=None): """Collects list of tensor events based on request.""" + ctx = plugin_util.context(request.environ) + experiment = plugin_util.experiment_id(request.environ) run = request.args.get("run") tag = request.args.get("tag") tensor_events = [] # List of tuples (meta, tensor) that contain tag. - for instance_tag in self._instance_tags(run, tag): - tensors = self._multiplexer.Tensors(run, instance_tag) - meta = self._instance_tag_metadata(run, instance_tag) + for instance_tag in self._instance_tags(ctx, experiment, run, tag): + tensors = self._data_provider.read_tensors( + ctx, + experiment_id=experiment, + plugin_name=metadata.PLUGIN_NAME, + run_tag_filter=provider.RunTagFilter( + runs=[run], tags=[instance_tag] + ), + downsample=self._downsample_to, + )[run][instance_tag] + meta = self._instance_tag_metadata( + ctx, experiment, run, instance_tag + ) tensor_events += [(meta, tensor) for tensor in tensors] if step is not None: diff --git a/tensorboard/plugins/mesh/mesh_plugin_test.py b/tensorboard/plugins/mesh/mesh_plugin_test.py index 908c4f47cc..0a8a12dac8 100644 --- a/tensorboard/plugins/mesh/mesh_plugin_test.py +++ b/tensorboard/plugins/mesh/mesh_plugin_test.py @@ -25,6 +25,7 @@ from werkzeug import test as werkzeug_test from werkzeug import wrappers from tensorboard.backend import application +from tensorboard.backend.event_processing import data_provider from tensorboard.backend.event_processing import ( plugin_event_multiplexer as event_multiplexer, ) @@ -137,13 +138,16 @@ def setUp(self): ) # Start a server that will receive requests. - self.multiplexer = event_multiplexer.EventMultiplexer( + multiplexer = event_multiplexer.EventMultiplexer( { "bar": bar_directory, } ) + provider = data_provider.MultiplexerDataProvider( + multiplexer, self.log_dir + ) self.context = base_plugin.TBContext( - logdir=self.log_dir, multiplexer=self.multiplexer + logdir=self.log_dir, data_provider=provider ) self.plugin = mesh_plugin.MeshPlugin(self.context) # Wait until after plugin construction to reload the multiplexer because the @@ -152,7 +156,7 @@ def setUp(self): # TODO(https://github.com/tensorflow/tensorboard/issues/2579): Eliminate the # caching of data at construction time and move this Reload() up to just # after the multiplexer is created. - self.multiplexer.Reload() + multiplexer.Reload() wsgi_app = application.TensorBoardWSGI([self.plugin]) self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse) self.routes = self.plugin.get_plugin_apps()