Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tensorboard/backend/event_processing/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@ py_library(
srcs = ["data_provider.py"],
srcs_version = "PY2AND3",
deps = [
":event_accumulator",
"//tensorboard:errors",
"//tensorboard/data:provider",
"//tensorboard/util:tb_logging",
"//tensorboard/util:tensor_util",
"//tensorboard/plugins/graph:metadata",
"@org_pythonhosted_six",
],
)
Expand Down
159 changes: 158 additions & 1 deletion tensorboard/backend/event_processing/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,19 @@
from __future__ import division
from __future__ import print_function

import base64
import collections
import json

import six

from tensorboard import errors
from tensorboard.backend.event_processing import plugin_event_accumulator
from tensorboard.data import provider
from tensorboard.plugins.graph import metadata as graphs_metadata
from tensorboard.util import tb_logging
from tensorboard.util import tensor_util


logger = tb_logging.get_logger()


Expand Down Expand Up @@ -166,6 +172,157 @@ def _read(self, convert_event, index):
result_for_run[tag] = [convert_event(e) for e in events]
return result

def list_blob_sequences(
self, experiment_id, plugin_name, run_tag_filter=None
):
del experiment_id # ignored for now
if run_tag_filter is None:
run_tag_filter = provider.RunTagFilter(runs=None, tags=None)

# TODO(davidsoergel, wchargin): consider images, etc.
# Note this plugin_name can really just be 'graphs' for now; the
# v2 cases are not handled yet.
if plugin_name != graphs_metadata.PLUGIN_NAME:
logger.warn("Directory has no blob data for plugin %r", plugin_name)
return {}

result = collections.defaultdict(lambda: {})
for (run, run_info) in six.iteritems(self._multiplexer.Runs()):
tag = None
if not self._test_run_tag(run_tag_filter, run, tag):
continue
if not run_info[plugin_event_accumulator.GRAPH]:
continue
result[run][tag] = provider.BlobSequenceTimeSeries(
max_step=0,
max_wall_time=0,
latest_max_index=0, # Graphs are always one blob at a time
plugin_content=None,
description=None,
display_name=None,
)
return result

def read_blob_sequences(
self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
):
# TODO(davidsoergel, wchargin): consider images, etc.
# Note this plugin_name can really just be 'graphs' for now; the
# v2 cases are not handled yet.
if plugin_name != graphs_metadata.PLUGIN_NAME:
logger.warn("Directory has no blob data for plugin %r", plugin_name)
return {}

result = collections.defaultdict(
lambda: collections.defaultdict(lambda: []))
for (run, run_info) in six.iteritems(self._multiplexer.Runs()):
tag = None
if not self._test_run_tag(run_tag_filter, run, tag):
continue
if not run_info[plugin_event_accumulator.GRAPH]:
continue

time_series = result[run][tag]

wall_time = 0. # dummy value for graph
step = 0 # dummy value for graph
index = 0 # dummy value for graph

# In some situations these blobs may have directly accessible URLs.
# But, for now, we assume they don't.
graph_url = None
graph_blob_key = _encode_blob_key(
experiment_id, plugin_name, run, tag, step, index)
blob_ref = provider.BlobReference(graph_blob_key, graph_url)

datum = provider.BlobSequenceDatum(
wall_time=wall_time,
step=step,
values=(blob_ref,),
)
time_series.append(datum)
return result

def read_blob(self, blob_key):
# note: ignoring nearly all key elements: there is only one graph per run.
(unused_experiment_id, plugin_name, run, unused_tag, unused_step,
unused_index) = _decode_blob_key(blob_key)

# TODO(davidsoergel, wchargin): consider images, etc.
if plugin_name != graphs_metadata.PLUGIN_NAME:
logger.warn("Directory has no blob data for plugin %r", plugin_name)
raise errors.NotFoundError()

serialized_graph = self._multiplexer.SerializedGraph(run)

# TODO(davidsoergel): graph_defs have no step attribute so we don't filter
# on it. Other blob types might, though.

if serialized_graph is None:
logger.warn("No blob found for key %r", blob_key)
raise errors.NotFoundError()

# TODO(davidsoergel): consider internal structure of non-graphdef blobs.
# In particular, note we ignore the requested index, since it's always 0.
return serialized_graph


# TODO(davidsoergel): deduplicate with other implementations
def _encode_blob_key(experiment_id, plugin_name, run, tag, step, index):
"""Generate a blob key: a short, URL-safe string identifying a blob.

A blob can be located using a set of integer and string fields; here we
serialize these to allow passing the data through a URL. Specifically, we
1) construct a tuple of the arguments in order; 2) represent that as an
ascii-encoded JSON string (without whitespace); and 3) take the URL-safe
base64 encoding of that, with no padding. For example:

1) Tuple: ("some_id", "graphs", "train", "graph_def", 2, 0)
2) JSON: ["some_id","graphs","train","graph_def",2,0]
3) base64: WyJzb21lX2lkIiwiZ3JhcGhzIiwidHJhaW4iLCJncmFwaF9kZWYiLDIsMF0K

Args:
experiment_id: a string ID identifying an experiment.
plugin_name: string
run: string
tag: string
step: int
index: int

Returns:
A URL-safe base64-encoded string representing the provided arguments.
"""
# Encodes the blob key as a URL-safe string, as required by the
# `BlobReference` API in `tensorboard/data/provider.py`, because these keys
# may be used to construct URLs for retrieving blobs.
stringified = json.dumps(
(experiment_id, plugin_name, run, tag, step, index),
separators=(",", ":"))
bytesified = stringified.encode("ascii")
encoded = base64.urlsafe_b64encode(bytesified)
return six.ensure_str(encoded).rstrip("=")


