Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorboard/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ py_test(
":program",
":test",
"//tensorboard/backend:application",
"//tensorboard/plugins:base_plugin",
"//tensorboard/plugins/core:core_plugin",
"@org_pocoo_werkzeug",
],
Expand Down
25 changes: 23 additions & 2 deletions tensorboard/backend/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ def TensorBoardWSGIApp(

Args:
flags: An argparse.Namespace containing TensorBoard CLI flags.
plugins: A list of TBLoader subclasses for the plugins to load.
plugins: A list of plugins, which can be provided as TBPlugin subclasses
or TBLoader instances or subclasses.
assets_zip_provider: See TBContext documentation for more information.
data_provider: Instance of `tensorboard.data.provider.DataProvider`. May
be `None` if `flags.generic_data` is set to `"false"` in which case
Expand Down Expand Up @@ -191,7 +192,8 @@ def TensorBoardWSGIApp(
plugin_name_to_instance=plugin_name_to_instance,
window_title=flags.window_title)
tbplugins = []
for loader in plugins:
for plugin_spec in plugins:
loader = make_plugin_loader(plugin_spec)
plugin = loader.load(context)
if plugin is None:
continue
Expand Down Expand Up @@ -601,3 +603,22 @@ def AddRunsFromDirectory(self, path, name=None):
def Reload(self):
"""Unsupported."""
raise NotImplementedError()


def make_plugin_loader(plugin_spec):
"""Returns a plugin loader for the given plugin.

Args:
plugin_spec: A TBPlugin subclass, or a TBLoader instance or subclass.

Returns:
A TBLoader for the given plugin.
"""
if isinstance(plugin_spec, base_plugin.TBLoader):
return plugin_spec
if isinstance(plugin_spec, type):
if issubclass(plugin_spec, base_plugin.TBLoader):
return plugin_spec()
if issubclass(plugin_spec, base_plugin.TBPlugin):
return base_plugin.BasicLoader(plugin_spec)
raise TypeError("Not a TBLoader or TBPlugin subclass: %r" % (plugin_spec,))
62 changes: 45 additions & 17 deletions tensorboard/backend/application_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,16 @@ def frontend_metadata(self):
)


class FakePluginLoader(base_plugin.TBLoader):
"""Pass-through loader for FakePlugin with arbitrary arguments."""

def __init__(self, **kwargs):
self._kwargs = kwargs

def load(self, context):
return FakePlugin(context, **self._kwargs)


