diff --git a/tensorboard/BUILD b/tensorboard/BUILD index 149a462260..9ee9eb74ed 100644 --- a/tensorboard/BUILD +++ b/tensorboard/BUILD @@ -149,6 +149,7 @@ py_test( ":program", ":test", "//tensorboard/backend:application", + "//tensorboard/plugins:base_plugin", "//tensorboard/plugins/core:core_plugin", "@org_pocoo_werkzeug", ], diff --git a/tensorboard/backend/application.py b/tensorboard/backend/application.py index 30513fea81..96c9288d21 100644 --- a/tensorboard/backend/application.py +++ b/tensorboard/backend/application.py @@ -160,7 +160,8 @@ def TensorBoardWSGIApp( Args: flags: An argparse.Namespace containing TensorBoard CLI flags. - plugins: A list of TBLoader subclasses for the plugins to load. + plugins: A list of plugins, which can be provided as TBPlugin subclasses + or TBLoader instances or subclasses. assets_zip_provider: See TBContext documentation for more information. data_provider: Instance of `tensorboard.data.provider.DataProvider`. May be `None` if `flags.generic_data` is set to `"false"` in which case @@ -191,7 +192,8 @@ def TensorBoardWSGIApp( plugin_name_to_instance=plugin_name_to_instance, window_title=flags.window_title) tbplugins = [] - for loader in plugins: + for plugin_spec in plugins: + loader = make_plugin_loader(plugin_spec) plugin = loader.load(context) if plugin is None: continue @@ -601,3 +603,22 @@ def AddRunsFromDirectory(self, path, name=None): def Reload(self): """Unsupported.""" raise NotImplementedError() + + +def make_plugin_loader(plugin_spec): + """Returns a plugin loader for the given plugin. + + Args: + plugin_spec: A TBPlugin subclass, or a TBLoader instance or subclass. + + Returns: + A TBLoader for the given plugin. + """ + if isinstance(plugin_spec, base_plugin.TBLoader): + return plugin_spec + if isinstance(plugin_spec, type): + if issubclass(plugin_spec, base_plugin.TBLoader): + return plugin_spec() + if issubclass(plugin_spec, base_plugin.TBPlugin): + return base_plugin.BasicLoader(plugin_spec) + raise TypeError("Not a TBLoader or TBPlugin subclass: %r" % (plugin_spec,)) diff --git a/tensorboard/backend/application_test.py b/tensorboard/backend/application_test.py index d05eacec11..be6bc0ff39 100644 --- a/tensorboard/backend/application_test.py +++ b/tensorboard/backend/application_test.py @@ -135,6 +135,16 @@ def frontend_metadata(self): ) +class FakePluginLoader(base_plugin.TBLoader): + """Pass-through loader for FakePlugin with arbitrary arguments.""" + + def __init__(self, **kwargs): + self._kwargs = kwargs + + def load(self, context): + return FakePlugin(context, **self._kwargs) + + class ApplicationTest(tb_test.TestCase): def setUp(self): plugins = [ @@ -357,6 +367,26 @@ def testSlashlessRoute(self): application.TensorBoardWSGI([self._make_plugin('runaway')]) +class MakePluginLoaderTest(tb_test.TestCase): + + def testMakePluginLoader_pluginClass(self): + loader = application.make_plugin_loader(FakePlugin) + self.assertIsInstance(loader, base_plugin.BasicLoader) + self.assertIs(loader.plugin_class, FakePlugin) + + def testMakePluginLoader_pluginLoaderClass(self): + loader = application.make_plugin_loader(FakePluginLoader) + self.assertIsInstance(loader, FakePluginLoader) + + def testMakePluginLoader_pluginLoader(self): + loader = FakePluginLoader() + self.assertIs(loader, application.make_plugin_loader(loader)) + + def testMakePluginLoader_invalidType(self): + with six.assertRaisesRegex(self, TypeError, 'FakePlugin'): + application.make_plugin_loader(FakePlugin()) + + class GetEventFileActiveFilterTest(tb_test.TestCase): def testDisabled(self): @@ -519,23 +549,21 @@ def setUp(self): self.app = application.standard_tensorboard_wsgi( FakeFlags(logdir=self.get_temp_dir()), [ - base_plugin.BasicLoader(functools.partial( - FakePlugin, - plugin_name='foo', - is_active_value=True, - routes_mapping={'/foo_route': self._foo_handler}, - construction_callback=self._construction_callback)), - base_plugin.BasicLoader(functools.partial( - FakePlugin, - plugin_name='bar', - is_active_value=True, - routes_mapping={ - '/bar_route': self._bar_handler, - '/wildcard/*': self._wildcard_handler, - '/wildcard/special/*': self._wildcard_special_handler, - '/wildcard/special/exact': self._foo_handler, - }, - construction_callback=self._construction_callback)), + FakePluginLoader( + plugin_name='foo', + is_active_value=True, + routes_mapping={'/foo_route': self._foo_handler}, + construction_callback=self._construction_callback), + FakePluginLoader( + plugin_name='bar', + is_active_value=True, + routes_mapping={ + '/bar_route': self._bar_handler, + '/wildcard/*': self._wildcard_handler, + '/wildcard/special/*': self._wildcard_special_handler, + '/wildcard/special/exact': self._foo_handler, + }, + construction_callback=self._construction_callback), ], dummy_assets_zip_provider) diff --git a/tensorboard/default.py b/tensorboard/default.py index 1a75c665e7..41f66546a1 100644 --- a/tensorboard/default.py +++ b/tensorboard/default.py @@ -60,20 +60,20 @@ # Ordering matters. The order in which these lines appear determines the # ordering of tabs in TensorBoard's GUI. _PLUGINS = [ - core_plugin.CorePluginLoader(), + core_plugin.CorePluginLoader, scalars_plugin.ScalarsPlugin, custom_scalars_plugin.CustomScalarsPlugin, images_plugin.ImagesPlugin, audio_plugin.AudioPlugin, - debugger_plugin_loader.DebuggerPluginLoader(), + debugger_plugin_loader.DebuggerPluginLoader, graphs_plugin.GraphsPlugin, distributions_plugin.DistributionsPlugin, histograms_plugin.HistogramsPlugin, text_plugin.TextPlugin, pr_curves_plugin.PrCurvesPlugin, - profile_plugin_loader.ProfilePluginLoader(), - beholder_plugin_loader.BeholderPluginLoader(), - interactive_inference_plugin_loader.InteractiveInferencePluginLoader(), + profile_plugin_loader.ProfilePluginLoader, + beholder_plugin_loader.BeholderPluginLoader, + interactive_inference_plugin_loader.InteractiveInferencePluginLoader, hparams_plugin.HParamsPlugin, mesh_plugin.MeshPlugin, ] diff --git a/tensorboard/plugins/base_plugin.py b/tensorboard/plugins/base_plugin.py index 48ce6ce75f..51c909221b 100644 --- a/tensorboard/plugins/base_plugin.py +++ b/tensorboard/plugins/base_plugin.py @@ -257,10 +257,10 @@ def __init__(self, plugin_class): :param plugin_class: :class:`TBPlugin` """ - self._plugin_class = plugin_class + self.plugin_class = plugin_class def load(self, context): - return self._plugin_class(context) + return self.plugin_class(context) class FlagsError(ValueError): diff --git a/tensorboard/program.py b/tensorboard/program.py index a0fa7147a6..f124af0a50 100644 --- a/tensorboard/program.py +++ b/tensorboard/program.py @@ -118,16 +118,13 @@ def __init__(self, """Creates new instance. Args: - plugins: A list of TensorBoard plugins to load, as TBLoader instances or - TBPlugin classes. If not specified, defaults to first-party plugins. + plugin: A list of TensorBoard plugins to load, as TBPlugin classes or + TBLoader instances or classes. If not specified, defaults to first-party + plugins. assets_zip_provider: Delegates to TBContext or uses default if None. server_class: An optional factory for a `TensorBoardServer` to use for serving the TensorBoard WSGI app. If provided, its callable signature should match that of `TensorBoardServer.__init__`. - - :type plugins: list[Union[base_plugin.TBLoader, Type[base_plugin.TBPlugin]]] - :type assets_zip_provider: () -> file - :type server_class: class """ if plugins is None: from tensorboard import default @@ -136,13 +133,7 @@ def __init__(self, assets_zip_provider = get_default_assets_zip_provider() if server_class is None: server_class = create_port_scanning_werkzeug_server - def make_loader(plugin): - if isinstance(plugin, base_plugin.TBLoader): - return plugin - if issubclass(plugin, base_plugin.TBPlugin): - return base_plugin.BasicLoader(plugin) - raise ValueError("Not a TBLoader or TBPlugin subclass: %s" % plugin) - self.plugin_loaders = [make_loader(p) for p in plugins] + self.plugin_loaders = [application.make_plugin_loader(p) for p in plugins] self.assets_zip_provider = assets_zip_provider self.server_class = server_class self.flags = None diff --git a/tensorboard/program_test.py b/tensorboard/program_test.py index 11305ba4bc..e63d8dbd3e 100644 --- a/tensorboard/program_test.py +++ b/tensorboard/program_test.py @@ -24,18 +24,39 @@ from tensorboard import program from tensorboard import test as tb_test +from tensorboard.plugins import base_plugin from tensorboard.plugins.core import core_plugin class TensorBoardTest(tb_test.TestCase): """Tests the TensorBoard program.""" + def testPlugins_pluginClass(self): + tb = program.TensorBoard(plugins=[core_plugin.CorePlugin]) + self.assertIsInstance(tb.plugin_loaders[0], base_plugin.BasicLoader) + self.assertIs(tb.plugin_loaders[0].plugin_class, core_plugin.CorePlugin) + + def testPlugins_pluginLoaderClass(self): + tb = program.TensorBoard(plugins=[core_plugin.CorePluginLoader]) + self.assertIsInstance(tb.plugin_loaders[0], core_plugin.CorePluginLoader) + + def testPlugins_pluginLoader(self): + loader = core_plugin.CorePluginLoader() + tb = program.TensorBoard(plugins=[loader]) + self.assertIs(tb.plugin_loaders[0], loader) + + def testPlugins_invalidType(self): + plugin_instance = core_plugin.CorePlugin(base_plugin.TBContext()) + with six.assertRaisesRegex(self, TypeError, 'CorePlugin'): + tb = program.TensorBoard(plugins=[plugin_instance]) + def testConfigure(self): - # Many useful flags are defined under the core plugin. - tb = program.TensorBoard(plugins=[core_plugin.CorePluginLoader()]) + tb = program.TensorBoard(plugins=[core_plugin.CorePluginLoader]) tb.configure(logdir='foo') - self.assertStartsWith(tb.flags.logdir, 'foo') + self.assertEqual(tb.flags.logdir, 'foo') + def testConfigure_unknownFlag(self): + tb = program.TensorBoard(plugins=[core_plugin.CorePlugin]) with six.assertRaisesRegex(self, ValueError, 'Unknown TensorBoard flag'): tb.configure(foo='bar')