Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorboard/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ py_library(
deps = [
"//tensorboard/compat/proto:protos_all_py_pb2",
"//tensorboard/plugins/audio:metadata",
"//tensorboard/plugins/custom_scalar:metadata",
"//tensorboard/plugins/graph:metadata",
"//tensorboard/plugins/histogram:metadata",
"//tensorboard/plugins/hparams:metadata",
Expand Down
5 changes: 4 additions & 1 deletion tensorboard/data/server/data_compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pub(crate) mod plugin_names {
pub const PR_CURVES: &str = "pr_curves";
pub const HPARAMS: &str = "hparams";
pub const MESH: &str = "mesh";
pub const CUSTOM_SCALARS: &str = "custom_scalars";
}

/// The inner contents of a single value from an event.
Expand Down Expand Up @@ -352,7 +353,8 @@ impl SummaryValue {
| Some(plugin_names::TEXT)
| Some(plugin_names::HPARAMS)
| Some(plugin_names::PR_CURVES)
| Some(plugin_names::MESH) => {
| Some(plugin_names::MESH)
| Some(plugin_names::CUSTOM_SCALARS) => {
md.data_class = pb::DataClass::Tensor.into();
}
Some(plugin_names::IMAGES)
Expand Down Expand Up @@ -743,6 +745,7 @@ mod tests {
plugin_names::PR_CURVES,
plugin_names::HPARAMS,
plugin_names::MESH,
plugin_names::CUSTOM_SCALARS,
] {
let md = blank_with_plugin_content(
plugin_name,
Expand Down
11 changes: 11 additions & 0 deletions tensorboard/dataclass_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from tensorboard.compat.proto import event_pb2
from tensorboard.compat.proto import summary_pb2
from tensorboard.plugins.audio import metadata as audio_metadata
from tensorboard.plugins.custom_scalar import (
metadata as custom_scalars_metadata,
)
from tensorboard.plugins.graph import metadata as graphs_metadata
from tensorboard.plugins.histogram import metadata as histograms_metadata
from tensorboard.plugins.hparams import metadata as hparams_metadata
Expand Down Expand Up @@ -139,6 +142,8 @@ def _migrate_value(value, initial_metadata):
return _migrate_pr_curve_value(value)
if plugin_name == mesh_metadata.PLUGIN_NAME:
return _migrate_mesh_value(value)
if plugin_name == custom_scalars_metadata.PLUGIN_NAME:
return _migrate_custom_scalars_value(value)
if plugin_name in [
graphs_metadata.PLUGIN_NAME_RUN_METADATA,
graphs_metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH,
Expand Down Expand Up @@ -205,6 +210,12 @@ def _migrate_mesh_value(value):
return (value,)


def _migrate_custom_scalars_value(value):
if value.HasField("metadata"):
value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR
return (value,)


def _migrate_graph_sub_plugin_value(value):
if value.HasField("metadata"):
value.metadata.data_class = summary_pb2.DATA_CLASS_BLOB_SEQUENCE
Expand Down
1 change: 0 additions & 1 deletion tensorboard/plugins/custom_scalar/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ py_library(
"//tensorboard/plugins:base_plugin",
"//tensorboard/plugins/scalar:metadata",
"//tensorboard/plugins/scalar:scalars_plugin",
"//tensorboard/util:tensor_util",
"@org_pocoo_werkzeug",
],
)
Expand Down
41 changes: 22 additions & 19 deletions tensorboard/plugins/custom_scalar/custom_scalars_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
from tensorboard import plugin_util
from tensorboard.backend import http_util
from tensorboard.compat import tf
from tensorboard.data import provider
from tensorboard.plugins import base_plugin
from tensorboard.plugins.custom_scalar import layout_pb2
from tensorboard.plugins.custom_scalar import metadata
from tensorboard.plugins.scalar import metadata as scalars_metadata
from tensorboard.plugins.scalar import scalars_plugin
from tensorboard.util import tensor_util


# The name of the property in the response for whether the regex is valid.
Expand All @@ -63,7 +63,7 @@ def __init__(self, context):
context: A base_plugin.TBContext instance.
"""
self._logdir = context.logdir
self._multiplexer = context.multiplexer
self._data_provider = context.data_provider
self._plugin_name_to_instance = context.plugin_name_to_instance

def _get_scalars_plugin(self):
Expand Down Expand Up @@ -214,8 +214,11 @@ def scalars_impl(self, ctx, run, tag_regex_string, experiment):
}

# Fetch the tags for the run. Filter for tags that match the regex.
run_to_data = self._multiplexer.PluginRunToTagToContent(
scalars_metadata.PLUGIN_NAME
run_to_data = self._data_provider.list_scalars(
ctx,
experiment_id=experiment,
plugin_name=scalars_metadata.PLUGIN_NAME,
run_tag_filter=provider.RunTagFilter(runs=[run]),
)

tag_to_data = None
Expand Down Expand Up @@ -264,29 +267,29 @@ def layout_route(self, request):

The response is an empty object if no layout could be found.
"""
body = self.layout_impl()
ctx = plugin_util.context(request.environ)
experiment = plugin_util.experiment_id(request.environ)
body = self.layout_impl(ctx, experiment)
return http_util.Respond(request, body, "application/json")

def layout_impl(self):
def layout_impl(self, ctx, experiment):
# Keep a mapping between and category so we do not create duplicate
# categories.
title_to_category = {}

merged_layout = None
runs = list(
self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME)
data = self._data_provider.read_tensors(
ctx,
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME,
run_tag_filter=provider.RunTagFilter(
tags=[metadata.CONFIG_SUMMARY_TAG]
),
downsample=1,
)
runs.sort()
for run in runs:
tensor_events = self._multiplexer.Tensors(
run, metadata.CONFIG_SUMMARY_TAG
)

# This run has a layout. Merge it with the ones currently found.
string_array = tensor_util.make_ndarray(
tensor_events[0].tensor_proto
)
content = string_array.item()
for run in sorted(data):
points = data[run][metadata.CONFIG_SUMMARY_TAG]
content = points[0].numpy.item()
layout_proto = layout_pb2.Layout()
layout_proto.ParseFromString(tf.compat.as_bytes(content))

Expand Down
12 changes: 8 additions & 4 deletions tensorboard/plugins/custom_scalar/custom_scalars_plugin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def createPlugin(self, logdir):
plugin_name_to_instance = {}
context = base_plugin.TBContext(
logdir=logdir,
multiplexer=multiplexer,
data_provider=provider,
plugin_name_to_instance=plugin_name_to_instance,
)
Expand Down Expand Up @@ -232,8 +231,9 @@ def testScalars(self):
np.testing.assert_allclose(step + 1, entry[2])

def testMergedLayout(self):
ctx = context.RequestContext()
parsed_layout = layout_pb2.Layout()
json_format.Parse(self.plugin.layout_impl(), parsed_layout)
json_format.Parse(self.plugin.layout_impl(ctx, "exp_id"), parsed_layout)
correct_layout = layout_pb2.Layout(
category=[
# A category with this name is also present in a layout for a
Expand Down Expand Up @@ -293,15 +293,19 @@ def testMergedLayout(self):

def testLayoutFromSingleRun(self):
# The foo directory contains 1 single layout.
ctx = context.RequestContext()
local_plugin = self.createPlugin(os.path.join(self.logdir, "foo"))
parsed_layout = layout_pb2.Layout()
json_format.Parse(local_plugin.layout_impl(), parsed_layout)
json_format.Parse(
local_plugin.layout_impl(ctx, "exp_id"), parsed_layout
)
self.assertProtoEquals(self.foo_layout, parsed_layout)

def testNoLayoutFound(self):
# The bar directory contains no layout.
ctx = context.RequestContext()
local_plugin = self.createPlugin(os.path.join(self.logdir, "bar"))
self.assertDictEqual({}, local_plugin.layout_impl())
self.assertDictEqual({}, local_plugin.layout_impl(ctx, "exp_id"))

def testIsActive(self):
self.assertFalse(self.plugin.is_active())
Expand Down