Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorboard/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ py_library(
"//tensorboard/plugins/histogram:metadata",
"//tensorboard/plugins/hparams:metadata",
"//tensorboard/plugins/image:metadata",
"//tensorboard/plugins/mesh:metadata",
"//tensorboard/plugins/pr_curve:metadata",
"//tensorboard/plugins/scalar:metadata",
"//tensorboard/plugins/text:metadata",
Expand Down
5 changes: 4 additions & 1 deletion tensorboard/data/server/data_compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub(crate) mod plugin_names {
pub const TEXT: &str = "text";
pub const PR_CURVES: &str = "pr_curves";
pub const HPARAMS: &str = "hparams";
pub const MESH: &str = "mesh";
}

/// The inner contents of a single value from an event.
Expand Down Expand Up @@ -350,7 +351,8 @@ impl SummaryValue {
Some(plugin_names::HISTOGRAMS)
| Some(plugin_names::TEXT)
| Some(plugin_names::HPARAMS)
| Some(plugin_names::PR_CURVES) => {
| Some(plugin_names::PR_CURVES)
| Some(plugin_names::MESH) => {
md.data_class = pb::DataClass::Tensor.into();
}
Some(plugin_names::IMAGES)
Expand Down Expand Up @@ -740,6 +742,7 @@ mod tests {
plugin_names::TEXT,
plugin_names::PR_CURVES,
plugin_names::HPARAMS,
plugin_names::MESH,
] {
let md = blank_with_plugin_content(
plugin_name,
Expand Down
9 changes: 9 additions & 0 deletions tensorboard/dataclass_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from tensorboard.plugins.histogram import metadata as histograms_metadata
from tensorboard.plugins.hparams import metadata as hparams_metadata
from tensorboard.plugins.image import metadata as images_metadata
from tensorboard.plugins.mesh import metadata as mesh_metadata
from tensorboard.plugins.pr_curve import metadata as pr_curves_metadata
from tensorboard.plugins.scalar import metadata as scalars_metadata
from tensorboard.plugins.text import metadata as text_metadata
Expand Down Expand Up @@ -136,6 +137,8 @@ def _migrate_value(value, initial_metadata):
return _migrate_hparams_value(value)
if plugin_name == pr_curves_metadata.PLUGIN_NAME:
return _migrate_pr_curve_value(value)
if plugin_name == mesh_metadata.PLUGIN_NAME:
return _migrate_mesh_value(value)
if plugin_name in [
graphs_metadata.PLUGIN_NAME_RUN_METADATA,
graphs_metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH,
Expand Down Expand Up @@ -196,6 +199,12 @@ def _migrate_pr_curve_value(value):
return (value,)


def _migrate_mesh_value(value):
if value.HasField("metadata"):
value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR
return (value,)


def _migrate_graph_sub_plugin_value(value):
if value.HasField("metadata"):
value.metadata.data_class = summary_pb2.DATA_CLASS_BLOB_SEQUENCE
Expand Down
3 changes: 2 additions & 1 deletion tensorboard/plugins/mesh/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ py_library(
"//tensorboard:expect_numpy_installed",
"//tensorboard:plugin_util",
"//tensorboard/backend:http_util",
"//tensorboard/data:provider",
"//tensorboard/plugins:base_plugin",
"//tensorboard/util:tensor_util",
"@org_pocoo_werkzeug",
],
)
Expand Down Expand Up @@ -76,6 +76,7 @@ py_test(
"//tensorboard:expect_numpy_installed",
"//tensorboard:expect_tensorflow_installed",
"//tensorboard/backend:application",
"//tensorboard/backend/event_processing:data_provider",
"//tensorboard/backend/event_processing:event_multiplexer",
"//tensorboard/plugins:base_plugin",
"//tensorboard/util:test_util",
Expand Down
94 changes: 63 additions & 31 deletions tensorboard/plugins/mesh/mesh_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
from werkzeug import wrappers

from tensorboard.backend import http_util
from tensorboard.data import provider
from tensorboard.plugins import base_plugin
from tensorboard.plugins.mesh import metadata
from tensorboard.plugins.mesh import plugin_data_pb2
from tensorboard.util import tensor_util
from tensorboard import plugin_util

_DEFAULT_DOWNSAMPLING = 100 # meshes per time series


class MeshPlugin(base_plugin.TBPlugin):
Expand All @@ -36,28 +39,42 @@ def __init__(self, context):
context: A base_plugin.TBContext instance. A magic container that
TensorBoard uses to make objects available to the plugin.
"""
# Retrieve the multiplexer from the context and store a reference to it.
self._multiplexer = context.multiplexer
self._data_provider = context.data_provider
self._downsample_to = (context.sampling_hints or {}).get(
self.plugin_name, _DEFAULT_DOWNSAMPLING
)

def _instance_tag_metadata(self, run, instance_tag):
def _instance_tag_metadata(self, ctx, experiment, run, instance_tag):
"""Gets the `MeshPluginData` proto for an instance tag."""
summary_metadata = self._multiplexer.SummaryMetadata(run, instance_tag)
content = summary_metadata.plugin_data.content
results = self._data_provider.list_tensors(
ctx,
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME,
run_tag_filter=provider.RunTagFilter(
runs=[run], tags=[instance_tag]
),
)
content = results[run][instance_tag].plugin_content
return metadata.parse_plugin_metadata(content)

def _tag(self, run, instance_tag):
def _tag(self, ctx, experiment, run, instance_tag):
"""Gets the user-facing tag name for an instance tag."""
return self._instance_tag_metadata(run, instance_tag).name
return self._instance_tag_metadata(
ctx, experiment, run, instance_tag
).name

def _instance_tags(self, run, tag):
def _instance_tags(self, ctx, experiment, run, tag):
"""Gets the instance tag names for a user-facing tag."""
index = self._multiplexer.GetAccumulator(run).PluginTagToContent(
metadata.PLUGIN_NAME
index = self._data_provider.list_tensors(
ctx,
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME,
run_tag_filter=provider.RunTagFilter(runs=[run]),
)
return [
instance_tag
for (instance_tag, content) in index.items()
if tag == metadata.parse_plugin_metadata(content).name
for (instance_tag, ts) in index.get(run, {}).items()
if tag == metadata.parse_plugin_metadata(ts.plugin_content).name
]

@wrappers.Request.application
Expand All @@ -72,23 +89,27 @@ def _serve_tags(self, request):
are all the runs. Each run is mapped to a (potentially empty)
list of all tags that are relevant to this plugin.
"""
# This is a dictionary mapping from run to (tag to string content).
# To be clear, the values of the dictionary are dictionaries.
all_runs = self._multiplexer.PluginRunToTagToContent(
MeshPlugin.plugin_name
ctx = plugin_util.context(request.environ)
experiment = plugin_util.experiment_id(request.environ)
all_runs = self._data_provider.list_tensors(
ctx,
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME,
)

# tagToContent is itself a dictionary mapping tag name to string
# SummaryMetadata.plugin_data.content. Retrieve the keys of that dictionary
# to obtain a list of tags associated with each run. For each tag estimate
# number of samples.
response = dict()
for run, tag_to_content in all_runs.items():
for run, tags in all_runs.items():
response[run] = dict()
for instance_tag, _ in tag_to_content.items():
for instance_tag in tags:
# Make sure we only operate on user-defined tags here.
tag = self._tag(run, instance_tag)
meta = self._instance_tag_metadata(run, instance_tag)
tag = self._tag(ctx, experiment, run, instance_tag)
meta = self._instance_tag_metadata(
ctx, experiment, run, instance_tag
)
# Batch size must be defined, otherwise we don't know how many
# samples were there.
response[run][tag] = {"samples": meta.shape[0]}
Expand Down Expand Up @@ -117,20 +138,19 @@ def is_active(self):
def frontend_metadata(self):
return base_plugin.FrontendMetadata(element_name="mesh-dashboard")

def _get_sample(self, tensor_event, sample):
def _get_sample(self, tensor_datum, sample):
"""Returns a single sample from a batch of samples."""
data = tensor_util.make_ndarray(tensor_event.tensor_proto)
return data[sample].tolist()
return tensor_datum.numpy[sample].tolist()

def _get_tensor_metadata(
self, event, content_type, components, data_shape, config
):
"""Converts a TensorEvent into a JSON-compatible response.
"""Converts a TensorDatum into a JSON-compatible response.

Args:
event: TensorEvent object containing data in proto format.
event: TensorDatum object containing data in proto format.
content_type: enum plugin_data_pb2.MeshPluginData.ContentType value,
representing content type in TensorEvent.
representing content type in TensorDatum.
components: Bitmask representing all parts (vertices, colors, etc.) that
belong to the summary.
data_shape: list of dimensions sizes of the tensor.
Expand All @@ -149,19 +169,31 @@ def _get_tensor_metadata(
}

def _get_tensor_data(self, event, sample):
"""Convert a TensorEvent into a JSON-compatible response."""
"""Convert a TensorDatum into a JSON-compatible response."""
data = self._get_sample(event, sample)
return data

def _collect_tensor_events(self, request, step=None):
"""Collects list of tensor events based on request."""
ctx = plugin_util.context(request.environ)
experiment = plugin_util.experiment_id(request.environ)
run = request.args.get("run")
tag = request.args.get("tag")

tensor_events = [] # List of tuples (meta, tensor) that contain tag.
for instance_tag in self._instance_tags(run, tag):
tensors = self._multiplexer.Tensors(run, instance_tag)
meta = self._instance_tag_metadata(run, instance_tag)
for instance_tag in self._instance_tags(ctx, experiment, run, tag):
tensors = self._data_provider.read_tensors(
ctx,
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME,
run_tag_filter=provider.RunTagFilter(
runs=[run], tags=[instance_tag]
),
downsample=self._downsample_to,
)[run][instance_tag]
meta = self._instance_tag_metadata(
ctx, experiment, run, instance_tag
)
tensor_events += [(meta, tensor) for tensor in tensors]

if step is not None:
Expand Down
10 changes: 7 additions & 3 deletions tensorboard/plugins/mesh/mesh_plugin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from werkzeug import test as werkzeug_test
from werkzeug import wrappers
from tensorboard.backend import application
from tensorboard.backend.event_processing import data_provider
from tensorboard.backend.event_processing import (
plugin_event_multiplexer as event_multiplexer,
)
Expand Down Expand Up @@ -137,13 +138,16 @@ def setUp(self):
)

# Start a server that will receive requests.
self.multiplexer = event_multiplexer.EventMultiplexer(
multiplexer = event_multiplexer.EventMultiplexer(
{
"bar": bar_directory,
}
)
provider = data_provider.MultiplexerDataProvider(
multiplexer, self.log_dir
)
self.context = base_plugin.TBContext(
logdir=self.log_dir, multiplexer=self.multiplexer
logdir=self.log_dir, data_provider=provider
)
self.plugin = mesh_plugin.MeshPlugin(self.context)
# Wait until after plugin construction to reload the multiplexer because the
Expand All @@ -152,7 +156,7 @@ def setUp(self):
# TODO(https://github.com/tensorflow/tensorboard/issues/2579): Eliminate the
# caching of data at construction time and move this Reload() up to just
# after the multiplexer is created.
self.multiplexer.Reload()
multiplexer.Reload()
wsgi_app = application.TensorBoardWSGI([self.plugin])
self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse)
self.routes = self.plugin.get_plugin_apps()
Expand Down