diff --git a/.travis.yml b/.travis.yml
index 2015746981..1f55e9d1e7 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -53,7 +53,10 @@ install:
- pip install yamllint==1.17.0
# TensorBoard deps.
- pip install futures==3.1.1
- - pip install grpcio==1.6.3
+ - pip install grpcio==1.24.3
+ - pip install grpcio-testing==1.24.3
+ - pip install 'google-auth >= 1.6.3, < 2'
+ - pip install 'google-auth-oauthlib >= 0.4.1, < 0.5'
- yarn install --ignore-engines
# Uninstall older Travis numpy to avoid upgrade-in-place issues.
- pip uninstall -y numpy
@@ -66,13 +69,6 @@ install:
pip install "absl-py>=0.7.0" \
&& pip install "numpy<2.0,>=1.14.5"
fi
- - |
- # On TF 1.x nightlies, downgrade estimator temporarily.
- case "${TF_VERSION_ID}" in
- tf-nightly|tf-nightly==*)
- pip install tf-estimator-nightly==1.14.0.dev2019091701
- ;;
- esac
# Deps for gfile S3 test.
- pip install boto3==1.9.86
- pip install moto==1.3.7
diff --git a/RELEASE.md b/RELEASE.md
index 45497d83b8..ff521e13a5 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -1,3 +1,9 @@
+# Release 2.0.1
+
+## Features
+- Preview of TensorBoard.dev uploader! Check out for
+ information and usage instructions.
+
# Release 2.0.0
The 2.0 minor series tracks TensorFlow 2.0.
diff --git a/tensorboard/BUILD b/tensorboard/BUILD
index ca9522823a..7d6df761f5 100644
--- a/tensorboard/BUILD
+++ b/tensorboard/BUILD
@@ -28,6 +28,7 @@ py_binary(
":program",
"//tensorboard:expect_tensorflow_installed",
"//tensorboard/plugins:base_plugin",
+ "//tensorboard/uploader:uploader_main_lib",
"//tensorboard/util:tb_logging",
],
)
@@ -150,6 +151,7 @@ py_library(
"//tensorboard:expect_absl_logging_installed",
"//tensorboard/backend:application",
"//tensorboard/backend/event_processing:event_file_inspector",
+ "//tensorboard/util:argparse_util",
"@org_pocoo_werkzeug",
"@org_pythonhosted_six",
],
@@ -169,6 +171,7 @@ py_test(
"//tensorboard/plugins:base_plugin",
"//tensorboard/plugins/core:core_plugin",
"@org_pocoo_werkzeug",
+ "@org_pythonhosted_mock",
],
)
@@ -274,6 +277,22 @@ py_library(
visibility = ["//visibility:public"],
)
+py_library(
+ name = "expect_grpc_installed",
+ # This is a dummy rule used as a grpc dependency in open-source.
+ # We expect grpc to already be installed on the system, e.g. via
+ # `pip install grpcio`
+ visibility = ["//visibility:public"],
+)
+
+py_library(
+ name = "expect_grpc_testing_installed",
+ # This is a dummy rule used as a grpc_testing dependency in open-source.
+ # We expect grpc_testing to already be installed on the system, e.g. via
+ # `pip install grpcio_testing`
+ visibility = ["//visibility:public"],
+)
+
py_library(
name = "expect_sqlite3_installed",
# This is a dummy rule used as a sqlite3 dependency in open-source.
@@ -306,6 +325,14 @@ py_library(
visibility = ["//visibility:public"],
)
+py_library(
+ name = "expect_absl_flags_argparse_flags_installed",
+ # This is a dummy rule used as a absl-py dependency in open-source.
+ # We expect absl-py to already be installed on the system, e.g. via
+ # `pip install absl-py`
+ visibility = ["//visibility:public"],
+)
+
py_library(
name = "expect_absl_logging_installed",
# This is a dummy rule used as a absl-py dependency in open-source.
@@ -322,6 +349,22 @@ py_library(
visibility = ["//visibility:public"],
)
+py_library(
+ name = "expect_google_auth_installed",
+ # This is a dummy rule used as a google_auth dependency in open-source.
+ # We expect google_auth to already be installed on the system, e.g., via
+ # `pip install google-auth`.
+ visibility = ["//visibility:public"],
+)
+
+py_library(
+ name = "expect_google_auth_oauthlib_installed",
+ # This is a dummy rule used as a google_auth oauthlib_dependency in open-source.
+ # We expect google_auth_oauthlib to already be installed on the system, e.g., via
+ # `pip install google-auth-oauthlib`.
+ visibility = ["//visibility:public"],
+)
+
py_library(
name = "expect_pkg_resources_installed",
# This is a dummy rule used as a pkg-resources dependency in open-source.
diff --git a/tensorboard/defs/BUILD b/tensorboard/defs/BUILD
index f219f54c9c..368b2908ef 100644
--- a/tensorboard/defs/BUILD
+++ b/tensorboard/defs/BUILD
@@ -1,5 +1,7 @@
package(default_visibility = ["//tensorboard:internal"])
+load("//tensorboard/defs:protos.bzl", "tb_proto_library")
+
licenses(["notice"]) # Apache 2.0
filegroup(
@@ -13,4 +15,33 @@ filegroup(
visibility = ["//visibility:public"],
)
+py_test(
+ name = "tb_proto_library_test",
+ srcs = ["tb_proto_library_test.py"],
+ deps = [
+ ":test_base_py_pb2",
+ ":test_base_py_pb2_grpc",
+ ":test_downstream_py_pb2",
+ ":test_downstream_py_pb2_grpc",
+ "//tensorboard:test",
+ ],
+)
+
+tb_proto_library(
+ name = "test_base",
+ srcs = ["test_base.proto"],
+ has_services = True,
+ testonly = True,
+)
+
+tb_proto_library(
+ name = "test_downstream",
+ srcs = ["test_downstream.proto"],
+ deps = [
+ ":test_base",
+ ],
+ has_services = True,
+ testonly = True,
+)
+
exports_files(["web_test_python_stub.template.py"])
diff --git a/tensorboard/defs/protos.bzl b/tensorboard/defs/protos.bzl
index 58e8f27c15..3c18bc1000 100644
--- a/tensorboard/defs/protos.bzl
+++ b/tensorboard/defs/protos.bzl
@@ -12,16 +12,56 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-load("@com_google_protobuf//:protobuf.bzl", "py_proto_library")
+load("@com_google_protobuf//:protobuf.bzl", "proto_gen")
-def tb_proto_library(name, srcs=None, visibility=None, testonly=None):
- py_proto_library(
- name = name + "_py_pb2",
- srcs = srcs,
- srcs_version = "PY2AND3",
- deps = ["@com_google_protobuf//:protobuf_python"],
- protoc = "@com_google_protobuf//:protoc",
- visibility = visibility,
- default_runtime = "@com_google_protobuf//:protobuf_python",
- testonly = testonly,
- )
+def tb_proto_library(
+ name,
+ srcs = None,
+ deps = [],
+ visibility = None,
+ testonly = None,
+ has_services = False):
+ outs_proto = _PyOuts(srcs, grpc = False)
+ outs_grpc = _PyOuts(srcs, grpc = True) if has_services else []
+ outs_all = outs_proto + outs_grpc
+
+ runtime = "@com_google_protobuf//:protobuf_python"
+
+ proto_gen(
+ name = name + "_genproto",
+ srcs = srcs,
+ deps = [s + "_genproto" for s in deps] + [runtime + "_genproto"],
+ includes = [],
+ protoc = "@com_google_protobuf//:protoc",
+ gen_py = True,
+ outs = outs_all,
+ visibility = ["//visibility:public"],
+ plugin = "//external:grpc_python_plugin" if has_services else None,
+ plugin_language = "grpc",
+ )
+
+ py_deps = [s + "_py_pb2" for s in deps] + [runtime]
+ native.py_library(
+ name = name + "_py_pb2",
+ srcs = outs_proto,
+ imports = [],
+ srcs_version = "PY2AND3",
+ deps = py_deps,
+ testonly = testonly,
+ visibility = visibility,
+ )
+ if has_services:
+ native.py_library(
+ name = name + "_py_pb2_grpc",
+ srcs = outs_grpc,
+ imports = [],
+ srcs_version = "PY2AND3",
+ deps = [name + "_py_pb2"] + py_deps,
+ testonly = testonly,
+ visibility = visibility,
+ )
+
+def _PyOuts(srcs, grpc):
+ # Adapted from @com_google_protobuf//:protobuf.bzl.
+ ext = "_pb2.py" if not grpc else "_pb2_grpc.py"
+ return [s[:-len(".proto")] + ext for s in srcs]
diff --git a/tensorboard/defs/tb_proto_library_test.py b/tensorboard/defs/tb_proto_library_test.py
new file mode 100644
index 0000000000..425e1d06f9
--- /dev/null
+++ b/tensorboard/defs/tb_proto_library_test.py
@@ -0,0 +1,44 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the `tb_proto_library` build macro."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorboard import test as tb_test
+from tensorboard.defs import test_base_pb2
+from tensorboard.defs import test_base_pb2_grpc
+from tensorboard.defs import test_downstream_pb2
+from tensorboard.defs import test_downstream_pb2_grpc
+
+
+class TbProtoLibraryTest(tb_test.TestCase):
+ """Tests for `tb_proto_library`."""
+
+ def tests_with_deps(self):
+ foo = test_base_pb2.Foo()
+ foo.foo = 1
+ bar = test_downstream_pb2.Bar()
+ bar.foo.foo = 1
+ self.assertEqual(foo, bar.foo)
+
+ def test_service_deps(self):
+ self.assertIsInstance(test_base_pb2_grpc.FooServiceServicer, type)
+ self.assertIsInstance(test_downstream_pb2_grpc.BarServiceServicer, type)
+
+
+if __name__ == "__main__":
+ tb_test.main()
diff --git a/tensorboard/defs/test_base.proto b/tensorboard/defs/test_base.proto
new file mode 100644
index 0000000000..6795667916
--- /dev/null
+++ b/tensorboard/defs/test_base.proto
@@ -0,0 +1,33 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto3";
+
+message Foo {
+ int32 foo = 1;
+}
+
+service FooService {
+ // Loads some objects.
+ rpc GetFoo(GetFooRequest) returns (GetFooResponse);
+}
+
+message GetFooRequest {
+ int32 count = 1;
+}
+
+message GetFooResponse {
+ repeated Foo foo = 1;
+}
diff --git a/tensorboard/defs/test_downstream.proto b/tensorboard/defs/test_downstream.proto
new file mode 100644
index 0000000000..0ff23a0b82
--- /dev/null
+++ b/tensorboard/defs/test_downstream.proto
@@ -0,0 +1,36 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto3";
+
+import "tensorboard/defs/test_base.proto";
+
+message Bar {
+ Foo foo = 1;
+ int32 bar = 2;
+}
+
+service BarService {
+ // Loads some objects.
+ rpc GetBar(GetBarRequest) returns (GetBarResponse);
+}
+
+message GetBarRequest {
+ int32 count = 1;
+}
+
+message GetBarResponse {
+ repeated Bar bar = 1;
+}
diff --git a/tensorboard/main.py b/tensorboard/main.py
index da0521a852..5ce413cfe1 100644
--- a/tensorboard/main.py
+++ b/tensorboard/main.py
@@ -41,6 +41,7 @@
from tensorboard import program
from tensorboard.compat import tf
from tensorboard.plugins import base_plugin
+from tensorboard.uploader import uploader_main
from tensorboard.util import tb_logging
@@ -56,7 +57,8 @@ def run_main():
tensorboard = program.TensorBoard(
default.get_plugins() + default.get_dynamic_plugins(),
- program.get_default_assets_zip_provider())
+ program.get_default_assets_zip_provider(),
+ subcommands=[uploader_main.UploaderSubcommand()])
try:
from absl import app
# Import this to check that app.run() will accept the flags_parser argument.
diff --git a/tensorboard/pip_package/setup.py b/tensorboard/pip_package/setup.py
index 22141d92c8..a33543f872 100644
--- a/tensorboard/pip_package/setup.py
+++ b/tensorboard/pip_package/setup.py
@@ -26,7 +26,9 @@
'absl-py >= 0.4',
# futures is a backport of the python 3.2+ concurrent.futures module
'futures >= 3.1.1; python_version < "3"',
- 'grpcio >= 1.6.3',
+ 'grpcio >= 1.24.3',
+ 'google-auth >= 1.6.3, < 2',
+ 'google-auth-oauthlib >= 0.4.1, < 0.5',
'markdown >= 2.6.8',
'numpy >= 1.12.0',
'protobuf >= 3.6.0',
diff --git a/tensorboard/program.py b/tensorboard/program.py
index 2bb34581bb..d6069802f6 100644
--- a/tensorboard/program.py
+++ b/tensorboard/program.py
@@ -56,6 +56,7 @@
from tensorboard.backend.event_processing import event_file_inspector as efi
from tensorboard.plugins import base_plugin
from tensorboard.plugins.core import core_plugin
+from tensorboard.util import argparse_util
from tensorboard.util import tb_logging
try:
@@ -68,6 +69,11 @@
logger = tb_logging.get_logger()
+# Default subcommand name. This is a user-facing CLI and should not change.
+_SERVE_SUBCOMMAND_NAME = 'serve'
+# Internal flag name used to store which subcommand was invoked.
+_SUBCOMMAND_FLAG = '__tensorboard_subcommand'
+
def setup_environment():
"""Makes recommended modifications to the environment.
@@ -111,10 +117,13 @@ class TensorBoard(object):
cache_key: As `manager.cache_key`; set by the configure() method.
"""
- def __init__(self,
- plugins=None,
- assets_zip_provider=None,
- server_class=None):
+ def __init__(
+ self,
+ plugins=None,
+ assets_zip_provider=None,
+ server_class=None,
+ subcommands=None,
+ ):
"""Creates new instance.
Args:
@@ -133,9 +142,17 @@ def __init__(self,
assets_zip_provider = get_default_assets_zip_provider()
if server_class is None:
server_class = create_port_scanning_werkzeug_server
+ if subcommands is None:
+ subcommands = []
self.plugin_loaders = [application.make_plugin_loader(p) for p in plugins]
self.assets_zip_provider = assets_zip_provider
self.server_class = server_class
+ self.subcommands = {}
+ for subcommand in subcommands:
+ name = subcommand.name()
+ if name in self.subcommands or name == _SERVE_SUBCOMMAND_NAME:
+ raise ValueError("Duplicate subcommand name: %r" % name)
+ self.subcommands[name] = subcommand
self.flags = None
def configure(self, argv=('',), **kwargs):
@@ -159,15 +176,48 @@ def configure(self, argv=('',), **kwargs):
Raises:
ValueError: If flag values are invalid.
"""
- parser = argparse_flags.ArgumentParser(
+
+ base_parser = argparse_flags.ArgumentParser(
prog='tensorboard',
description=('TensorBoard is a suite of web applications for '
'inspecting and understanding your TensorFlow runs '
'and graphs. https://github.com/tensorflow/tensorboard '))
+ subparsers = base_parser.add_subparsers(
+ help="TensorBoard subcommand (defaults to %r)" % _SERVE_SUBCOMMAND_NAME)
+
+ serve_subparser = subparsers.add_parser(
+ _SERVE_SUBCOMMAND_NAME,
+ help='start local TensorBoard server (default subcommand)')
+ serve_subparser.set_defaults(**{_SUBCOMMAND_FLAG: _SERVE_SUBCOMMAND_NAME})
+
+ if len(argv) < 2 or argv[1].startswith('-'):
+ # This invocation, if valid, must not use any subcommands: we
+ # don't permit flags before the subcommand name.
+ serve_parser = base_parser
+ else:
+ # This invocation, if valid, must use a subcommand: we don't take
+ # any positional arguments to `serve`.
+ serve_parser = serve_subparser
+
+ for (name, subcommand) in six.iteritems(self.subcommands):
+ subparser = subparsers.add_parser(
+ name, help=subcommand.help(), description=subcommand.description())
+ subparser.set_defaults(**{_SUBCOMMAND_FLAG: name})
+ subcommand.define_flags(subparser)
+
for loader in self.plugin_loaders:
- loader.define_flags(parser)
+ loader.define_flags(serve_parser)
+
arg0 = argv[0] if argv else ''
- flags = parser.parse_args(argv[1:]) # Strip binary name from argv.
+
+ with argparse_util.allow_missing_subcommand():
+ flags = base_parser.parse_args(argv[1:]) # Strip binary name from argv.
+ if getattr(flags, _SUBCOMMAND_FLAG, None) is None:
+ # Manually assign default value rather than using `set_defaults`
+ # on the base parser to work around Python bug #9351 on old
+ # versions of `argparse`:
+ setattr(flags, _SUBCOMMAND_FLAG, _SERVE_SUBCOMMAND_NAME)
+
self.cache_key = manager.cache_key(
working_directory=os.getcwd(),
arguments=argv[1:],
@@ -185,8 +235,9 @@ def configure(self, argv=('',), **kwargs):
if not hasattr(flags, k):
raise ValueError('Unknown TensorBoard flag: %s' % k)
setattr(flags, k, v)
- for loader in self.plugin_loaders:
- loader.fix_flags(flags)
+ if getattr(flags, _SUBCOMMAND_FLAG) == _SERVE_SUBCOMMAND_NAME:
+ for loader in self.plugin_loaders:
+ loader.fix_flags(flags)
self.flags = flags
return [arg0]
@@ -208,14 +259,24 @@ def main(self, ignored_argv=('',)):
:rtype: int
"""
self._install_signal_handler(signal.SIGTERM, "SIGTERM")
- if self.flags.inspect:
- logger.info('Not bringing up TensorBoard, but inspecting event files.')
- event_file = os.path.expanduser(self.flags.event_file)
- efi.inspect(self.flags.logdir, event_file, self.flags.tag)
- return 0
- if self.flags.version_tb:
+ subcommand_name = getattr(self.flags, _SUBCOMMAND_FLAG)
+ if subcommand_name == _SERVE_SUBCOMMAND_NAME:
+ runner = self._run_serve_subcommand
+ else:
+ runner = self.subcommands[subcommand_name].run
+ return runner(self.flags) or 0
+
+ def _run_serve_subcommand(self, flags):
+ # TODO(#2801): Make `--version` a flag on only the base parser, not `serve`.
+ if flags.version_tb:
print(version.VERSION)
return 0
+ if flags.inspect:
+ # TODO(@wchargin): Convert `inspect` to a normal subcommand?
+ logger.info('Not bringing up TensorBoard, but inspecting event files.')
+ event_file = os.path.expanduser(flags.event_file)
+ efi.inspect(flags.logdir, event_file, flags.tag)
+ return 0
try:
server = self._make_server()
server.print_serving_message()
@@ -300,6 +361,56 @@ def _make_server(self):
return self.server_class(app, self.flags)
+@six.add_metaclass(ABCMeta)
+class TensorBoardSubcommand(object):
+ """Experimental private API for defining subcommands to tensorboard(1)."""
+
+ @abstractmethod
+ def name(self):
+ """Name of this subcommand, as specified on the command line.
+
+ This must be unique across all subcommands.
+
+ Returns:
+ A string.
+ """
+ pass
+
+ @abstractmethod
+ def define_flags(self, parser):
+ """Configure an argument parser for this subcommand.
+
+ Flags whose names start with two underscores (e.g., `__foo`) are
+ reserved for use by the runtime and must not be defined by
+ subcommands.
+
+ Args:
+ parser: An `argparse.ArgumentParser` scoped to this subcommand,
+ which this function should mutate.
+ """
+ pass
+
+ @abstractmethod
+ def run(self, flags):
+ """Execute this subcommand with user-provided flags.
+
+ Args:
+ flags: An `argparse.Namespace` object with all defined flags.
+
+ Returns:
+ An `int` exit code, or `None` as an alias for `0`.
+ """
+ pass
+
+ def help(self):
+ """Short, one-line help text to display on `tensorboard --help`."""
+ return None
+
+ def description(self):
+ """Description to display on `tensorboard SUBCOMMAND --help`."""
+ return None
+
+
@six.add_metaclass(ABCMeta)
class TensorBoardServer(object):
"""Class for customizing TensorBoard WSGI app serving."""
diff --git a/tensorboard/program_test.py b/tensorboard/program_test.py
index 919fa3d493..7a4d919072 100644
--- a/tensorboard/program_test.py
+++ b/tensorboard/program_test.py
@@ -19,9 +19,17 @@
from __future__ import print_function
import argparse
+import contextlib
+import sys
import six
+try:
+ # python version >= 3.3
+ from unittest import mock # pylint: disable=g-import-not-at-top
+except ImportError:
+ import mock # pylint: disable=g-import-not-at-top,unused-import
+
from tensorboard import program
from tensorboard import test as tb_test
from tensorboard.plugins import base_plugin
@@ -120,5 +128,135 @@ def testSpecifiedHost(self):
self.assertTrue(one_passed) # We expect either IPv4 or IPv6 to be supported
+class SubcommandTest(tb_test.TestCase):
+
+ def setUp(self):
+ super(SubcommandTest, self).setUp()
+ self.stderr = six.StringIO()
+ patchers = [
+ mock.patch.object(program.TensorBoard, '_install_signal_handler'),
+ mock.patch.object(program.TensorBoard, '_run_serve_subcommand'),
+ mock.patch.object(_TestSubcommand, 'run'),
+ mock.patch.object(sys, 'stderr', self.stderr),
+ ]
+ for p in patchers:
+ p.start()
+ self.addCleanup(p.stop)
+ _TestSubcommand.run.return_value = None
+
+ def tearDown(self):
+ stderr = self.stderr.getvalue()
+ if stderr:
+ # In case of failing tests, let there be debug info.
+ print('Stderr:\n%s' % stderr)
+
+ def testImplicitServe(self):
+ tb = program.TensorBoard(
+ plugins=[core_plugin.CorePluginLoader],
+ subcommands=[_TestSubcommand(lambda parser: None)],
+ )
+ tb.configure(('tb', '--logdir', 'logs', '--path_prefix', '/x/'))
+ tb.main()
+ program.TensorBoard._run_serve_subcommand.assert_called_once()
+ flags = program.TensorBoard._run_serve_subcommand.call_args[0][0]
+ self.assertEqual(flags.logdir, 'logs')
+ self.assertEqual(flags.path_prefix, '/x') # fixed by core_plugin
+
+ def testExplicitServe(self):
+ tb = program.TensorBoard(
+ plugins=[core_plugin.CorePluginLoader],
+ subcommands=[_TestSubcommand()],
+ )
+ tb.configure(('tb', 'serve', '--logdir', 'logs', '--path_prefix', '/x/'))
+ tb.main()
+ program.TensorBoard._run_serve_subcommand.assert_called_once()
+ flags = program.TensorBoard._run_serve_subcommand.call_args[0][0]
+ self.assertEqual(flags.logdir, 'logs')
+ self.assertEqual(flags.path_prefix, '/x') # fixed by core_plugin
+
+ def testSubcommand(self):
+ def define_flags(parser):
+ parser.add_argument('--hello')
+
+ tb = program.TensorBoard(
+ plugins=[core_plugin.CorePluginLoader],
+ subcommands=[_TestSubcommand(define_flags=define_flags)],
+ )
+ tb.configure(('tb', 'test', '--hello', 'world'))
+ self.assertEqual(tb.main(), 0)
+ _TestSubcommand.run.assert_called_once()
+ flags = _TestSubcommand.run.call_args[0][0]
+ self.assertEqual(flags.hello, 'world')
+
+ def testSubcommand_ExitCode(self):
+ tb = program.TensorBoard(
+ plugins=[core_plugin.CorePluginLoader],
+ subcommands=[_TestSubcommand()],
+ )
+ _TestSubcommand.run.return_value = 77
+ tb.configure(('tb', 'test'))
+ self.assertEqual(tb.main(), 77)
+
+ def testSubcommand_DoesNotInheritBaseArgs(self):
+ tb = program.TensorBoard(
+ plugins=[core_plugin.CorePluginLoader],
+ subcommands=[_TestSubcommand()],
+ )
+ with self.assertRaises(SystemExit):
+ tb.configure(('tb', 'test', '--logdir', 'logs'))
+ self.assertIn(
+ 'unrecognized arguments: --logdir logs', self.stderr.getvalue())
+ self.stderr.truncate(0)
+
+ def testSubcommand_MayRequirePositionals(self):
+ def define_flags(parser):
+ parser.add_argument('payload')
+
+ tb = program.TensorBoard(
+ plugins=[core_plugin.CorePluginLoader],
+ subcommands=[_TestSubcommand(define_flags=define_flags)],
+ )
+ with self.assertRaises(SystemExit):
+ tb.configure(('tb', 'test'))
+ self.assertIn('required', self.stderr.getvalue())
+ self.assertIn('payload', self.stderr.getvalue())
+ self.stderr.truncate(0)
+
+ def testConflictingNames_AmongSubcommands(self):
+ with self.assertRaises(ValueError) as cm:
+ tb = program.TensorBoard(
+ plugins=[core_plugin.CorePluginLoader],
+ subcommands=[_TestSubcommand(), _TestSubcommand()],
+ )
+ self.assertIn('Duplicate subcommand name:', str(cm.exception))
+ self.assertIn('test', str(cm.exception))
+
+ def testConflictingNames_WithServe(self):
+ with self.assertRaises(ValueError) as cm:
+ tb = program.TensorBoard(
+ plugins=[core_plugin.CorePluginLoader],
+ subcommands=[_TestSubcommand(name='serve')],
+ )
+ self.assertIn('Duplicate subcommand name:', str(cm.exception))
+ self.assertIn('serve', str(cm.exception))
+
+
+class _TestSubcommand(program.TensorBoardSubcommand):
+
+ def __init__(self, name=None, define_flags=None):
+ self._name = name
+ self._define_flags = define_flags
+
+ def name(self):
+ return self._name or 'test'
+
+ def define_flags(self, parser):
+ if self._define_flags:
+ self._define_flags(parser)
+
+ def run(self, flags):
+ pass
+
+
if __name__ == '__main__':
tb_test.main()
diff --git a/tensorboard/tools/whitespace_hygiene_test.py b/tensorboard/tools/whitespace_hygiene_test.py
index 7a46a300a9..3e127811af 100755
--- a/tensorboard/tools/whitespace_hygiene_test.py
+++ b/tensorboard/tools/whitespace_hygiene_test.py
@@ -28,8 +28,10 @@
import sys
-# Remove files from this list as whitespace errors are fixed.
exceptions = frozenset([
+ # End-of-line whitespace is semantic in patch files when a line
+ # contains a single space.
+ "third_party/mock_call_assertions.patch",
])
diff --git a/tensorboard/uploader/BUILD b/tensorboard/uploader/BUILD
new file mode 100644
index 0000000000..a0b25720ba
--- /dev/null
+++ b/tensorboard/uploader/BUILD
@@ -0,0 +1,203 @@
+# Description:
+# Uploader for TensorBoard.dev
+
+package(default_visibility = ["//tensorboard:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+ name = "dev_creds",
+ srcs = ["dev_creds.py"],
+)
+
+py_library(
+ name = "exporter_lib",
+ srcs = ["exporter.py"],
+ deps = [
+ ":util",
+ "//tensorboard:expect_grpc_installed",
+ "//tensorboard/uploader/proto:protos_all_py_pb2",
+ "//tensorboard/util:grpc_util",
+ ],
+)
+
+py_test(
+ name = "exporter_test",
+ srcs = ["exporter_test.py"],
+ deps = [
+ ":exporter_lib",
+ ":test_util",
+ "//tensorboard:expect_grpc_installed",
+ "//tensorboard:expect_grpc_testing_installed",
+ "//tensorboard:test",
+ "//tensorboard/compat/proto:protos_all_py_pb2",
+ "//tensorboard/uploader/proto:protos_all_py_pb2",
+ "//tensorboard/uploader/proto:protos_all_py_pb2_grpc",
+ "//tensorboard/util:grpc_util",
+ "@org_pythonhosted_mock",
+ ],
+)
+
+py_binary(
+ name = "uploader",
+ srcs = ["uploader_main.py"],
+ main = "uploader_main.py",
+ python_version = "PY2",
+ deps = [":uploader_main_lib"],
+)
+
+py_library(
+ name = "uploader_main_lib",
+ srcs = ["uploader_main.py"],
+ visibility = ["//tensorboard:internal"],
+ deps = [
+ ":auth",
+ ":dev_creds",
+ ":exporter_lib",
+ ":uploader_lib",
+ "//tensorboard:expect_absl_app_installed",
+ "//tensorboard:expect_absl_flags_argparse_flags_installed",
+ "//tensorboard:expect_absl_flags_installed",
+ "//tensorboard:expect_absl_logging_installed",
+ "//tensorboard:expect_grpc_installed",
+ "//tensorboard:program",
+ "//tensorboard/plugins:base_plugin",
+ "//tensorboard/uploader/proto:protos_all_py_pb2_grpc",
+ "@org_pythonhosted_six",
+ ],
+)
+
+py_library(
+ name = "uploader_lib",
+ srcs = ["uploader.py"],
+ deps = [
+ ":logdir_loader",
+ ":peekable_iterator",
+ ":util",
+ "//tensorboard:data_compat",
+ "//tensorboard:expect_grpc_installed",
+ "//tensorboard/backend/event_processing:directory_loader",
+ "//tensorboard/backend/event_processing:event_file_loader",
+ "//tensorboard/backend/event_processing:io_wrapper",
+ "//tensorboard/plugins/scalar:metadata",
+ "//tensorboard/uploader/proto:protos_all_py_pb2",
+ "//tensorboard/util:grpc_util",
+ "//tensorboard/util:tb_logging",
+ "//tensorboard/util:tensor_util",
+ "@org_pythonhosted_six",
+ ],
+)
+
+py_test(
+ name = "uploader_test",
+ srcs = ["uploader_test.py"],
+ deps = [
+ ":test_util",
+ ":uploader_lib",
+ ":util",
+ "//tensorboard:expect_grpc_installed",
+ "//tensorboard:expect_grpc_testing_installed",
+ "//tensorboard:expect_tensorflow_installed",
+ "//tensorboard/compat/proto:protos_all_py_pb2",
+ "//tensorboard/plugins/histogram:summary_v2",
+ "//tensorboard/plugins/scalar:summary_v2",
+ "//tensorboard/summary:summary_v1",
+ "//tensorboard/uploader/proto:protos_all_py_pb2",
+ "//tensorboard/uploader/proto:protos_all_py_pb2_grpc",
+ "//tensorboard/util:test_util",
+ "@org_pythonhosted_mock",
+ ],
+)
+
+py_library(
+ name = "auth",
+ srcs = ["auth.py"],
+ deps = [
+ ":util",
+ "//tensorboard:expect_google_auth_installed",
+ "//tensorboard:expect_google_auth_oauthlib_installed",
+ "//tensorboard:expect_grpc_installed",
+ "//tensorboard/util:tb_logging",
+ ],
+)
+
+py_test(
+ name = "auth_test",
+ srcs = ["auth_test.py"],
+ deps = [
+ ":auth",
+ "//tensorboard:expect_google_auth_installed",
+ "//tensorboard:test",
+ "@org_pythonhosted_mock",
+ ],
+)
+
+py_library(
+ name = "logdir_loader",
+ srcs = ["logdir_loader.py"],
+ deps = [
+ "//tensorboard/backend/event_processing:directory_watcher",
+ "//tensorboard/backend/event_processing:io_wrapper",
+ "//tensorboard/util:tb_logging",
+ ],
+)
+
+py_test(
+ name = "logdir_loader_test",
+ srcs = ["logdir_loader_test.py"],
+ deps = [
+ ":logdir_loader",
+ "//tensorboard:test",
+ "//tensorboard/backend/event_processing:directory_loader",
+ "//tensorboard/backend/event_processing:event_file_loader",
+ "//tensorboard/backend/event_processing:io_wrapper",
+ "//tensorboard/util:test_util",
+ "@org_pythonhosted_six",
+ ],
+)
+
+py_library(
+ name = "test_util",
+ testonly = 1,
+ srcs = ["test_util.py"],
+ deps = [
+ "//tensorboard:expect_grpc_installed",
+ "//tensorboard/compat/proto:protos_all_py_pb2",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "util",
+ srcs = ["util.py"],
+)
+
+py_test(
+ name = "util_test",
+ srcs = ["util_test.py"],
+ deps = [
+ ":test_util",
+ ":util",
+ "//tensorboard:test",
+ "@com_google_protobuf//:protobuf_python",
+ "@org_pythonhosted_mock",
+ ],
+)
+
+py_library(
+ name = "peekable_iterator",
+ srcs = ["peekable_iterator.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_test(
+ name = "peekable_iterator_test",
+ size = "small",
+ srcs = ["peekable_iterator_test.py"],
+ deps = [
+ ":peekable_iterator",
+ "//tensorboard:test",
+ ],
+)
diff --git a/tensorboard/uploader/auth.py b/tensorboard/uploader/auth.py
new file mode 100644
index 0000000000..a435c6ff5a
--- /dev/null
+++ b/tensorboard/uploader/auth.py
@@ -0,0 +1,227 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# Lint as: python3
+"""Provides authentication support for TensorBoardUploader."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import errno
+import json
+import os
+import sys
+import webbrowser
+
+import google_auth_oauthlib.flow
+import grpc
+import google.auth
+import google.auth.transport.requests
+import google.oauth2.credentials
+
+from tensorboard.uploader import util
+from tensorboard.util import tb_logging
+
+
+logger = tb_logging.get_logger()
+
+
+# OAuth2 scopes used for OpenID Connect:
+# https://developers.google.com/identity/protocols/OpenIDConnect#scope-param
+OPENID_CONNECT_SCOPES = (
+ "openid",
+ "https://www.googleapis.com/auth/userinfo.email",
+)
+
+
+# The client "secret" is public by design for installed apps. See
+# https://developers.google.com/identity/protocols/OAuth2?csw=1#installed
+OAUTH_CLIENT_CONFIG = u"""
+{
+ "installed": {
+ "client_id": "373649185512-8v619h5kft38l4456nm2dj4ubeqsrvh6.apps.googleusercontent.com",
+ "project_id": "hosted-tensorboard-prod",
+ "auth_uri": "https://accounts.google.com/o/oauth2/auth",
+ "token_uri": "https://oauth2.googleapis.com/token",
+ "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
+ "client_secret": "pOyAuU2yq2arsM98Bw5hwYtr",
+ "redirect_uris": [
+ "urn:ietf:wg:oauth:2.0:oob",
+ "http://localhost"
+ ]
+ }
+}
+"""
+
+
+# Components of the relative path (within the user settings directory) at which
+# to store TensorBoard uploader credentials.
+TENSORBOARD_CREDENTIALS_FILEPATH_PARTS = [
+ "tensorboard", "credentials", "uploader-creds.json"]
+
+
+class CredentialsStore(object):
+ """Private file store for a `google.oauth2.credentials.Credentials`."""
+
+ _DEFAULT_CONFIG_DIRECTORY = object() # Sentinel value.
+
+ def __init__(self, user_config_directory=_DEFAULT_CONFIG_DIRECTORY):
+ """Creates a CredentialsStore.
+
+ Args:
+ user_config_directory: Optional absolute path to the root directory for
+ storing user configs, under which to store the credentials file. If not
+ set, defaults to a platform-specific location. If set to None, the
+ store is disabled (reads return None; write and clear are no-ops).
+ """
+ if user_config_directory is CredentialsStore._DEFAULT_CONFIG_DIRECTORY:
+ user_config_directory = util.get_user_config_directory()
+ if user_config_directory is None:
+ logger.warning(
+ "Credentials caching disabled - no private config directory found")
+ if user_config_directory is None:
+ self._credentials_filepath = None
+ else:
+ self._credentials_filepath = os.path.join(
+ user_config_directory, *TENSORBOARD_CREDENTIALS_FILEPATH_PARTS)
+
+ def read_credentials(self):
+ """Returns the current `google.oauth2.credentials.Credentials`, or None."""
+ if self._credentials_filepath is None:
+ return None
+ if os.path.exists(self._credentials_filepath):
+ return google.oauth2.credentials.Credentials.from_authorized_user_file(
+ self._credentials_filepath)
+ return None
+
+ def write_credentials(self, credentials):
+ """Writes a `google.oauth2.credentials.Credentials` to the store."""
+ if not isinstance(credentials, google.oauth2.credentials.Credentials):
+ raise TypeError("Cannot write credentials of type %s" % type(credentials))
+ if self._credentials_filepath is None:
+ return
+ # Make the credential file private if not on Windows; on Windows we rely on
+ # the default user config settings directory being private since we don't
+ # have a straightforward way to make an individual file private.
+ private = os.name != "nt"
+ util.make_file_with_directories(self._credentials_filepath, private=private)
+ data = {
+ "refresh_token": credentials.refresh_token,
+ "token_uri": credentials.token_uri,
+ "client_id": credentials.client_id,
+ "client_secret": credentials.client_secret,
+ "scopes": credentials.scopes,
+ "type": "authorized_user",
+ }
+ with open(self._credentials_filepath, "w") as f:
+ json.dump(data, f)
+
+ def clear(self):
+ """Clears the store of any persisted credentials information."""
+ if self._credentials_filepath is None:
+ return
+ try:
+ os.remove(self._credentials_filepath)
+ except OSError as e:
+ if e.errno != errno.ENOENT:
+ raise
+
+
+def build_installed_app_flow(client_config):
+ """Returns a `CustomInstalledAppFlow` for the given config.
+
+ Args:
+ client_config (Mapping[str, Any]): The client configuration in the Google
+ client secrets format.
+
+ Returns:
+ CustomInstalledAppFlow: the constructed flow.
+ """
+ return CustomInstalledAppFlow.from_client_config(
+ client_config, scopes=OPENID_CONNECT_SCOPES)
+
+
+class CustomInstalledAppFlow(google_auth_oauthlib.flow.InstalledAppFlow):
+ """Customized version of the Installed App OAuth2 flow."""
+
+ def run(self, force_console=False):
+ """Run the flow using a local server if possible, otherwise the console."""
+ # TODO(b/141721828): make auto-detection smarter, especially for macOS.
+ if not force_console and os.getenv("DISPLAY"):
+ try:
+ return self.run_local_server(port=0)
+ except webbrowser.Error:
+ sys.stderr.write("Falling back to console authentication flow...\n")
+ return self.run_console()
+
+
+class IdTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin):
+ """A `gRPC AuthMetadataPlugin` that uses ID tokens.
+
+ This works like the existing `google.auth.transport.grpc.AuthMetadataPlugin`
+ except that instead of always using access tokens, it preferentially uses the
+ `Credentials.id_token` property if available (and logs an error otherwise).
+
+ See http://www.grpc.io/grpc/python/grpc.html#grpc.AuthMetadataPlugin
+ """
+
+ def __init__(self, credentials, request):
+ """Constructs an IdTokenAuthMetadataPlugin.
+
+ Args:
+ credentials (google.auth.credentials.Credentials): The credentials to
+ add to requests.
+ request (google.auth.transport.Request): A HTTP transport request object
+ used to refresh credentials as needed.
+ """
+ super(IdTokenAuthMetadataPlugin, self).__init__()
+ if not isinstance(credentials, google.oauth2.credentials.Credentials):
+ raise TypeError(
+ "Cannot get ID tokens from credentials type %s" % type(credentials))
+ self._credentials = credentials
+ self._request = request
+
+ def __call__(self, context, callback):
+ """Passes authorization metadata into the given callback.
+
+ Args:
+ context (grpc.AuthMetadataContext): The RPC context.
+ callback (grpc.AuthMetadataPluginCallback): The callback that will
+ be invoked to pass in the authorization metadata.
+ """
+ headers = {}
+ self._credentials.before_request(
+ self._request, context.method_name, context.service_url, headers)
+ id_token = getattr(self._credentials, "id_token", None)
+ if id_token:
+ self._credentials.apply(headers, token=id_token)
+ else:
+ logger.error("Failed to find ID token credentials")
+ # Pass headers as key-value pairs to match CallCredentials metadata.
+ callback(list(headers.items()), None)
+
+
+def id_token_call_credentials(credentials):
+ """Constructs `grpc.CallCredentials` using `google.auth.Credentials.id_token`.
+
+ Args:
+ credentials (google.auth.credentials.Credentials): The credentials to use.
+
+ Returns:
+ grpc.CallCredentials: The call credentials.
+ """
+ request = google.auth.transport.requests.Request()
+ return grpc.metadata_call_credentials(
+ IdTokenAuthMetadataPlugin(credentials, request))
diff --git a/tensorboard/uploader/auth_test.py b/tensorboard/uploader/auth_test.py
new file mode 100644
index 0000000000..587109f625
--- /dev/null
+++ b/tensorboard/uploader/auth_test.py
@@ -0,0 +1,140 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# Lint as: python3
+"""Tests for tensorboard.uploader.auth."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import os
+
+import google.auth.credentials
+import google.oauth2.credentials
+
+from tensorboard.uploader import auth
+from tensorboard import test as tb_test
+
+
+class CredentialsStoreTest(tb_test.TestCase):
+
+ def test_no_config_dir(self):
+ store = auth.CredentialsStore(user_config_directory=None)
+ self.assertIsNone(store.read_credentials())
+ creds = google.oauth2.credentials.Credentials(token=None)
+ store.write_credentials(creds)
+ store.clear()
+
+ def test_clear_existent_file(self):
+ root = self.get_temp_dir()
+ path = os.path.join(
+ root, "tensorboard", "credentials", "uploader-creds.json")
+ os.makedirs(os.path.dirname(path))
+ open(path, mode="w").close()
+ self.assertTrue(os.path.exists(path))
+ auth.CredentialsStore(user_config_directory=root).clear()
+ self.assertFalse(os.path.exists(path))
+
+ def test_clear_nonexistent_file(self):
+ root = self.get_temp_dir()
+ path = os.path.join(
+ root, "tensorboard", "credentials", "uploader-creds.json")
+ self.assertFalse(os.path.exists(path))
+ auth.CredentialsStore(user_config_directory=root).clear()
+ self.assertFalse(os.path.exists(path))
+
+ def test_write_wrong_type(self):
+ creds = google.auth.credentials.AnonymousCredentials()
+ with self.assertRaisesRegex(TypeError, "google.auth.credentials"):
+ auth.CredentialsStore(user_config_directory=None).write_credentials(creds)
+
+ def test_write_creates_private_file(self):
+ root = self.get_temp_dir()
+ auth.CredentialsStore(user_config_directory=root).write_credentials(
+ google.oauth2.credentials.Credentials(
+ token=None, refresh_token="12345"))
+ path = os.path.join(
+ root, "tensorboard", "credentials", "uploader-creds.json")
+ self.assertTrue(os.path.exists(path))
+ # Skip permissions check on Windows.
+ if os.name != "nt":
+ self.assertEqual(0o600, os.stat(path).st_mode & 0o777)
+ with open(path) as f:
+ contents = json.load(f)
+ self.assertEqual("12345", contents["refresh_token"])
+
+ def test_write_overwrites_file(self):
+ root = self.get_temp_dir()
+ store = auth.CredentialsStore(user_config_directory=root)
+ # Write twice to ensure that we're overwriting correctly.
+ store.write_credentials(google.oauth2.credentials.Credentials(
+ token=None, refresh_token="12345"))
+ store.write_credentials(google.oauth2.credentials.Credentials(
+ token=None, refresh_token="67890"))
+ path = os.path.join(
+ root, "tensorboard", "credentials", "uploader-creds.json")
+ self.assertTrue(os.path.exists(path))
+ with open(path) as f:
+ contents = json.load(f)
+ self.assertEqual("67890", contents["refresh_token"])
+
+ def test_write_and_read_roundtrip(self):
+ orig_creds = google.oauth2.credentials.Credentials(
+ token="12345",
+ refresh_token="67890",
+ token_uri="https://oauth2.googleapis.com/token",
+ client_id="my-client",
+ client_secret="123abc456xyz",
+ scopes=["userinfo", "email"])
+ root = self.get_temp_dir()
+ store = auth.CredentialsStore(user_config_directory=root)
+ store.write_credentials(orig_creds)
+ creds = store.read_credentials()
+ self.assertEqual(orig_creds.refresh_token, creds.refresh_token)
+ self.assertEqual(orig_creds.token_uri, creds.token_uri)
+ self.assertEqual(orig_creds.client_id, creds.client_id)
+ self.assertEqual(orig_creds.client_secret, creds.client_secret)
+
+ def test_read_nonexistent_file(self):
+ root = self.get_temp_dir()
+ store = auth.CredentialsStore(user_config_directory=root)
+ self.assertIsNone(store.read_credentials())
+
+ def test_read_non_json_file(self):
+ root = self.get_temp_dir()
+ store = auth.CredentialsStore(user_config_directory=root)
+ path = os.path.join(
+ root, "tensorboard", "credentials", "uploader-creds.json")
+ os.makedirs(os.path.dirname(path))
+ with open(path, mode="w") as f:
+ f.write("foobar")
+ with self.assertRaises(ValueError):
+ store.read_credentials()
+
+ def test_read_invalid_json_file(self):
+ root = self.get_temp_dir()
+ store = auth.CredentialsStore(user_config_directory=root)
+ path = os.path.join(
+ root, "tensorboard", "credentials", "uploader-creds.json")
+ os.makedirs(os.path.dirname(path))
+ with open(path, mode="w") as f:
+ f.write("{}")
+ with self.assertRaises(ValueError):
+ store.read_credentials()
+
+
+if __name__ == "__main__":
+ tb_test.main()
diff --git a/tensorboard/uploader/dev_creds.py b/tensorboard/uploader/dev_creds.py
new file mode 100644
index 0000000000..fe4617e3ec
--- /dev/null
+++ b/tensorboard/uploader/dev_creds.py
@@ -0,0 +1,29 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Module providing access to dev GRPC SSL credentials."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+DEV_SSL_CERT = u"""
+"""
+
+DEV_SSL_CERT_KEY = u"""
+"""
+
+DEV_OAUTH_CLIENT_CONFIG = u"""
+"""
diff --git a/tensorboard/uploader/exporter.py b/tensorboard/uploader/exporter.py
new file mode 100644
index 0000000000..e7421f9a24
--- /dev/null
+++ b/tensorboard/uploader/exporter.py
@@ -0,0 +1,211 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Downloads experiment data from TensorBoard.dev."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import base64
+import errno
+import grpc
+import json
+import os
+import string
+import time
+
+from tensorboard.uploader.proto import export_service_pb2
+from tensorboard.uploader import util
+from tensorboard.util import grpc_util
+
+# Characters that are assumed to be safe in filenames. Note that the
+# server's experiment IDs are base64 encodings of 16-byte blobs, so they
+# can theoretically collide on case-insensitive filesystems. Each
+# character has a ~3% chance of colliding, and so two random IDs have
+# about a ~10^-33 chance of colliding. As a precaution, we'll still
+# detect collision and fail fast rather than overwriting data.
+_FILENAME_SAFE_CHARS = frozenset(string.ascii_letters + string.digits + "-_")
+
+# Maximum value of a signed 64-bit integer.
+_MAX_INT64 = 2**63 - 1
+
+class TensorBoardExporter(object):
+ """Exports all of the user's experiment data from TensorBoard.dev.
+
+ Data is exported into a directory, with one file per experiment. Each
+ experiment file is a sequence of time series, represented as a stream
+ of JSON objects, one per line. Each JSON object includes a run name,
+ tag name, `tensorboard.compat.proto.summary_pb2.SummaryMetadata` proto
+ (base64-encoded, standard RFC 4648 alphabet), and set of points.
+ Points are stored in three equal-length lists of steps, wall times (as
+ seconds since epoch), and scalar values, for storage efficiency.
+
+ Such streams of JSON objects may be conveniently processed with tools
+ like jq(1).
+
+ For example one line of an experiment file might read (when
+ pretty-printed):
+
+ {
+ "points": {
+ "steps": [0, 5],
+ "values": [4.8935227394104, 2.5438034534454346],
+ "wall_times": [1563406522.669238, 1563406523.0268838]
+ },
+ "run": "lr_1E-04,conv=1,fc=2",
+ "summary_metadata": "CgkKB3NjYWxhcnMSC3hlbnQveGVudF8x",
+ "tag": "xent/xent_1"
+ }
+
+ This is a time series with two points, both logged on 2019-07-17, one
+ about 0.36 seconds after the other.
+ """
+
+ def __init__(self, reader_service_client, output_directory):
+ """Constructs a TensorBoardExporter.
+
+ Args:
+ reader_service_client: A TensorBoardExporterService stub instance.
+ output_directory: Path to a directory into which to write data. The
+ directory must not exist, to avoid stomping existing or concurrent
+ output. Its ancestors will be created if needed.
+ """
+ self._api = reader_service_client
+ self._outdir = output_directory
+ parent_dir = os.path.dirname(self._outdir)
+ if parent_dir:
+ _mkdir_p(parent_dir)
+ try:
+ os.mkdir(self._outdir)
+ except OSError as e:
+ if e.errno == errno.EEXIST:
+ # Bail to avoid stomping existing output.
+ raise OutputDirectoryExistsError()
+
+ def export(self, read_time=None):
+ """Executes the export flow.
+
+ Args:
+ read_time: A fixed timestamp from which to export data, as float seconds
+ since epoch (like `time.time()`). Optional; defaults to the current
+ time.
+
+ Yields:
+ After each experiment is successfully downloaded, the ID of that
+ experiment, as a string.
+ """
+ if read_time is None:
+ read_time = time.time()
+ for experiment_id in self._request_experiment_ids(read_time):
+ filepath = _scalars_filepath(self._outdir, experiment_id)
+ try:
+ with _open_excl(filepath) as outfile:
+ data = self._request_scalar_data(experiment_id, read_time)
+ for block in data:
+ json.dump(block, outfile, sort_keys=True)
+ outfile.write("\n")
+ outfile.flush()
+ yield experiment_id
+ except grpc.RpcError as e:
+ if e.code() == grpc.StatusCode.CANCELLED:
+ raise GrpcTimeoutException(experiment_id)
+ else:
+ raise
+
+ def _request_experiment_ids(self, read_time):
+ """Yields all of the calling user's experiment IDs, as strings."""
+ request = export_service_pb2.StreamExperimentsRequest(limit=_MAX_INT64)
+ util.set_timestamp(request.read_timestamp, read_time)
+ stream = self._api.StreamExperiments(
+ request, metadata=grpc_util.version_metadata())
+ for response in stream:
+ for experiment_id in response.experiment_ids:
+ yield experiment_id
+
+ def _request_scalar_data(self, experiment_id, read_time):
+ """Yields JSON-serializable blocks of scalar data."""
+ request = export_service_pb2.StreamExperimentDataRequest()
+ request.experiment_id = experiment_id
+ util.set_timestamp(request.read_timestamp, read_time)
+ # No special error handling as we don't expect any errors from these
+ # calls: all experiments should exist (read consistency timestamp)
+ # and be owned by the calling user (only queried for own experiment
+ # IDs). Any non-transient errors would be internal, and we have no
+ # way to efficiently resume from transient errors because the server
+ # does not support pagination.
+ stream = self._api.StreamExperimentData(
+ request, metadata=grpc_util.version_metadata())
+ for response in stream:
+ metadata = base64.b64encode(
+ response.tag_metadata.SerializeToString()).decode("ascii")
+ wall_times = [t.ToNanoseconds() / 1e9 for t in response.points.wall_times]
+ yield {
+ u"run": response.run_name,
+ u"tag": response.tag_name,
+ u"summary_metadata": metadata,
+ u"points": {
+ u"steps": list(response.points.steps),
+ u"wall_times": wall_times,
+ u"values": list(response.points.values),
+ },
+ }
+
+
+class OutputDirectoryExistsError(ValueError):
+ pass
+
+
+class OutputFileExistsError(ValueError):
+ # Like Python 3's `__builtins__.FileExistsError`.
+ pass
+
+class GrpcTimeoutException(Exception):
+ def __init__(self, experiment_id):
+ super(GrpcTimeoutException, self).__init__(experiment_id)
+ self.experiment_id = experiment_id
+
+def _scalars_filepath(base_dir, experiment_id):
+ """Gets file path in which to store scalars for the given experiment."""
+ # Experiment IDs from the server should be filename-safe; verify
+ # this before creating any files.
+ bad_chars = frozenset(experiment_id) - _FILENAME_SAFE_CHARS
+ if bad_chars:
+ raise RuntimeError(
+ "Unexpected characters ({bad_chars!r}) in experiment ID {eid!r}".format(
+ bad_chars=sorted(bad_chars), eid=experiment_id))
+ return os.path.join(base_dir, "scalars_%s.json" % experiment_id)
+
+
+def _mkdir_p(path):
+ """Like `os.makedirs(path, exist_ok=True)`, but Python 2-compatible."""
+ try:
+ os.makedirs(path)
+ except OSError as e:
+ if e.errno != errno.EEXIST or not os.path.isdir(path):
+ raise
+
+
+def _open_excl(path):
+ """Like `open(path, "x")`, but Python 2-compatible."""
+ try:
+ # `os.O_EXCL` works on Windows as well as POSIX-compliant systems.
+ # See:
+ fd = os.open(path, os.O_WRONLY | os.O_CREAT | os.O_EXCL)
+ except OSError as e:
+ if e.errno == errno.EEXIST:
+ raise OutputFileExistsError(path)
+ else:
+ raise
+ return os.fdopen(fd, "w")
diff --git a/tensorboard/uploader/exporter_test.py b/tensorboard/uploader/exporter_test.py
new file mode 100644
index 0000000000..f4e5213a86
--- /dev/null
+++ b/tensorboard/uploader/exporter_test.py
@@ -0,0 +1,388 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorboard.uploader.exporter."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import base64
+import errno
+import json
+import os
+
+import grpc
+import grpc_testing
+
+try:
+ # python version >= 3.3
+ from unittest import mock # pylint: disable=g-import-not-at-top
+except ImportError:
+ import mock # pylint: disable=g-import-not-at-top,unused-import
+
+
+from tensorboard.uploader.proto import export_service_pb2
+from tensorboard.uploader.proto import export_service_pb2_grpc
+from tensorboard.uploader import exporter as exporter_lib
+from tensorboard.uploader import test_util
+from tensorboard.util import grpc_util
+from tensorboard import test as tb_test
+from tensorboard.compat.proto import summary_pb2
+
+
+class TensorBoardExporterTest(tb_test.TestCase):
+
+ def _create_mock_api_client(self):
+ # Create a stub instance (using a test channel) in order to derive a mock
+ # from it with autospec enabled. Mocking TensorBoardExporterServiceStub
+ # itself doesn't work with autospec because grpc constructs stubs via
+ # metaclassing.
+ test_channel = grpc_testing.channel(
+ service_descriptors=[], time=grpc_testing.strict_real_time())
+ stub = export_service_pb2_grpc.TensorBoardExporterServiceStub(test_channel)
+ mock_api_client = mock.create_autospec(stub)
+ return mock_api_client
+
+ def _make_experiments_response(self, eids):
+ return export_service_pb2.StreamExperimentsResponse(experiment_ids=eids)
+
+ def test_e2e_success_case(self):
+ mock_api_client = self._create_mock_api_client()
+ mock_api_client.StreamExperiments.return_value = iter([
+ export_service_pb2.StreamExperimentsResponse(experiment_ids=["789"]),
+ ])
+
+ def stream_experiments(request, **kwargs):
+ del request # unused
+ self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())
+ yield export_service_pb2.StreamExperimentsResponse(
+ experiment_ids=["123", "456"])
+ yield export_service_pb2.StreamExperimentsResponse(experiment_ids=["789"])
+
+ def stream_experiment_data(request, **kwargs):
+ self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())
+ for run in ("train", "test"):
+ for tag in ("accuracy", "loss"):
+ response = export_service_pb2.StreamExperimentDataResponse()
+ response.run_name = run
+ response.tag_name = tag
+ display_name = "%s:%s" % (request.experiment_id, tag)
+ response.tag_metadata.CopyFrom(
+ test_util.scalar_metadata(display_name))
+ for step in range(10):
+ response.points.steps.append(step)
+ response.points.values.append(2.0 * step)
+ response.points.wall_times.add(
+ seconds=1571084520 + step, nanos=862939144)
+ yield response
+
+ mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments)
+ mock_api_client.StreamExperimentData = mock.Mock(
+ wraps=stream_experiment_data)
+
+ outdir = os.path.join(self.get_temp_dir(), "outdir")
+ exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir)
+ start_time = 1571084846.25
+ start_time_pb = test_util.timestamp_pb(1571084846250000000)
+
+ generator = exporter.export(read_time=start_time)
+ expected_files = []
+ self.assertTrue(os.path.isdir(outdir))
+ self.assertCountEqual(expected_files, os.listdir(outdir))
+ mock_api_client.StreamExperiments.assert_not_called()
+ mock_api_client.StreamExperimentData.assert_not_called()
+
+ # The first iteration should request the list of experiments and
+ # data for one of them.
+ self.assertEqual(next(generator), "123")
+ expected_files.append("scalars_123.json")
+ self.assertCountEqual(expected_files, os.listdir(outdir))
+
+ expected_eids_request = export_service_pb2.StreamExperimentsRequest()
+ expected_eids_request.read_timestamp.CopyFrom(start_time_pb)
+ expected_eids_request.limit = 2**63 - 1
+ mock_api_client.StreamExperiments.assert_called_once_with(
+ expected_eids_request, metadata=grpc_util.version_metadata())
+
+ expected_data_request = export_service_pb2.StreamExperimentDataRequest()
+ expected_data_request.experiment_id = "123"
+ expected_data_request.read_timestamp.CopyFrom(start_time_pb)
+ mock_api_client.StreamExperimentData.assert_called_once_with(
+ expected_data_request, metadata=grpc_util.version_metadata())
+
+ # The next iteration should just request data for the next experiment.
+ mock_api_client.StreamExperiments.reset_mock()
+ mock_api_client.StreamExperimentData.reset_mock()
+ self.assertEqual(next(generator), "456")
+
+ expected_files.append("scalars_456.json")
+ self.assertCountEqual(expected_files, os.listdir(outdir))
+ mock_api_client.StreamExperiments.assert_not_called()
+ expected_data_request.experiment_id = "456"
+ mock_api_client.StreamExperimentData.assert_called_once_with(
+ expected_data_request, metadata=grpc_util.version_metadata())
+
+ # Again, request data for the next experiment; this experiment ID
+ # was in the second response batch in the list of IDs.
+ expected_files.append("scalars_789.json")
+ mock_api_client.StreamExperiments.reset_mock()
+ mock_api_client.StreamExperimentData.reset_mock()
+ self.assertEqual(next(generator), "789")
+
+ self.assertCountEqual(expected_files, os.listdir(outdir))
+ mock_api_client.StreamExperiments.assert_not_called()
+ expected_data_request.experiment_id = "789"
+ mock_api_client.StreamExperimentData.assert_called_once_with(
+ expected_data_request, metadata=grpc_util.version_metadata())
+
+ # The final continuation shouldn't need to send any RPCs.
+ mock_api_client.StreamExperiments.reset_mock()
+ mock_api_client.StreamExperimentData.reset_mock()
+ self.assertEqual(list(generator), [])
+
+ self.assertCountEqual(expected_files, os.listdir(outdir))
+ mock_api_client.StreamExperiments.assert_not_called()
+ mock_api_client.StreamExperimentData.assert_not_called()
+
+ # Spot-check one of the files.
+ with open(os.path.join(outdir, "scalars_456.json")) as infile:
+ jsons = [json.loads(line) for line in infile]
+ self.assertLen(jsons, 4)
+ datum = jsons[2]
+ self.assertEqual(datum.pop("run"), "test")
+ self.assertEqual(datum.pop("tag"), "accuracy")
+ summary_metadata = summary_pb2.SummaryMetadata.FromString(
+ base64.b64decode(datum.pop("summary_metadata")))
+ expected_summary_metadata = test_util.scalar_metadata("456:accuracy")
+ self.assertEqual(summary_metadata, expected_summary_metadata)
+ points = datum.pop("points")
+ expected_steps = [x for x in range(10)]
+ expected_values = [2.0 * x for x in range(10)]
+ expected_wall_times = [1571084520.862939144 + x for x in range(10)]
+ self.assertEqual(points.pop("steps"), expected_steps)
+ self.assertEqual(points.pop("values"), expected_values)
+ self.assertEqual(points.pop("wall_times"), expected_wall_times)
+ self.assertEqual(points, {})
+ self.assertEqual(datum, {})
+
+ def test_rejects_dangerous_experiment_ids(self):
+ mock_api_client = self._create_mock_api_client()
+
+ def stream_experiments(request, **kwargs):
+ del request # unused
+ yield export_service_pb2.StreamExperimentsResponse(
+ experiment_ids=["../authorized_keys"])
+
+ mock_api_client.StreamExperiments = stream_experiments
+
+ outdir = os.path.join(self.get_temp_dir(), "outdir")
+ exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir)
+ generator = exporter.export()
+
+ with self.assertRaises(RuntimeError) as cm:
+ next(generator)
+
+ msg = str(cm.exception)
+ self.assertIn("Unexpected characters", msg)
+ self.assertIn(repr(sorted([u".", u"/"])), msg)
+ self.assertIn("../authorized_keys", msg)
+ mock_api_client.StreamExperimentData.assert_not_called()
+
+ def test_fails_nicely_on_stream_experiment_data_timeout(self):
+ # Setup: Client where:
+ # 1. stream_experiments will say there is one experiment_id.
+ # 2. stream_experiment_data will raise a grpc CANCELLED, as per
+ # a timeout.
+ mock_api_client = self._create_mock_api_client()
+ experiment_id="123"
+
+ def stream_experiments(request, **kwargs):
+ del request # unused
+ yield export_service_pb2.StreamExperimentsResponse(
+ experiment_ids=[experiment_id])
+
+ def stream_experiment_data(request, **kwargs):
+ raise test_util.grpc_error(grpc.StatusCode.CANCELLED, "details string")
+
+ mock_api_client.StreamExperiments = stream_experiments
+ mock_api_client.StreamExperimentData = stream_experiment_data
+
+ outdir = os.path.join(self.get_temp_dir(), "outdir")
+ # Execute: exporter.export()
+ exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir)
+ generator = exporter.export()
+ # Expect: A nice exception of the right type and carrying the right
+ # experiment_id.
+ with self.assertRaises(exporter_lib.GrpcTimeoutException) as cm:
+ next(generator)
+ self.assertEquals(cm.exception.experiment_id, experiment_id)
+
+ def test_stream_experiment_data_passes_through_unexpected_exception(self):
+ # Setup: Client where:
+ # 1. stream_experiments will say there is one experiment_id.
+ # 2. stream_experiment_data will throw an internal error.
+ mock_api_client = self._create_mock_api_client()
+ experiment_id = "123"
+
+ def stream_experiments(request, **kwargs):
+ del request # unused
+ yield export_service_pb2.StreamExperimentsResponse(
+ experiment_ids=[experiment_id])
+
+ def stream_experiment_data(request, **kwargs):
+ del request # unused
+ raise test_util.grpc_error(grpc.StatusCode.INTERNAL, "details string")
+
+ mock_api_client.StreamExperiments = stream_experiments
+ mock_api_client.StreamExperimentData = stream_experiment_data
+
+ outdir = os.path.join(self.get_temp_dir(), "outdir")
+ # Execute: exporter.export().
+ exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir)
+ generator = exporter.export()
+ # Expect: The internal error is passed through.
+ with self.assertRaises(grpc.RpcError) as cm:
+ next(generator)
+ self.assertEquals(cm.exception.details(), "details string")
+
+ def test_handles_outdir_with_no_slash(self):
+ oldcwd = os.getcwd()
+ try:
+ os.chdir(self.get_temp_dir())
+ mock_api_client = self._create_mock_api_client()
+ mock_api_client.StreamExperiments.return_value = iter([
+ export_service_pb2.StreamExperimentsResponse(experiment_ids=["123"]),
+ ])
+ mock_api_client.StreamExperimentData.return_value = iter([
+ export_service_pb2.StreamExperimentDataResponse()
+ ])
+
+ exporter = exporter_lib.TensorBoardExporter(mock_api_client, "outdir")
+ generator = exporter.export()
+ self.assertEqual(list(generator), ["123"])
+ self.assertTrue(os.path.isdir("outdir"))
+ finally:
+ os.chdir(oldcwd)
+
+ def test_rejects_existing_directory(self):
+ mock_api_client = self._create_mock_api_client()
+ outdir = os.path.join(self.get_temp_dir(), "outdir")
+ os.mkdir(outdir)
+ with open(os.path.join(outdir, "scalars_999.json"), "w"):
+ pass
+
+ with self.assertRaises(exporter_lib.OutputDirectoryExistsError):
+ exporter_lib.TensorBoardExporter(mock_api_client, outdir)
+
+ mock_api_client.StreamExperiments.assert_not_called()
+ mock_api_client.StreamExperimentData.assert_not_called()
+
+ def test_rejects_existing_file(self):
+ mock_api_client = self._create_mock_api_client()
+
+ def stream_experiments(request, **kwargs):
+ del request # unused
+ yield export_service_pb2.StreamExperimentsResponse(experiment_ids=["123"])
+
+ mock_api_client.StreamExperiments = stream_experiments
+
+ outdir = os.path.join(self.get_temp_dir(), "outdir")
+ exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir)
+ generator = exporter.export()
+
+ with open(os.path.join(outdir, "scalars_123.json"), "w"):
+ pass
+
+ with self.assertRaises(exporter_lib.OutputFileExistsError):
+ next(generator)
+
+ mock_api_client.StreamExperimentData.assert_not_called()
+
+ def test_propagates_mkdir_errors(self):
+ mock_api_client = self._create_mock_api_client()
+ outdir = os.path.join(self.get_temp_dir(), "some_file", "outdir")
+ with open(os.path.join(self.get_temp_dir(), "some_file"), "w"):
+ pass
+
+ with self.assertRaises(OSError):
+ exporter_lib.TensorBoardExporter(mock_api_client, outdir)
+
+ mock_api_client.StreamExperiments.assert_not_called()
+ mock_api_client.StreamExperimentData.assert_not_called()
+
+
+class MkdirPTest(tb_test.TestCase):
+
+ def test_makes_full_chain(self):
+ path = os.path.join(self.get_temp_dir(), "a", "b", "c")
+ exporter_lib._mkdir_p(path)
+ self.assertTrue(os.path.isdir(path))
+
+ def test_makes_leaf(self):
+ base = os.path.join(self.get_temp_dir(), "a", "b")
+ exporter_lib._mkdir_p(base)
+ leaf = os.path.join(self.get_temp_dir(), "a", "b", "c")
+ exporter_lib._mkdir_p(leaf)
+ self.assertTrue(os.path.isdir(leaf))
+
+ def test_fails_when_path_is_a_normal_file(self):
+ path = os.path.join(self.get_temp_dir(), "somefile")
+ with open(path, "w"):
+ pass
+ with self.assertRaises(OSError) as cm:
+ exporter_lib._mkdir_p(path)
+ self.assertEqual(cm.exception.errno, errno.EEXIST)
+
+ def test_propagates_other_errors(self):
+ base = os.path.join(self.get_temp_dir(), "somefile")
+ with open(base, "w"):
+ pass
+ leaf = os.path.join(self.get_temp_dir(), "somefile", "somedir")
+ with self.assertRaises(OSError) as cm:
+ exporter_lib._mkdir_p(leaf)
+ self.assertNotEqual(cm.exception.errno, errno.EEXIST)
+ if os.name == "nt":
+ expected_errno = errno.ENOENT
+ else:
+ expected_errno = errno.ENOTDIR
+ self.assertEqual(cm.exception.errno, expected_errno)
+
+
+class OpenExclTest(tb_test.TestCase):
+
+ def test_success(self):
+ path = os.path.join(self.get_temp_dir(), "test.txt")
+ with exporter_lib._open_excl(path) as outfile:
+ outfile.write("hello\n")
+ with open(path) as infile:
+ self.assertEqual(infile.read(), "hello\n")
+
+ def test_fails_when_file_exists(self):
+ path = os.path.join(self.get_temp_dir(), "test.txt")
+ with open(path, "w"):
+ pass
+ with self.assertRaises(exporter_lib.OutputFileExistsError) as cm:
+ exporter_lib._open_excl(path)
+ self.assertEqual(str(cm.exception), path)
+
+ def test_propagates_other_errors(self):
+ path = os.path.join(self.get_temp_dir(), "enoent", "test.txt")
+ with self.assertRaises(OSError) as cm:
+ exporter_lib._open_excl(path)
+ self.assertEqual(cm.exception.errno, errno.ENOENT)
+
+
+if __name__ == "__main__":
+ tb_test.main()
diff --git a/tensorboard/uploader/logdir_loader.py b/tensorboard/uploader/logdir_loader.py
new file mode 100644
index 0000000000..fc83d2428a
--- /dev/null
+++ b/tensorboard/uploader/logdir_loader.py
@@ -0,0 +1,107 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Loader for event file data for an entire TensorBoard log directory."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import os
+
+from tensorboard.backend.event_processing import directory_watcher
+from tensorboard.backend.event_processing import io_wrapper
+from tensorboard.util import tb_logging
+
+
+logger = tb_logging.get_logger()
+
+
+class LogdirLoader(object):
+ """Loader for a root log directory, maintaining multiple DirectoryLoaders.
+
+ This class takes a root log directory and a factory for DirectoryLoaders, and
+ maintains one DirectoryLoader per "logdir subdirectory" of the root logdir.
+
+ Note that this class is not thread-safe.
+ """
+
+ def __init__(self, logdir, directory_loader_factory):
+ """Constructs a new LogdirLoader.
+
+ Args:
+ logdir: The root log directory to load from.
+ directory_loader_factory: A factory for creating DirectoryLoaders. The
+ factory should take a path and return a DirectoryLoader.
+
+ Raises:
+ ValueError: If logdir or directory_loader_factory are None.
+ """
+ if logdir is None:
+ raise ValueError('A logdir is required')
+ if directory_loader_factory is None:
+ raise ValueError('A directory loader factory is required')
+ self._logdir = logdir
+ self._directory_loader_factory = directory_loader_factory
+ # Maps run names to corresponding DirectoryLoader instances.
+ self._directory_loaders = {}
+
+ def synchronize_runs(self):
+ """Finds new runs within `logdir` and makes `DirectoryLoaders` for them.
+
+ In addition, any existing `DirectoryLoader` whose run directory no longer
+ exists will be deleted.
+ """
+ logger.info('Starting logdir traversal of %s', self._logdir)
+ runs_seen = set()
+ for subdir in io_wrapper.GetLogdirSubdirectories(self._logdir):
+ run = os.path.relpath(subdir, self._logdir)
+ runs_seen.add(run)
+ if run not in self._directory_loaders:
+ logger.info('- Adding run for relative directory %s', run)
+ self._directory_loaders[run] = self._directory_loader_factory(subdir)
+ stale_runs = set(self._directory_loaders) - runs_seen
+ if stale_runs:
+ for run in stale_runs:
+ logger.info('- Removing run for relative directory %s', run)
+ del self._directory_loaders[run]
+ logger.info('Ending logdir traversal of %s', self._logdir)
+
+ def get_run_events(self):
+ """Returns tf.Event generators for each run's `DirectoryLoader`.
+
+ Warning: the generators are stateful and consuming them will affect the
+ results of any other existing generators for that run; calling code should
+ ensure it takes events from only a single generator per run at a time.
+
+ Returns:
+ Dictionary containing an entry for each run, mapping the run name to a
+ generator yielding tf.Event protobuf objects loaded from that run.
+ """
+ runs = list(self._directory_loaders)
+ logger.info('Creating event loading generators for %d runs', len(runs))
+ run_to_loader = collections.OrderedDict()
+ for run_name in sorted(runs):
+ loader = self._directory_loaders[run_name]
+ run_to_loader[run_name] = self._wrap_loader_generator(loader.Load())
+ return run_to_loader
+
+ def _wrap_loader_generator(self, loader_generator):
+ """Wraps `DirectoryLoader` generator to swallow `DirectoryDeletedError`."""
+ try:
+ for item in loader_generator:
+ yield item
+ except directory_watcher.DirectoryDeletedError:
+ return
diff --git a/tensorboard/uploader/logdir_loader_test.py b/tensorboard/uploader/logdir_loader_test.py
new file mode 100644
index 0000000000..c3d8e09f3c
--- /dev/null
+++ b/tensorboard/uploader/logdir_loader_test.py
@@ -0,0 +1,154 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorboard.uploader.logdir_loader."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+import shutil
+import six
+
+from tensorboard.uploader import logdir_loader
+from tensorboard import test as tb_test
+from tensorboard.backend.event_processing import directory_loader
+from tensorboard.backend.event_processing import event_file_loader
+from tensorboard.backend.event_processing import io_wrapper
+from tensorboard.util import test_util
+
+
+class LogdirLoaderTest(tb_test.TestCase):
+
+ def _create_logdir_loader(self, logdir):
+ def directory_loader_factory(path):
+ return directory_loader.DirectoryLoader(
+ path,
+ event_file_loader.TimestampedEventFileLoader,
+ path_filter=io_wrapper.IsTensorFlowEventsFile)
+ return logdir_loader.LogdirLoader(logdir, directory_loader_factory)
+
+ def _extract_tags(self, event_generator):
+ """Converts a generator of tf.Events into a list of event tags."""
+ return [
+ event.summary.value[0].tag for event in event_generator
+ if not event.file_version
+ ]
+
+ def _extract_run_to_tags(self, run_to_events):
+ """Returns run-to-tags dict from run-to-event-generator dict."""
+ run_to_tags = {}
+ for run_name, event_generator in six.iteritems(run_to_events):
+ # There should be no duplicate runs.
+ self.assertNotIn(run_name, run_to_tags)
+ run_to_tags[run_name] = self._extract_tags(event_generator)
+ return run_to_tags
+
+ def test_empty_logdir(self):
+ logdir = self.get_temp_dir()
+ loader = self._create_logdir_loader(logdir)
+ # Default state is empty.
+ self.assertEmpty(list(loader.get_run_events()))
+ loader.synchronize_runs()
+ # Still empty, since there's no data.
+ self.assertEmpty(list(loader.get_run_events()))
+
+ def test_single_event_logdir(self):
+ logdir = self.get_temp_dir()
+ with test_util.FileWriter(logdir) as writer:
+ writer.add_test_summary("foo")
+ loader = self._create_logdir_loader(logdir)
+ loader.synchronize_runs()
+ self.assertEqual(
+ self._extract_run_to_tags(loader.get_run_events()), {".": ["foo"]})
+ # A second load should indicate no new data for the run.
+ self.assertEqual(
+ self._extract_run_to_tags(loader.get_run_events()), {".": []})
+
+ def test_multiple_writes_to_logdir(self):
+ logdir = self.get_temp_dir()
+ with test_util.FileWriter(os.path.join(logdir, "a")) as writer:
+ writer.add_test_summary("tag_a")
+ with test_util.FileWriter(os.path.join(logdir, "b")) as writer:
+ writer.add_test_summary("tag_b")
+ with test_util.FileWriter(os.path.join(logdir, "b", "x")) as writer:
+ writer.add_test_summary("tag_b_x")
+ writer_c = test_util.FileWriter(os.path.join(logdir, "c"))
+ writer_c.add_test_summary("tag_c")
+ writer_c.flush()
+ loader = self._create_logdir_loader(logdir)
+ loader.synchronize_runs()
+ self.assertEqual(
+ self._extract_run_to_tags(loader.get_run_events()),
+ {"a": ["tag_a"], "b": ["tag_b"], "b/x": ["tag_b_x"], "c": ["tag_c"]})
+ # A second load should indicate no new data.
+ self.assertEqual(
+ self._extract_run_to_tags(loader.get_run_events()),
+ {"a": [], "b": [], "b/x": [], "c": []})
+ # Write some new data to both new and pre-existing event files.
+ with test_util.FileWriter(
+ os.path.join(logdir, "a"), filename_suffix=".other") as writer:
+ writer.add_test_summary("tag_a_2")
+ writer.add_test_summary("tag_a_3")
+ writer.add_test_summary("tag_a_4")
+ with test_util.FileWriter(
+ os.path.join(logdir, "b", "x"), filename_suffix=".other") as writer:
+ writer.add_test_summary("tag_b_x_2")
+ with writer_c as writer:
+ writer.add_test_summary("tag_c_2")
+ # New data should appear on the next load.
+ self.assertEqual(
+ self._extract_run_to_tags(loader.get_run_events()), {
+ "a": ["tag_a_2", "tag_a_3", "tag_a_4"],
+ "b": [],
+ "b/x": ["tag_b_x_2"],
+ "c": ["tag_c_2"]
+ })
+
+ def test_directory_deletion(self):
+ logdir = self.get_temp_dir()
+ with test_util.FileWriter(os.path.join(logdir, "a")) as writer:
+ writer.add_test_summary("tag_a")
+ with test_util.FileWriter(os.path.join(logdir, "b")) as writer:
+ writer.add_test_summary("tag_b")
+ with test_util.FileWriter(os.path.join(logdir, "c")) as writer:
+ writer.add_test_summary("tag_c")
+ loader = self._create_logdir_loader(logdir)
+ loader.synchronize_runs()
+ self.assertEqual(list(loader.get_run_events().keys()), ["a", "b", "c"])
+ shutil.rmtree(os.path.join(logdir, "b"))
+ loader.synchronize_runs()
+ self.assertEqual(list(loader.get_run_events().keys()), ["a", "c"])
+ shutil.rmtree(logdir)
+ loader.synchronize_runs()
+ self.assertEmpty(loader.get_run_events())
+
+ def test_directory_deletion_during_event_loading(self):
+ logdir = self.get_temp_dir()
+ with test_util.FileWriter(logdir) as writer:
+ writer.add_test_summary("foo")
+ loader = self._create_logdir_loader(logdir)
+ loader.synchronize_runs()
+ self.assertEqual(
+ self._extract_run_to_tags(loader.get_run_events()), {".": ["foo"]})
+ shutil.rmtree(logdir)
+ runs_to_events = loader.get_run_events()
+ self.assertEqual(list(runs_to_events.keys()), ["."])
+ events = runs_to_events["."]
+ self.assertEqual(self._extract_tags(events), [])
+
+
+if __name__ == "__main__":
+ tb_test.main()
diff --git a/tensorboard/uploader/peekable_iterator.py b/tensorboard/uploader/peekable_iterator.py
new file mode 100644
index 0000000000..59dea6a3ae
--- /dev/null
+++ b/tensorboard/uploader/peekable_iterator.py
@@ -0,0 +1,87 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Iterator adapter that supports peeking ahead."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+class PeekableIterator(object):
+ """Iterator adapter that supports peeking ahead.
+
+ As with most Python iterators, this is also iterable; its `__iter__`
+ returns itself.
+
+ This class is not thread-safe. Use external synchronization if
+ iterating concurrently.
+ """
+
+ def __init__(self, iterable):
+ """Initializes a peeking iterator wrapping the provided iterable.
+
+ Args:
+ iterable: An iterable to wrap.
+ """
+ self._iterator = iter(iterable)
+ self._has_peeked = False
+ self._peeked_element = None
+
+ def has_next(self):
+ """Checks whether there are any more items in this iterator.
+
+ The next call to `next` or `peek` will raise `StopIteration` if and
+ only if this method returns `False`.
+
+ Returns:
+ `True` if there are any more items in this iterator, else `False`.
+ """
+ try:
+ self.peek()
+ return True
+ except StopIteration:
+ return False
+
+ def peek(self):
+ """Gets the next item in the iterator without consuming it.
+
+ Multiple consecutive calls will return the same element.
+
+ Returns:
+ The value that would be returned by `next`.
+
+ Raises:
+ StopIteration: If there are no more items in the iterator.
+ """
+ if not self._has_peeked:
+ self._peeked_element = next(self._iterator)
+ self._has_peeked = True
+ return self._peeked_element
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self._has_peeked:
+ self._has_peeked = False
+ result = self._peeked_element
+ self._peeked_element = None # allow GC
+ return result
+ else:
+ return next(self._iterator)
+
+ def next(self):
+ # (Like `__next__`, but Python 2.)
+ return self.__next__()
diff --git a/tensorboard/uploader/peekable_iterator_test.py b/tensorboard/uploader/peekable_iterator_test.py
new file mode 100644
index 0000000000..e2c52505d0
--- /dev/null
+++ b/tensorboard/uploader/peekable_iterator_test.py
@@ -0,0 +1,68 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorboard.uploader.peekable_iterator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorboard.uploader import peekable_iterator
+from tensorboard import test as tb_test
+
+
+class PeekableIteratorTest(tb_test.TestCase):
+ """Tests for `PeekableIterator`."""
+
+ def test_empty_iteration(self):
+ it = peekable_iterator.PeekableIterator([])
+ self.assertEqual(list(it), [])
+
+ def test_normal_iteration(self):
+ it = peekable_iterator.PeekableIterator([1, 2, 3])
+ self.assertEqual(list(it), [1, 2, 3])
+
+ def test_simple_peek(self):
+ it = peekable_iterator.PeekableIterator([1, 2, 3])
+ self.assertEqual(it.peek(), 1)
+ self.assertEqual(it.peek(), 1)
+ self.assertEqual(next(it), 1)
+ self.assertEqual(it.peek(), 2)
+ self.assertEqual(next(it), 2)
+ self.assertEqual(next(it), 3)
+ self.assertEqual(list(it), [])
+
+ def test_simple_has_next(self):
+ it = peekable_iterator.PeekableIterator([1, 2])
+ self.assertTrue(it.has_next())
+ self.assertEqual(it.peek(), 1)
+ self.assertTrue(it.has_next())
+ self.assertEqual(next(it), 1)
+ self.assertEqual(it.peek(), 2)
+ self.assertTrue(it.has_next())
+ self.assertEqual(next(it), 2)
+ self.assertFalse(it.has_next())
+ self.assertFalse(it.has_next())
+
+ def test_peek_after_end(self):
+ it = peekable_iterator.PeekableIterator([1, 2, 3])
+ self.assertEqual(list(it), [1, 2, 3])
+ with self.assertRaises(StopIteration):
+ it.peek()
+ with self.assertRaises(StopIteration):
+ it.peek()
+
+
+if __name__ == "__main__":
+ tb_test.main()
diff --git a/tensorboard/uploader/proto/BUILD b/tensorboard/uploader/proto/BUILD
new file mode 100644
index 0000000000..9353b897e6
--- /dev/null
+++ b/tensorboard/uploader/proto/BUILD
@@ -0,0 +1,20 @@
+package(default_visibility = ["//tensorboard:internal"])
+
+load("//tensorboard/defs:protos.bzl", "tb_proto_library")
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+tb_proto_library(
+ name = "protos_all",
+ srcs = [
+ "export_service.proto",
+ "scalar.proto",
+ "write_service.proto",
+ ],
+ has_services = True,
+ deps = [
+ "//tensorboard/compat/proto:protos_all",
+ ],
+)
diff --git a/tensorboard/uploader/proto/export_service.proto b/tensorboard/uploader/proto/export_service.proto
new file mode 100644
index 0000000000..54ad5ab751
--- /dev/null
+++ b/tensorboard/uploader/proto/export_service.proto
@@ -0,0 +1,87 @@
+syntax = "proto3";
+
+package tensorboard.service;
+
+import "google/protobuf/timestamp.proto";
+import "tensorboard/compat/proto/summary.proto";
+
+// Service for exporting data from TensorBoard.dev.
+service TensorBoardExporterService {
+ // Stream the experiment_id of all the experiments owned by the caller.
+ rpc StreamExperiments(StreamExperimentsRequest)
+ returns (stream StreamExperimentsResponse) {}
+ // Stream scalars for all the runs and tags in an experiment.
+ rpc StreamExperimentData(StreamExperimentDataRequest)
+ returns (stream StreamExperimentDataResponse) {}
+}
+
+// Request to stream the experiment_id of all the experiments owned by the
+// caller from TensorBoard.dev.
+message StreamExperimentsRequest {
+ // Timestamp to get a consistent snapshot of the data in the database.
+ // This is useful when making multiple read RPCs and needing the data to be
+ // consistent across the read calls.
+ google.protobuf.Timestamp read_timestamp = 1;
+ // User ID defaults to the caller, but may be set to a different user for
+ // internal Takeout processes operating on behalf of a user.
+ string user_id = 2;
+ // Limits the number of experiment IDs returned. This is useful to check if
+ // user might have any data by setting limit=1. Also useful to preview the
+ // list of experiments.
+ int64 limit = 3;
+ // TODO(@karthikv2k): Support pagination.
+}
+
+// Streams experiment IDs returned from TensorBoard.dev.
+message StreamExperimentsResponse {
+ // List of experiment IDs for the experiments owned by the user. The entire
+ // list of experiments owned by the user is streamed in batches and each batch
+ // contains a list of experiment IDs. A consumer of this stream needs to
+ // concatenate all these lists to get the full response. The order of
+ // experiment IDs in the stream is not defined.
+ repeated string experiment_ids = 1;
+}
+
+// Request to stream scalars from all the runs and tags in an experiment.
+message StreamExperimentDataRequest {
+ // The permanent ID of the experiment whose data need to be streamed.
+ string experiment_id = 1;
+ // Timestamp to get a consistent snapshot of the data in the database.
+ // This is useful when making multiple read RPCs and needing the data to be
+ // consistent across the read calls. Should be the same as the read timestamp
+ // used for the corresponding `StreamExperimentsRequest` for consistency.
+ google.protobuf.Timestamp read_timestamp = 2;
+}
+
+// Streams scalars from all the runs and tags in an experiment. Each stream
+// result only contains data for a single tag from a single run. For example if
+// there are five runs and each run had two tags, the RPC will return a stream
+// of at least ten `StreamExperimentDataResponse`s, each one having the
+// scalars for one tag. The values from a single tag may be split among multiple
+// responses. Users need to aggregate information from entire stream to get
+// data for the entire experiment. Empty experiments will have zero stream
+// results. Empty runs that doesn't have any tags need not be supported by a
+// hosted service.
+message StreamExperimentDataResponse {
+ // Name of the tag whose data is contained in this response.
+ string tag_name = 1;
+ // Name of the run that contains the tag `tag_name`.
+ string run_name = 2;
+ // The metadata of the tag `tag_name`.
+ .tensorboard.SummaryMetadata tag_metadata = 3;
+ // Data to store for the tag `tag_name.
+ ScalarPoints points = 4;
+
+ // Data for the scalars are stored in a columnar fashion to optimize it for
+ // exporting the data into textual formats like JSON.
+ // The data for the ith scalar is { steps[i], wall_times[i], values[i] }.
+ // The data here is sorted by step values in ascending order.
+ message ScalarPoints {
+ // Step index within the run.
+ repeated int64 steps = 1;
+ // Timestamp of the creation of this point.
+ repeated google.protobuf.Timestamp wall_times = 2;
+ // Value of the point at this step / timestamp.
+ repeated double values = 3;
+ }
+}
diff --git a/tensorboard/uploader/proto/scalar.proto b/tensorboard/uploader/proto/scalar.proto
new file mode 100644
index 0000000000..462cf4df5c
--- /dev/null
+++ b/tensorboard/uploader/proto/scalar.proto
@@ -0,0 +1,28 @@
+syntax = "proto3";
+
+package tensorboard.service;
+
+import "google/protobuf/timestamp.proto";
+import "tensorboard/compat/proto/summary.proto";
+
+// One point viewable on a scalar metric plot.
+message ScalarPoint {
+ // Step index within the run.
+ int64 step = 1;
+ // Timestamp of the creation of this point.
+ google.protobuf.Timestamp wall_time = 2;
+ // Value of the point at this step / timestamp.
+ double value = 3;
+}
+
+// Metadata for the ScalarPoints stored for one (Experiment, Run, Tag).
+message ScalarPointMetadata {
+ // Maximum step recorded for the tag.
+ int64 max_step = 1;
+ // Timestamp corresponding to the max step.
+ google.protobuf.Timestamp max_wall_time = 2;
+ // Information about the plugin which created this scalar data.
+ // Note: The period is required part of the type here due to the
+ // package name resolution logic.
+ .tensorboard.SummaryMetadata summary_metadata = 3;
+}
diff --git a/tensorboard/uploader/proto/write_service.proto b/tensorboard/uploader/proto/write_service.proto
new file mode 100644
index 0000000000..4d63f22f69
--- /dev/null
+++ b/tensorboard/uploader/proto/write_service.proto
@@ -0,0 +1,130 @@
+syntax = "proto3";
+
+package tensorboard.service;
+
+import "tensorboard/uploader/proto/scalar.proto";
+import "tensorboard/compat/proto/summary.proto";
+
+// Service for writing data to TensorBoard.dev.
+service TensorBoardWriterService {
+ // Request for a new location to write TensorBoard readable events.
+ rpc CreateExperiment(CreateExperimentRequest)
+ returns (CreateExperimentResponse) {}
+ // Request that an experiment be deleted, along with all tags and scalars
+ // that it contains. This call may only be made by the original owner of the
+ // experiment.
+ rpc DeleteExperiment(DeleteExperimentRequest)
+ returns (DeleteExperimentResponse) {}
+ // Request that unreachable data be purged. Used only for testing;
+ // disabled in production.
+ rpc PurgeData(PurgeDataRequest) returns (PurgeDataResponse) {}
+ // Request additional data be stored in TensorBoard.dev.
+ rpc WriteScalar(WriteScalarRequest) returns (WriteScalarResponse) {}
+ // Request that the calling user and all their data be permanently deleted.
+ // Used for testing purposes.
+ rpc DeleteOwnUser(DeleteOwnUserRequest) returns (DeleteOwnUserResponse) {}
+}
+
+// This is currently empty on purpose. No information is necessary
+// to request a URL, except. authorization of course, which doesn't
+// come within the proto.
+message CreateExperimentRequest {
+ // This is empty on purpose.
+}
+
+// Carries all information necessary to:
+// 1. Inform the user where to navigate to see their TensorBoard.
+// 2. Subsequently load (Scalars, Tensors, etc.) to the specified location.
+message CreateExperimentResponse {
+ // Service-wide unique identifier of an uploaded log dir.
+ // eg: "1r9d0kQkh2laODSZcQXWP"
+ string experiment_id = 1;
+ // Url the user should navigate to to see their TensorBoard
+ // eg: "https://example.com/public/1r9d0kQkh2laODSZcQXWP"
+ string url = 2;
+}
+
+message DeleteExperimentRequest {
+ // Service-wide unique identifier of an uploaded log dir.
+ // eg: "1r9d0kQkh2laODSZcQXWP"
+ string experiment_id = 1;
+}
+
+message DeleteExperimentResponse {
+ // This is empty on purpose.
+}
+
+// Only used for testing; corresponding RPC is disabled in prod.
+message PurgeDataRequest {
+ // Maximum number of entities of a given kind to purge at once (e.g.,
+ // maximum number of tags to purge). Required; must be positive.
+ int32 batch_limit = 1;
+}
+
+// Only used for testing; corresponding RPC is disabled in prod.
+message PurgeDataResponse {
+ // Stats about how many elements where purged. Compare to the batch
+ // limit specified in the request to estimate whether the backlog has
+ // any more items.
+ PurgeStats purge_stats = 1;
+}
+
+// Details about what actions were taken as a result of a purge request.
+// These values are upper bounds; they may exceed the true values.
+message PurgeStats {
+ // Number of tags deleted as a result of this request.
+ int32 tags = 1;
+ // Number of experiments marked as purged as a result of this request.
+ int32 experiments = 2;
+ // Number of users deleted as a result of this request.
+ int32 users = 3;
+}
+
+// Carries all that is needed to add additional run data to the hosted service.
+message WriteScalarRequest {
+ // All the data to store for one Run. This data will be stored under the
+ // corresponding run in the hosted storage. WriteScalarRequest is merged into
+ // the data store for the keyed run. The tags and included scalars will be
+ // the union of the data sent across all WriteScalarRequests. Metadata by
+ // default uses a 'first write wins' approach.
+ message Run {
+ // The name of this run. For example "/some/path/mnist_experiments/run1/"
+ string name = 1;
+ // Data to store for this Run/Tag combination.
+ repeated Tag tags = 2;
+ }
+
+ // All the data to store for one Tag of one Run. This data will be stored
+ // under the corresponding run/tag in the hosted storage. A tag corresponds to
+ // a single time series.
+ message Tag {
+ // The name of this tag. For example "loss"
+ string name = 1;
+ // Data to store for this Run/Tag combination.
+ repeated ScalarPoint points = 2;
+ // The metadata of this tag.
+ .tensorboard.SummaryMetadata metadata = 3;
+ }
+
+ // Which experiment to write to - corresponding to one hosted TensorBoard URL.
+ // The requester must have authorization to write to this location.
+ string experiment_id = 1;
+ // Data to append to the existing storage at the experiment_id.
+ repeated Run runs = 2;
+}
+
+// Everything the caller needs to know about how the writing went.
+// (Currently empty)
+message WriteScalarResponse {
+ // This is empty on purpose.
+}
+
+// Requests that the calling user and all their data be permanently deleted.
+message DeleteOwnUserRequest {
+ // This is empty on purpose.
+}
+
+// Everything the caller needs to know about how the deletion went.
+message DeleteOwnUserResponse {
+ // This is empty on purpose.
+}
diff --git a/tensorboard/uploader/test_util.py b/tensorboard/uploader/test_util.py
new file mode 100644
index 0000000000..d6c5cca9d1
--- /dev/null
+++ b/tensorboard/uploader/test_util.py
@@ -0,0 +1,63 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for testing."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+
+import grpc
+
+from google.protobuf import timestamp_pb2
+from tensorboard.compat.proto import summary_pb2
+
+
+class FakeTime(object):
+ """Thread-safe fake replacement for the `time` module."""
+
+ def __init__(self, current=0.0):
+ self._time = float(current)
+ self._lock = threading.Lock()
+
+ def time(self):
+ with self._lock:
+ return self._time
+
+ def sleep(self, secs):
+ with self._lock:
+ self._time += secs
+
+
+def scalar_metadata(display_name):
+ """Makes a scalar metadata proto, for constructing expected requests."""
+ metadata = summary_pb2.SummaryMetadata(display_name=display_name)
+ metadata.plugin_data.plugin_name = "scalars"
+ return metadata
+
+
+def grpc_error(code, details):
+ # Monkey patch insertion for the methods a real grpc.RpcError would have.
+ error = grpc.RpcError("RPC error %r: %s" % (code, details))
+ error.code = lambda: code
+ error.details = lambda: details
+ return error
+
+
+def timestamp_pb(nanos):
+ result = timestamp_pb2.Timestamp()
+ result.FromNanoseconds(nanos)
+ return result
diff --git a/tensorboard/uploader/uploader.py b/tensorboard/uploader/uploader.py
new file mode 100644
index 0000000000..222a23977c
--- /dev/null
+++ b/tensorboard/uploader/uploader.py
@@ -0,0 +1,413 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Uploads a TensorBoard logdir to TensorBoard.dev."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import time
+
+import grpc
+import six
+
+from tensorboard.uploader.proto import write_service_pb2
+from tensorboard.uploader import logdir_loader
+from tensorboard.uploader import peekable_iterator
+from tensorboard.uploader import util
+from tensorboard import data_compat
+from tensorboard.backend.event_processing import directory_loader
+from tensorboard.backend.event_processing import event_file_loader
+from tensorboard.backend.event_processing import io_wrapper
+from tensorboard.plugins.scalar import metadata as scalar_metadata
+from tensorboard.util import grpc_util
+from tensorboard.util import tb_logging
+from tensorboard.util import tensor_util
+
+# Minimum length of an upload cycle in seconds; shorter cycles will sleep to
+# use up the rest of the time to avoid sending write RPCs too quickly.
+_MIN_UPLOAD_CYCLE_DURATION_SECS = 5
+
+# Age in seconds of last write after which an event file is considered inactive.
+# TODO(@nfelt): consolidate with TensorBoard --reload_multifile default logic.
+_EVENT_FILE_INACTIVE_SECS = 4000
+
+# Maximum length of a base-128 varint as used to encode a 64-bit value
+# (without the "msb of last byte is bit 63" optimization, to be
+# compatible with protobuf and golang varints).
+_MAX_VARINT64_LENGTH_BYTES = 10
+
+# Maximum outgoing request size. The server-side limit is 4 MiB [1]; we
+# should pad a bit to mitigate any errors in our bookkeeping. Currently,
+# we pad a lot, because using higher request sizes causes occasional
+# Deadline Exceeded errors in the RPC server.
+#
+# [1]: https://github.com/grpc/grpc/blob/e70d8582b4b0eedc45e3d25a57b58a08b94a9f4a/include/grpc/impl/codegen/grpc_types.h#L447 # pylint: disable=line-too-long
+_MAX_REQUEST_LENGTH_BYTES = 1024 * 128
+
+logger = tb_logging.get_logger()
+
+
+class TensorBoardUploader(object):
+ """Uploads a TensorBoard logdir to TensorBoard.dev."""
+
+ def __init__(self, writer_client, logdir, rate_limiter=None):
+ """Constructs a TensorBoardUploader.
+
+ Args:
+ writer_client: a TensorBoardWriterService stub instance
+ logdir: path of the log directory to upload
+ rate_limiter: a `RateLimiter` to use to limit upload cycle frequency
+ """
+ self._api = writer_client
+ self._logdir = logdir
+ self._request_builder = None
+ if rate_limiter is None:
+ self._rate_limiter = util.RateLimiter(_MIN_UPLOAD_CYCLE_DURATION_SECS)
+ else:
+ self._rate_limiter = rate_limiter
+ active_filter = lambda secs: secs + _EVENT_FILE_INACTIVE_SECS >= time.time()
+ directory_loader_factory = functools.partial(
+ directory_loader.DirectoryLoader,
+ loader_factory=event_file_loader.TimestampedEventFileLoader,
+ path_filter=io_wrapper.IsTensorFlowEventsFile,
+ active_filter=active_filter)
+ self._logdir_loader = logdir_loader.LogdirLoader(
+ self._logdir, directory_loader_factory)
+
+ def create_experiment(self):
+ """Creates an Experiment for this upload session and returns the URL."""
+ logger.info("Creating experiment")
+ request = write_service_pb2.CreateExperimentRequest()
+ response = grpc_util.call_with_retries(self._api.CreateExperiment, request)
+ self._request_builder = _RequestBuilder(response.experiment_id)
+ return response.url
+
+ def start_uploading(self):
+ """Blocks forever to continuously upload data from the logdir.
+
+ Raises:
+ RuntimeError: If `create_experiment` has not yet been called.
+ ExperimentNotFoundError: If the experiment is deleted during the
+ course of the upload.
+ """
+ if self._request_builder is None:
+ raise RuntimeError(
+ "Must call create_experiment() before start_uploading()")
+ while True:
+ self._upload_once()
+
+ def _upload_once(self):
+ """Runs one upload cycle, sending zero or more RPCs."""
+ logger.info("Starting an upload cycle")
+ self._rate_limiter.tick()
+
+ sync_start_time = time.time()
+ self._logdir_loader.synchronize_runs()
+ sync_duration_secs = time.time() - sync_start_time
+ logger.info("Logdir sync took %.3f seconds", sync_duration_secs)
+
+ run_to_events = self._logdir_loader.get_run_events()
+ first_request = True
+ for request in self._request_builder.build_requests(run_to_events):
+ if not first_request:
+ self._rate_limiter.tick()
+ first_request = False
+ upload_start_time = time.time()
+ request_bytes = request.ByteSize()
+ logger.info("Trying request of %d bytes", request_bytes)
+ self._upload(request)
+ upload_duration_secs = time.time() - upload_start_time
+ logger.info(
+ "Upload for %d runs (%d bytes) took %.3f seconds",
+ len(request.runs),
+ request_bytes,
+ upload_duration_secs)
+
+ def _upload(self, request):
+ try:
+ # TODO(@nfelt): execute this RPC asynchronously.
+ grpc_util.call_with_retries(self._api.WriteScalar, request)
+ except grpc.RpcError as e:
+ if e.code() == grpc.StatusCode.NOT_FOUND:
+ raise ExperimentNotFoundError()
+ logger.error("Upload call failed with error %s", e)
+
+
+def delete_experiment(writer_client, experiment_id):
+ """Permanently deletes an experiment and all of its contents.
+
+ Args:
+ writer_client: a TensorBoardWriterService stub instance
+ experiment_id: string ID of the experiment to delete
+
+ Raises:
+ ExperimentNotFoundError: If no such experiment exists.
+ PermissionDeniedError: If the user is not authorized to delete this
+ experiment.
+ RuntimeError: On unexpected failure.
+ """
+ logger.info("Deleting experiment %r", experiment_id)
+ request = write_service_pb2.DeleteExperimentRequest()
+ request.experiment_id = experiment_id
+ try:
+ grpc_util.call_with_retries(writer_client.DeleteExperiment, request)
+ except grpc.RpcError as e:
+ if e.code() == grpc.StatusCode.NOT_FOUND:
+ raise ExperimentNotFoundError()
+ if e.code() == grpc.StatusCode.PERMISSION_DENIED:
+ raise PermissionDeniedError()
+ raise
+
+
+class ExperimentNotFoundError(RuntimeError):
+ pass
+
+
+class PermissionDeniedError(RuntimeError):
+ pass
+
+
+class _OutOfSpaceError(Exception):
+ """Action could not proceed without overflowing request budget.
+
+ This is a signaling exception (like `StopIteration`) used internally
+ by `_RequestBuilder`; it does not mean that anything has gone wrong.
+ """
+ pass
+
+
+class _RequestBuilder(object):
+ """Helper class for building requests that fit under a size limit.
+
+ This class is not threadsafe. Use external synchronization if calling
+ its methods concurrently.
+ """
+
+ _NON_SCALAR_TIME_SERIES = object() # sentinel
+
+ def __init__(self, experiment_id):
+ self._experiment_id = experiment_id
+ # The request currently being populated.
+ self._request = None # type: write_service_pb2.WriteScalarRequest
+ # A lower bound on the number of bytes that we may yet add to the
+ # request.
+ self._byte_budget = None # type: int
+ # Map from `(run_name, tag_name)` to `SummaryMetadata` if the time
+ # series is a scalar time series, else to `_NON_SCALAR_TIME_SERIES`.
+ self._tag_metadata = {}
+
+ def _new_request(self):
+ """Allocates a new request and refreshes the budget."""
+ self._request = write_service_pb2.WriteScalarRequest()
+ self._byte_budget = _MAX_REQUEST_LENGTH_BYTES
+ self._request.experiment_id = self._experiment_id
+ self._byte_budget -= self._request.ByteSize()
+ if self._byte_budget < 0:
+ raise RuntimeError("Byte budget too small for experiment ID")
+
+ def build_requests(self, run_to_events):
+ """Converts a stream of TF events to a stream of outgoing requests.
+
+ Each yielded request will be at most `_MAX_REQUEST_LENGTH_BYTES`
+ bytes long.
+
+ Args:
+ run_to_events: Mapping from run name to generator of `tf.Event`
+ values, as returned by `LogdirLoader.get_run_events`.
+
+ Yields:
+ A finite stream of `WriteScalarRequest` objects.
+
+ Raises:
+ RuntimeError: If no progress can be made because even a single
+ point is too large (say, due to a gigabyte-long tag name).
+ """
+
+ self._new_request()
+ runs = {} # cache: map from run name to `Run` proto in request
+ tags = {} # cache: map from `(run, tag)` to `Tag` proto in run in request
+ work_items = peekable_iterator.PeekableIterator(
+ self._run_values(run_to_events))
+
+ while work_items.has_next():
+ (run_name, event, orig_value) = work_items.peek()
+ value = data_compat.migrate_value(orig_value)
+ time_series_key = (run_name, value.tag)
+
+ metadata = self._tag_metadata.get(time_series_key)
+ if metadata is None:
+ plugin_name = value.metadata.plugin_data.plugin_name
+ if plugin_name == scalar_metadata.PLUGIN_NAME:
+ metadata = value.metadata
+ else:
+ metadata = _RequestBuilder._NON_SCALAR_TIME_SERIES
+ self._tag_metadata[time_series_key] = metadata
+ if metadata is _RequestBuilder._NON_SCALAR_TIME_SERIES:
+ next(work_items)
+ continue
+ try:
+ run_proto = runs.get(run_name)
+ if run_proto is None:
+ run_proto = self._create_run(run_name)
+ runs[run_name] = run_proto
+ tag_proto = tags.get((run_name, value.tag))
+ if tag_proto is None:
+ tag_proto = self._create_tag(run_proto, value.tag, metadata)
+ tags[(run_name, value.tag)] = tag_proto
+ self._create_point(tag_proto, event, value)
+ next(work_items)
+ except _OutOfSpaceError:
+ # Flush request and start a new one.
+ request_to_emit = self._prune_request()
+ if request_to_emit is None:
+ raise RuntimeError("Could not make progress uploading data")
+ self._new_request()
+ runs.clear()
+ tags.clear()
+ yield request_to_emit
+
+ final_request = self._prune_request()
+ if final_request is not None:
+ yield final_request
+
+ def _run_values(self, run_to_events):
+ """Helper generator to create a single stream of work items."""
+ # Note that each of these joins in principle has deletion anomalies:
+ # if the input stream contains runs with no events, or events with
+ # no values, we'll lose that information. This is not a problem: we
+ # would need to prune such data from the request anyway.
+ for (run_name, events) in six.iteritems(run_to_events):
+ for event in events:
+ for value in event.summary.value:
+ yield (run_name, event, value)
+
+ def _prune_request(self):
+ """Removes empty runs and tags from the active request.
+
+ This does not refund `self._byte_budget`; it is assumed that the
+ request will be emitted immediately, anyway.
+
+ Returns:
+ The active request, or `None` if after pruning the request
+ contains no data.
+ """
+ request = self._request
+ for (run_idx, run) in reversed(list(enumerate(request.runs))):
+ for (tag_idx, tag) in reversed(list(enumerate(run.tags))):
+ if not tag.points:
+ del run.tags[tag_idx]
+ if not run.tags:
+ del self._request.runs[run_idx]
+ if not request.runs:
+ request = None
+ return request
+
+ def _create_run(self, run_name):
+ """Adds a run to the live request, if there's space.
+
+ Args:
+ run_name: String name of the run to add.
+
+ Returns:
+ The `WriteScalarRequest.Run` that was added to `request.runs`.
+
+ Raises:
+ _OutOfSpaceError: If adding the run would exceed the remaining
+ request budget.
+ """
+ run_proto = self._request.runs.add(name=run_name)
+ # We can't calculate the proto key cost exactly ahead of time, as
+ # it depends on the total size of all tags. Be conservative.
+ cost = run_proto.ByteSize() + _MAX_VARINT64_LENGTH_BYTES + 1
+ if cost > self._byte_budget:
+ raise _OutOfSpaceError()
+ self._byte_budget -= cost
+ return run_proto
+
+ def _create_tag(self, run_proto, tag_name, metadata):
+ """Adds a tag for the given value, if there's space.
+
+ Args:
+ run_proto: `WriteScalarRequest.Run` proto to which to add a tag.
+ tag_name: String name of the tag to add (as `value.tag`).
+ metadata: TensorBoard `SummaryMetadata` proto from the first
+ occurrence of this time series.
+
+ Returns:
+ The `WriteScalarRequest.Tag` that was added to `run_proto.tags`.
+
+ Raises:
+ _OutOfSpaceError: If adding the tag would exceed the remaining
+ request budget.
+ """
+ tag_proto = run_proto.tags.add(name=tag_name)
+ tag_proto.metadata.CopyFrom(metadata)
+ submessage_cost = tag_proto.ByteSize()
+ # We can't calculate the proto key cost exactly ahead of time, as
+ # it depends on the number of points. Be conservative.
+ cost = submessage_cost + _MAX_VARINT64_LENGTH_BYTES + 1
+ if cost > self._byte_budget:
+ raise _OutOfSpaceError()
+ self._byte_budget -= cost
+ return tag_proto
+
+ def _create_point(self, tag_proto, event, value):
+ """Adds a scalar point to the given tag, if there's space.
+
+ Args:
+ tag_proto: `WriteScalarRequest.Tag` proto to which to add a point.
+ event: Enclosing `Event` proto with the step and wall time data.
+ value: Scalar `Summary.Value` proto with the actual scalar data.
+
+ Returns:
+ The `ScalarPoint` that was added to `tag_proto.points`.
+
+ Raises:
+ _OutOfSpaceError: If adding the point would exceed the remaining
+ request budget.
+ """
+ point = tag_proto.points.add()
+ point.step = event.step
+ # TODO(@nfelt): skip tensor roundtrip for Value with simple_value set
+ point.value = tensor_util.make_ndarray(value.tensor).item()
+ util.set_timestamp(point.wall_time, event.wall_time)
+ submessage_cost = point.ByteSize()
+ cost = submessage_cost + _varint_cost(submessage_cost) + 1 # proto key
+ if cost > self._byte_budget:
+ tag_proto.points.pop()
+ raise _OutOfSpaceError()
+ self._byte_budget -= cost
+ return point
+
+
+def _varint_cost(n):
+ """Computes the size of `n` encoded as an unsigned base-128 varint.
+
+ This should be consistent with the proto wire format:
+
+
+ Args:
+ n: A non-negative integer.
+
+ Returns:
+ An integer number of bytes.
+ """
+ result = 1
+ while n >= 128:
+ result += 1
+ n >>= 7
+ return result
diff --git a/tensorboard/uploader/uploader_main.py b/tensorboard/uploader/uploader_main.py
new file mode 100644
index 0000000000..d9fc953847
--- /dev/null
+++ b/tensorboard/uploader/uploader_main.py
@@ -0,0 +1,474 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Main program for the TensorBoard.dev uploader."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import json
+import os
+import sys
+import textwrap
+
+from absl import app
+from absl import logging
+from absl.flags import argparse_flags
+import grpc
+import six
+
+from tensorboard.uploader import dev_creds
+from tensorboard.uploader.proto import export_service_pb2_grpc
+from tensorboard.uploader.proto import write_service_pb2_grpc
+from tensorboard.uploader import auth
+from tensorboard.uploader import exporter as exporter_lib
+from tensorboard.uploader import uploader as uploader_lib
+from tensorboard import program
+from tensorboard.plugins import base_plugin
+
+
+# Temporary integration point for absl compatibility; will go away once
+# migrated to TensorBoard subcommand.
+_FLAGS = None
+
+
+_MESSAGE_TOS = u"""\
+Your use of this service is subject to Google's Terms of Service
+ and Privacy Policy
+, and TensorBoard.dev's Terms of Service
+.
+
+This notice will not be shown again while you are logged into the uploader.
+To log out, run `tensorboard dev auth revoke`.
+"""
+
+
+_SUBCOMMAND_FLAG = '_uploader__subcommand'
+_SUBCOMMAND_KEY_UPLOAD = 'UPLOAD'
+_SUBCOMMAND_KEY_DELETE = 'DELETE'
+_SUBCOMMAND_KEY_EXPORT = 'EXPORT'
+_SUBCOMMAND_KEY_AUTH = 'AUTH'
+_AUTH_SUBCOMMAND_FLAG = '_uploader__subcommand_auth'
+_AUTH_SUBCOMMAND_KEY_REVOKE = 'REVOKE'
+
+
+def _prompt_for_user_ack(intent):
+ """Prompts for user consent, exiting the program if they decline."""
+ body = intent.get_ack_message_body()
+ header = '\n***** TensorBoard Uploader *****\n'
+ user_ack_message = '\n'.join((header, body, _MESSAGE_TOS))
+ sys.stderr.write(user_ack_message)
+ sys.stderr.write('\n')
+ response = six.moves.input('Continue? (yes/NO) ')
+ if response.lower() not in ('y', 'yes'):
+ sys.exit(0)
+ sys.stderr.write('\n')
+
+
+def _define_flags(parser):
+ """Configures flags on the provided argument parser.
+
+ Integration point for `tensorboard.program`'s subcommand system.
+
+ Args:
+ parser: An `argparse.ArgumentParser` to be mutated.
+ """
+
+ subparsers = parser.add_subparsers()
+
+ parser.add_argument(
+ '--endpoint',
+ type=str,
+ default='api.tensorboard.dev:443',
+ help='URL for the API server accepting write requests.')
+
+ parser.add_argument(
+ '--grpc_creds_type',
+ type=str,
+ default='ssl',
+ choices=('local', 'ssl', 'ssl_dev'),
+ help='The type of credentials to use for the gRPC client')
+
+ parser.add_argument(
+ '--auth_force_console',
+ action='store_true',
+ help='Set to true to force authentication flow to use the '
+ '--console rather than a browser redirect to localhost.')
+
+ upload = subparsers.add_parser(
+ 'upload', help='upload an experiment to TensorBoard.dev')
+ upload.set_defaults(**{_SUBCOMMAND_FLAG: _SUBCOMMAND_KEY_UPLOAD})
+ upload.add_argument(
+ '--logdir',
+ metavar='PATH',
+ type=str,
+ default=None,
+ help='Directory containing the logs to process')
+
+ delete = subparsers.add_parser(
+ 'delete',
+ help='permanently delete an experiment',
+ inherited_absl_flags=None)
+ delete.set_defaults(**{_SUBCOMMAND_FLAG: _SUBCOMMAND_KEY_DELETE})
+ # We would really like to call this next flag `--experiment` rather
+ # than `--experiment_id`, but this is broken inside Google due to a
+ # long-standing Python bug:
+ # (Some Google-internal dependencies define `--experimental_*` flags.)
+ # This isn't exactly a principled fix, but it gets the job done.
+ delete.add_argument(
+ '--experiment_id',
+ metavar='EXPERIMENT_ID',
+ type=str,
+ default=None,
+ help='ID of an experiment to delete permanently')
+
+ export = subparsers.add_parser(
+ 'export', help='download all your experiment data')
+ export.set_defaults(**{_SUBCOMMAND_FLAG: _SUBCOMMAND_KEY_EXPORT})
+ export.add_argument(
+ '--outdir',
+ metavar='OUTPUT_PATH',
+ type=str,
+ default=None,
+ help='Directory into which to download all experiment data; '
+ 'must not yet exist')
+
+ auth_parser = subparsers.add_parser('auth', help='log in, log out')
+ auth_parser.set_defaults(**{_SUBCOMMAND_FLAG: _SUBCOMMAND_KEY_AUTH})
+ auth_subparsers = auth_parser.add_subparsers()
+
+ auth_revoke = auth_subparsers.add_parser(
+ 'revoke', help='revoke all existing credentials and log out')
+ auth_revoke.set_defaults(
+ **{_AUTH_SUBCOMMAND_FLAG: _AUTH_SUBCOMMAND_KEY_REVOKE})
+
+
+def _parse_flags(argv=('',)):
+ """Integration point for `absl.app`.
+
+ Exits if flag values are invalid.
+
+ Args:
+ argv: CLI arguments, as with `sys.argv`, where the first argument is taken
+ to be the name of the program being executed.
+
+ Returns:
+ Either argv[:1] if argv was non-empty, or [''] otherwise, as a mechanism
+ for absl.app.run() compatibility.
+ """
+ parser = argparse_flags.ArgumentParser(
+ prog='uploader',
+ description=('Upload your TensorBoard experiments to TensorBoard.dev'))
+ _define_flags(parser)
+ arg0 = argv[0] if argv else ''
+ global _FLAGS
+ _FLAGS = parser.parse_args(argv[1:])
+ return [arg0]
+
+
+def _run(flags):
+ """Runs the main uploader program given parsed flags.
+
+ Args:
+ flags: An `argparse.Namespace`.
+ """
+
+ logging.set_stderrthreshold(logging.WARNING)
+ intent = _get_intent(flags)
+
+ store = auth.CredentialsStore()
+ if isinstance(intent, _AuthRevokeIntent):
+ store.clear()
+ sys.stderr.write('Logged out of uploader.\n')
+ sys.stderr.flush()
+ return
+ # TODO(b/141723268): maybe reconfirm Google Account prior to reuse.
+ credentials = store.read_credentials()
+ if not credentials:
+ _prompt_for_user_ack(intent)
+ client_config = json.loads(auth.OAUTH_CLIENT_CONFIG)
+ flow = auth.build_installed_app_flow(client_config)
+ credentials = flow.run(force_console=flags.auth_force_console)
+ sys.stderr.write('\n') # Extra newline after auth flow messages.
+ store.write_credentials(credentials)
+
+ channel_options = None
+ if flags.grpc_creds_type == 'local':
+ channel_creds = grpc.local_channel_credentials()
+ elif flags.grpc_creds_type == 'ssl':
+ channel_creds = grpc.ssl_channel_credentials()
+ elif flags.grpc_creds_type == 'ssl_dev':
+ channel_creds = grpc.ssl_channel_credentials(dev_creds.DEV_SSL_CERT)
+ channel_options = [('grpc.ssl_target_name_override', 'localhost')]
+ else:
+ msg = 'Invalid --grpc_creds_type %s' % flags.grpc_creds_type
+ raise base_plugin.FlagsError(msg)
+
+ composite_channel_creds = grpc.composite_channel_credentials(
+ channel_creds, auth.id_token_call_credentials(credentials))
+
+ # TODO(@nfelt): In the `_UploadIntent` case, consider waiting until
+ # logdir exists to open channel.
+ channel = grpc.secure_channel(
+ flags.endpoint, composite_channel_creds, options=channel_options)
+ with channel:
+ intent.execute(channel)
+
+
+@six.add_metaclass(abc.ABCMeta)
+class _Intent(object):
+ """A description of the user's intent in invoking this program.
+
+ Each valid set of CLI flags corresponds to one intent: e.g., "upload
+ data from this logdir", or "delete the experiment with that ID".
+ """
+
+ @abc.abstractmethod
+ def get_ack_message_body(self):
+ """Gets the message to show when executing this intent at first login.
+
+ This need not include the header (program name) or Terms of Service
+ notice.
+
+ Returns:
+ A Unicode string, potentially spanning multiple lines.
+ """
+ pass
+
+ @abc.abstractmethod
+ def execute(self, channel):
+ """Carries out this intent with the specified gRPC channel.
+
+ Args:
+ channel: A connected gRPC channel whose server provides the TensorBoard
+ reader and writer services.
+ """
+ pass
+
+
+class _AuthRevokeIntent(_Intent):
+ """The user intends to revoke credentials."""
+
+ def get_ack_message_body(self):
+ """Must not be called."""
+ raise AssertionError('No user ack needed to revoke credentials')
+
+ def execute(self, channel):
+ """Execute handled specially by `main`. Must not be called."""
+ raise AssertionError('_AuthRevokeIntent should not be directly executed')
+
+
+class _DeleteExperimentIntent(_Intent):
+ """The user intends to delete an experiment."""
+
+ _MESSAGE_TEMPLATE = textwrap.dedent(u"""\
+ This will delete the experiment on https://tensorboard.dev with the
+ following experiment ID:
+
+ {experiment_id}
+
+ You have chosen to delete an experiment. All experiments uploaded
+ to TensorBoard.dev are publicly visible. Do not upload sensitive
+ data.
+ """)
+
+ def __init__(self, experiment_id):
+ self.experiment_id = experiment_id
+
+ def get_ack_message_body(self):
+ return self._MESSAGE_TEMPLATE.format(experiment_id=self.experiment_id)
+
+ def execute(self, channel):
+ api_client = write_service_pb2_grpc.TensorBoardWriterServiceStub(channel)
+ experiment_id = self.experiment_id
+ if not experiment_id:
+ raise base_plugin.FlagsError(
+ 'Must specify a non-empty experiment ID to delete.')
+ try:
+ uploader_lib.delete_experiment(api_client, experiment_id)
+ except uploader_lib.ExperimentNotFoundError:
+ _die(
+ 'No such experiment %s. Either it never existed or it has '
+ 'already been deleted.' % experiment_id)
+ except uploader_lib.PermissionDeniedError:
+ _die(
+ 'Cannot delete experiment %s because it is owned by a '
+ 'different user.' % experiment_id)
+ except grpc.RpcError as e:
+ _die('Internal error deleting experiment: %s' % e)
+ print('Deleted experiment %s.' % experiment_id)
+
+
+class _UploadIntent(_Intent):
+ """The user intends to upload an experiment from the given logdir."""
+
+ _MESSAGE_TEMPLATE = textwrap.dedent(u"""\
+ This will upload your TensorBoard logs to https://tensorboard.dev/ from
+ the following directory:
+
+ {logdir}
+
+ This TensorBoard will be visible to everyone. Do not upload sensitive
+ data.
+ """)
+
+ def __init__(self, logdir):
+ self.logdir = logdir
+
+ def get_ack_message_body(self):
+ return self._MESSAGE_TEMPLATE.format(logdir=self.logdir)
+
+ def execute(self, channel):
+ api_client = write_service_pb2_grpc.TensorBoardWriterServiceStub(channel)
+ uploader = uploader_lib.TensorBoardUploader(api_client, self.logdir)
+ url = uploader.create_experiment()
+ print("Upload started and will continue reading any new data as it's added")
+ print("to the logdir. To stop uploading, press Ctrl-C.")
+ print("View your TensorBoard live at: %s" % url)
+ try:
+ uploader.start_uploading()
+ except uploader_lib.ExperimentNotFoundError:
+ print('Experiment was deleted; uploading has been cancelled')
+ return
+ except KeyboardInterrupt:
+ print()
+ print('Upload stopped. View your TensorBoard at %s' % url)
+ return
+ # TODO(@nfelt): make it possible for the upload cycle to end once we
+ # detect that no more runs are active, so this code can be reached.
+ print('Done! View your TensorBoard at %s' % url)
+
+
+class _ExportIntent(_Intent):
+ """The user intends to download all their experiment data."""
+
+ _MESSAGE_TEMPLATE = textwrap.dedent(u"""\
+ This will download all your experiment data from https://tensorboard.dev
+ and save it to the following directory:
+
+ {output_dir}
+
+ Downloading your experiment data does not delete it from the
+ service. All experiments uploaded to TensorBoard.dev are publicly
+ visible. Do not upload sensitive data.
+ """)
+
+ def __init__(self, output_dir):
+ self.output_dir = output_dir
+
+ def get_ack_message_body(self):
+ return self._MESSAGE_TEMPLATE.format(output_dir=self.output_dir)
+
+ def execute(self, channel):
+ api_client = export_service_pb2_grpc.TensorBoardExporterServiceStub(channel)
+ outdir = self.output_dir
+ try:
+ exporter = exporter_lib.TensorBoardExporter(api_client, outdir)
+ except exporter_lib.OutputDirectoryExistsError:
+ msg = 'Output directory already exists: %r' % outdir
+ raise base_plugin.FlagsError(msg)
+ num_experiments = 0
+ try:
+ for experiment_id in exporter.export():
+ num_experiments += 1
+ print('Downloaded experiment %s' % experiment_id)
+ except exporter_lib.GrpcTimeoutException as e:
+ print(
+ '\nUploader has failed because of a timeout error. Please reach '
+ 'out via e-mail to tensorboard.dev-support@google.com to get help '
+ 'completing your export of experiment %s.' % e.experiment_id)
+ print('Done. Downloaded %d experiments to: %s' % (num_experiments, outdir))
+
+
+def _get_intent(flags):
+ """Determines what the program should do (upload, delete, ...).
+
+ Args:
+ flags: An `argparse.Namespace` with the parsed flags.
+
+ Returns:
+ An `_Intent` instance.
+
+ Raises:
+ base_plugin.FlagsError: If the command-line `flags` do not correctly
+ specify an intent.
+ """
+ cmd = getattr(flags, _SUBCOMMAND_FLAG, None)
+ if cmd is None:
+ raise base_plugin.FlagsError('Must specify subcommand (try --help).')
+ if cmd == _SUBCOMMAND_KEY_UPLOAD:
+ if flags.logdir:
+ return _UploadIntent(os.path.expanduser(flags.logdir))
+ else:
+ raise base_plugin.FlagsError(
+ 'Must specify directory to upload via `--logdir`.')
+ elif cmd == _SUBCOMMAND_KEY_DELETE:
+ if flags.experiment_id:
+ return _DeleteExperimentIntent(flags.experiment_id)
+ else:
+ raise base_plugin.FlagsError(
+ 'Must specify experiment to delete via `--experiment_id`.')
+ elif cmd == _SUBCOMMAND_KEY_EXPORT:
+ if flags.outdir:
+ return _ExportIntent(flags.outdir)
+ else:
+ raise base_plugin.FlagsError(
+ 'Must specify output directory via `--outdir`.')
+ elif cmd == _SUBCOMMAND_KEY_AUTH:
+ auth_cmd = getattr(flags, _AUTH_SUBCOMMAND_FLAG, None)
+ if auth_cmd is None:
+ raise base_plugin.FlagsError('Must specify a subcommand to `auth`.')
+ if auth_cmd == _AUTH_SUBCOMMAND_KEY_REVOKE:
+ return _AuthRevokeIntent()
+ else:
+ raise AssertionError('Unknown auth subcommand %r' % (auth_cmd,))
+ else:
+ raise AssertionError('Unknown subcommand %r' % (cmd,))
+
+
+def _die(message):
+ sys.stderr.write('%s\n' % (message,))
+ sys.stderr.flush()
+ sys.exit(1)
+
+
+def main(unused_argv):
+ global _FLAGS
+ flags = _FLAGS
+ # Prevent accidental use of `_FLAGS` until migration to TensorBoard
+ # subcommand is complete, at which point `_FLAGS` goes away.
+ del _FLAGS
+ return _run(flags)
+
+
+class UploaderSubcommand(program.TensorBoardSubcommand):
+ """Integration point with `tensorboard` CLI."""
+
+ def name(self):
+ return 'dev'
+
+ def define_flags(self, parser):
+ _define_flags(parser)
+
+ def run(self, flags):
+ return _run(flags)
+
+ def help(self):
+ return 'upload data to TensorBoard.dev'
+
+
+if __name__ == '__main__':
+ app.run(main, flags_parser=_parse_flags)
diff --git a/tensorboard/uploader/uploader_test.py b/tensorboard/uploader/uploader_test.py
new file mode 100644
index 0000000000..427ce45ac8
--- /dev/null
+++ b/tensorboard/uploader/uploader_test.py
@@ -0,0 +1,662 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorboard.uploader.uploader."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import os
+
+import grpc
+import grpc_testing
+
+try:
+ # python version >= 3.3
+ from unittest import mock # pylint: disable=g-import-not-at-top
+except ImportError:
+ import mock # pylint: disable=g-import-not-at-top,unused-import
+
+import tensorflow as tf
+
+from tensorboard.uploader.proto import scalar_pb2
+from tensorboard.uploader.proto import write_service_pb2
+from tensorboard.uploader.proto import write_service_pb2_grpc
+from tensorboard.uploader import test_util
+from tensorboard.uploader import uploader as uploader_lib
+from tensorboard.uploader import util
+from tensorboard.compat.proto import event_pb2
+from tensorboard.compat.proto import summary_pb2
+from tensorboard.plugins.histogram import summary_v2 as histogram_v2
+from tensorboard.plugins.scalar import summary_v2 as scalar_v2
+from tensorboard.summary import v1 as summary_v1
+from tensorboard.util import test_util as tb_test_util
+
+
+class AbortUploadError(Exception):
+ """Exception used in testing to abort the upload process."""
+
+
+class TensorboardUploaderTest(tf.test.TestCase):
+
+ def _create_mock_client(self):
+ # Create a stub instance (using a test channel) in order to derive a mock
+ # from it with autospec enabled. Mocking TensorBoardWriterServiceStub itself
+ # doesn't work with autospec because grpc constructs stubs via metaclassing.
+ test_channel = grpc_testing.channel(
+ service_descriptors=[], time=grpc_testing.strict_real_time())
+ stub = write_service_pb2_grpc.TensorBoardWriterServiceStub(test_channel)
+ mock_client = mock.create_autospec(stub)
+ fake_exp_response = write_service_pb2.CreateExperimentResponse(
+ experiment_id="123", url="https://example.com/123")
+ mock_client.CreateExperiment.return_value = fake_exp_response
+ return mock_client
+
+ def test_create_experiment(self):
+ logdir = "/logs/foo"
+ mock_client = self._create_mock_client()
+ uploader = uploader_lib.TensorBoardUploader(mock_client, logdir)
+ url = uploader.create_experiment()
+ self.assertEqual(url, "https://example.com/123")
+
+ def test_start_uploading_without_create_experiment_fails(self):
+ mock_client = self._create_mock_client()
+ uploader = uploader_lib.TensorBoardUploader(mock_client, "/logs/foo")
+ with self.assertRaisesRegex(RuntimeError, "call create_experiment()"):
+ uploader.start_uploading()
+
+ def test_start_uploading(self):
+ mock_client = self._create_mock_client()
+ mock_rate_limiter = mock.create_autospec(util.RateLimiter)
+ uploader = uploader_lib.TensorBoardUploader(
+ mock_client, "/logs/foo", mock_rate_limiter)
+ uploader.create_experiment()
+ mock_builder = mock.create_autospec(uploader_lib._RequestBuilder)
+ request = write_service_pb2.WriteScalarRequest()
+ mock_builder.build_requests.side_effect = [
+ iter([request, request]),
+ iter([request, request, request, request, request]),
+ AbortUploadError,
+ ]
+ # pylint: disable=g-backslash-continuation
+ with mock.patch.object(uploader, "_upload") as mock_upload, \
+ mock.patch.object(uploader, "_request_builder", mock_builder), \
+ self.assertRaises(AbortUploadError):
+ uploader.start_uploading()
+ # pylint: enable=g-backslash-continuation
+ self.assertEqual(7, mock_upload.call_count)
+ self.assertEqual(2 + 5 + 1, mock_rate_limiter.tick.call_count)
+
+ def test_upload_empty_logdir(self):
+ logdir = self.get_temp_dir()
+ mock_client = self._create_mock_client()
+ mock_rate_limiter = mock.create_autospec(util.RateLimiter)
+ uploader = uploader_lib.TensorBoardUploader(
+ mock_client, logdir, mock_rate_limiter)
+ uploader.create_experiment()
+ uploader._upload_once()
+ mock_client.WriteScalar.assert_not_called()
+
+ def test_upload_swallows_rpc_failure(self):
+ logdir = self.get_temp_dir()
+ with tb_test_util.FileWriter(logdir) as writer:
+ writer.add_test_summary("foo")
+ mock_client = self._create_mock_client()
+ mock_rate_limiter = mock.create_autospec(util.RateLimiter)
+ uploader = uploader_lib.TensorBoardUploader(
+ mock_client, logdir, mock_rate_limiter)
+ uploader.create_experiment()
+ error = test_util.grpc_error(grpc.StatusCode.INTERNAL, "Failure")
+ mock_client.WriteScalar.side_effect = error
+ uploader._upload_once()
+ mock_client.WriteScalar.assert_called_once()
+
+ def test_upload_propagates_experiment_deletion(self):
+ logdir = self.get_temp_dir()
+ with tb_test_util.FileWriter(logdir) as writer:
+ writer.add_test_summary("foo")
+ mock_client = self._create_mock_client()
+ mock_rate_limiter = mock.create_autospec(util.RateLimiter)
+ uploader = uploader_lib.TensorBoardUploader(
+ mock_client, logdir, mock_rate_limiter)
+ uploader.create_experiment()
+ error = test_util.grpc_error(grpc.StatusCode.NOT_FOUND, "nope")
+ mock_client.WriteScalar.side_effect = error
+ with self.assertRaises(uploader_lib.ExperimentNotFoundError):
+ uploader._upload_once()
+
+ def test_upload_preserves_wall_time(self):
+ logdir = self.get_temp_dir()
+ with tb_test_util.FileWriter(logdir) as writer:
+ # Add a raw event so we can specify the wall_time value deterministically.
+ writer.add_event(
+ event_pb2.Event(
+ step=1,
+ wall_time=123.123123123,
+ summary=scalar_v2.scalar_pb("foo", 5.0)))
+ mock_client = self._create_mock_client()
+ mock_rate_limiter = mock.create_autospec(util.RateLimiter)
+ uploader = uploader_lib.TensorBoardUploader(
+ mock_client, logdir, mock_rate_limiter)
+ uploader.create_experiment()
+ uploader._upload_once()
+ mock_client.WriteScalar.assert_called_once()
+ request = mock_client.WriteScalar.call_args[0][0]
+ # Just check the wall_time value; everything else is covered in the full
+ # logdir test below.
+ self.assertEqual(
+ 123123123123,
+ request.runs[0].tags[0].points[0].wall_time.ToNanoseconds())
+
+ def test_upload_full_logdir(self):
+ logdir = self.get_temp_dir()
+ mock_client = self._create_mock_client()
+ mock_rate_limiter = mock.create_autospec(util.RateLimiter)
+ uploader = uploader_lib.TensorBoardUploader(
+ mock_client, logdir, mock_rate_limiter)
+ uploader.create_experiment()
+
+ # Convenience helpers for constructing expected requests.
+ run = write_service_pb2.WriteScalarRequest.Run
+ tag = write_service_pb2.WriteScalarRequest.Tag
+ point = scalar_pb2.ScalarPoint
+
+ # First round
+ writer = tb_test_util.FileWriter(logdir)
+ writer.add_test_summary("foo", simple_value=5.0, step=1)
+ writer.add_test_summary("foo", simple_value=6.0, step=2)
+ writer.add_test_summary("foo", simple_value=7.0, step=3)
+ writer.add_test_summary("bar", simple_value=8.0, step=3)
+ writer.flush()
+ writer_a = tb_test_util.FileWriter(os.path.join(logdir, "a"))
+ writer_a.add_test_summary("qux", simple_value=9.0, step=2)
+ writer_a.flush()
+ uploader._upload_once()
+ self.assertEqual(1, mock_client.WriteScalar.call_count)
+ request1 = mock_client.WriteScalar.call_args[0][0]
+ _clear_wall_times(request1)
+ expected_request1 = write_service_pb2.WriteScalarRequest(
+ experiment_id="123",
+ runs=[
+ run(name=".",
+ tags=[
+ tag(name="foo",
+ metadata=test_util.scalar_metadata("foo"),
+ points=[
+ point(step=1, value=5.0),
+ point(step=2, value=6.0),
+ point(step=3, value=7.0),
+ ]),
+ tag(name="bar",
+ metadata=test_util.scalar_metadata("bar"),
+ points=[
+ point(step=3, value=8.0),
+ ]),
+ ]),
+ run(name="a",
+ tags=[
+ tag(name="qux",
+ metadata=test_util.scalar_metadata("qux"),
+ points=[
+ point(step=2, value=9.0),
+ ]),
+ ]),
+ ])
+ self.assertProtoEquals(expected_request1, request1)
+ mock_client.WriteScalar.reset_mock()
+
+ # Second round
+ writer.add_test_summary("foo", simple_value=10.0, step=5)
+ writer.add_test_summary("baz", simple_value=11.0, step=1)
+ writer.flush()
+ writer_b = tb_test_util.FileWriter(os.path.join(logdir, "b"))
+ writer_b.add_test_summary("xyz", simple_value=12.0, step=1)
+ writer_b.flush()
+ uploader._upload_once()
+ self.assertEqual(1, mock_client.WriteScalar.call_count)
+ request2 = mock_client.WriteScalar.call_args[0][0]
+ _clear_wall_times(request2)
+ expected_request2 = write_service_pb2.WriteScalarRequest(
+ experiment_id="123",
+ runs=[
+ run(name=".",
+ tags=[
+ tag(name="foo",
+ metadata=test_util.scalar_metadata("foo"),
+ points=[
+ point(step=5, value=10.0),
+ ]),
+ tag(name="baz",
+ metadata=test_util.scalar_metadata("baz"),
+ points=[
+ point(step=1, value=11.0),
+ ]),
+ ]),
+ run(name="b",
+ tags=[
+ tag(name="xyz",
+ metadata=test_util.scalar_metadata("xyz"),
+ points=[
+ point(step=1, value=12.0),
+ ]),
+ ]),
+ ])
+ self.assertProtoEquals(expected_request2, request2)
+ mock_client.WriteScalar.reset_mock()
+
+ # Empty third round
+ uploader._upload_once()
+ mock_client.WriteScalar.assert_not_called()
+
+
+class RequestBuilderTest(tf.test.TestCase):
+
+ def _populate_run_from_events(self, run_proto, events):
+ builder = uploader_lib._RequestBuilder(experiment_id="123")
+ requests = builder.build_requests({"": events})
+ request = next(requests, None)
+ if request is not None:
+ self.assertLen(request.runs, 1)
+ run_proto.MergeFrom(request.runs[0])
+ self.assertIsNone(next(requests, None))
+
+ def test_empty_events(self):
+ run_proto = write_service_pb2.WriteScalarRequest.Run()
+ self._populate_run_from_events(run_proto, [])
+ self.assertProtoEquals(
+ run_proto, write_service_pb2.WriteScalarRequest.Run())
+
+ def test_aggregation_by_tag(self):
+ def make_event(step, wall_time, tag, value):
+ return event_pb2.Event(
+ step=step,
+ wall_time=wall_time,
+ summary=scalar_v2.scalar_pb(tag, value))
+ events = [
+ make_event(1, 1.0, "one", 11.0),
+ make_event(1, 2.0, "two", 22.0),
+ make_event(2, 3.0, "one", 33.0),
+ make_event(2, 4.0, "two", 44.0),
+ make_event(1, 5.0, "one", 55.0), # Should preserve duplicate step=1.
+ make_event(1, 6.0, "three", 66.0),
+ ]
+ run_proto = write_service_pb2.WriteScalarRequest.Run()
+ self._populate_run_from_events(run_proto, events)
+ tag_data = {
+ tag.name: [
+ (p.step, p.wall_time.ToSeconds(), p.value) for p in tag.points]
+ for tag in run_proto.tags}
+ self.assertEqual(
+ tag_data, {
+ "one": [(1, 1.0, 11.0), (2, 3.0, 33.0), (1, 5.0, 55.0)],
+ "two": [(1, 2.0, 22.0), (2, 4.0, 44.0)],
+ "three": [(1, 6.0, 66.0)],
+ })
+
+ def test_skips_non_scalar_events(self):
+ events = [
+ event_pb2.Event(file_version="brain.Event:2"),
+ event_pb2.Event(summary=scalar_v2.scalar_pb("scalar1", 5.0)),
+ event_pb2.Event(summary=scalar_v2.scalar_pb("scalar2", 5.0)),
+ event_pb2.Event(
+ summary=histogram_v2.histogram_pb("histogram", [5.0]))
+ ]
+ run_proto = write_service_pb2.WriteScalarRequest.Run()
+ self._populate_run_from_events(run_proto, events)
+ tag_counts = {tag.name: len(tag.points) for tag in run_proto.tags}
+ self.assertEqual(tag_counts, {"scalar1": 1, "scalar2": 1})
+
+ def test_skips_scalar_events_in_non_scalar_time_series(self):
+ events = [
+ event_pb2.Event(file_version="brain.Event:2"),
+ event_pb2.Event(summary=scalar_v2.scalar_pb("scalar1", 5.0)),
+ event_pb2.Event(summary=scalar_v2.scalar_pb("scalar2", 5.0)),
+ event_pb2.Event(
+ summary=histogram_v2.histogram_pb("histogram", [5.0])),
+ event_pb2.Event(summary=scalar_v2.scalar_pb("histogram", 5.0)),
+ ]
+ run_proto = write_service_pb2.WriteScalarRequest.Run()
+ self._populate_run_from_events(run_proto, events)
+ tag_counts = {tag.name: len(tag.points) for tag in run_proto.tags}
+ self.assertEqual(tag_counts, {"scalar1": 1, "scalar2": 1})
+
+ def test_remembers_first_metadata_in_scalar_time_series(self):
+ scalar_1 = event_pb2.Event(summary=scalar_v2.scalar_pb("loss", 4.0))
+ scalar_2 = event_pb2.Event(summary=scalar_v2.scalar_pb("loss", 3.0))
+ scalar_2.summary.value[0].ClearField("metadata")
+ events = [
+ event_pb2.Event(file_version="brain.Event:2"),
+ scalar_1,
+ scalar_2,
+ ]
+ run_proto = write_service_pb2.WriteScalarRequest.Run()
+ self._populate_run_from_events(run_proto, events)
+ tag_counts = {tag.name: len(tag.points) for tag in run_proto.tags}
+ self.assertEqual(tag_counts, {"loss": 2})
+
+ def test_v1_summary_single_value(self):
+ event = event_pb2.Event(step=1, wall_time=123.456)
+ event.summary.value.add(tag="foo", simple_value=5.0)
+ run_proto = write_service_pb2.WriteScalarRequest.Run()
+ self._populate_run_from_events(run_proto, [event])
+ expected_run_proto = write_service_pb2.WriteScalarRequest.Run()
+ foo_tag = expected_run_proto.tags.add()
+ foo_tag.name = "foo"
+ foo_tag.metadata.display_name = "foo"
+ foo_tag.metadata.plugin_data.plugin_name = "scalars"
+ foo_tag.points.add(
+ step=1, wall_time=test_util.timestamp_pb(123456000000), value=5.0)
+ self.assertProtoEquals(run_proto, expected_run_proto)
+
+ def test_v1_summary_multiple_value(self):
+ event = event_pb2.Event(step=1, wall_time=123.456)
+ event.summary.value.add(tag="foo", simple_value=1.0)
+ event.summary.value.add(tag="foo", simple_value=2.0)
+ event.summary.value.add(tag="foo", simple_value=3.0)
+ run_proto = write_service_pb2.WriteScalarRequest.Run()
+ self._populate_run_from_events(run_proto, [event])
+ expected_run_proto = write_service_pb2.WriteScalarRequest.Run()
+ foo_tag = expected_run_proto.tags.add()
+ foo_tag.name = "foo"
+ foo_tag.metadata.display_name = "foo"
+ foo_tag.metadata.plugin_data.plugin_name = "scalars"
+ foo_tag.points.add(
+ step=1, wall_time=test_util.timestamp_pb(123456000000), value=1.0)
+ foo_tag.points.add(
+ step=1, wall_time=test_util.timestamp_pb(123456000000), value=2.0)
+ foo_tag.points.add(
+ step=1, wall_time=test_util.timestamp_pb(123456000000), value=3.0)
+ self.assertProtoEquals(run_proto, expected_run_proto)
+
+ def test_v1_summary_tb_summary(self):
+ tf_summary = summary_v1.scalar_pb("foo", 5.0)
+ tb_summary = summary_pb2.Summary.FromString(tf_summary.SerializeToString())
+ event = event_pb2.Event(step=1, wall_time=123.456, summary=tb_summary)
+ run_proto = write_service_pb2.WriteScalarRequest.Run()
+ self._populate_run_from_events(run_proto, [event])
+ expected_run_proto = write_service_pb2.WriteScalarRequest.Run()
+ foo_tag = expected_run_proto.tags.add()
+ foo_tag.name = "foo/scalar_summary"
+ foo_tag.metadata.display_name = "foo"
+ foo_tag.metadata.plugin_data.plugin_name = "scalars"
+ foo_tag.points.add(
+ step=1, wall_time=test_util.timestamp_pb(123456000000), value=5.0)
+ self.assertProtoEquals(run_proto, expected_run_proto)
+
+ def test_v2_summary(self):
+ event = event_pb2.Event(
+ step=1, wall_time=123.456, summary=scalar_v2.scalar_pb("foo", 5.0))
+ run_proto = write_service_pb2.WriteScalarRequest.Run()
+ self._populate_run_from_events(run_proto, [event])
+ expected_run_proto = write_service_pb2.WriteScalarRequest.Run()
+ foo_tag = expected_run_proto.tags.add()
+ foo_tag.name = "foo"
+ foo_tag.metadata.plugin_data.plugin_name = "scalars"
+ foo_tag.points.add(
+ step=1, wall_time=test_util.timestamp_pb(123456000000), value=5.0)
+ self.assertProtoEquals(run_proto, expected_run_proto)
+
+ def test_no_budget_for_experiment_id(self):
+ event = event_pb2.Event(step=1, wall_time=123.456)
+ event.summary.value.add(tag="foo", simple_value=1.0)
+ run_to_events = {"run_name": [event]}
+ long_experiment_id = "A" * uploader_lib._MAX_REQUEST_LENGTH_BYTES
+ with self.assertRaises(RuntimeError) as cm:
+ builder = uploader_lib._RequestBuilder(long_experiment_id)
+ list(builder.build_requests(run_to_events))
+ self.assertEqual(
+ str(cm.exception), "Byte budget too small for experiment ID")
+
+ def test_no_room_for_single_point(self):
+ event = event_pb2.Event(step=1, wall_time=123.456)
+ event.summary.value.add(tag="foo", simple_value=1.0)
+ long_run_name = "A" * uploader_lib._MAX_REQUEST_LENGTH_BYTES
+ run_to_events = {long_run_name: [event]}
+ with self.assertRaises(RuntimeError) as cm:
+ builder = uploader_lib._RequestBuilder("123")
+ list(builder.build_requests(run_to_events))
+ self.assertEqual(
+ str(cm.exception), "Could not make progress uploading data")
+
+ @mock.patch.object(uploader_lib, "_MAX_REQUEST_LENGTH_BYTES", 1024)
+ def test_break_at_run_boundary(self):
+ # Choose run name sizes such that one run fits, but not two.
+ long_run_1 = "A" * 768
+ long_run_2 = "B" * 768
+ event_1 = event_pb2.Event(step=1)
+ event_1.summary.value.add(tag="foo", simple_value=1.0)
+ event_2 = event_pb2.Event(step=2)
+ event_2.summary.value.add(tag="bar", simple_value=-2.0)
+ run_to_events = collections.OrderedDict([
+ (long_run_1, [event_1]),
+ (long_run_2, [event_2]),
+ ])
+
+ builder = uploader_lib._RequestBuilder("123")
+ requests = list(builder.build_requests(run_to_events))
+ for request in requests:
+ _clear_wall_times(request)
+
+ expected = [
+ write_service_pb2.WriteScalarRequest(experiment_id="123"),
+ write_service_pb2.WriteScalarRequest(experiment_id="123"),
+ ]
+ (expected[0].runs.add(name=long_run_1).tags.add(
+ name="foo", metadata=test_util.scalar_metadata("foo")).points.add(
+ step=1, value=1.0))
+ (expected[1].runs.add(name=long_run_2).tags.add(
+ name="bar", metadata=test_util.scalar_metadata("bar")).points.add(
+ step=2, value=-2.0))
+ self.assertEqual(requests, expected)
+
+ @mock.patch.object(uploader_lib, "_MAX_REQUEST_LENGTH_BYTES", 1024)
+ def test_break_at_tag_boundary(self):
+ # Choose tag name sizes such that one tag fits, but not two. Note
+ # that tag names appear in both `Tag.name` and the summary metadata.
+ long_tag_1 = "a" * 384
+ long_tag_2 = "b" * 384
+ event = event_pb2.Event(step=1)
+ event.summary.value.add(tag=long_tag_1, simple_value=1.0)
+ event.summary.value.add(tag=long_tag_2, simple_value=2.0)
+ run_to_events = {"train": [event]}
+
+ builder = uploader_lib._RequestBuilder("123")
+ requests = list(builder.build_requests(run_to_events))
+ for request in requests:
+ _clear_wall_times(request)
+
+ expected = [
+ write_service_pb2.WriteScalarRequest(experiment_id="123"),
+ write_service_pb2.WriteScalarRequest(experiment_id="123"),
+ ]
+ (expected[0].runs.add(name="train").tags.add(
+ name=long_tag_1,
+ metadata=test_util.scalar_metadata(long_tag_1)).points.add(
+ step=1, value=1.0))
+ (expected[1].runs.add(name="train").tags.add(
+ name=long_tag_2,
+ metadata=test_util.scalar_metadata(long_tag_2)).points.add(
+ step=1, value=2.0))
+ self.assertEqual(requests, expected)
+
+ @mock.patch.object(uploader_lib, "_MAX_REQUEST_LENGTH_BYTES", 1024)
+ def test_break_at_scalar_point_boundary(self):
+ point_count = 2000 # comfortably saturates a single 1024-byte request
+ events = []
+ for step in range(point_count):
+ summary = scalar_v2.scalar_pb("loss", -2.0 * step)
+ if step > 0:
+ summary.value[0].ClearField("metadata")
+ events.append(event_pb2.Event(summary=summary, step=step))
+ run_to_events = {"train": events}
+
+ builder = uploader_lib._RequestBuilder("123")
+ requests = list(builder.build_requests(run_to_events))
+ for request in requests:
+ _clear_wall_times(request)
+
+ self.assertGreater(len(requests), 1)
+ self.assertLess(len(requests), point_count)
+
+ total_points_in_result = 0
+ for request in requests:
+ self.assertLen(request.runs, 1)
+ run = request.runs[0]
+ self.assertEqual(run.name, "train")
+ self.assertLen(run.tags, 1)
+ tag = run.tags[0]
+ self.assertEqual(tag.name, "loss")
+ for point in tag.points:
+ self.assertEqual(point.step, total_points_in_result)
+ self.assertEqual(point.value, -2.0 * point.step)
+ total_points_in_result += 1
+ self.assertLessEqual(
+ request.ByteSize(), uploader_lib._MAX_REQUEST_LENGTH_BYTES)
+ self.assertEqual(total_points_in_result, point_count)
+
+ def test_prunes_tags_and_runs(self):
+ event_1 = event_pb2.Event(step=1)
+ event_1.summary.value.add(tag="foo", simple_value=1.0)
+ event_2 = event_pb2.Event(step=2)
+ event_2.summary.value.add(tag="bar", simple_value=-2.0)
+ run_to_events = collections.OrderedDict([
+ ("train", [event_1]),
+ ("test", [event_2]),
+ ])
+
+ real_create_point = uploader_lib._RequestBuilder._create_point
+
+ create_point_call_count_box = [0]
+
+ def mock_create_point(uploader_self, *args, **kwargs):
+ # Simulate out-of-space error the first time that we try to store
+ # the second point.
+ create_point_call_count_box[0] += 1
+ if create_point_call_count_box[0] == 2:
+ raise uploader_lib._OutOfSpaceError()
+ return real_create_point(uploader_self, *args, **kwargs)
+
+ with mock.patch.object(
+ uploader_lib._RequestBuilder, "_create_point", mock_create_point):
+ builder = uploader_lib._RequestBuilder("123")
+ requests = list(builder.build_requests(run_to_events))
+ for request in requests:
+ _clear_wall_times(request)
+
+ expected = [
+ write_service_pb2.WriteScalarRequest(experiment_id="123"),
+ write_service_pb2.WriteScalarRequest(experiment_id="123"),
+ ]
+ (expected[0].runs.add(name="train").tags.add(
+ name="foo", metadata=test_util.scalar_metadata("foo")).points.add(
+ step=1, value=1.0))
+ (expected[1].runs.add(name="test").tags.add(
+ name="bar", metadata=test_util.scalar_metadata("bar")).points.add(
+ step=2, value=-2.0))
+ self.assertEqual(expected, requests)
+
+ def test_wall_time_precision(self):
+ # Test a wall time that is exactly representable in float64 but has enough
+ # digits to incur error if converted to nanonseconds the naive way (* 1e9).
+ event1 = event_pb2.Event(step=1, wall_time=1567808404.765432119)
+ event1.summary.value.add(tag="foo", simple_value=1.0)
+ # Test a wall time where as a float64, the fractional part on its own will
+ # introduce error if truncated to 9 decimal places instead of rounded.
+ event2 = event_pb2.Event(step=2, wall_time=1.000000002)
+ event2.summary.value.add(tag="foo", simple_value=2.0)
+ run_proto = write_service_pb2.WriteScalarRequest.Run()
+ self._populate_run_from_events(run_proto, [event1, event2])
+ self.assertEqual(
+ test_util.timestamp_pb(1567808404765432119),
+ run_proto.tags[0].points[0].wall_time)
+ self.assertEqual(
+ test_util.timestamp_pb(1000000002),
+ run_proto.tags[0].points[1].wall_time)
+
+
+class DeleteExperimentTest(tf.test.TestCase):
+
+ def _create_mock_client(self):
+ # Create a stub instance (using a test channel) in order to derive a mock
+ # from it with autospec enabled. Mocking TensorBoardWriterServiceStub itself
+ # doesn't work with autospec because grpc constructs stubs via metaclassing.
+ test_channel = grpc_testing.channel(
+ service_descriptors=[], time=grpc_testing.strict_real_time())
+ stub = write_service_pb2_grpc.TensorBoardWriterServiceStub(test_channel)
+ mock_client = mock.create_autospec(stub)
+ return mock_client
+
+ def test_success(self):
+ mock_client = self._create_mock_client()
+ response = write_service_pb2.DeleteExperimentResponse()
+ mock_client.DeleteExperiment.return_value = response
+
+ uploader_lib.delete_experiment(mock_client, "123")
+
+ expected_request = write_service_pb2.DeleteExperimentRequest()
+ expected_request.experiment_id = "123"
+ mock_client.DeleteExperiment.assert_called_once()
+ (args, _) = mock_client.DeleteExperiment.call_args
+ self.assertEqual(args[0], expected_request)
+
+ def test_not_found(self):
+ mock_client = self._create_mock_client()
+ error = test_util.grpc_error(grpc.StatusCode.NOT_FOUND, "nope")
+ mock_client.DeleteExperiment.side_effect = error
+
+ with self.assertRaises(uploader_lib.ExperimentNotFoundError):
+ uploader_lib.delete_experiment(mock_client, "123")
+
+ def test_unauthorized(self):
+ mock_client = self._create_mock_client()
+ error = test_util.grpc_error(grpc.StatusCode.PERMISSION_DENIED, "nope")
+ mock_client.DeleteExperiment.side_effect = error
+
+ with self.assertRaises(uploader_lib.PermissionDeniedError):
+ uploader_lib.delete_experiment(mock_client, "123")
+
+ def test_internal_error(self):
+ mock_client = self._create_mock_client()
+ error = test_util.grpc_error(grpc.StatusCode.INTERNAL, "travesty")
+ mock_client.DeleteExperiment.side_effect = error
+
+ with self.assertRaises(grpc.RpcError) as cm:
+ uploader_lib.delete_experiment(mock_client, "123")
+ msg = str(cm.exception)
+ self.assertIn("travesty", msg)
+
+
+class VarintCostTest(tf.test.TestCase):
+
+ def test_varint_cost(self):
+ self.assertEqual(uploader_lib._varint_cost(0), 1)
+ self.assertEqual(uploader_lib._varint_cost(7), 1)
+ self.assertEqual(uploader_lib._varint_cost(127), 1)
+ self.assertEqual(uploader_lib._varint_cost(128), 2)
+ self.assertEqual(uploader_lib._varint_cost(128 * 128 - 1), 2)
+ self.assertEqual(uploader_lib._varint_cost(128 * 128), 3)
+
+
+def _clear_wall_times(request):
+ """Clears the wall_time fields in a WriteScalarRequest to be deterministic."""
+ for run in request.runs:
+ for tag in run.tags:
+ for point in tag.points:
+ point.ClearField("wall_time")
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorboard/uploader/util.py b/tensorboard/uploader/util.py
new file mode 100644
index 0000000000..795758c4c1
--- /dev/null
+++ b/tensorboard/uploader/util.py
@@ -0,0 +1,114 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for use by the uploader command line tool."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import errno
+import os
+import os.path
+import time
+
+
+class RateLimiter(object):
+ """Helper class for rate-limiting using a fixed minimum interval."""
+
+ def __init__(self, interval_secs):
+ """Constructs a RateLimiter that permits a tick() every `interval_secs`."""
+ self._time = time # Use property for ease of testing.
+ self._interval_secs = interval_secs
+ self._last_called_secs = 0
+
+ def tick(self):
+ """Blocks until it has been at least `interval_secs` since last tick()."""
+ wait_secs = self._last_called_secs + self._interval_secs - self._time.time()
+ if wait_secs > 0:
+ self._time.sleep(wait_secs)
+ self._last_called_secs = self._time.time()
+
+
+def get_user_config_directory():
+ """Returns a platform-specific root directory for user config settings."""
+ # On Windows, prefer %LOCALAPPDATA%, then %APPDATA%, since we can expect the
+ # AppData directories to be ACLed to be visible only to the user and admin
+ # users (https://stackoverflow.com/a/7617601/1179226). If neither is set,
+ # return None instead of falling back to something that may be world-readable.
+ if os.name == "nt":
+ appdata = os.getenv("LOCALAPPDATA")
+ if appdata:
+ return appdata
+ appdata = os.getenv("APPDATA")
+ if appdata:
+ return appdata
+ return None
+ # On non-windows, use XDG_CONFIG_HOME if set, else default to ~/.config.
+ xdg_config_home = os.getenv("XDG_CONFIG_HOME")
+ if xdg_config_home:
+ return xdg_config_home
+ return os.path.join(os.path.expanduser("~"), ".config")
+
+
+def make_file_with_directories(path, private=False):
+ """Creates a file and its containing directories, if they don't already exist.
+
+
+ If `private` is True, the file will be made private (readable only by the
+ current user) and so will the leaf directory. Pre-existing contents of the
+ file are not modified.
+
+ Passing `private=True` is not supported on Windows because it doesn't support
+ the relevant parts of `os.chmod()`.
+
+ Args:
+ path: str, The path of the file to create.
+ private: boolean, Whether to make the file and leaf directory readable only
+ by the current user.
+
+ Raises:
+ RuntimeError: If called on Windows with `private` set to True.
+ """
+ if private and os.name == "nt":
+ raise RuntimeError("Creating private file not supported on Windows")
+ try:
+ path = os.path.realpath(path)
+ leaf_dir = os.path.dirname(path)
+ try:
+ os.makedirs(leaf_dir)
+ except OSError as e:
+ if e.errno != errno.EEXIST:
+ raise
+ if private:
+ os.chmod(leaf_dir, 0o700)
+ open(path, "a").close()
+ if private:
+ os.chmod(path, 0o600)
+ except EnvironmentError as e:
+ raise RuntimeError("Failed to create file %s: %s" % (path, e))
+
+
+def set_timestamp(pb, seconds_since_epoch):
+ """Sets a `Timestamp` proto message to a floating point UNIX time.
+
+ This is like `pb.FromNanoseconds(int(seconds_since_epoch * 1e9))` but
+ without introducing floating-point error.
+
+ Args:
+ pb: A `google.protobuf.Timestamp` message to mutate.
+ seconds_since_epoch: A `float`, as returned by `time.time`.
+ """
+ pb.seconds = int(seconds_since_epoch)
+ pb.nanos = int(round((seconds_since_epoch % 1) * 10**9))
diff --git a/tensorboard/uploader/util_test.py b/tensorboard/uploader/util_test.py
new file mode 100644
index 0000000000..b0670d5315
--- /dev/null
+++ b/tensorboard/uploader/util_test.py
@@ -0,0 +1,199 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorboard.uploader.util."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import unittest
+
+
+try:
+ # python version >= 3.3
+ from unittest import mock # pylint: disable=g-import-not-at-top
+except ImportError:
+ import mock # pylint: disable=g-import-not-at-top,unused-import
+
+
+from google.protobuf import timestamp_pb2
+from tensorboard.uploader import test_util
+from tensorboard.uploader import util
+from tensorboard import test as tb_test
+
+
+class RateLimiterTest(tb_test.TestCase):
+
+ def test_rate_limiting(self):
+ rate_limiter = util.RateLimiter(10)
+ fake_time = test_util.FakeTime(current=1000)
+ with mock.patch.object(rate_limiter, "_time", fake_time):
+ self.assertEqual(1000, fake_time.time())
+ # No sleeping for initial tick.
+ rate_limiter.tick()
+ self.assertEqual(1000, fake_time.time())
+ # Second tick requires a full sleep.
+ rate_limiter.tick()
+ self.assertEqual(1010, fake_time.time())
+ # Third tick requires a sleep just to make up the remaining second.
+ fake_time.sleep(9)
+ self.assertEqual(1019, fake_time.time())
+ rate_limiter.tick()
+ self.assertEqual(1020, fake_time.time())
+ # Fourth tick requires no sleep since we have no remaining seconds.
+ fake_time.sleep(11)
+ self.assertEqual(1031, fake_time.time())
+ rate_limiter.tick()
+ self.assertEqual(1031, fake_time.time())
+
+
+class GetUserConfigDirectoryTest(tb_test.TestCase):
+
+ def test_windows(self):
+ with mock.patch.object(os, "name", "nt"):
+ with mock.patch.dict(os.environ, {
+ "LOCALAPPDATA": "C:\\Users\\Alice\\AppData\\Local",
+ "APPDATA": "C:\\Users\\Alice\\AppData\\Roaming",
+ }):
+ self.assertEqual(
+ "C:\\Users\\Alice\\AppData\\Local",
+ util.get_user_config_directory())
+ with mock.patch.dict(os.environ, {
+ "LOCALAPPDATA": "",
+ "APPDATA": "C:\\Users\\Alice\\AppData\\Roaming",
+ }):
+ self.assertEqual(
+ "C:\\Users\\Alice\\AppData\\Roaming",
+ util.get_user_config_directory())
+ with mock.patch.dict(os.environ, {
+ "LOCALAPPDATA": "",
+ "APPDATA": "",
+ }):
+ self.assertIsNone(util.get_user_config_directory())
+
+ def test_non_windows(self):
+ with mock.patch.dict(os.environ, {"HOME": "/home/alice"}):
+ self.assertEqual(
+ "/home/alice%s.config" % os.sep, util.get_user_config_directory())
+ with mock.patch.dict(
+ os.environ, {"XDG_CONFIG_HOME": "/home/alice/configz"}):
+ self.assertEqual(
+ "/home/alice/configz", util.get_user_config_directory())
+
+
+skip_if_windows = unittest.skipIf(os.name == "nt", "Unsupported on Windows")
+
+
+class MakeFileWithDirectoriesTest(tb_test.TestCase):
+
+ def test_windows_private(self):
+ with mock.patch.object(os, "name", "nt"):
+ with self.assertRaisesRegex(RuntimeError, "Windows"):
+ util.make_file_with_directories("/tmp/foo", private=True)
+
+ def test_existing_file(self):
+ root = self.get_temp_dir()
+ path = os.path.join(root, "foo", "bar", "qux.txt")
+ os.makedirs(os.path.dirname(path))
+ with open(path, mode="w") as f:
+ f.write("foobar")
+ util.make_file_with_directories(path)
+ with open(path, mode="r") as f:
+ self.assertEqual("foobar", f.read())
+
+ def test_existing_dir(self):
+ root = self.get_temp_dir()
+ path = os.path.join(root, "foo", "bar", "qux.txt")
+ os.makedirs(os.path.dirname(path))
+ util.make_file_with_directories(path)
+ self.assertEqual(0, os.path.getsize(path))
+
+ def test_nonexistent_leaf_dir(self):
+ root = self.get_temp_dir()
+ path = os.path.join(root, "foo", "bar", "qux.txt")
+ os.makedirs(os.path.dirname(os.path.dirname(path)))
+ util.make_file_with_directories(path)
+ self.assertEqual(0, os.path.getsize(path))
+
+ def test_nonexistent_multiple_dirs(self):
+ root = self.get_temp_dir()
+ path = os.path.join(root, "foo", "bar", "qux.txt")
+ util.make_file_with_directories(path)
+ self.assertEqual(0, os.path.getsize(path))
+
+ def assertMode(self, mode, path):
+ self.assertEqual(mode, os.stat(path).st_mode & 0o777)
+
+ @skip_if_windows
+ def test_private_existing_file(self):
+ root = self.get_temp_dir()
+ path = os.path.join(root, "foo", "bar", "qux.txt")
+ os.makedirs(os.path.dirname(path))
+ with open(path, mode="w") as f:
+ f.write("foobar")
+ os.chmod(os.path.dirname(path), 0o777)
+ os.chmod(path, 0o666)
+ util.make_file_with_directories(path, private=True)
+ self.assertMode(0o700, os.path.dirname(path))
+ self.assertMode(0o600, path)
+ with open(path, mode="r") as f:
+ self.assertEqual("foobar", f.read())
+
+ @skip_if_windows
+ def test_private_existing_dir(self):
+ root = self.get_temp_dir()
+ path = os.path.join(root, "foo", "bar", "qux.txt")
+ os.makedirs(os.path.dirname(path))
+ os.chmod(os.path.dirname(path), 0o777)
+ util.make_file_with_directories(path, private=True)
+ self.assertMode(0o700, os.path.dirname(path))
+ self.assertMode(0o600, path)
+ self.assertEqual(0, os.path.getsize(path))
+
+ @skip_if_windows
+ def test_private_nonexistent_leaf_dir(self):
+ root = self.get_temp_dir()
+ path = os.path.join(root, "foo", "bar", "qux.txt")
+ os.makedirs(os.path.dirname(os.path.dirname(path)))
+ util.make_file_with_directories(path, private=True)
+ self.assertMode(0o700, os.path.dirname(path))
+ self.assertMode(0o600, path)
+ self.assertEqual(0, os.path.getsize(path))
+
+ @skip_if_windows
+ def test_private_nonexistent_multiple_dirs(self):
+ root = self.get_temp_dir()
+ path = os.path.join(root, "foo", "bar", "qux.txt")
+ util.make_file_with_directories(path, private=True)
+ self.assertMode(0o700, os.path.dirname(path))
+ self.assertMode(0o600, path)
+ self.assertEqual(0, os.path.getsize(path))
+
+
+class SetTimestampTest(tb_test.TestCase):
+
+ def test_set_timestamp(self):
+ pb = timestamp_pb2.Timestamp()
+ t = 1234567890.007812500
+ # Note that just multiplying by 1e9 would lose precision:
+ self.assertEqual(int(t * 1e9) % int(1e9), 7812608)
+ util.set_timestamp(pb, t)
+ self.assertEqual(pb.seconds, 1234567890)
+ self.assertEqual(pb.nanos, 7812500)
+
+
+if __name__ == "__main__":
+ tb_test.main()
diff --git a/tensorboard/util/BUILD b/tensorboard/util/BUILD
index f4fbe92a27..3ec9e5c44d 100644
--- a/tensorboard/util/BUILD
+++ b/tensorboard/util/BUILD
@@ -1,9 +1,28 @@
package(default_visibility = ["//tensorboard:internal"])
+load("//tensorboard/defs:protos.bzl", "tb_proto_library")
+
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"]) # Needed for internal repo.
+py_library(
+ name = "argparse_util",
+ srcs = ["argparse_util.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_test(
+ name = "argparse_util_test",
+ size = "small",
+ srcs = ["argparse_util_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":argparse_util",
+ "//tensorboard:test",
+ ],
+)
+
py_library(
name = "encoder",
srcs = ["encoder.py"],
@@ -28,6 +47,43 @@ py_test(
],
)
+py_library(
+ name = "grpc_util",
+ srcs = ["grpc_util.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorboard:expect_grpc_installed",
+ "//tensorboard:version",
+ "//tensorboard/util:tb_logging",
+ ],
+)
+
+py_test(
+ name = "grpc_util_test",
+ size = "small",
+ srcs = ["grpc_util_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":grpc_util",
+ ":grpc_util_test_proto_py_pb2",
+ ":grpc_util_test_proto_py_pb2_grpc",
+ ":test_util",
+ "//tensorboard:expect_futures_installed",
+ "//tensorboard:expect_grpc_installed",
+ "//tensorboard:test",
+ "//tensorboard:version",
+ "@org_pythonhosted_mock",
+ "@org_pythonhosted_six",
+ ],
+)
+
+tb_proto_library(
+ name = "grpc_util_test_proto",
+ has_services = True,
+ srcs = ["grpc_util_test.proto"],
+ testonly = True,
+)
+
py_library(
name = "op_evaluator",
srcs = ["op_evaluator.py"],
diff --git a/tensorboard/util/argparse_util.py b/tensorboard/util/argparse_util.py
new file mode 100644
index 0000000000..27d08dd1d5
--- /dev/null
+++ b/tensorboard/util/argparse_util.py
@@ -0,0 +1,65 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for working with `argparse` in a portable way."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import contextlib
+import gettext
+
+
+@contextlib.contextmanager
+def allow_missing_subcommand():
+ """Make Python 2.7 behave like Python 3 w.r.t. default subcommands.
+
+ The behavior of argparse was changed [1] [2] in Python 3.3. When a
+ parser defines subcommands, it used to be an error for the user to
+ invoke the binary without specifying a subcommand. As of Python 3.3,
+ this is permitted. This monkey patch backports the new behavior to
+ earlier versions of Python.
+
+ This context manager need only be used around `parse_args`; parsers
+ may be constructed and configured outside of the context manager.
+
+ [1]: https://github.com/python/cpython/commit/f97c59aaba2d93e48cbc6d25f7ff9f9c87f8d0b2
+ [2]: https://bugs.python.org/issue16308
+ """
+
+ real_error = argparse.ArgumentParser.error
+
+ # This must exactly match the error message raised by Python 2.7's
+ # `argparse` when no subparser is given. This is `argparse.py:1954` at
+ # Git tag `v2.7.16`.
+ ignored_message = gettext.gettext("too few arguments")
+
+ def error(*args, **kwargs):
+ # Expected signature is `error(self, message)`, but we retain more
+ # flexibility to be forward-compatible with implementation changes.
+ if "message" not in kwargs and len(args) < 2:
+ return real_error(*args, **kwargs)
+ message = kwargs["message"] if "message" in kwargs else args[1]
+ if message == ignored_message:
+ return None
+ else:
+ return real_error(*args, **kwargs)
+
+ argparse.ArgumentParser.error = error
+ try:
+ yield
+ finally:
+ argparse.ArgumentParser.error = real_error
diff --git a/tensorboard/util/argparse_util_test.py b/tensorboard/util/argparse_util_test.py
new file mode 100644
index 0000000000..4e18d4a027
--- /dev/null
+++ b/tensorboard/util/argparse_util_test.py
@@ -0,0 +1,56 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `argparse_util`."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+
+from tensorboard import test as tb_test
+from tensorboard.util import argparse_util
+
+
+class AllowMissingSubcommandTest(tb_test.TestCase):
+
+ def test_allows_missing_subcommands(self):
+ parser = argparse.ArgumentParser()
+ subparsers = parser.add_subparsers()
+ subparser = subparsers.add_parser("magic")
+ subparser.set_defaults(chosen="magic")
+ with argparse_util.allow_missing_subcommand():
+ args = parser.parse_args([])
+ self.assertEqual(args, argparse.Namespace())
+
+ def test_allows_provided_subcommands(self):
+ parser = argparse.ArgumentParser()
+ subparsers = parser.add_subparsers()
+ subparser = subparsers.add_parser("magic")
+ subparser.set_defaults(chosen="magic")
+ with argparse_util.allow_missing_subcommand():
+ args = parser.parse_args(["magic"])
+ self.assertEqual(args, argparse.Namespace(chosen="magic"))
+
+ def test_still_complains_on_missing_arguments(self):
+ parser = argparse.ArgumentParser()
+ parser.add_argument("please_provide_me")
+ with argparse_util.allow_missing_subcommand():
+ with self.assertRaises(SystemExit):
+ parser.parse_args([])
+
+
+if __name__ == "__main__":
+ tb_test.main()
diff --git a/tensorboard/util/grpc_util.py b/tensorboard/util/grpc_util.py
new file mode 100644
index 0000000000..43c551dfeb
--- /dev/null
+++ b/tensorboard/util/grpc_util.py
@@ -0,0 +1,126 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for working with python gRPC stubs."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+import time
+
+import grpc
+
+from tensorboard import version
+from tensorboard.util import tb_logging
+
+logger = tb_logging.get_logger()
+
+# Default RPC timeout.
+_GRPC_DEFAULT_TIMEOUT_SECS = 30
+
+# Max number of times to attempt an RPC, retrying on transient failures.
+_GRPC_RETRY_MAX_ATTEMPTS = 5
+
+# Parameters to control the exponential backoff behavior.
+_GRPC_RETRY_EXPONENTIAL_BASE = 2
+_GRPC_RETRY_JITTER_FACTOR_MIN = 1.1
+_GRPC_RETRY_JITTER_FACTOR_MAX = 1.5
+
+# Status codes from gRPC for which it's reasonable to retry the RPC.
+_GRPC_RETRYABLE_STATUS_CODES = frozenset([
+ grpc.StatusCode.ABORTED,
+ grpc.StatusCode.DEADLINE_EXCEEDED,
+ grpc.StatusCode.RESOURCE_EXHAUSTED,
+ grpc.StatusCode.UNAVAILABLE,
+])
+
+# gRPC metadata key whose value contains the client version.
+_VERSION_METADATA_KEY = "tensorboard-version"
+
+
+def call_with_retries(api_method, request, clock=None):
+ """Call a gRPC stub API method, with automatic retry logic.
+
+ This only supports unary-unary RPCs: i.e., no streaming on either end.
+ Streamed RPCs will generally need application-level pagination support,
+ because after a gRPC error one must retry the entire request; there is no
+ "retry-resume" functionality.
+
+ Args:
+ api_method: Callable for the API method to invoke.
+ request: Request protocol buffer to pass to the API method.
+ clock: an interface object supporting `time()` and `sleep()` methods
+ like the standard `time` module; if not passed, uses the normal module.
+
+ Returns:
+ Response protocol buffer returned by the API method.
+
+ Raises:
+ grpc.RpcError: if a non-retryable error is returned, or if all retry
+ attempts have been exhausted.
+ """
+ if clock is None:
+ clock = time
+ # We can't actually use api_method.__name__ because it's not a real method,
+ # it's a special gRPC callable instance that doesn't expose the method name.
+ rpc_name = request.__class__.__name__.replace("Request", "")
+ logger.debug("RPC call %s with request: %r", rpc_name, request)
+ num_attempts = 0
+ while True:
+ num_attempts += 1
+ try:
+ return api_method(
+ request,
+ timeout=_GRPC_DEFAULT_TIMEOUT_SECS,
+ metadata=version_metadata())
+ except grpc.RpcError as e:
+ logger.info("RPC call %s got error %s", rpc_name, e)
+ if e.code() not in _GRPC_RETRYABLE_STATUS_CODES:
+ raise
+ if num_attempts >= _GRPC_RETRY_MAX_ATTEMPTS:
+ raise
+ jitter_factor = random.uniform(
+ _GRPC_RETRY_JITTER_FACTOR_MIN, _GRPC_RETRY_JITTER_FACTOR_MAX)
+ backoff_secs = (_GRPC_RETRY_EXPONENTIAL_BASE**num_attempts) * jitter_factor
+ logger.info(
+ "RPC call %s attempted %d times, retrying in %.1f seconds",
+ rpc_name, num_attempts, backoff_secs)
+ clock.sleep(backoff_secs)
+
+
+def version_metadata():
+ """Creates gRPC invocation metadata encoding the TensorBoard version.
+
+ Usage: `stub.MyRpc(request, metadata=version_metadata())`.
+
+ Returns:
+ A tuple of key-value pairs (themselves 2-tuples) to be passed as the
+ `metadata` kwarg to gRPC stub API methods.
+ """
+ return ((_VERSION_METADATA_KEY, version.VERSION),)
+
+
+def extract_version(metadata):
+ """Extracts version from invocation metadata.
+
+ The argument should be the result of a prior call to `metadata` or the
+ result of combining such a result with other metadata.
+
+ Returns:
+ The TensorBoard version listed in this metadata, or `None` if none
+ is listed.
+ """
+ return dict(metadata).get(_VERSION_METADATA_KEY)
diff --git a/tensorboard/util/grpc_util_test.proto b/tensorboard/util/grpc_util_test.proto
new file mode 100644
index 0000000000..3f4bb5c976
--- /dev/null
+++ b/tensorboard/util/grpc_util_test.proto
@@ -0,0 +1,18 @@
+// Minimal example RPC service definition. See grpc_util_test.py for usage.
+syntax = "proto3";
+
+package tensorboard.util;
+
+// Test service for grpc_util_test.py.
+service TestService {
+ // Test RPC.
+ rpc TestRpc(TestRpcRequest) returns (TestRpcResponse);
+}
+
+message TestRpcRequest {
+ int32 nonce = 1;
+}
+
+message TestRpcResponse {
+ int32 nonce = 1;
+}
diff --git a/tensorboard/util/grpc_util_test.py b/tensorboard/util/grpc_util_test.py
new file mode 100644
index 0000000000..4aeabc8914
--- /dev/null
+++ b/tensorboard/util/grpc_util_test.py
@@ -0,0 +1,164 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tensorboard.util.grpc_util`."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import hashlib
+import threading
+
+from concurrent import futures
+import grpc
+import six
+
+from tensorboard.util import grpc_util
+from tensorboard.util import grpc_util_test_pb2
+from tensorboard.util import grpc_util_test_pb2_grpc
+from tensorboard.util import test_util
+from tensorboard import test as tb_test
+from tensorboard import version
+
+
+def make_request(nonce):
+ return grpc_util_test_pb2.TestRpcRequest(nonce=nonce)
+
+
+def make_response(nonce):
+ return grpc_util_test_pb2.TestRpcResponse(nonce=nonce)
+
+
+class TestGrpcServer(grpc_util_test_pb2_grpc.TestServiceServicer):
+ """Helper for testing gRPC client logic with a dummy gRPC server."""
+
+ def __init__(self, handler):
+ super(TestGrpcServer, self).__init__()
+ self._handler = handler
+
+ def TestRpc(self, request, context):
+ return self._handler(request, context)
+
+ @contextlib.contextmanager
+ def run(self):
+ """Context manager to run the gRPC server and yield a client for it."""
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
+ grpc_util_test_pb2_grpc.add_TestServiceServicer_to_server(self, server)
+ port = server.add_secure_port(
+ "localhost:0", grpc.local_server_credentials())
+ def launch_server():
+ server.start()
+ server.wait_for_termination()
+ thread = threading.Thread(target=launch_server, name="TestGrpcServer")
+ thread.daemon = True
+ thread.start()
+ with grpc.secure_channel(
+ "localhost:%d" % port, grpc.local_channel_credentials()) as channel:
+ yield grpc_util_test_pb2_grpc.TestServiceStub(channel)
+ server.stop(grace=None)
+ thread.join()
+
+
+class CallWithRetriesTest(tb_test.TestCase):
+
+ def test_call_with_retries_succeeds(self):
+ def handler(request, _):
+ return make_response(request.nonce)
+ server = TestGrpcServer(handler)
+ with server.run() as client:
+ response = grpc_util.call_with_retries(client.TestRpc, make_request(42))
+ self.assertEqual(make_response(42), response)
+
+ def test_call_with_retries_fails_immediately_on_permanent_error(self):
+ def handler(_, context):
+ context.abort(grpc.StatusCode.INTERNAL, "foo")
+ server = TestGrpcServer(handler)
+ with server.run() as client:
+ with self.assertRaises(grpc.RpcError) as raised:
+ grpc_util.call_with_retries(client.TestRpc, make_request(42))
+ self.assertEqual(grpc.StatusCode.INTERNAL, raised.exception.code())
+ self.assertEqual("foo", raised.exception.details())
+
+ def test_call_with_retries_fails_after_backoff_on_nonpermanent_error(self):
+ attempt_times = []
+ fake_time = test_util.FakeTime()
+ def handler(_, context):
+ attempt_times.append(fake_time.time())
+ context.abort(grpc.StatusCode.UNAVAILABLE, "foo")
+ server = TestGrpcServer(handler)
+ with server.run() as client:
+ with self.assertRaises(grpc.RpcError) as raised:
+ grpc_util.call_with_retries(client.TestRpc, make_request(42), fake_time)
+ self.assertEqual(grpc.StatusCode.UNAVAILABLE, raised.exception.code())
+ self.assertEqual("foo", raised.exception.details())
+ self.assertLen(attempt_times, 5)
+ self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4)
+ self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8)
+ self.assertBetween(attempt_times[3] - attempt_times[2], 8, 16)
+ self.assertBetween(attempt_times[4] - attempt_times[3], 16, 32)
+
+ def test_call_with_retries_succeeds_after_backoff_on_transient_error(self):
+ attempt_times = []
+ fake_time = test_util.FakeTime()
+ def handler(request, context):
+ attempt_times.append(fake_time.time())
+ if len(attempt_times) < 3:
+ context.abort(grpc.StatusCode.UNAVAILABLE, "foo")
+ return make_response(request.nonce)
+ server = TestGrpcServer(handler)
+ with server.run() as client:
+ response = grpc_util.call_with_retries(
+ client.TestRpc, make_request(42), fake_time)
+ self.assertEqual(make_response(42), response)
+ self.assertLen(attempt_times, 3)
+ self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4)
+ self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8)
+
+ def test_call_with_retries_includes_version_metadata(self):
+ def digest(s):
+ """Hashes a string into a 32-bit integer."""
+ return int(hashlib.sha256(s.encode("utf-8")).hexdigest(), 16) & 0xffffffff
+ def handler(request, context):
+ metadata = context.invocation_metadata()
+ client_version = grpc_util.extract_version(metadata)
+ return make_response(digest(client_version))
+ server = TestGrpcServer(handler)
+ with server.run() as client:
+ response = grpc_util.call_with_retries(client.TestRpc, make_request(0))
+ expected_nonce = digest(
+ grpc_util.extract_version(grpc_util.version_metadata()))
+ self.assertEqual(make_response(expected_nonce), response)
+
+
+class VersionMetadataTest(tb_test.TestCase):
+
+ def test_structure(self):
+ result = grpc_util.version_metadata()
+ self.assertIsInstance(result, tuple)
+ for kv in result:
+ self.assertIsInstance(kv, tuple)
+ self.assertLen(kv, 2)
+ (k, v) = kv
+ self.assertIsInstance(k, str)
+ self.assertIsInstance(v, six.string_types)
+
+ def test_roundtrip(self):
+ result = grpc_util.extract_version(grpc_util.version_metadata())
+ self.assertEqual(result, version.VERSION)
+
+
+if __name__ == "__main__":
+ tb_test.main()
diff --git a/tensorboard/util/test_util.py b/tensorboard/util/test_util.py
index c970dcf472..44d56fa164 100644
--- a/tensorboard/util/test_util.py
+++ b/tensorboard/util/test_util.py
@@ -128,6 +128,22 @@ def get(logdir):
return FileWriterCache._cache[logdir]
+class FakeTime(object):
+ """Thread-safe fake replacement for the `time` module."""
+
+ def __init__(self, current=0.0):
+ self._time = float(current)
+ self._lock = threading.Lock()
+
+ def time(self):
+ with self._lock:
+ return self._time
+
+ def sleep(self, secs):
+ with self._lock:
+ self._time += secs
+
+
def ensure_tb_summary_proto(summary):
"""Ensures summary is TensorBoard Summary proto.
diff --git a/tensorboard/version.py b/tensorboard/version.py
index 9d80d2e29a..09fea74ee6 100644
--- a/tensorboard/version.py
+++ b/tensorboard/version.py
@@ -15,4 +15,4 @@
"""Contains the version string."""
-VERSION = '2.0.0'
+VERSION = '2.0.1'
diff --git a/third_party/mock_call_assertions.patch b/third_party/mock_call_assertions.patch
new file mode 100644
index 0000000000..bc28e1b633
--- /dev/null
+++ b/third_party/mock_call_assertions.patch
@@ -0,0 +1,59 @@
+--- mock.py 2012-10-07 18:00:10.000000000 +0100
++++ mock.py 2019-10-24 22:19:25.657417082 -0700
+@@ -286,6 +286,12 @@
+ if not _is_instance_mock(mock):
+ return
+
++ def assert_called(*args, **kwargs):
++ return mock.assert_called(*args, **kwargs)
++ def assert_not_called(*args, **kwargs):
++ return mock.assert_not_called(*args, **kwargs)
++ def assert_called_once(*args, **kwargs):
++ return mock.assert_called_once(*args, **kwargs)
+ def assert_called_with(*args, **kwargs):
+ return mock.assert_called_with(*args, **kwargs)
+ def assert_called_once_with(*args, **kwargs):
+@@ -318,6 +324,9 @@
+ funcopy.assert_has_calls = assert_has_calls
+ funcopy.assert_any_call = assert_any_call
+ funcopy.reset_mock = reset_mock
++ funcopy.assert_called = assert_called
++ funcopy.assert_not_called = assert_not_called
++ funcopy.assert_called_once = assert_called_once
+
+ mock._mock_delegate = funcopy
+
+@@ -809,6 +818,33 @@
+ return message % (expected_string, actual_string)
+
+
++ def assert_not_called(_mock_self):
++ """assert that the mock was never called.
++ """
++ self = _mock_self
++ if self.call_count != 0:
++ msg = ("Expected '%s' to not have been called. Called %s times." %
++ (self._mock_name or 'mock', self.call_count))
++ raise AssertionError(msg)
++
++ def assert_called(_mock_self):
++ """assert that the mock was called at least once
++ """
++ self = _mock_self
++ if self.call_count == 0:
++ msg = ("Expected '%s' to have been called." %
++ self._mock_name or 'mock')
++ raise AssertionError(msg)
++
++ def assert_called_once(_mock_self):
++ """assert that the mock was called only once.
++ """
++ self = _mock_self
++ if not self.call_count == 1:
++ msg = ("Expected '%s' to have been called once. Called %s times." %
++ (self._mock_name or 'mock', self.call_count))
++ raise AssertionError(msg)
++
+ def assert_called_with(_mock_self, *args, **kwargs):
+ """assert that the mock was called with the specified arguments.
+
diff --git a/third_party/python.bzl b/third_party/python.bzl
index 02652257af..b3c41e2c8a 100644
--- a/third_party/python.bzl
+++ b/third_party/python.bzl
@@ -114,6 +114,14 @@ def tensorboard_python_workspace():
sha256 = "2d9fbe67001d2e8f02692075257f3c11e1b0194bd838c8ce3f49b31fc6c3f033",
strip_prefix = "mock-1.0.0",
build_file = str(Label("//third_party:mock.BUILD")),
+ patches = [
+ # `mock==1.0.0` lacks some assertion methods present in
+ # later versions of `mock` (see comment above for why we pin
+ # to this version). Patch created by diffing the pinned
+ # `mock.py` with GitHub head and identifying all the bits
+ # that looked related to the methods in question.
+ "//third_party:mock_call_assertions.patch",
+ ],
)
http_archive(