diff --git a/tensorboard/BUILD b/tensorboard/BUILD index f0ed46decf..f30f7e6c40 100644 --- a/tensorboard/BUILD +++ b/tensorboard/BUILD @@ -498,6 +498,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorboard/compat/proto:protos_all_py_pb2", + "//tensorboard/plugins/audio:metadata", "//tensorboard/plugins/graph:metadata", "//tensorboard/plugins/histogram:metadata", "//tensorboard/plugins/hparams:metadata", @@ -519,6 +520,8 @@ py_test( "//tensorboard:expect_tensorflow_installed", "//tensorboard/backend/event_processing:event_file_loader", "//tensorboard/compat/proto:protos_all_py_pb2", + "//tensorboard/plugins/audio:metadata", + "//tensorboard/plugins/audio:summary", "//tensorboard/plugins/graph:metadata", "//tensorboard/plugins/histogram:metadata", "//tensorboard/plugins/histogram:summary", diff --git a/tensorboard/dataclass_compat.py b/tensorboard/dataclass_compat.py index e42e288b6d..3a7e065f1e 100644 --- a/tensorboard/dataclass_compat.py +++ b/tensorboard/dataclass_compat.py @@ -28,6 +28,7 @@ from tensorboard.compat.proto import event_pb2 from tensorboard.compat.proto import summary_pb2 from tensorboard.compat.proto import types_pb2 +from tensorboard.plugins.audio import metadata as audio_metadata from tensorboard.plugins.graph import metadata as graphs_metadata from tensorboard.plugins.histogram import metadata as histograms_metadata from tensorboard.plugins.hparams import metadata as hparams_metadata @@ -110,6 +111,8 @@ def _migrate_value(value, initial_metadata): return _migrate_histogram_value(value) if plugin_name == images_metadata.PLUGIN_NAME: return _migrate_image_value(value) + if plugin_name == audio_metadata.PLUGIN_NAME: + return _migrate_audio_value(value) if plugin_name == scalars_metadata.PLUGIN_NAME: return _migrate_scalar_value(value) if plugin_name == text_metadata.PLUGIN_NAME: @@ -143,6 +146,19 @@ def _migrate_text_value(value): return (value,) +def _migrate_audio_value(value): + if value.HasField("metadata"): + value.metadata.data_class = summary_pb2.DATA_CLASS_BLOB_SEQUENCE + tensor = value.tensor + # Project out just the first axis: actual audio clips. + stride = 1 + while len(tensor.tensor_shape.dim) > 1: + stride *= tensor.tensor_shape.dim.pop().size + if stride != 1: + tensor.string_val[:] = tensor.string_val[::stride] + return (value,) + + def _migrate_hparams_value(value): if value.HasField("metadata"): value.metadata.data_class = summary_pb2.DATA_CLASS_TENSOR diff --git a/tensorboard/dataclass_compat_test.py b/tensorboard/dataclass_compat_test.py index 34e51b38b9..cbd29e62e7 100644 --- a/tensorboard/dataclass_compat_test.py +++ b/tensorboard/dataclass_compat_test.py @@ -31,6 +31,8 @@ from tensorboard.compat.proto import graph_pb2 from tensorboard.compat.proto import node_def_pb2 from tensorboard.compat.proto import summary_pb2 +from tensorboard.plugins.audio import metadata as audio_metadata +from tensorboard.plugins.audio import summary as audio_summary from tensorboard.plugins.graph import metadata as graphs_metadata from tensorboard.plugins.histogram import metadata as histogram_metadata from tensorboard.plugins.histogram import summary as histogram_summary @@ -41,11 +43,7 @@ from tensorboard.util import tensor_util from tensorboard.util import test_util -try: - # python version >= 3.3 - from unittest import mock -except ImportError: - import mock # pylint: disable=unused-import +tf.compat.v1.enable_eager_execution() class MigrateEventTest(tf.test.TestCase): @@ -182,6 +180,68 @@ def test_histogram(self): histogram_metadata.PLUGIN_NAME, ) + def test_audio(self): + logdir = self.get_temp_dir() + steps = (0, 1, 2) + with test_util.FileWriter(logdir) as writer: + for step in steps: + event = event_pb2.Event() + event.step = step + event.wall_time = 456.75 * step + audio = tf.reshape( + tf.linspace(0.0, 100.0, 4 * 10 * 2), (4, 10, 2) + ) + audio_pb = audio_summary.pb( + "foo", + audio, + labels=["one", "two", "three", "four"], + sample_rate=44100, + display_name="bar", + description="baz", + ) + writer.add_summary( + audio_pb.SerializeToString(), global_step=step + ) + files = os.listdir(logdir) + self.assertLen(files, 1) + event_file = os.path.join(logdir, files[0]) + loader = event_file_loader.RawEventFileLoader(event_file) + input_events = [event_pb2.Event.FromString(x) for x in loader.Load()] + + new_events = [] + initial_metadata = {} + for input_event in input_events: + migrated = self._migrate_event( + input_event, initial_metadata=initial_metadata + ) + new_events.extend(migrated) + + self.assertLen(new_events, 4) + self.assertEqual(new_events[0].WhichOneof("what"), "file_version") + for step in steps: + with self.subTest("step %d" % step): + new_event = new_events[step + 1] + self.assertLen(new_event.summary.value, 1) + value = new_event.summary.value[0] + tensor = tensor_util.make_ndarray(value.tensor) + self.assertEqual( + tensor.shape, (3,) + ) # 4 clipped to max_outputs=3 + self.assertStartsWith(tensor[0], b"RIFF") + self.assertStartsWith(tensor[1], b"RIFF") + if step == min(steps): + metadata = value.metadata + self.assertEqual( + metadata.data_class, + summary_pb2.DATA_CLASS_BLOB_SEQUENCE, + ) + self.assertEqual( + metadata.plugin_data.plugin_name, + audio_metadata.PLUGIN_NAME, + ) + else: + self.assertFalse(value.HasField("metadata")) + def test_hparams(self): old_event = event_pb2.Event() old_event.step = 0 diff --git a/tensorboard/plugins/audio/audio_plugin.py b/tensorboard/plugins/audio/audio_plugin.py index 906e12aeb0..ac4d2ea54b 100644 --- a/tensorboard/plugins/audio/audio_plugin.py +++ b/tensorboard/plugins/audio/audio_plugin.py @@ -186,13 +186,11 @@ def _audio_response_for_run(self, tensor_events, run, tag, sample): filtered_events = self._filter_by_sample(tensor_events, sample) content_type = self._get_mime_type(run, tag) for (index, tensor_event) in enumerate(filtered_events): - data = tensor_util.make_ndarray(tensor_event.tensor_proto) - label = data[sample, 1] response.append( { "wall_time": tensor_event.wall_time, "step": tensor_event.step, - "label": plugin_util.markdown_to_safe_html(label), + "label": "", "contentType": content_type, "query": self._query_for_individual_audio( run, tag, sample, index @@ -240,9 +238,7 @@ def _serve_individual_audio(self, request): events = self._filter_by_sample( self._multiplexer.Tensors(run, tag), sample ) - data = tensor_util.make_ndarray(events[index].tensor_proto)[ - sample, 0 - ] + data = events[index].tensor_proto.string_val[sample] except (KeyError, IndexError): return http_util.Respond( request, diff --git a/tensorboard/plugins/audio/audio_plugin_test.py b/tensorboard/plugins/audio/audio_plugin_test.py index ec9e3920d1..ab26f1c807 100644 --- a/tensorboard/plugins/audio/audio_plugin_test.py +++ b/tensorboard/plugins/audio/audio_plugin_test.py @@ -196,10 +196,7 @@ def testNewStyleAudioRoute(self): # Verify that the 1st entry is correct. entry = entries[0] self.assertEqual("audio/wav", entry["contentType"]) - self.assertEqual( - "
step %s, sample 0
" % entry["step"], - entry["label"], - ) + self.assertEqual("", entry["label"]) self.assertEqual(0, entry["step"]) parsed_query = urllib.parse.parse_qs(entry["query"]) self.assertListEqual(["bar"], parsed_query["run"]) @@ -210,10 +207,7 @@ def testNewStyleAudioRoute(self): # Verify that the 2nd entry is correct. entry = entries[1] self.assertEqual("audio/wav", entry["contentType"]) - self.assertEqual( - "step %s, sample 0
" % entry["step"], - entry["label"], - ) + self.assertEqual("", entry["label"]) self.assertEqual(1, entry["step"]) parsed_query = urllib.parse.parse_qs(entry["query"]) self.assertListEqual(["bar"], parsed_query["run"]) diff --git a/tensorboard/plugins/audio/plugin_data.proto b/tensorboard/plugins/audio/plugin_data.proto index 0431eeb62a..015086ef14 100644 --- a/tensorboard/plugins/audio/plugin_data.proto +++ b/tensorboard/plugins/audio/plugin_data.proto @@ -28,7 +28,14 @@ message AudioPluginData { WAV = 11; } - // Version `0` is the only supported version. + // Version `0` is the only supported version. It has the following + // semantics: + // + // - If the tensor shape is rank-2, then `t[:, 0]` represent encoded + // audio data, and `t[:, 1]` represent corresponding UTF-8 encoded + // Markdown labels. + // - If the tensor shape is rank-1, then `t[:]` represent encoded + // audio data. There are no labels. int32 version = 1; Encoding encoding = 2; diff --git a/tensorboard/plugins/audio/summary.py b/tensorboard/plugins/audio/summary.py index 506a07e6c5..36e7b504f6 100644 --- a/tensorboard/plugins/audio/summary.py +++ b/tensorboard/plugins/audio/summary.py @@ -29,6 +29,7 @@ from __future__ import print_function import functools +import warnings import numpy as np @@ -41,6 +42,12 @@ audio = summary_v2.audio +_LABELS_WARNING = ( + "Labels on audio summaries are deprecated and will be removed. " + "See