diff --git a/.travis.yml b/.travis.yml index 2015746981..1f55e9d1e7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -53,7 +53,10 @@ install: - pip install yamllint==1.17.0 # TensorBoard deps. - pip install futures==3.1.1 - - pip install grpcio==1.6.3 + - pip install grpcio==1.24.3 + - pip install grpcio-testing==1.24.3 + - pip install 'google-auth >= 1.6.3, < 2' + - pip install 'google-auth-oauthlib >= 0.4.1, < 0.5' - yarn install --ignore-engines # Uninstall older Travis numpy to avoid upgrade-in-place issues. - pip uninstall -y numpy @@ -66,13 +69,6 @@ install: pip install "absl-py>=0.7.0" \ && pip install "numpy<2.0,>=1.14.5" fi - - | - # On TF 1.x nightlies, downgrade estimator temporarily. - case "${TF_VERSION_ID}" in - tf-nightly|tf-nightly==*) - pip install tf-estimator-nightly==1.14.0.dev2019091701 - ;; - esac # Deps for gfile S3 test. - pip install boto3==1.9.86 - pip install moto==1.3.7 diff --git a/RELEASE.md b/RELEASE.md index 45497d83b8..ff521e13a5 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,9 @@ +# Release 2.0.1 + +## Features +- Preview of TensorBoard.dev uploader! Check out for + information and usage instructions. + # Release 2.0.0 The 2.0 minor series tracks TensorFlow 2.0. diff --git a/tensorboard/BUILD b/tensorboard/BUILD index ca9522823a..7d6df761f5 100644 --- a/tensorboard/BUILD +++ b/tensorboard/BUILD @@ -28,6 +28,7 @@ py_binary( ":program", "//tensorboard:expect_tensorflow_installed", "//tensorboard/plugins:base_plugin", + "//tensorboard/uploader:uploader_main_lib", "//tensorboard/util:tb_logging", ], ) @@ -150,6 +151,7 @@ py_library( "//tensorboard:expect_absl_logging_installed", "//tensorboard/backend:application", "//tensorboard/backend/event_processing:event_file_inspector", + "//tensorboard/util:argparse_util", "@org_pocoo_werkzeug", "@org_pythonhosted_six", ], @@ -169,6 +171,7 @@ py_test( "//tensorboard/plugins:base_plugin", "//tensorboard/plugins/core:core_plugin", "@org_pocoo_werkzeug", + "@org_pythonhosted_mock", ], ) @@ -274,6 +277,22 @@ py_library( visibility = ["//visibility:public"], ) +py_library( + name = "expect_grpc_installed", + # This is a dummy rule used as a grpc dependency in open-source. + # We expect grpc to already be installed on the system, e.g. via + # `pip install grpcio` + visibility = ["//visibility:public"], +) + +py_library( + name = "expect_grpc_testing_installed", + # This is a dummy rule used as a grpc_testing dependency in open-source. + # We expect grpc_testing to already be installed on the system, e.g. via + # `pip install grpcio_testing` + visibility = ["//visibility:public"], +) + py_library( name = "expect_sqlite3_installed", # This is a dummy rule used as a sqlite3 dependency in open-source. @@ -306,6 +325,14 @@ py_library( visibility = ["//visibility:public"], ) +py_library( + name = "expect_absl_flags_argparse_flags_installed", + # This is a dummy rule used as a absl-py dependency in open-source. + # We expect absl-py to already be installed on the system, e.g. via + # `pip install absl-py` + visibility = ["//visibility:public"], +) + py_library( name = "expect_absl_logging_installed", # This is a dummy rule used as a absl-py dependency in open-source. @@ -322,6 +349,22 @@ py_library( visibility = ["//visibility:public"], ) +py_library( + name = "expect_google_auth_installed", + # This is a dummy rule used as a google_auth dependency in open-source. + # We expect google_auth to already be installed on the system, e.g., via + # `pip install google-auth`. + visibility = ["//visibility:public"], +) + +py_library( + name = "expect_google_auth_oauthlib_installed", + # This is a dummy rule used as a google_auth oauthlib_dependency in open-source. + # We expect google_auth_oauthlib to already be installed on the system, e.g., via + # `pip install google-auth-oauthlib`. + visibility = ["//visibility:public"], +) + py_library( name = "expect_pkg_resources_installed", # This is a dummy rule used as a pkg-resources dependency in open-source. diff --git a/tensorboard/defs/BUILD b/tensorboard/defs/BUILD index f219f54c9c..368b2908ef 100644 --- a/tensorboard/defs/BUILD +++ b/tensorboard/defs/BUILD @@ -1,5 +1,7 @@ package(default_visibility = ["//tensorboard:internal"]) +load("//tensorboard/defs:protos.bzl", "tb_proto_library") + licenses(["notice"]) # Apache 2.0 filegroup( @@ -13,4 +15,33 @@ filegroup( visibility = ["//visibility:public"], ) +py_test( + name = "tb_proto_library_test", + srcs = ["tb_proto_library_test.py"], + deps = [ + ":test_base_py_pb2", + ":test_base_py_pb2_grpc", + ":test_downstream_py_pb2", + ":test_downstream_py_pb2_grpc", + "//tensorboard:test", + ], +) + +tb_proto_library( + name = "test_base", + srcs = ["test_base.proto"], + has_services = True, + testonly = True, +) + +tb_proto_library( + name = "test_downstream", + srcs = ["test_downstream.proto"], + deps = [ + ":test_base", + ], + has_services = True, + testonly = True, +) + exports_files(["web_test_python_stub.template.py"]) diff --git a/tensorboard/defs/protos.bzl b/tensorboard/defs/protos.bzl index 58e8f27c15..3c18bc1000 100644 --- a/tensorboard/defs/protos.bzl +++ b/tensorboard/defs/protos.bzl @@ -12,16 +12,56 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@com_google_protobuf//:protobuf.bzl", "py_proto_library") +load("@com_google_protobuf//:protobuf.bzl", "proto_gen") -def tb_proto_library(name, srcs=None, visibility=None, testonly=None): - py_proto_library( - name = name + "_py_pb2", - srcs = srcs, - srcs_version = "PY2AND3", - deps = ["@com_google_protobuf//:protobuf_python"], - protoc = "@com_google_protobuf//:protoc", - visibility = visibility, - default_runtime = "@com_google_protobuf//:protobuf_python", - testonly = testonly, - ) +def tb_proto_library( + name, + srcs = None, + deps = [], + visibility = None, + testonly = None, + has_services = False): + outs_proto = _PyOuts(srcs, grpc = False) + outs_grpc = _PyOuts(srcs, grpc = True) if has_services else [] + outs_all = outs_proto + outs_grpc + + runtime = "@com_google_protobuf//:protobuf_python" + + proto_gen( + name = name + "_genproto", + srcs = srcs, + deps = [s + "_genproto" for s in deps] + [runtime + "_genproto"], + includes = [], + protoc = "@com_google_protobuf//:protoc", + gen_py = True, + outs = outs_all, + visibility = ["//visibility:public"], + plugin = "//external:grpc_python_plugin" if has_services else None, + plugin_language = "grpc", + ) + + py_deps = [s + "_py_pb2" for s in deps] + [runtime] + native.py_library( + name = name + "_py_pb2", + srcs = outs_proto, + imports = [], + srcs_version = "PY2AND3", + deps = py_deps, + testonly = testonly, + visibility = visibility, + ) + if has_services: + native.py_library( + name = name + "_py_pb2_grpc", + srcs = outs_grpc, + imports = [], + srcs_version = "PY2AND3", + deps = [name + "_py_pb2"] + py_deps, + testonly = testonly, + visibility = visibility, + ) + +def _PyOuts(srcs, grpc): + # Adapted from @com_google_protobuf//:protobuf.bzl. + ext = "_pb2.py" if not grpc else "_pb2_grpc.py" + return [s[:-len(".proto")] + ext for s in srcs] diff --git a/tensorboard/defs/tb_proto_library_test.py b/tensorboard/defs/tb_proto_library_test.py new file mode 100644 index 0000000000..425e1d06f9 --- /dev/null +++ b/tensorboard/defs/tb_proto_library_test.py @@ -0,0 +1,44 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the `tb_proto_library` build macro.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorboard import test as tb_test +from tensorboard.defs import test_base_pb2 +from tensorboard.defs import test_base_pb2_grpc +from tensorboard.defs import test_downstream_pb2 +from tensorboard.defs import test_downstream_pb2_grpc + + +class TbProtoLibraryTest(tb_test.TestCase): + """Tests for `tb_proto_library`.""" + + def tests_with_deps(self): + foo = test_base_pb2.Foo() + foo.foo = 1 + bar = test_downstream_pb2.Bar() + bar.foo.foo = 1 + self.assertEqual(foo, bar.foo) + + def test_service_deps(self): + self.assertIsInstance(test_base_pb2_grpc.FooServiceServicer, type) + self.assertIsInstance(test_downstream_pb2_grpc.BarServiceServicer, type) + + +if __name__ == "__main__": + tb_test.main() diff --git a/tensorboard/defs/test_base.proto b/tensorboard/defs/test_base.proto new file mode 100644 index 0000000000..6795667916 --- /dev/null +++ b/tensorboard/defs/test_base.proto @@ -0,0 +1,33 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +message Foo { + int32 foo = 1; +} + +service FooService { + // Loads some objects. + rpc GetFoo(GetFooRequest) returns (GetFooResponse); +} + +message GetFooRequest { + int32 count = 1; +} + +message GetFooResponse { + repeated Foo foo = 1; +} diff --git a/tensorboard/defs/test_downstream.proto b/tensorboard/defs/test_downstream.proto new file mode 100644 index 0000000000..0ff23a0b82 --- /dev/null +++ b/tensorboard/defs/test_downstream.proto @@ -0,0 +1,36 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +import "tensorboard/defs/test_base.proto"; + +message Bar { + Foo foo = 1; + int32 bar = 2; +} + +service BarService { + // Loads some objects. + rpc GetBar(GetBarRequest) returns (GetBarResponse); +} + +message GetBarRequest { + int32 count = 1; +} + +message GetBarResponse { + repeated Bar bar = 1; +} diff --git a/tensorboard/main.py b/tensorboard/main.py index da0521a852..5ce413cfe1 100644 --- a/tensorboard/main.py +++ b/tensorboard/main.py @@ -41,6 +41,7 @@ from tensorboard import program from tensorboard.compat import tf from tensorboard.plugins import base_plugin +from tensorboard.uploader import uploader_main from tensorboard.util import tb_logging @@ -56,7 +57,8 @@ def run_main(): tensorboard = program.TensorBoard( default.get_plugins() + default.get_dynamic_plugins(), - program.get_default_assets_zip_provider()) + program.get_default_assets_zip_provider(), + subcommands=[uploader_main.UploaderSubcommand()]) try: from absl import app # Import this to check that app.run() will accept the flags_parser argument. diff --git a/tensorboard/pip_package/setup.py b/tensorboard/pip_package/setup.py index 22141d92c8..a33543f872 100644 --- a/tensorboard/pip_package/setup.py +++ b/tensorboard/pip_package/setup.py @@ -26,7 +26,9 @@ 'absl-py >= 0.4', # futures is a backport of the python 3.2+ concurrent.futures module 'futures >= 3.1.1; python_version < "3"', - 'grpcio >= 1.6.3', + 'grpcio >= 1.24.3', + 'google-auth >= 1.6.3, < 2', + 'google-auth-oauthlib >= 0.4.1, < 0.5', 'markdown >= 2.6.8', 'numpy >= 1.12.0', 'protobuf >= 3.6.0', diff --git a/tensorboard/program.py b/tensorboard/program.py index 2bb34581bb..d6069802f6 100644 --- a/tensorboard/program.py +++ b/tensorboard/program.py @@ -56,6 +56,7 @@ from tensorboard.backend.event_processing import event_file_inspector as efi from tensorboard.plugins import base_plugin from tensorboard.plugins.core import core_plugin +from tensorboard.util import argparse_util from tensorboard.util import tb_logging try: @@ -68,6 +69,11 @@ logger = tb_logging.get_logger() +# Default subcommand name. This is a user-facing CLI and should not change. +_SERVE_SUBCOMMAND_NAME = 'serve' +# Internal flag name used to store which subcommand was invoked. +_SUBCOMMAND_FLAG = '__tensorboard_subcommand' + def setup_environment(): """Makes recommended modifications to the environment. @@ -111,10 +117,13 @@ class TensorBoard(object): cache_key: As `manager.cache_key`; set by the configure() method. """ - def __init__(self, - plugins=None, - assets_zip_provider=None, - server_class=None): + def __init__( + self, + plugins=None, + assets_zip_provider=None, + server_class=None, + subcommands=None, + ): """Creates new instance. Args: @@ -133,9 +142,17 @@ def __init__(self, assets_zip_provider = get_default_assets_zip_provider() if server_class is None: server_class = create_port_scanning_werkzeug_server + if subcommands is None: + subcommands = [] self.plugin_loaders = [application.make_plugin_loader(p) for p in plugins] self.assets_zip_provider = assets_zip_provider self.server_class = server_class + self.subcommands = {} + for subcommand in subcommands: + name = subcommand.name() + if name in self.subcommands or name == _SERVE_SUBCOMMAND_NAME: + raise ValueError("Duplicate subcommand name: %r" % name) + self.subcommands[name] = subcommand self.flags = None def configure(self, argv=('',), **kwargs): @@ -159,15 +176,48 @@ def configure(self, argv=('',), **kwargs): Raises: ValueError: If flag values are invalid. """ - parser = argparse_flags.ArgumentParser( + + base_parser = argparse_flags.ArgumentParser( prog='tensorboard', description=('TensorBoard is a suite of web applications for ' 'inspecting and understanding your TensorFlow runs ' 'and graphs. https://github.com/tensorflow/tensorboard ')) + subparsers = base_parser.add_subparsers( + help="TensorBoard subcommand (defaults to %r)" % _SERVE_SUBCOMMAND_NAME) + + serve_subparser = subparsers.add_parser( + _SERVE_SUBCOMMAND_NAME, + help='start local TensorBoard server (default subcommand)') + serve_subparser.set_defaults(**{_SUBCOMMAND_FLAG: _SERVE_SUBCOMMAND_NAME}) + + if len(argv) < 2 or argv[1].startswith('-'): + # This invocation, if valid, must not use any subcommands: we + # don't permit flags before the subcommand name. + serve_parser = base_parser + else: + # This invocation, if valid, must use a subcommand: we don't take + # any positional arguments to `serve`. + serve_parser = serve_subparser + + for (name, subcommand) in six.iteritems(self.subcommands): + subparser = subparsers.add_parser( + name, help=subcommand.help(), description=subcommand.description()) + subparser.set_defaults(**{_SUBCOMMAND_FLAG: name}) + subcommand.define_flags(subparser) + for loader in self.plugin_loaders: - loader.define_flags(parser) + loader.define_flags(serve_parser) + arg0 = argv[0] if argv else '' - flags = parser.parse_args(argv[1:]) # Strip binary name from argv. + + with argparse_util.allow_missing_subcommand(): + flags = base_parser.parse_args(argv[1:]) # Strip binary name from argv. + if getattr(flags, _SUBCOMMAND_FLAG, None) is None: + # Manually assign default value rather than using `set_defaults` + # on the base parser to work around Python bug #9351 on old + # versions of `argparse`: + setattr(flags, _SUBCOMMAND_FLAG, _SERVE_SUBCOMMAND_NAME) + self.cache_key = manager.cache_key( working_directory=os.getcwd(), arguments=argv[1:], @@ -185,8 +235,9 @@ def configure(self, argv=('',), **kwargs): if not hasattr(flags, k): raise ValueError('Unknown TensorBoard flag: %s' % k) setattr(flags, k, v) - for loader in self.plugin_loaders: - loader.fix_flags(flags) + if getattr(flags, _SUBCOMMAND_FLAG) == _SERVE_SUBCOMMAND_NAME: + for loader in self.plugin_loaders: + loader.fix_flags(flags) self.flags = flags return [arg0] @@ -208,14 +259,24 @@ def main(self, ignored_argv=('',)): :rtype: int """ self._install_signal_handler(signal.SIGTERM, "SIGTERM") - if self.flags.inspect: - logger.info('Not bringing up TensorBoard, but inspecting event files.') - event_file = os.path.expanduser(self.flags.event_file) - efi.inspect(self.flags.logdir, event_file, self.flags.tag) - return 0 - if self.flags.version_tb: + subcommand_name = getattr(self.flags, _SUBCOMMAND_FLAG) + if subcommand_name == _SERVE_SUBCOMMAND_NAME: + runner = self._run_serve_subcommand + else: + runner = self.subcommands[subcommand_name].run + return runner(self.flags) or 0 + + def _run_serve_subcommand(self, flags): + # TODO(#2801): Make `--version` a flag on only the base parser, not `serve`. + if flags.version_tb: print(version.VERSION) return 0 + if flags.inspect: + # TODO(@wchargin): Convert `inspect` to a normal subcommand? + logger.info('Not bringing up TensorBoard, but inspecting event files.') + event_file = os.path.expanduser(flags.event_file) + efi.inspect(flags.logdir, event_file, flags.tag) + return 0 try: server = self._make_server() server.print_serving_message() @@ -300,6 +361,56 @@ def _make_server(self): return self.server_class(app, self.flags) +@six.add_metaclass(ABCMeta) +class TensorBoardSubcommand(object): + """Experimental private API for defining subcommands to tensorboard(1).""" + + @abstractmethod + def name(self): + """Name of this subcommand, as specified on the command line. + + This must be unique across all subcommands. + + Returns: + A string. + """ + pass + + @abstractmethod + def define_flags(self, parser): + """Configure an argument parser for this subcommand. + + Flags whose names start with two underscores (e.g., `__foo`) are + reserved for use by the runtime and must not be defined by + subcommands. + + Args: + parser: An `argparse.ArgumentParser` scoped to this subcommand, + which this function should mutate. + """ + pass + + @abstractmethod + def run(self, flags): + """Execute this subcommand with user-provided flags. + + Args: + flags: An `argparse.Namespace` object with all defined flags. + + Returns: + An `int` exit code, or `None` as an alias for `0`. + """ + pass + + def help(self): + """Short, one-line help text to display on `tensorboard --help`.""" + return None + + def description(self): + """Description to display on `tensorboard SUBCOMMAND --help`.""" + return None + + @six.add_metaclass(ABCMeta) class TensorBoardServer(object): """Class for customizing TensorBoard WSGI app serving.""" diff --git a/tensorboard/program_test.py b/tensorboard/program_test.py index 919fa3d493..7a4d919072 100644 --- a/tensorboard/program_test.py +++ b/tensorboard/program_test.py @@ -19,9 +19,17 @@ from __future__ import print_function import argparse +import contextlib +import sys import six +try: + # python version >= 3.3 + from unittest import mock # pylint: disable=g-import-not-at-top +except ImportError: + import mock # pylint: disable=g-import-not-at-top,unused-import + from tensorboard import program from tensorboard import test as tb_test from tensorboard.plugins import base_plugin @@ -120,5 +128,135 @@ def testSpecifiedHost(self): self.assertTrue(one_passed) # We expect either IPv4 or IPv6 to be supported +class SubcommandTest(tb_test.TestCase): + + def setUp(self): + super(SubcommandTest, self).setUp() + self.stderr = six.StringIO() + patchers = [ + mock.patch.object(program.TensorBoard, '_install_signal_handler'), + mock.patch.object(program.TensorBoard, '_run_serve_subcommand'), + mock.patch.object(_TestSubcommand, 'run'), + mock.patch.object(sys, 'stderr', self.stderr), + ] + for p in patchers: + p.start() + self.addCleanup(p.stop) + _TestSubcommand.run.return_value = None + + def tearDown(self): + stderr = self.stderr.getvalue() + if stderr: + # In case of failing tests, let there be debug info. + print('Stderr:\n%s' % stderr) + + def testImplicitServe(self): + tb = program.TensorBoard( + plugins=[core_plugin.CorePluginLoader], + subcommands=[_TestSubcommand(lambda parser: None)], + ) + tb.configure(('tb', '--logdir', 'logs', '--path_prefix', '/x/')) + tb.main() + program.TensorBoard._run_serve_subcommand.assert_called_once() + flags = program.TensorBoard._run_serve_subcommand.call_args[0][0] + self.assertEqual(flags.logdir, 'logs') + self.assertEqual(flags.path_prefix, '/x') # fixed by core_plugin + + def testExplicitServe(self): + tb = program.TensorBoard( + plugins=[core_plugin.CorePluginLoader], + subcommands=[_TestSubcommand()], + ) + tb.configure(('tb', 'serve', '--logdir', 'logs', '--path_prefix', '/x/')) + tb.main() + program.TensorBoard._run_serve_subcommand.assert_called_once() + flags = program.TensorBoard._run_serve_subcommand.call_args[0][0] + self.assertEqual(flags.logdir, 'logs') + self.assertEqual(flags.path_prefix, '/x') # fixed by core_plugin + + def testSubcommand(self): + def define_flags(parser): + parser.add_argument('--hello') + + tb = program.TensorBoard( + plugins=[core_plugin.CorePluginLoader], + subcommands=[_TestSubcommand(define_flags=define_flags)], + ) + tb.configure(('tb', 'test', '--hello', 'world')) + self.assertEqual(tb.main(), 0) + _TestSubcommand.run.assert_called_once() + flags = _TestSubcommand.run.call_args[0][0] + self.assertEqual(flags.hello, 'world') + + def testSubcommand_ExitCode(self): + tb = program.TensorBoard( + plugins=[core_plugin.CorePluginLoader], + subcommands=[_TestSubcommand()], + ) + _TestSubcommand.run.return_value = 77 + tb.configure(('tb', 'test')) + self.assertEqual(tb.main(), 77) + + def testSubcommand_DoesNotInheritBaseArgs(self): + tb = program.TensorBoard( + plugins=[core_plugin.CorePluginLoader], + subcommands=[_TestSubcommand()], + ) + with self.assertRaises(SystemExit): + tb.configure(('tb', 'test', '--logdir', 'logs')) + self.assertIn( + 'unrecognized arguments: --logdir logs', self.stderr.getvalue()) + self.stderr.truncate(0) + + def testSubcommand_MayRequirePositionals(self): + def define_flags(parser): + parser.add_argument('payload') + + tb = program.TensorBoard( + plugins=[core_plugin.CorePluginLoader], + subcommands=[_TestSubcommand(define_flags=define_flags)], + ) + with self.assertRaises(SystemExit): + tb.configure(('tb', 'test')) + self.assertIn('required', self.stderr.getvalue()) + self.assertIn('payload', self.stderr.getvalue()) + self.stderr.truncate(0) + + def testConflictingNames_AmongSubcommands(self): + with self.assertRaises(ValueError) as cm: + tb = program.TensorBoard( + plugins=[core_plugin.CorePluginLoader], + subcommands=[_TestSubcommand(), _TestSubcommand()], + ) + self.assertIn('Duplicate subcommand name:', str(cm.exception)) + self.assertIn('test', str(cm.exception)) + + def testConflictingNames_WithServe(self): + with self.assertRaises(ValueError) as cm: + tb = program.TensorBoard( + plugins=[core_plugin.CorePluginLoader], + subcommands=[_TestSubcommand(name='serve')], + ) + self.assertIn('Duplicate subcommand name:', str(cm.exception)) + self.assertIn('serve', str(cm.exception)) + + +class _TestSubcommand(program.TensorBoardSubcommand): + + def __init__(self, name=None, define_flags=None): + self._name = name + self._define_flags = define_flags + + def name(self): + return self._name or 'test' + + def define_flags(self, parser): + if self._define_flags: + self._define_flags(parser) + + def run(self, flags): + pass + + if __name__ == '__main__': tb_test.main() diff --git a/tensorboard/tools/whitespace_hygiene_test.py b/tensorboard/tools/whitespace_hygiene_test.py index 7a46a300a9..3e127811af 100755 --- a/tensorboard/tools/whitespace_hygiene_test.py +++ b/tensorboard/tools/whitespace_hygiene_test.py @@ -28,8 +28,10 @@ import sys -# Remove files from this list as whitespace errors are fixed. exceptions = frozenset([ + # End-of-line whitespace is semantic in patch files when a line + # contains a single space. + "third_party/mock_call_assertions.patch", ]) diff --git a/tensorboard/uploader/BUILD b/tensorboard/uploader/BUILD new file mode 100644 index 0000000000..a0b25720ba --- /dev/null +++ b/tensorboard/uploader/BUILD @@ -0,0 +1,203 @@ +# Description: +# Uploader for TensorBoard.dev + +package(default_visibility = ["//tensorboard:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_library( + name = "dev_creds", + srcs = ["dev_creds.py"], +) + +py_library( + name = "exporter_lib", + srcs = ["exporter.py"], + deps = [ + ":util", + "//tensorboard:expect_grpc_installed", + "//tensorboard/uploader/proto:protos_all_py_pb2", + "//tensorboard/util:grpc_util", + ], +) + +py_test( + name = "exporter_test", + srcs = ["exporter_test.py"], + deps = [ + ":exporter_lib", + ":test_util", + "//tensorboard:expect_grpc_installed", + "//tensorboard:expect_grpc_testing_installed", + "//tensorboard:test", + "//tensorboard/compat/proto:protos_all_py_pb2", + "//tensorboard/uploader/proto:protos_all_py_pb2", + "//tensorboard/uploader/proto:protos_all_py_pb2_grpc", + "//tensorboard/util:grpc_util", + "@org_pythonhosted_mock", + ], +) + +py_binary( + name = "uploader", + srcs = ["uploader_main.py"], + main = "uploader_main.py", + python_version = "PY2", + deps = [":uploader_main_lib"], +) + +py_library( + name = "uploader_main_lib", + srcs = ["uploader_main.py"], + visibility = ["//tensorboard:internal"], + deps = [ + ":auth", + ":dev_creds", + ":exporter_lib", + ":uploader_lib", + "//tensorboard:expect_absl_app_installed", + "//tensorboard:expect_absl_flags_argparse_flags_installed", + "//tensorboard:expect_absl_flags_installed", + "//tensorboard:expect_absl_logging_installed", + "//tensorboard:expect_grpc_installed", + "//tensorboard:program", + "//tensorboard/plugins:base_plugin", + "//tensorboard/uploader/proto:protos_all_py_pb2_grpc", + "@org_pythonhosted_six", + ], +) + +py_library( + name = "uploader_lib", + srcs = ["uploader.py"], + deps = [ + ":logdir_loader", + ":peekable_iterator", + ":util", + "//tensorboard:data_compat", + "//tensorboard:expect_grpc_installed", + "//tensorboard/backend/event_processing:directory_loader", + "//tensorboard/backend/event_processing:event_file_loader", + "//tensorboard/backend/event_processing:io_wrapper", + "//tensorboard/plugins/scalar:metadata", + "//tensorboard/uploader/proto:protos_all_py_pb2", + "//tensorboard/util:grpc_util", + "//tensorboard/util:tb_logging", + "//tensorboard/util:tensor_util", + "@org_pythonhosted_six", + ], +) + +py_test( + name = "uploader_test", + srcs = ["uploader_test.py"], + deps = [ + ":test_util", + ":uploader_lib", + ":util", + "//tensorboard:expect_grpc_installed", + "//tensorboard:expect_grpc_testing_installed", + "//tensorboard:expect_tensorflow_installed", + "//tensorboard/compat/proto:protos_all_py_pb2", + "//tensorboard/plugins/histogram:summary_v2", + "//tensorboard/plugins/scalar:summary_v2", + "//tensorboard/summary:summary_v1", + "//tensorboard/uploader/proto:protos_all_py_pb2", + "//tensorboard/uploader/proto:protos_all_py_pb2_grpc", + "//tensorboard/util:test_util", + "@org_pythonhosted_mock", + ], +) + +py_library( + name = "auth", + srcs = ["auth.py"], + deps = [ + ":util", + "//tensorboard:expect_google_auth_installed", + "//tensorboard:expect_google_auth_oauthlib_installed", + "//tensorboard:expect_grpc_installed", + "//tensorboard/util:tb_logging", + ], +) + +py_test( + name = "auth_test", + srcs = ["auth_test.py"], + deps = [ + ":auth", + "//tensorboard:expect_google_auth_installed", + "//tensorboard:test", + "@org_pythonhosted_mock", + ], +) + +py_library( + name = "logdir_loader", + srcs = ["logdir_loader.py"], + deps = [ + "//tensorboard/backend/event_processing:directory_watcher", + "//tensorboard/backend/event_processing:io_wrapper", + "//tensorboard/util:tb_logging", + ], +) + +py_test( + name = "logdir_loader_test", + srcs = ["logdir_loader_test.py"], + deps = [ + ":logdir_loader", + "//tensorboard:test", + "//tensorboard/backend/event_processing:directory_loader", + "//tensorboard/backend/event_processing:event_file_loader", + "//tensorboard/backend/event_processing:io_wrapper", + "//tensorboard/util:test_util", + "@org_pythonhosted_six", + ], +) + +py_library( + name = "test_util", + testonly = 1, + srcs = ["test_util.py"], + deps = [ + "//tensorboard:expect_grpc_installed", + "//tensorboard/compat/proto:protos_all_py_pb2", + "@com_google_protobuf//:protobuf_python", + ], +) + +py_library( + name = "util", + srcs = ["util.py"], +) + +py_test( + name = "util_test", + srcs = ["util_test.py"], + deps = [ + ":test_util", + ":util", + "//tensorboard:test", + "@com_google_protobuf//:protobuf_python", + "@org_pythonhosted_mock", + ], +) + +py_library( + name = "peekable_iterator", + srcs = ["peekable_iterator.py"], + srcs_version = "PY2AND3", +) + +py_test( + name = "peekable_iterator_test", + size = "small", + srcs = ["peekable_iterator_test.py"], + deps = [ + ":peekable_iterator", + "//tensorboard:test", + ], +) diff --git a/tensorboard/uploader/auth.py b/tensorboard/uploader/auth.py new file mode 100644 index 0000000000..a435c6ff5a --- /dev/null +++ b/tensorboard/uploader/auth.py @@ -0,0 +1,227 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Lint as: python3 +"""Provides authentication support for TensorBoardUploader.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import errno +import json +import os +import sys +import webbrowser + +import google_auth_oauthlib.flow +import grpc +import google.auth +import google.auth.transport.requests +import google.oauth2.credentials + +from tensorboard.uploader import util +from tensorboard.util import tb_logging + + +logger = tb_logging.get_logger() + + +# OAuth2 scopes used for OpenID Connect: +# https://developers.google.com/identity/protocols/OpenIDConnect#scope-param +OPENID_CONNECT_SCOPES = ( + "openid", + "https://www.googleapis.com/auth/userinfo.email", +) + + +# The client "secret" is public by design for installed apps. See +# https://developers.google.com/identity/protocols/OAuth2?csw=1#installed +OAUTH_CLIENT_CONFIG = u""" +{ + "installed": { + "client_id": "373649185512-8v619h5kft38l4456nm2dj4ubeqsrvh6.apps.googleusercontent.com", + "project_id": "hosted-tensorboard-prod", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_secret": "pOyAuU2yq2arsM98Bw5hwYtr", + "redirect_uris": [ + "urn:ietf:wg:oauth:2.0:oob", + "http://localhost" + ] + } +} +""" + + +# Components of the relative path (within the user settings directory) at which +# to store TensorBoard uploader credentials. +TENSORBOARD_CREDENTIALS_FILEPATH_PARTS = [ + "tensorboard", "credentials", "uploader-creds.json"] + + +class CredentialsStore(object): + """Private file store for a `google.oauth2.credentials.Credentials`.""" + + _DEFAULT_CONFIG_DIRECTORY = object() # Sentinel value. + + def __init__(self, user_config_directory=_DEFAULT_CONFIG_DIRECTORY): + """Creates a CredentialsStore. + + Args: + user_config_directory: Optional absolute path to the root directory for + storing user configs, under which to store the credentials file. If not + set, defaults to a platform-specific location. If set to None, the + store is disabled (reads return None; write and clear are no-ops). + """ + if user_config_directory is CredentialsStore._DEFAULT_CONFIG_DIRECTORY: + user_config_directory = util.get_user_config_directory() + if user_config_directory is None: + logger.warning( + "Credentials caching disabled - no private config directory found") + if user_config_directory is None: + self._credentials_filepath = None + else: + self._credentials_filepath = os.path.join( + user_config_directory, *TENSORBOARD_CREDENTIALS_FILEPATH_PARTS) + + def read_credentials(self): + """Returns the current `google.oauth2.credentials.Credentials`, or None.""" + if self._credentials_filepath is None: + return None + if os.path.exists(self._credentials_filepath): + return google.oauth2.credentials.Credentials.from_authorized_user_file( + self._credentials_filepath) + return None + + def write_credentials(self, credentials): + """Writes a `google.oauth2.credentials.Credentials` to the store.""" + if not isinstance(credentials, google.oauth2.credentials.Credentials): + raise TypeError("Cannot write credentials of type %s" % type(credentials)) + if self._credentials_filepath is None: + return + # Make the credential file private if not on Windows; on Windows we rely on + # the default user config settings directory being private since we don't + # have a straightforward way to make an individual file private. + private = os.name != "nt" + util.make_file_with_directories(self._credentials_filepath, private=private) + data = { + "refresh_token": credentials.refresh_token, + "token_uri": credentials.token_uri, + "client_id": credentials.client_id, + "client_secret": credentials.client_secret, + "scopes": credentials.scopes, + "type": "authorized_user", + } + with open(self._credentials_filepath, "w") as f: + json.dump(data, f) + + def clear(self): + """Clears the store of any persisted credentials information.""" + if self._credentials_filepath is None: + return + try: + os.remove(self._credentials_filepath) + except OSError as e: + if e.errno != errno.ENOENT: + raise + + +def build_installed_app_flow(client_config): + """Returns a `CustomInstalledAppFlow` for the given config. + + Args: + client_config (Mapping[str, Any]): The client configuration in the Google + client secrets format. + + Returns: + CustomInstalledAppFlow: the constructed flow. + """ + return CustomInstalledAppFlow.from_client_config( + client_config, scopes=OPENID_CONNECT_SCOPES) + + +class CustomInstalledAppFlow(google_auth_oauthlib.flow.InstalledAppFlow): + """Customized version of the Installed App OAuth2 flow.""" + + def run(self, force_console=False): + """Run the flow using a local server if possible, otherwise the console.""" + # TODO(b/141721828): make auto-detection smarter, especially for macOS. + if not force_console and os.getenv("DISPLAY"): + try: + return self.run_local_server(port=0) + except webbrowser.Error: + sys.stderr.write("Falling back to console authentication flow...\n") + return self.run_console() + + +class IdTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin): + """A `gRPC AuthMetadataPlugin` that uses ID tokens. + + This works like the existing `google.auth.transport.grpc.AuthMetadataPlugin` + except that instead of always using access tokens, it preferentially uses the + `Credentials.id_token` property if available (and logs an error otherwise). + + See http://www.grpc.io/grpc/python/grpc.html#grpc.AuthMetadataPlugin + """ + + def __init__(self, credentials, request): + """Constructs an IdTokenAuthMetadataPlugin. + + Args: + credentials (google.auth.credentials.Credentials): The credentials to + add to requests. + request (google.auth.transport.Request): A HTTP transport request object + used to refresh credentials as needed. + """ + super(IdTokenAuthMetadataPlugin, self).__init__() + if not isinstance(credentials, google.oauth2.credentials.Credentials): + raise TypeError( + "Cannot get ID tokens from credentials type %s" % type(credentials)) + self._credentials = credentials + self._request = request + + def __call__(self, context, callback): + """Passes authorization metadata into the given callback. + + Args: + context (grpc.AuthMetadataContext): The RPC context. + callback (grpc.AuthMetadataPluginCallback): The callback that will + be invoked to pass in the authorization metadata. + """ + headers = {} + self._credentials.before_request( + self._request, context.method_name, context.service_url, headers) + id_token = getattr(self._credentials, "id_token", None) + if id_token: + self._credentials.apply(headers, token=id_token) + else: + logger.error("Failed to find ID token credentials") + # Pass headers as key-value pairs to match CallCredentials metadata. + callback(list(headers.items()), None) + + +def id_token_call_credentials(credentials): + """Constructs `grpc.CallCredentials` using `google.auth.Credentials.id_token`. + + Args: + credentials (google.auth.credentials.Credentials): The credentials to use. + + Returns: + grpc.CallCredentials: The call credentials. + """ + request = google.auth.transport.requests.Request() + return grpc.metadata_call_credentials( + IdTokenAuthMetadataPlugin(credentials, request)) diff --git a/tensorboard/uploader/auth_test.py b/tensorboard/uploader/auth_test.py new file mode 100644 index 0000000000..587109f625 --- /dev/null +++ b/tensorboard/uploader/auth_test.py @@ -0,0 +1,140 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Lint as: python3 +"""Tests for tensorboard.uploader.auth.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import os + +import google.auth.credentials +import google.oauth2.credentials + +from tensorboard.uploader import auth +from tensorboard import test as tb_test + + +class CredentialsStoreTest(tb_test.TestCase): + + def test_no_config_dir(self): + store = auth.CredentialsStore(user_config_directory=None) + self.assertIsNone(store.read_credentials()) + creds = google.oauth2.credentials.Credentials(token=None) + store.write_credentials(creds) + store.clear() + + def test_clear_existent_file(self): + root = self.get_temp_dir() + path = os.path.join( + root, "tensorboard", "credentials", "uploader-creds.json") + os.makedirs(os.path.dirname(path)) + open(path, mode="w").close() + self.assertTrue(os.path.exists(path)) + auth.CredentialsStore(user_config_directory=root).clear() + self.assertFalse(os.path.exists(path)) + + def test_clear_nonexistent_file(self): + root = self.get_temp_dir() + path = os.path.join( + root, "tensorboard", "credentials", "uploader-creds.json") + self.assertFalse(os.path.exists(path)) + auth.CredentialsStore(user_config_directory=root).clear() + self.assertFalse(os.path.exists(path)) + + def test_write_wrong_type(self): + creds = google.auth.credentials.AnonymousCredentials() + with self.assertRaisesRegex(TypeError, "google.auth.credentials"): + auth.CredentialsStore(user_config_directory=None).write_credentials(creds) + + def test_write_creates_private_file(self): + root = self.get_temp_dir() + auth.CredentialsStore(user_config_directory=root).write_credentials( + google.oauth2.credentials.Credentials( + token=None, refresh_token="12345")) + path = os.path.join( + root, "tensorboard", "credentials", "uploader-creds.json") + self.assertTrue(os.path.exists(path)) + # Skip permissions check on Windows. + if os.name != "nt": + self.assertEqual(0o600, os.stat(path).st_mode & 0o777) + with open(path) as f: + contents = json.load(f) + self.assertEqual("12345", contents["refresh_token"]) + + def test_write_overwrites_file(self): + root = self.get_temp_dir() + store = auth.CredentialsStore(user_config_directory=root) + # Write twice to ensure that we're overwriting correctly. + store.write_credentials(google.oauth2.credentials.Credentials( + token=None, refresh_token="12345")) + store.write_credentials(google.oauth2.credentials.Credentials( + token=None, refresh_token="67890")) + path = os.path.join( + root, "tensorboard", "credentials", "uploader-creds.json") + self.assertTrue(os.path.exists(path)) + with open(path) as f: + contents = json.load(f) + self.assertEqual("67890", contents["refresh_token"]) + + def test_write_and_read_roundtrip(self): + orig_creds = google.oauth2.credentials.Credentials( + token="12345", + refresh_token="67890", + token_uri="https://oauth2.googleapis.com/token", + client_id="my-client", + client_secret="123abc456xyz", + scopes=["userinfo", "email"]) + root = self.get_temp_dir() + store = auth.CredentialsStore(user_config_directory=root) + store.write_credentials(orig_creds) + creds = store.read_credentials() + self.assertEqual(orig_creds.refresh_token, creds.refresh_token) + self.assertEqual(orig_creds.token_uri, creds.token_uri) + self.assertEqual(orig_creds.client_id, creds.client_id) + self.assertEqual(orig_creds.client_secret, creds.client_secret) + + def test_read_nonexistent_file(self): + root = self.get_temp_dir() + store = auth.CredentialsStore(user_config_directory=root) + self.assertIsNone(store.read_credentials()) + + def test_read_non_json_file(self): + root = self.get_temp_dir() + store = auth.CredentialsStore(user_config_directory=root) + path = os.path.join( + root, "tensorboard", "credentials", "uploader-creds.json") + os.makedirs(os.path.dirname(path)) + with open(path, mode="w") as f: + f.write("foobar") + with self.assertRaises(ValueError): + store.read_credentials() + + def test_read_invalid_json_file(self): + root = self.get_temp_dir() + store = auth.CredentialsStore(user_config_directory=root) + path = os.path.join( + root, "tensorboard", "credentials", "uploader-creds.json") + os.makedirs(os.path.dirname(path)) + with open(path, mode="w") as f: + f.write("{}") + with self.assertRaises(ValueError): + store.read_credentials() + + +if __name__ == "__main__": + tb_test.main() diff --git a/tensorboard/uploader/dev_creds.py b/tensorboard/uploader/dev_creds.py new file mode 100644 index 0000000000..fe4617e3ec --- /dev/null +++ b/tensorboard/uploader/dev_creds.py @@ -0,0 +1,29 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Module providing access to dev GRPC SSL credentials.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +DEV_SSL_CERT = u""" +""" + +DEV_SSL_CERT_KEY = u""" +""" + +DEV_OAUTH_CLIENT_CONFIG = u""" +""" diff --git a/tensorboard/uploader/exporter.py b/tensorboard/uploader/exporter.py new file mode 100644 index 0000000000..e7421f9a24 --- /dev/null +++ b/tensorboard/uploader/exporter.py @@ -0,0 +1,211 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Downloads experiment data from TensorBoard.dev.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import base64 +import errno +import grpc +import json +import os +import string +import time + +from tensorboard.uploader.proto import export_service_pb2 +from tensorboard.uploader import util +from tensorboard.util import grpc_util + +# Characters that are assumed to be safe in filenames. Note that the +# server's experiment IDs are base64 encodings of 16-byte blobs, so they +# can theoretically collide on case-insensitive filesystems. Each +# character has a ~3% chance of colliding, and so two random IDs have +# about a ~10^-33 chance of colliding. As a precaution, we'll still +# detect collision and fail fast rather than overwriting data. +_FILENAME_SAFE_CHARS = frozenset(string.ascii_letters + string.digits + "-_") + +# Maximum value of a signed 64-bit integer. +_MAX_INT64 = 2**63 - 1 + +class TensorBoardExporter(object): + """Exports all of the user's experiment data from TensorBoard.dev. + + Data is exported into a directory, with one file per experiment. Each + experiment file is a sequence of time series, represented as a stream + of JSON objects, one per line. Each JSON object includes a run name, + tag name, `tensorboard.compat.proto.summary_pb2.SummaryMetadata` proto + (base64-encoded, standard RFC 4648 alphabet), and set of points. + Points are stored in three equal-length lists of steps, wall times (as + seconds since epoch), and scalar values, for storage efficiency. + + Such streams of JSON objects may be conveniently processed with tools + like jq(1). + + For example one line of an experiment file might read (when + pretty-printed): + + { + "points": { + "steps": [0, 5], + "values": [4.8935227394104, 2.5438034534454346], + "wall_times": [1563406522.669238, 1563406523.0268838] + }, + "run": "lr_1E-04,conv=1,fc=2", + "summary_metadata": "CgkKB3NjYWxhcnMSC3hlbnQveGVudF8x", + "tag": "xent/xent_1" + } + + This is a time series with two points, both logged on 2019-07-17, one + about 0.36 seconds after the other. + """ + + def __init__(self, reader_service_client, output_directory): + """Constructs a TensorBoardExporter. + + Args: + reader_service_client: A TensorBoardExporterService stub instance. + output_directory: Path to a directory into which to write data. The + directory must not exist, to avoid stomping existing or concurrent + output. Its ancestors will be created if needed. + """ + self._api = reader_service_client + self._outdir = output_directory + parent_dir = os.path.dirname(self._outdir) + if parent_dir: + _mkdir_p(parent_dir) + try: + os.mkdir(self._outdir) + except OSError as e: + if e.errno == errno.EEXIST: + # Bail to avoid stomping existing output. + raise OutputDirectoryExistsError() + + def export(self, read_time=None): + """Executes the export flow. + + Args: + read_time: A fixed timestamp from which to export data, as float seconds + since epoch (like `time.time()`). Optional; defaults to the current + time. + + Yields: + After each experiment is successfully downloaded, the ID of that + experiment, as a string. + """ + if read_time is None: + read_time = time.time() + for experiment_id in self._request_experiment_ids(read_time): + filepath = _scalars_filepath(self._outdir, experiment_id) + try: + with _open_excl(filepath) as outfile: + data = self._request_scalar_data(experiment_id, read_time) + for block in data: + json.dump(block, outfile, sort_keys=True) + outfile.write("\n") + outfile.flush() + yield experiment_id + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.CANCELLED: + raise GrpcTimeoutException(experiment_id) + else: + raise + + def _request_experiment_ids(self, read_time): + """Yields all of the calling user's experiment IDs, as strings.""" + request = export_service_pb2.StreamExperimentsRequest(limit=_MAX_INT64) + util.set_timestamp(request.read_timestamp, read_time) + stream = self._api.StreamExperiments( + request, metadata=grpc_util.version_metadata()) + for response in stream: + for experiment_id in response.experiment_ids: + yield experiment_id + + def _request_scalar_data(self, experiment_id, read_time): + """Yields JSON-serializable blocks of scalar data.""" + request = export_service_pb2.StreamExperimentDataRequest() + request.experiment_id = experiment_id + util.set_timestamp(request.read_timestamp, read_time) + # No special error handling as we don't expect any errors from these + # calls: all experiments should exist (read consistency timestamp) + # and be owned by the calling user (only queried for own experiment + # IDs). Any non-transient errors would be internal, and we have no + # way to efficiently resume from transient errors because the server + # does not support pagination. + stream = self._api.StreamExperimentData( + request, metadata=grpc_util.version_metadata()) + for response in stream: + metadata = base64.b64encode( + response.tag_metadata.SerializeToString()).decode("ascii") + wall_times = [t.ToNanoseconds() / 1e9 for t in response.points.wall_times] + yield { + u"run": response.run_name, + u"tag": response.tag_name, + u"summary_metadata": metadata, + u"points": { + u"steps": list(response.points.steps), + u"wall_times": wall_times, + u"values": list(response.points.values), + }, + } + + +class OutputDirectoryExistsError(ValueError): + pass + + +class OutputFileExistsError(ValueError): + # Like Python 3's `__builtins__.FileExistsError`. + pass + +class GrpcTimeoutException(Exception): + def __init__(self, experiment_id): + super(GrpcTimeoutException, self).__init__(experiment_id) + self.experiment_id = experiment_id + +def _scalars_filepath(base_dir, experiment_id): + """Gets file path in which to store scalars for the given experiment.""" + # Experiment IDs from the server should be filename-safe; verify + # this before creating any files. + bad_chars = frozenset(experiment_id) - _FILENAME_SAFE_CHARS + if bad_chars: + raise RuntimeError( + "Unexpected characters ({bad_chars!r}) in experiment ID {eid!r}".format( + bad_chars=sorted(bad_chars), eid=experiment_id)) + return os.path.join(base_dir, "scalars_%s.json" % experiment_id) + + +def _mkdir_p(path): + """Like `os.makedirs(path, exist_ok=True)`, but Python 2-compatible.""" + try: + os.makedirs(path) + except OSError as e: + if e.errno != errno.EEXIST or not os.path.isdir(path): + raise + + +def _open_excl(path): + """Like `open(path, "x")`, but Python 2-compatible.""" + try: + # `os.O_EXCL` works on Windows as well as POSIX-compliant systems. + # See: + fd = os.open(path, os.O_WRONLY | os.O_CREAT | os.O_EXCL) + except OSError as e: + if e.errno == errno.EEXIST: + raise OutputFileExistsError(path) + else: + raise + return os.fdopen(fd, "w") diff --git a/tensorboard/uploader/exporter_test.py b/tensorboard/uploader/exporter_test.py new file mode 100644 index 0000000000..f4e5213a86 --- /dev/null +++ b/tensorboard/uploader/exporter_test.py @@ -0,0 +1,388 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorboard.uploader.exporter.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import base64 +import errno +import json +import os + +import grpc +import grpc_testing + +try: + # python version >= 3.3 + from unittest import mock # pylint: disable=g-import-not-at-top +except ImportError: + import mock # pylint: disable=g-import-not-at-top,unused-import + + +from tensorboard.uploader.proto import export_service_pb2 +from tensorboard.uploader.proto import export_service_pb2_grpc +from tensorboard.uploader import exporter as exporter_lib +from tensorboard.uploader import test_util +from tensorboard.util import grpc_util +from tensorboard import test as tb_test +from tensorboard.compat.proto import summary_pb2 + + +class TensorBoardExporterTest(tb_test.TestCase): + + def _create_mock_api_client(self): + # Create a stub instance (using a test channel) in order to derive a mock + # from it with autospec enabled. Mocking TensorBoardExporterServiceStub + # itself doesn't work with autospec because grpc constructs stubs via + # metaclassing. + test_channel = grpc_testing.channel( + service_descriptors=[], time=grpc_testing.strict_real_time()) + stub = export_service_pb2_grpc.TensorBoardExporterServiceStub(test_channel) + mock_api_client = mock.create_autospec(stub) + return mock_api_client + + def _make_experiments_response(self, eids): + return export_service_pb2.StreamExperimentsResponse(experiment_ids=eids) + + def test_e2e_success_case(self): + mock_api_client = self._create_mock_api_client() + mock_api_client.StreamExperiments.return_value = iter([ + export_service_pb2.StreamExperimentsResponse(experiment_ids=["789"]), + ]) + + def stream_experiments(request, **kwargs): + del request # unused + self.assertEqual(kwargs["metadata"], grpc_util.version_metadata()) + yield export_service_pb2.StreamExperimentsResponse( + experiment_ids=["123", "456"]) + yield export_service_pb2.StreamExperimentsResponse(experiment_ids=["789"]) + + def stream_experiment_data(request, **kwargs): + self.assertEqual(kwargs["metadata"], grpc_util.version_metadata()) + for run in ("train", "test"): + for tag in ("accuracy", "loss"): + response = export_service_pb2.StreamExperimentDataResponse() + response.run_name = run + response.tag_name = tag + display_name = "%s:%s" % (request.experiment_id, tag) + response.tag_metadata.CopyFrom( + test_util.scalar_metadata(display_name)) + for step in range(10): + response.points.steps.append(step) + response.points.values.append(2.0 * step) + response.points.wall_times.add( + seconds=1571084520 + step, nanos=862939144) + yield response + + mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments) + mock_api_client.StreamExperimentData = mock.Mock( + wraps=stream_experiment_data) + + outdir = os.path.join(self.get_temp_dir(), "outdir") + exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir) + start_time = 1571084846.25 + start_time_pb = test_util.timestamp_pb(1571084846250000000) + + generator = exporter.export(read_time=start_time) + expected_files = [] + self.assertTrue(os.path.isdir(outdir)) + self.assertCountEqual(expected_files, os.listdir(outdir)) + mock_api_client.StreamExperiments.assert_not_called() + mock_api_client.StreamExperimentData.assert_not_called() + + # The first iteration should request the list of experiments and + # data for one of them. + self.assertEqual(next(generator), "123") + expected_files.append("scalars_123.json") + self.assertCountEqual(expected_files, os.listdir(outdir)) + + expected_eids_request = export_service_pb2.StreamExperimentsRequest() + expected_eids_request.read_timestamp.CopyFrom(start_time_pb) + expected_eids_request.limit = 2**63 - 1 + mock_api_client.StreamExperiments.assert_called_once_with( + expected_eids_request, metadata=grpc_util.version_metadata()) + + expected_data_request = export_service_pb2.StreamExperimentDataRequest() + expected_data_request.experiment_id = "123" + expected_data_request.read_timestamp.CopyFrom(start_time_pb) + mock_api_client.StreamExperimentData.assert_called_once_with( + expected_data_request, metadata=grpc_util.version_metadata()) + + # The next iteration should just request data for the next experiment. + mock_api_client.StreamExperiments.reset_mock() + mock_api_client.StreamExperimentData.reset_mock() + self.assertEqual(next(generator), "456") + + expected_files.append("scalars_456.json") + self.assertCountEqual(expected_files, os.listdir(outdir)) + mock_api_client.StreamExperiments.assert_not_called() + expected_data_request.experiment_id = "456" + mock_api_client.StreamExperimentData.assert_called_once_with( + expected_data_request, metadata=grpc_util.version_metadata()) + + # Again, request data for the next experiment; this experiment ID + # was in the second response batch in the list of IDs. + expected_files.append("scalars_789.json") + mock_api_client.StreamExperiments.reset_mock() + mock_api_client.StreamExperimentData.reset_mock() + self.assertEqual(next(generator), "789") + + self.assertCountEqual(expected_files, os.listdir(outdir)) + mock_api_client.StreamExperiments.assert_not_called() + expected_data_request.experiment_id = "789" + mock_api_client.StreamExperimentData.assert_called_once_with( + expected_data_request, metadata=grpc_util.version_metadata()) + + # The final continuation shouldn't need to send any RPCs. + mock_api_client.StreamExperiments.reset_mock() + mock_api_client.StreamExperimentData.reset_mock() + self.assertEqual(list(generator), []) + + self.assertCountEqual(expected_files, os.listdir(outdir)) + mock_api_client.StreamExperiments.assert_not_called() + mock_api_client.StreamExperimentData.assert_not_called() + + # Spot-check one of the files. + with open(os.path.join(outdir, "scalars_456.json")) as infile: + jsons = [json.loads(line) for line in infile] + self.assertLen(jsons, 4) + datum = jsons[2] + self.assertEqual(datum.pop("run"), "test") + self.assertEqual(datum.pop("tag"), "accuracy") + summary_metadata = summary_pb2.SummaryMetadata.FromString( + base64.b64decode(datum.pop("summary_metadata"))) + expected_summary_metadata = test_util.scalar_metadata("456:accuracy") + self.assertEqual(summary_metadata, expected_summary_metadata) + points = datum.pop("points") + expected_steps = [x for x in range(10)] + expected_values = [2.0 * x for x in range(10)] + expected_wall_times = [1571084520.862939144 + x for x in range(10)] + self.assertEqual(points.pop("steps"), expected_steps) + self.assertEqual(points.pop("values"), expected_values) + self.assertEqual(points.pop("wall_times"), expected_wall_times) + self.assertEqual(points, {}) + self.assertEqual(datum, {}) + + def test_rejects_dangerous_experiment_ids(self): + mock_api_client = self._create_mock_api_client() + + def stream_experiments(request, **kwargs): + del request # unused + yield export_service_pb2.StreamExperimentsResponse( + experiment_ids=["../authorized_keys"]) + + mock_api_client.StreamExperiments = stream_experiments + + outdir = os.path.join(self.get_temp_dir(), "outdir") + exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir) + generator = exporter.export() + + with self.assertRaises(RuntimeError) as cm: + next(generator) + + msg = str(cm.exception) + self.assertIn("Unexpected characters", msg) + self.assertIn(repr(sorted([u".", u"/"])), msg) + self.assertIn("../authorized_keys", msg) + mock_api_client.StreamExperimentData.assert_not_called() + + def test_fails_nicely_on_stream_experiment_data_timeout(self): + # Setup: Client where: + # 1. stream_experiments will say there is one experiment_id. + # 2. stream_experiment_data will raise a grpc CANCELLED, as per + # a timeout. + mock_api_client = self._create_mock_api_client() + experiment_id="123" + + def stream_experiments(request, **kwargs): + del request # unused + yield export_service_pb2.StreamExperimentsResponse( + experiment_ids=[experiment_id]) + + def stream_experiment_data(request, **kwargs): + raise test_util.grpc_error(grpc.StatusCode.CANCELLED, "details string") + + mock_api_client.StreamExperiments = stream_experiments + mock_api_client.StreamExperimentData = stream_experiment_data + + outdir = os.path.join(self.get_temp_dir(), "outdir") + # Execute: exporter.export() + exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir) + generator = exporter.export() + # Expect: A nice exception of the right type and carrying the right + # experiment_id. + with self.assertRaises(exporter_lib.GrpcTimeoutException) as cm: + next(generator) + self.assertEquals(cm.exception.experiment_id, experiment_id) + + def test_stream_experiment_data_passes_through_unexpected_exception(self): + # Setup: Client where: + # 1. stream_experiments will say there is one experiment_id. + # 2. stream_experiment_data will throw an internal error. + mock_api_client = self._create_mock_api_client() + experiment_id = "123" + + def stream_experiments(request, **kwargs): + del request # unused + yield export_service_pb2.StreamExperimentsResponse( + experiment_ids=[experiment_id]) + + def stream_experiment_data(request, **kwargs): + del request # unused + raise test_util.grpc_error(grpc.StatusCode.INTERNAL, "details string") + + mock_api_client.StreamExperiments = stream_experiments + mock_api_client.StreamExperimentData = stream_experiment_data + + outdir = os.path.join(self.get_temp_dir(), "outdir") + # Execute: exporter.export(). + exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir) + generator = exporter.export() + # Expect: The internal error is passed through. + with self.assertRaises(grpc.RpcError) as cm: + next(generator) + self.assertEquals(cm.exception.details(), "details string") + + def test_handles_outdir_with_no_slash(self): + oldcwd = os.getcwd() + try: + os.chdir(self.get_temp_dir()) + mock_api_client = self._create_mock_api_client() + mock_api_client.StreamExperiments.return_value = iter([ + export_service_pb2.StreamExperimentsResponse(experiment_ids=["123"]), + ]) + mock_api_client.StreamExperimentData.return_value = iter([ + export_service_pb2.StreamExperimentDataResponse() + ]) + + exporter = exporter_lib.TensorBoardExporter(mock_api_client, "outdir") + generator = exporter.export() + self.assertEqual(list(generator), ["123"]) + self.assertTrue(os.path.isdir("outdir")) + finally: + os.chdir(oldcwd) + + def test_rejects_existing_directory(self): + mock_api_client = self._create_mock_api_client() + outdir = os.path.join(self.get_temp_dir(), "outdir") + os.mkdir(outdir) + with open(os.path.join(outdir, "scalars_999.json"), "w"): + pass + + with self.assertRaises(exporter_lib.OutputDirectoryExistsError): + exporter_lib.TensorBoardExporter(mock_api_client, outdir) + + mock_api_client.StreamExperiments.assert_not_called() + mock_api_client.StreamExperimentData.assert_not_called() + + def test_rejects_existing_file(self): + mock_api_client = self._create_mock_api_client() + + def stream_experiments(request, **kwargs): + del request # unused + yield export_service_pb2.StreamExperimentsResponse(experiment_ids=["123"]) + + mock_api_client.StreamExperiments = stream_experiments + + outdir = os.path.join(self.get_temp_dir(), "outdir") + exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir) + generator = exporter.export() + + with open(os.path.join(outdir, "scalars_123.json"), "w"): + pass + + with self.assertRaises(exporter_lib.OutputFileExistsError): + next(generator) + + mock_api_client.StreamExperimentData.assert_not_called() + + def test_propagates_mkdir_errors(self): + mock_api_client = self._create_mock_api_client() + outdir = os.path.join(self.get_temp_dir(), "some_file", "outdir") + with open(os.path.join(self.get_temp_dir(), "some_file"), "w"): + pass + + with self.assertRaises(OSError): + exporter_lib.TensorBoardExporter(mock_api_client, outdir) + + mock_api_client.StreamExperiments.assert_not_called() + mock_api_client.StreamExperimentData.assert_not_called() + + +class MkdirPTest(tb_test.TestCase): + + def test_makes_full_chain(self): + path = os.path.join(self.get_temp_dir(), "a", "b", "c") + exporter_lib._mkdir_p(path) + self.assertTrue(os.path.isdir(path)) + + def test_makes_leaf(self): + base = os.path.join(self.get_temp_dir(), "a", "b") + exporter_lib._mkdir_p(base) + leaf = os.path.join(self.get_temp_dir(), "a", "b", "c") + exporter_lib._mkdir_p(leaf) + self.assertTrue(os.path.isdir(leaf)) + + def test_fails_when_path_is_a_normal_file(self): + path = os.path.join(self.get_temp_dir(), "somefile") + with open(path, "w"): + pass + with self.assertRaises(OSError) as cm: + exporter_lib._mkdir_p(path) + self.assertEqual(cm.exception.errno, errno.EEXIST) + + def test_propagates_other_errors(self): + base = os.path.join(self.get_temp_dir(), "somefile") + with open(base, "w"): + pass + leaf = os.path.join(self.get_temp_dir(), "somefile", "somedir") + with self.assertRaises(OSError) as cm: + exporter_lib._mkdir_p(leaf) + self.assertNotEqual(cm.exception.errno, errno.EEXIST) + if os.name == "nt": + expected_errno = errno.ENOENT + else: + expected_errno = errno.ENOTDIR + self.assertEqual(cm.exception.errno, expected_errno) + + +class OpenExclTest(tb_test.TestCase): + + def test_success(self): + path = os.path.join(self.get_temp_dir(), "test.txt") + with exporter_lib._open_excl(path) as outfile: + outfile.write("hello\n") + with open(path) as infile: + self.assertEqual(infile.read(), "hello\n") + + def test_fails_when_file_exists(self): + path = os.path.join(self.get_temp_dir(), "test.txt") + with open(path, "w"): + pass + with self.assertRaises(exporter_lib.OutputFileExistsError) as cm: + exporter_lib._open_excl(path) + self.assertEqual(str(cm.exception), path) + + def test_propagates_other_errors(self): + path = os.path.join(self.get_temp_dir(), "enoent", "test.txt") + with self.assertRaises(OSError) as cm: + exporter_lib._open_excl(path) + self.assertEqual(cm.exception.errno, errno.ENOENT) + + +if __name__ == "__main__": + tb_test.main() diff --git a/tensorboard/uploader/logdir_loader.py b/tensorboard/uploader/logdir_loader.py new file mode 100644 index 0000000000..fc83d2428a --- /dev/null +++ b/tensorboard/uploader/logdir_loader.py @@ -0,0 +1,107 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Loader for event file data for an entire TensorBoard log directory.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import os + +from tensorboard.backend.event_processing import directory_watcher +from tensorboard.backend.event_processing import io_wrapper +from tensorboard.util import tb_logging + + +logger = tb_logging.get_logger() + + +class LogdirLoader(object): + """Loader for a root log directory, maintaining multiple DirectoryLoaders. + + This class takes a root log directory and a factory for DirectoryLoaders, and + maintains one DirectoryLoader per "logdir subdirectory" of the root logdir. + + Note that this class is not thread-safe. + """ + + def __init__(self, logdir, directory_loader_factory): + """Constructs a new LogdirLoader. + + Args: + logdir: The root log directory to load from. + directory_loader_factory: A factory for creating DirectoryLoaders. The + factory should take a path and return a DirectoryLoader. + + Raises: + ValueError: If logdir or directory_loader_factory are None. + """ + if logdir is None: + raise ValueError('A logdir is required') + if directory_loader_factory is None: + raise ValueError('A directory loader factory is required') + self._logdir = logdir + self._directory_loader_factory = directory_loader_factory + # Maps run names to corresponding DirectoryLoader instances. + self._directory_loaders = {} + + def synchronize_runs(self): + """Finds new runs within `logdir` and makes `DirectoryLoaders` for them. + + In addition, any existing `DirectoryLoader` whose run directory no longer + exists will be deleted. + """ + logger.info('Starting logdir traversal of %s', self._logdir) + runs_seen = set() + for subdir in io_wrapper.GetLogdirSubdirectories(self._logdir): + run = os.path.relpath(subdir, self._logdir) + runs_seen.add(run) + if run not in self._directory_loaders: + logger.info('- Adding run for relative directory %s', run) + self._directory_loaders[run] = self._directory_loader_factory(subdir) + stale_runs = set(self._directory_loaders) - runs_seen + if stale_runs: + for run in stale_runs: + logger.info('- Removing run for relative directory %s', run) + del self._directory_loaders[run] + logger.info('Ending logdir traversal of %s', self._logdir) + + def get_run_events(self): + """Returns tf.Event generators for each run's `DirectoryLoader`. + + Warning: the generators are stateful and consuming them will affect the + results of any other existing generators for that run; calling code should + ensure it takes events from only a single generator per run at a time. + + Returns: + Dictionary containing an entry for each run, mapping the run name to a + generator yielding tf.Event protobuf objects loaded from that run. + """ + runs = list(self._directory_loaders) + logger.info('Creating event loading generators for %d runs', len(runs)) + run_to_loader = collections.OrderedDict() + for run_name in sorted(runs): + loader = self._directory_loaders[run_name] + run_to_loader[run_name] = self._wrap_loader_generator(loader.Load()) + return run_to_loader + + def _wrap_loader_generator(self, loader_generator): + """Wraps `DirectoryLoader` generator to swallow `DirectoryDeletedError`.""" + try: + for item in loader_generator: + yield item + except directory_watcher.DirectoryDeletedError: + return diff --git a/tensorboard/uploader/logdir_loader_test.py b/tensorboard/uploader/logdir_loader_test.py new file mode 100644 index 0000000000..c3d8e09f3c --- /dev/null +++ b/tensorboard/uploader/logdir_loader_test.py @@ -0,0 +1,154 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorboard.uploader.logdir_loader.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path +import shutil +import six + +from tensorboard.uploader import logdir_loader +from tensorboard import test as tb_test +from tensorboard.backend.event_processing import directory_loader +from tensorboard.backend.event_processing import event_file_loader +from tensorboard.backend.event_processing import io_wrapper +from tensorboard.util import test_util + + +class LogdirLoaderTest(tb_test.TestCase): + + def _create_logdir_loader(self, logdir): + def directory_loader_factory(path): + return directory_loader.DirectoryLoader( + path, + event_file_loader.TimestampedEventFileLoader, + path_filter=io_wrapper.IsTensorFlowEventsFile) + return logdir_loader.LogdirLoader(logdir, directory_loader_factory) + + def _extract_tags(self, event_generator): + """Converts a generator of tf.Events into a list of event tags.""" + return [ + event.summary.value[0].tag for event in event_generator + if not event.file_version + ] + + def _extract_run_to_tags(self, run_to_events): + """Returns run-to-tags dict from run-to-event-generator dict.""" + run_to_tags = {} + for run_name, event_generator in six.iteritems(run_to_events): + # There should be no duplicate runs. + self.assertNotIn(run_name, run_to_tags) + run_to_tags[run_name] = self._extract_tags(event_generator) + return run_to_tags + + def test_empty_logdir(self): + logdir = self.get_temp_dir() + loader = self._create_logdir_loader(logdir) + # Default state is empty. + self.assertEmpty(list(loader.get_run_events())) + loader.synchronize_runs() + # Still empty, since there's no data. + self.assertEmpty(list(loader.get_run_events())) + + def test_single_event_logdir(self): + logdir = self.get_temp_dir() + with test_util.FileWriter(logdir) as writer: + writer.add_test_summary("foo") + loader = self._create_logdir_loader(logdir) + loader.synchronize_runs() + self.assertEqual( + self._extract_run_to_tags(loader.get_run_events()), {".": ["foo"]}) + # A second load should indicate no new data for the run. + self.assertEqual( + self._extract_run_to_tags(loader.get_run_events()), {".": []}) + + def test_multiple_writes_to_logdir(self): + logdir = self.get_temp_dir() + with test_util.FileWriter(os.path.join(logdir, "a")) as writer: + writer.add_test_summary("tag_a") + with test_util.FileWriter(os.path.join(logdir, "b")) as writer: + writer.add_test_summary("tag_b") + with test_util.FileWriter(os.path.join(logdir, "b", "x")) as writer: + writer.add_test_summary("tag_b_x") + writer_c = test_util.FileWriter(os.path.join(logdir, "c")) + writer_c.add_test_summary("tag_c") + writer_c.flush() + loader = self._create_logdir_loader(logdir) + loader.synchronize_runs() + self.assertEqual( + self._extract_run_to_tags(loader.get_run_events()), + {"a": ["tag_a"], "b": ["tag_b"], "b/x": ["tag_b_x"], "c": ["tag_c"]}) + # A second load should indicate no new data. + self.assertEqual( + self._extract_run_to_tags(loader.get_run_events()), + {"a": [], "b": [], "b/x": [], "c": []}) + # Write some new data to both new and pre-existing event files. + with test_util.FileWriter( + os.path.join(logdir, "a"), filename_suffix=".other") as writer: + writer.add_test_summary("tag_a_2") + writer.add_test_summary("tag_a_3") + writer.add_test_summary("tag_a_4") + with test_util.FileWriter( + os.path.join(logdir, "b", "x"), filename_suffix=".other") as writer: + writer.add_test_summary("tag_b_x_2") + with writer_c as writer: + writer.add_test_summary("tag_c_2") + # New data should appear on the next load. + self.assertEqual( + self._extract_run_to_tags(loader.get_run_events()), { + "a": ["tag_a_2", "tag_a_3", "tag_a_4"], + "b": [], + "b/x": ["tag_b_x_2"], + "c": ["tag_c_2"] + }) + + def test_directory_deletion(self): + logdir = self.get_temp_dir() + with test_util.FileWriter(os.path.join(logdir, "a")) as writer: + writer.add_test_summary("tag_a") + with test_util.FileWriter(os.path.join(logdir, "b")) as writer: + writer.add_test_summary("tag_b") + with test_util.FileWriter(os.path.join(logdir, "c")) as writer: + writer.add_test_summary("tag_c") + loader = self._create_logdir_loader(logdir) + loader.synchronize_runs() + self.assertEqual(list(loader.get_run_events().keys()), ["a", "b", "c"]) + shutil.rmtree(os.path.join(logdir, "b")) + loader.synchronize_runs() + self.assertEqual(list(loader.get_run_events().keys()), ["a", "c"]) + shutil.rmtree(logdir) + loader.synchronize_runs() + self.assertEmpty(loader.get_run_events()) + + def test_directory_deletion_during_event_loading(self): + logdir = self.get_temp_dir() + with test_util.FileWriter(logdir) as writer: + writer.add_test_summary("foo") + loader = self._create_logdir_loader(logdir) + loader.synchronize_runs() + self.assertEqual( + self._extract_run_to_tags(loader.get_run_events()), {".": ["foo"]}) + shutil.rmtree(logdir) + runs_to_events = loader.get_run_events() + self.assertEqual(list(runs_to_events.keys()), ["."]) + events = runs_to_events["."] + self.assertEqual(self._extract_tags(events), []) + + +if __name__ == "__main__": + tb_test.main() diff --git a/tensorboard/uploader/peekable_iterator.py b/tensorboard/uploader/peekable_iterator.py new file mode 100644 index 0000000000..59dea6a3ae --- /dev/null +++ b/tensorboard/uploader/peekable_iterator.py @@ -0,0 +1,87 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Iterator adapter that supports peeking ahead.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +class PeekableIterator(object): + """Iterator adapter that supports peeking ahead. + + As with most Python iterators, this is also iterable; its `__iter__` + returns itself. + + This class is not thread-safe. Use external synchronization if + iterating concurrently. + """ + + def __init__(self, iterable): + """Initializes a peeking iterator wrapping the provided iterable. + + Args: + iterable: An iterable to wrap. + """ + self._iterator = iter(iterable) + self._has_peeked = False + self._peeked_element = None + + def has_next(self): + """Checks whether there are any more items in this iterator. + + The next call to `next` or `peek` will raise `StopIteration` if and + only if this method returns `False`. + + Returns: + `True` if there are any more items in this iterator, else `False`. + """ + try: + self.peek() + return True + except StopIteration: + return False + + def peek(self): + """Gets the next item in the iterator without consuming it. + + Multiple consecutive calls will return the same element. + + Returns: + The value that would be returned by `next`. + + Raises: + StopIteration: If there are no more items in the iterator. + """ + if not self._has_peeked: + self._peeked_element = next(self._iterator) + self._has_peeked = True + return self._peeked_element + + def __iter__(self): + return self + + def __next__(self): + if self._has_peeked: + self._has_peeked = False + result = self._peeked_element + self._peeked_element = None # allow GC + return result + else: + return next(self._iterator) + + def next(self): + # (Like `__next__`, but Python 2.) + return self.__next__() diff --git a/tensorboard/uploader/peekable_iterator_test.py b/tensorboard/uploader/peekable_iterator_test.py new file mode 100644 index 0000000000..e2c52505d0 --- /dev/null +++ b/tensorboard/uploader/peekable_iterator_test.py @@ -0,0 +1,68 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorboard.uploader.peekable_iterator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorboard.uploader import peekable_iterator +from tensorboard import test as tb_test + + +class PeekableIteratorTest(tb_test.TestCase): + """Tests for `PeekableIterator`.""" + + def test_empty_iteration(self): + it = peekable_iterator.PeekableIterator([]) + self.assertEqual(list(it), []) + + def test_normal_iteration(self): + it = peekable_iterator.PeekableIterator([1, 2, 3]) + self.assertEqual(list(it), [1, 2, 3]) + + def test_simple_peek(self): + it = peekable_iterator.PeekableIterator([1, 2, 3]) + self.assertEqual(it.peek(), 1) + self.assertEqual(it.peek(), 1) + self.assertEqual(next(it), 1) + self.assertEqual(it.peek(), 2) + self.assertEqual(next(it), 2) + self.assertEqual(next(it), 3) + self.assertEqual(list(it), []) + + def test_simple_has_next(self): + it = peekable_iterator.PeekableIterator([1, 2]) + self.assertTrue(it.has_next()) + self.assertEqual(it.peek(), 1) + self.assertTrue(it.has_next()) + self.assertEqual(next(it), 1) + self.assertEqual(it.peek(), 2) + self.assertTrue(it.has_next()) + self.assertEqual(next(it), 2) + self.assertFalse(it.has_next()) + self.assertFalse(it.has_next()) + + def test_peek_after_end(self): + it = peekable_iterator.PeekableIterator([1, 2, 3]) + self.assertEqual(list(it), [1, 2, 3]) + with self.assertRaises(StopIteration): + it.peek() + with self.assertRaises(StopIteration): + it.peek() + + +if __name__ == "__main__": + tb_test.main() diff --git a/tensorboard/uploader/proto/BUILD b/tensorboard/uploader/proto/BUILD new file mode 100644 index 0000000000..9353b897e6 --- /dev/null +++ b/tensorboard/uploader/proto/BUILD @@ -0,0 +1,20 @@ +package(default_visibility = ["//tensorboard:internal"]) + +load("//tensorboard/defs:protos.bzl", "tb_proto_library") + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +tb_proto_library( + name = "protos_all", + srcs = [ + "export_service.proto", + "scalar.proto", + "write_service.proto", + ], + has_services = True, + deps = [ + "//tensorboard/compat/proto:protos_all", + ], +) diff --git a/tensorboard/uploader/proto/export_service.proto b/tensorboard/uploader/proto/export_service.proto new file mode 100644 index 0000000000..54ad5ab751 --- /dev/null +++ b/tensorboard/uploader/proto/export_service.proto @@ -0,0 +1,87 @@ +syntax = "proto3"; + +package tensorboard.service; + +import "google/protobuf/timestamp.proto"; +import "tensorboard/compat/proto/summary.proto"; + +// Service for exporting data from TensorBoard.dev. +service TensorBoardExporterService { + // Stream the experiment_id of all the experiments owned by the caller. + rpc StreamExperiments(StreamExperimentsRequest) + returns (stream StreamExperimentsResponse) {} + // Stream scalars for all the runs and tags in an experiment. + rpc StreamExperimentData(StreamExperimentDataRequest) + returns (stream StreamExperimentDataResponse) {} +} + +// Request to stream the experiment_id of all the experiments owned by the +// caller from TensorBoard.dev. +message StreamExperimentsRequest { + // Timestamp to get a consistent snapshot of the data in the database. + // This is useful when making multiple read RPCs and needing the data to be + // consistent across the read calls. + google.protobuf.Timestamp read_timestamp = 1; + // User ID defaults to the caller, but may be set to a different user for + // internal Takeout processes operating on behalf of a user. + string user_id = 2; + // Limits the number of experiment IDs returned. This is useful to check if + // user might have any data by setting limit=1. Also useful to preview the + // list of experiments. + int64 limit = 3; + // TODO(@karthikv2k): Support pagination. +} + +// Streams experiment IDs returned from TensorBoard.dev. +message StreamExperimentsResponse { + // List of experiment IDs for the experiments owned by the user. The entire + // list of experiments owned by the user is streamed in batches and each batch + // contains a list of experiment IDs. A consumer of this stream needs to + // concatenate all these lists to get the full response. The order of + // experiment IDs in the stream is not defined. + repeated string experiment_ids = 1; +} + +// Request to stream scalars from all the runs and tags in an experiment. +message StreamExperimentDataRequest { + // The permanent ID of the experiment whose data need to be streamed. + string experiment_id = 1; + // Timestamp to get a consistent snapshot of the data in the database. + // This is useful when making multiple read RPCs and needing the data to be + // consistent across the read calls. Should be the same as the read timestamp + // used for the corresponding `StreamExperimentsRequest` for consistency. + google.protobuf.Timestamp read_timestamp = 2; +} + +// Streams scalars from all the runs and tags in an experiment. Each stream +// result only contains data for a single tag from a single run. For example if +// there are five runs and each run had two tags, the RPC will return a stream +// of at least ten `StreamExperimentDataResponse`s, each one having the +// scalars for one tag. The values from a single tag may be split among multiple +// responses. Users need to aggregate information from entire stream to get +// data for the entire experiment. Empty experiments will have zero stream +// results. Empty runs that doesn't have any tags need not be supported by a +// hosted service. +message StreamExperimentDataResponse { + // Name of the tag whose data is contained in this response. + string tag_name = 1; + // Name of the run that contains the tag `tag_name`. + string run_name = 2; + // The metadata of the tag `tag_name`. + .tensorboard.SummaryMetadata tag_metadata = 3; + // Data to store for the tag `tag_name. + ScalarPoints points = 4; + + // Data for the scalars are stored in a columnar fashion to optimize it for + // exporting the data into textual formats like JSON. + // The data for the ith scalar is { steps[i], wall_times[i], values[i] }. + // The data here is sorted by step values in ascending order. + message ScalarPoints { + // Step index within the run. + repeated int64 steps = 1; + // Timestamp of the creation of this point. + repeated google.protobuf.Timestamp wall_times = 2; + // Value of the point at this step / timestamp. + repeated double values = 3; + } +} diff --git a/tensorboard/uploader/proto/scalar.proto b/tensorboard/uploader/proto/scalar.proto new file mode 100644 index 0000000000..462cf4df5c --- /dev/null +++ b/tensorboard/uploader/proto/scalar.proto @@ -0,0 +1,28 @@ +syntax = "proto3"; + +package tensorboard.service; + +import "google/protobuf/timestamp.proto"; +import "tensorboard/compat/proto/summary.proto"; + +// One point viewable on a scalar metric plot. +message ScalarPoint { + // Step index within the run. + int64 step = 1; + // Timestamp of the creation of this point. + google.protobuf.Timestamp wall_time = 2; + // Value of the point at this step / timestamp. + double value = 3; +} + +// Metadata for the ScalarPoints stored for one (Experiment, Run, Tag). +message ScalarPointMetadata { + // Maximum step recorded for the tag. + int64 max_step = 1; + // Timestamp corresponding to the max step. + google.protobuf.Timestamp max_wall_time = 2; + // Information about the plugin which created this scalar data. + // Note: The period is required part of the type here due to the + // package name resolution logic. + .tensorboard.SummaryMetadata summary_metadata = 3; +} diff --git a/tensorboard/uploader/proto/write_service.proto b/tensorboard/uploader/proto/write_service.proto new file mode 100644 index 0000000000..4d63f22f69 --- /dev/null +++ b/tensorboard/uploader/proto/write_service.proto @@ -0,0 +1,130 @@ +syntax = "proto3"; + +package tensorboard.service; + +import "tensorboard/uploader/proto/scalar.proto"; +import "tensorboard/compat/proto/summary.proto"; + +// Service for writing data to TensorBoard.dev. +service TensorBoardWriterService { + // Request for a new location to write TensorBoard readable events. + rpc CreateExperiment(CreateExperimentRequest) + returns (CreateExperimentResponse) {} + // Request that an experiment be deleted, along with all tags and scalars + // that it contains. This call may only be made by the original owner of the + // experiment. + rpc DeleteExperiment(DeleteExperimentRequest) + returns (DeleteExperimentResponse) {} + // Request that unreachable data be purged. Used only for testing; + // disabled in production. + rpc PurgeData(PurgeDataRequest) returns (PurgeDataResponse) {} + // Request additional data be stored in TensorBoard.dev. + rpc WriteScalar(WriteScalarRequest) returns (WriteScalarResponse) {} + // Request that the calling user and all their data be permanently deleted. + // Used for testing purposes. + rpc DeleteOwnUser(DeleteOwnUserRequest) returns (DeleteOwnUserResponse) {} +} + +// This is currently empty on purpose. No information is necessary +// to request a URL, except. authorization of course, which doesn't +// come within the proto. +message CreateExperimentRequest { + // This is empty on purpose. +} + +// Carries all information necessary to: +// 1. Inform the user where to navigate to see their TensorBoard. +// 2. Subsequently load (Scalars, Tensors, etc.) to the specified location. +message CreateExperimentResponse { + // Service-wide unique identifier of an uploaded log dir. + // eg: "1r9d0kQkh2laODSZcQXWP" + string experiment_id = 1; + // Url the user should navigate to to see their TensorBoard + // eg: "https://example.com/public/1r9d0kQkh2laODSZcQXWP" + string url = 2; +} + +message DeleteExperimentRequest { + // Service-wide unique identifier of an uploaded log dir. + // eg: "1r9d0kQkh2laODSZcQXWP" + string experiment_id = 1; +} + +message DeleteExperimentResponse { + // This is empty on purpose. +} + +// Only used for testing; corresponding RPC is disabled in prod. +message PurgeDataRequest { + // Maximum number of entities of a given kind to purge at once (e.g., + // maximum number of tags to purge). Required; must be positive. + int32 batch_limit = 1; +} + +// Only used for testing; corresponding RPC is disabled in prod. +message PurgeDataResponse { + // Stats about how many elements where purged. Compare to the batch + // limit specified in the request to estimate whether the backlog has + // any more items. + PurgeStats purge_stats = 1; +} + +// Details about what actions were taken as a result of a purge request. +// These values are upper bounds; they may exceed the true values. +message PurgeStats { + // Number of tags deleted as a result of this request. + int32 tags = 1; + // Number of experiments marked as purged as a result of this request. + int32 experiments = 2; + // Number of users deleted as a result of this request. + int32 users = 3; +} + +// Carries all that is needed to add additional run data to the hosted service. +message WriteScalarRequest { + // All the data to store for one Run. This data will be stored under the + // corresponding run in the hosted storage. WriteScalarRequest is merged into + // the data store for the keyed run. The tags and included scalars will be + // the union of the data sent across all WriteScalarRequests. Metadata by + // default uses a 'first write wins' approach. + message Run { + // The name of this run. For example "/some/path/mnist_experiments/run1/" + string name = 1; + // Data to store for this Run/Tag combination. + repeated Tag tags = 2; + } + + // All the data to store for one Tag of one Run. This data will be stored + // under the corresponding run/tag in the hosted storage. A tag corresponds to + // a single time series. + message Tag { + // The name of this tag. For example "loss" + string name = 1; + // Data to store for this Run/Tag combination. + repeated ScalarPoint points = 2; + // The metadata of this tag. + .tensorboard.SummaryMetadata metadata = 3; + } + + // Which experiment to write to - corresponding to one hosted TensorBoard URL. + // The requester must have authorization to write to this location. + string experiment_id = 1; + // Data to append to the existing storage at the experiment_id. + repeated Run runs = 2; +} + +// Everything the caller needs to know about how the writing went. +// (Currently empty) +message WriteScalarResponse { + // This is empty on purpose. +} + +// Requests that the calling user and all their data be permanently deleted. +message DeleteOwnUserRequest { + // This is empty on purpose. +} + +// Everything the caller needs to know about how the deletion went. +message DeleteOwnUserResponse { + // This is empty on purpose. +} diff --git a/tensorboard/uploader/test_util.py b/tensorboard/uploader/test_util.py new file mode 100644 index 0000000000..d6c5cca9d1 --- /dev/null +++ b/tensorboard/uploader/test_util.py @@ -0,0 +1,63 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for testing.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +import grpc + +from google.protobuf import timestamp_pb2 +from tensorboard.compat.proto import summary_pb2 + + +class FakeTime(object): + """Thread-safe fake replacement for the `time` module.""" + + def __init__(self, current=0.0): + self._time = float(current) + self._lock = threading.Lock() + + def time(self): + with self._lock: + return self._time + + def sleep(self, secs): + with self._lock: + self._time += secs + + +def scalar_metadata(display_name): + """Makes a scalar metadata proto, for constructing expected requests.""" + metadata = summary_pb2.SummaryMetadata(display_name=display_name) + metadata.plugin_data.plugin_name = "scalars" + return metadata + + +def grpc_error(code, details): + # Monkey patch insertion for the methods a real grpc.RpcError would have. + error = grpc.RpcError("RPC error %r: %s" % (code, details)) + error.code = lambda: code + error.details = lambda: details + return error + + +def timestamp_pb(nanos): + result = timestamp_pb2.Timestamp() + result.FromNanoseconds(nanos) + return result diff --git a/tensorboard/uploader/uploader.py b/tensorboard/uploader/uploader.py new file mode 100644 index 0000000000..222a23977c --- /dev/null +++ b/tensorboard/uploader/uploader.py @@ -0,0 +1,413 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Uploads a TensorBoard logdir to TensorBoard.dev.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import time + +import grpc +import six + +from tensorboard.uploader.proto import write_service_pb2 +from tensorboard.uploader import logdir_loader +from tensorboard.uploader import peekable_iterator +from tensorboard.uploader import util +from tensorboard import data_compat +from tensorboard.backend.event_processing import directory_loader +from tensorboard.backend.event_processing import event_file_loader +from tensorboard.backend.event_processing import io_wrapper +from tensorboard.plugins.scalar import metadata as scalar_metadata +from tensorboard.util import grpc_util +from tensorboard.util import tb_logging +from tensorboard.util import tensor_util + +# Minimum length of an upload cycle in seconds; shorter cycles will sleep to +# use up the rest of the time to avoid sending write RPCs too quickly. +_MIN_UPLOAD_CYCLE_DURATION_SECS = 5 + +# Age in seconds of last write after which an event file is considered inactive. +# TODO(@nfelt): consolidate with TensorBoard --reload_multifile default logic. +_EVENT_FILE_INACTIVE_SECS = 4000 + +# Maximum length of a base-128 varint as used to encode a 64-bit value +# (without the "msb of last byte is bit 63" optimization, to be +# compatible with protobuf and golang varints). +_MAX_VARINT64_LENGTH_BYTES = 10 + +# Maximum outgoing request size. The server-side limit is 4 MiB [1]; we +# should pad a bit to mitigate any errors in our bookkeeping. Currently, +# we pad a lot, because using higher request sizes causes occasional +# Deadline Exceeded errors in the RPC server. +# +# [1]: https://github.com/grpc/grpc/blob/e70d8582b4b0eedc45e3d25a57b58a08b94a9f4a/include/grpc/impl/codegen/grpc_types.h#L447 # pylint: disable=line-too-long +_MAX_REQUEST_LENGTH_BYTES = 1024 * 128 + +logger = tb_logging.get_logger() + + +class TensorBoardUploader(object): + """Uploads a TensorBoard logdir to TensorBoard.dev.""" + + def __init__(self, writer_client, logdir, rate_limiter=None): + """Constructs a TensorBoardUploader. + + Args: + writer_client: a TensorBoardWriterService stub instance + logdir: path of the log directory to upload + rate_limiter: a `RateLimiter` to use to limit upload cycle frequency + """ + self._api = writer_client + self._logdir = logdir + self._request_builder = None + if rate_limiter is None: + self._rate_limiter = util.RateLimiter(_MIN_UPLOAD_CYCLE_DURATION_SECS) + else: + self._rate_limiter = rate_limiter + active_filter = lambda secs: secs + _EVENT_FILE_INACTIVE_SECS >= time.time() + directory_loader_factory = functools.partial( + directory_loader.DirectoryLoader, + loader_factory=event_file_loader.TimestampedEventFileLoader, + path_filter=io_wrapper.IsTensorFlowEventsFile, + active_filter=active_filter) + self._logdir_loader = logdir_loader.LogdirLoader( + self._logdir, directory_loader_factory) + + def create_experiment(self): + """Creates an Experiment for this upload session and returns the URL.""" + logger.info("Creating experiment") + request = write_service_pb2.CreateExperimentRequest() + response = grpc_util.call_with_retries(self._api.CreateExperiment, request) + self._request_builder = _RequestBuilder(response.experiment_id) + return response.url + + def start_uploading(self): + """Blocks forever to continuously upload data from the logdir. + + Raises: + RuntimeError: If `create_experiment` has not yet been called. + ExperimentNotFoundError: If the experiment is deleted during the + course of the upload. + """ + if self._request_builder is None: + raise RuntimeError( + "Must call create_experiment() before start_uploading()") + while True: + self._upload_once() + + def _upload_once(self): + """Runs one upload cycle, sending zero or more RPCs.""" + logger.info("Starting an upload cycle") + self._rate_limiter.tick() + + sync_start_time = time.time() + self._logdir_loader.synchronize_runs() + sync_duration_secs = time.time() - sync_start_time + logger.info("Logdir sync took %.3f seconds", sync_duration_secs) + + run_to_events = self._logdir_loader.get_run_events() + first_request = True + for request in self._request_builder.build_requests(run_to_events): + if not first_request: + self._rate_limiter.tick() + first_request = False + upload_start_time = time.time() + request_bytes = request.ByteSize() + logger.info("Trying request of %d bytes", request_bytes) + self._upload(request) + upload_duration_secs = time.time() - upload_start_time + logger.info( + "Upload for %d runs (%d bytes) took %.3f seconds", + len(request.runs), + request_bytes, + upload_duration_secs) + + def _upload(self, request): + try: + # TODO(@nfelt): execute this RPC asynchronously. + grpc_util.call_with_retries(self._api.WriteScalar, request) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.NOT_FOUND: + raise ExperimentNotFoundError() + logger.error("Upload call failed with error %s", e) + + +def delete_experiment(writer_client, experiment_id): + """Permanently deletes an experiment and all of its contents. + + Args: + writer_client: a TensorBoardWriterService stub instance + experiment_id: string ID of the experiment to delete + + Raises: + ExperimentNotFoundError: If no such experiment exists. + PermissionDeniedError: If the user is not authorized to delete this + experiment. + RuntimeError: On unexpected failure. + """ + logger.info("Deleting experiment %r", experiment_id) + request = write_service_pb2.DeleteExperimentRequest() + request.experiment_id = experiment_id + try: + grpc_util.call_with_retries(writer_client.DeleteExperiment, request) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.NOT_FOUND: + raise ExperimentNotFoundError() + if e.code() == grpc.StatusCode.PERMISSION_DENIED: + raise PermissionDeniedError() + raise + + +class ExperimentNotFoundError(RuntimeError): + pass + + +class PermissionDeniedError(RuntimeError): + pass + + +class _OutOfSpaceError(Exception): + """Action could not proceed without overflowing request budget. + + This is a signaling exception (like `StopIteration`) used internally + by `_RequestBuilder`; it does not mean that anything has gone wrong. + """ + pass + + +class _RequestBuilder(object): + """Helper class for building requests that fit under a size limit. + + This class is not threadsafe. Use external synchronization if calling + its methods concurrently. + """ + + _NON_SCALAR_TIME_SERIES = object() # sentinel + + def __init__(self, experiment_id): + self._experiment_id = experiment_id + # The request currently being populated. + self._request = None # type: write_service_pb2.WriteScalarRequest + # A lower bound on the number of bytes that we may yet add to the + # request. + self._byte_budget = None # type: int + # Map from `(run_name, tag_name)` to `SummaryMetadata` if the time + # series is a scalar time series, else to `_NON_SCALAR_TIME_SERIES`. + self._tag_metadata = {} + + def _new_request(self): + """Allocates a new request and refreshes the budget.""" + self._request = write_service_pb2.WriteScalarRequest() + self._byte_budget = _MAX_REQUEST_LENGTH_BYTES + self._request.experiment_id = self._experiment_id + self._byte_budget -= self._request.ByteSize() + if self._byte_budget < 0: + raise RuntimeError("Byte budget too small for experiment ID") + + def build_requests(self, run_to_events): + """Converts a stream of TF events to a stream of outgoing requests. + + Each yielded request will be at most `_MAX_REQUEST_LENGTH_BYTES` + bytes long. + + Args: + run_to_events: Mapping from run name to generator of `tf.Event` + values, as returned by `LogdirLoader.get_run_events`. + + Yields: + A finite stream of `WriteScalarRequest` objects. + + Raises: + RuntimeError: If no progress can be made because even a single + point is too large (say, due to a gigabyte-long tag name). + """ + + self._new_request() + runs = {} # cache: map from run name to `Run` proto in request + tags = {} # cache: map from `(run, tag)` to `Tag` proto in run in request + work_items = peekable_iterator.PeekableIterator( + self._run_values(run_to_events)) + + while work_items.has_next(): + (run_name, event, orig_value) = work_items.peek() + value = data_compat.migrate_value(orig_value) + time_series_key = (run_name, value.tag) + + metadata = self._tag_metadata.get(time_series_key) + if metadata is None: + plugin_name = value.metadata.plugin_data.plugin_name + if plugin_name == scalar_metadata.PLUGIN_NAME: + metadata = value.metadata + else: + metadata = _RequestBuilder._NON_SCALAR_TIME_SERIES + self._tag_metadata[time_series_key] = metadata + if metadata is _RequestBuilder._NON_SCALAR_TIME_SERIES: + next(work_items) + continue + try: + run_proto = runs.get(run_name) + if run_proto is None: + run_proto = self._create_run(run_name) + runs[run_name] = run_proto + tag_proto = tags.get((run_name, value.tag)) + if tag_proto is None: + tag_proto = self._create_tag(run_proto, value.tag, metadata) + tags[(run_name, value.tag)] = tag_proto + self._create_point(tag_proto, event, value) + next(work_items) + except _OutOfSpaceError: + # Flush request and start a new one. + request_to_emit = self._prune_request() + if request_to_emit is None: + raise RuntimeError("Could not make progress uploading data") + self._new_request() + runs.clear() + tags.clear() + yield request_to_emit + + final_request = self._prune_request() + if final_request is not None: + yield final_request + + def _run_values(self, run_to_events): + """Helper generator to create a single stream of work items.""" + # Note that each of these joins in principle has deletion anomalies: + # if the input stream contains runs with no events, or events with + # no values, we'll lose that information. This is not a problem: we + # would need to prune such data from the request anyway. + for (run_name, events) in six.iteritems(run_to_events): + for event in events: + for value in event.summary.value: + yield (run_name, event, value) + + def _prune_request(self): + """Removes empty runs and tags from the active request. + + This does not refund `self._byte_budget`; it is assumed that the + request will be emitted immediately, anyway. + + Returns: + The active request, or `None` if after pruning the request + contains no data. + """ + request = self._request + for (run_idx, run) in reversed(list(enumerate(request.runs))): + for (tag_idx, tag) in reversed(list(enumerate(run.tags))): + if not tag.points: + del run.tags[tag_idx] + if not run.tags: + del self._request.runs[run_idx] + if not request.runs: + request = None + return request + + def _create_run(self, run_name): + """Adds a run to the live request, if there's space. + + Args: + run_name: String name of the run to add. + + Returns: + The `WriteScalarRequest.Run` that was added to `request.runs`. + + Raises: + _OutOfSpaceError: If adding the run would exceed the remaining + request budget. + """ + run_proto = self._request.runs.add(name=run_name) + # We can't calculate the proto key cost exactly ahead of time, as + # it depends on the total size of all tags. Be conservative. + cost = run_proto.ByteSize() + _MAX_VARINT64_LENGTH_BYTES + 1 + if cost > self._byte_budget: + raise _OutOfSpaceError() + self._byte_budget -= cost + return run_proto + + def _create_tag(self, run_proto, tag_name, metadata): + """Adds a tag for the given value, if there's space. + + Args: + run_proto: `WriteScalarRequest.Run` proto to which to add a tag. + tag_name: String name of the tag to add (as `value.tag`). + metadata: TensorBoard `SummaryMetadata` proto from the first + occurrence of this time series. + + Returns: + The `WriteScalarRequest.Tag` that was added to `run_proto.tags`. + + Raises: + _OutOfSpaceError: If adding the tag would exceed the remaining + request budget. + """ + tag_proto = run_proto.tags.add(name=tag_name) + tag_proto.metadata.CopyFrom(metadata) + submessage_cost = tag_proto.ByteSize() + # We can't calculate the proto key cost exactly ahead of time, as + # it depends on the number of points. Be conservative. + cost = submessage_cost + _MAX_VARINT64_LENGTH_BYTES + 1 + if cost > self._byte_budget: + raise _OutOfSpaceError() + self._byte_budget -= cost + return tag_proto + + def _create_point(self, tag_proto, event, value): + """Adds a scalar point to the given tag, if there's space. + + Args: + tag_proto: `WriteScalarRequest.Tag` proto to which to add a point. + event: Enclosing `Event` proto with the step and wall time data. + value: Scalar `Summary.Value` proto with the actual scalar data. + + Returns: + The `ScalarPoint` that was added to `tag_proto.points`. + + Raises: + _OutOfSpaceError: If adding the point would exceed the remaining + request budget. + """ + point = tag_proto.points.add() + point.step = event.step + # TODO(@nfelt): skip tensor roundtrip for Value with simple_value set + point.value = tensor_util.make_ndarray(value.tensor).item() + util.set_timestamp(point.wall_time, event.wall_time) + submessage_cost = point.ByteSize() + cost = submessage_cost + _varint_cost(submessage_cost) + 1 # proto key + if cost > self._byte_budget: + tag_proto.points.pop() + raise _OutOfSpaceError() + self._byte_budget -= cost + return point + + +def _varint_cost(n): + """Computes the size of `n` encoded as an unsigned base-128 varint. + + This should be consistent with the proto wire format: + + + Args: + n: A non-negative integer. + + Returns: + An integer number of bytes. + """ + result = 1 + while n >= 128: + result += 1 + n >>= 7 + return result diff --git a/tensorboard/uploader/uploader_main.py b/tensorboard/uploader/uploader_main.py new file mode 100644 index 0000000000..d9fc953847 --- /dev/null +++ b/tensorboard/uploader/uploader_main.py @@ -0,0 +1,474 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Main program for the TensorBoard.dev uploader.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import json +import os +import sys +import textwrap + +from absl import app +from absl import logging +from absl.flags import argparse_flags +import grpc +import six + +from tensorboard.uploader import dev_creds +from tensorboard.uploader.proto import export_service_pb2_grpc +from tensorboard.uploader.proto import write_service_pb2_grpc +from tensorboard.uploader import auth +from tensorboard.uploader import exporter as exporter_lib +from tensorboard.uploader import uploader as uploader_lib +from tensorboard import program +from tensorboard.plugins import base_plugin + + +# Temporary integration point for absl compatibility; will go away once +# migrated to TensorBoard subcommand. +_FLAGS = None + + +_MESSAGE_TOS = u"""\ +Your use of this service is subject to Google's Terms of Service + and Privacy Policy +, and TensorBoard.dev's Terms of Service +. + +This notice will not be shown again while you are logged into the uploader. +To log out, run `tensorboard dev auth revoke`. +""" + + +_SUBCOMMAND_FLAG = '_uploader__subcommand' +_SUBCOMMAND_KEY_UPLOAD = 'UPLOAD' +_SUBCOMMAND_KEY_DELETE = 'DELETE' +_SUBCOMMAND_KEY_EXPORT = 'EXPORT' +_SUBCOMMAND_KEY_AUTH = 'AUTH' +_AUTH_SUBCOMMAND_FLAG = '_uploader__subcommand_auth' +_AUTH_SUBCOMMAND_KEY_REVOKE = 'REVOKE' + + +def _prompt_for_user_ack(intent): + """Prompts for user consent, exiting the program if they decline.""" + body = intent.get_ack_message_body() + header = '\n***** TensorBoard Uploader *****\n' + user_ack_message = '\n'.join((header, body, _MESSAGE_TOS)) + sys.stderr.write(user_ack_message) + sys.stderr.write('\n') + response = six.moves.input('Continue? (yes/NO) ') + if response.lower() not in ('y', 'yes'): + sys.exit(0) + sys.stderr.write('\n') + + +def _define_flags(parser): + """Configures flags on the provided argument parser. + + Integration point for `tensorboard.program`'s subcommand system. + + Args: + parser: An `argparse.ArgumentParser` to be mutated. + """ + + subparsers = parser.add_subparsers() + + parser.add_argument( + '--endpoint', + type=str, + default='api.tensorboard.dev:443', + help='URL for the API server accepting write requests.') + + parser.add_argument( + '--grpc_creds_type', + type=str, + default='ssl', + choices=('local', 'ssl', 'ssl_dev'), + help='The type of credentials to use for the gRPC client') + + parser.add_argument( + '--auth_force_console', + action='store_true', + help='Set to true to force authentication flow to use the ' + '--console rather than a browser redirect to localhost.') + + upload = subparsers.add_parser( + 'upload', help='upload an experiment to TensorBoard.dev') + upload.set_defaults(**{_SUBCOMMAND_FLAG: _SUBCOMMAND_KEY_UPLOAD}) + upload.add_argument( + '--logdir', + metavar='PATH', + type=str, + default=None, + help='Directory containing the logs to process') + + delete = subparsers.add_parser( + 'delete', + help='permanently delete an experiment', + inherited_absl_flags=None) + delete.set_defaults(**{_SUBCOMMAND_FLAG: _SUBCOMMAND_KEY_DELETE}) + # We would really like to call this next flag `--experiment` rather + # than `--experiment_id`, but this is broken inside Google due to a + # long-standing Python bug: + # (Some Google-internal dependencies define `--experimental_*` flags.) + # This isn't exactly a principled fix, but it gets the job done. + delete.add_argument( + '--experiment_id', + metavar='EXPERIMENT_ID', + type=str, + default=None, + help='ID of an experiment to delete permanently') + + export = subparsers.add_parser( + 'export', help='download all your experiment data') + export.set_defaults(**{_SUBCOMMAND_FLAG: _SUBCOMMAND_KEY_EXPORT}) + export.add_argument( + '--outdir', + metavar='OUTPUT_PATH', + type=str, + default=None, + help='Directory into which to download all experiment data; ' + 'must not yet exist') + + auth_parser = subparsers.add_parser('auth', help='log in, log out') + auth_parser.set_defaults(**{_SUBCOMMAND_FLAG: _SUBCOMMAND_KEY_AUTH}) + auth_subparsers = auth_parser.add_subparsers() + + auth_revoke = auth_subparsers.add_parser( + 'revoke', help='revoke all existing credentials and log out') + auth_revoke.set_defaults( + **{_AUTH_SUBCOMMAND_FLAG: _AUTH_SUBCOMMAND_KEY_REVOKE}) + + +def _parse_flags(argv=('',)): + """Integration point for `absl.app`. + + Exits if flag values are invalid. + + Args: + argv: CLI arguments, as with `sys.argv`, where the first argument is taken + to be the name of the program being executed. + + Returns: + Either argv[:1] if argv was non-empty, or [''] otherwise, as a mechanism + for absl.app.run() compatibility. + """ + parser = argparse_flags.ArgumentParser( + prog='uploader', + description=('Upload your TensorBoard experiments to TensorBoard.dev')) + _define_flags(parser) + arg0 = argv[0] if argv else '' + global _FLAGS + _FLAGS = parser.parse_args(argv[1:]) + return [arg0] + + +def _run(flags): + """Runs the main uploader program given parsed flags. + + Args: + flags: An `argparse.Namespace`. + """ + + logging.set_stderrthreshold(logging.WARNING) + intent = _get_intent(flags) + + store = auth.CredentialsStore() + if isinstance(intent, _AuthRevokeIntent): + store.clear() + sys.stderr.write('Logged out of uploader.\n') + sys.stderr.flush() + return + # TODO(b/141723268): maybe reconfirm Google Account prior to reuse. + credentials = store.read_credentials() + if not credentials: + _prompt_for_user_ack(intent) + client_config = json.loads(auth.OAUTH_CLIENT_CONFIG) + flow = auth.build_installed_app_flow(client_config) + credentials = flow.run(force_console=flags.auth_force_console) + sys.stderr.write('\n') # Extra newline after auth flow messages. + store.write_credentials(credentials) + + channel_options = None + if flags.grpc_creds_type == 'local': + channel_creds = grpc.local_channel_credentials() + elif flags.grpc_creds_type == 'ssl': + channel_creds = grpc.ssl_channel_credentials() + elif flags.grpc_creds_type == 'ssl_dev': + channel_creds = grpc.ssl_channel_credentials(dev_creds.DEV_SSL_CERT) + channel_options = [('grpc.ssl_target_name_override', 'localhost')] + else: + msg = 'Invalid --grpc_creds_type %s' % flags.grpc_creds_type + raise base_plugin.FlagsError(msg) + + composite_channel_creds = grpc.composite_channel_credentials( + channel_creds, auth.id_token_call_credentials(credentials)) + + # TODO(@nfelt): In the `_UploadIntent` case, consider waiting until + # logdir exists to open channel. + channel = grpc.secure_channel( + flags.endpoint, composite_channel_creds, options=channel_options) + with channel: + intent.execute(channel) + + +@six.add_metaclass(abc.ABCMeta) +class _Intent(object): + """A description of the user's intent in invoking this program. + + Each valid set of CLI flags corresponds to one intent: e.g., "upload + data from this logdir", or "delete the experiment with that ID". + """ + + @abc.abstractmethod + def get_ack_message_body(self): + """Gets the message to show when executing this intent at first login. + + This need not include the header (program name) or Terms of Service + notice. + + Returns: + A Unicode string, potentially spanning multiple lines. + """ + pass + + @abc.abstractmethod + def execute(self, channel): + """Carries out this intent with the specified gRPC channel. + + Args: + channel: A connected gRPC channel whose server provides the TensorBoard + reader and writer services. + """ + pass + + +class _AuthRevokeIntent(_Intent): + """The user intends to revoke credentials.""" + + def get_ack_message_body(self): + """Must not be called.""" + raise AssertionError('No user ack needed to revoke credentials') + + def execute(self, channel): + """Execute handled specially by `main`. Must not be called.""" + raise AssertionError('_AuthRevokeIntent should not be directly executed') + + +class _DeleteExperimentIntent(_Intent): + """The user intends to delete an experiment.""" + + _MESSAGE_TEMPLATE = textwrap.dedent(u"""\ + This will delete the experiment on https://tensorboard.dev with the + following experiment ID: + + {experiment_id} + + You have chosen to delete an experiment. All experiments uploaded + to TensorBoard.dev are publicly visible. Do not upload sensitive + data. + """) + + def __init__(self, experiment_id): + self.experiment_id = experiment_id + + def get_ack_message_body(self): + return self._MESSAGE_TEMPLATE.format(experiment_id=self.experiment_id) + + def execute(self, channel): + api_client = write_service_pb2_grpc.TensorBoardWriterServiceStub(channel) + experiment_id = self.experiment_id + if not experiment_id: + raise base_plugin.FlagsError( + 'Must specify a non-empty experiment ID to delete.') + try: + uploader_lib.delete_experiment(api_client, experiment_id) + except uploader_lib.ExperimentNotFoundError: + _die( + 'No such experiment %s. Either it never existed or it has ' + 'already been deleted.' % experiment_id) + except uploader_lib.PermissionDeniedError: + _die( + 'Cannot delete experiment %s because it is owned by a ' + 'different user.' % experiment_id) + except grpc.RpcError as e: + _die('Internal error deleting experiment: %s' % e) + print('Deleted experiment %s.' % experiment_id) + + +class _UploadIntent(_Intent): + """The user intends to upload an experiment from the given logdir.""" + + _MESSAGE_TEMPLATE = textwrap.dedent(u"""\ + This will upload your TensorBoard logs to https://tensorboard.dev/ from + the following directory: + + {logdir} + + This TensorBoard will be visible to everyone. Do not upload sensitive + data. + """) + + def __init__(self, logdir): + self.logdir = logdir + + def get_ack_message_body(self): + return self._MESSAGE_TEMPLATE.format(logdir=self.logdir) + + def execute(self, channel): + api_client = write_service_pb2_grpc.TensorBoardWriterServiceStub(channel) + uploader = uploader_lib.TensorBoardUploader(api_client, self.logdir) + url = uploader.create_experiment() + print("Upload started and will continue reading any new data as it's added") + print("to the logdir. To stop uploading, press Ctrl-C.") + print("View your TensorBoard live at: %s" % url) + try: + uploader.start_uploading() + except uploader_lib.ExperimentNotFoundError: + print('Experiment was deleted; uploading has been cancelled') + return + except KeyboardInterrupt: + print() + print('Upload stopped. View your TensorBoard at %s' % url) + return + # TODO(@nfelt): make it possible for the upload cycle to end once we + # detect that no more runs are active, so this code can be reached. + print('Done! View your TensorBoard at %s' % url) + + +class _ExportIntent(_Intent): + """The user intends to download all their experiment data.""" + + _MESSAGE_TEMPLATE = textwrap.dedent(u"""\ + This will download all your experiment data from https://tensorboard.dev + and save it to the following directory: + + {output_dir} + + Downloading your experiment data does not delete it from the + service. All experiments uploaded to TensorBoard.dev are publicly + visible. Do not upload sensitive data. + """) + + def __init__(self, output_dir): + self.output_dir = output_dir + + def get_ack_message_body(self): + return self._MESSAGE_TEMPLATE.format(output_dir=self.output_dir) + + def execute(self, channel): + api_client = export_service_pb2_grpc.TensorBoardExporterServiceStub(channel) + outdir = self.output_dir + try: + exporter = exporter_lib.TensorBoardExporter(api_client, outdir) + except exporter_lib.OutputDirectoryExistsError: + msg = 'Output directory already exists: %r' % outdir + raise base_plugin.FlagsError(msg) + num_experiments = 0 + try: + for experiment_id in exporter.export(): + num_experiments += 1 + print('Downloaded experiment %s' % experiment_id) + except exporter_lib.GrpcTimeoutException as e: + print( + '\nUploader has failed because of a timeout error. Please reach ' + 'out via e-mail to tensorboard.dev-support@google.com to get help ' + 'completing your export of experiment %s.' % e.experiment_id) + print('Done. Downloaded %d experiments to: %s' % (num_experiments, outdir)) + + +def _get_intent(flags): + """Determines what the program should do (upload, delete, ...). + + Args: + flags: An `argparse.Namespace` with the parsed flags. + + Returns: + An `_Intent` instance. + + Raises: + base_plugin.FlagsError: If the command-line `flags` do not correctly + specify an intent. + """ + cmd = getattr(flags, _SUBCOMMAND_FLAG, None) + if cmd is None: + raise base_plugin.FlagsError('Must specify subcommand (try --help).') + if cmd == _SUBCOMMAND_KEY_UPLOAD: + if flags.logdir: + return _UploadIntent(os.path.expanduser(flags.logdir)) + else: + raise base_plugin.FlagsError( + 'Must specify directory to upload via `--logdir`.') + elif cmd == _SUBCOMMAND_KEY_DELETE: + if flags.experiment_id: + return _DeleteExperimentIntent(flags.experiment_id) + else: + raise base_plugin.FlagsError( + 'Must specify experiment to delete via `--experiment_id`.') + elif cmd == _SUBCOMMAND_KEY_EXPORT: + if flags.outdir: + return _ExportIntent(flags.outdir) + else: + raise base_plugin.FlagsError( + 'Must specify output directory via `--outdir`.') + elif cmd == _SUBCOMMAND_KEY_AUTH: + auth_cmd = getattr(flags, _AUTH_SUBCOMMAND_FLAG, None) + if auth_cmd is None: + raise base_plugin.FlagsError('Must specify a subcommand to `auth`.') + if auth_cmd == _AUTH_SUBCOMMAND_KEY_REVOKE: + return _AuthRevokeIntent() + else: + raise AssertionError('Unknown auth subcommand %r' % (auth_cmd,)) + else: + raise AssertionError('Unknown subcommand %r' % (cmd,)) + + +def _die(message): + sys.stderr.write('%s\n' % (message,)) + sys.stderr.flush() + sys.exit(1) + + +def main(unused_argv): + global _FLAGS + flags = _FLAGS + # Prevent accidental use of `_FLAGS` until migration to TensorBoard + # subcommand is complete, at which point `_FLAGS` goes away. + del _FLAGS + return _run(flags) + + +class UploaderSubcommand(program.TensorBoardSubcommand): + """Integration point with `tensorboard` CLI.""" + + def name(self): + return 'dev' + + def define_flags(self, parser): + _define_flags(parser) + + def run(self, flags): + return _run(flags) + + def help(self): + return 'upload data to TensorBoard.dev' + + +if __name__ == '__main__': + app.run(main, flags_parser=_parse_flags) diff --git a/tensorboard/uploader/uploader_test.py b/tensorboard/uploader/uploader_test.py new file mode 100644 index 0000000000..427ce45ac8 --- /dev/null +++ b/tensorboard/uploader/uploader_test.py @@ -0,0 +1,662 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorboard.uploader.uploader.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import os + +import grpc +import grpc_testing + +try: + # python version >= 3.3 + from unittest import mock # pylint: disable=g-import-not-at-top +except ImportError: + import mock # pylint: disable=g-import-not-at-top,unused-import + +import tensorflow as tf + +from tensorboard.uploader.proto import scalar_pb2 +from tensorboard.uploader.proto import write_service_pb2 +from tensorboard.uploader.proto import write_service_pb2_grpc +from tensorboard.uploader import test_util +from tensorboard.uploader import uploader as uploader_lib +from tensorboard.uploader import util +from tensorboard.compat.proto import event_pb2 +from tensorboard.compat.proto import summary_pb2 +from tensorboard.plugins.histogram import summary_v2 as histogram_v2 +from tensorboard.plugins.scalar import summary_v2 as scalar_v2 +from tensorboard.summary import v1 as summary_v1 +from tensorboard.util import test_util as tb_test_util + + +class AbortUploadError(Exception): + """Exception used in testing to abort the upload process.""" + + +class TensorboardUploaderTest(tf.test.TestCase): + + def _create_mock_client(self): + # Create a stub instance (using a test channel) in order to derive a mock + # from it with autospec enabled. Mocking TensorBoardWriterServiceStub itself + # doesn't work with autospec because grpc constructs stubs via metaclassing. + test_channel = grpc_testing.channel( + service_descriptors=[], time=grpc_testing.strict_real_time()) + stub = write_service_pb2_grpc.TensorBoardWriterServiceStub(test_channel) + mock_client = mock.create_autospec(stub) + fake_exp_response = write_service_pb2.CreateExperimentResponse( + experiment_id="123", url="https://example.com/123") + mock_client.CreateExperiment.return_value = fake_exp_response + return mock_client + + def test_create_experiment(self): + logdir = "/logs/foo" + mock_client = self._create_mock_client() + uploader = uploader_lib.TensorBoardUploader(mock_client, logdir) + url = uploader.create_experiment() + self.assertEqual(url, "https://example.com/123") + + def test_start_uploading_without_create_experiment_fails(self): + mock_client = self._create_mock_client() + uploader = uploader_lib.TensorBoardUploader(mock_client, "/logs/foo") + with self.assertRaisesRegex(RuntimeError, "call create_experiment()"): + uploader.start_uploading() + + def test_start_uploading(self): + mock_client = self._create_mock_client() + mock_rate_limiter = mock.create_autospec(util.RateLimiter) + uploader = uploader_lib.TensorBoardUploader( + mock_client, "/logs/foo", mock_rate_limiter) + uploader.create_experiment() + mock_builder = mock.create_autospec(uploader_lib._RequestBuilder) + request = write_service_pb2.WriteScalarRequest() + mock_builder.build_requests.side_effect = [ + iter([request, request]), + iter([request, request, request, request, request]), + AbortUploadError, + ] + # pylint: disable=g-backslash-continuation + with mock.patch.object(uploader, "_upload") as mock_upload, \ + mock.patch.object(uploader, "_request_builder", mock_builder), \ + self.assertRaises(AbortUploadError): + uploader.start_uploading() + # pylint: enable=g-backslash-continuation + self.assertEqual(7, mock_upload.call_count) + self.assertEqual(2 + 5 + 1, mock_rate_limiter.tick.call_count) + + def test_upload_empty_logdir(self): + logdir = self.get_temp_dir() + mock_client = self._create_mock_client() + mock_rate_limiter = mock.create_autospec(util.RateLimiter) + uploader = uploader_lib.TensorBoardUploader( + mock_client, logdir, mock_rate_limiter) + uploader.create_experiment() + uploader._upload_once() + mock_client.WriteScalar.assert_not_called() + + def test_upload_swallows_rpc_failure(self): + logdir = self.get_temp_dir() + with tb_test_util.FileWriter(logdir) as writer: + writer.add_test_summary("foo") + mock_client = self._create_mock_client() + mock_rate_limiter = mock.create_autospec(util.RateLimiter) + uploader = uploader_lib.TensorBoardUploader( + mock_client, logdir, mock_rate_limiter) + uploader.create_experiment() + error = test_util.grpc_error(grpc.StatusCode.INTERNAL, "Failure") + mock_client.WriteScalar.side_effect = error + uploader._upload_once() + mock_client.WriteScalar.assert_called_once() + + def test_upload_propagates_experiment_deletion(self): + logdir = self.get_temp_dir() + with tb_test_util.FileWriter(logdir) as writer: + writer.add_test_summary("foo") + mock_client = self._create_mock_client() + mock_rate_limiter = mock.create_autospec(util.RateLimiter) + uploader = uploader_lib.TensorBoardUploader( + mock_client, logdir, mock_rate_limiter) + uploader.create_experiment() + error = test_util.grpc_error(grpc.StatusCode.NOT_FOUND, "nope") + mock_client.WriteScalar.side_effect = error + with self.assertRaises(uploader_lib.ExperimentNotFoundError): + uploader._upload_once() + + def test_upload_preserves_wall_time(self): + logdir = self.get_temp_dir() + with tb_test_util.FileWriter(logdir) as writer: + # Add a raw event so we can specify the wall_time value deterministically. + writer.add_event( + event_pb2.Event( + step=1, + wall_time=123.123123123, + summary=scalar_v2.scalar_pb("foo", 5.0))) + mock_client = self._create_mock_client() + mock_rate_limiter = mock.create_autospec(util.RateLimiter) + uploader = uploader_lib.TensorBoardUploader( + mock_client, logdir, mock_rate_limiter) + uploader.create_experiment() + uploader._upload_once() + mock_client.WriteScalar.assert_called_once() + request = mock_client.WriteScalar.call_args[0][0] + # Just check the wall_time value; everything else is covered in the full + # logdir test below. + self.assertEqual( + 123123123123, + request.runs[0].tags[0].points[0].wall_time.ToNanoseconds()) + + def test_upload_full_logdir(self): + logdir = self.get_temp_dir() + mock_client = self._create_mock_client() + mock_rate_limiter = mock.create_autospec(util.RateLimiter) + uploader = uploader_lib.TensorBoardUploader( + mock_client, logdir, mock_rate_limiter) + uploader.create_experiment() + + # Convenience helpers for constructing expected requests. + run = write_service_pb2.WriteScalarRequest.Run + tag = write_service_pb2.WriteScalarRequest.Tag + point = scalar_pb2.ScalarPoint + + # First round + writer = tb_test_util.FileWriter(logdir) + writer.add_test_summary("foo", simple_value=5.0, step=1) + writer.add_test_summary("foo", simple_value=6.0, step=2) + writer.add_test_summary("foo", simple_value=7.0, step=3) + writer.add_test_summary("bar", simple_value=8.0, step=3) + writer.flush() + writer_a = tb_test_util.FileWriter(os.path.join(logdir, "a")) + writer_a.add_test_summary("qux", simple_value=9.0, step=2) + writer_a.flush() + uploader._upload_once() + self.assertEqual(1, mock_client.WriteScalar.call_count) + request1 = mock_client.WriteScalar.call_args[0][0] + _clear_wall_times(request1) + expected_request1 = write_service_pb2.WriteScalarRequest( + experiment_id="123", + runs=[ + run(name=".", + tags=[ + tag(name="foo", + metadata=test_util.scalar_metadata("foo"), + points=[ + point(step=1, value=5.0), + point(step=2, value=6.0), + point(step=3, value=7.0), + ]), + tag(name="bar", + metadata=test_util.scalar_metadata("bar"), + points=[ + point(step=3, value=8.0), + ]), + ]), + run(name="a", + tags=[ + tag(name="qux", + metadata=test_util.scalar_metadata("qux"), + points=[ + point(step=2, value=9.0), + ]), + ]), + ]) + self.assertProtoEquals(expected_request1, request1) + mock_client.WriteScalar.reset_mock() + + # Second round + writer.add_test_summary("foo", simple_value=10.0, step=5) + writer.add_test_summary("baz", simple_value=11.0, step=1) + writer.flush() + writer_b = tb_test_util.FileWriter(os.path.join(logdir, "b")) + writer_b.add_test_summary("xyz", simple_value=12.0, step=1) + writer_b.flush() + uploader._upload_once() + self.assertEqual(1, mock_client.WriteScalar.call_count) + request2 = mock_client.WriteScalar.call_args[0][0] + _clear_wall_times(request2) + expected_request2 = write_service_pb2.WriteScalarRequest( + experiment_id="123", + runs=[ + run(name=".", + tags=[ + tag(name="foo", + metadata=test_util.scalar_metadata("foo"), + points=[ + point(step=5, value=10.0), + ]), + tag(name="baz", + metadata=test_util.scalar_metadata("baz"), + points=[ + point(step=1, value=11.0), + ]), + ]), + run(name="b", + tags=[ + tag(name="xyz", + metadata=test_util.scalar_metadata("xyz"), + points=[ + point(step=1, value=12.0), + ]), + ]), + ]) + self.assertProtoEquals(expected_request2, request2) + mock_client.WriteScalar.reset_mock() + + # Empty third round + uploader._upload_once() + mock_client.WriteScalar.assert_not_called() + + +class RequestBuilderTest(tf.test.TestCase): + + def _populate_run_from_events(self, run_proto, events): + builder = uploader_lib._RequestBuilder(experiment_id="123") + requests = builder.build_requests({"": events}) + request = next(requests, None) + if request is not None: + self.assertLen(request.runs, 1) + run_proto.MergeFrom(request.runs[0]) + self.assertIsNone(next(requests, None)) + + def test_empty_events(self): + run_proto = write_service_pb2.WriteScalarRequest.Run() + self._populate_run_from_events(run_proto, []) + self.assertProtoEquals( + run_proto, write_service_pb2.WriteScalarRequest.Run()) + + def test_aggregation_by_tag(self): + def make_event(step, wall_time, tag, value): + return event_pb2.Event( + step=step, + wall_time=wall_time, + summary=scalar_v2.scalar_pb(tag, value)) + events = [ + make_event(1, 1.0, "one", 11.0), + make_event(1, 2.0, "two", 22.0), + make_event(2, 3.0, "one", 33.0), + make_event(2, 4.0, "two", 44.0), + make_event(1, 5.0, "one", 55.0), # Should preserve duplicate step=1. + make_event(1, 6.0, "three", 66.0), + ] + run_proto = write_service_pb2.WriteScalarRequest.Run() + self._populate_run_from_events(run_proto, events) + tag_data = { + tag.name: [ + (p.step, p.wall_time.ToSeconds(), p.value) for p in tag.points] + for tag in run_proto.tags} + self.assertEqual( + tag_data, { + "one": [(1, 1.0, 11.0), (2, 3.0, 33.0), (1, 5.0, 55.0)], + "two": [(1, 2.0, 22.0), (2, 4.0, 44.0)], + "three": [(1, 6.0, 66.0)], + }) + + def test_skips_non_scalar_events(self): + events = [ + event_pb2.Event(file_version="brain.Event:2"), + event_pb2.Event(summary=scalar_v2.scalar_pb("scalar1", 5.0)), + event_pb2.Event(summary=scalar_v2.scalar_pb("scalar2", 5.0)), + event_pb2.Event( + summary=histogram_v2.histogram_pb("histogram", [5.0])) + ] + run_proto = write_service_pb2.WriteScalarRequest.Run() + self._populate_run_from_events(run_proto, events) + tag_counts = {tag.name: len(tag.points) for tag in run_proto.tags} + self.assertEqual(tag_counts, {"scalar1": 1, "scalar2": 1}) + + def test_skips_scalar_events_in_non_scalar_time_series(self): + events = [ + event_pb2.Event(file_version="brain.Event:2"), + event_pb2.Event(summary=scalar_v2.scalar_pb("scalar1", 5.0)), + event_pb2.Event(summary=scalar_v2.scalar_pb("scalar2", 5.0)), + event_pb2.Event( + summary=histogram_v2.histogram_pb("histogram", [5.0])), + event_pb2.Event(summary=scalar_v2.scalar_pb("histogram", 5.0)), + ] + run_proto = write_service_pb2.WriteScalarRequest.Run() + self._populate_run_from_events(run_proto, events) + tag_counts = {tag.name: len(tag.points) for tag in run_proto.tags} + self.assertEqual(tag_counts, {"scalar1": 1, "scalar2": 1}) + + def test_remembers_first_metadata_in_scalar_time_series(self): + scalar_1 = event_pb2.Event(summary=scalar_v2.scalar_pb("loss", 4.0)) + scalar_2 = event_pb2.Event(summary=scalar_v2.scalar_pb("loss", 3.0)) + scalar_2.summary.value[0].ClearField("metadata") + events = [ + event_pb2.Event(file_version="brain.Event:2"), + scalar_1, + scalar_2, + ] + run_proto = write_service_pb2.WriteScalarRequest.Run() + self._populate_run_from_events(run_proto, events) + tag_counts = {tag.name: len(tag.points) for tag in run_proto.tags} + self.assertEqual(tag_counts, {"loss": 2}) + + def test_v1_summary_single_value(self): + event = event_pb2.Event(step=1, wall_time=123.456) + event.summary.value.add(tag="foo", simple_value=5.0) + run_proto = write_service_pb2.WriteScalarRequest.Run() + self._populate_run_from_events(run_proto, [event]) + expected_run_proto = write_service_pb2.WriteScalarRequest.Run() + foo_tag = expected_run_proto.tags.add() + foo_tag.name = "foo" + foo_tag.metadata.display_name = "foo" + foo_tag.metadata.plugin_data.plugin_name = "scalars" + foo_tag.points.add( + step=1, wall_time=test_util.timestamp_pb(123456000000), value=5.0) + self.assertProtoEquals(run_proto, expected_run_proto) + + def test_v1_summary_multiple_value(self): + event = event_pb2.Event(step=1, wall_time=123.456) + event.summary.value.add(tag="foo", simple_value=1.0) + event.summary.value.add(tag="foo", simple_value=2.0) + event.summary.value.add(tag="foo", simple_value=3.0) + run_proto = write_service_pb2.WriteScalarRequest.Run() + self._populate_run_from_events(run_proto, [event]) + expected_run_proto = write_service_pb2.WriteScalarRequest.Run() + foo_tag = expected_run_proto.tags.add() + foo_tag.name = "foo" + foo_tag.metadata.display_name = "foo" + foo_tag.metadata.plugin_data.plugin_name = "scalars" + foo_tag.points.add( + step=1, wall_time=test_util.timestamp_pb(123456000000), value=1.0) + foo_tag.points.add( + step=1, wall_time=test_util.timestamp_pb(123456000000), value=2.0) + foo_tag.points.add( + step=1, wall_time=test_util.timestamp_pb(123456000000), value=3.0) + self.assertProtoEquals(run_proto, expected_run_proto) + + def test_v1_summary_tb_summary(self): + tf_summary = summary_v1.scalar_pb("foo", 5.0) + tb_summary = summary_pb2.Summary.FromString(tf_summary.SerializeToString()) + event = event_pb2.Event(step=1, wall_time=123.456, summary=tb_summary) + run_proto = write_service_pb2.WriteScalarRequest.Run() + self._populate_run_from_events(run_proto, [event]) + expected_run_proto = write_service_pb2.WriteScalarRequest.Run() + foo_tag = expected_run_proto.tags.add() + foo_tag.name = "foo/scalar_summary" + foo_tag.metadata.display_name = "foo" + foo_tag.metadata.plugin_data.plugin_name = "scalars" + foo_tag.points.add( + step=1, wall_time=test_util.timestamp_pb(123456000000), value=5.0) + self.assertProtoEquals(run_proto, expected_run_proto) + + def test_v2_summary(self): + event = event_pb2.Event( + step=1, wall_time=123.456, summary=scalar_v2.scalar_pb("foo", 5.0)) + run_proto = write_service_pb2.WriteScalarRequest.Run() + self._populate_run_from_events(run_proto, [event]) + expected_run_proto = write_service_pb2.WriteScalarRequest.Run() + foo_tag = expected_run_proto.tags.add() + foo_tag.name = "foo" + foo_tag.metadata.plugin_data.plugin_name = "scalars" + foo_tag.points.add( + step=1, wall_time=test_util.timestamp_pb(123456000000), value=5.0) + self.assertProtoEquals(run_proto, expected_run_proto) + + def test_no_budget_for_experiment_id(self): + event = event_pb2.Event(step=1, wall_time=123.456) + event.summary.value.add(tag="foo", simple_value=1.0) + run_to_events = {"run_name": [event]} + long_experiment_id = "A" * uploader_lib._MAX_REQUEST_LENGTH_BYTES + with self.assertRaises(RuntimeError) as cm: + builder = uploader_lib._RequestBuilder(long_experiment_id) + list(builder.build_requests(run_to_events)) + self.assertEqual( + str(cm.exception), "Byte budget too small for experiment ID") + + def test_no_room_for_single_point(self): + event = event_pb2.Event(step=1, wall_time=123.456) + event.summary.value.add(tag="foo", simple_value=1.0) + long_run_name = "A" * uploader_lib._MAX_REQUEST_LENGTH_BYTES + run_to_events = {long_run_name: [event]} + with self.assertRaises(RuntimeError) as cm: + builder = uploader_lib._RequestBuilder("123") + list(builder.build_requests(run_to_events)) + self.assertEqual( + str(cm.exception), "Could not make progress uploading data") + + @mock.patch.object(uploader_lib, "_MAX_REQUEST_LENGTH_BYTES", 1024) + def test_break_at_run_boundary(self): + # Choose run name sizes such that one run fits, but not two. + long_run_1 = "A" * 768 + long_run_2 = "B" * 768 + event_1 = event_pb2.Event(step=1) + event_1.summary.value.add(tag="foo", simple_value=1.0) + event_2 = event_pb2.Event(step=2) + event_2.summary.value.add(tag="bar", simple_value=-2.0) + run_to_events = collections.OrderedDict([ + (long_run_1, [event_1]), + (long_run_2, [event_2]), + ]) + + builder = uploader_lib._RequestBuilder("123") + requests = list(builder.build_requests(run_to_events)) + for request in requests: + _clear_wall_times(request) + + expected = [ + write_service_pb2.WriteScalarRequest(experiment_id="123"), + write_service_pb2.WriteScalarRequest(experiment_id="123"), + ] + (expected[0].runs.add(name=long_run_1).tags.add( + name="foo", metadata=test_util.scalar_metadata("foo")).points.add( + step=1, value=1.0)) + (expected[1].runs.add(name=long_run_2).tags.add( + name="bar", metadata=test_util.scalar_metadata("bar")).points.add( + step=2, value=-2.0)) + self.assertEqual(requests, expected) + + @mock.patch.object(uploader_lib, "_MAX_REQUEST_LENGTH_BYTES", 1024) + def test_break_at_tag_boundary(self): + # Choose tag name sizes such that one tag fits, but not two. Note + # that tag names appear in both `Tag.name` and the summary metadata. + long_tag_1 = "a" * 384 + long_tag_2 = "b" * 384 + event = event_pb2.Event(step=1) + event.summary.value.add(tag=long_tag_1, simple_value=1.0) + event.summary.value.add(tag=long_tag_2, simple_value=2.0) + run_to_events = {"train": [event]} + + builder = uploader_lib._RequestBuilder("123") + requests = list(builder.build_requests(run_to_events)) + for request in requests: + _clear_wall_times(request) + + expected = [ + write_service_pb2.WriteScalarRequest(experiment_id="123"), + write_service_pb2.WriteScalarRequest(experiment_id="123"), + ] + (expected[0].runs.add(name="train").tags.add( + name=long_tag_1, + metadata=test_util.scalar_metadata(long_tag_1)).points.add( + step=1, value=1.0)) + (expected[1].runs.add(name="train").tags.add( + name=long_tag_2, + metadata=test_util.scalar_metadata(long_tag_2)).points.add( + step=1, value=2.0)) + self.assertEqual(requests, expected) + + @mock.patch.object(uploader_lib, "_MAX_REQUEST_LENGTH_BYTES", 1024) + def test_break_at_scalar_point_boundary(self): + point_count = 2000 # comfortably saturates a single 1024-byte request + events = [] + for step in range(point_count): + summary = scalar_v2.scalar_pb("loss", -2.0 * step) + if step > 0: + summary.value[0].ClearField("metadata") + events.append(event_pb2.Event(summary=summary, step=step)) + run_to_events = {"train": events} + + builder = uploader_lib._RequestBuilder("123") + requests = list(builder.build_requests(run_to_events)) + for request in requests: + _clear_wall_times(request) + + self.assertGreater(len(requests), 1) + self.assertLess(len(requests), point_count) + + total_points_in_result = 0 + for request in requests: + self.assertLen(request.runs, 1) + run = request.runs[0] + self.assertEqual(run.name, "train") + self.assertLen(run.tags, 1) + tag = run.tags[0] + self.assertEqual(tag.name, "loss") + for point in tag.points: + self.assertEqual(point.step, total_points_in_result) + self.assertEqual(point.value, -2.0 * point.step) + total_points_in_result += 1 + self.assertLessEqual( + request.ByteSize(), uploader_lib._MAX_REQUEST_LENGTH_BYTES) + self.assertEqual(total_points_in_result, point_count) + + def test_prunes_tags_and_runs(self): + event_1 = event_pb2.Event(step=1) + event_1.summary.value.add(tag="foo", simple_value=1.0) + event_2 = event_pb2.Event(step=2) + event_2.summary.value.add(tag="bar", simple_value=-2.0) + run_to_events = collections.OrderedDict([ + ("train", [event_1]), + ("test", [event_2]), + ]) + + real_create_point = uploader_lib._RequestBuilder._create_point + + create_point_call_count_box = [0] + + def mock_create_point(uploader_self, *args, **kwargs): + # Simulate out-of-space error the first time that we try to store + # the second point. + create_point_call_count_box[0] += 1 + if create_point_call_count_box[0] == 2: + raise uploader_lib._OutOfSpaceError() + return real_create_point(uploader_self, *args, **kwargs) + + with mock.patch.object( + uploader_lib._RequestBuilder, "_create_point", mock_create_point): + builder = uploader_lib._RequestBuilder("123") + requests = list(builder.build_requests(run_to_events)) + for request in requests: + _clear_wall_times(request) + + expected = [ + write_service_pb2.WriteScalarRequest(experiment_id="123"), + write_service_pb2.WriteScalarRequest(experiment_id="123"), + ] + (expected[0].runs.add(name="train").tags.add( + name="foo", metadata=test_util.scalar_metadata("foo")).points.add( + step=1, value=1.0)) + (expected[1].runs.add(name="test").tags.add( + name="bar", metadata=test_util.scalar_metadata("bar")).points.add( + step=2, value=-2.0)) + self.assertEqual(expected, requests) + + def test_wall_time_precision(self): + # Test a wall time that is exactly representable in float64 but has enough + # digits to incur error if converted to nanonseconds the naive way (* 1e9). + event1 = event_pb2.Event(step=1, wall_time=1567808404.765432119) + event1.summary.value.add(tag="foo", simple_value=1.0) + # Test a wall time where as a float64, the fractional part on its own will + # introduce error if truncated to 9 decimal places instead of rounded. + event2 = event_pb2.Event(step=2, wall_time=1.000000002) + event2.summary.value.add(tag="foo", simple_value=2.0) + run_proto = write_service_pb2.WriteScalarRequest.Run() + self._populate_run_from_events(run_proto, [event1, event2]) + self.assertEqual( + test_util.timestamp_pb(1567808404765432119), + run_proto.tags[0].points[0].wall_time) + self.assertEqual( + test_util.timestamp_pb(1000000002), + run_proto.tags[0].points[1].wall_time) + + +class DeleteExperimentTest(tf.test.TestCase): + + def _create_mock_client(self): + # Create a stub instance (using a test channel) in order to derive a mock + # from it with autospec enabled. Mocking TensorBoardWriterServiceStub itself + # doesn't work with autospec because grpc constructs stubs via metaclassing. + test_channel = grpc_testing.channel( + service_descriptors=[], time=grpc_testing.strict_real_time()) + stub = write_service_pb2_grpc.TensorBoardWriterServiceStub(test_channel) + mock_client = mock.create_autospec(stub) + return mock_client + + def test_success(self): + mock_client = self._create_mock_client() + response = write_service_pb2.DeleteExperimentResponse() + mock_client.DeleteExperiment.return_value = response + + uploader_lib.delete_experiment(mock_client, "123") + + expected_request = write_service_pb2.DeleteExperimentRequest() + expected_request.experiment_id = "123" + mock_client.DeleteExperiment.assert_called_once() + (args, _) = mock_client.DeleteExperiment.call_args + self.assertEqual(args[0], expected_request) + + def test_not_found(self): + mock_client = self._create_mock_client() + error = test_util.grpc_error(grpc.StatusCode.NOT_FOUND, "nope") + mock_client.DeleteExperiment.side_effect = error + + with self.assertRaises(uploader_lib.ExperimentNotFoundError): + uploader_lib.delete_experiment(mock_client, "123") + + def test_unauthorized(self): + mock_client = self._create_mock_client() + error = test_util.grpc_error(grpc.StatusCode.PERMISSION_DENIED, "nope") + mock_client.DeleteExperiment.side_effect = error + + with self.assertRaises(uploader_lib.PermissionDeniedError): + uploader_lib.delete_experiment(mock_client, "123") + + def test_internal_error(self): + mock_client = self._create_mock_client() + error = test_util.grpc_error(grpc.StatusCode.INTERNAL, "travesty") + mock_client.DeleteExperiment.side_effect = error + + with self.assertRaises(grpc.RpcError) as cm: + uploader_lib.delete_experiment(mock_client, "123") + msg = str(cm.exception) + self.assertIn("travesty", msg) + + +class VarintCostTest(tf.test.TestCase): + + def test_varint_cost(self): + self.assertEqual(uploader_lib._varint_cost(0), 1) + self.assertEqual(uploader_lib._varint_cost(7), 1) + self.assertEqual(uploader_lib._varint_cost(127), 1) + self.assertEqual(uploader_lib._varint_cost(128), 2) + self.assertEqual(uploader_lib._varint_cost(128 * 128 - 1), 2) + self.assertEqual(uploader_lib._varint_cost(128 * 128), 3) + + +def _clear_wall_times(request): + """Clears the wall_time fields in a WriteScalarRequest to be deterministic.""" + for run in request.runs: + for tag in run.tags: + for point in tag.points: + point.ClearField("wall_time") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/uploader/util.py b/tensorboard/uploader/util.py new file mode 100644 index 0000000000..795758c4c1 --- /dev/null +++ b/tensorboard/uploader/util.py @@ -0,0 +1,114 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for use by the uploader command line tool.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import errno +import os +import os.path +import time + + +class RateLimiter(object): + """Helper class for rate-limiting using a fixed minimum interval.""" + + def __init__(self, interval_secs): + """Constructs a RateLimiter that permits a tick() every `interval_secs`.""" + self._time = time # Use property for ease of testing. + self._interval_secs = interval_secs + self._last_called_secs = 0 + + def tick(self): + """Blocks until it has been at least `interval_secs` since last tick().""" + wait_secs = self._last_called_secs + self._interval_secs - self._time.time() + if wait_secs > 0: + self._time.sleep(wait_secs) + self._last_called_secs = self._time.time() + + +def get_user_config_directory(): + """Returns a platform-specific root directory for user config settings.""" + # On Windows, prefer %LOCALAPPDATA%, then %APPDATA%, since we can expect the + # AppData directories to be ACLed to be visible only to the user and admin + # users (https://stackoverflow.com/a/7617601/1179226). If neither is set, + # return None instead of falling back to something that may be world-readable. + if os.name == "nt": + appdata = os.getenv("LOCALAPPDATA") + if appdata: + return appdata + appdata = os.getenv("APPDATA") + if appdata: + return appdata + return None + # On non-windows, use XDG_CONFIG_HOME if set, else default to ~/.config. + xdg_config_home = os.getenv("XDG_CONFIG_HOME") + if xdg_config_home: + return xdg_config_home + return os.path.join(os.path.expanduser("~"), ".config") + + +def make_file_with_directories(path, private=False): + """Creates a file and its containing directories, if they don't already exist. + + + If `private` is True, the file will be made private (readable only by the + current user) and so will the leaf directory. Pre-existing contents of the + file are not modified. + + Passing `private=True` is not supported on Windows because it doesn't support + the relevant parts of `os.chmod()`. + + Args: + path: str, The path of the file to create. + private: boolean, Whether to make the file and leaf directory readable only + by the current user. + + Raises: + RuntimeError: If called on Windows with `private` set to True. + """ + if private and os.name == "nt": + raise RuntimeError("Creating private file not supported on Windows") + try: + path = os.path.realpath(path) + leaf_dir = os.path.dirname(path) + try: + os.makedirs(leaf_dir) + except OSError as e: + if e.errno != errno.EEXIST: + raise + if private: + os.chmod(leaf_dir, 0o700) + open(path, "a").close() + if private: + os.chmod(path, 0o600) + except EnvironmentError as e: + raise RuntimeError("Failed to create file %s: %s" % (path, e)) + + +def set_timestamp(pb, seconds_since_epoch): + """Sets a `Timestamp` proto message to a floating point UNIX time. + + This is like `pb.FromNanoseconds(int(seconds_since_epoch * 1e9))` but + without introducing floating-point error. + + Args: + pb: A `google.protobuf.Timestamp` message to mutate. + seconds_since_epoch: A `float`, as returned by `time.time`. + """ + pb.seconds = int(seconds_since_epoch) + pb.nanos = int(round((seconds_since_epoch % 1) * 10**9)) diff --git a/tensorboard/uploader/util_test.py b/tensorboard/uploader/util_test.py new file mode 100644 index 0000000000..b0670d5315 --- /dev/null +++ b/tensorboard/uploader/util_test.py @@ -0,0 +1,199 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorboard.uploader.util.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import unittest + + +try: + # python version >= 3.3 + from unittest import mock # pylint: disable=g-import-not-at-top +except ImportError: + import mock # pylint: disable=g-import-not-at-top,unused-import + + +from google.protobuf import timestamp_pb2 +from tensorboard.uploader import test_util +from tensorboard.uploader import util +from tensorboard import test as tb_test + + +class RateLimiterTest(tb_test.TestCase): + + def test_rate_limiting(self): + rate_limiter = util.RateLimiter(10) + fake_time = test_util.FakeTime(current=1000) + with mock.patch.object(rate_limiter, "_time", fake_time): + self.assertEqual(1000, fake_time.time()) + # No sleeping for initial tick. + rate_limiter.tick() + self.assertEqual(1000, fake_time.time()) + # Second tick requires a full sleep. + rate_limiter.tick() + self.assertEqual(1010, fake_time.time()) + # Third tick requires a sleep just to make up the remaining second. + fake_time.sleep(9) + self.assertEqual(1019, fake_time.time()) + rate_limiter.tick() + self.assertEqual(1020, fake_time.time()) + # Fourth tick requires no sleep since we have no remaining seconds. + fake_time.sleep(11) + self.assertEqual(1031, fake_time.time()) + rate_limiter.tick() + self.assertEqual(1031, fake_time.time()) + + +class GetUserConfigDirectoryTest(tb_test.TestCase): + + def test_windows(self): + with mock.patch.object(os, "name", "nt"): + with mock.patch.dict(os.environ, { + "LOCALAPPDATA": "C:\\Users\\Alice\\AppData\\Local", + "APPDATA": "C:\\Users\\Alice\\AppData\\Roaming", + }): + self.assertEqual( + "C:\\Users\\Alice\\AppData\\Local", + util.get_user_config_directory()) + with mock.patch.dict(os.environ, { + "LOCALAPPDATA": "", + "APPDATA": "C:\\Users\\Alice\\AppData\\Roaming", + }): + self.assertEqual( + "C:\\Users\\Alice\\AppData\\Roaming", + util.get_user_config_directory()) + with mock.patch.dict(os.environ, { + "LOCALAPPDATA": "", + "APPDATA": "", + }): + self.assertIsNone(util.get_user_config_directory()) + + def test_non_windows(self): + with mock.patch.dict(os.environ, {"HOME": "/home/alice"}): + self.assertEqual( + "/home/alice%s.config" % os.sep, util.get_user_config_directory()) + with mock.patch.dict( + os.environ, {"XDG_CONFIG_HOME": "/home/alice/configz"}): + self.assertEqual( + "/home/alice/configz", util.get_user_config_directory()) + + +skip_if_windows = unittest.skipIf(os.name == "nt", "Unsupported on Windows") + + +class MakeFileWithDirectoriesTest(tb_test.TestCase): + + def test_windows_private(self): + with mock.patch.object(os, "name", "nt"): + with self.assertRaisesRegex(RuntimeError, "Windows"): + util.make_file_with_directories("/tmp/foo", private=True) + + def test_existing_file(self): + root = self.get_temp_dir() + path = os.path.join(root, "foo", "bar", "qux.txt") + os.makedirs(os.path.dirname(path)) + with open(path, mode="w") as f: + f.write("foobar") + util.make_file_with_directories(path) + with open(path, mode="r") as f: + self.assertEqual("foobar", f.read()) + + def test_existing_dir(self): + root = self.get_temp_dir() + path = os.path.join(root, "foo", "bar", "qux.txt") + os.makedirs(os.path.dirname(path)) + util.make_file_with_directories(path) + self.assertEqual(0, os.path.getsize(path)) + + def test_nonexistent_leaf_dir(self): + root = self.get_temp_dir() + path = os.path.join(root, "foo", "bar", "qux.txt") + os.makedirs(os.path.dirname(os.path.dirname(path))) + util.make_file_with_directories(path) + self.assertEqual(0, os.path.getsize(path)) + + def test_nonexistent_multiple_dirs(self): + root = self.get_temp_dir() + path = os.path.join(root, "foo", "bar", "qux.txt") + util.make_file_with_directories(path) + self.assertEqual(0, os.path.getsize(path)) + + def assertMode(self, mode, path): + self.assertEqual(mode, os.stat(path).st_mode & 0o777) + + @skip_if_windows + def test_private_existing_file(self): + root = self.get_temp_dir() + path = os.path.join(root, "foo", "bar", "qux.txt") + os.makedirs(os.path.dirname(path)) + with open(path, mode="w") as f: + f.write("foobar") + os.chmod(os.path.dirname(path), 0o777) + os.chmod(path, 0o666) + util.make_file_with_directories(path, private=True) + self.assertMode(0o700, os.path.dirname(path)) + self.assertMode(0o600, path) + with open(path, mode="r") as f: + self.assertEqual("foobar", f.read()) + + @skip_if_windows + def test_private_existing_dir(self): + root = self.get_temp_dir() + path = os.path.join(root, "foo", "bar", "qux.txt") + os.makedirs(os.path.dirname(path)) + os.chmod(os.path.dirname(path), 0o777) + util.make_file_with_directories(path, private=True) + self.assertMode(0o700, os.path.dirname(path)) + self.assertMode(0o600, path) + self.assertEqual(0, os.path.getsize(path)) + + @skip_if_windows + def test_private_nonexistent_leaf_dir(self): + root = self.get_temp_dir() + path = os.path.join(root, "foo", "bar", "qux.txt") + os.makedirs(os.path.dirname(os.path.dirname(path))) + util.make_file_with_directories(path, private=True) + self.assertMode(0o700, os.path.dirname(path)) + self.assertMode(0o600, path) + self.assertEqual(0, os.path.getsize(path)) + + @skip_if_windows + def test_private_nonexistent_multiple_dirs(self): + root = self.get_temp_dir() + path = os.path.join(root, "foo", "bar", "qux.txt") + util.make_file_with_directories(path, private=True) + self.assertMode(0o700, os.path.dirname(path)) + self.assertMode(0o600, path) + self.assertEqual(0, os.path.getsize(path)) + + +class SetTimestampTest(tb_test.TestCase): + + def test_set_timestamp(self): + pb = timestamp_pb2.Timestamp() + t = 1234567890.007812500 + # Note that just multiplying by 1e9 would lose precision: + self.assertEqual(int(t * 1e9) % int(1e9), 7812608) + util.set_timestamp(pb, t) + self.assertEqual(pb.seconds, 1234567890) + self.assertEqual(pb.nanos, 7812500) + + +if __name__ == "__main__": + tb_test.main() diff --git a/tensorboard/util/BUILD b/tensorboard/util/BUILD index f4fbe92a27..3ec9e5c44d 100644 --- a/tensorboard/util/BUILD +++ b/tensorboard/util/BUILD @@ -1,9 +1,28 @@ package(default_visibility = ["//tensorboard:internal"]) +load("//tensorboard/defs:protos.bzl", "tb_proto_library") + licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) # Needed for internal repo. +py_library( + name = "argparse_util", + srcs = ["argparse_util.py"], + srcs_version = "PY2AND3", +) + +py_test( + name = "argparse_util_test", + size = "small", + srcs = ["argparse_util_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":argparse_util", + "//tensorboard:test", + ], +) + py_library( name = "encoder", srcs = ["encoder.py"], @@ -28,6 +47,43 @@ py_test( ], ) +py_library( + name = "grpc_util", + srcs = ["grpc_util.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorboard:expect_grpc_installed", + "//tensorboard:version", + "//tensorboard/util:tb_logging", + ], +) + +py_test( + name = "grpc_util_test", + size = "small", + srcs = ["grpc_util_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":grpc_util", + ":grpc_util_test_proto_py_pb2", + ":grpc_util_test_proto_py_pb2_grpc", + ":test_util", + "//tensorboard:expect_futures_installed", + "//tensorboard:expect_grpc_installed", + "//tensorboard:test", + "//tensorboard:version", + "@org_pythonhosted_mock", + "@org_pythonhosted_six", + ], +) + +tb_proto_library( + name = "grpc_util_test_proto", + has_services = True, + srcs = ["grpc_util_test.proto"], + testonly = True, +) + py_library( name = "op_evaluator", srcs = ["op_evaluator.py"], diff --git a/tensorboard/util/argparse_util.py b/tensorboard/util/argparse_util.py new file mode 100644 index 0000000000..27d08dd1d5 --- /dev/null +++ b/tensorboard/util/argparse_util.py @@ -0,0 +1,65 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for working with `argparse` in a portable way.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import contextlib +import gettext + + +@contextlib.contextmanager +def allow_missing_subcommand(): + """Make Python 2.7 behave like Python 3 w.r.t. default subcommands. + + The behavior of argparse was changed [1] [2] in Python 3.3. When a + parser defines subcommands, it used to be an error for the user to + invoke the binary without specifying a subcommand. As of Python 3.3, + this is permitted. This monkey patch backports the new behavior to + earlier versions of Python. + + This context manager need only be used around `parse_args`; parsers + may be constructed and configured outside of the context manager. + + [1]: https://github.com/python/cpython/commit/f97c59aaba2d93e48cbc6d25f7ff9f9c87f8d0b2 + [2]: https://bugs.python.org/issue16308 + """ + + real_error = argparse.ArgumentParser.error + + # This must exactly match the error message raised by Python 2.7's + # `argparse` when no subparser is given. This is `argparse.py:1954` at + # Git tag `v2.7.16`. + ignored_message = gettext.gettext("too few arguments") + + def error(*args, **kwargs): + # Expected signature is `error(self, message)`, but we retain more + # flexibility to be forward-compatible with implementation changes. + if "message" not in kwargs and len(args) < 2: + return real_error(*args, **kwargs) + message = kwargs["message"] if "message" in kwargs else args[1] + if message == ignored_message: + return None + else: + return real_error(*args, **kwargs) + + argparse.ArgumentParser.error = error + try: + yield + finally: + argparse.ArgumentParser.error = real_error diff --git a/tensorboard/util/argparse_util_test.py b/tensorboard/util/argparse_util_test.py new file mode 100644 index 0000000000..4e18d4a027 --- /dev/null +++ b/tensorboard/util/argparse_util_test.py @@ -0,0 +1,56 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `argparse_util`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse + +from tensorboard import test as tb_test +from tensorboard.util import argparse_util + + +class AllowMissingSubcommandTest(tb_test.TestCase): + + def test_allows_missing_subcommands(self): + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers() + subparser = subparsers.add_parser("magic") + subparser.set_defaults(chosen="magic") + with argparse_util.allow_missing_subcommand(): + args = parser.parse_args([]) + self.assertEqual(args, argparse.Namespace()) + + def test_allows_provided_subcommands(self): + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers() + subparser = subparsers.add_parser("magic") + subparser.set_defaults(chosen="magic") + with argparse_util.allow_missing_subcommand(): + args = parser.parse_args(["magic"]) + self.assertEqual(args, argparse.Namespace(chosen="magic")) + + def test_still_complains_on_missing_arguments(self): + parser = argparse.ArgumentParser() + parser.add_argument("please_provide_me") + with argparse_util.allow_missing_subcommand(): + with self.assertRaises(SystemExit): + parser.parse_args([]) + + +if __name__ == "__main__": + tb_test.main() diff --git a/tensorboard/util/grpc_util.py b/tensorboard/util/grpc_util.py new file mode 100644 index 0000000000..43c551dfeb --- /dev/null +++ b/tensorboard/util/grpc_util.py @@ -0,0 +1,126 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for working with python gRPC stubs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +import time + +import grpc + +from tensorboard import version +from tensorboard.util import tb_logging + +logger = tb_logging.get_logger() + +# Default RPC timeout. +_GRPC_DEFAULT_TIMEOUT_SECS = 30 + +# Max number of times to attempt an RPC, retrying on transient failures. +_GRPC_RETRY_MAX_ATTEMPTS = 5 + +# Parameters to control the exponential backoff behavior. +_GRPC_RETRY_EXPONENTIAL_BASE = 2 +_GRPC_RETRY_JITTER_FACTOR_MIN = 1.1 +_GRPC_RETRY_JITTER_FACTOR_MAX = 1.5 + +# Status codes from gRPC for which it's reasonable to retry the RPC. +_GRPC_RETRYABLE_STATUS_CODES = frozenset([ + grpc.StatusCode.ABORTED, + grpc.StatusCode.DEADLINE_EXCEEDED, + grpc.StatusCode.RESOURCE_EXHAUSTED, + grpc.StatusCode.UNAVAILABLE, +]) + +# gRPC metadata key whose value contains the client version. +_VERSION_METADATA_KEY = "tensorboard-version" + + +def call_with_retries(api_method, request, clock=None): + """Call a gRPC stub API method, with automatic retry logic. + + This only supports unary-unary RPCs: i.e., no streaming on either end. + Streamed RPCs will generally need application-level pagination support, + because after a gRPC error one must retry the entire request; there is no + "retry-resume" functionality. + + Args: + api_method: Callable for the API method to invoke. + request: Request protocol buffer to pass to the API method. + clock: an interface object supporting `time()` and `sleep()` methods + like the standard `time` module; if not passed, uses the normal module. + + Returns: + Response protocol buffer returned by the API method. + + Raises: + grpc.RpcError: if a non-retryable error is returned, or if all retry + attempts have been exhausted. + """ + if clock is None: + clock = time + # We can't actually use api_method.__name__ because it's not a real method, + # it's a special gRPC callable instance that doesn't expose the method name. + rpc_name = request.__class__.__name__.replace("Request", "") + logger.debug("RPC call %s with request: %r", rpc_name, request) + num_attempts = 0 + while True: + num_attempts += 1 + try: + return api_method( + request, + timeout=_GRPC_DEFAULT_TIMEOUT_SECS, + metadata=version_metadata()) + except grpc.RpcError as e: + logger.info("RPC call %s got error %s", rpc_name, e) + if e.code() not in _GRPC_RETRYABLE_STATUS_CODES: + raise + if num_attempts >= _GRPC_RETRY_MAX_ATTEMPTS: + raise + jitter_factor = random.uniform( + _GRPC_RETRY_JITTER_FACTOR_MIN, _GRPC_RETRY_JITTER_FACTOR_MAX) + backoff_secs = (_GRPC_RETRY_EXPONENTIAL_BASE**num_attempts) * jitter_factor + logger.info( + "RPC call %s attempted %d times, retrying in %.1f seconds", + rpc_name, num_attempts, backoff_secs) + clock.sleep(backoff_secs) + + +def version_metadata(): + """Creates gRPC invocation metadata encoding the TensorBoard version. + + Usage: `stub.MyRpc(request, metadata=version_metadata())`. + + Returns: + A tuple of key-value pairs (themselves 2-tuples) to be passed as the + `metadata` kwarg to gRPC stub API methods. + """ + return ((_VERSION_METADATA_KEY, version.VERSION),) + + +def extract_version(metadata): + """Extracts version from invocation metadata. + + The argument should be the result of a prior call to `metadata` or the + result of combining such a result with other metadata. + + Returns: + The TensorBoard version listed in this metadata, or `None` if none + is listed. + """ + return dict(metadata).get(_VERSION_METADATA_KEY) diff --git a/tensorboard/util/grpc_util_test.proto b/tensorboard/util/grpc_util_test.proto new file mode 100644 index 0000000000..3f4bb5c976 --- /dev/null +++ b/tensorboard/util/grpc_util_test.proto @@ -0,0 +1,18 @@ +// Minimal example RPC service definition. See grpc_util_test.py for usage. +syntax = "proto3"; + +package tensorboard.util; + +// Test service for grpc_util_test.py. +service TestService { + // Test RPC. + rpc TestRpc(TestRpcRequest) returns (TestRpcResponse); +} + +message TestRpcRequest { + int32 nonce = 1; +} + +message TestRpcResponse { + int32 nonce = 1; +} diff --git a/tensorboard/util/grpc_util_test.py b/tensorboard/util/grpc_util_test.py new file mode 100644 index 0000000000..4aeabc8914 --- /dev/null +++ b/tensorboard/util/grpc_util_test.py @@ -0,0 +1,164 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `tensorboard.util.grpc_util`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import hashlib +import threading + +from concurrent import futures +import grpc +import six + +from tensorboard.util import grpc_util +from tensorboard.util import grpc_util_test_pb2 +from tensorboard.util import grpc_util_test_pb2_grpc +from tensorboard.util import test_util +from tensorboard import test as tb_test +from tensorboard import version + + +def make_request(nonce): + return grpc_util_test_pb2.TestRpcRequest(nonce=nonce) + + +def make_response(nonce): + return grpc_util_test_pb2.TestRpcResponse(nonce=nonce) + + +class TestGrpcServer(grpc_util_test_pb2_grpc.TestServiceServicer): + """Helper for testing gRPC client logic with a dummy gRPC server.""" + + def __init__(self, handler): + super(TestGrpcServer, self).__init__() + self._handler = handler + + def TestRpc(self, request, context): + return self._handler(request, context) + + @contextlib.contextmanager + def run(self): + """Context manager to run the gRPC server and yield a client for it.""" + server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) + grpc_util_test_pb2_grpc.add_TestServiceServicer_to_server(self, server) + port = server.add_secure_port( + "localhost:0", grpc.local_server_credentials()) + def launch_server(): + server.start() + server.wait_for_termination() + thread = threading.Thread(target=launch_server, name="TestGrpcServer") + thread.daemon = True + thread.start() + with grpc.secure_channel( + "localhost:%d" % port, grpc.local_channel_credentials()) as channel: + yield grpc_util_test_pb2_grpc.TestServiceStub(channel) + server.stop(grace=None) + thread.join() + + +class CallWithRetriesTest(tb_test.TestCase): + + def test_call_with_retries_succeeds(self): + def handler(request, _): + return make_response(request.nonce) + server = TestGrpcServer(handler) + with server.run() as client: + response = grpc_util.call_with_retries(client.TestRpc, make_request(42)) + self.assertEqual(make_response(42), response) + + def test_call_with_retries_fails_immediately_on_permanent_error(self): + def handler(_, context): + context.abort(grpc.StatusCode.INTERNAL, "foo") + server = TestGrpcServer(handler) + with server.run() as client: + with self.assertRaises(grpc.RpcError) as raised: + grpc_util.call_with_retries(client.TestRpc, make_request(42)) + self.assertEqual(grpc.StatusCode.INTERNAL, raised.exception.code()) + self.assertEqual("foo", raised.exception.details()) + + def test_call_with_retries_fails_after_backoff_on_nonpermanent_error(self): + attempt_times = [] + fake_time = test_util.FakeTime() + def handler(_, context): + attempt_times.append(fake_time.time()) + context.abort(grpc.StatusCode.UNAVAILABLE, "foo") + server = TestGrpcServer(handler) + with server.run() as client: + with self.assertRaises(grpc.RpcError) as raised: + grpc_util.call_with_retries(client.TestRpc, make_request(42), fake_time) + self.assertEqual(grpc.StatusCode.UNAVAILABLE, raised.exception.code()) + self.assertEqual("foo", raised.exception.details()) + self.assertLen(attempt_times, 5) + self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4) + self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8) + self.assertBetween(attempt_times[3] - attempt_times[2], 8, 16) + self.assertBetween(attempt_times[4] - attempt_times[3], 16, 32) + + def test_call_with_retries_succeeds_after_backoff_on_transient_error(self): + attempt_times = [] + fake_time = test_util.FakeTime() + def handler(request, context): + attempt_times.append(fake_time.time()) + if len(attempt_times) < 3: + context.abort(grpc.StatusCode.UNAVAILABLE, "foo") + return make_response(request.nonce) + server = TestGrpcServer(handler) + with server.run() as client: + response = grpc_util.call_with_retries( + client.TestRpc, make_request(42), fake_time) + self.assertEqual(make_response(42), response) + self.assertLen(attempt_times, 3) + self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4) + self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8) + + def test_call_with_retries_includes_version_metadata(self): + def digest(s): + """Hashes a string into a 32-bit integer.""" + return int(hashlib.sha256(s.encode("utf-8")).hexdigest(), 16) & 0xffffffff + def handler(request, context): + metadata = context.invocation_metadata() + client_version = grpc_util.extract_version(metadata) + return make_response(digest(client_version)) + server = TestGrpcServer(handler) + with server.run() as client: + response = grpc_util.call_with_retries(client.TestRpc, make_request(0)) + expected_nonce = digest( + grpc_util.extract_version(grpc_util.version_metadata())) + self.assertEqual(make_response(expected_nonce), response) + + +class VersionMetadataTest(tb_test.TestCase): + + def test_structure(self): + result = grpc_util.version_metadata() + self.assertIsInstance(result, tuple) + for kv in result: + self.assertIsInstance(kv, tuple) + self.assertLen(kv, 2) + (k, v) = kv + self.assertIsInstance(k, str) + self.assertIsInstance(v, six.string_types) + + def test_roundtrip(self): + result = grpc_util.extract_version(grpc_util.version_metadata()) + self.assertEqual(result, version.VERSION) + + +if __name__ == "__main__": + tb_test.main() diff --git a/tensorboard/util/test_util.py b/tensorboard/util/test_util.py index c970dcf472..44d56fa164 100644 --- a/tensorboard/util/test_util.py +++ b/tensorboard/util/test_util.py @@ -128,6 +128,22 @@ def get(logdir): return FileWriterCache._cache[logdir] +class FakeTime(object): + """Thread-safe fake replacement for the `time` module.""" + + def __init__(self, current=0.0): + self._time = float(current) + self._lock = threading.Lock() + + def time(self): + with self._lock: + return self._time + + def sleep(self, secs): + with self._lock: + self._time += secs + + def ensure_tb_summary_proto(summary): """Ensures summary is TensorBoard Summary proto. diff --git a/tensorboard/version.py b/tensorboard/version.py index 9d80d2e29a..09fea74ee6 100644 --- a/tensorboard/version.py +++ b/tensorboard/version.py @@ -15,4 +15,4 @@ """Contains the version string.""" -VERSION = '2.0.0' +VERSION = '2.0.1' diff --git a/third_party/mock_call_assertions.patch b/third_party/mock_call_assertions.patch new file mode 100644 index 0000000000..bc28e1b633 --- /dev/null +++ b/third_party/mock_call_assertions.patch @@ -0,0 +1,59 @@ +--- mock.py 2012-10-07 18:00:10.000000000 +0100 ++++ mock.py 2019-10-24 22:19:25.657417082 -0700 +@@ -286,6 +286,12 @@ + if not _is_instance_mock(mock): + return + ++ def assert_called(*args, **kwargs): ++ return mock.assert_called(*args, **kwargs) ++ def assert_not_called(*args, **kwargs): ++ return mock.assert_not_called(*args, **kwargs) ++ def assert_called_once(*args, **kwargs): ++ return mock.assert_called_once(*args, **kwargs) + def assert_called_with(*args, **kwargs): + return mock.assert_called_with(*args, **kwargs) + def assert_called_once_with(*args, **kwargs): +@@ -318,6 +324,9 @@ + funcopy.assert_has_calls = assert_has_calls + funcopy.assert_any_call = assert_any_call + funcopy.reset_mock = reset_mock ++ funcopy.assert_called = assert_called ++ funcopy.assert_not_called = assert_not_called ++ funcopy.assert_called_once = assert_called_once + + mock._mock_delegate = funcopy + +@@ -809,6 +818,33 @@ + return message % (expected_string, actual_string) + + ++ def assert_not_called(_mock_self): ++ """assert that the mock was never called. ++ """ ++ self = _mock_self ++ if self.call_count != 0: ++ msg = ("Expected '%s' to not have been called. Called %s times." % ++ (self._mock_name or 'mock', self.call_count)) ++ raise AssertionError(msg) ++ ++ def assert_called(_mock_self): ++ """assert that the mock was called at least once ++ """ ++ self = _mock_self ++ if self.call_count == 0: ++ msg = ("Expected '%s' to have been called." % ++ self._mock_name or 'mock') ++ raise AssertionError(msg) ++ ++ def assert_called_once(_mock_self): ++ """assert that the mock was called only once. ++ """ ++ self = _mock_self ++ if not self.call_count == 1: ++ msg = ("Expected '%s' to have been called once. Called %s times." % ++ (self._mock_name or 'mock', self.call_count)) ++ raise AssertionError(msg) ++ + def assert_called_with(_mock_self, *args, **kwargs): + """assert that the mock was called with the specified arguments. + diff --git a/third_party/python.bzl b/third_party/python.bzl index 02652257af..b3c41e2c8a 100644 --- a/third_party/python.bzl +++ b/third_party/python.bzl @@ -114,6 +114,14 @@ def tensorboard_python_workspace(): sha256 = "2d9fbe67001d2e8f02692075257f3c11e1b0194bd838c8ce3f49b31fc6c3f033", strip_prefix = "mock-1.0.0", build_file = str(Label("//third_party:mock.BUILD")), + patches = [ + # `mock==1.0.0` lacks some assertion methods present in + # later versions of `mock` (see comment above for why we pin + # to this version). Patch created by diffing the pinned + # `mock.py` with GitHub head and identifying all the bits + # that looked related to the methods in question. + "//third_party:mock_call_assertions.patch", + ], ) http_archive(