diff --git a/tensorboard/BUILD b/tensorboard/BUILD index 6e7fa9e1f5..01da360f55 100644 --- a/tensorboard/BUILD +++ b/tensorboard/BUILD @@ -531,6 +531,7 @@ py_library( deps = [ "//tensorboard/compat/proto:protos_all_py_pb2", "//tensorboard/plugins/audio:metadata", + "//tensorboard/plugins/custom_scalar:metadata", "//tensorboard/plugins/graph:metadata", "//tensorboard/plugins/histogram:metadata", "//tensorboard/plugins/hparams:metadata", diff --git a/tensorboard/data/server/data_compat.rs b/tensorboard/data/server/data_compat.rs index 4ba9b8cd82..22ae733f1f 100644 --- a/tensorboard/data/server/data_compat.rs +++ b/tensorboard/data/server/data_compat.rs @@ -44,6 +44,7 @@ pub(crate) mod plugin_names { pub const PR_CURVES: &str = "pr_curves"; pub const HPARAMS: &str = "hparams"; pub const MESH: &str = "mesh"; + pub const CUSTOM_SCALARS: &str = "custom_scalars"; } /// The inner contents of a single value from an event. @@ -352,7 +353,8 @@ impl SummaryValue { | Some(plugin_names::TEXT) | Some(plugin_names::HPARAMS) | Some(plugin_names::PR_CURVES) - | Some(plugin_names::MESH) => { + | Some(plugin_names::MESH) + | Some(plugin_names::CUSTOM_SCALARS) => { md.data_class = pb::DataClass::Tensor.into(); } Some(plugin_names::IMAGES) @@ -743,6 +745,7 @@ mod tests { plugin_names::PR_CURVES, plugin_names::HPARAMS, plugin_names::MESH, + plugin_names::CUSTOM_SCALARS, ] { let md = blank_with_plugin_content( plugin_name, diff --git a/tensorboard/dataclass_compat.py b/tensorboard/dataclass_compat.py index 2fa25c94ce..f2f3a74434 100644 --- a/tensorboard/dataclass_compat.py +++ b/tensorboard/dataclass_compat.py @@ -25,6 +25,9 @@ from tensorboard.compat.proto import event_pb2 from tensorboard.compat.proto import summary_pb2 from tensorboard.plugins.audio import metadata as audio_metadata +from tensorboard.plugins.custom_scalar import ( + metadata as custom_scalars_metadata, +) from tensorboard.plugins.graph import metadata as graphs_metadata from tensorboard.plugins.histogram import metadata as histograms_metadata from tensorboard.plugins.hparams import metadata as hparams_metadata @@ -139,6 +142,8 @@ def _migrate_value(value, initial_metadata): return _migrate_pr_curve_value(value) if plugin_name == mesh_metadata.PLUGIN_NAME: return _migrate_mesh_value(value) + if plugin_name == custom_scalars_metadata.PLUGIN_NAME: + return _migrate_custom_scalars_value(value) if plugin_name in [ graphs_metadata.PLUGIN_NAME_RUN_METADATA, graphs_metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH, @@ -205,6 +210,12 @@ def _migrate_mesh_value(value): return (value,) +def _migrate_custom_scalars_value(value): + if value.HasField("metadata"): + value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR + return (value,) + + def _migrate_graph_sub_plugin_value(value): if value.HasField("metadata"): value.metadata.data_class = summary_pb2.DATA_CLASS_BLOB_SEQUENCE diff --git a/tensorboard/plugins/custom_scalar/BUILD b/tensorboard/plugins/custom_scalar/BUILD index f83c9ee24f..8a8c4cd8b8 100644 --- a/tensorboard/plugins/custom_scalar/BUILD +++ b/tensorboard/plugins/custom_scalar/BUILD @@ -21,7 +21,6 @@ py_library( "//tensorboard/plugins:base_plugin", "//tensorboard/plugins/scalar:metadata", "//tensorboard/plugins/scalar:scalars_plugin", - "//tensorboard/util:tensor_util", "@org_pocoo_werkzeug", ], ) diff --git a/tensorboard/plugins/custom_scalar/custom_scalars_plugin.py b/tensorboard/plugins/custom_scalar/custom_scalars_plugin.py index 10dff8714a..697c36f969 100644 --- a/tensorboard/plugins/custom_scalar/custom_scalars_plugin.py +++ b/tensorboard/plugins/custom_scalar/custom_scalars_plugin.py @@ -31,12 +31,12 @@ from tensorboard import plugin_util from tensorboard.backend import http_util from tensorboard.compat import tf +from tensorboard.data import provider from tensorboard.plugins import base_plugin from tensorboard.plugins.custom_scalar import layout_pb2 from tensorboard.plugins.custom_scalar import metadata from tensorboard.plugins.scalar import metadata as scalars_metadata from tensorboard.plugins.scalar import scalars_plugin -from tensorboard.util import tensor_util # The name of the property in the response for whether the regex is valid. @@ -63,7 +63,7 @@ def __init__(self, context): context: A base_plugin.TBContext instance. """ self._logdir = context.logdir - self._multiplexer = context.multiplexer + self._data_provider = context.data_provider self._plugin_name_to_instance = context.plugin_name_to_instance def _get_scalars_plugin(self): @@ -214,8 +214,11 @@ def scalars_impl(self, ctx, run, tag_regex_string, experiment): } # Fetch the tags for the run. Filter for tags that match the regex. - run_to_data = self._multiplexer.PluginRunToTagToContent( - scalars_metadata.PLUGIN_NAME + run_to_data = self._data_provider.list_scalars( + ctx, + experiment_id=experiment, + plugin_name=scalars_metadata.PLUGIN_NAME, + run_tag_filter=provider.RunTagFilter(runs=[run]), ) tag_to_data = None @@ -264,29 +267,29 @@ def layout_route(self, request): The response is an empty object if no layout could be found. """ - body = self.layout_impl() + ctx = plugin_util.context(request.environ) + experiment = plugin_util.experiment_id(request.environ) + body = self.layout_impl(ctx, experiment) return http_util.Respond(request, body, "application/json") - def layout_impl(self): + def layout_impl(self, ctx, experiment): # Keep a mapping between and category so we do not create duplicate # categories. title_to_category = {} merged_layout = None - runs = list( - self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME) + data = self._data_provider.read_tensors( + ctx, + experiment_id=experiment, + plugin_name=metadata.PLUGIN_NAME, + run_tag_filter=provider.RunTagFilter( + tags=[metadata.CONFIG_SUMMARY_TAG] + ), + downsample=1, ) - runs.sort() - for run in runs: - tensor_events = self._multiplexer.Tensors( - run, metadata.CONFIG_SUMMARY_TAG - ) - - # This run has a layout. Merge it with the ones currently found. - string_array = tensor_util.make_ndarray( - tensor_events[0].tensor_proto - ) - content = string_array.item() + for run in sorted(data): + points = data[run][metadata.CONFIG_SUMMARY_TAG] + content = points[0].numpy.item() layout_proto = layout_pb2.Layout() layout_proto.ParseFromString(tf.compat.as_bytes(content)) diff --git a/tensorboard/plugins/custom_scalar/custom_scalars_plugin_test.py b/tensorboard/plugins/custom_scalar/custom_scalars_plugin_test.py index f4def5fdcd..f6ca9f10c0 100644 --- a/tensorboard/plugins/custom_scalar/custom_scalars_plugin_test.py +++ b/tensorboard/plugins/custom_scalar/custom_scalars_plugin_test.py @@ -165,7 +165,6 @@ def createPlugin(self, logdir): plugin_name_to_instance = {} context = base_plugin.TBContext( logdir=logdir, - multiplexer=multiplexer, data_provider=provider, plugin_name_to_instance=plugin_name_to_instance, ) @@ -232,8 +231,9 @@ def testScalars(self): np.testing.assert_allclose(step + 1, entry[2]) def testMergedLayout(self): + ctx = context.RequestContext() parsed_layout = layout_pb2.Layout() - json_format.Parse(self.plugin.layout_impl(), parsed_layout) + json_format.Parse(self.plugin.layout_impl(ctx, "exp_id"), parsed_layout) correct_layout = layout_pb2.Layout( category=[ # A category with this name is also present in a layout for a @@ -293,15 +293,19 @@ def testMergedLayout(self): def testLayoutFromSingleRun(self): # The foo directory contains 1 single layout. + ctx = context.RequestContext() local_plugin = self.createPlugin(os.path.join(self.logdir, "foo")) parsed_layout = layout_pb2.Layout() - json_format.Parse(local_plugin.layout_impl(), parsed_layout) + json_format.Parse( + local_plugin.layout_impl(ctx, "exp_id"), parsed_layout + ) self.assertProtoEquals(self.foo_layout, parsed_layout) def testNoLayoutFound(self): # The bar directory contains no layout. + ctx = context.RequestContext() local_plugin = self.createPlugin(os.path.join(self.logdir, "bar")) - self.assertDictEqual({}, local_plugin.layout_impl()) + self.assertDictEqual({}, local_plugin.layout_impl(ctx, "exp_id")) def testIsActive(self): self.assertFalse(self.plugin.is_active())