diff --git a/tensorboard/backend/application.py b/tensorboard/backend/application.py index a9aa08c280..bc352bc559 100644 --- a/tensorboard/backend/application.py +++ b/tensorboard/backend/application.py @@ -89,6 +89,8 @@ PLUGINS_LISTING_ROUTE = "/plugins_listing" PLUGIN_ENTRY_ROUTE = "/plugin_entry.html" +EXPERIMENTAL_PLUGINS_QUERY_PARAM = "experimentalPlugin" + # Slashes in a plugin name could throw the router for a loop. An empty # name would be confusing, too. To be safe, let's restrict the valid # names as follows. @@ -114,13 +116,19 @@ def _apply_tensor_size_guidance(sampling_hints): return tensor_size_guidance -def standard_tensorboard_wsgi(flags, plugin_loaders, assets_zip_provider): +def standard_tensorboard_wsgi( + flags, plugin_loaders, assets_zip_provider, experimental_plugins=None +): """Construct a TensorBoardWSGIApp with standard plugins and multiplexer. Args: flags: An argparse.Namespace containing TensorBoard CLI flags. plugin_loaders: A list of TBLoader instances. assets_zip_provider: See TBContext documentation for more information. + experimental_plugins: A list of plugin names that are only provided + experimentally. The corresponding plugins will only be activated for + a user if the user has specified the plugin with the expplugin query + parameter in the URL. Returns: The new TensorBoard WSGI application. @@ -177,8 +185,14 @@ def standard_tensorboard_wsgi(flags, plugin_loaders, assets_zip_provider): start_reloading_multiplexer( multiplexer, path_to_run, reload_interval, flags.reload_task ) + return TensorBoardWSGIApp( - flags, plugin_loaders, data_provider, assets_zip_provider, multiplexer + flags, + plugin_loaders, + data_provider, + assets_zip_provider, + multiplexer, + experimental_plugins, ) @@ -205,6 +219,7 @@ def TensorBoardWSGIApp( data_provider=None, assets_zip_provider=None, deprecated_multiplexer=None, + experimental_plugins=None, ): """Constructs a TensorBoard WSGI app from plugins and data providers. @@ -218,6 +233,10 @@ def TensorBoardWSGIApp( deprecated_multiplexer: Optional `plugin_event_multiplexer.EventMultiplexer` to use for any plugins not yet enabled for the DataProvider API. Required if the data_provider argument is not passed. + experimental_plugins: A list of plugin names that are only provided + experimentally. The corresponding plugins will only be activated for + a user if the user has specified the plugin with the expplugin query + parameter in the URL. Returns: A WSGI application that implements the TensorBoard backend. @@ -253,13 +272,21 @@ def TensorBoardWSGIApp( continue tbplugins.append(plugin) plugin_name_to_instance[plugin.plugin_name] = plugin - return TensorBoardWSGI(tbplugins, flags.path_prefix, data_provider) + return TensorBoardWSGI( + tbplugins, flags.path_prefix, data_provider, experimental_plugins + ) class TensorBoardWSGI(object): """The TensorBoard WSGI app that delegates to a set of TBPlugin.""" - def __init__(self, plugins, path_prefix="", data_provider=None): + def __init__( + self, + plugins, + path_prefix="", + data_provider=None, + experimental_plugins=None, + ): """Constructs TensorBoardWSGI instance. Args: @@ -268,6 +295,10 @@ def __init__(self, plugins, path_prefix="", data_provider=None): data_provider: `tensorboard.data.provider.DataProvider` or `None`; if present, will inform the "active" state of `/plugins_listing`. + experimental_plugins: A list of plugin names that are only provided + experimentally. The corresponding plugins will only be activated for + a user if the user has specified the plugin with the expplugin query + parameter in the URL. Returns: A WSGI application for the set of all TBPlugin instances. @@ -285,6 +316,9 @@ def __init__(self, plugins, path_prefix="", data_provider=None): self._plugins = plugins self._path_prefix = path_prefix self._data_provider = data_provider + self._experimental_plugins = frozenset( + experimental_plugins if experimental_plugins is not None else [] + ) if self._path_prefix.endswith("/"): # Should have been fixed by `fix_flags`. raise ValueError( @@ -467,7 +501,13 @@ def _serve_plugins_listing(self, request): if self._data_provider is not None else frozenset() ) + plugins_to_skip = self._experimental_plugins - frozenset( + request.args.getlist(EXPERIMENTAL_PLUGINS_QUERY_PARAM) + ) for plugin in self._plugins: + if plugin.plugin_name in plugins_to_skip: + continue + if ( type(plugin) is core_plugin.CorePlugin ): # pylint: disable=unidiomatic-typecheck diff --git a/tensorboard/backend/application_test.py b/tensorboard/backend/application_test.py index b6da5ff591..346454f9e2 100644 --- a/tensorboard/backend/application_test.py +++ b/tensorboard/backend/application_test.py @@ -373,6 +373,64 @@ def fake_is_active(self): self.assertEqual(parsed_object["foo"]["enabled"], False) self.assertEqual(parsed_object["baz"]["enabled"], True) + def testPluginsListingWithExperimentalPlugin(self): + plugins = [ + FakePlugin(plugin_name="bar"), + FakePlugin(plugin_name="foo"), + FakePlugin(plugin_name="bazz"), + ] + app = application.TensorBoardWSGI(plugins, experimental_plugins=["foo"]) + self._install_server(app) + + plugins_without_flag = self._get_json("/data/plugins_listing") + self.assertIsNotNone(plugins_without_flag.get("bar")) + self.assertIsNone(plugins_without_flag.get("foo")) + self.assertIsNotNone(plugins_without_flag.get("bazz")) + + plugins_with_flag = self._get_json( + "/data/plugins_listing?experimentalPlugin=foo" + ) + self.assertIsNotNone(plugins_with_flag.get("bar")) + self.assertIsNotNone(plugins_with_flag.get("foo")) + self.assertIsNotNone(plugins_with_flag.get("bazz")) + + plugins_with_useless_flag = self._get_json( + "/data/plugins_listing?experimentalPlugin=bar" + ) + self.assertIsNotNone(plugins_with_useless_flag.get("bar")) + self.assertIsNone(plugins_with_useless_flag.get("foo")) + self.assertIsNotNone(plugins_with_useless_flag.get("bazz")) + + def testPluginsListingWithMultipleExperimentalPlugins(self): + plugins = [ + FakePlugin(plugin_name="bar"), + FakePlugin(plugin_name="foo"), + FakePlugin(plugin_name="bazz"), + ] + app = application.TensorBoardWSGI( + plugins, experimental_plugins=["bar", "bazz"] + ) + self._install_server(app) + + plugins_without_flag = self._get_json("/data/plugins_listing") + self.assertIsNone(plugins_without_flag.get("bar")) + self.assertIsNotNone(plugins_without_flag.get("foo")) + self.assertIsNone(plugins_without_flag.get("bazz")) + + plugins_with_one_flag = self._get_json( + "/data/plugins_listing?experimentalPlugin=bar" + ) + self.assertIsNotNone(plugins_with_one_flag.get("bar")) + self.assertIsNotNone(plugins_with_one_flag.get("foo")) + self.assertIsNone(plugins_with_one_flag.get("bazz")) + + plugins_with_multiple_flags = self._get_json( + "/data/plugins_listing?experimentalPlugin=bar&experimentalPlugin=bazz" + ) + self.assertIsNotNone(plugins_with_multiple_flags.get("bar")) + self.assertIsNotNone(plugins_with_multiple_flags.get("foo")) + self.assertIsNotNone(plugins_with_multiple_flags.get("bazz")) + def testPluginEntry(self): """Test the data/plugin_entry.html endpoint.""" response = self.server.get("/data/plugin_entry.html?name=baz") diff --git a/tensorboard/components/tf_backend/router.ts b/tensorboard/components/tf_backend/router.ts index e9ba029ca9..5c8b283a4e 100644 --- a/tensorboard/components/tf_backend/router.ts +++ b/tensorboard/components/tf_backend/router.ts @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ namespace tf_backend { + const EXPERIMENTAL_PLUGINS_QUERY_PARAM = 'experimentalPlugin'; + export interface Router { environment: () => string; experiments: () => string; @@ -34,7 +36,10 @@ namespace tf_backend { * * @param dataDir {string=} The base prefix for data endpoints. */ - export function createRouter(dataDir = 'data'): Router { + export function createRouter( + dataDir = 'data', + urlSearchParams = new URLSearchParams(window.location.search) + ): Router { if (dataDir[dataDir.length - 1] === '/') { dataDir = dataDir.slice(0, dataDir.length - 1); } @@ -52,7 +57,16 @@ namespace tf_backend { params ); }, - pluginsListing: () => createDataPath(dataDir, '/plugins_listing'), + pluginsListing: () => + createDataPath( + dataDir, + '/plugins_listing', + createSearchParam({ + [EXPERIMENTAL_PLUGINS_QUERY_PARAM]: urlSearchParams.getAll( + EXPERIMENTAL_PLUGINS_QUERY_PARAM + ), + }) + ), runs: () => createDataPath(dataDir, '/runs'), runsForExperiment: (id) => { return createDataPath( diff --git a/tensorboard/components/tf_backend/test/backendTests.ts b/tensorboard/components/tf_backend/test/backendTests.ts index 4225e37540..d0e64e500c 100644 --- a/tensorboard/components/tf_backend/test/backendTests.ts +++ b/tensorboard/components/tf_backend/test/backendTests.ts @@ -172,10 +172,6 @@ namespace tf_backend { }); }); - it('returns correct value for #pluginsListing', () => { - assert.equal(router.pluginsListing(), 'data/plugins_listing'); - }); - it('returns correct value for #runs', () => { assert.equal(router.runs(), 'data/runs'); }); @@ -187,6 +183,30 @@ namespace tf_backend { ); }); }); + + describe('#pluginsListing', () => { + it('returns /plugins_listing with no query params', () => { + const router = createRouter('data', new URLSearchParams('')); + assert.equal(router.pluginsListing(), 'data/plugins_listing'); + }); + + it('returns /plugins_listing with experimentalPlugin query params', () => { + const router = createRouter( + 'data', + new URLSearchParams( + 'experimentalPlugin=plugin1&' + + 'to_ignore=ignoreme&' + + 'experimentalPlugin=plugin2' + ) + ); + assert.equal( + router.pluginsListing(), + 'data/plugins_listing?' + + 'experimentalPlugin=plugin1&' + + 'experimentalPlugin=plugin2' + ); + }); + }); }); }); } // namespace tf_backend