Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tensorboard/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ py_library(
"//tensorboard/compat/proto:protos_all_py_pb2",
"//tensorboard/plugins/graph:metadata",
"//tensorboard/plugins/histogram:metadata",
"//tensorboard/plugins/hparams:metadata",
"//tensorboard/plugins/image:metadata",
"//tensorboard/plugins/scalar:metadata",
"//tensorboard/plugins/text:metadata",
Expand All @@ -512,6 +513,8 @@ py_test(
"//tensorboard/plugins/graph:metadata",
"//tensorboard/plugins/histogram:metadata",
"//tensorboard/plugins/histogram:summary",
"//tensorboard/plugins/hparams:metadata",
"//tensorboard/plugins/hparams:summary_v2",
"//tensorboard/plugins/scalar:metadata",
"//tensorboard/plugins/scalar:summary",
"//tensorboard/util:tensor_util",
Expand Down
11 changes: 11 additions & 0 deletions tensorboard/dataclass_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@

from tensorboard.compat.proto import event_pb2
from tensorboard.compat.proto import summary_pb2
from tensorboard.compat.proto import types_pb2
from tensorboard.plugins.graph import metadata as graphs_metadata
from tensorboard.plugins.histogram import metadata as histograms_metadata
from tensorboard.plugins.hparams import metadata as hparams_metadata
from tensorboard.plugins.image import metadata as images_metadata
from tensorboard.plugins.scalar import metadata as scalars_metadata
from tensorboard.plugins.text import metadata as text_metadata
Expand Down Expand Up @@ -95,6 +97,8 @@ def _migrate_value(value):
return _migrate_scalar_value(value)
if plugin_name == text_metadata.PLUGIN_NAME:
return _migrate_text_value(value)
if plugin_name == hparams_metadata.PLUGIN_NAME:
return _migrate_hparams_value(value)
return (value,)


Expand All @@ -116,3 +120,10 @@ def _migrate_image_value(value):
def _migrate_text_value(value):
value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR
return (value,)


def _migrate_hparams_value(value):
value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR
if not value.HasField("tensor"):
value.tensor.CopyFrom(hparams_metadata.NULL_TENSOR)
return (value,)
25 changes: 25 additions & 0 deletions tensorboard/dataclass_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from tensorboard.plugins.graph import metadata as graphs_metadata
from tensorboard.plugins.histogram import metadata as histogram_metadata
from tensorboard.plugins.histogram import summary as histogram_summary
from tensorboard.plugins.hparams import metadata as hparams_metadata
from tensorboard.plugins.hparams import summary_v2 as hparams_summary
from tensorboard.plugins.scalar import metadata as scalar_metadata
from tensorboard.plugins.scalar import summary as scalar_summary
from tensorboard.util import tensor_util
Expand Down Expand Up @@ -138,6 +140,29 @@ def test_histogram(self):
histogram_metadata.PLUGIN_NAME,
)

def test_hparams(self):
old_event = event_pb2.Event()
old_event.step = 0
old_event.wall_time = 456.75
hparams_pb = hparams_summary.hparams_pb({"optimizer": "adam"})
# Simulate legacy event with no tensor content
for v in hparams_pb.value:
v.ClearField("tensor")
old_event.summary.CopyFrom(hparams_pb)

new_events = self._migrate_event(old_event)
self.assertLen(new_events, 1)
self.assertLen(new_events[0].summary.value, 1)
value = new_events[0].summary.value[0]
self.assertEqual(value.tensor, hparams_metadata.NULL_TENSOR)
self.assertEqual(
value.metadata.data_class, summary_pb2.DATA_CLASS_TENSOR
)
self.assertEqual(
value.metadata.plugin_data,
hparams_pb.value[0].metadata.plugin_data,
)

def test_graph_def(self):
# Create a `GraphDef` and write it to disk as an event.
logdir = self.get_temp_dir()
Expand Down
1 change: 1 addition & 0 deletions tensorboard/plugins/hparams/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ py_library(
":error",
":protos_all_py_pb2",
"//tensorboard/compat/proto:protos_all_py_pb2",
"//tensorboard/util:tensor_util",
],
)

