Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensorboard/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ py_library(
"//tensorboard:expect_absl_logging_installed",
"//tensorboard/backend:application",
"//tensorboard/backend/event_processing:event_file_inspector",
"//tensorboard/util:argparse_util",
"@org_pocoo_werkzeug",
"@org_pythonhosted_six",
],
Expand All @@ -170,6 +171,7 @@ py_test(
"//tensorboard/plugins:base_plugin",
"//tensorboard/plugins/core:core_plugin",
"@org_pocoo_werkzeug",
"@org_pythonhosted_mock",
],
)

Expand Down
137 changes: 122 additions & 15 deletions tensorboard/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,17 @@
from tensorboard.backend.event_processing import event_file_inspector as efi
from tensorboard.plugins import base_plugin
from tensorboard.plugins.core import core_plugin
from tensorboard.util import argparse_util
from tensorboard.util import tb_logging


logger = tb_logging.get_logger()

# Default subcommand name. This is a user-facing CLI and should not change.
_SERVE_SUBCOMMAND_NAME = 'serve'
# Internal flag name used to store which subcommand was invoked.
_SUBCOMMAND_FLAG = '__tensorboard_subcommand'


def setup_environment():
"""Makes recommended modifications to the environment.
Expand Down Expand Up @@ -106,10 +112,13 @@ class TensorBoard(object):
cache_key: As `manager.cache_key`; set by the configure() method.
"""

