Skip to content
15 changes: 15 additions & 0 deletions tensorboard/dataclass_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@ def _migrate_value(value, initial_metadata):
return _migrate_hparams_value(value)
if plugin_name == pr_curves_metadata.PLUGIN_NAME:
return _migrate_pr_curve_value(value)
if plugin_name in [
graphs_metadata.PLUGIN_NAME_RUN_METADATA,
graphs_metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH,
graphs_metadata.PLUGIN_NAME_KERAS_MODEL,
]:
return _migrate_graph_sub_plugin_value(value)
return (value,)


Expand Down Expand Up @@ -191,3 +197,12 @@ def _migrate_pr_curve_value(value):
if value.HasField("metadata"):
value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR
return (value,)


def _migrate_graph_sub_plugin_value(value):
if value.HasField("metadata"):
value.metadata.data_class = summary_pb2.DATA_CLASS_BLOB_SEQUENCE
shape = value.tensor.tensor_shape.dim
if not shape:
shape.add(size=1)
return (value,)
37 changes: 37 additions & 0 deletions tensorboard/dataclass_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,43 @@ def test_run_metadata(self):
graphs_metadata.PLUGIN_NAME_TAGGED_RUN_METADATA,
)

def test_graph_sub_plugins(self):
# Tests for `graph_run_metadata`, `graph_run_metadata_graph`,
# and `graph_keras_model` plugins. We fabricate these since it's
# not straightforward to get handles to them.
for plugin_name in [
graphs_metadata.PLUGIN_NAME_RUN_METADATA,
graphs_metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH,
graphs_metadata.PLUGIN_NAME_KERAS_MODEL,
]:
with self.subTest(plugin_name):
old_event = event_pb2.Event()
old_event.step = 123
old_event.wall_time = 456.75
old_value = old_event.summary.value.add()
old_value.metadata.plugin_data.plugin_name = plugin_name
old_value.metadata.plugin_data.content = b"1"
old_tensor = tensor_util.make_tensor_proto(b"2+2=4")
# input data are scalar tensors
self.assertEqual(tensor_util.make_ndarray(old_tensor).shape, ())
old_value.tensor.CopyFrom(old_tensor)

new_events = self._migrate_event(old_event)
self.assertLen(new_events, 1)
self.assertLen(new_events[0].summary.value, 1)
new_value = new_events[0].summary.value[0]
ndarray = tensor_util.make_ndarray(new_value.tensor)
self.assertEqual(ndarray.shape, (1,))
self.assertEqual(ndarray.item(), b"2+2=4")
self.assertEqual(
new_value.metadata.data_class,
summary_pb2.DATA_CLASS_BLOB_SEQUENCE,
)
self.assertEqual(
new_value.metadata.plugin_data.plugin_name, plugin_name
)
self.assertEqual(new_value.metadata.plugin_data.content, b"1")


if __name__ == "__main__":
tf.test.main()
2 changes: 1 addition & 1 deletion tensorboard/plugins/graph/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ py_library(
":graph_util",
":keras_util",
":metadata",
"//tensorboard:errors",
"//tensorboard:plugin_util",
"//tensorboard/backend:http_util",
"//tensorboard/backend:process_graph",
"//tensorboard/data:provider",
"//tensorboard/plugins:base_plugin",
"@com_google_protobuf//:protobuf_python",
"@org_pocoo_werkzeug",
"@org_pythonhosted_six",
],
)

Expand Down
Loading