diff --git a/tensorboard/backend/event_processing/BUILD b/tensorboard/backend/event_processing/BUILD index 01bf9552b5..a95438d966 100644 --- a/tensorboard/backend/event_processing/BUILD +++ b/tensorboard/backend/event_processing/BUILD @@ -35,9 +35,12 @@ py_library( srcs = ["data_provider.py"], srcs_version = "PY2AND3", deps = [ + ":event_accumulator", + "//tensorboard:errors", "//tensorboard/data:provider", "//tensorboard/util:tb_logging", "//tensorboard/util:tensor_util", + "//tensorboard/plugins/graph:metadata", "@org_pythonhosted_six", ], ) diff --git a/tensorboard/backend/event_processing/data_provider.py b/tensorboard/backend/event_processing/data_provider.py index 0c410acd7e..2679ca61c6 100644 --- a/tensorboard/backend/event_processing/data_provider.py +++ b/tensorboard/backend/event_processing/data_provider.py @@ -18,13 +18,19 @@ from __future__ import division from __future__ import print_function +import base64 +import collections +import json + import six +from tensorboard import errors +from tensorboard.backend.event_processing import plugin_event_accumulator from tensorboard.data import provider +from tensorboard.plugins.graph import metadata as graphs_metadata from tensorboard.util import tb_logging from tensorboard.util import tensor_util - logger = tb_logging.get_logger() @@ -166,6 +172,157 @@ def _read(self, convert_event, index): result_for_run[tag] = [convert_event(e) for e in events] return result + def list_blob_sequences( + self, experiment_id, plugin_name, run_tag_filter=None + ): + del experiment_id # ignored for now + if run_tag_filter is None: + run_tag_filter = provider.RunTagFilter(runs=None, tags=None) + + # TODO(davidsoergel, wchargin): consider images, etc. + # Note this plugin_name can really just be 'graphs' for now; the + # v2 cases are not handled yet. + if plugin_name != graphs_metadata.PLUGIN_NAME: + logger.warn("Directory has no blob data for plugin %r", plugin_name) + return {} + + result = collections.defaultdict(lambda: {}) + for (run, run_info) in six.iteritems(self._multiplexer.Runs()): + tag = None + if not self._test_run_tag(run_tag_filter, run, tag): + continue + if not run_info[plugin_event_accumulator.GRAPH]: + continue + result[run][tag] = provider.BlobSequenceTimeSeries( + max_step=0, + max_wall_time=0, + latest_max_index=0, # Graphs are always one blob at a time + plugin_content=None, + description=None, + display_name=None, + ) + return result + + def read_blob_sequences( + self, experiment_id, plugin_name, downsample=None, run_tag_filter=None + ): + # TODO(davidsoergel, wchargin): consider images, etc. + # Note this plugin_name can really just be 'graphs' for now; the + # v2 cases are not handled yet. + if plugin_name != graphs_metadata.PLUGIN_NAME: + logger.warn("Directory has no blob data for plugin %r", plugin_name) + return {} + + result = collections.defaultdict( + lambda: collections.defaultdict(lambda: [])) + for (run, run_info) in six.iteritems(self._multiplexer.Runs()): + tag = None + if not self._test_run_tag(run_tag_filter, run, tag): + continue + if not run_info[plugin_event_accumulator.GRAPH]: + continue + + time_series = result[run][tag] + + wall_time = 0. # dummy value for graph + step = 0 # dummy value for graph + index = 0 # dummy value for graph + + # In some situations these blobs may have directly accessible URLs. + # But, for now, we assume they don't. + graph_url = None + graph_blob_key = _encode_blob_key( + experiment_id, plugin_name, run, tag, step, index) + blob_ref = provider.BlobReference(graph_blob_key, graph_url) + + datum = provider.BlobSequenceDatum( + wall_time=wall_time, + step=step, + values=(blob_ref,), + ) + time_series.append(datum) + return result + + def read_blob(self, blob_key): + # note: ignoring nearly all key elements: there is only one graph per run. + (unused_experiment_id, plugin_name, run, unused_tag, unused_step, + unused_index) = _decode_blob_key(blob_key) + + # TODO(davidsoergel, wchargin): consider images, etc. + if plugin_name != graphs_metadata.PLUGIN_NAME: + logger.warn("Directory has no blob data for plugin %r", plugin_name) + raise errors.NotFoundError() + + serialized_graph = self._multiplexer.SerializedGraph(run) + + # TODO(davidsoergel): graph_defs have no step attribute so we don't filter + # on it. Other blob types might, though. + + if serialized_graph is None: + logger.warn("No blob found for key %r", blob_key) + raise errors.NotFoundError() + + # TODO(davidsoergel): consider internal structure of non-graphdef blobs. + # In particular, note we ignore the requested index, since it's always 0. + return serialized_graph + + +# TODO(davidsoergel): deduplicate with other implementations +def _encode_blob_key(experiment_id, plugin_name, run, tag, step, index): + """Generate a blob key: a short, URL-safe string identifying a blob. + + A blob can be located using a set of integer and string fields; here we + serialize these to allow passing the data through a URL. Specifically, we + 1) construct a tuple of the arguments in order; 2) represent that as an + ascii-encoded JSON string (without whitespace); and 3) take the URL-safe + base64 encoding of that, with no padding. For example: + + 1) Tuple: ("some_id", "graphs", "train", "graph_def", 2, 0) + 2) JSON: ["some_id","graphs","train","graph_def",2,0] + 3) base64: WyJzb21lX2lkIiwiZ3JhcGhzIiwidHJhaW4iLCJncmFwaF9kZWYiLDIsMF0K + + Args: + experiment_id: a string ID identifying an experiment. + plugin_name: string + run: string + tag: string + step: int + index: int + + Returns: + A URL-safe base64-encoded string representing the provided arguments. + """ + # Encodes the blob key as a URL-safe string, as required by the + # `BlobReference` API in `tensorboard/data/provider.py`, because these keys + # may be used to construct URLs for retrieving blobs. + stringified = json.dumps( + (experiment_id, plugin_name, run, tag, step, index), + separators=(",", ":")) + bytesified = stringified.encode("ascii") + encoded = base64.urlsafe_b64encode(bytesified) + return six.ensure_str(encoded).rstrip("=") + + +# Any changes to this function need not be backward-compatible, even though +# the current encoding was used to generate URLs. The reason is that the +# generated URLs are not considered permalinks: they need to be valid only +# within the context of the session that created them (via the matching +# `_encode_blob_key` function above). +def _decode_blob_key(key): + """Decode a blob key produced by `_encode_blob_key` into component fields. + + Args: + key: a blob key, as generated by `_encode_blob_key`. + + Returns: + A tuple of `(experiment_id, plugin_name, run, tag, step, index)`, with types + matching the arguments of `_encode_blob_key`. + """ + decoded = base64.urlsafe_b64decode(key + "==") # pad past a multiple of 4. + stringified = decoded.decode("ascii") + (experiment_id, plugin_name, run, tag, step, index) = json.loads(stringified) + return (experiment_id, plugin_name, run, tag, step, index) + def _convert_scalar_event(event): """Helper for `read_scalars`.""" diff --git a/tensorboard/backend/event_processing/event_multiplexer.py b/tensorboard/backend/event_processing/event_multiplexer.py index 340cf0bd01..5823a6412d 100644 --- a/tensorboard/backend/event_processing/event_multiplexer.py +++ b/tensorboard/backend/event_processing/event_multiplexer.py @@ -287,6 +287,22 @@ def Graph(self, run): accumulator = self.GetAccumulator(run) return accumulator.Graph() + def SerializedGraph(self, run): + """Retrieve the serialized graph associated with the provided run. + + Args: + run: A string name of a run to load the graph for. + + Raises: + KeyError: If the run is not found. + ValueError: If the run does not have an associated graph. + + Returns: + The serialized form of the `GraphDef` protobuf data structure. + """ + accumulator = self.GetAccumulator(run) + return accumulator.SerializedGraph() + def MetaGraph(self, run): """Retrieve the metagraph associated with the provided run. diff --git a/tensorboard/backend/event_processing/plugin_event_accumulator.py b/tensorboard/backend/event_processing/plugin_event_accumulator.py index 75155b6766..8fbea6f53a 100644 --- a/tensorboard/backend/event_processing/plugin_event_accumulator.py +++ b/tensorboard/backend/event_processing/plugin_event_accumulator.py @@ -387,6 +387,10 @@ def Graph(self): return graph raise ValueError('There is no graph in this EventAccumulator') + def SerializedGraph(self): + """Return the graph definition in serialized form, if there is one.""" + return self._graph + def MetaGraph(self): """Return the metagraph definition, if there is one. diff --git a/tensorboard/backend/event_processing/plugin_event_accumulator_test.py b/tensorboard/backend/event_processing/plugin_event_accumulator_test.py index 99e2fe8089..c6861a9838 100644 --- a/tensorboard/backend/event_processing/plugin_event_accumulator_test.py +++ b/tensorboard/backend/event_processing/plugin_event_accumulator_test.py @@ -75,7 +75,7 @@ def AddScalarTensor(self, tag, wall_time=0, step=0, value=0): def AddEvent(self, event): if self.zero_out_timestamps: - event.wall_time = 0 + event.wall_time = 0. self.items.append(event) def add_event(self, event): # pylint: disable=invalid-name @@ -633,6 +633,8 @@ def FakeScalarSummary(tag, value): expected_graph_def = graph_pb2.GraphDef.FromString( graph.as_graph_def(add_shapes=True).SerializeToString()) self.assertProtoEquals(expected_graph_def, acc.Graph()) + self.assertProtoEquals(expected_graph_def, + graph_pb2.GraphDef.FromString(acc.SerializedGraph())) expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString( meta_graph_def.SerializeToString()) @@ -667,6 +669,8 @@ def testGraphFromMetaGraphBecomesAvailable(self): expected_graph_def = graph_pb2.GraphDef.FromString( graph.as_graph_def(add_shapes=True).SerializeToString()) self.assertProtoEquals(expected_graph_def, acc.Graph()) + self.assertProtoEquals(expected_graph_def, + graph_pb2.GraphDef.FromString(acc.SerializedGraph())) expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString( meta_graph_def.SerializeToString()) diff --git a/tensorboard/backend/event_processing/plugin_event_multiplexer.py b/tensorboard/backend/event_processing/plugin_event_multiplexer.py index e69afb4c46..63a04d4d15 100644 --- a/tensorboard/backend/event_processing/plugin_event_multiplexer.py +++ b/tensorboard/backend/event_processing/plugin_event_multiplexer.py @@ -321,6 +321,22 @@ def Graph(self, run): accumulator = self.GetAccumulator(run) return accumulator.Graph() + def SerializedGraph(self, run): + """Retrieve the serialized graph associated with the provided run. + + Args: + run: A string name of a run to load the graph for. + + Raises: + KeyError: If the run is not found. + ValueError: If the run does not have an associated graph. + + Returns: + The serialized form of the `GraphDef` protobuf data structure. + """ + accumulator = self.GetAccumulator(run) + return accumulator.SerializedGraph() + def MetaGraph(self, run): """Retrieve the metagraph associated with the provided run. diff --git a/tensorboard/data/provider.py b/tensorboard/data/provider.py index aad91b392d..ad2dee8d49 100644 --- a/tensorboard/data/provider.py +++ b/tensorboard/data/provider.py @@ -34,7 +34,7 @@ class DataProvider(object): downsampling strategies or domain restriction by step or wall time. Unless otherwise noted, any methods on this class may raise errors - defined in `tensorboard.errors`, like `tensorboard.errors.NotFound`. + defined in `tensorboard.errors`, like `tensorboard.errors.NotFoundError`. """ def data_location(self, experiment_id): diff --git a/tensorboard/plugins/graph/BUILD b/tensorboard/plugins/graph/BUILD index 9133e890ea..6b91c67850 100644 --- a/tensorboard/plugins/graph/BUILD +++ b/tensorboard/plugins/graph/BUILD @@ -17,10 +17,12 @@ py_library( ":keras_util", ":graph_util", ":metadata", + "//tensorboard:plugin_util", "//tensorboard/backend:http_util", "//tensorboard/backend:process_graph", "//tensorboard/backend/event_processing:event_accumulator", "//tensorboard/compat:tensorflow", + "//tensorboard/data:provider", "//tensorboard/plugins:base_plugin", "@com_google_protobuf//:protobuf_python", "@org_pocoo_werkzeug", @@ -45,6 +47,7 @@ py_library( ":graphs_plugin", "//tensorboard:expect_tensorflow_installed", "//tensorboard/backend:application", + "//tensorboard/backend/event_processing:data_provider", "//tensorboard/backend/event_processing:event_multiplexer", "//tensorboard/compat/proto:protos_all_py_pb2", "//tensorboard/plugins:base_plugin", diff --git a/tensorboard/plugins/graph/graphs_plugin.py b/tensorboard/plugins/graph/graphs_plugin.py index 9d2b7e1130..eb77488aa6 100644 --- a/tensorboard/plugins/graph/graphs_plugin.py +++ b/tensorboard/plugins/graph/graphs_plugin.py @@ -22,11 +22,13 @@ import six from werkzeug import wrappers +from tensorboard import plugin_util from tensorboard.backend import http_util from tensorboard.backend import process_graph from tensorboard.backend.event_processing import plugin_event_accumulator as event_accumulator # pylint: disable=line-too-long from tensorboard.compat.proto import config_pb2 from tensorboard.compat.proto import graph_pb2 +from tensorboard.data import provider from tensorboard.plugins import base_plugin from tensorboard.plugins.graph import graph_util from tensorboard.plugins.graph import keras_util @@ -58,6 +60,10 @@ def __init__(self, context): context: A base_plugin.TBContext instance. """ self._multiplexer = context.multiplexer + if context.flags and context.flags.generic_data == 'true': + self._data_provider = context.data_provider + else: + self._data_provider = None def get_plugin_apps(self): return { @@ -67,8 +73,13 @@ def get_plugin_apps(self): } def is_active(self): - """The graphs plugin is active iff any run has a graph.""" - return bool(self._multiplexer and self.info_impl()) + """The graphs plugin is active iff any run has a graph or metadata.""" + if self._data_provider: + # We don't have an experiment ID, and modifying the backend core + # to provide one would break backward compatibility. Hack for now. + return True + + return bool(self.info_impl()) def frontend_metadata(self): return base_plugin.FrontendMetadata( @@ -77,8 +88,8 @@ def frontend_metadata(self): disable_reload=True, ) - def info_impl(self): - """Returns a dict of all runs and tags and their data availabilities.""" + def info_impl(self, experiment=None): + """Returns a dict of all runs and their data availabilities.""" result = {} def add_row_item(run, tag=None): run_item = result.setdefault(run, { @@ -97,6 +108,19 @@ def add_row_item(run, tag=None): 'profile': False}) return (run_item, tag_item) + if self._data_provider: + mapping = self._data_provider.list_blob_sequences( + experiment_id=experiment, + plugin_name=metadata.PLUGIN_NAME, + ) + for (run_name, tag_to_time_series) in six.iteritems(mapping): + for tag in tag_to_time_series: + (run_item, tag_item) = add_row_item(run_name, tag) + run_item['op_graph'] = True + if tag_item: + tag_item['op_graph'] = True + return result + mapping = self._multiplexer.PluginRunToTagToContent( _PLUGIN_NAME_RUN_METADATA_WITH_GRAPH) for run_name, tag_to_content in six.iteritems(mapping): @@ -148,14 +172,33 @@ def add_row_item(run, tag=None): return result - def graph_impl(self, run, tag, is_conceptual, limit_attr_size=None, large_attrs_key=None): + def graph_impl(self, run, tag, is_conceptual, experiment=None, limit_attr_size=None, large_attrs_key=None): """Result of the form `(body, mime_type)`, or `None` if no graph exists.""" - if is_conceptual: + if self._data_provider: + graph_blob_sequences = self._data_provider.read_blob_sequences( + experiment_id=experiment, + plugin_name=metadata.PLUGIN_NAME, + run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]), + ) + blob_datum_list = graph_blob_sequences.get(run, {}).get(tag, ()) + try: + blob_ref = blob_datum_list[0].values[0] + except IndexError: + return None + # Always use the blob_key approach for now, even if there is a direct url. + graph_raw = self._data_provider.read_blob(blob_ref.blob_key) + # This method ultimately returns pbtxt, but we have to deserialize and + # later reserialize this anyway, because a) this way we accept binary + # protobufs too, and b) below we run `prepare_graph_for_ui` on the graph. + graph = graph_pb2.GraphDef.FromString(graph_raw) + + elif is_conceptual: tensor_events = self._multiplexer.Tensors(run, tag) # Take the first event if there are multiple events written from different # steps. keras_model_config = json.loads(tensor_events[0].tensor_proto.string_val[0]) graph = keras_util.keras_model_to_graph_def(keras_model_config) + elif tag: tensor_events = self._multiplexer.Tensors(run, tag) # Take the first event if there are multiple events written from different @@ -176,6 +219,9 @@ def graph_impl(self, run, tag, is_conceptual, limit_attr_size=None, large_attrs_ def run_metadata_impl(self, run, tag): """Result of the form `(body, mime_type)`, or `None` if no data exists.""" + if self._data_provider: + # TODO(davidsoergel, wchargin): Consider plumbing run metadata through data providers. + return None try: run_metadata = self._multiplexer.RunMetadata(run, tag) except ValueError: @@ -194,12 +240,14 @@ def run_metadata_impl(self, run, tag): @wrappers.Request.application def info_route(self, request): - info = self.info_impl() + experiment = plugin_util.experiment_id(request.environ) + info = self.info_impl(experiment) return http_util.Respond(request, info, 'application/json') @wrappers.Request.application def graph_route(self, request): """Given a single run, return the graph definition in protobuf format.""" + experiment = plugin_util.experiment_id(request.environ) run = request.args.get('run') tag = request.args.get('tag', '') conceptual_arg = request.args.get('conceptual', False) @@ -221,7 +269,7 @@ def graph_route(self, request): large_attrs_key = request.args.get('large_attrs_key', None) try: - result = self.graph_impl(run, tag, is_conceptual, limit_attr_size, large_attrs_key) + result = self.graph_impl(run, tag, is_conceptual, experiment, limit_attr_size, large_attrs_key) except ValueError as e: return http_util.Respond(request, e.message, 'text/plain', code=400) else: diff --git a/tensorboard/plugins/graph/graphs_plugin_test.py b/tensorboard/plugins/graph/graphs_plugin_test.py index 0eac4111a6..ed0a9cd6b0 100644 --- a/tensorboard/plugins/graph/graphs_plugin_test.py +++ b/tensorboard/plugins/graph/graphs_plugin_test.py @@ -19,13 +19,16 @@ from __future__ import division from __future__ import print_function +import argparse import collections import math +import functools import os.path import tensorflow as tf from google.protobuf import text_format +from tensorboard.backend.event_processing import data_provider from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer # pylint: disable=line-too-long from tensorboard.compat.proto import config_pb2 from tensorboard.plugins import base_plugin @@ -38,57 +41,80 @@ # TODO(stephanwlee): Move more tests into the base class when v2 test # can write graph and metadata with a TF public API. -class GraphsPluginBaseTest(object): - _RUN_WITH_GRAPH = '_RUN_WITH_GRAPH' - _RUN_WITHOUT_GRAPH = '_RUN_WITHOUT_GRAPH' +_RUN_WITH_GRAPH_WITH_METADATA = ('_RUN_WITH_GRAPH_WITH_METADATA', True, True) +_RUN_WITHOUT_GRAPH_WITH_METADATA = ('_RUN_WITHOUT_GRAPH_WITH_METADATA', False, True) +_RUN_WITH_GRAPH_WITHOUT_METADATA = ('_RUN_WITH_GRAPH_WITHOUT_METADATA', True, False) +_RUN_WITHOUT_GRAPH_WITHOUT_METADATA = ('_RUN_WITHOUT_GRAPH_WITHOUT_METADATA', False, False) + +def with_runs(run_specs): + """Run a test with a bare multiplexer and with a `data_provider`. + + The decorated function will receive an initialized `GraphsPlugin` + object as its first positional argument. + + The receiver argument of the decorated function must be a `TestCase` instance + that also provides `load_runs`.` + """ + def decorator(fn): + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + (logdir, multiplexer) = self.load_runs(run_specs) + with self.subTest('bare multiplexer'): + ctx = base_plugin.TBContext(logdir=logdir, multiplexer=multiplexer) + fn(self, graphs_plugin.GraphsPlugin(ctx), *args, **kwargs) + with self.subTest('generic data provider'): + flags = argparse.Namespace(generic_data='true') + provider = data_provider.MultiplexerDataProvider(multiplexer, logdir) + ctx = base_plugin.TBContext( + flags=flags, + logdir=logdir, + multiplexer=multiplexer, + data_provider=provider, + ) + fn(self, graphs_plugin.GraphsPlugin(ctx), *args, **kwargs) + return wrapper + return decorator + +class GraphsPluginBaseTest(object): _METADATA_TAG = 'secret-stats' _MESSAGE_PREFIX_LENGTH_LOWER_BOUND = 1024 def __init__(self, *args, **kwargs): super(GraphsPluginBaseTest, self).__init__(*args, **kwargs) - self.logdir = None self.plugin = None def setUp(self): super(GraphsPluginBaseTest, self).setUp() - self.logdir = self.get_temp_dir() - def generate_run(self, run_name, include_graph, include_run_metadata): + def generate_run(self, logdir, run_name, include_graph, include_run_metadata): """Create a run""" raise NotImplementedError('Please implement generate_run') - def set_up_with_runs(self, with_graph=True, without_graph=True): - if with_graph: - self.generate_run(self._RUN_WITH_GRAPH, - include_graph=True, - include_run_metadata=True) - if without_graph: - self.generate_run(self._RUN_WITHOUT_GRAPH, - include_graph=False, - include_run_metadata=True) - self.bootstrap_plugin() - - def bootstrap_plugin(self): + def load_runs(self, run_specs): + logdir = self.get_temp_dir() + for run_spec in run_specs: + self.generate_run(logdir, *run_spec) + return self.bootstrap_plugin(logdir) + + def bootstrap_plugin(self, logdir): multiplexer = event_multiplexer.EventMultiplexer() - multiplexer.AddRunsFromDirectory(self.logdir) + multiplexer.AddRunsFromDirectory(logdir) multiplexer.Reload() - context = base_plugin.TBContext(logdir=self.logdir, multiplexer=multiplexer) - self.plugin = graphs_plugin.GraphsPlugin(context) + return (logdir, multiplexer) - def testRoutesProvided(self): + @with_runs([_RUN_WITH_GRAPH_WITH_METADATA, _RUN_WITHOUT_GRAPH_WITH_METADATA]) + def testRoutesProvided(self, plugin): """Tests that the plugin offers the correct routes.""" - self.set_up_with_runs() - routes = self.plugin.get_plugin_apps() + routes = plugin.get_plugin_apps() self.assertIsInstance(routes['/graph'], collections.Callable) self.assertIsInstance(routes['/run_metadata'], collections.Callable) self.assertIsInstance(routes['/info'], collections.Callable) - class GraphsPluginV1Test(GraphsPluginBaseTest, tf.test.TestCase): - def generate_run(self, run_name, include_graph, include_run_metadata): + def generate_run(self, logdir, run_name, include_graph, include_run_metadata): """Create a run with a text summary, metadata, and optionally a graph.""" tf.compat.v1.reset_default_graph() k1 = tf.constant(math.pi, name='k1') @@ -106,7 +132,7 @@ def generate_run(self, run_name, include_graph, include_run_metadata): summary_message = tf.compat.v1.summary.text('summary_message', error_message) sess = tf.compat.v1.Session() - writer = test_util.FileWriter(os.path.join(self.logdir, run_name)) + writer = test_util.FileWriter(os.path.join(logdir, run_name)) if include_graph: writer.add_graph(sess.graph) options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE) @@ -117,18 +143,22 @@ def generate_run(self, run_name, include_graph, include_run_metadata): writer.add_run_metadata(run_metadata, self._METADATA_TAG) writer.close() - def _get_graph(self, *args, **kwargs): + def _get_graph(self, plugin, *args, **kwargs): """Set up runs, then fetch and return the graph as a proto.""" - self.set_up_with_runs() - (graph_pbtxt, mime_type) = self.plugin.graph_impl( - self._RUN_WITH_GRAPH, *args, **kwargs) + (graph_pbtxt, mime_type) = plugin.graph_impl( + _RUN_WITH_GRAPH_WITH_METADATA[0], *args, **kwargs) self.assertEqual(mime_type, 'text/x-protobuf') return text_format.Parse(graph_pbtxt, tf.compat.v1.GraphDef()) - def test_info(self): + @with_runs([ + _RUN_WITH_GRAPH_WITH_METADATA, + _RUN_WITH_GRAPH_WITHOUT_METADATA, + _RUN_WITHOUT_GRAPH_WITH_METADATA, + _RUN_WITHOUT_GRAPH_WITHOUT_METADATA]) + def test_info(self, plugin): expected = { - 'w_graph_w_meta': { - 'run': 'w_graph_w_meta', + '_RUN_WITH_GRAPH_WITH_METADATA': { + 'run': '_RUN_WITH_GRAPH_WITH_METADATA', 'run_graph': True, 'tags': { 'secret-stats': { @@ -139,13 +169,17 @@ def test_info(self): }, }, }, - 'w_graph_wo_meta': { - 'run': 'w_graph_wo_meta', + '_RUN_WITH_GRAPH_WITHOUT_METADATA': { + 'run': '_RUN_WITH_GRAPH_WITHOUT_METADATA', 'run_graph': True, 'tags': {}, - }, - 'wo_graph_w_meta': { - 'run': 'wo_graph_w_meta', + } + } + + if not plugin._data_provider: + # Hack, for now. + expected['_RUN_WITHOUT_GRAPH_WITH_METADATA'] = { + 'run': '_RUN_WITHOUT_GRAPH_WITH_METADATA', 'run_graph': False, 'tags': { 'secret-stats': { @@ -155,27 +189,13 @@ def test_info(self): 'op_graph': False, }, }, - }, - } + } + + self.assertItemsEqual(expected, plugin.info_impl()) - self.generate_run('w_graph_w_meta', - include_graph=True, - include_run_metadata=True) - self.generate_run('w_graph_wo_meta', - include_graph=True, - include_run_metadata=False) - self.generate_run('wo_graph_w_meta', - include_graph=False, - include_run_metadata=True) - self.generate_run('wo_graph_wo_meta', - include_graph=False, - include_run_metadata=False) - self.bootstrap_plugin() - - self.assertItemsEqual(expected, self.plugin.info_impl()) - - def test_graph_simple(self): - graph = self._get_graph(tag=None, is_conceptual=False) + @with_runs([_RUN_WITH_GRAPH_WITH_METADATA]) + def test_graph_simple(self, plugin): + graph = self._get_graph(plugin, tag=None, is_conceptual=False) node_names = set(node.name for node in graph.node) self.assertEqual({ 'k1', 'k2', 'pow', 'sub', 'expected', 'sub_1', 'error', @@ -183,9 +203,11 @@ def test_graph_simple(self): 'summary_message/tag', 'summary_message/serialized_summary_metadata', }, node_names) - def test_graph_large_attrs(self): + @with_runs([_RUN_WITH_GRAPH_WITH_METADATA]) + def test_graph_large_attrs(self, plugin): key = 'o---;;-;' graph = self._get_graph( + plugin, tag=None, is_conceptual=False, limit_attr_size=self._MESSAGE_PREFIX_LENGTH_LOWER_BOUND, @@ -198,41 +220,38 @@ def test_graph_large_attrs(self): self.assertEqual({'message_prefix': [b'value']}, large_attrs) - def test_run_metadata(self): - self.set_up_with_runs() - (metadata_pbtxt, mime_type) = self.plugin.run_metadata_impl( - self._RUN_WITH_GRAPH, self._METADATA_TAG) - self.assertEqual(mime_type, 'text/x-protobuf') - text_format.Parse(metadata_pbtxt, config_pb2.RunMetadata()) - # If it parses, we're happy. - - def test_is_active_with_graph_without_run_metadata(self): - self.generate_run('w_graph_wo_meta', - include_graph=True, - include_run_metadata=False) - self.bootstrap_plugin() - self.assertTrue(self.plugin.is_active()) - - def test_is_active_without_graph_with_run_metadata(self): - self.generate_run('wo_graph_w_meta', - include_graph=False, - include_run_metadata=True) - self.bootstrap_plugin() - self.assertTrue(self.plugin.is_active()) - - def test_is_active_with_both(self): - self.generate_run('w_graph_w_meta', - include_graph=True, - include_run_metadata=True) - self.bootstrap_plugin() - self.assertTrue(self.plugin.is_active()) - - def test_is_active_without_both(self): - self.generate_run('wo_graph_wo_meta', - include_graph=False, - include_run_metadata=False) - self.bootstrap_plugin() - self.assertFalse(self.plugin.is_active()) + @with_runs([_RUN_WITH_GRAPH_WITH_METADATA]) + def test_run_metadata(self, plugin): + result = plugin.run_metadata_impl( + _RUN_WITH_GRAPH_WITH_METADATA[0], self._METADATA_TAG) + if plugin._data_provider: + # Hack, for now + self.assertEqual(result, None) + else: + (metadata_pbtxt, mime_type) = result + self.assertEqual(mime_type, 'text/x-protobuf') + text_format.Parse(metadata_pbtxt, config_pb2.RunMetadata()) + # If it parses, we're happy. + + @with_runs([_RUN_WITH_GRAPH_WITHOUT_METADATA]) + def test_is_active_with_graph_without_run_metadata(self, plugin): + self.assertTrue(plugin.is_active()) + + @with_runs([_RUN_WITHOUT_GRAPH_WITH_METADATA]) + def test_is_active_without_graph_with_run_metadata(self, plugin): + self.assertTrue(plugin.is_active()) + + @with_runs([_RUN_WITH_GRAPH_WITH_METADATA]) + def test_is_active_with_both(self, plugin): + self.assertTrue(plugin.is_active()) + + @with_runs([_RUN_WITHOUT_GRAPH_WITHOUT_METADATA]) + def test_is_inactive_without_both(self, plugin): + if plugin._data_provider: + # Hack, for now. + self.assertTrue(plugin.is_active()) + else: + self.assertFalse(plugin.is_active()) if __name__ == '__main__': tf.test.main() diff --git a/tensorboard/plugins/graph/graphs_plugin_v2_test.py b/tensorboard/plugins/graph/graphs_plugin_v2_test.py index bdef5dcc55..ca08955820 100644 --- a/tensorboard/plugins/graph/graphs_plugin_v2_test.py +++ b/tensorboard/plugins/graph/graphs_plugin_v2_test.py @@ -30,7 +30,7 @@ class GraphsPluginV2Test(graphs_plugin_test.GraphsPluginBaseTest, tf.test.TestCase): - def generate_run(self, run_name, include_graph, include_run_metadata): + def generate_run(self, logdir, run_name, include_graph, include_run_metadata): x, y = np.ones((10, 10)), np.ones((10, 1)) val_x, val_y = np.ones((4, 10)), np.ones((4, 1)) @@ -46,16 +46,19 @@ def generate_run(self, run_name, include_graph, include_run_metadata): batch_size=2, epochs=1, callbacks=[tf.compat.v2.keras.callbacks.TensorBoard( - log_dir=os.path.join(self.logdir, run_name), + log_dir=os.path.join(logdir, run_name), write_graph=include_graph)]) - def _get_graph(self, *args, **kwargs): + def _get_graph(self, plugin, *args, **kwargs): """Fetch and return the graph as a proto.""" - (graph_pbtxt, mime_type) = self.plugin.graph_impl(*args, **kwargs) + (graph_pbtxt, mime_type) = plugin.graph_impl(*args, **kwargs) self.assertEqual(mime_type, 'text/x-protobuf') return text_format.Parse(graph_pbtxt, graph_pb2.GraphDef()) - def test_info(self): + @graphs_plugin_test.with_runs([ + graphs_plugin_test._RUN_WITH_GRAPH_WITH_METADATA, + graphs_plugin_test._RUN_WITHOUT_GRAPH_WITH_METADATA]) + def test_info(self, plugin): raise self.skipTest('TODO: enable this after tf-nightly writes a conceptual graph.') expected = { @@ -81,7 +84,7 @@ def test_info(self): include_run_metadata=False) self.bootstrap_plugin() - self.assertEqual(expected, self.plugin.info_impl()) + self.assertEqual(expected, plugin.info_impl()) def test_graph_conceptual_graph(self): raise self.skipTest('TODO: enable this after tf-nightly writes a conceptual graph.') diff --git a/third_party/python.bzl b/third_party/python.bzl index b3c41e2c8a..360b0bf72e 100644 --- a/third_party/python.bzl +++ b/third_party/python.bzl @@ -127,11 +127,11 @@ def tensorboard_python_workspace(): http_archive( name = "org_pythonhosted_six", urls = [ - "http://mirror.tensorflow.org/pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", - "http://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", + "http://mirror.tensorflow.org/pypi.python.org/packages/source/s/six/six-1.13.0.tar.gz", + "https://pypi.python.org/packages/source/s/six/six-1.13.0.tar.gz", ], - sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a", - strip_prefix = "six-1.10.0", + sha256 = "30f610279e8b2578cab6db20741130331735c781b56053c59c4076da27f06b66", + strip_prefix = "six-1.13.0", build_file = str(Label("//third_party:six.BUILD")), )