diff --git a/tensorboard/BUILD b/tensorboard/BUILD index a09a7a5ac2..e89b1206d4 100644 --- a/tensorboard/BUILD +++ b/tensorboard/BUILD @@ -419,6 +419,7 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ + "//tensorboard/backend:experiment_id", "@org_mozilla_bleach", "@org_pythonhosted_markdown", "@org_pythonhosted_six", @@ -434,6 +435,7 @@ py_test( deps = [ ":plugin_util", ":test", + "//tensorboard/backend:experiment_id", "@org_pythonhosted_six", ], ) diff --git a/tensorboard/backend/BUILD b/tensorboard/backend/BUILD index 0bcc8bc12f..fb5eaa83fb 100644 --- a/tensorboard/backend/BUILD +++ b/tensorboard/backend/BUILD @@ -63,6 +63,7 @@ py_library( visibility = ["//visibility:public"], deps = [ ":empty_path_redirect", + ":experiment_id", ":http_util", ":path_prefix", "//tensorboard:errors", @@ -91,6 +92,7 @@ py_test( deps = [ ":application", "//tensorboard:errors", + "//tensorboard:plugin_util", "//tensorboard:test", "//tensorboard/backend/event_processing:event_multiplexer", "//tensorboard/plugins:base_plugin", @@ -118,6 +120,23 @@ py_test( ], ) +py_library( + name = "experiment_id", + srcs = ["experiment_id.py"], + srcs_version = "PY2AND3", +) + +py_test( + name = "experiment_id_test", + srcs = ["experiment_id_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":experiment_id", + "//tensorboard:test", + "@org_pocoo_werkzeug", + ], +) + py_library( name = "path_prefix", srcs = ["path_prefix.py"], diff --git a/tensorboard/backend/application.py b/tensorboard/backend/application.py index 0c35a780d8..c9c17d9b41 100644 --- a/tensorboard/backend/application.py +++ b/tensorboard/backend/application.py @@ -40,6 +40,7 @@ from tensorboard import errors from tensorboard.backend import empty_path_redirect +from tensorboard.backend import experiment_id from tensorboard.backend import http_util from tensorboard.backend import path_prefix from tensorboard.backend.event_processing import db_import_multiplexer @@ -331,6 +332,7 @@ def _create_wsgi_app(self): """Apply middleware to create the final WSGI app.""" app = self._route_request app = empty_path_redirect.EmptyPathRedirectMiddleware(app) + app = experiment_id.ExperimentIdMiddleware(app) app = path_prefix.PathPrefixMiddleware(app, self._path_prefix) app = _handling_errors(app) return app @@ -429,6 +431,7 @@ def _route_request(self, environ, start_response): environ: See WSGI spec (PEP 3333). start_response: See WSGI spec (PEP 3333). """ + request = wrappers.Request(environ) parsed_url = urlparse.urlparse(request.path) clean_path = _clean_path(parsed_url.path) diff --git a/tensorboard/backend/application_test.py b/tensorboard/backend/application_test.py index d1a2b740cf..e34d96603a 100644 --- a/tensorboard/backend/application_test.py +++ b/tensorboard/backend/application_test.py @@ -41,6 +41,7 @@ from werkzeug import wrappers from tensorboard import errors +from tensorboard import plugin_util from tensorboard import test as tb_test from tensorboard.backend import application from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer # pylint: disable=line-too-long @@ -621,6 +622,11 @@ def setUp(self): '/wildcard/special/exact': self._foo_handler, }, construction_callback=self._construction_callback), + FakePluginLoader( + plugin_name='whoami', + routes_mapping={ + '/eid': self._eid_handler, + }), ], dummy_assets_zip_provider) @@ -641,6 +647,12 @@ def _foo_handler(self, request): def _bar_handler(self): pass + @wrappers.Request.application + def _eid_handler(self, request): + eid = plugin_util.experiment_id(request.environ) + body = json.dumps({'experiment_id': eid}) + return wrappers.Response(body, 200, content_type='application/json') + @wrappers.Request.application def _wildcard_handler(self, request): if request.path == '/data/plugin/bar/wildcard/ok': @@ -664,11 +676,12 @@ def testPluginsAdded(self): self.assertLessEqual(expected_routes, frozenset(self.app.exact_routes)) def testNameToPluginMapping(self): - # The mapping from plugin name to instance should include both plugins. + # The mapping from plugin name to instance should include all plugins. mapping = self.context.plugin_name_to_instance - self.assertItemsEqual(['foo', 'bar'], list(mapping.keys())) + self.assertItemsEqual(['foo', 'bar', 'whoami'], list(mapping.keys())) self.assertEqual('foo', mapping['foo'].plugin_name) self.assertEqual('bar', mapping['bar'].plugin_name) + self.assertEqual('whoami', mapping['whoami'].plugin_name) def testNormalRoute(self): self._test_route('/data/plugin/foo/foo_route', 200) @@ -679,6 +692,18 @@ def testNormalRouteIsNotWildcard(self): def testMissingRoute(self): self._test_route('/data/plugin/foo/bogus', 404) + def testExperimentIdIntegration_withNoExperimentId(self): + response = self.server.get('/data/plugin/whoami/eid') + self.assertEqual(response.status_code, 200) + data = json.loads(response.get_data().decode('utf-8')) + self.assertEqual(data, {'experiment_id': ''}) + + def testExperimentIdIntegration_withExperimentId(self): + response = self.server.get('/experiment/123/data/plugin/whoami/eid') + self.assertEqual(response.status_code, 200) + data = json.loads(response.get_data().decode('utf-8')) + self.assertEqual(data, {'experiment_id': '123'}) + def testEmptyRoute(self): self._test_route('', 301) diff --git a/tensorboard/backend/experiment_id.py b/tensorboard/backend/experiment_id.py new file mode 100644 index 0000000000..7c427abe32 --- /dev/null +++ b/tensorboard/backend/experiment_id.py @@ -0,0 +1,71 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Application-level experiment ID support.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re + + +# Value of the first path component that signals that the second path +# component represents an experiment ID. +_EXPERIMENT_PATH_COMPONENT = "experiment" + +# Key into the WSGI environment used for the experiment ID. +WSGI_ENVIRON_KEY = "HTTP_TENSORBOARD_EXPERIMENT_ID" + + +class ExperimentIdMiddleware(object): + """WSGI middleware extracting experiment IDs from URL to environment. + + Any request whose path matches `/experiment/SOME_EID[/...]` will have + its first two path components stripped, and its experiment ID stored + onto the WSGI environment with key taken from the `WSGI_ENVIRON_KEY` + constant. All other requests will have paths unchanged and the + experiment ID set to the empty string. + + Instances of this class are WSGI applications (see PEP 3333). + """ + + def __init__(self, application): + """Initializes an `ExperimentIdMiddleware`. + + Args: + application: The WSGI application to wrap (see PEP 3333). + """ + self._application = application + # Regular expression that matches the whole `/experiment/EID` prefix + # (without any trailing slash) and captures the experiment ID. + self._pat = re.compile( + r"/%s/([^/]*)" % re.escape(_EXPERIMENT_PATH_COMPONENT) + ) + + def __call__(self, environ, start_response): + path = environ.get("PATH_INFO", "") + m = self._pat.match(path) + if m: + eid = m.group(1) + new_path = path[m.end(0):] + root = m.group(0) + else: + eid = "" + new_path = path + root = "" + environ[WSGI_ENVIRON_KEY] = eid + environ["PATH_INFO"] = new_path + environ["SCRIPT_NAME"] = environ.get("SCRIPT_NAME", "") + root + return self._application(environ, start_response) diff --git a/tensorboard/backend/experiment_id_test.py b/tensorboard/backend/experiment_id_test.py new file mode 100644 index 0000000000..4e8a85a129 --- /dev/null +++ b/tensorboard/backend/experiment_id_test.py @@ -0,0 +1,80 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `tensorboard.backend.experiment_id`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json + +import werkzeug + +from tensorboard import test as tb_test +from tensorboard.backend import experiment_id + + +class ExperimentIdMiddlewareTest(tb_test.TestCase): + """Tests for `ExperimentIdMiddleware`.""" + + def setUp(self): + super(ExperimentIdMiddlewareTest, self).setUp() + self.app = experiment_id.ExperimentIdMiddleware(self._echo_app) + self.server = werkzeug.test.Client(self.app, werkzeug.BaseResponse) + + def _echo_app(self, environ, start_response): + # https://www.python.org/dev/peps/pep-0333/#environ-variables + data = { + "eid": environ[experiment_id.WSGI_ENVIRON_KEY], + "path": environ.get("PATH_INFO", ""), + "script": environ.get("SCRIPT_NAME", ""), + } + body = json.dumps(data, sort_keys=True) + start_response("200 OK", [("Content-Type", "application/json")]) + return [body] + + def _assert_ok(self, response, eid, path, script): + self.assertEqual(response.status_code, 200) + actual = json.loads(response.get_data()) + expected = dict(eid=eid, path=path, script=script) + self.assertEqual(actual, expected) + + def test_no_experiment_empty_path(self): + response = self.server.get("") + self._assert_ok(response, eid="", path="", script="") + + def test_no_experiment_root_path(self): + response = self.server.get("/") + self._assert_ok(response, eid="", path="/", script="") + + def test_no_experiment_sub_path(self): + response = self.server.get("/x/y") + self._assert_ok(response, eid="", path="/x/y", script="") + + def test_with_experiment_empty_path(self): + response = self.server.get("/experiment/123") + self._assert_ok(response, eid="123", path="", script="/experiment/123") + + def test_with_experiment_root_path(self): + response = self.server.get("/experiment/123/") + self._assert_ok(response, eid="123", path="/", script="/experiment/123") + + def test_with_experiment_sub_path(self): + response = self.server.get("/experiment/123/x/y") + self._assert_ok(response, eid="123", path="/x/y", script="/experiment/123") + + +if __name__ == "__main__": + tb_test.main() diff --git a/tensorboard/plugin_util.py b/tensorboard/plugin_util.py index d1dfef5ccf..c259f001d2 100644 --- a/tensorboard/plugin_util.py +++ b/tensorboard/plugin_util.py @@ -24,6 +24,9 @@ import markdown import six +from tensorboard.backend import experiment_id as _experiment_id + + _ALLOWED_ATTRIBUTES = { 'a': ['href', 'title'], 'img': ['src', 'title', 'alt'], @@ -85,3 +88,20 @@ def markdown_to_safe_html(markdown_string): string_sanitized = bleach.clean( string_html, tags=_ALLOWED_TAGS, attributes=_ALLOWED_ATTRIBUTES) return warning + string_sanitized + + +def experiment_id(environ): + """Determine the experiment ID associated with a WSGI request. + + Each request to TensorBoard has an associated experiment ID, which is + always a string and may be empty. This experiment ID should be passed + to data providers. + + Args: + environ: A WSGI environment `dict`. For a Werkzeug request, this is + `request.environ`. + + Returns: + A experiment ID, as a possibly-empty `str`. + """ + return environ.get(_experiment_id.WSGI_ENVIRON_KEY, "") diff --git a/tensorboard/plugin_util_test.py b/tensorboard/plugin_util_test.py index 320c78348f..ae029ccb67 100644 --- a/tensorboard/plugin_util_test.py +++ b/tensorboard/plugin_util_test.py @@ -22,6 +22,7 @@ from tensorboard import plugin_util from tensorboard import test as tb_test +from tensorboard.backend import experiment_id class MarkdownToSafeHTMLTest(tb_test.TestCase): @@ -123,5 +124,19 @@ def test_null_bytes_stripped_before_markdown_processing(self): '
un_der_score
') +class ExperimentIdTest(tb_test.TestCase): + """Tests for `plugin_util.experiment_id`.""" + + def test_default(self): + # This shouldn't happen; the `ExperimentIdMiddleware` always set an + # experiment ID. In case something goes wrong, degrade gracefully. + environ = {} + self.assertEqual(plugin_util.experiment_id(environ), "") + + def test_present(self): + environ = {experiment_id.WSGI_ENVIRON_KEY: "123"} + self.assertEqual(plugin_util.experiment_id(environ), "123") + + if __name__ == '__main__': tb_test.main()