Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
5f43a27
audio: remove labels from UI
wchargin Apr 8, 2020
54c2da2
uploader: inline graph filtering from `dataclass_compat`
wchargin Apr 14, 2020
2cde07d
backend: move compat transforms to event file loading
wchargin Apr 14, 2020
f339182
dataclass_compat: track initial tag metadata
wchargin Apr 14, 2020
61ab333
[update patch]
wchargin Apr 14, 2020
45b02e7
[update diffbase]
wchargin Apr 14, 2020
76843be
[update diffbase]
wchargin Apr 14, 2020
a12f351
[update patch]
wchargin Apr 14, 2020
a5e8fa5
audio: add generic data support
wchargin Apr 14, 2020
8a4247c
[update patch]
wchargin Apr 14, 2020
aac18b4
[bump ci]
wchargin Apr 14, 2020
8ef2886
[update diffbase]
wchargin Apr 14, 2020
f159fa5
[update patch]
wchargin Apr 14, 2020
6d11b7c
[update diffbase]
wchargin Apr 14, 2020
6c2d8ae
[update diffbase]
wchargin Apr 14, 2020
7ea1d0c
[update diffbase]
wchargin Apr 14, 2020
ab52a51
uploader_test: check logical equality of protos
wchargin Apr 14, 2020
d17e9f9
[update diffbase]
wchargin Apr 14, 2020
84fd685
[update diffbase]
wchargin Apr 14, 2020
ff15dc7
[update diffbase]
wchargin Apr 14, 2020
a0edd31
[update diffbase]
wchargin Apr 14, 2020
1a04730
[update diffbase]
wchargin Apr 14, 2020
94a974c
[update diffbase]
wchargin Apr 14, 2020
2ac4935
[update patch]
wchargin Apr 14, 2020
7c33f3e
[update diffbase]
wchargin Apr 14, 2020
a9adf93
[update diffbase]
wchargin Apr 14, 2020
8f469cc
[update patch]
wchargin Apr 14, 2020
ac58c6d
[update diffbase]
wchargin Apr 16, 2020
61cc397
[update diffbase]
wchargin Apr 16, 2020
524d0bc
[update patch]
wchargin Apr 16, 2020
94e1afe
[update diffbase]
wchargin Apr 16, 2020
748938a
[update diffbase]
wchargin Apr 16, 2020
82310b4
[update diffbase]
wchargin Apr 16, 2020
d553605
[update patch]
wchargin Apr 16, 2020
83b3404
[update diffbase]
wchargin Apr 16, 2020
254eb40
[update diffbase]
wchargin Apr 16, 2020
3d1d73f
[update patch]
wchargin Apr 17, 2020
0a6d49f
[update diffbase]
wchargin Apr 17, 2020
c1da800
[update patch]
wchargin Apr 17, 2020
12af0db
[update diffbase]
wchargin Apr 17, 2020
800a3b1
[update patch]
wchargin Apr 17, 2020
2b095ce
[update diffbase]
wchargin Apr 17, 2020
fd201f6
[update patch]
wchargin Apr 17, 2020
ba64453
[update diffbase]
wchargin Apr 17, 2020
1cc2861
[update diffbase]
wchargin Apr 17, 2020
28a8148
[update patch]
wchargin Apr 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tensorboard/plugins/audio/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
"//tensorboard:errors",
"//tensorboard:plugin_util",
"//tensorboard/backend:http_util",
"//tensorboard/backend/event_processing:event_accumulator",
"//tensorboard/compat:tensorflow",
"//tensorboard/data:provider",
"//tensorboard/plugins:base_plugin",
"//tensorboard/util:tensor_util",
"@org_pocoo_werkzeug",
Expand All @@ -37,6 +39,7 @@ py_test(
"//tensorboard:expect_numpy_installed",
"//tensorboard:expect_tensorflow_installed",
"//tensorboard/backend:application",
"//tensorboard/backend/event_processing:data_provider",
"//tensorboard/backend/event_processing:event_multiplexer",
"//tensorboard/plugins:base_plugin",
"//tensorboard/util:test_util",
Expand Down
181 changes: 73 additions & 108 deletions tensorboard/plugins/audio/audio_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,24 @@
from six.moves import urllib
from werkzeug import wrappers

from tensorboard import errors
from tensorboard import plugin_util
from tensorboard.backend import http_util
from tensorboard.compat import tf
from tensorboard.data import provider
from tensorboard.plugins import base_plugin
from tensorboard.plugins.audio import metadata
from tensorboard.util import tensor_util


_DEFAULT_MIME_TYPE = "application/octet-stream"
_DEFAULT_DOWNSAMPLING = 10 # audio clips per time series
_MIME_TYPES = {
metadata.Encoding.Value("WAV"): "audio/wav",
}
_ALLOWED_MIME_TYPES = frozenset(
list(_MIME_TYPES.values()) + [_DEFAULT_MIME_TYPE]
)


class AudioPlugin(base_plugin.TBPlugin):
Expand All @@ -47,7 +53,10 @@ def __init__(self, context):
Args:
context: A base_plugin.TBContext instance.
"""
self._multiplexer = context.multiplexer
self._data_provider = context.data_provider
self._downsample_to = (context.sampling_hints or {}).get(
self.plugin_name, _DEFAULT_DOWNSAMPLING
)

def get_plugin_apps(self):
return {
Expand All @@ -57,18 +66,12 @@ def get_plugin_apps(self):
}

def is_active(self):
"""The audio plugin is active iff any run has at least one relevant
tag."""
if not self._multiplexer:
return False
return bool(
self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME)
)
return False # `list_plugins` as called by TB core suffices

def frontend_metadata(self):
return base_plugin.FrontendMetadata(element_name="tf-audio-dashboard")

def _index_impl(self):
def _index_impl(self, experiment):
"""Return information about the tags in each run.

Result is a dictionary of the form
Expand All @@ -93,49 +96,22 @@ def _index_impl(self):
five audio clips at step 0 and ten audio clips at step 1, then the
dictionary for `"minibatch_input"` will contain `"samples": 10`.
"""
runs = self._multiplexer.Runs()
result = {run: {} for run in runs}

mapping = self._multiplexer.PluginRunToTagToContent(
metadata.PLUGIN_NAME
mapping = self._data_provider.list_blob_sequences(
experiment_id=experiment, plugin_name=metadata.PLUGIN_NAME,
)
for (run, tag_to_content) in six.iteritems(mapping):
for tag in tag_to_content:
summary_metadata = self._multiplexer.SummaryMetadata(run, tag)
tensor_events = self._multiplexer.Tensors(run, tag)
samples = max(
[
self._number_of_samples(event.tensor_proto)
for event in tensor_events
]
+ [0]
result = {run: {} for run in mapping}
for (run, tag_to_time_series) in mapping.items():
for (tag, time_series) in tag_to_time_series.items():
description = plugin_util.markdown_to_safe_html(
time_series.description
)
result[run][tag] = {
"displayName": summary_metadata.display_name,
"description": plugin_util.markdown_to_safe_html(
summary_metadata.summary_description
),
"samples": samples,
"displayName": time_series.display_name,
"description": description,
"samples": time_series.max_length,
}

return result

def _number_of_samples(self, tensor_proto):
"""Count the number of samples of an audio TensorProto."""
# We directly inspect the `tensor_shape` of the proto instead of
# using the preferred `tensor_util.make_ndarray(...).shape`, because
# these protos can contain a large amount of encoded audio data,
# and we don't want to have to convert them all to numpy arrays
# just to look at their shape.
return tensor_proto.tensor_shape.dim[0].size

def _filter_by_sample(self, tensor_events, sample):
return [
tensor_event
for tensor_event in tensor_events
if self._number_of_samples(tensor_event.tensor_proto) > sample
]

@wrappers.Request.application
def _serve_audio_metadata(self, request):
"""Given a tag and list of runs, serve a list of metadata for audio.
Expand All @@ -151,24 +127,18 @@ def _serve_audio_metadata(self, request):
Returns:
A werkzeug.Response application.
"""
experiment = plugin_util.experiment_id(request.environ)
tag = request.args.get("tag")
run = request.args.get("run")
sample = int(request.args.get("sample", 0))

events = self._multiplexer.Tensors(run, tag)
try:
response = self._audio_response_for_run(events, run, tag, sample)
except KeyError:
return http_util.Respond(
request, "Invalid run or tag", "text/plain", code=400
)
response = self._audio_response_for_run(experiment, run, tag, sample)
return http_util.Respond(request, response, "application/json")

def _audio_response_for_run(self, tensor_events, run, tag, sample):
def _audio_response_for_run(self, experiment, run, tag, sample):
"""Builds a JSON-serializable object with information about audio.

Args:
tensor_events: A list of image event_accumulator.TensorEvent objects.
run: The name of the run.
tag: The name of the tag the audio entries all belong to.
sample: The zero-indexed sample of the audio sample for which to
Expand All @@ -178,78 +148,73 @@ def _audio_response_for_run(self, tensor_events, run, tag, sample):
the results.

Returns:
A list of dictionaries containing the wall time, step, URL, width, and
height for each audio entry.
A list of dictionaries containing the wall time, step, label,
content type, and query string for each audio entry.
"""
all_audio = self._data_provider.read_blob_sequences(
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME,
downsample=self._downsample_to,
run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]),
)
audio = all_audio.get(run, {}).get(tag, None)
if audio is None:
raise errors.NotFoundError(
"No audio data for run=%r, tag=%r" % (run, tag)
)
content_type = self._get_mime_type(experiment, run, tag)
response = []
index = 0
filtered_events = self._filter_by_sample(tensor_events, sample)
content_type = self._get_mime_type(run, tag)
for (index, tensor_event) in enumerate(filtered_events):
for datum in audio:
if len(datum.values) < sample:
continue
query = urllib.parse.urlencode(
{
"blob_key": datum.values[sample].blob_key,
"content_type": content_type,
}
)
response.append(
{
"wall_time": tensor_event.wall_time,
"step": tensor_event.step,
"wall_time": datum.wall_time,
"label": "",
"step": datum.step,
"contentType": content_type,
"query": self._query_for_individual_audio(
run, tag, sample, index
),
"query": query,
}
)
return response

def _query_for_individual_audio(self, run, tag, sample, index):
"""Builds a URL for accessing the specified audio.

This should be kept in sync with _serve_audio_metadata. Note that the URL is
*not* guaranteed to always return the same audio, since audio may be
unloaded from the reservoir as new audio entries come in.

Args:
run: The name of the run.
tag: The tag.
index: The index of the audio entry. Negative values are OK.

Returns:
A string representation of a URL that will load the index-th sampled audio
in the given run with the given tag.
"""
query_string = urllib.parse.urlencode(
{"run": run, "tag": tag, "sample": sample, "index": index,}
def _get_mime_type(self, experiment, run, tag):
# TODO(@wchargin): Move this call from `/audio` (called many
# times) to `/tags` (called few times) to reduce data provider
# calls.
self._data_provider.list_blob_sequences
mapping = self._data_provider.list_blob_sequences(
experiment_id=experiment, plugin_name=metadata.PLUGIN_NAME,
)
return query_string

def _get_mime_type(self, run, tag):
content = self._multiplexer.SummaryMetadata(
run, tag
).plugin_data.content
parsed = metadata.parse_plugin_metadata(content)
time_series = mapping.get(run, {}).get(tag, None)
if time_series is None:
raise errors.NotFoundError(
"No audio data for run=%r, tag=%r" % (run, tag)
)
parsed = metadata.parse_plugin_metadata(time_series.plugin_content)
return _MIME_TYPES.get(parsed.encoding, _DEFAULT_MIME_TYPE)

@wrappers.Request.application
def _serve_individual_audio(self, request):
"""Serve encoded audio data."""
tag = request.args.get("tag")
run = request.args.get("run")
index = int(request.args.get("index", "0"))
sample = int(request.args.get("sample", "0"))
try:
events = self._filter_by_sample(
self._multiplexer.Tensors(run, tag), sample
)
data = events[index].tensor_proto.string_val[sample]
except (KeyError, IndexError):
return http_util.Respond(
request,
"Invalid run, tag, index, or sample",
"text/plain",
code=400,
experiment = plugin_util.experiment_id(request.environ)
mime_type = request.args["content_type"]
if mime_type not in _ALLOWED_MIME_TYPES:
raise errors.InvalidArgumentError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this possible issue! Can we add a test that exercises this to ensure that we refuse to return audio as, say, text/javascript?

I think the Werkzeug client get() API takes a content_type parameter so this should be straightforward.

In the longer term it would be nice if we didn't have to rely on the client behavior for this. Perhaps we should add a per-blob content_type onto the blob data model? We would have to go around adding it to the various data providers and being sure to set it correctly on upload, but it seems worth doing to properly handle cases like this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure: done, I think, though I didn’t understand what you meant about the
Werkzeug client API taking a content_type parameter. (Isn’t that for
the content type of a POST request body?)

Perhaps we should add a per-blob content_type onto the blob data
model?

Perhaps. Presumably the source of truth would be a new field on the
SummaryMetadata proto (string content_type = 5;), required to be
homogeneous within a time series across all steps and indices, and
defaulting to application/octet-stream if empty for backward
compatibility? I’ll keep it in mind.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I was confused and misread this and somehow thought we were getting the content type from a header that the browser was populating (even though yes, that isn't a thing for GET requests) rather than the query parameter. Never mind :P

Re: per-blob content types, yeah, I guess it could have a field in SummaryMetadata, although it's perhaps a little odd to have a field that is returned per-value stored in the per-tag SummaryMetadata (even if for now we don't anticipate allowing it to vary per-value).

"Illegal mime type %r" % mime_type
)
mime_type = self._get_mime_type(run, tag)
blob_key = request.args["blob_key"]
data = self._data_provider.read_blob(blob_key)
return http_util.Respond(request, data, mime_type)

@wrappers.Request.application
def _serve_tags(self, request):
index = self._index_impl()
experiment = plugin_util.experiment_id(request.environ)
index = self._index_impl(experiment)
return http_util.Respond(request, index, "application/json")
Loading