# Any changes to this function need not be backward-compatible, even though
# the current encoding was used to generate URLs. The reason is that the
# generated URLs are not considered permalinks: they need to be valid only
# within the context of the session that created them (via the matching
# `_encode_blob_key` function above).
def _decode_blob_key(key):
"""Decode a blob key produced by `_encode_blob_key` into component fields.

Args:
key: a blob key, as generated by `_encode_blob_key`.

Returns:
A tuple of `(experiment_id, plugin_name, run, tag, step, index)`, with types
matching the arguments of `_encode_blob_key`.
"""
decoded = base64.urlsafe_b64decode(key + "==") # pad past a multiple of 4.
stringified = decoded.decode("ascii")
(experiment_id, plugin_name, run, tag, step, index) = json.loads(stringified)
return (experiment_id, plugin_name, run, tag, step, index)


def _convert_scalar_event(event):
"""Helper for `read_scalars`."""
Expand Down
16 changes: 16 additions & 0 deletions tensorboard/backend/event_processing/event_multiplexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,22 @@ def Graph(self, run):
accumulator = self.GetAccumulator(run)
return accumulator.Graph()

def SerializedGraph(self, run):
"""Retrieve the serialized graph associated with the provided run.

Args:
run: A string name of a run to load the graph for.

Raises:
KeyError: If the run is not found.
ValueError: If the run does not have an associated graph.

Returns:
The serialized form of the `GraphDef` protobuf data structure.
"""
accumulator = self.GetAccumulator(run)
return accumulator.SerializedGraph()

def MetaGraph(self, run):
"""Retrieve the metagraph associated with the provided run.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,10 @@ def Graph(self):
return graph
raise ValueError('There is no graph in this EventAccumulator')

def SerializedGraph(self):
"""Return the graph definition in serialized form, if there is one."""
return self._graph

def MetaGraph(self):
"""Return the metagraph definition, if there is one.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def AddScalarTensor(self, tag, wall_time=0, step=0, value=0):

def AddEvent(self, event):
if self.zero_out_timestamps:
event.wall_time = 0
event.wall_time = 0.
self.items.append(event)

def add_event(self, event): # pylint: disable=invalid-name
Expand Down Expand Up @@ -633,6 +633,8 @@ def FakeScalarSummary(tag, value):
expected_graph_def = graph_pb2.GraphDef.FromString(
graph.as_graph_def(add_shapes=True).SerializeToString())
self.assertProtoEquals(expected_graph_def, acc.Graph())
self.assertProtoEquals(expected_graph_def,
graph_pb2.GraphDef.FromString(acc.SerializedGraph()))

expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString(
meta_graph_def.SerializeToString())
Expand Down Expand Up @@ -667,6 +669,8 @@ def testGraphFromMetaGraphBecomesAvailable(self):
expected_graph_def = graph_pb2.GraphDef.FromString(
graph.as_graph_def(add_shapes=True).SerializeToString())
self.assertProtoEquals(expected_graph_def, acc.Graph())
self.assertProtoEquals(expected_graph_def,
graph_pb2.GraphDef.FromString(acc.SerializedGraph()))

expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString(
meta_graph_def.SerializeToString())
Expand Down
16 changes: 16 additions & 0 deletions tensorboard/backend/event_processing/plugin_event_multiplexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,22 @@ def Graph(self, run):
accumulator = self.GetAccumulator(run)
return accumulator.Graph()

def SerializedGraph(self, run):
"""Retrieve the serialized graph associated with the provided run.

Args:
run: A string name of a run to load the graph for.

Raises:
KeyError: If the run is not found.
ValueError: If the run does not have an associated graph.

Returns:
The serialized form of the `GraphDef` protobuf data structure.
"""
accumulator = self.GetAccumulator(run)
return accumulator.SerializedGraph()

def MetaGraph(self, run):
"""Retrieve the metagraph associated with the provided run.

Expand Down
2 changes: 1 addition & 1 deletion tensorboard/data/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class DataProvider(object):
downsampling strategies or domain restriction by step or wall time.

Unless otherwise noted, any methods on this class may raise errors
defined in `tensorboard.errors`, like `tensorboard.errors.NotFound`.
defined in `tensorboard.errors`, like `tensorboard.errors.NotFoundError`.
"""

def data_location(self, experiment_id):
Expand Down
3 changes: 3 additions & 0 deletions tensorboard/plugins/graph/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ py_library(
":keras_util",
":graph_util",
":metadata",
"//tensorboard:plugin_util",
"//tensorboard/backend:http_util",
"//tensorboard/backend:process_graph",
"//tensorboard/backend/event_processing:event_accumulator",
"//tensorboard/compat:tensorflow",
"//tensorboard/data:provider",
"//tensorboard/plugins:base_plugin",
"@com_google_protobuf//:protobuf_python",
"@org_pocoo_werkzeug",
Expand All @@ -45,6 +47,7 @@ py_library(
":graphs_plugin",
"//tensorboard:expect_tensorflow_installed",
"//tensorboard/backend:application",
"//tensorboard/backend/event_processing:data_provider",
"//tensorboard/backend/event_processing:event_multiplexer",
"//tensorboard/compat/proto:protos_all_py_pb2",
"//tensorboard/plugins:base_plugin",
Expand Down
Loading