class ApplicationTest(tb_test.TestCase):
def setUp(self):
plugins = [
Expand Down Expand Up @@ -357,6 +367,26 @@ def testSlashlessRoute(self):
application.TensorBoardWSGI([self._make_plugin('runaway')])


class MakePluginLoaderTest(tb_test.TestCase):

def testMakePluginLoader_pluginClass(self):
loader = application.make_plugin_loader(FakePlugin)
self.assertIsInstance(loader, base_plugin.BasicLoader)
self.assertIs(loader.plugin_class, FakePlugin)

def testMakePluginLoader_pluginLoaderClass(self):
loader = application.make_plugin_loader(FakePluginLoader)
self.assertIsInstance(loader, FakePluginLoader)

def testMakePluginLoader_pluginLoader(self):
loader = FakePluginLoader()
self.assertIs(loader, application.make_plugin_loader(loader))

def testMakePluginLoader_invalidType(self):
with six.assertRaisesRegex(self, TypeError, 'FakePlugin'):
application.make_plugin_loader(FakePlugin())


class GetEventFileActiveFilterTest(tb_test.TestCase):

def testDisabled(self):
Expand Down Expand Up @@ -519,23 +549,21 @@ def setUp(self):
self.app = application.standard_tensorboard_wsgi(
FakeFlags(logdir=self.get_temp_dir()),
[
base_plugin.BasicLoader(functools.partial(
FakePlugin,
plugin_name='foo',
is_active_value=True,
routes_mapping={'/foo_route': self._foo_handler},
construction_callback=self._construction_callback)),
base_plugin.BasicLoader(functools.partial(
FakePlugin,
plugin_name='bar',
is_active_value=True,
routes_mapping={
'/bar_route': self._bar_handler,
'/wildcard/*': self._wildcard_handler,
'/wildcard/special/*': self._wildcard_special_handler,
'/wildcard/special/exact': self._foo_handler,
},
construction_callback=self._construction_callback)),
FakePluginLoader(
plugin_name='foo',
is_active_value=True,
routes_mapping={'/foo_route': self._foo_handler},
construction_callback=self._construction_callback),
FakePluginLoader(
plugin_name='bar',
is_active_value=True,
routes_mapping={
'/bar_route': self._bar_handler,
'/wildcard/*': self._wildcard_handler,
'/wildcard/special/*': self._wildcard_special_handler,
'/wildcard/special/exact': self._foo_handler,
},
construction_callback=self._construction_callback),
],
dummy_assets_zip_provider)

Expand Down
10 changes: 5 additions & 5 deletions tensorboard/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,20 @@
# Ordering matters. The order in which these lines appear determines the
# ordering of tabs in TensorBoard's GUI.
_PLUGINS = [
core_plugin.CorePluginLoader(),
core_plugin.CorePluginLoader,
scalars_plugin.ScalarsPlugin,
custom_scalars_plugin.CustomScalarsPlugin,
images_plugin.ImagesPlugin,
audio_plugin.AudioPlugin,
debugger_plugin_loader.DebuggerPluginLoader(),
debugger_plugin_loader.DebuggerPluginLoader,
graphs_plugin.GraphsPlugin,
distributions_plugin.DistributionsPlugin,
histograms_plugin.HistogramsPlugin,
text_plugin.TextPlugin,
pr_curves_plugin.PrCurvesPlugin,
profile_plugin_loader.ProfilePluginLoader(),
beholder_plugin_loader.BeholderPluginLoader(),
interactive_inference_plugin_loader.InteractiveInferencePluginLoader(),
profile_plugin_loader.ProfilePluginLoader,
beholder_plugin_loader.BeholderPluginLoader,
interactive_inference_plugin_loader.InteractiveInferencePluginLoader,
hparams_plugin.HParamsPlugin,
mesh_plugin.MeshPlugin,
]
Expand Down
4 changes: 2 additions & 2 deletions tensorboard/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,10 @@ def __init__(self, plugin_class):

:param plugin_class: :class:`TBPlugin`
"""
self._plugin_class = plugin_class
self.plugin_class = plugin_class

def load(self, context):
return self._plugin_class(context)
return self.plugin_class(context)


class FlagsError(ValueError):
Expand Down
17 changes: 4 additions & 13 deletions tensorboard/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,13 @@ def __init__(self,
"""Creates new instance.

Args:
plugins: A list of TensorBoard plugins to load, as TBLoader instances or
TBPlugin classes. If not specified, defaults to first-party plugins.
plugin: A list of TensorBoard plugins to load, as TBPlugin classes or
TBLoader instances or classes. If not specified, defaults to first-party
plugins.
assets_zip_provider: Delegates to TBContext or uses default if None.
server_class: An optional factory for a `TensorBoardServer` to use
for serving the TensorBoard WSGI app. If provided, its callable
signature should match that of `TensorBoardServer.__init__`.

:type plugins: list[Union[base_plugin.TBLoader, Type[base_plugin.TBPlugin]]]
:type assets_zip_provider: () -> file
:type server_class: class
"""
if plugins is None:
from tensorboard import default
Expand All @@ -136,13 +133,7 @@ def __init__(self,
assets_zip_provider = get_default_assets_zip_provider()
if server_class is None:
server_class = create_port_scanning_werkzeug_server
def make_loader(plugin):
if isinstance(plugin, base_plugin.TBLoader):
return plugin
if issubclass(plugin, base_plugin.TBPlugin):
return base_plugin.BasicLoader(plugin)
raise ValueError("Not a TBLoader or TBPlugin subclass: %s" % plugin)
self.plugin_loaders = [make_loader(p) for p in plugins]
self.plugin_loaders = [application.make_plugin_loader(p) for p in plugins]
self.assets_zip_provider = assets_zip_provider
self.server_class = server_class
self.flags = None
Expand Down
27 changes: 24 additions & 3 deletions tensorboard/program_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,39 @@

from tensorboard import program
from tensorboard import test as tb_test
from tensorboard.plugins import base_plugin
from tensorboard.plugins.core import core_plugin


class TensorBoardTest(tb_test.TestCase):
"""Tests the TensorBoard program."""

def testPlugins_pluginClass(self):
tb = program.TensorBoard(plugins=[core_plugin.CorePlugin])
self.assertIsInstance(tb.plugin_loaders[0], base_plugin.BasicLoader)
self.assertIs(tb.plugin_loaders[0].plugin_class, core_plugin.CorePlugin)

def testPlugins_pluginLoaderClass(self):
tb = program.TensorBoard(plugins=[core_plugin.CorePluginLoader])
self.assertIsInstance(tb.plugin_loaders[0], core_plugin.CorePluginLoader)

def testPlugins_pluginLoader(self):
loader = core_plugin.CorePluginLoader()
tb = program.TensorBoard(plugins=[loader])
self.assertIs(tb.plugin_loaders[0], loader)

def testPlugins_invalidType(self):
plugin_instance = core_plugin.CorePlugin(base_plugin.TBContext())
with six.assertRaisesRegex(self, TypeError, 'CorePlugin'):
tb = program.TensorBoard(plugins=[plugin_instance])

def testConfigure(self):
# Many useful flags are defined under the core plugin.
tb = program.TensorBoard(plugins=[core_plugin.CorePluginLoader()])
tb = program.TensorBoard(plugins=[core_plugin.CorePluginLoader])
tb.configure(logdir='foo')
self.assertStartsWith(tb.flags.logdir, 'foo')
self.assertEqual(tb.flags.logdir, 'foo')

def testConfigure_unknownFlag(self):
tb = program.TensorBoard(plugins=[core_plugin.CorePlugin])
with six.assertRaisesRegex(self, ValueError, 'Unknown TensorBoard flag'):
tb.configure(foo='bar')

Expand Down