def __init__(self,
plugins=None,
assets_zip_provider=None,
server_class=None):
def __init__(
self,
plugins=None,
assets_zip_provider=None,
server_class=None,
subcommands=None,
):
"""Creates new instance.

Args:
Expand All @@ -128,9 +137,17 @@ def __init__(self,
assets_zip_provider = get_default_assets_zip_provider()
if server_class is None:
server_class = create_port_scanning_werkzeug_server
if subcommands is None:
subcommands = []
self.plugin_loaders = [application.make_plugin_loader(p) for p in plugins]
self.assets_zip_provider = assets_zip_provider
self.server_class = server_class
self.subcommands = {}
for subcommand in subcommands:
name = subcommand.name()
if name in self.subcommands or name == _SERVE_SUBCOMMAND_NAME:
raise ValueError("Duplicate subcommand name: %r" % name)
self.subcommands[name] = subcommand
self.flags = None

def configure(self, argv=('',), **kwargs):
Expand All @@ -154,15 +171,44 @@ def configure(self, argv=('',), **kwargs):
Raises:
ValueError: If flag values are invalid.
"""
parser = argparse_flags.ArgumentParser(

base_parser = argparse_flags.ArgumentParser(
prog='tensorboard',
description=('TensorBoard is a suite of web applications for '
'inspecting and understanding your TensorFlow runs '
'and graphs. https://github.com/tensorflow/tensorboard '))
base_parser.set_defaults(**{_SUBCOMMAND_FLAG: _SERVE_SUBCOMMAND_NAME})
subparsers = base_parser.add_subparsers(
help="TensorBoard subcommand (defaults to %r)" % _SERVE_SUBCOMMAND_NAME)

serve_subparser = subparsers.add_parser(
_SERVE_SUBCOMMAND_NAME,
help='start local TensorBoard server (default subcommand)')
serve_subparser.set_defaults(**{_SUBCOMMAND_FLAG: _SERVE_SUBCOMMAND_NAME})

if len(argv) < 2 or argv[1].startswith('-'):
# This invocation, if valid, must not use any subcommands: we
# don't permit flags before the subcommand name.
serve_parser = base_parser
else:
# This invocation, if valid, must use a subcommand: we don't take
# any positional arguments to `serve`.
serve_parser = serve_subparser

for (name, subcommand) in six.iteritems(self.subcommands):
subparser = subparsers.add_parser(
name, help=subcommand.help(), description=subcommand.description())
subparser.set_defaults(**{_SUBCOMMAND_FLAG: name})
subcommand.define_flags(subparser)

for loader in self.plugin_loaders:
loader.define_flags(parser)
loader.define_flags(serve_parser)

arg0 = argv[0] if argv else ''
flags = parser.parse_args(argv[1:]) # Strip binary name from argv.

with argparse_util.allow_missing_subcommand():
flags = base_parser.parse_args(argv[1:]) # Strip binary name from argv.

self.cache_key = manager.cache_key(
working_directory=os.getcwd(),
arguments=argv[1:],
Expand All @@ -180,8 +226,9 @@ def configure(self, argv=('',), **kwargs):
if not hasattr(flags, k):
raise ValueError('Unknown TensorBoard flag: %s' % k)
setattr(flags, k, v)
for loader in self.plugin_loaders:
loader.fix_flags(flags)
if getattr(flags, _SUBCOMMAND_FLAG) == _SERVE_SUBCOMMAND_NAME:
for loader in self.plugin_loaders:
loader.fix_flags(flags)
self.flags = flags
return [arg0]

Expand All @@ -203,14 +250,24 @@ def main(self, ignored_argv=('',)):
:rtype: int
"""
self._install_signal_handler(signal.SIGTERM, "SIGTERM")
if self.flags.inspect:
logger.info('Not bringing up TensorBoard, but inspecting event files.')
event_file = os.path.expanduser(self.flags.event_file)
efi.inspect(self.flags.logdir, event_file, self.flags.tag)
return 0
if self.flags.version_tb:
subcommand_name = getattr(self.flags, _SUBCOMMAND_FLAG)
if subcommand_name == _SERVE_SUBCOMMAND_NAME:
runner = self._run_serve_subcommand
else:
runner = self.subcommands[subcommand_name].run
return runner(self.flags) or 0

def _run_serve_subcommand(self, flags):
# TODO(#2801): Make `--version` a flag on only the base parser, not `serve`.
if flags.version_tb:
print(version.VERSION)
return 0
if flags.inspect:
# TODO(@wchargin): Convert `inspect` to a normal subcommand?
logger.info('Not bringing up TensorBoard, but inspecting event files.')
event_file = os.path.expanduser(flags.event_file)
efi.inspect(flags.logdir, event_file, flags.tag)
return 0
try:
server = self._make_server()
server.print_serving_message()
Expand Down Expand Up @@ -295,6 +352,56 @@ def _make_server(self):
return self.server_class(app, self.flags)


@six.add_metaclass(ABCMeta)
class TensorBoardSubcommand(object):
"""Experimental private API for defining subcommands to tensorboard(1)."""

@abstractmethod
def name(self):
"""Name of this subcommand, as specified on the command line.

This must be unique across all subcommands.

Returns:
A string.
"""
pass

@abstractmethod
def define_flags(self, parser):
"""Configure an argument parser for this subcommand.

Flags whose names start with two underscores (e.g., `__foo`) are
reserved for use by the runtime and must not be defined by
subcommands.

Args:
parser: An `argparse.ArgumentParser` scoped to this subcommand,
which this function should mutate.
"""
pass

@abstractmethod
def run(self, flags):
"""Execute this subcommand with user-provided flags.

Args:
flags: An `argparse.Namespace` object with all defined flags.

Returns:
An `int` exit code, or `None` as an alias for `0`.
"""
pass

def help(self):
"""Short, one-line help text to display on `tensorboard --help`."""
return None

def description(self):
"""Description to display on `tensorboard SUBCOMMAND --help`."""
return None


@six.add_metaclass(ABCMeta)
class TensorBoardServer(object):
"""Class for customizing TensorBoard WSGI app serving."""
Expand Down
138 changes: 138 additions & 0 deletions tensorboard/program_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,17 @@
from __future__ import print_function

import argparse
import contextlib
import sys

import six

try:
# python version >= 3.3
from unittest import mock # pylint: disable=g-import-not-at-top
except ImportError:
import mock # pylint: disable=g-import-not-at-top,unused-import

from tensorboard import program
from tensorboard import test as tb_test
from tensorboard.plugins import base_plugin
Expand Down Expand Up @@ -121,5 +129,135 @@ def testSpecifiedHost(self):
self.assertTrue(one_passed) # We expect either IPv4 or IPv6 to be supported


class SubcommandTest(tb_test.TestCase):

def setUp(self):
super(SubcommandTest, self).setUp()
self.stderr = six.StringIO()
patchers = [
mock.patch.object(program.TensorBoard, '_install_signal_handler'),
mock.patch.object(program.TensorBoard, '_run_serve_subcommand'),
mock.patch.object(_TestSubcommand, 'run'),
mock.patch.object(sys, 'stderr', self.stderr),
]
for p in patchers:
p.start()
self.addCleanup(p.stop)
_TestSubcommand.run.return_value = None

def tearDown(self):
stderr = self.stderr.getvalue()
if stderr:
# In case of failing tests, let there be debug info.
print('Stderr:\n%s' % stderr)

def testImplicitServe(self):
tb = program.TensorBoard(
plugins=[core_plugin.CorePluginLoader],
subcommands=[_TestSubcommand(lambda parser: None)],
)
tb.configure(('tb', '--logdir', 'logs', '--path_prefix', '/x///'))
tb.main()
program.TensorBoard._run_serve_subcommand.assert_called_once()
flags = program.TensorBoard._run_serve_subcommand.call_args[0][0]
self.assertEqual(flags.logdir, 'logs')
self.assertEqual(flags.path_prefix, '/x') # fixed by core_plugin

def testExplicitServe(self):
tb = program.TensorBoard(
plugins=[core_plugin.CorePluginLoader],
subcommands=[_TestSubcommand()],
)
tb.configure(('tb', 'serve', '--logdir', 'logs', '--path_prefix', '/x///'))
tb.main()
program.TensorBoard._run_serve_subcommand.assert_called_once()
flags = program.TensorBoard._run_serve_subcommand.call_args[0][0]
self.assertEqual(flags.logdir, 'logs')
self.assertEqual(flags.path_prefix, '/x') # fixed by core_plugin

def testSubcommand(self):
def define_flags(parser):
parser.add_argument('--hello')

tb = program.TensorBoard(
plugins=[core_plugin.CorePluginLoader],
subcommands=[_TestSubcommand(define_flags=define_flags)],
)
tb.configure(('tb', 'test', '--hello', 'world'))
self.assertEqual(tb.main(), 0)
_TestSubcommand.run.assert_called_once()
flags = _TestSubcommand.run.call_args[0][0]
self.assertEqual(flags.hello, 'world')

def testSubcommand_ExitCode(self):
tb = program.TensorBoard(
plugins=[core_plugin.CorePluginLoader],
subcommands=[_TestSubcommand()],
)
_TestSubcommand.run.return_value = 77
tb.configure(('tb', 'test'))
self.assertEqual(tb.main(), 77)

def testSubcommand_DoesNotInheritBaseArgs(self):
tb = program.TensorBoard(
plugins=[core_plugin.CorePluginLoader],
subcommands=[_TestSubcommand()],
)
with self.assertRaises(SystemExit):
tb.configure(('tb', 'test', '--logdir', 'logs'))
self.assertIn(
'unrecognized arguments: --logdir logs', self.stderr.getvalue())
self.stderr.truncate(0)

def testSubcommand_MayRequirePositionals(self):
def define_flags(parser):
parser.add_argument('payload')

tb = program.TensorBoard(
plugins=[core_plugin.CorePluginLoader],
subcommands=[_TestSubcommand(define_flags=define_flags)],
)
with self.assertRaises(SystemExit):
tb.configure(('tb', 'test'))
self.assertIn('required', self.stderr.getvalue())
self.assertIn('payload', self.stderr.getvalue())
self.stderr.truncate(0)

def testConflictingNames_AmongSubcommands(self):
with self.assertRaises(ValueError) as cm:
tb = program.TensorBoard(
plugins=[core_plugin.CorePluginLoader],
subcommands=[_TestSubcommand(), _TestSubcommand()],
)
self.assertIn('Duplicate subcommand name:', str(cm.exception))
self.assertIn('test', str(cm.exception))

def testConflictingNames_WithServe(self):
with self.assertRaises(ValueError) as cm:
tb = program.TensorBoard(
plugins=[core_plugin.CorePluginLoader],
subcommands=[_TestSubcommand(name='serve')],
)
self.assertIn('Duplicate subcommand name:', str(cm.exception))
self.assertIn('serve', str(cm.exception))


class _TestSubcommand(program.TensorBoardSubcommand):

def __init__(self, name=None, define_flags=None):
self._name = name
self._define_flags = define_flags

def name(self):
return self._name or 'test'

def define_flags(self, parser):
if self._define_flags:
self._define_flags(parser)

def run(self, flags):
pass


if __name__ == '__main__':
tb_test.main()
17 changes: 17 additions & 0 deletions tensorboard/util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,23 @@ licenses(["notice"]) # Apache 2.0

exports_files(["LICENSE"]) # Needed for internal repo.

py_library(
name = "argparse_util",
srcs = ["argparse_util.py"],
srcs_version = "PY2AND3",
)

py_test(
name = "argparse_util_test",
size = "small",
srcs = ["argparse_util_test.py"],
srcs_version = "PY2AND3",
deps = [
":argparse_util",
"//tensorboard:test",
],
)

py_library(
name = "encoder",
srcs = ["encoder.py"],
Expand Down
Loading