diff --git a/tensorboard/BUILD b/tensorboard/BUILD index 671273f547..8c5dec4463 100644 --- a/tensorboard/BUILD +++ b/tensorboard/BUILD @@ -491,6 +491,7 @@ py_library( "//tensorboard/compat/proto:protos_all_py_pb2", "//tensorboard/plugins/graph:metadata", "//tensorboard/plugins/histogram:metadata", + "//tensorboard/plugins/hparams:metadata", "//tensorboard/plugins/image:metadata", "//tensorboard/plugins/scalar:metadata", "//tensorboard/plugins/text:metadata", @@ -512,6 +513,8 @@ py_test( "//tensorboard/plugins/graph:metadata", "//tensorboard/plugins/histogram:metadata", "//tensorboard/plugins/histogram:summary", + "//tensorboard/plugins/hparams:metadata", + "//tensorboard/plugins/hparams:summary_v2", "//tensorboard/plugins/scalar:metadata", "//tensorboard/plugins/scalar:summary", "//tensorboard/util:tensor_util", diff --git a/tensorboard/dataclass_compat.py b/tensorboard/dataclass_compat.py index 6d2b834b50..c21bd13daf 100644 --- a/tensorboard/dataclass_compat.py +++ b/tensorboard/dataclass_compat.py @@ -27,8 +27,10 @@ from tensorboard.compat.proto import event_pb2 from tensorboard.compat.proto import summary_pb2 +from tensorboard.compat.proto import types_pb2 from tensorboard.plugins.graph import metadata as graphs_metadata from tensorboard.plugins.histogram import metadata as histograms_metadata +from tensorboard.plugins.hparams import metadata as hparams_metadata from tensorboard.plugins.image import metadata as images_metadata from tensorboard.plugins.scalar import metadata as scalars_metadata from tensorboard.plugins.text import metadata as text_metadata @@ -95,6 +97,8 @@ def _migrate_value(value): return _migrate_scalar_value(value) if plugin_name == text_metadata.PLUGIN_NAME: return _migrate_text_value(value) + if plugin_name == hparams_metadata.PLUGIN_NAME: + return _migrate_hparams_value(value) return (value,) @@ -116,3 +120,10 @@ def _migrate_image_value(value): def _migrate_text_value(value): value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR return (value,) + + +def _migrate_hparams_value(value): + value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR + if not value.HasField("tensor"): + value.tensor.CopyFrom(hparams_metadata.NULL_TENSOR) + return (value,) diff --git a/tensorboard/dataclass_compat_test.py b/tensorboard/dataclass_compat_test.py index 9f4a44aaab..e4be74d94c 100644 --- a/tensorboard/dataclass_compat_test.py +++ b/tensorboard/dataclass_compat_test.py @@ -31,6 +31,8 @@ from tensorboard.plugins.graph import metadata as graphs_metadata from tensorboard.plugins.histogram import metadata as histogram_metadata from tensorboard.plugins.histogram import summary as histogram_summary +from tensorboard.plugins.hparams import metadata as hparams_metadata +from tensorboard.plugins.hparams import summary_v2 as hparams_summary from tensorboard.plugins.scalar import metadata as scalar_metadata from tensorboard.plugins.scalar import summary as scalar_summary from tensorboard.util import tensor_util @@ -138,6 +140,29 @@ def test_histogram(self): histogram_metadata.PLUGIN_NAME, ) + def test_hparams(self): + old_event = event_pb2.Event() + old_event.step = 0 + old_event.wall_time = 456.75 + hparams_pb = hparams_summary.hparams_pb({"optimizer": "adam"}) + # Simulate legacy event with no tensor content + for v in hparams_pb.value: + v.ClearField("tensor") + old_event.summary.CopyFrom(hparams_pb) + + new_events = self._migrate_event(old_event) + self.assertLen(new_events, 1) + self.assertLen(new_events[0].summary.value, 1) + value = new_events[0].summary.value[0] + self.assertEqual(value.tensor, hparams_metadata.NULL_TENSOR) + self.assertEqual( + value.metadata.data_class, summary_pb2.DATA_CLASS_TENSOR + ) + self.assertEqual( + value.metadata.plugin_data, + hparams_pb.value[0].metadata.plugin_data, + ) + def test_graph_def(self): # Create a `GraphDef` and write it to disk as an event. logdir = self.get_temp_dir() diff --git a/tensorboard/plugins/hparams/BUILD b/tensorboard/plugins/hparams/BUILD index b5a0f2b1a0..eac0b8bfed 100644 --- a/tensorboard/plugins/hparams/BUILD +++ b/tensorboard/plugins/hparams/BUILD @@ -316,6 +316,7 @@ py_library( ":error", ":protos_all_py_pb2", "//tensorboard/compat/proto:protos_all_py_pb2", + "//tensorboard/util:tensor_util", ], ) diff --git a/tensorboard/plugins/hparams/metadata.py b/tensorboard/plugins/hparams/metadata.py index 92f2dc4141..0b9677024e 100644 --- a/tensorboard/plugins/hparams/metadata.py +++ b/tensorboard/plugins/hparams/metadata.py @@ -19,13 +19,22 @@ from __future__ import print_function from tensorboard.compat.proto import summary_pb2 +from tensorboard.compat.proto import types_pb2 from tensorboard.plugins.hparams import error from tensorboard.plugins.hparams import plugin_data_pb2 +from tensorboard.util import tensor_util PLUGIN_NAME = "hparams" PLUGIN_DATA_VERSION = 0 +# Tensor value for use in summaries that really only need to store +# metadata. A length-0 float vector is of minimal serialized length +# (6 bytes) among valid tensors. Cache this: computing it takes on the +# order of tens of microseconds. +NULL_TENSOR = tensor_util.make_tensor_proto( + [], dtype=types_pb2.DT_FLOAT, shape=(0,) +) EXPERIMENT_TAG = "_hparams_/experiment" SESSION_START_INFO_TAG = "_hparams_/session_start_info" diff --git a/tensorboard/plugins/hparams/summary.py b/tensorboard/plugins/hparams/summary.py index 0d8234d0b6..489c2fbad5 100644 --- a/tensorboard/plugins/hparams/summary.py +++ b/tensorboard/plugins/hparams/summary.py @@ -193,5 +193,14 @@ def _summary(tag, hparams_plugin_data): tb_metadata = metadata.create_summary_metadata(hparams_plugin_data) raw_metadata = tb_metadata.SerializeToString() tf_metadata = tf.compat.v1.SummaryMetadata.FromString(raw_metadata) - summary.value.add(tag=tag, metadata=tf_metadata) + summary.value.add( + tag=tag, metadata=tf_metadata, tensor=_TF_NULL_TENSOR, + ) return summary + + +# Like `metadata.NULL_TENSOR`, but with the TensorFlow version of the +# proto. Slight kludge needed to expose the `TensorProto` type. +_TF_NULL_TENSOR = type(tf.make_tensor_proto(0)).FromString( + metadata.NULL_TENSOR.SerializeToString() +) diff --git a/tensorboard/plugins/hparams/summary_test.py b/tensorboard/plugins/hparams/summary_test.py index 3a64c8c901..8b72fc6348 100644 --- a/tensorboard/plugins/hparams/summary_test.py +++ b/tensorboard/plugins/hparams/summary_test.py @@ -70,6 +70,7 @@ def test_experiment_pb(self): value=[ tf.compat.v1.Summary.Value( tag="_hparams_/experiment", + tensor=summary._TF_NULL_TENSOR, metadata=tf.compat.v1.SummaryMetadata( plugin_data=tf.compat.v1.SummaryMetadata.PluginData( plugin_name="hparams", @@ -98,6 +99,7 @@ def test_session_end_pb(self): value=[ tf.compat.v1.Summary.Value( tag="_hparams_/session_end_info", + tensor=summary._TF_NULL_TENSOR, metadata=tf.compat.v1.SummaryMetadata( plugin_data=tf.compat.v1.SummaryMetadata.PluginData( plugin_name="hparams", diff --git a/tensorboard/plugins/hparams/summary_v2.py b/tensorboard/plugins/hparams/summary_v2.py index 6b3af6c151..ce2ee37c88 100644 --- a/tensorboard/plugins/hparams/summary_v2.py +++ b/tensorboard/plugins/hparams/summary_v2.py @@ -268,7 +268,9 @@ def _summary_pb(tag, hparams_plugin_data): """ summary = summary_pb2.Summary() summary_metadata = metadata.create_summary_metadata(hparams_plugin_data) - summary.value.add(tag=tag, metadata=summary_metadata) + value = summary.value.add( + tag=tag, metadata=summary_metadata, tensor=metadata.NULL_TENSOR + ) return summary diff --git a/tensorboard/plugins/hparams/summary_v2_test.py b/tensorboard/plugins/hparams/summary_v2_test.py index 89500a9d42..7d6f681f87 100644 --- a/tensorboard/plugins/hparams/summary_v2_test.py +++ b/tensorboard/plugins/hparams/summary_v2_test.py @@ -39,6 +39,7 @@ from tensorboard import test from tensorboard.compat import tf from tensorboard.compat.proto import summary_pb2 +from tensorboard.compat.proto import tensor_pb2 from tensorboard.plugins.hparams import api_pb2 from tensorboard.plugins.hparams import metadata from tensorboard.plugins.hparams import plugin_data_pb2 @@ -108,6 +109,12 @@ def _check_summary(self, summary_pb, check_group_name=False): self.assertEqual( actual_value.metadata.plugin_data.plugin_name, metadata.PLUGIN_NAME, ) + self.assertEqual( + tensor_pb2.TensorProto.FromString( + actual_value.tensor.SerializeToString() + ), + metadata.NULL_TENSOR, + ) plugin_content = actual_value.metadata.plugin_data.content info_pb = metadata.parse_session_start_info_plugin_data(plugin_content) # Usually ignore the `group_name` field; its properties are checked