Expand Down
9 changes: 9 additions & 0 deletions tensorboard/plugins/hparams/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,22 @@
from __future__ import print_function

from tensorboard.compat.proto import summary_pb2
from tensorboard.compat.proto import types_pb2
from tensorboard.plugins.hparams import error
from tensorboard.plugins.hparams import plugin_data_pb2
from tensorboard.util import tensor_util


PLUGIN_NAME = "hparams"
PLUGIN_DATA_VERSION = 0

# Tensor value for use in summaries that really only need to store
# metadata. A length-0 float vector is of minimal serialized length
# (6 bytes) among valid tensors. Cache this: computing it takes on the
# order of tens of microseconds.
NULL_TENSOR = tensor_util.make_tensor_proto(
[], dtype=types_pb2.DT_FLOAT, shape=(0,)
)

EXPERIMENT_TAG = "_hparams_/experiment"
SESSION_START_INFO_TAG = "_hparams_/session_start_info"
Expand Down
11 changes: 10 additions & 1 deletion tensorboard/plugins/hparams/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,5 +193,14 @@ def _summary(tag, hparams_plugin_data):
tb_metadata = metadata.create_summary_metadata(hparams_plugin_data)
raw_metadata = tb_metadata.SerializeToString()
tf_metadata = tf.compat.v1.SummaryMetadata.FromString(raw_metadata)
summary.value.add(tag=tag, metadata=tf_metadata)
summary.value.add(
tag=tag, metadata=tf_metadata, tensor=_TF_NULL_TENSOR,
)
return summary


# Like `metadata.NULL_TENSOR`, but with the TensorFlow version of the
# proto. Slight kludge needed to expose the `TensorProto` type.
_TF_NULL_TENSOR = type(tf.make_tensor_proto(0)).FromString(
metadata.NULL_TENSOR.SerializeToString()
)
2 changes: 2 additions & 0 deletions tensorboard/plugins/hparams/summary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def test_experiment_pb(self):
value=[
tf.compat.v1.Summary.Value(
tag="_hparams_/experiment",
tensor=summary._TF_NULL_TENSOR,
metadata=tf.compat.v1.SummaryMetadata(
plugin_data=tf.compat.v1.SummaryMetadata.PluginData(
plugin_name="hparams",
Expand Down Expand Up @@ -98,6 +99,7 @@ def test_session_end_pb(self):
value=[
tf.compat.v1.Summary.Value(
tag="_hparams_/session_end_info",
tensor=summary._TF_NULL_TENSOR,
metadata=tf.compat.v1.SummaryMetadata(
plugin_data=tf.compat.v1.SummaryMetadata.PluginData(
plugin_name="hparams",
Expand Down
4 changes: 3 additions & 1 deletion tensorboard/plugins/hparams/summary_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,9 @@ def _summary_pb(tag, hparams_plugin_data):
"""
summary = summary_pb2.Summary()
summary_metadata = metadata.create_summary_metadata(hparams_plugin_data)
summary.value.add(tag=tag, metadata=summary_metadata)
value = summary.value.add(
tag=tag, metadata=summary_metadata, tensor=metadata.NULL_TENSOR
)
return summary


Expand Down
7 changes: 7 additions & 0 deletions tensorboard/plugins/hparams/summary_v2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from tensorboard import test
from tensorboard.compat import tf
from tensorboard.compat.proto import summary_pb2
from tensorboard.compat.proto import tensor_pb2
from tensorboard.plugins.hparams import api_pb2
from tensorboard.plugins.hparams import metadata
from tensorboard.plugins.hparams import plugin_data_pb2
Expand Down Expand Up @@ -108,6 +109,12 @@ def _check_summary(self, summary_pb, check_group_name=False):
self.assertEqual(
actual_value.metadata.plugin_data.plugin_name, metadata.PLUGIN_NAME,
)
self.assertEqual(
tensor_pb2.TensorProto.FromString(
actual_value.tensor.SerializeToString()
),
metadata.NULL_TENSOR,
)
plugin_content = actual_value.metadata.plugin_data.content
info_pb = metadata.parse_session_start_info_plugin_data(plugin_content)
# Usually ignore the `group_name` field; its properties are checked
Expand Down