diff --git a/.travis.yml b/.travis.yml index 2f8406347f..af1297cf03 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,8 @@ dist: trusty language: python python: - - "2.7" - "3.6" + - "2.7" branches: only: @@ -46,12 +46,14 @@ before_install: install: - elapsed "install" + - "PY3=\"$(python -c 'if __import__(\"sys\").version_info[0] > 2: print(1)')\"" # Older versions of Pip sometimes resolve specifiers like `tf-nightly` # to versions other than the most recent(!). - pip install -U pip # Lint check deps. - pip install flake8==3.7.8 - pip install yamllint==1.17.0 + - if [ -n "${PY3}" ]; then pip install black==19.10b0; fi # TensorBoard deps. - pip install futures==3.1.1 - pip install grpcio==1.24.3 @@ -87,6 +89,8 @@ before_script: - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # Lint frontend code - yarn lint + # Lint backend code + - if [ -n "${PY3}" ]; then black --check .; fi # Lint .yaml docs files. Use '# yamllint disable-line rule:foo' to suppress. - yamllint -c docs/.yamllint docs docs/.yamllint # Make sure that IPython notebooks have valid Markdown. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..c73ad6175e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[tool.black] +line-length = 80 +target-version = ["py27", "py36", "py37", "py38"] diff --git a/tensorboard/__init__.py b/tensorboard/__init__.py index 6d172de7f1..510de45b0b 100644 --- a/tensorboard/__init__.py +++ b/tensorboard/__init__.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TensorBoard is a webapp for understanding TensorFlow runs and graphs. -""" +"""TensorBoard is a webapp for understanding TensorFlow runs and graphs.""" from __future__ import absolute_import from __future__ import division @@ -67,33 +66,36 @@ # additional discussion. -@lazy.lazy_load('tensorboard.notebook') +@lazy.lazy_load("tensorboard.notebook") def notebook(): - import importlib - return importlib.import_module('tensorboard.notebook') + import importlib + return importlib.import_module("tensorboard.notebook") -@lazy.lazy_load('tensorboard.program') + +@lazy.lazy_load("tensorboard.program") def program(): - import importlib - return importlib.import_module('tensorboard.program') + import importlib + + return importlib.import_module("tensorboard.program") -@lazy.lazy_load('tensorboard.summary') +@lazy.lazy_load("tensorboard.summary") def summary(): - import importlib - return importlib.import_module('tensorboard.summary') + import importlib + + return importlib.import_module("tensorboard.summary") def load_ipython_extension(ipython): - """IPython API entry point. + """IPython API entry point. - Only intended to be called by the IPython runtime. + Only intended to be called by the IPython runtime. - See: - https://ipython.readthedocs.io/en/stable/config/extensions/index.html - """ - notebook._load_ipython_extension(ipython) + See: + https://ipython.readthedocs.io/en/stable/config/extensions/index.html + """ + notebook._load_ipython_extension(ipython) __version__ = version.VERSION diff --git a/tensorboard/__main__.py b/tensorboard/__main__.py index 5ae362013a..51022d7afa 100644 --- a/tensorboard/__main__.py +++ b/tensorboard/__main__.py @@ -24,5 +24,5 @@ del _main -if __name__ == '__main__': - run_main() +if __name__ == "__main__": + run_main() diff --git a/tensorboard/backend/application.py b/tensorboard/backend/application.py index 5dd0f132dc..6ccde8d3f8 100644 --- a/tensorboard/backend/application.py +++ b/tensorboard/backend/application.py @@ -37,7 +37,9 @@ import time import six -from six.moves.urllib import parse as urlparse # pylint: disable=wrong-import-order +from six.moves.urllib import ( + parse as urlparse, +) # pylint: disable=wrong-import-order from werkzeug import wrappers @@ -48,9 +50,15 @@ from tensorboard.backend import path_prefix from tensorboard.backend import security_validator from tensorboard.backend.event_processing import db_import_multiplexer -from tensorboard.backend.event_processing import data_provider as event_data_provider -from tensorboard.backend.event_processing import plugin_event_accumulator as event_accumulator -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from tensorboard.backend.event_processing import ( + data_provider as event_data_provider, +) +from tensorboard.backend.event_processing import ( + plugin_event_accumulator as event_accumulator, +) +from tensorboard.backend.event_processing import ( + plugin_event_multiplexer as event_multiplexer, +) from tensorboard.plugins import base_plugin from tensorboard.plugins.audio import metadata as audio_metadata from tensorboard.plugins.core import core_plugin @@ -75,110 +83,115 @@ pr_curve_metadata.PLUGIN_NAME: 100, } -DATA_PREFIX = '/data' -PLUGIN_PREFIX = '/plugin' -PLUGINS_LISTING_ROUTE = '/plugins_listing' -PLUGIN_ENTRY_ROUTE = '/plugin_entry.html' +DATA_PREFIX = "/data" +PLUGIN_PREFIX = "/plugin" +PLUGINS_LISTING_ROUTE = "/plugins_listing" +PLUGIN_ENTRY_ROUTE = "/plugin_entry.html" # Slashes in a plugin name could throw the router for a loop. An empty # name would be confusing, too. To be safe, let's restrict the valid # names as follows. -_VALID_PLUGIN_RE = re.compile(r'^[A-Za-z0-9_.-]+$') +_VALID_PLUGIN_RE = re.compile(r"^[A-Za-z0-9_.-]+$") logger = tb_logging.get_logger() def tensor_size_guidance_from_flags(flags): - """Apply user per-summary size guidance overrides.""" + """Apply user per-summary size guidance overrides.""" - tensor_size_guidance = dict(DEFAULT_TENSOR_SIZE_GUIDANCE) - if not flags or not flags.samples_per_plugin: - return tensor_size_guidance + tensor_size_guidance = dict(DEFAULT_TENSOR_SIZE_GUIDANCE) + if not flags or not flags.samples_per_plugin: + return tensor_size_guidance - for token in flags.samples_per_plugin.split(','): - k, v = token.strip().split('=') - tensor_size_guidance[k] = int(v) + for token in flags.samples_per_plugin.split(","): + k, v = token.strip().split("=") + tensor_size_guidance[k] = int(v) - return tensor_size_guidance + return tensor_size_guidance def standard_tensorboard_wsgi(flags, plugin_loaders, assets_zip_provider): - """Construct a TensorBoardWSGIApp with standard plugins and multiplexer. - - Args: - flags: An argparse.Namespace containing TensorBoard CLI flags. - plugin_loaders: A list of TBLoader instances. - assets_zip_provider: See TBContext documentation for more information. - - Returns: - The new TensorBoard WSGI application. - - :type plugin_loaders: list[base_plugin.TBLoader] - :rtype: TensorBoardWSGI - """ - data_provider = None - multiplexer = None - reload_interval = flags.reload_interval - if flags.db_import: - # DB import mode. - db_uri = flags.db - # Create a temporary DB file if we weren't given one. - if not db_uri: - tmpdir = tempfile.mkdtemp(prefix='tbimport') - atexit.register(shutil.rmtree, tmpdir) - db_uri = 'sqlite:%s/tmp.sqlite' % tmpdir - db_connection_provider = create_sqlite_connection_provider(db_uri) - logger.info('Importing logdir into DB at %s', db_uri) - multiplexer = db_import_multiplexer.DbImportMultiplexer( - db_uri=db_uri, - db_connection_provider=db_connection_provider, - purge_orphaned_data=flags.purge_orphaned_data, - max_reload_threads=flags.max_reload_threads) - elif flags.db: - # DB read-only mode, never load event logs. - reload_interval = -1 - db_connection_provider = create_sqlite_connection_provider(flags.db) - multiplexer = _DbModeMultiplexer(flags.db, db_connection_provider) - else: - # Regular logdir loading mode. - multiplexer = event_multiplexer.EventMultiplexer( - size_guidance=DEFAULT_SIZE_GUIDANCE, - tensor_size_guidance=tensor_size_guidance_from_flags(flags), - purge_orphaned_data=flags.purge_orphaned_data, - max_reload_threads=flags.max_reload_threads, - event_file_active_filter=_get_event_file_active_filter(flags)) - if flags.generic_data != 'false': - data_provider = event_data_provider.MultiplexerDataProvider( - multiplexer, flags.logdir or flags.logdir_spec - ) - - if reload_interval >= 0: - # We either reload the multiplexer once when TensorBoard starts up, or we - # continuously reload the multiplexer. - if flags.logdir: - path_to_run = {os.path.expanduser(flags.logdir): None} + """Construct a TensorBoardWSGIApp with standard plugins and multiplexer. + + Args: + flags: An argparse.Namespace containing TensorBoard CLI flags. + plugin_loaders: A list of TBLoader instances. + assets_zip_provider: See TBContext documentation for more information. + + Returns: + The new TensorBoard WSGI application. + + :type plugin_loaders: list[base_plugin.TBLoader] + :rtype: TensorBoardWSGI + """ + data_provider = None + multiplexer = None + reload_interval = flags.reload_interval + if flags.db_import: + # DB import mode. + db_uri = flags.db + # Create a temporary DB file if we weren't given one. + if not db_uri: + tmpdir = tempfile.mkdtemp(prefix="tbimport") + atexit.register(shutil.rmtree, tmpdir) + db_uri = "sqlite:%s/tmp.sqlite" % tmpdir + db_connection_provider = create_sqlite_connection_provider(db_uri) + logger.info("Importing logdir into DB at %s", db_uri) + multiplexer = db_import_multiplexer.DbImportMultiplexer( + db_uri=db_uri, + db_connection_provider=db_connection_provider, + purge_orphaned_data=flags.purge_orphaned_data, + max_reload_threads=flags.max_reload_threads, + ) + elif flags.db: + # DB read-only mode, never load event logs. + reload_interval = -1 + db_connection_provider = create_sqlite_connection_provider(flags.db) + multiplexer = _DbModeMultiplexer(flags.db, db_connection_provider) else: - path_to_run = parse_event_files_spec(flags.logdir_spec) - start_reloading_multiplexer( - multiplexer, path_to_run, reload_interval, flags.reload_task) - return TensorBoardWSGIApp( - flags, plugin_loaders, data_provider, assets_zip_provider, multiplexer) + # Regular logdir loading mode. + multiplexer = event_multiplexer.EventMultiplexer( + size_guidance=DEFAULT_SIZE_GUIDANCE, + tensor_size_guidance=tensor_size_guidance_from_flags(flags), + purge_orphaned_data=flags.purge_orphaned_data, + max_reload_threads=flags.max_reload_threads, + event_file_active_filter=_get_event_file_active_filter(flags), + ) + if flags.generic_data != "false": + data_provider = event_data_provider.MultiplexerDataProvider( + multiplexer, flags.logdir or flags.logdir_spec + ) + + if reload_interval >= 0: + # We either reload the multiplexer once when TensorBoard starts up, or we + # continuously reload the multiplexer. + if flags.logdir: + path_to_run = {os.path.expanduser(flags.logdir): None} + else: + path_to_run = parse_event_files_spec(flags.logdir_spec) + start_reloading_multiplexer( + multiplexer, path_to_run, reload_interval, flags.reload_task + ) + return TensorBoardWSGIApp( + flags, plugin_loaders, data_provider, assets_zip_provider, multiplexer + ) def _handling_errors(wsgi_app): - def wrapper(*args): - (environ, start_response) = (args[-2], args[-1]) - try: - return wsgi_app(*args) - except errors.PublicError as e: - request = wrappers.Request(environ) - error_app = http_util.Respond( - request, str(e), "text/plain", code=e.http_code - ) - return error_app(environ, start_response) - # Let other exceptions be handled by the server, as an opaque - # internal server error. - return wrapper + def wrapper(*args): + (environ, start_response) = (args[-2], args[-1]) + try: + return wsgi_app(*args) + except errors.PublicError as e: + request = wrappers.Request(environ) + error_app = http_util.Respond( + request, str(e), "text/plain", code=e.http_code + ) + return error_app(environ, start_response) + # Let other exceptions be handled by the server, as an opaque + # internal server error. + + return wrapper def TensorBoardWSGIApp( @@ -186,568 +199,624 @@ def TensorBoardWSGIApp( plugins, data_provider=None, assets_zip_provider=None, - deprecated_multiplexer=None): - """Constructs a TensorBoard WSGI app from plugins and data providers. - - Args: - flags: An argparse.Namespace containing TensorBoard CLI flags. - plugins: A list of plugin loader instances. - assets_zip_provider: See TBContext documentation for more information. - data_provider: Instance of `tensorboard.data.provider.DataProvider`. May - be `None` if `flags.generic_data` is set to `"false"` in which case - `deprecated_multiplexer` must be passed instead. - deprecated_multiplexer: Optional `plugin_event_multiplexer.EventMultiplexer` - to use for any plugins not yet enabled for the DataProvider API. - Required if the data_provider argument is not passed. - - Returns: - A WSGI application that implements the TensorBoard backend. - - :type plugins: list[base_plugin.TBLoader] - """ - db_uri = None - db_connection_provider = None - if isinstance( - deprecated_multiplexer, - (db_import_multiplexer.DbImportMultiplexer, _DbModeMultiplexer)): - db_uri = deprecated_multiplexer.db_uri - db_connection_provider = deprecated_multiplexer.db_connection_provider - plugin_name_to_instance = {} - context = base_plugin.TBContext( - data_provider=data_provider, - db_connection_provider=db_connection_provider, - db_uri=db_uri, - flags=flags, - logdir=flags.logdir, - multiplexer=deprecated_multiplexer, - assets_zip_provider=assets_zip_provider, - plugin_name_to_instance=plugin_name_to_instance, - window_title=flags.window_title) - tbplugins = [] - for plugin_spec in plugins: - loader = make_plugin_loader(plugin_spec) - plugin = loader.load(context) - if plugin is None: - continue - tbplugins.append(plugin) - plugin_name_to_instance[plugin.plugin_name] = plugin - return TensorBoardWSGI(tbplugins, flags.path_prefix) - - -class TensorBoardWSGI(object): - """The TensorBoard WSGI app that delegates to a set of TBPlugin.""" - - def __init__(self, plugins, path_prefix=''): - """Constructs TensorBoardWSGI instance. + deprecated_multiplexer=None, +): + """Constructs a TensorBoard WSGI app from plugins and data providers. Args: - plugins: A list of base_plugin.TBPlugin subclass instances. flags: An argparse.Namespace containing TensorBoard CLI flags. + plugins: A list of plugin loader instances. + assets_zip_provider: See TBContext documentation for more information. + data_provider: Instance of `tensorboard.data.provider.DataProvider`. May + be `None` if `flags.generic_data` is set to `"false"` in which case + `deprecated_multiplexer` must be passed instead. + deprecated_multiplexer: Optional `plugin_event_multiplexer.EventMultiplexer` + to use for any plugins not yet enabled for the DataProvider API. + Required if the data_provider argument is not passed. Returns: - A WSGI application for the set of all TBPlugin instances. + A WSGI application that implements the TensorBoard backend. - Raises: - ValueError: If some plugin has no plugin_name - ValueError: If some plugin has an invalid plugin_name (plugin - names must only contain [A-Za-z0-9_.-]) - ValueError: If two plugins have the same plugin_name - ValueError: If some plugin handles a route that does not start - with a slash - - :type plugins: list[base_plugin.TBPlugin] + :type plugins: list[base_plugin.TBLoader] """ - self._plugins = plugins - self._path_prefix = path_prefix - if self._path_prefix.endswith('/'): - # Should have been fixed by `fix_flags`. - raise ValueError('Trailing slash in path prefix: %r' % self._path_prefix) - - self.exact_routes = { - # TODO(@chihuahua): Delete this RPC once we have skylark rules that - # obviate the need for the frontend to determine which plugins are - # active. - DATA_PREFIX + PLUGINS_LISTING_ROUTE: self._serve_plugins_listing, - DATA_PREFIX + PLUGIN_ENTRY_ROUTE: self._serve_plugin_entry, - } - unordered_prefix_routes = {} - - # Serve the routes from the registered plugins using their name as the route - # prefix. For example if plugin z has two routes /a and /b, they will be - # served as /data/plugin/z/a and /data/plugin/z/b. - plugin_names_encountered = set() - for plugin in self._plugins: - if plugin.plugin_name is None: - raise ValueError('Plugin %s has no plugin_name' % plugin) - if not _VALID_PLUGIN_RE.match(plugin.plugin_name): - raise ValueError('Plugin %s has invalid name %r' % (plugin, - plugin.plugin_name)) - if plugin.plugin_name in plugin_names_encountered: - raise ValueError('Duplicate plugins for name %s' % plugin.plugin_name) - plugin_names_encountered.add(plugin.plugin_name) - - try: - plugin_apps = plugin.get_plugin_apps() - except Exception as e: # pylint: disable=broad-except - if type(plugin) is core_plugin.CorePlugin: # pylint: disable=unidiomatic-typecheck - raise - logger.warn('Plugin %s failed. Exception: %s', - plugin.plugin_name, str(e)) - continue - for route, app in plugin_apps.items(): - if not route.startswith('/'): - raise ValueError('Plugin named %r handles invalid route %r: ' - 'route does not start with a slash' % - (plugin.plugin_name, route)) - if type(plugin) is core_plugin.CorePlugin: # pylint: disable=unidiomatic-typecheck - path = route - else: - path = ( - DATA_PREFIX + PLUGIN_PREFIX + '/' + plugin.plugin_name + route - ) - - if path.endswith('/*'): - # Note we remove the '*' but leave the slash in place. - path = path[:-1] - if '*' in path: - # note we re-add the removed * in the format string - raise ValueError('Plugin %r handles invalid route \'%s*\': Only ' - 'trailing wildcards are supported ' - '(i.e., `/.../*`)' % - (plugin.plugin_name, path)) - unordered_prefix_routes[path] = app - else: - if '*' in path: - raise ValueError('Plugin %r handles invalid route %r: Only ' - 'trailing wildcards are supported ' - '(i.e., `/.../*`)' % - (plugin.plugin_name, path)) - self.exact_routes[path] = app - - # Wildcard routes will be checked in the given order, so we sort them - # longest to shortest so that a more specific route will take precedence - # over a more general one (e.g., a catchall route `/*` should come last). - self.prefix_routes = collections.OrderedDict( - sorted( - six.iteritems(unordered_prefix_routes), - key=lambda x: len(x[0]), - reverse=True)) - - self._app = self._create_wsgi_app() - - def _create_wsgi_app(self): - """Apply middleware to create the final WSGI app.""" - app = self._route_request - app = empty_path_redirect.EmptyPathRedirectMiddleware(app) - app = experiment_id.ExperimentIdMiddleware(app) - app = path_prefix.PathPrefixMiddleware(app, self._path_prefix) - app = security_validator.SecurityValidatorMiddleware(app) - app = _handling_errors(app) - return app - - @wrappers.Request.application - def _serve_plugin_entry(self, request): - """Serves a HTML for iframed plugin entry point. + db_uri = None + db_connection_provider = None + if isinstance( + deprecated_multiplexer, + (db_import_multiplexer.DbImportMultiplexer, _DbModeMultiplexer), + ): + db_uri = deprecated_multiplexer.db_uri + db_connection_provider = deprecated_multiplexer.db_connection_provider + plugin_name_to_instance = {} + context = base_plugin.TBContext( + data_provider=data_provider, + db_connection_provider=db_connection_provider, + db_uri=db_uri, + flags=flags, + logdir=flags.logdir, + multiplexer=deprecated_multiplexer, + assets_zip_provider=assets_zip_provider, + plugin_name_to_instance=plugin_name_to_instance, + window_title=flags.window_title, + ) + tbplugins = [] + for plugin_spec in plugins: + loader = make_plugin_loader(plugin_spec) + plugin = loader.load(context) + if plugin is None: + continue + tbplugins.append(plugin) + plugin_name_to_instance[plugin.plugin_name] = plugin + return TensorBoardWSGI(tbplugins, flags.path_prefix) - Args: - request: The werkzeug.Request object. - Returns: - A werkzeug.Response object. - """ - name = request.args.get('name') - plugins = [ - plugin for plugin in self._plugins if plugin.plugin_name == name] - - if not plugins: - raise errors.NotFoundError(name) - - if len(plugins) > 1: - # Technically is not possible as plugin names are unique and is checked - # by the check on __init__. - reason = ( - 'Plugin invariant error: multiple plugins with name ' - '{name} found: {list}' - ).format(name=name, list=plugins) - raise AssertionError(reason) - - plugin = plugins[0] - module_path = plugin.frontend_metadata().es_module_path - if not module_path: - return http_util.Respond( - request, 'Plugin is not module loadable', 'text/plain', code=400) - - # non-self origin is blocked by CSP but this is a good invariant checking. - if urlparse.urlparse(module_path).netloc: - raise ValueError('Expected es_module_path to be non-absolute path') - - module_json = json.dumps('.' + module_path) - script_content = 'import({}).then((m) => void m.render());'.format( - module_json) - digest = hashlib.sha256(script_content.encode('utf-8')).digest() - script_sha = base64.b64encode(digest).decode('ascii') - - html = textwrap.dedent(""" +class TensorBoardWSGI(object): + """The TensorBoard WSGI app that delegates to a set of TBPlugin.""" + + def __init__(self, plugins, path_prefix=""): + """Constructs TensorBoardWSGI instance. + + Args: + plugins: A list of base_plugin.TBPlugin subclass instances. + flags: An argparse.Namespace containing TensorBoard CLI flags. + + Returns: + A WSGI application for the set of all TBPlugin instances. + + Raises: + ValueError: If some plugin has no plugin_name + ValueError: If some plugin has an invalid plugin_name (plugin + names must only contain [A-Za-z0-9_.-]) + ValueError: If two plugins have the same plugin_name + ValueError: If some plugin handles a route that does not start + with a slash + + :type plugins: list[base_plugin.TBPlugin] + """ + self._plugins = plugins + self._path_prefix = path_prefix + if self._path_prefix.endswith("/"): + # Should have been fixed by `fix_flags`. + raise ValueError( + "Trailing slash in path prefix: %r" % self._path_prefix + ) + + self.exact_routes = { + # TODO(@chihuahua): Delete this RPC once we have skylark rules that + # obviate the need for the frontend to determine which plugins are + # active. + DATA_PREFIX + PLUGINS_LISTING_ROUTE: self._serve_plugins_listing, + DATA_PREFIX + PLUGIN_ENTRY_ROUTE: self._serve_plugin_entry, + } + unordered_prefix_routes = {} + + # Serve the routes from the registered plugins using their name as the route + # prefix. For example if plugin z has two routes /a and /b, they will be + # served as /data/plugin/z/a and /data/plugin/z/b. + plugin_names_encountered = set() + for plugin in self._plugins: + if plugin.plugin_name is None: + raise ValueError("Plugin %s has no plugin_name" % plugin) + if not _VALID_PLUGIN_RE.match(plugin.plugin_name): + raise ValueError( + "Plugin %s has invalid name %r" + % (plugin, plugin.plugin_name) + ) + if plugin.plugin_name in plugin_names_encountered: + raise ValueError( + "Duplicate plugins for name %s" % plugin.plugin_name + ) + plugin_names_encountered.add(plugin.plugin_name) + + try: + plugin_apps = plugin.get_plugin_apps() + except Exception as e: # pylint: disable=broad-except + if ( + type(plugin) is core_plugin.CorePlugin + ): # pylint: disable=unidiomatic-typecheck + raise + logger.warn( + "Plugin %s failed. Exception: %s", + plugin.plugin_name, + str(e), + ) + continue + for route, app in plugin_apps.items(): + if not route.startswith("/"): + raise ValueError( + "Plugin named %r handles invalid route %r: " + "route does not start with a slash" + % (plugin.plugin_name, route) + ) + if ( + type(plugin) is core_plugin.CorePlugin + ): # pylint: disable=unidiomatic-typecheck + path = route + else: + path = ( + DATA_PREFIX + + PLUGIN_PREFIX + + "/" + + plugin.plugin_name + + route + ) + + if path.endswith("/*"): + # Note we remove the '*' but leave the slash in place. + path = path[:-1] + if "*" in path: + # note we re-add the removed * in the format string + raise ValueError( + "Plugin %r handles invalid route '%s*': Only " + "trailing wildcards are supported " + "(i.e., `/.../*`)" % (plugin.plugin_name, path) + ) + unordered_prefix_routes[path] = app + else: + if "*" in path: + raise ValueError( + "Plugin %r handles invalid route %r: Only " + "trailing wildcards are supported " + "(i.e., `/.../*`)" % (plugin.plugin_name, path) + ) + self.exact_routes[path] = app + + # Wildcard routes will be checked in the given order, so we sort them + # longest to shortest so that a more specific route will take precedence + # over a more general one (e.g., a catchall route `/*` should come last). + self.prefix_routes = collections.OrderedDict( + sorted( + six.iteritems(unordered_prefix_routes), + key=lambda x: len(x[0]), + reverse=True, + ) + ) + + self._app = self._create_wsgi_app() + + def _create_wsgi_app(self): + """Apply middleware to create the final WSGI app.""" + app = self._route_request + app = empty_path_redirect.EmptyPathRedirectMiddleware(app) + app = experiment_id.ExperimentIdMiddleware(app) + app = path_prefix.PathPrefixMiddleware(app, self._path_prefix) + app = security_validator.SecurityValidatorMiddleware(app) + app = _handling_errors(app) + return app + + @wrappers.Request.application + def _serve_plugin_entry(self, request): + """Serves a HTML for iframed plugin entry point. + + Args: + request: The werkzeug.Request object. + + Returns: + A werkzeug.Response object. + """ + name = request.args.get("name") + plugins = [ + plugin for plugin in self._plugins if plugin.plugin_name == name + ] + + if not plugins: + raise errors.NotFoundError(name) + + if len(plugins) > 1: + # Technically is not possible as plugin names are unique and is checked + # by the check on __init__. + reason = ( + "Plugin invariant error: multiple plugins with name " + "{name} found: {list}" + ).format(name=name, list=plugins) + raise AssertionError(reason) + + plugin = plugins[0] + module_path = plugin.frontend_metadata().es_module_path + if not module_path: + return http_util.Respond( + request, "Plugin is not module loadable", "text/plain", code=400 + ) + + # non-self origin is blocked by CSP but this is a good invariant checking. + if urlparse.urlparse(module_path).netloc: + raise ValueError("Expected es_module_path to be non-absolute path") + + module_json = json.dumps("." + module_path) + script_content = "import({}).then((m) => void m.render());".format( + module_json + ) + digest = hashlib.sha256(script_content.encode("utf-8")).digest() + script_sha = base64.b64encode(digest).decode("ascii") + + html = textwrap.dedent( + """ - """).format(name=name, script_content=script_content) - return http_util.Respond( - request, - html, - 'text/html', - csp_scripts_sha256s=[script_sha], - ) + """ + ).format(name=name, script_content=script_content) + return http_util.Respond( + request, html, "text/html", csp_scripts_sha256s=[script_sha], + ) - @wrappers.Request.application - def _serve_plugins_listing(self, request): - """Serves an object mapping plugin name to whether it is enabled. + @wrappers.Request.application + def _serve_plugins_listing(self, request): + """Serves an object mapping plugin name to whether it is enabled. + + Args: + request: The werkzeug.Request object. + + Returns: + A werkzeug.Response object. + """ + response = collections.OrderedDict() + for plugin in self._plugins: + if ( + type(plugin) is core_plugin.CorePlugin + ): # pylint: disable=unidiomatic-typecheck + # This plugin's existence is a backend implementation detail. + continue + start = time.time() + is_active = plugin.is_active() + elapsed = time.time() - start + logger.info( + "Plugin listing: is_active() for %s took %0.3f seconds", + plugin.plugin_name, + elapsed, + ) + + plugin_metadata = plugin.frontend_metadata() + output_metadata = { + "disable_reload": plugin_metadata.disable_reload, + "enabled": is_active, + # loading_mechanism set below + "remove_dom": plugin_metadata.remove_dom, + # tab_name set below + } + + if plugin_metadata.tab_name is not None: + output_metadata["tab_name"] = plugin_metadata.tab_name + else: + output_metadata["tab_name"] = plugin.plugin_name + + es_module_handler = plugin_metadata.es_module_path + element_name = plugin_metadata.element_name + is_ng_component = plugin_metadata.is_ng_component + if is_ng_component: + if element_name is not None: + raise ValueError( + "Plugin %r declared as both Angular built-in and legacy" + % plugin.plugin_name + ) + if es_module_handler is not None: + raise ValueError( + "Plugin %r declared as both Angular built-in and iframed" + % plugin.plugin_name + ) + loading_mechanism = { + "type": "NG_COMPONENT", + } + elif element_name is not None and es_module_handler is not None: + logger.error( + "Plugin %r declared as both legacy and iframed; skipping", + plugin.plugin_name, + ) + continue + elif element_name is not None and es_module_handler is None: + loading_mechanism = { + "type": "CUSTOM_ELEMENT", + "element_name": element_name, + } + elif element_name is None and es_module_handler is not None: + loading_mechanism = { + "type": "IFRAME", + "module_path": "".join( + [ + request.script_root, + DATA_PREFIX, + PLUGIN_PREFIX, + "/", + plugin.plugin_name, + es_module_handler, + ] + ), + } + else: + # As a compatibility measure (for plugins that we don't + # control), we'll pull it from the frontend registry for now. + loading_mechanism = { + "type": "NONE", + } + output_metadata["loading_mechanism"] = loading_mechanism + + response[plugin.plugin_name] = output_metadata + return http_util.Respond(request, response, "application/json") + + def __call__(self, environ, start_response): + """Central entry point for the TensorBoard application. + + This __call__ method conforms to the WSGI spec, so that instances of this + class are WSGI applications. + + Args: + environ: See WSGI spec (PEP 3333). + start_response: See WSGI spec (PEP 3333). + """ + return self._app(environ, start_response) + + def _route_request(self, environ, start_response): + """Delegate an incoming request to sub-applications. + + This method supports strict string matching and wildcard routes of a + single path component, such as `/foo/*`. Other routing patterns, + like regular expressions, are not supported. + + This is the main TensorBoard entry point before middleware is + applied. (See `_create_wsgi_app`.) + + Args: + environ: See WSGI spec (PEP 3333). + start_response: See WSGI spec (PEP 3333). + """ + + request = wrappers.Request(environ) + parsed_url = urlparse.urlparse(request.path) + clean_path = _clean_path(parsed_url.path) + + # pylint: disable=too-many-function-args + if clean_path in self.exact_routes: + return self.exact_routes[clean_path](environ, start_response) + else: + for path_prefix in self.prefix_routes: + if clean_path.startswith(path_prefix): + return self.prefix_routes[path_prefix]( + environ, start_response + ) - Args: - request: The werkzeug.Request object. + logger.warn("path %s not found, sending 404", clean_path) + return http_util.Respond( + request, "Not found", "text/plain", code=404 + )(environ, start_response) + # pylint: enable=too-many-function-args - Returns: - A werkzeug.Response object. - """ - response = collections.OrderedDict() - for plugin in self._plugins: - if type(plugin) is core_plugin.CorePlugin: # pylint: disable=unidiomatic-typecheck - # This plugin's existence is a backend implementation detail. - continue - start = time.time() - is_active = plugin.is_active() - elapsed = time.time() - start - logger.info( - 'Plugin listing: is_active() for %s took %0.3f seconds', - plugin.plugin_name, elapsed) - - plugin_metadata = plugin.frontend_metadata() - output_metadata = { - 'disable_reload': plugin_metadata.disable_reload, - 'enabled': is_active, - # loading_mechanism set below - 'remove_dom': plugin_metadata.remove_dom, - # tab_name set below - } - - if plugin_metadata.tab_name is not None: - output_metadata['tab_name'] = plugin_metadata.tab_name - else: - output_metadata['tab_name'] = plugin.plugin_name - - es_module_handler = plugin_metadata.es_module_path - element_name = plugin_metadata.element_name - is_ng_component = plugin_metadata.is_ng_component - if is_ng_component: - if element_name is not None: - raise ValueError( - 'Plugin %r declared as both Angular built-in and legacy' % - plugin.plugin_name) - if es_module_handler is not None: - raise ValueError( - 'Plugin %r declared as both Angular built-in and iframed' % - plugin.plugin_name) - loading_mechanism = { - 'type': 'NG_COMPONENT', - } - elif element_name is not None and es_module_handler is not None: - logger.error( - 'Plugin %r declared as both legacy and iframed; skipping', - plugin.plugin_name, - ) - continue - elif element_name is not None and es_module_handler is None: - loading_mechanism = { - 'type': 'CUSTOM_ELEMENT', - 'element_name': element_name, - } - elif element_name is None and es_module_handler is not None: - loading_mechanism = { - 'type': 'IFRAME', - 'module_path': ''.join([ - request.script_root, DATA_PREFIX, PLUGIN_PREFIX, '/', - plugin.plugin_name, es_module_handler, - ]), - } - else: - # As a compatibility measure (for plugins that we don't - # control), we'll pull it from the frontend registry for now. - loading_mechanism = { - 'type': 'NONE', - } - output_metadata['loading_mechanism'] = loading_mechanism - response[plugin.plugin_name] = output_metadata - return http_util.Respond(request, response, 'application/json') +def parse_event_files_spec(logdir_spec): + """Parses `logdir_spec` into a map from paths to run group names. - def __call__(self, environ, start_response): - """Central entry point for the TensorBoard application. + The `--logdir_spec` flag format is a comma-separated list of path + specifications. A path spec looks like 'group_name:/path/to/directory' or + '/path/to/directory'; in the latter case, the group is unnamed. Group names + cannot start with a forward slash: /foo:bar/baz will be interpreted as a spec + with no name and path '/foo:bar/baz'. - This __call__ method conforms to the WSGI spec, so that instances of this - class are WSGI applications. + Globs are not supported. Args: - environ: See WSGI spec (PEP 3333). - start_response: See WSGI spec (PEP 3333). + logdir: A comma-separated list of run specifications. + Returns: + A dict mapping directory paths to names like {'/path/to/directory': 'name'}. + Groups without an explicit name are named after their path. If logdir is + None, returns an empty dict, which is helpful for testing things that don't + require any valid runs. """ - return self._app(environ, start_response) + files = {} + if logdir_spec is None: + return files + # Make sure keeping consistent with ParseURI in core/lib/io/path.cc + uri_pattern = re.compile("[a-zA-Z][0-9a-zA-Z.]*://.*") + for specification in logdir_spec.split(","): + # Check if the spec contains group. A spec start with xyz:// is regarded as + # URI path spec instead of group spec. If the spec looks like /foo:bar/baz, + # then we assume it's a path with a colon. If the spec looks like + # [a-zA-z]:\foo then we assume its a Windows path and not a single letter + # group + if ( + uri_pattern.match(specification) is None + and ":" in specification + and specification[0] != "/" + and not os.path.splitdrive(specification)[0] + ): + # We split at most once so run_name:/path:with/a/colon will work. + run_name, _, path = specification.partition(":") + else: + run_name = None + path = specification + if uri_pattern.match(path) is None: + path = os.path.realpath(os.path.expanduser(path)) + files[path] = run_name + return files - def _route_request(self, environ, start_response): - """Delegate an incoming request to sub-applications. - This method supports strict string matching and wildcard routes of a - single path component, such as `/foo/*`. Other routing patterns, - like regular expressions, are not supported. +def start_reloading_multiplexer( + multiplexer, path_to_run, load_interval, reload_task +): + """Starts automatically reloading the given multiplexer. - This is the main TensorBoard entry point before middleware is - applied. (See `_create_wsgi_app`.) + If `load_interval` is positive, the thread will reload the multiplexer + by calling `ReloadMultiplexer` every `load_interval` seconds, starting + immediately. Otherwise, reloads the multiplexer once and never again. Args: - environ: See WSGI spec (PEP 3333). - start_response: See WSGI spec (PEP 3333). - """ - - request = wrappers.Request(environ) - parsed_url = urlparse.urlparse(request.path) - clean_path = _clean_path(parsed_url.path) + multiplexer: The `EventMultiplexer` to add runs to and reload. + path_to_run: A dict mapping from paths to run names, where `None` as the run + name is interpreted as a run name equal to the path. + load_interval: An integer greater than or equal to 0. If positive, how many + seconds to wait after one load before starting the next load. Otherwise, + reloads the multiplexer once and never again (no continuous reloading). + reload_task: Indicates the type of background task to reload with. - # pylint: disable=too-many-function-args - if clean_path in self.exact_routes: - return self.exact_routes[clean_path](environ, start_response) + Raises: + ValueError: If `load_interval` is negative. + """ + if load_interval < 0: + raise ValueError("load_interval is negative: %d" % load_interval) + + def _reload(): + while True: + start = time.time() + logger.info("TensorBoard reload process beginning") + for path, name in six.iteritems(path_to_run): + multiplexer.AddRunsFromDirectory(path, name) + logger.info( + "TensorBoard reload process: Reload the whole Multiplexer" + ) + multiplexer.Reload() + duration = time.time() - start + logger.info( + "TensorBoard done reloading. Load took %0.3f secs", duration + ) + if load_interval == 0: + # Only load the multiplexer once. Do not continuously reload. + break + time.sleep(load_interval) + + if reload_task == "process": + logger.info("Launching reload in a child process") + import multiprocessing + + process = multiprocessing.Process(target=_reload, name="Reloader") + # Best-effort cleanup; on exit, the main TB parent process will attempt to + # kill all its daemonic children. + process.daemon = True + process.start() + elif reload_task in ("thread", "auto"): + logger.info("Launching reload in a daemon thread") + thread = threading.Thread(target=_reload, name="Reloader") + # Make this a daemon thread, which won't block TB from exiting. + thread.daemon = True + thread.start() + elif reload_task == "blocking": + if load_interval != 0: + raise ValueError( + "blocking reload only allowed with load_interval=0" + ) + _reload() else: - for path_prefix in self.prefix_routes: - if clean_path.startswith(path_prefix): - return self.prefix_routes[path_prefix](environ, start_response) + raise ValueError("unrecognized reload_task: %s" % reload_task) - logger.warn('path %s not found, sending 404', clean_path) - return http_util.Respond(request, 'Not found', 'text/plain', code=404)( - environ, start_response) - # pylint: enable=too-many-function-args +def create_sqlite_connection_provider(db_uri): + """Returns function that returns SQLite Connection objects. -def parse_event_files_spec(logdir_spec): - """Parses `logdir_spec` into a map from paths to run group names. - - The `--logdir_spec` flag format is a comma-separated list of path - specifications. A path spec looks like 'group_name:/path/to/directory' or - '/path/to/directory'; in the latter case, the group is unnamed. Group names - cannot start with a forward slash: /foo:bar/baz will be interpreted as a spec - with no name and path '/foo:bar/baz'. - - Globs are not supported. - - Args: - logdir: A comma-separated list of run specifications. - Returns: - A dict mapping directory paths to names like {'/path/to/directory': 'name'}. - Groups without an explicit name are named after their path. If logdir is - None, returns an empty dict, which is helpful for testing things that don't - require any valid runs. - """ - files = {} - if logdir_spec is None: - return files - # Make sure keeping consistent with ParseURI in core/lib/io/path.cc - uri_pattern = re.compile('[a-zA-Z][0-9a-zA-Z.]*://.*') - for specification in logdir_spec.split(','): - # Check if the spec contains group. A spec start with xyz:// is regarded as - # URI path spec instead of group spec. If the spec looks like /foo:bar/baz, - # then we assume it's a path with a colon. If the spec looks like - # [a-zA-z]:\foo then we assume its a Windows path and not a single letter - # group - if (uri_pattern.match(specification) is None and ':' in specification and - specification[0] != '/' and not os.path.splitdrive(specification)[0]): - # We split at most once so run_name:/path:with/a/colon will work. - run_name, _, path = specification.partition(':') - else: - run_name = None - path = specification - if uri_pattern.match(path) is None: - path = os.path.realpath(os.path.expanduser(path)) - files[path] = run_name - return files - - -def start_reloading_multiplexer(multiplexer, path_to_run, load_interval, - reload_task): - """Starts automatically reloading the given multiplexer. - - If `load_interval` is positive, the thread will reload the multiplexer - by calling `ReloadMultiplexer` every `load_interval` seconds, starting - immediately. Otherwise, reloads the multiplexer once and never again. - - Args: - multiplexer: The `EventMultiplexer` to add runs to and reload. - path_to_run: A dict mapping from paths to run names, where `None` as the run - name is interpreted as a run name equal to the path. - load_interval: An integer greater than or equal to 0. If positive, how many - seconds to wait after one load before starting the next load. Otherwise, - reloads the multiplexer once and never again (no continuous reloading). - reload_task: Indicates the type of background task to reload with. - - Raises: - ValueError: If `load_interval` is negative. - """ - if load_interval < 0: - raise ValueError('load_interval is negative: %d' % load_interval) - - def _reload(): - while True: - start = time.time() - logger.info('TensorBoard reload process beginning') - for path, name in six.iteritems(path_to_run): - multiplexer.AddRunsFromDirectory(path, name) - logger.info('TensorBoard reload process: Reload the whole Multiplexer') - multiplexer.Reload() - duration = time.time() - start - logger.info('TensorBoard done reloading. Load took %0.3f secs', duration) - if load_interval == 0: - # Only load the multiplexer once. Do not continuously reload. - break - time.sleep(load_interval) - - if reload_task == 'process': - logger.info('Launching reload in a child process') - import multiprocessing - process = multiprocessing.Process(target=_reload, name='Reloader') - # Best-effort cleanup; on exit, the main TB parent process will attempt to - # kill all its daemonic children. - process.daemon = True - process.start() - elif reload_task in ('thread', 'auto'): - logger.info('Launching reload in a daemon thread') - thread = threading.Thread(target=_reload, name='Reloader') - # Make this a daemon thread, which won't block TB from exiting. - thread.daemon = True - thread.start() - elif reload_task == 'blocking': - if load_interval != 0: - raise ValueError('blocking reload only allowed with load_interval=0') - _reload() - else: - raise ValueError('unrecognized reload_task: %s' % reload_task) + Args: + db_uri: A string URI expressing the DB file, e.g. "sqlite:~/tb.db". + Returns: + A function that returns a new PEP-249 DB Connection, which must be closed, + each time it is called. -def create_sqlite_connection_provider(db_uri): - """Returns function that returns SQLite Connection objects. - - Args: - db_uri: A string URI expressing the DB file, e.g. "sqlite:~/tb.db". - - Returns: - A function that returns a new PEP-249 DB Connection, which must be closed, - each time it is called. - - Raises: - ValueError: If db_uri is not a valid sqlite file URI. - """ - uri = urlparse.urlparse(db_uri) - if uri.scheme != 'sqlite': - raise ValueError('Only sqlite DB URIs are supported: ' + db_uri) - if uri.netloc: - raise ValueError('Can not connect to SQLite over network: ' + db_uri) - if uri.path == ':memory:': - raise ValueError('Memory mode SQLite not supported: ' + db_uri) - path = os.path.expanduser(uri.path) - params = _get_connect_params(uri.query) - # TODO(@jart): Add thread-local pooling. - return lambda: sqlite3.connect(path, **params) + Raises: + ValueError: If db_uri is not a valid sqlite file URI. + """ + uri = urlparse.urlparse(db_uri) + if uri.scheme != "sqlite": + raise ValueError("Only sqlite DB URIs are supported: " + db_uri) + if uri.netloc: + raise ValueError("Can not connect to SQLite over network: " + db_uri) + if uri.path == ":memory:": + raise ValueError("Memory mode SQLite not supported: " + db_uri) + path = os.path.expanduser(uri.path) + params = _get_connect_params(uri.query) + # TODO(@jart): Add thread-local pooling. + return lambda: sqlite3.connect(path, **params) def _get_connect_params(query): - params = urlparse.parse_qs(query) - if any(len(v) > 2 for v in params.values()): - raise ValueError('DB URI params list has duplicate keys: ' + query) - return {k: json.loads(v[0]) for k, v in params.items()} + params = urlparse.parse_qs(query) + if any(len(v) > 2 for v in params.values()): + raise ValueError("DB URI params list has duplicate keys: " + query) + return {k: json.loads(v[0]) for k, v in params.items()} def _clean_path(path): - """Removes a trailing slash from a non-root path. + """Removes a trailing slash from a non-root path. - Arguments: - path: The path of a request. + Arguments: + path: The path of a request. - Returns: - The route to use to serve the request. - """ - if path != '/' and path.endswith('/'): - return path[:-1] - return path + Returns: + The route to use to serve the request. + """ + if path != "/" and path.endswith("/"): + return path[:-1] + return path def _get_event_file_active_filter(flags): - """Returns a predicate for whether an event file load timestamp is active. - - Returns: - A predicate function accepting a single UNIX timestamp float argument, or - None if multi-file loading is not enabled. - """ - if not flags.reload_multifile: - return None - inactive_secs = flags.reload_multifile_inactive_secs - if inactive_secs == 0: - return None - if inactive_secs < 0: - return lambda timestamp: True - return lambda timestamp: timestamp + inactive_secs >= time.time() + """Returns a predicate for whether an event file load timestamp is active. + Returns: + A predicate function accepting a single UNIX timestamp float argument, or + None if multi-file loading is not enabled. + """ + if not flags.reload_multifile: + return None + inactive_secs = flags.reload_multifile_inactive_secs + if inactive_secs == 0: + return None + if inactive_secs < 0: + return lambda timestamp: True + return lambda timestamp: timestamp + inactive_secs >= time.time() -class _DbModeMultiplexer(event_multiplexer.EventMultiplexer): - """Shim EventMultiplexer to use when in read-only DB mode. - In read-only DB mode, the EventMultiplexer is nonfunctional - there is no - logdir to reload, and the data is all exposed via SQL. This class represents - the do-nothing EventMultiplexer for that purpose, which serves only as a - conduit for DB-related parameters. +class _DbModeMultiplexer(event_multiplexer.EventMultiplexer): + """Shim EventMultiplexer to use when in read-only DB mode. - The load APIs raise exceptions if called, and the read APIs always - return empty results. - """ - def __init__(self, db_uri, db_connection_provider): - """Constructor for `_DbModeMultiplexer`. + In read-only DB mode, the EventMultiplexer is nonfunctional - there is no + logdir to reload, and the data is all exposed via SQL. This class represents + the do-nothing EventMultiplexer for that purpose, which serves only as a + conduit for DB-related parameters. - Args: - db_uri: A URI to the database file in use. - db_connection_provider: Provider function for creating a DB connection. + The load APIs raise exceptions if called, and the read APIs always + return empty results. """ - logger.info('_DbModeMultiplexer initializing for %s', db_uri) - super(_DbModeMultiplexer, self).__init__() - self.db_uri = db_uri - self.db_connection_provider = db_connection_provider - logger.info('_DbModeMultiplexer done initializing') - def AddRun(self, path, name=None): - """Unsupported.""" - raise NotImplementedError() + def __init__(self, db_uri, db_connection_provider): + """Constructor for `_DbModeMultiplexer`. + + Args: + db_uri: A URI to the database file in use. + db_connection_provider: Provider function for creating a DB connection. + """ + logger.info("_DbModeMultiplexer initializing for %s", db_uri) + super(_DbModeMultiplexer, self).__init__() + self.db_uri = db_uri + self.db_connection_provider = db_connection_provider + logger.info("_DbModeMultiplexer done initializing") + + def AddRun(self, path, name=None): + """Unsupported.""" + raise NotImplementedError() - def AddRunsFromDirectory(self, path, name=None): - """Unsupported.""" - raise NotImplementedError() + def AddRunsFromDirectory(self, path, name=None): + """Unsupported.""" + raise NotImplementedError() - def Reload(self): - """Unsupported.""" - raise NotImplementedError() + def Reload(self): + """Unsupported.""" + raise NotImplementedError() def make_plugin_loader(plugin_spec): - """Returns a plugin loader for the given plugin. - - Args: - plugin_spec: A TBPlugin subclass, or a TBLoader instance or subclass. - - Returns: - A TBLoader for the given plugin. - - :type plugin_spec: - Type[base_plugin.TBPlugin] | Type[base_plugin.TBLoader] | - base_plugin.TBLoader - :rtype: base_plugin.TBLoader - """ - if isinstance(plugin_spec, base_plugin.TBLoader): - return plugin_spec - if isinstance(plugin_spec, type): - if issubclass(plugin_spec, base_plugin.TBLoader): - return plugin_spec() - if issubclass(plugin_spec, base_plugin.TBPlugin): - return base_plugin.BasicLoader(plugin_spec) - raise TypeError("Not a TBLoader or TBPlugin subclass: %r" % (plugin_spec,)) + """Returns a plugin loader for the given plugin. + + Args: + plugin_spec: A TBPlugin subclass, or a TBLoader instance or subclass. + + Returns: + A TBLoader for the given plugin. + + :type plugin_spec: + Type[base_plugin.TBPlugin] | Type[base_plugin.TBLoader] | + base_plugin.TBLoader + :rtype: base_plugin.TBLoader + """ + if isinstance(plugin_spec, base_plugin.TBLoader): + return plugin_spec + if isinstance(plugin_spec, type): + if issubclass(plugin_spec, base_plugin.TBLoader): + return plugin_spec() + if issubclass(plugin_spec, base_plugin.TBPlugin): + return base_plugin.BasicLoader(plugin_spec) + raise TypeError("Not a TBLoader or TBPlugin subclass: %r" % (plugin_spec,)) diff --git a/tensorboard/backend/application_test.py b/tensorboard/backend/application_test.py index 111d135a54..36cf656495 100644 --- a/tensorboard/backend/application_test.py +++ b/tensorboard/backend/application_test.py @@ -32,10 +32,10 @@ import six try: - # python version >= 3.3 - from unittest import mock + # python version >= 3.3 + from unittest import mock except ImportError: - import mock # pylint: disable=unused-import + import mock # pylint: disable=unused-import from werkzeug import test as werkzeug_test from werkzeug import wrappers @@ -44,865 +44,920 @@ from tensorboard import plugin_util from tensorboard import test as tb_test from tensorboard.backend import application -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from tensorboard.backend.event_processing import ( + plugin_event_multiplexer as event_multiplexer, +) from tensorboard.plugins import base_plugin class FakeFlags(object): - def __init__( - self, - logdir, - logdir_spec='', - purge_orphaned_data=True, - reload_interval=60, - samples_per_plugin='', - max_reload_threads=1, - reload_task='auto', - db='', - db_import=False, - window_title='', - path_prefix='', - reload_multifile=False, - reload_multifile_inactive_secs=4000, - generic_data='auto'): - self.logdir = logdir - self.logdir_spec = logdir_spec - self.purge_orphaned_data = purge_orphaned_data - self.reload_interval = reload_interval - self.samples_per_plugin = samples_per_plugin - self.max_reload_threads = max_reload_threads - self.reload_task = reload_task - self.db = db - self.db_import = db_import - self.window_title = window_title - self.path_prefix = path_prefix - self.reload_multifile = reload_multifile - self.reload_multifile_inactive_secs = reload_multifile_inactive_secs - self.generic_data = generic_data + def __init__( + self, + logdir, + logdir_spec="", + purge_orphaned_data=True, + reload_interval=60, + samples_per_plugin="", + max_reload_threads=1, + reload_task="auto", + db="", + db_import=False, + window_title="", + path_prefix="", + reload_multifile=False, + reload_multifile_inactive_secs=4000, + generic_data="auto", + ): + self.logdir = logdir + self.logdir_spec = logdir_spec + self.purge_orphaned_data = purge_orphaned_data + self.reload_interval = reload_interval + self.samples_per_plugin = samples_per_plugin + self.max_reload_threads = max_reload_threads + self.reload_task = reload_task + self.db = db + self.db_import = db_import + self.window_title = window_title + self.path_prefix = path_prefix + self.reload_multifile = reload_multifile + self.reload_multifile_inactive_secs = reload_multifile_inactive_secs + self.generic_data = generic_data class FakePlugin(base_plugin.TBPlugin): - """A plugin with no functionality.""" - - def __init__(self, - context=None, - plugin_name='foo', - is_active_value=True, - routes_mapping={}, - element_name_value=None, - es_module_path_value=None, - is_ng_component=False, - construction_callback=None): - """Constructs a fake plugin. - - Args: - context: The TBContext magic container. Contains properties that are - potentially useful to this plugin. - plugin_name: The name of this plugin. - is_active_value: Whether the plugin is active. - routes_mapping: A dictionary mapping from route (string URL path) to the - method called when a user issues a request to that route. - es_module_path_value: An optional string value that indicates a frontend - module entry to the plugin. Must be one of the keys of routes_mapping. - is_ng_component: Whether this plugin is of the built-in Angular-based - type. - construction_callback: An optional callback called when the plugin is - constructed. The callback is passed the TBContext. - """ - self.plugin_name = plugin_name - self._is_active_value = is_active_value - self._routes_mapping = routes_mapping - self._element_name_value = element_name_value - self._es_module_path_value = es_module_path_value - self._is_ng_component = is_ng_component - - if construction_callback: - construction_callback(context) - - def get_plugin_apps(self): - """Returns a mapping from routes to handlers offered by this plugin. - - Returns: - A dictionary mapping from routes to handlers offered by this plugin. - """ - return self._routes_mapping - - def is_active(self): - """Returns whether this plugin is active. - - Returns: - A boolean. Whether this plugin is active. - """ - return self._is_active_value - - def frontend_metadata(self): - return base_plugin.FrontendMetadata( - element_name=self._element_name_value, - es_module_path=self._es_module_path_value, - is_ng_component=self._is_ng_component, - ) + """A plugin with no functionality.""" + + def __init__( + self, + context=None, + plugin_name="foo", + is_active_value=True, + routes_mapping={}, + element_name_value=None, + es_module_path_value=None, + is_ng_component=False, + construction_callback=None, + ): + """Constructs a fake plugin. + + Args: + context: The TBContext magic container. Contains properties that are + potentially useful to this plugin. + plugin_name: The name of this plugin. + is_active_value: Whether the plugin is active. + routes_mapping: A dictionary mapping from route (string URL path) to the + method called when a user issues a request to that route. + es_module_path_value: An optional string value that indicates a frontend + module entry to the plugin. Must be one of the keys of routes_mapping. + is_ng_component: Whether this plugin is of the built-in Angular-based + type. + construction_callback: An optional callback called when the plugin is + constructed. The callback is passed the TBContext. + """ + self.plugin_name = plugin_name + self._is_active_value = is_active_value + self._routes_mapping = routes_mapping + self._element_name_value = element_name_value + self._es_module_path_value = es_module_path_value + self._is_ng_component = is_ng_component + + if construction_callback: + construction_callback(context) + + def get_plugin_apps(self): + """Returns a mapping from routes to handlers offered by this plugin. + + Returns: + A dictionary mapping from routes to handlers offered by this plugin. + """ + return self._routes_mapping + + def is_active(self): + """Returns whether this plugin is active. + + Returns: + A boolean. Whether this plugin is active. + """ + return self._is_active_value + + def frontend_metadata(self): + return base_plugin.FrontendMetadata( + element_name=self._element_name_value, + es_module_path=self._es_module_path_value, + is_ng_component=self._is_ng_component, + ) class FakePluginLoader(base_plugin.TBLoader): - """Pass-through loader for FakePlugin with arbitrary arguments.""" + """Pass-through loader for FakePlugin with arbitrary arguments.""" - def __init__(self, **kwargs): - self._kwargs = kwargs + def __init__(self, **kwargs): + self._kwargs = kwargs - def load(self, context): - return FakePlugin(context, **self._kwargs) + def load(self, context): + return FakePlugin(context, **self._kwargs) class HandlingErrorsTest(tb_test.TestCase): - - def test_successful_response_passes_through(self): - @application._handling_errors - @wrappers.Request.application - def app(request): - return wrappers.Response('All is well', 200, content_type='text/html') - - server = werkzeug_test.Client(app, wrappers.BaseResponse) - response = server.get('/') - self.assertEqual(response.get_data(), b'All is well') - self.assertEqual(response.status_code, 200) - self.assertEqual(response.headers.get('Content-Type'), 'text/html') - - def test_public_errors_serve_response(self): - @application._handling_errors - @wrappers.Request.application - def app(request): - raise errors.NotFoundError('no scalar data for run=foo, tag=bar') - - server = werkzeug_test.Client(app, wrappers.BaseResponse) - response = server.get('/') - self.assertEqual( - response.get_data(), - b'Not found: no scalar data for run=foo, tag=bar', - ) - self.assertEqual(response.status_code, 404) - self.assertStartsWith(response.headers.get('Content-Type'), 'text/plain') - - def test_internal_errors_propagate(self): - @application._handling_errors - @wrappers.Request.application - def app(request): - raise ValueError('something borked internally') - - server = werkzeug_test.Client(app, wrappers.BaseResponse) - with self.assertRaises(ValueError) as cm: - response = server.get('/') - self.assertEqual(str(cm.exception), 'something borked internally') - - def test_passes_through_non_wsgi_args(self): - class C(object): - @application._handling_errors - def __call__(self, environ, start_response): - start_response('200 OK', [('Content-Type', 'text/html')]) - yield b'All is well' - - app = C() - server = werkzeug_test.Client(app, wrappers.BaseResponse) - response = server.get('/') - self.assertEqual(response.get_data(), b'All is well') - self.assertEqual(response.status_code, 200) - self.assertEqual(response.headers.get('Content-Type'), 'text/html') + def test_successful_response_passes_through(self): + @application._handling_errors + @wrappers.Request.application + def app(request): + return wrappers.Response( + "All is well", 200, content_type="text/html" + ) + + server = werkzeug_test.Client(app, wrappers.BaseResponse) + response = server.get("/") + self.assertEqual(response.get_data(), b"All is well") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.headers.get("Content-Type"), "text/html") + + def test_public_errors_serve_response(self): + @application._handling_errors + @wrappers.Request.application + def app(request): + raise errors.NotFoundError("no scalar data for run=foo, tag=bar") + + server = werkzeug_test.Client(app, wrappers.BaseResponse) + response = server.get("/") + self.assertEqual( + response.get_data(), + b"Not found: no scalar data for run=foo, tag=bar", + ) + self.assertEqual(response.status_code, 404) + self.assertStartsWith( + response.headers.get("Content-Type"), "text/plain" + ) + + def test_internal_errors_propagate(self): + @application._handling_errors + @wrappers.Request.application + def app(request): + raise ValueError("something borked internally") + + server = werkzeug_test.Client(app, wrappers.BaseResponse) + with self.assertRaises(ValueError) as cm: + response = server.get("/") + self.assertEqual(str(cm.exception), "something borked internally") + + def test_passes_through_non_wsgi_args(self): + class C(object): + @application._handling_errors + def __call__(self, environ, start_response): + start_response("200 OK", [("Content-Type", "text/html")]) + yield b"All is well" + + app = C() + server = werkzeug_test.Client(app, wrappers.BaseResponse) + response = server.get("/") + self.assertEqual(response.get_data(), b"All is well") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.headers.get("Content-Type"), "text/html") class ApplicationTest(tb_test.TestCase): - def setUp(self): - plugins = [ - FakePlugin(plugin_name='foo'), - FakePlugin( - plugin_name='bar', - is_active_value=False, - element_name_value='tf-bar-dashboard', - ), - FakePlugin( - plugin_name='baz', - routes_mapping={ - '/esmodule': lambda req: None, - }, - es_module_path_value='/esmodule' - ), - FakePlugin( - plugin_name='qux', - is_active_value=False, - is_ng_component=True, - ), - ] - app = application.TensorBoardWSGI(plugins) - self.server = werkzeug_test.Client(app, wrappers.BaseResponse) - - def _get_json(self, path): - response = self.server.get(path) - self.assertEqual(200, response.status_code) - self.assertEqual('application/json', response.headers.get('Content-Type')) - return json.loads(response.get_data().decode('utf-8')) - - def testBasicStartup(self): - """Start the server up and then shut it down immediately.""" - pass - - def testRequestNonexistentPage(self): - """Request a page that doesn't exist; it should 404.""" - response = self.server.get('/asdf') - self.assertEqual(404, response.status_code) - - def testPluginsListing(self): - """Test the format of the data/plugins_listing endpoint.""" - parsed_object = self._get_json('/data/plugins_listing') - self.assertEqual( - parsed_object, - { - 'foo': { - 'enabled': True, - 'loading_mechanism': {'type': 'NONE'}, - 'remove_dom': False, - 'tab_name': 'foo', - 'disable_reload': False, - }, - 'bar': { - 'enabled': False, - 'loading_mechanism': { - 'type': 'CUSTOM_ELEMENT', - 'element_name': 'tf-bar-dashboard', + def setUp(self): + plugins = [ + FakePlugin(plugin_name="foo"), + FakePlugin( + plugin_name="bar", + is_active_value=False, + element_name_value="tf-bar-dashboard", + ), + FakePlugin( + plugin_name="baz", + routes_mapping={"/esmodule": lambda req: None,}, + es_module_path_value="/esmodule", + ), + FakePlugin( + plugin_name="qux", is_active_value=False, is_ng_component=True, + ), + ] + app = application.TensorBoardWSGI(plugins) + self.server = werkzeug_test.Client(app, wrappers.BaseResponse) + + def _get_json(self, path): + response = self.server.get(path) + self.assertEqual(200, response.status_code) + self.assertEqual( + "application/json", response.headers.get("Content-Type") + ) + return json.loads(response.get_data().decode("utf-8")) + + def testBasicStartup(self): + """Start the server up and then shut it down immediately.""" + pass + + def testRequestNonexistentPage(self): + """Request a page that doesn't exist; it should 404.""" + response = self.server.get("/asdf") + self.assertEqual(404, response.status_code) + + def testPluginsListing(self): + """Test the format of the data/plugins_listing endpoint.""" + parsed_object = self._get_json("/data/plugins_listing") + self.assertEqual( + parsed_object, + { + "foo": { + "enabled": True, + "loading_mechanism": {"type": "NONE"}, + "remove_dom": False, + "tab_name": "foo", + "disable_reload": False, }, - 'tab_name': 'bar', - 'remove_dom': False, - 'disable_reload': False, - }, - 'baz': { - 'enabled': True, - 'loading_mechanism': { - 'type': 'IFRAME', - 'module_path': '/data/plugin/baz/esmodule', + "bar": { + "enabled": False, + "loading_mechanism": { + "type": "CUSTOM_ELEMENT", + "element_name": "tf-bar-dashboard", + }, + "tab_name": "bar", + "remove_dom": False, + "disable_reload": False, }, - 'tab_name': 'baz', - 'remove_dom': False, - 'disable_reload': False, - }, - 'qux': { - 'enabled': False, - 'loading_mechanism': { - 'type': 'NG_COMPONENT', + "baz": { + "enabled": True, + "loading_mechanism": { + "type": "IFRAME", + "module_path": "/data/plugin/baz/esmodule", + }, + "tab_name": "baz", + "remove_dom": False, + "disable_reload": False, + }, + "qux": { + "enabled": False, + "loading_mechanism": {"type": "NG_COMPONENT",}, + "tab_name": "qux", + "remove_dom": False, + "disable_reload": False, }, - 'tab_name': 'qux', - 'remove_dom': False, - 'disable_reload': False, }, - - } - ) - - def testPluginEntry(self): - """Test the data/plugin_entry.html endpoint.""" - response = self.server.get('/data/plugin_entry.html?name=baz') - self.assertEqual(200, response.status_code) - self.assertEqual( - 'text/html; charset=utf-8', response.headers.get('Content-Type')) - - document = response.get_data().decode('utf-8') - self.assertIn('', document) - self.assertIn( - 'import("./esmodule").then((m) => void m.render());', document) - # base64 sha256 of above script - self.assertIn( - "'sha256-3KGOnqHhLsX2RmjH/K2DurN9N2qtApZk5zHdSPg4LcA='", - response.headers.get('Content-Security-Policy'), - ) - - for name in ['bazz', 'baz ']: - response = self.server.get('/data/plugin_entry.html?name=%s' % name) - self.assertEqual(404, response.status_code) - - for name in ['foo', 'bar']: - response = self.server.get('/data/plugin_entry.html?name=%s' % name) - self.assertEqual(400, response.status_code) - self.assertEqual( - response.get_data().decode('utf-8'), - 'Plugin is not module loadable', - ) - - def testPluginEntryBadModulePath(self): - plugins = [ - FakePlugin( - plugin_name='mallory', - es_module_path_value='//pwn.tb/somepath' - ), - ] - app = application.TensorBoardWSGI(plugins) - server = werkzeug_test.Client(app, wrappers.BaseResponse) - with six.assertRaisesRegex( - self, ValueError, 'Expected es_module_path to be non-absolute path'): - server.get('/data/plugin_entry.html?name=mallory') - - def testNgComponentPluginWithIncompatibleSetElementName(self): - plugins = [ - FakePlugin( - plugin_name='quux', - is_ng_component=True, - element_name_value='incompatible', - ), - ] - app = application.TensorBoardWSGI(plugins) - server = werkzeug_test.Client(app, wrappers.BaseResponse) - with six.assertRaisesRegex( - self, ValueError, 'quux.*declared.*both Angular.*legacy'): - server.get('/data/plugins_listing') - - def testNgComponentPluginWithIncompatiblEsModulePath(self): - plugins = [ - FakePlugin( - plugin_name='quux', - is_ng_component=True, - es_module_path_value='//incompatible', - ), - ] - app = application.TensorBoardWSGI(plugins) - server = werkzeug_test.Client(app, wrappers.BaseResponse) - with six.assertRaisesRegex( - self, ValueError, 'quux.*declared.*both Angular.*iframed'): - server.get('/data/plugins_listing') + ) + + def testPluginEntry(self): + """Test the data/plugin_entry.html endpoint.""" + response = self.server.get("/data/plugin_entry.html?name=baz") + self.assertEqual(200, response.status_code) + self.assertEqual( + "text/html; charset=utf-8", response.headers.get("Content-Type") + ) + + document = response.get_data().decode("utf-8") + self.assertIn('', document) + self.assertIn( + 'import("./esmodule").then((m) => void m.render());', document + ) + # base64 sha256 of above script + self.assertIn( + "'sha256-3KGOnqHhLsX2RmjH/K2DurN9N2qtApZk5zHdSPg4LcA='", + response.headers.get("Content-Security-Policy"), + ) + + for name in ["bazz", "baz "]: + response = self.server.get("/data/plugin_entry.html?name=%s" % name) + self.assertEqual(404, response.status_code) + + for name in ["foo", "bar"]: + response = self.server.get("/data/plugin_entry.html?name=%s" % name) + self.assertEqual(400, response.status_code) + self.assertEqual( + response.get_data().decode("utf-8"), + "Plugin is not module loadable", + ) + + def testPluginEntryBadModulePath(self): + plugins = [ + FakePlugin( + plugin_name="mallory", es_module_path_value="//pwn.tb/somepath" + ), + ] + app = application.TensorBoardWSGI(plugins) + server = werkzeug_test.Client(app, wrappers.BaseResponse) + with six.assertRaisesRegex( + self, ValueError, "Expected es_module_path to be non-absolute path" + ): + server.get("/data/plugin_entry.html?name=mallory") + + def testNgComponentPluginWithIncompatibleSetElementName(self): + plugins = [ + FakePlugin( + plugin_name="quux", + is_ng_component=True, + element_name_value="incompatible", + ), + ] + app = application.TensorBoardWSGI(plugins) + server = werkzeug_test.Client(app, wrappers.BaseResponse) + with six.assertRaisesRegex( + self, ValueError, "quux.*declared.*both Angular.*legacy" + ): + server.get("/data/plugins_listing") + + def testNgComponentPluginWithIncompatiblEsModulePath(self): + plugins = [ + FakePlugin( + plugin_name="quux", + is_ng_component=True, + es_module_path_value="//incompatible", + ), + ] + app = application.TensorBoardWSGI(plugins) + server = werkzeug_test.Client(app, wrappers.BaseResponse) + with six.assertRaisesRegex( + self, ValueError, "quux.*declared.*both Angular.*iframed" + ): + server.get("/data/plugins_listing") class ApplicationBaseUrlTest(tb_test.TestCase): - path_prefix = '/test' - def setUp(self): - plugins = [ - FakePlugin(plugin_name='foo'), - FakePlugin( - plugin_name='bar', - is_active_value=False, - element_name_value='tf-bar-dashboard', - ), - FakePlugin( - plugin_name='baz', - routes_mapping={ - '/esmodule': lambda req: None, - }, - es_module_path_value='/esmodule' - ), - ] - app = application.TensorBoardWSGI(plugins, path_prefix=self.path_prefix) - self.server = werkzeug_test.Client(app, wrappers.BaseResponse) - - def _get_json(self, path): - response = self.server.get(path) - self.assertEqual(200, response.status_code) - self.assertEqual('application/json', response.headers.get('Content-Type')) - return json.loads(response.get_data().decode('utf-8')) - - def testBaseUrlRequest(self): - """Base URL should redirect to "/" for proper relative URLs.""" - response = self.server.get(self.path_prefix) - self.assertEqual(301, response.status_code) - - def testBaseUrlRequestNonexistentPage(self): - """Request a page that doesn't exist; it should 404.""" - response = self.server.get(self.path_prefix + '/asdf') - self.assertEqual(404, response.status_code) - - def testBaseUrlNonexistentPluginsListing(self): - """Test the format of the data/plugins_listing endpoint.""" - response = self.server.get('/non_existent_prefix/data/plugins_listing') - self.assertEqual(404, response.status_code) - - def testPluginsListing(self): - """Test the format of the data/plugins_listing endpoint.""" - parsed_object = self._get_json(self.path_prefix + '/data/plugins_listing') - self.assertEqual( - parsed_object, - { - 'foo': { - 'enabled': True, - 'loading_mechanism': {'type': 'NONE'}, - 'remove_dom': False, - 'tab_name': 'foo', - 'disable_reload': False, - }, - 'bar': { - 'enabled': False, - 'loading_mechanism': { - 'type': 'CUSTOM_ELEMENT', - 'element_name': 'tf-bar-dashboard', + path_prefix = "/test" + + def setUp(self): + plugins = [ + FakePlugin(plugin_name="foo"), + FakePlugin( + plugin_name="bar", + is_active_value=False, + element_name_value="tf-bar-dashboard", + ), + FakePlugin( + plugin_name="baz", + routes_mapping={"/esmodule": lambda req: None,}, + es_module_path_value="/esmodule", + ), + ] + app = application.TensorBoardWSGI(plugins, path_prefix=self.path_prefix) + self.server = werkzeug_test.Client(app, wrappers.BaseResponse) + + def _get_json(self, path): + response = self.server.get(path) + self.assertEqual(200, response.status_code) + self.assertEqual( + "application/json", response.headers.get("Content-Type") + ) + return json.loads(response.get_data().decode("utf-8")) + + def testBaseUrlRequest(self): + """Base URL should redirect to "/" for proper relative URLs.""" + response = self.server.get(self.path_prefix) + self.assertEqual(301, response.status_code) + + def testBaseUrlRequestNonexistentPage(self): + """Request a page that doesn't exist; it should 404.""" + response = self.server.get(self.path_prefix + "/asdf") + self.assertEqual(404, response.status_code) + + def testBaseUrlNonexistentPluginsListing(self): + """Test the format of the data/plugins_listing endpoint.""" + response = self.server.get("/non_existent_prefix/data/plugins_listing") + self.assertEqual(404, response.status_code) + + def testPluginsListing(self): + """Test the format of the data/plugins_listing endpoint.""" + parsed_object = self._get_json( + self.path_prefix + "/data/plugins_listing" + ) + self.assertEqual( + parsed_object, + { + "foo": { + "enabled": True, + "loading_mechanism": {"type": "NONE"}, + "remove_dom": False, + "tab_name": "foo", + "disable_reload": False, }, - 'tab_name': 'bar', - 'remove_dom': False, - 'disable_reload': False, - }, - 'baz': { - 'enabled': True, - 'loading_mechanism': { - 'type': 'IFRAME', - 'module_path': '/test/data/plugin/baz/esmodule', + "bar": { + "enabled": False, + "loading_mechanism": { + "type": "CUSTOM_ELEMENT", + "element_name": "tf-bar-dashboard", + }, + "tab_name": "bar", + "remove_dom": False, + "disable_reload": False, + }, + "baz": { + "enabled": True, + "loading_mechanism": { + "type": "IFRAME", + "module_path": "/test/data/plugin/baz/esmodule", + }, + "tab_name": "baz", + "remove_dom": False, + "disable_reload": False, }, - 'tab_name': 'baz', - 'remove_dom': False, - 'disable_reload': False, }, - } - ) + ) class ApplicationPluginNameTest(tb_test.TestCase): - - def testSimpleName(self): - application.TensorBoardWSGI( - plugins=[FakePlugin(plugin_name='scalars')]) - - def testComprehensiveName(self): - application.TensorBoardWSGI( - plugins=[FakePlugin(plugin_name='Scalar-Dashboard_3000.1')]) - - def testNameIsNone(self): - with six.assertRaisesRegex(self, ValueError, r'no plugin_name'): - application.TensorBoardWSGI( - plugins=[FakePlugin(plugin_name=None)]) - - def testEmptyName(self): - with six.assertRaisesRegex(self, ValueError, r'invalid name'): - application.TensorBoardWSGI( - plugins=[FakePlugin(plugin_name='')]) - - def testNameWithSlashes(self): - with six.assertRaisesRegex(self, ValueError, r'invalid name'): - application.TensorBoardWSGI( - plugins=[FakePlugin(plugin_name='scalars/data')]) - - def testNameWithSpaces(self): - with six.assertRaisesRegex(self, ValueError, r'invalid name'): - application.TensorBoardWSGI( - plugins=[FakePlugin(plugin_name='my favorite plugin')]) - - def testDuplicateName(self): - with six.assertRaisesRegex(self, ValueError, r'Duplicate'): - application.TensorBoardWSGI( - plugins=[FakePlugin(plugin_name='scalars'), - FakePlugin(plugin_name='scalars')]) + def testSimpleName(self): + application.TensorBoardWSGI(plugins=[FakePlugin(plugin_name="scalars")]) + + def testComprehensiveName(self): + application.TensorBoardWSGI( + plugins=[FakePlugin(plugin_name="Scalar-Dashboard_3000.1")] + ) + + def testNameIsNone(self): + with six.assertRaisesRegex(self, ValueError, r"no plugin_name"): + application.TensorBoardWSGI(plugins=[FakePlugin(plugin_name=None)]) + + def testEmptyName(self): + with six.assertRaisesRegex(self, ValueError, r"invalid name"): + application.TensorBoardWSGI(plugins=[FakePlugin(plugin_name="")]) + + def testNameWithSlashes(self): + with six.assertRaisesRegex(self, ValueError, r"invalid name"): + application.TensorBoardWSGI( + plugins=[FakePlugin(plugin_name="scalars/data")] + ) + + def testNameWithSpaces(self): + with six.assertRaisesRegex(self, ValueError, r"invalid name"): + application.TensorBoardWSGI( + plugins=[FakePlugin(plugin_name="my favorite plugin")] + ) + + def testDuplicateName(self): + with six.assertRaisesRegex(self, ValueError, r"Duplicate"): + application.TensorBoardWSGI( + plugins=[ + FakePlugin(plugin_name="scalars"), + FakePlugin(plugin_name="scalars"), + ] + ) class ApplicationPluginRouteTest(tb_test.TestCase): + def _make_plugin(self, route): + return FakePlugin( + plugin_name="foo", + routes_mapping={route: lambda environ, start_response: None}, + ) - def _make_plugin(self, route): - return FakePlugin( - plugin_name='foo', - routes_mapping={route: lambda environ, start_response: None}) - - def testNormalRoute(self): - application.TensorBoardWSGI([self._make_plugin('/runs')]) + def testNormalRoute(self): + application.TensorBoardWSGI([self._make_plugin("/runs")]) - def testWildcardRoute(self): - application.TensorBoardWSGI([self._make_plugin('/foo/*')]) + def testWildcardRoute(self): + application.TensorBoardWSGI([self._make_plugin("/foo/*")]) - def testNonPathComponentWildcardRoute(self): - with six.assertRaisesRegex(self, ValueError, r'invalid route'): - application.TensorBoardWSGI([self._make_plugin('/foo*')]) + def testNonPathComponentWildcardRoute(self): + with six.assertRaisesRegex(self, ValueError, r"invalid route"): + application.TensorBoardWSGI([self._make_plugin("/foo*")]) - def testMultiWildcardRoute(self): - with six.assertRaisesRegex(self, ValueError, r'invalid route'): - application.TensorBoardWSGI([self._make_plugin('/foo/*/bar/*')]) + def testMultiWildcardRoute(self): + with six.assertRaisesRegex(self, ValueError, r"invalid route"): + application.TensorBoardWSGI([self._make_plugin("/foo/*/bar/*")]) - def testInternalWildcardRoute(self): - with six.assertRaisesRegex(self, ValueError, r'invalid route'): - application.TensorBoardWSGI([self._make_plugin('/foo/*/bar')]) + def testInternalWildcardRoute(self): + with six.assertRaisesRegex(self, ValueError, r"invalid route"): + application.TensorBoardWSGI([self._make_plugin("/foo/*/bar")]) - def testEmptyRoute(self): - with six.assertRaisesRegex(self, ValueError, r'invalid route'): - application.TensorBoardWSGI([self._make_plugin('')]) + def testEmptyRoute(self): + with six.assertRaisesRegex(self, ValueError, r"invalid route"): + application.TensorBoardWSGI([self._make_plugin("")]) - def testSlashlessRoute(self): - with six.assertRaisesRegex(self, ValueError, r'invalid route'): - application.TensorBoardWSGI([self._make_plugin('runaway')]) + def testSlashlessRoute(self): + with six.assertRaisesRegex(self, ValueError, r"invalid route"): + application.TensorBoardWSGI([self._make_plugin("runaway")]) class MakePluginLoaderTest(tb_test.TestCase): + def testMakePluginLoader_pluginClass(self): + loader = application.make_plugin_loader(FakePlugin) + self.assertIsInstance(loader, base_plugin.BasicLoader) + self.assertIs(loader.plugin_class, FakePlugin) - def testMakePluginLoader_pluginClass(self): - loader = application.make_plugin_loader(FakePlugin) - self.assertIsInstance(loader, base_plugin.BasicLoader) - self.assertIs(loader.plugin_class, FakePlugin) - - def testMakePluginLoader_pluginLoaderClass(self): - loader = application.make_plugin_loader(FakePluginLoader) - self.assertIsInstance(loader, FakePluginLoader) + def testMakePluginLoader_pluginLoaderClass(self): + loader = application.make_plugin_loader(FakePluginLoader) + self.assertIsInstance(loader, FakePluginLoader) - def testMakePluginLoader_pluginLoader(self): - loader = FakePluginLoader() - self.assertIs(loader, application.make_plugin_loader(loader)) + def testMakePluginLoader_pluginLoader(self): + loader = FakePluginLoader() + self.assertIs(loader, application.make_plugin_loader(loader)) - def testMakePluginLoader_invalidType(self): - with six.assertRaisesRegex(self, TypeError, 'FakePlugin'): - application.make_plugin_loader(FakePlugin()) + def testMakePluginLoader_invalidType(self): + with six.assertRaisesRegex(self, TypeError, "FakePlugin"): + application.make_plugin_loader(FakePlugin()) class GetEventFileActiveFilterTest(tb_test.TestCase): - - def testDisabled(self): - flags = FakeFlags('logdir', reload_multifile=False) - self.assertIsNone(application._get_event_file_active_filter(flags)) - - def testInactiveSecsZero(self): - flags = FakeFlags('logdir', reload_multifile=True, - reload_multifile_inactive_secs=0) - self.assertIsNone(application._get_event_file_active_filter(flags)) - - def testInactiveSecsNegative(self): - flags = FakeFlags('logdir', reload_multifile=True, - reload_multifile_inactive_secs=-1) - filter_fn = application._get_event_file_active_filter(flags) - self.assertTrue(filter_fn(0)) - self.assertTrue(filter_fn(time.time())) - self.assertTrue(filter_fn(float("inf"))) - - def testInactiveSecs(self): - flags = FakeFlags('logdir', reload_multifile=True, - reload_multifile_inactive_secs=10) - filter_fn = application._get_event_file_active_filter(flags) - with mock.patch.object(time, 'time') as mock_time: - mock_time.return_value = 100 - self.assertFalse(filter_fn(0)) - self.assertFalse(filter_fn(time.time() - 11)) - self.assertTrue(filter_fn(time.time() - 10)) - self.assertTrue(filter_fn(time.time())) - self.assertTrue(filter_fn(float("inf"))) + def testDisabled(self): + flags = FakeFlags("logdir", reload_multifile=False) + self.assertIsNone(application._get_event_file_active_filter(flags)) + + def testInactiveSecsZero(self): + flags = FakeFlags( + "logdir", reload_multifile=True, reload_multifile_inactive_secs=0 + ) + self.assertIsNone(application._get_event_file_active_filter(flags)) + + def testInactiveSecsNegative(self): + flags = FakeFlags( + "logdir", reload_multifile=True, reload_multifile_inactive_secs=-1 + ) + filter_fn = application._get_event_file_active_filter(flags) + self.assertTrue(filter_fn(0)) + self.assertTrue(filter_fn(time.time())) + self.assertTrue(filter_fn(float("inf"))) + + def testInactiveSecs(self): + flags = FakeFlags( + "logdir", reload_multifile=True, reload_multifile_inactive_secs=10 + ) + filter_fn = application._get_event_file_active_filter(flags) + with mock.patch.object(time, "time") as mock_time: + mock_time.return_value = 100 + self.assertFalse(filter_fn(0)) + self.assertFalse(filter_fn(time.time() - 11)) + self.assertTrue(filter_fn(time.time() - 10)) + self.assertTrue(filter_fn(time.time())) + self.assertTrue(filter_fn(float("inf"))) class ParseEventFilesSpecTest(tb_test.TestCase): - - def assertPlatformSpecificLogdirParsing(self, pathObj, logdir, expected): - """ - A custom assertion to test :func:`parse_event_files_spec` under various - systems. - - Args: - pathObj: a custom replacement object for `os.path`, typically - `posixpath` or `ntpath` - logdir: the string to be parsed by - :func:`~application.parse_event_files_spec` - expected: the expected dictionary as returned by - :func:`~application.parse_event_files_spec` - - """ - - with mock.patch('os.path', pathObj): - self.assertEqual(application.parse_event_files_spec(logdir), expected) - - def testBasic(self): - self.assertPlatformSpecificLogdirParsing( - posixpath, '/lol/cat', {'/lol/cat': None}) - self.assertPlatformSpecificLogdirParsing( - ntpath, r'C:\lol\cat', {r'C:\lol\cat': None}) - - def testRunName(self): - self.assertPlatformSpecificLogdirParsing( - posixpath, 'lol:/cat', {'/cat': 'lol'}) - self.assertPlatformSpecificLogdirParsing( - ntpath, 'lol:C:\\cat', {'C:\\cat': 'lol'}) - - def testPathWithColonThatComesAfterASlash_isNotConsideredARunName(self): - self.assertPlatformSpecificLogdirParsing( - posixpath, '/lol:/cat', {'/lol:/cat': None}) - - def testExpandsUser(self): - oldhome = os.environ.get('HOME', None) - try: - os.environ['HOME'] = '/usr/eliza' - self.assertPlatformSpecificLogdirParsing( - posixpath, '~/lol/cat~dog', {'/usr/eliza/lol/cat~dog': None}) - os.environ['HOME'] = r'C:\Users\eliza' - self.assertPlatformSpecificLogdirParsing( - ntpath, r'~\lol\cat~dog', {r'C:\Users\eliza\lol\cat~dog': None}) - finally: - if oldhome is not None: - os.environ['HOME'] = oldhome - - def testExpandsUserForMultipleDirectories(self): - oldhome = os.environ.get('HOME', None) - try: - os.environ['HOME'] = '/usr/eliza' - self.assertPlatformSpecificLogdirParsing( - posixpath, 'a:~/lol,b:~/cat', - {'/usr/eliza/lol': 'a', '/usr/eliza/cat': 'b'}) - os.environ['HOME'] = r'C:\Users\eliza' - self.assertPlatformSpecificLogdirParsing( - ntpath, r'aa:~\lol,bb:~\cat', - {r'C:\Users\eliza\lol': 'aa', r'C:\Users\eliza\cat': 'bb'}) - finally: - if oldhome is not None: - os.environ['HOME'] = oldhome - - def testMultipleDirectories(self): - self.assertPlatformSpecificLogdirParsing( - posixpath, '/a,/b', {'/a': None, '/b': None}) - self.assertPlatformSpecificLogdirParsing( - ntpath, 'C:\\a,C:\\b', {'C:\\a': None, 'C:\\b': None}) - - def testNormalizesPaths(self): - self.assertPlatformSpecificLogdirParsing( - posixpath, '/lol/.//cat/../cat', {'/lol/cat': None}) - self.assertPlatformSpecificLogdirParsing( - ntpath, 'C:\\lol\\.\\\\cat\\..\\cat', {'C:\\lol\\cat': None}) - - def testAbsolutifies(self): - self.assertPlatformSpecificLogdirParsing( - posixpath, 'lol/cat', {posixpath.realpath('lol/cat'): None}) - self.assertPlatformSpecificLogdirParsing( - ntpath, 'lol\\cat', {ntpath.realpath('lol\\cat'): None}) - - def testRespectsGCSPath(self): - self.assertPlatformSpecificLogdirParsing( - posixpath, 'gs://foo/path', {'gs://foo/path': None}) - self.assertPlatformSpecificLogdirParsing( - ntpath, 'gs://foo/path', {'gs://foo/path': None}) - - def testRespectsHDFSPath(self): - self.assertPlatformSpecificLogdirParsing( - posixpath, 'hdfs://foo/path', {'hdfs://foo/path': None}) - self.assertPlatformSpecificLogdirParsing( - ntpath, 'hdfs://foo/path', {'hdfs://foo/path': None}) - - def testDoesNotExpandUserInGCSPath(self): - self.assertPlatformSpecificLogdirParsing( - posixpath, 'gs://~/foo/path', {'gs://~/foo/path': None}) - self.assertPlatformSpecificLogdirParsing( - ntpath, 'gs://~/foo/path', {'gs://~/foo/path': None}) - - def testDoesNotNormalizeGCSPath(self): - self.assertPlatformSpecificLogdirParsing( - posixpath, 'gs://foo/./path//..', {'gs://foo/./path//..': None}) - self.assertPlatformSpecificLogdirParsing( - ntpath, 'gs://foo/./path//..', {'gs://foo/./path//..': None}) - - def testRunNameWithGCSPath(self): - self.assertPlatformSpecificLogdirParsing( - posixpath, 'lol:gs://foo/path', {'gs://foo/path': 'lol'}) - self.assertPlatformSpecificLogdirParsing( - ntpath, 'lol:gs://foo/path', {'gs://foo/path': 'lol'}) - - def testSingleLetterGroup(self): - self.assertPlatformSpecificLogdirParsing( - posixpath, 'A:/foo/path', {'/foo/path': 'A'}) - # single letter groups are not supported on Windows - with self.assertRaises(AssertionError): - self.assertPlatformSpecificLogdirParsing( - ntpath, 'A:C:\\foo\\path', {'C:\\foo\\path': 'A'}) + def assertPlatformSpecificLogdirParsing(self, pathObj, logdir, expected): + """A custom assertion to test :func:`parse_event_files_spec` under + various systems. + + Args: + pathObj: a custom replacement object for `os.path`, typically + `posixpath` or `ntpath` + logdir: the string to be parsed by + :func:`~application.parse_event_files_spec` + expected: the expected dictionary as returned by + :func:`~application.parse_event_files_spec` + """ + + with mock.patch("os.path", pathObj): + self.assertEqual( + application.parse_event_files_spec(logdir), expected + ) + + def testBasic(self): + self.assertPlatformSpecificLogdirParsing( + posixpath, "/lol/cat", {"/lol/cat": None} + ) + self.assertPlatformSpecificLogdirParsing( + ntpath, r"C:\lol\cat", {r"C:\lol\cat": None} + ) + + def testRunName(self): + self.assertPlatformSpecificLogdirParsing( + posixpath, "lol:/cat", {"/cat": "lol"} + ) + self.assertPlatformSpecificLogdirParsing( + ntpath, "lol:C:\\cat", {"C:\\cat": "lol"} + ) + + def testPathWithColonThatComesAfterASlash_isNotConsideredARunName(self): + self.assertPlatformSpecificLogdirParsing( + posixpath, "/lol:/cat", {"/lol:/cat": None} + ) + + def testExpandsUser(self): + oldhome = os.environ.get("HOME", None) + try: + os.environ["HOME"] = "/usr/eliza" + self.assertPlatformSpecificLogdirParsing( + posixpath, "~/lol/cat~dog", {"/usr/eliza/lol/cat~dog": None} + ) + os.environ["HOME"] = r"C:\Users\eliza" + self.assertPlatformSpecificLogdirParsing( + ntpath, r"~\lol\cat~dog", {r"C:\Users\eliza\lol\cat~dog": None} + ) + finally: + if oldhome is not None: + os.environ["HOME"] = oldhome + + def testExpandsUserForMultipleDirectories(self): + oldhome = os.environ.get("HOME", None) + try: + os.environ["HOME"] = "/usr/eliza" + self.assertPlatformSpecificLogdirParsing( + posixpath, + "a:~/lol,b:~/cat", + {"/usr/eliza/lol": "a", "/usr/eliza/cat": "b"}, + ) + os.environ["HOME"] = r"C:\Users\eliza" + self.assertPlatformSpecificLogdirParsing( + ntpath, + r"aa:~\lol,bb:~\cat", + {r"C:\Users\eliza\lol": "aa", r"C:\Users\eliza\cat": "bb"}, + ) + finally: + if oldhome is not None: + os.environ["HOME"] = oldhome + + def testMultipleDirectories(self): + self.assertPlatformSpecificLogdirParsing( + posixpath, "/a,/b", {"/a": None, "/b": None} + ) + self.assertPlatformSpecificLogdirParsing( + ntpath, "C:\\a,C:\\b", {"C:\\a": None, "C:\\b": None} + ) + + def testNormalizesPaths(self): + self.assertPlatformSpecificLogdirParsing( + posixpath, "/lol/.//cat/../cat", {"/lol/cat": None} + ) + self.assertPlatformSpecificLogdirParsing( + ntpath, "C:\\lol\\.\\\\cat\\..\\cat", {"C:\\lol\\cat": None} + ) + + def testAbsolutifies(self): + self.assertPlatformSpecificLogdirParsing( + posixpath, "lol/cat", {posixpath.realpath("lol/cat"): None} + ) + self.assertPlatformSpecificLogdirParsing( + ntpath, "lol\\cat", {ntpath.realpath("lol\\cat"): None} + ) + + def testRespectsGCSPath(self): + self.assertPlatformSpecificLogdirParsing( + posixpath, "gs://foo/path", {"gs://foo/path": None} + ) + self.assertPlatformSpecificLogdirParsing( + ntpath, "gs://foo/path", {"gs://foo/path": None} + ) + + def testRespectsHDFSPath(self): + self.assertPlatformSpecificLogdirParsing( + posixpath, "hdfs://foo/path", {"hdfs://foo/path": None} + ) + self.assertPlatformSpecificLogdirParsing( + ntpath, "hdfs://foo/path", {"hdfs://foo/path": None} + ) + + def testDoesNotExpandUserInGCSPath(self): + self.assertPlatformSpecificLogdirParsing( + posixpath, "gs://~/foo/path", {"gs://~/foo/path": None} + ) + self.assertPlatformSpecificLogdirParsing( + ntpath, "gs://~/foo/path", {"gs://~/foo/path": None} + ) + + def testDoesNotNormalizeGCSPath(self): + self.assertPlatformSpecificLogdirParsing( + posixpath, "gs://foo/./path//..", {"gs://foo/./path//..": None} + ) + self.assertPlatformSpecificLogdirParsing( + ntpath, "gs://foo/./path//..", {"gs://foo/./path//..": None} + ) + + def testRunNameWithGCSPath(self): + self.assertPlatformSpecificLogdirParsing( + posixpath, "lol:gs://foo/path", {"gs://foo/path": "lol"} + ) + self.assertPlatformSpecificLogdirParsing( + ntpath, "lol:gs://foo/path", {"gs://foo/path": "lol"} + ) + + def testSingleLetterGroup(self): + self.assertPlatformSpecificLogdirParsing( + posixpath, "A:/foo/path", {"/foo/path": "A"} + ) + # single letter groups are not supported on Windows + with self.assertRaises(AssertionError): + self.assertPlatformSpecificLogdirParsing( + ntpath, "A:C:\\foo\\path", {"C:\\foo\\path": "A"} + ) class TensorBoardPluginsTest(tb_test.TestCase): + def setUp(self): + self.context = None + dummy_assets_zip_provider = lambda: None + # The application should have added routes for both plugins. + self.app = application.standard_tensorboard_wsgi( + FakeFlags(logdir=self.get_temp_dir()), + [ + FakePluginLoader( + plugin_name="foo", + is_active_value=True, + routes_mapping={"/foo_route": self._foo_handler}, + construction_callback=self._construction_callback, + ), + FakePluginLoader( + plugin_name="bar", + is_active_value=True, + routes_mapping={ + "/bar_route": self._bar_handler, + "/wildcard/*": self._wildcard_handler, + "/wildcard/special/*": self._wildcard_special_handler, + "/wildcard/special/exact": self._foo_handler, + }, + construction_callback=self._construction_callback, + ), + FakePluginLoader( + plugin_name="whoami", + routes_mapping={"/eid": self._eid_handler,}, + ), + ], + dummy_assets_zip_provider, + ) + + self.server = werkzeug_test.Client(self.app, wrappers.BaseResponse) + + def _construction_callback(self, context): + """Called when a plugin is constructed.""" + self.context = context + + def _test_route(self, route, expected_status_code): + response = self.server.get(route) + self.assertEqual(response.status_code, expected_status_code) - def setUp(self): - self.context = None - dummy_assets_zip_provider = lambda: None - # The application should have added routes for both plugins. - self.app = application.standard_tensorboard_wsgi( - FakeFlags(logdir=self.get_temp_dir()), - [ - FakePluginLoader( - plugin_name='foo', - is_active_value=True, - routes_mapping={'/foo_route': self._foo_handler}, - construction_callback=self._construction_callback), - FakePluginLoader( - plugin_name='bar', - is_active_value=True, - routes_mapping={ - '/bar_route': self._bar_handler, - '/wildcard/*': self._wildcard_handler, - '/wildcard/special/*': self._wildcard_special_handler, - '/wildcard/special/exact': self._foo_handler, - }, - construction_callback=self._construction_callback), - FakePluginLoader( - plugin_name='whoami', - routes_mapping={ - '/eid': self._eid_handler, - }), - ], - dummy_assets_zip_provider) - - self.server = werkzeug_test.Client(self.app, wrappers.BaseResponse) - - def _construction_callback(self, context): - """Called when a plugin is constructed.""" - self.context = context - - def _test_route(self, route, expected_status_code): - response = self.server.get(route) - self.assertEqual(response.status_code, expected_status_code) - - @wrappers.Request.application - def _foo_handler(self, request): - return wrappers.Response(response='hello world', status=200) - - def _bar_handler(self): - pass - - @wrappers.Request.application - def _eid_handler(self, request): - eid = plugin_util.experiment_id(request.environ) - body = json.dumps({'experiment_id': eid}) - return wrappers.Response(body, 200, content_type='application/json') - - @wrappers.Request.application - def _wildcard_handler(self, request): - if request.path == '/data/plugin/bar/wildcard/ok': - return wrappers.Response(response='hello world', status=200) - elif request.path == '/data/plugin/bar/wildcard/': - # this route cannot actually be hit; see testEmptyWildcardRouteWithSlash. - return wrappers.Response(response='hello world', status=200) - else: - return wrappers.Response(status=401) - - @wrappers.Request.application - def _wildcard_special_handler(self, request): - return wrappers.Response(status=300) - - def testPluginsAdded(self): - # The routes are prefixed with /data/plugin/[plugin name]. - expected_routes = frozenset(( - '/data/plugin/foo/foo_route', - '/data/plugin/bar/bar_route', - )) - self.assertLessEqual(expected_routes, frozenset(self.app.exact_routes)) - - def testNameToPluginMapping(self): - # The mapping from plugin name to instance should include all plugins. - mapping = self.context.plugin_name_to_instance - self.assertItemsEqual(['foo', 'bar', 'whoami'], list(mapping.keys())) - self.assertEqual('foo', mapping['foo'].plugin_name) - self.assertEqual('bar', mapping['bar'].plugin_name) - self.assertEqual('whoami', mapping['whoami'].plugin_name) - - def testNormalRoute(self): - self._test_route('/data/plugin/foo/foo_route', 200) - - def testNormalRouteIsNotWildcard(self): - self._test_route('/data/plugin/foo/foo_route/bogus', 404) - - def testMissingRoute(self): - self._test_route('/data/plugin/foo/bogus', 404) - - def testExperimentIdIntegration_withNoExperimentId(self): - response = self.server.get('/data/plugin/whoami/eid') - self.assertEqual(response.status_code, 200) - data = json.loads(response.get_data().decode('utf-8')) - self.assertEqual(data, {'experiment_id': ''}) - - def testExperimentIdIntegration_withExperimentId(self): - response = self.server.get('/experiment/123/data/plugin/whoami/eid') - self.assertEqual(response.status_code, 200) - data = json.loads(response.get_data().decode('utf-8')) - self.assertEqual(data, {'experiment_id': '123'}) - - def testEmptyRoute(self): - self._test_route('', 301) - - def testSlashlessRoute(self): - self._test_route('runaway', 404) - - def testWildcardAcceptedRoute(self): - self._test_route('/data/plugin/bar/wildcard/ok', 200) - - def testLongerWildcardRouteTakesPrecedence(self): - # This tests that the longer 'special' wildcard takes precedence over - # the shorter one. - self._test_route('/data/plugin/bar/wildcard/special/blah', 300) - - def testExactRouteTakesPrecedence(self): - # This tests that an exact match takes precedence over a wildcard. - self._test_route('/data/plugin/bar/wildcard/special/exact', 200) - - def testWildcardRejectedRoute(self): - # A plugin may reject a request passed to it via a wildcard route. - # Note our test plugin returns 401 in this case, to distinguish this - # response from a 404 passed if the route were not found. - self._test_route('/data/plugin/bar/wildcard/bogus', 401) - - def testWildcardRouteWithoutSlash(self): - # A wildcard route requires a slash before the '*'. - # Lacking one, no route is matched. - self._test_route('/data/plugin/bar/wildcard', 404) - - def testEmptyWildcardRouteWithSlash(self): - # A wildcard route requires a slash before the '*'. Here we include the - # slash, so we might expect the route to match. - # - # However: Trailing slashes are automatically removed from incoming requests - # in _clean_path(). Consequently, this request does not match the wildcard - # route after all. - # - # Note the test plugin specifically accepts this route (returning 200), so - # the fact that 404 is returned demonstrates that the plugin was not - # consulted. - self._test_route('/data/plugin/bar/wildcard/', 404) + @wrappers.Request.application + def _foo_handler(self, request): + return wrappers.Response(response="hello world", status=200) + def _bar_handler(self): + pass -class DbTest(tb_test.TestCase): + @wrappers.Request.application + def _eid_handler(self, request): + eid = plugin_util.experiment_id(request.environ) + body = json.dumps({"experiment_id": eid}) + return wrappers.Response(body, 200, content_type="application/json") - def testSqliteDb(self): - db_uri = 'sqlite:' + os.path.join(self.get_temp_dir(), 'db') - db_connection_provider = application.create_sqlite_connection_provider( - db_uri) - with contextlib.closing(db_connection_provider()) as conn: - with conn: - with contextlib.closing(conn.cursor()) as c: - c.execute('create table peeps (name text)') - c.execute('insert into peeps (name) values (?)', ('justine',)) - db_connection_provider = application.create_sqlite_connection_provider( - db_uri) - with contextlib.closing(db_connection_provider()) as conn: - with contextlib.closing(conn.cursor()) as c: - c.execute('select name from peeps') - self.assertEqual(('justine',), c.fetchone()) - - def testTransactionRollback(self): - db_uri = 'sqlite:' + os.path.join(self.get_temp_dir(), 'db') - db_connection_provider = application.create_sqlite_connection_provider( - db_uri) - with contextlib.closing(db_connection_provider()) as conn: - with conn: - with contextlib.closing(conn.cursor()) as c: - c.execute('create table peeps (name text)') - try: - with conn: - with contextlib.closing(conn.cursor()) as c: - c.execute('insert into peeps (name) values (?)', ('justine',)) - raise IOError('hi') - except IOError: - pass - with contextlib.closing(conn.cursor()) as c: - c.execute('select name from peeps') - self.assertIsNone(c.fetchone()) - - def testTransactionRollback_doesntDoAnythingIfIsolationLevelIsNone(self): - # NOTE: This is a terrible idea. Don't do this. - db_uri = ('sqlite:' + os.path.join(self.get_temp_dir(), 'db') + - '?isolation_level=null') - db_connection_provider = application.create_sqlite_connection_provider( - db_uri) - with contextlib.closing(db_connection_provider()) as conn: - with conn: - with contextlib.closing(conn.cursor()) as c: - c.execute('create table peeps (name text)') - try: - with conn: - with contextlib.closing(conn.cursor()) as c: - c.execute('insert into peeps (name) values (?)', ('justine',)) - raise IOError('hi') - except IOError: - pass - with contextlib.closing(conn.cursor()) as c: - c.execute('select name from peeps') - self.assertEqual(('justine',), c.fetchone()) + @wrappers.Request.application + def _wildcard_handler(self, request): + if request.path == "/data/plugin/bar/wildcard/ok": + return wrappers.Response(response="hello world", status=200) + elif request.path == "/data/plugin/bar/wildcard/": + # this route cannot actually be hit; see testEmptyWildcardRouteWithSlash. + return wrappers.Response(response="hello world", status=200) + else: + return wrappers.Response(status=401) - def testSqliteUriErrors(self): - with self.assertRaises(ValueError): - application.create_sqlite_connection_provider("lol:cat") - with self.assertRaises(ValueError): - application.create_sqlite_connection_provider("sqlite::memory:") - with self.assertRaises(ValueError): - application.create_sqlite_connection_provider("sqlite://foo.example/bar") + @wrappers.Request.application + def _wildcard_special_handler(self, request): + return wrappers.Response(status=300) + + def testPluginsAdded(self): + # The routes are prefixed with /data/plugin/[plugin name]. + expected_routes = frozenset( + ("/data/plugin/foo/foo_route", "/data/plugin/bar/bar_route",) + ) + self.assertLessEqual(expected_routes, frozenset(self.app.exact_routes)) + + def testNameToPluginMapping(self): + # The mapping from plugin name to instance should include all plugins. + mapping = self.context.plugin_name_to_instance + self.assertItemsEqual(["foo", "bar", "whoami"], list(mapping.keys())) + self.assertEqual("foo", mapping["foo"].plugin_name) + self.assertEqual("bar", mapping["bar"].plugin_name) + self.assertEqual("whoami", mapping["whoami"].plugin_name) + + def testNormalRoute(self): + self._test_route("/data/plugin/foo/foo_route", 200) + + def testNormalRouteIsNotWildcard(self): + self._test_route("/data/plugin/foo/foo_route/bogus", 404) + + def testMissingRoute(self): + self._test_route("/data/plugin/foo/bogus", 404) + + def testExperimentIdIntegration_withNoExperimentId(self): + response = self.server.get("/data/plugin/whoami/eid") + self.assertEqual(response.status_code, 200) + data = json.loads(response.get_data().decode("utf-8")) + self.assertEqual(data, {"experiment_id": ""}) + + def testExperimentIdIntegration_withExperimentId(self): + response = self.server.get("/experiment/123/data/plugin/whoami/eid") + self.assertEqual(response.status_code, 200) + data = json.loads(response.get_data().decode("utf-8")) + self.assertEqual(data, {"experiment_id": "123"}) + + def testEmptyRoute(self): + self._test_route("", 301) + + def testSlashlessRoute(self): + self._test_route("runaway", 404) + + def testWildcardAcceptedRoute(self): + self._test_route("/data/plugin/bar/wildcard/ok", 200) + + def testLongerWildcardRouteTakesPrecedence(self): + # This tests that the longer 'special' wildcard takes precedence over + # the shorter one. + self._test_route("/data/plugin/bar/wildcard/special/blah", 300) + + def testExactRouteTakesPrecedence(self): + # This tests that an exact match takes precedence over a wildcard. + self._test_route("/data/plugin/bar/wildcard/special/exact", 200) + + def testWildcardRejectedRoute(self): + # A plugin may reject a request passed to it via a wildcard route. + # Note our test plugin returns 401 in this case, to distinguish this + # response from a 404 passed if the route were not found. + self._test_route("/data/plugin/bar/wildcard/bogus", 401) + + def testWildcardRouteWithoutSlash(self): + # A wildcard route requires a slash before the '*'. + # Lacking one, no route is matched. + self._test_route("/data/plugin/bar/wildcard", 404) + + def testEmptyWildcardRouteWithSlash(self): + # A wildcard route requires a slash before the '*'. Here we include the + # slash, so we might expect the route to match. + # + # However: Trailing slashes are automatically removed from incoming requests + # in _clean_path(). Consequently, this request does not match the wildcard + # route after all. + # + # Note the test plugin specifically accepts this route (returning 200), so + # the fact that 404 is returned demonstrates that the plugin was not + # consulted. + self._test_route("/data/plugin/bar/wildcard/", 404) -if __name__ == '__main__': - tb_test.main() +class DbTest(tb_test.TestCase): + def testSqliteDb(self): + db_uri = "sqlite:" + os.path.join(self.get_temp_dir(), "db") + db_connection_provider = application.create_sqlite_connection_provider( + db_uri + ) + with contextlib.closing(db_connection_provider()) as conn: + with conn: + with contextlib.closing(conn.cursor()) as c: + c.execute("create table peeps (name text)") + c.execute( + "insert into peeps (name) values (?)", ("justine",) + ) + db_connection_provider = application.create_sqlite_connection_provider( + db_uri + ) + with contextlib.closing(db_connection_provider()) as conn: + with contextlib.closing(conn.cursor()) as c: + c.execute("select name from peeps") + self.assertEqual(("justine",), c.fetchone()) + + def testTransactionRollback(self): + db_uri = "sqlite:" + os.path.join(self.get_temp_dir(), "db") + db_connection_provider = application.create_sqlite_connection_provider( + db_uri + ) + with contextlib.closing(db_connection_provider()) as conn: + with conn: + with contextlib.closing(conn.cursor()) as c: + c.execute("create table peeps (name text)") + try: + with conn: + with contextlib.closing(conn.cursor()) as c: + c.execute( + "insert into peeps (name) values (?)", ("justine",) + ) + raise IOError("hi") + except IOError: + pass + with contextlib.closing(conn.cursor()) as c: + c.execute("select name from peeps") + self.assertIsNone(c.fetchone()) + + def testTransactionRollback_doesntDoAnythingIfIsolationLevelIsNone(self): + # NOTE: This is a terrible idea. Don't do this. + db_uri = ( + "sqlite:" + + os.path.join(self.get_temp_dir(), "db") + + "?isolation_level=null" + ) + db_connection_provider = application.create_sqlite_connection_provider( + db_uri + ) + with contextlib.closing(db_connection_provider()) as conn: + with conn: + with contextlib.closing(conn.cursor()) as c: + c.execute("create table peeps (name text)") + try: + with conn: + with contextlib.closing(conn.cursor()) as c: + c.execute( + "insert into peeps (name) values (?)", ("justine",) + ) + raise IOError("hi") + except IOError: + pass + with contextlib.closing(conn.cursor()) as c: + c.execute("select name from peeps") + self.assertEqual(("justine",), c.fetchone()) + + def testSqliteUriErrors(self): + with self.assertRaises(ValueError): + application.create_sqlite_connection_provider("lol:cat") + with self.assertRaises(ValueError): + application.create_sqlite_connection_provider("sqlite::memory:") + with self.assertRaises(ValueError): + application.create_sqlite_connection_provider( + "sqlite://foo.example/bar" + ) + + +if __name__ == "__main__": + tb_test.main() diff --git a/tensorboard/backend/empty_path_redirect.py b/tensorboard/backend/empty_path_redirect.py index d2d5be4fe0..69b7cae5fc 100644 --- a/tensorboard/backend/empty_path_redirect.py +++ b/tensorboard/backend/empty_path_redirect.py @@ -31,20 +31,20 @@ class EmptyPathRedirectMiddleware(object): - """WSGI middleware to redirect from "" to "/".""" - - def __init__(self, application): - """Initializes this middleware. - - Args: - application: The WSGI application to wrap (see PEP 3333). - """ - self._application = application - - def __call__(self, environ, start_response): - path = environ.get("PATH_INFO", "") - if path: - return self._application(environ, start_response) - location = environ.get("SCRIPT_NAME", "") + "/" - start_response("301 Moved Permanently", [("Location", location)]) - return [] + """WSGI middleware to redirect from "" to "/".""" + + def __init__(self, application): + """Initializes this middleware. + + Args: + application: The WSGI application to wrap (see PEP 3333). + """ + self._application = application + + def __call__(self, environ, start_response): + path = environ.get("PATH_INFO", "") + if path: + return self._application(environ, start_response) + location = environ.get("SCRIPT_NAME", "") + "/" + start_response("301 Moved Permanently", [("Location", location)]) + return [] diff --git a/tensorboard/backend/empty_path_redirect_test.py b/tensorboard/backend/empty_path_redirect_test.py index 4aac642fda..7a643e14c3 100644 --- a/tensorboard/backend/empty_path_redirect_test.py +++ b/tensorboard/backend/empty_path_redirect_test.py @@ -28,46 +28,48 @@ class EmptyPathRedirectMiddlewareTest(tb_test.TestCase): - """Tests for `EmptyPathRedirectMiddleware`.""" + """Tests for `EmptyPathRedirectMiddleware`.""" - def setUp(self): - super(EmptyPathRedirectMiddlewareTest, self).setUp() - app = werkzeug.Request.application(lambda req: werkzeug.Response(req.path)) - app = empty_path_redirect.EmptyPathRedirectMiddleware(app) - app = self._lax_strip_foo_middleware(app) - self.app = app - self.server = werkzeug_test.Client(self.app, werkzeug.BaseResponse) + def setUp(self): + super(EmptyPathRedirectMiddlewareTest, self).setUp() + app = werkzeug.Request.application( + lambda req: werkzeug.Response(req.path) + ) + app = empty_path_redirect.EmptyPathRedirectMiddleware(app) + app = self._lax_strip_foo_middleware(app) + self.app = app + self.server = werkzeug_test.Client(self.app, werkzeug.BaseResponse) - def _lax_strip_foo_middleware(self, app): - """Strips a `/foo` prefix if it exists; no-op otherwise.""" + def _lax_strip_foo_middleware(self, app): + """Strips a `/foo` prefix if it exists; no-op otherwise.""" - def wrapper(environ, start_response): - path = environ.get("PATH_INFO", "") - if path.startswith("/foo"): - environ["PATH_INFO"] = path[len("/foo") :] - environ["SCRIPT_NAME"] = "/foo" - return app(environ, start_response) + def wrapper(environ, start_response): + path = environ.get("PATH_INFO", "") + if path.startswith("/foo"): + environ["PATH_INFO"] = path[len("/foo") :] + environ["SCRIPT_NAME"] = "/foo" + return app(environ, start_response) - return wrapper + return wrapper - def test_normal_route_not_redirected(self): - response = self.server.get("/foo/bar") - self.assertEqual(response.status_code, 200) + def test_normal_route_not_redirected(self): + response = self.server.get("/foo/bar") + self.assertEqual(response.status_code, 200) - def test_slash_not_redirected(self): - response = self.server.get("/foo/") - self.assertEqual(response.status_code, 200) + def test_slash_not_redirected(self): + response = self.server.get("/foo/") + self.assertEqual(response.status_code, 200) - def test_empty_redirected_with_script_name(self): - response = self.server.get("/foo") - self.assertEqual(response.status_code, 301) - self.assertEqual(response.headers["Location"], "/foo/") + def test_empty_redirected_with_script_name(self): + response = self.server.get("/foo") + self.assertEqual(response.status_code, 301) + self.assertEqual(response.headers["Location"], "/foo/") - def test_empty_redirected_with_blank_script_name(self): - response = self.server.get("") - self.assertEqual(response.status_code, 301) - self.assertEqual(response.headers["Location"], "/") + def test_empty_redirected_with_blank_script_name(self): + response = self.server.get("") + self.assertEqual(response.status_code, 301) + self.assertEqual(response.headers["Location"], "/") if __name__ == "__main__": - tb_test.main() + tb_test.main() diff --git a/tensorboard/backend/event_processing/data_provider.py b/tensorboard/backend/event_processing/data_provider.py index 859ed7be73..aca177ef51 100644 --- a/tensorboard/backend/event_processing/data_provider.py +++ b/tensorboard/backend/event_processing/data_provider.py @@ -35,285 +35,292 @@ class MultiplexerDataProvider(provider.DataProvider): - def __init__(self, multiplexer, logdir): - """Trivial initializer. - - Args: - multiplexer: A `plugin_event_multiplexer.EventMultiplexer` (note: - not a boring old `event_multiplexer.EventMultiplexer`). - logdir: The log directory from which data is being read. Only used - cosmetically. Should be a `str`. - """ - self._multiplexer = multiplexer - self._logdir = logdir - - def _validate_experiment_id(self, experiment_id): - # This data provider doesn't consume the experiment ID at all, but - # as a courtesy to callers we require that it be a valid string, to - # help catch usage errors. - if not isinstance(experiment_id, str): - raise TypeError( - "experiment_id must be %r, but got %r: %r" - % (str, type(experiment_id), experiment_id) - ) - - def _test_run_tag(self, run_tag_filter, run, tag): - runs = run_tag_filter.runs - if runs is not None and run not in runs: - return False - tags = run_tag_filter.tags - if tags is not None and tag not in tags: - return False - return True - - def _get_first_event_timestamp(self, run_name): - try: - return self._multiplexer.FirstEventTimestamp(run_name) - except ValueError as e: - return None - - def data_location(self, experiment_id): - self._validate_experiment_id(experiment_id) - return str(self._logdir) - - def list_runs(self, experiment_id): - self._validate_experiment_id(experiment_id) - return [ - provider.Run( - run_id=run, # use names as IDs - run_name=run, - start_time=self._get_first_event_timestamp(run), + def __init__(self, multiplexer, logdir): + """Trivial initializer. + + Args: + multiplexer: A `plugin_event_multiplexer.EventMultiplexer` (note: + not a boring old `event_multiplexer.EventMultiplexer`). + logdir: The log directory from which data is being read. Only used + cosmetically. Should be a `str`. + """ + self._multiplexer = multiplexer + self._logdir = logdir + + def _validate_experiment_id(self, experiment_id): + # This data provider doesn't consume the experiment ID at all, but + # as a courtesy to callers we require that it be a valid string, to + # help catch usage errors. + if not isinstance(experiment_id, str): + raise TypeError( + "experiment_id must be %r, but got %r: %r" + % (str, type(experiment_id), experiment_id) + ) + + def _test_run_tag(self, run_tag_filter, run, tag): + runs = run_tag_filter.runs + if runs is not None and run not in runs: + return False + tags = run_tag_filter.tags + if tags is not None and tag not in tags: + return False + return True + + def _get_first_event_timestamp(self, run_name): + try: + return self._multiplexer.FirstEventTimestamp(run_name) + except ValueError as e: + return None + + def data_location(self, experiment_id): + self._validate_experiment_id(experiment_id) + return str(self._logdir) + + def list_runs(self, experiment_id): + self._validate_experiment_id(experiment_id) + return [ + provider.Run( + run_id=run, # use names as IDs + run_name=run, + start_time=self._get_first_event_timestamp(run), + ) + for run in self._multiplexer.Runs() + ] + + def list_scalars(self, experiment_id, plugin_name, run_tag_filter=None): + self._validate_experiment_id(experiment_id) + run_tag_content = self._multiplexer.PluginRunToTagToContent(plugin_name) + return self._list( + provider.ScalarTimeSeries, run_tag_content, run_tag_filter ) - for run in self._multiplexer.Runs() - ] - - def list_scalars(self, experiment_id, plugin_name, run_tag_filter=None): - self._validate_experiment_id(experiment_id) - run_tag_content = self._multiplexer.PluginRunToTagToContent(plugin_name) - return self._list( - provider.ScalarTimeSeries, run_tag_content, run_tag_filter - ) - def read_scalars( - self, experiment_id, plugin_name, downsample=None, run_tag_filter=None - ): - # TODO(@wchargin): Downsampling not implemented, as the multiplexer - # is already downsampled. We could downsample on top of the existing - # sampling, which would be nice for testing. - del downsample # ignored for now - index = self.list_scalars( - experiment_id, plugin_name, run_tag_filter=run_tag_filter - ) - return self._read(_convert_scalar_event, index) + def read_scalars( + self, experiment_id, plugin_name, downsample=None, run_tag_filter=None + ): + # TODO(@wchargin): Downsampling not implemented, as the multiplexer + # is already downsampled. We could downsample on top of the existing + # sampling, which would be nice for testing. + del downsample # ignored for now + index = self.list_scalars( + experiment_id, plugin_name, run_tag_filter=run_tag_filter + ) + return self._read(_convert_scalar_event, index) - def list_tensors(self, experiment_id, plugin_name, run_tag_filter=None): - self._validate_experiment_id(experiment_id) - run_tag_content = self._multiplexer.PluginRunToTagToContent(plugin_name) - return self._list( - provider.TensorTimeSeries, run_tag_content, run_tag_filter - ) + def list_tensors(self, experiment_id, plugin_name, run_tag_filter=None): + self._validate_experiment_id(experiment_id) + run_tag_content = self._multiplexer.PluginRunToTagToContent(plugin_name) + return self._list( + provider.TensorTimeSeries, run_tag_content, run_tag_filter + ) - def read_tensors( - self, experiment_id, plugin_name, downsample=None, run_tag_filter=None - ): - # TODO(@wchargin): Downsampling not implemented, as the multiplexer - # is already downsampled. We could downsample on top of the existing - # sampling, which would be nice for testing. - del downsample # ignored for now - index = self.list_tensors( - experiment_id, plugin_name, run_tag_filter=run_tag_filter - ) - return self._read(_convert_tensor_event, index) + def read_tensors( + self, experiment_id, plugin_name, downsample=None, run_tag_filter=None + ): + # TODO(@wchargin): Downsampling not implemented, as the multiplexer + # is already downsampled. We could downsample on top of the existing + # sampling, which would be nice for testing. + del downsample # ignored for now + index = self.list_tensors( + experiment_id, plugin_name, run_tag_filter=run_tag_filter + ) + return self._read(_convert_tensor_event, index) + + def _list(self, construct_time_series, run_tag_content, run_tag_filter): + """Helper to list scalar or tensor time series. + + Args: + construct_time_series: `ScalarTimeSeries` or `TensorTimeSeries`. + run_tag_content: Result of `_multiplexer.PluginRunToTagToContent(...)`. + run_tag_filter: As given by the client; may be `None`. + + Returns: + A list of objects of type given by `construct_time_series`, + suitable to be returned from `list_scalars` or `list_tensors`. + """ + result = {} + if run_tag_filter is None: + run_tag_filter = provider.RunTagFilter(runs=None, tags=None) + for (run, tag_to_content) in six.iteritems(run_tag_content): + result_for_run = {} + for tag in tag_to_content: + if not self._test_run_tag(run_tag_filter, run, tag): + continue + result[run] = result_for_run + max_step = None + max_wall_time = None + for event in self._multiplexer.Tensors(run, tag): + if max_step is None or max_step < event.step: + max_step = event.step + if max_wall_time is None or max_wall_time < event.wall_time: + max_wall_time = event.wall_time + summary_metadata = self._multiplexer.SummaryMetadata(run, tag) + result_for_run[tag] = construct_time_series( + max_step=max_step, + max_wall_time=max_wall_time, + plugin_content=summary_metadata.plugin_data.content, + description=summary_metadata.summary_description, + display_name=summary_metadata.display_name, + ) + return result + + def _read(self, convert_event, index): + """Helper to read scalar or tensor data from the multiplexer. + + Args: + convert_event: Takes `plugin_event_accumulator.TensorEvent` to + either `provider.ScalarDatum` or `provider.TensorDatum`. + index: The result of `list_scalars` or `list_tensors`. + + Returns: + A dict of dicts of values returned by `convert_event` calls, + suitable to be returned from `read_scalars` or `read_tensors`. + """ + result = {} + for (run, tags_for_run) in six.iteritems(index): + result_for_run = {} + result[run] = result_for_run + for (tag, metadata) in six.iteritems(tags_for_run): + events = self._multiplexer.Tensors(run, tag) + result_for_run[tag] = [convert_event(e) for e in events] + return result + + def list_blob_sequences( + self, experiment_id, plugin_name, run_tag_filter=None + ): + self._validate_experiment_id(experiment_id) + if run_tag_filter is None: + run_tag_filter = provider.RunTagFilter(runs=None, tags=None) + + # TODO(davidsoergel, wchargin): consider images, etc. + # Note this plugin_name can really just be 'graphs' for now; the + # v2 cases are not handled yet. + if plugin_name != graphs_metadata.PLUGIN_NAME: + logger.warn("Directory has no blob data for plugin %r", plugin_name) + return {} + + result = collections.defaultdict(lambda: {}) + for (run, run_info) in six.iteritems(self._multiplexer.Runs()): + tag = None + if not self._test_run_tag(run_tag_filter, run, tag): + continue + if not run_info[plugin_event_accumulator.GRAPH]: + continue + result[run][tag] = provider.BlobSequenceTimeSeries( + max_step=0, + max_wall_time=0, + latest_max_index=0, # Graphs are always one blob at a time + plugin_content=None, + description=None, + display_name=None, + ) + return result + + def read_blob_sequences( + self, experiment_id, plugin_name, downsample=None, run_tag_filter=None + ): + self._validate_experiment_id(experiment_id) + # TODO(davidsoergel, wchargin): consider images, etc. + # Note this plugin_name can really just be 'graphs' for now; the + # v2 cases are not handled yet. + if plugin_name != graphs_metadata.PLUGIN_NAME: + logger.warn("Directory has no blob data for plugin %r", plugin_name) + return {} + + result = collections.defaultdict( + lambda: collections.defaultdict(lambda: []) + ) + for (run, run_info) in six.iteritems(self._multiplexer.Runs()): + tag = None + if not self._test_run_tag(run_tag_filter, run, tag): + continue + if not run_info[plugin_event_accumulator.GRAPH]: + continue + + time_series = result[run][tag] + + wall_time = 0.0 # dummy value for graph + step = 0 # dummy value for graph + index = 0 # dummy value for graph + + # In some situations these blobs may have directly accessible URLs. + # But, for now, we assume they don't. + graph_url = None + graph_blob_key = _encode_blob_key( + experiment_id, plugin_name, run, tag, step, index + ) + blob_ref = provider.BlobReference(graph_blob_key, graph_url) + + datum = provider.BlobSequenceDatum( + wall_time=wall_time, step=step, values=(blob_ref,), + ) + time_series.append(datum) + return result + + def read_blob(self, blob_key): + # note: ignoring nearly all key elements: there is only one graph per run. + ( + unused_experiment_id, + plugin_name, + run, + unused_tag, + unused_step, + unused_index, + ) = _decode_blob_key(blob_key) + + # TODO(davidsoergel, wchargin): consider images, etc. + if plugin_name != graphs_metadata.PLUGIN_NAME: + logger.warn("Directory has no blob data for plugin %r", plugin_name) + raise errors.NotFoundError() + + serialized_graph = self._multiplexer.SerializedGraph(run) + + # TODO(davidsoergel): graph_defs have no step attribute so we don't filter + # on it. Other blob types might, though. + + if serialized_graph is None: + logger.warn("No blob found for key %r", blob_key) + raise errors.NotFoundError() + + # TODO(davidsoergel): consider internal structure of non-graphdef blobs. + # In particular, note we ignore the requested index, since it's always 0. + return serialized_graph - def _list(self, construct_time_series, run_tag_content, run_tag_filter): - """Helper to list scalar or tensor time series. - Args: - construct_time_series: `ScalarTimeSeries` or `TensorTimeSeries`. - run_tag_content: Result of `_multiplexer.PluginRunToTagToContent(...)`. - run_tag_filter: As given by the client; may be `None`. +# TODO(davidsoergel): deduplicate with other implementations +def _encode_blob_key(experiment_id, plugin_name, run, tag, step, index): + """Generate a blob key: a short, URL-safe string identifying a blob. - Returns: - A list of objects of type given by `construct_time_series`, - suitable to be returned from `list_scalars` or `list_tensors`. - """ - result = {} - if run_tag_filter is None: - run_tag_filter = provider.RunTagFilter(runs=None, tags=None) - for (run, tag_to_content) in six.iteritems(run_tag_content): - result_for_run = {} - for tag in tag_to_content: - if not self._test_run_tag(run_tag_filter, run, tag): - continue - result[run] = result_for_run - max_step = None - max_wall_time = None - for event in self._multiplexer.Tensors(run, tag): - if max_step is None or max_step < event.step: - max_step = event.step - if max_wall_time is None or max_wall_time < event.wall_time: - max_wall_time = event.wall_time - summary_metadata = self._multiplexer.SummaryMetadata(run, tag) - result_for_run[tag] = construct_time_series( - max_step=max_step, - max_wall_time=max_wall_time, - plugin_content=summary_metadata.plugin_data.content, - description=summary_metadata.summary_description, - display_name=summary_metadata.display_name, - ) - return result + A blob can be located using a set of integer and string fields; here we + serialize these to allow passing the data through a URL. Specifically, we + 1) construct a tuple of the arguments in order; 2) represent that as an + ascii-encoded JSON string (without whitespace); and 3) take the URL-safe + base64 encoding of that, with no padding. For example: - def _read(self, convert_event, index): - """Helper to read scalar or tensor data from the multiplexer. + 1) Tuple: ("some_id", "graphs", "train", "graph_def", 2, 0) + 2) JSON: ["some_id","graphs","train","graph_def",2,0] + 3) base64: WyJzb21lX2lkIiwiZ3JhcGhzIiwidHJhaW4iLCJncmFwaF9kZWYiLDIsMF0K Args: - convert_event: Takes `plugin_event_accumulator.TensorEvent` to - either `provider.ScalarDatum` or `provider.TensorDatum`. - index: The result of `list_scalars` or `list_tensors`. + experiment_id: a string ID identifying an experiment. + plugin_name: string + run: string + tag: string + step: int + index: int Returns: - A dict of dicts of values returned by `convert_event` calls, - suitable to be returned from `read_scalars` or `read_tensors`. + A URL-safe base64-encoded string representing the provided arguments. """ - result = {} - for (run, tags_for_run) in six.iteritems(index): - result_for_run = {} - result[run] = result_for_run - for (tag, metadata) in six.iteritems(tags_for_run): - events = self._multiplexer.Tensors(run, tag) - result_for_run[tag] = [convert_event(e) for e in events] - return result - - def list_blob_sequences( - self, experiment_id, plugin_name, run_tag_filter=None - ): - self._validate_experiment_id(experiment_id) - if run_tag_filter is None: - run_tag_filter = provider.RunTagFilter(runs=None, tags=None) - - # TODO(davidsoergel, wchargin): consider images, etc. - # Note this plugin_name can really just be 'graphs' for now; the - # v2 cases are not handled yet. - if plugin_name != graphs_metadata.PLUGIN_NAME: - logger.warn("Directory has no blob data for plugin %r", plugin_name) - return {} - - result = collections.defaultdict(lambda: {}) - for (run, run_info) in six.iteritems(self._multiplexer.Runs()): - tag = None - if not self._test_run_tag(run_tag_filter, run, tag): - continue - if not run_info[plugin_event_accumulator.GRAPH]: - continue - result[run][tag] = provider.BlobSequenceTimeSeries( - max_step=0, - max_wall_time=0, - latest_max_index=0, # Graphs are always one blob at a time - plugin_content=None, - description=None, - display_name=None, - ) - return result - - def read_blob_sequences( - self, experiment_id, plugin_name, downsample=None, run_tag_filter=None - ): - self._validate_experiment_id(experiment_id) - # TODO(davidsoergel, wchargin): consider images, etc. - # Note this plugin_name can really just be 'graphs' for now; the - # v2 cases are not handled yet. - if plugin_name != graphs_metadata.PLUGIN_NAME: - logger.warn("Directory has no blob data for plugin %r", plugin_name) - return {} - - result = collections.defaultdict( - lambda: collections.defaultdict(lambda: [])) - for (run, run_info) in six.iteritems(self._multiplexer.Runs()): - tag = None - if not self._test_run_tag(run_tag_filter, run, tag): - continue - if not run_info[plugin_event_accumulator.GRAPH]: - continue - - time_series = result[run][tag] - - wall_time = 0. # dummy value for graph - step = 0 # dummy value for graph - index = 0 # dummy value for graph - - # In some situations these blobs may have directly accessible URLs. - # But, for now, we assume they don't. - graph_url = None - graph_blob_key = _encode_blob_key( - experiment_id, plugin_name, run, tag, step, index) - blob_ref = provider.BlobReference(graph_blob_key, graph_url) - - datum = provider.BlobSequenceDatum( - wall_time=wall_time, - step=step, - values=(blob_ref,), - ) - time_series.append(datum) - return result - - def read_blob(self, blob_key): - # note: ignoring nearly all key elements: there is only one graph per run. - (unused_experiment_id, plugin_name, run, unused_tag, unused_step, - unused_index) = _decode_blob_key(blob_key) - - # TODO(davidsoergel, wchargin): consider images, etc. - if plugin_name != graphs_metadata.PLUGIN_NAME: - logger.warn("Directory has no blob data for plugin %r", plugin_name) - raise errors.NotFoundError() - - serialized_graph = self._multiplexer.SerializedGraph(run) - - # TODO(davidsoergel): graph_defs have no step attribute so we don't filter - # on it. Other blob types might, though. - - if serialized_graph is None: - logger.warn("No blob found for key %r", blob_key) - raise errors.NotFoundError() - - # TODO(davidsoergel): consider internal structure of non-graphdef blobs. - # In particular, note we ignore the requested index, since it's always 0. - return serialized_graph - - -# TODO(davidsoergel): deduplicate with other implementations -def _encode_blob_key(experiment_id, plugin_name, run, tag, step, index): - """Generate a blob key: a short, URL-safe string identifying a blob. - - A blob can be located using a set of integer and string fields; here we - serialize these to allow passing the data through a URL. Specifically, we - 1) construct a tuple of the arguments in order; 2) represent that as an - ascii-encoded JSON string (without whitespace); and 3) take the URL-safe - base64 encoding of that, with no padding. For example: - - 1) Tuple: ("some_id", "graphs", "train", "graph_def", 2, 0) - 2) JSON: ["some_id","graphs","train","graph_def",2,0] - 3) base64: WyJzb21lX2lkIiwiZ3JhcGhzIiwidHJhaW4iLCJncmFwaF9kZWYiLDIsMF0K - - Args: - experiment_id: a string ID identifying an experiment. - plugin_name: string - run: string - tag: string - step: int - index: int - - Returns: - A URL-safe base64-encoded string representing the provided arguments. - """ - # Encodes the blob key as a URL-safe string, as required by the - # `BlobReference` API in `tensorboard/data/provider.py`, because these keys - # may be used to construct URLs for retrieving blobs. - stringified = json.dumps( - (experiment_id, plugin_name, run, tag, step, index), - separators=(",", ":")) - bytesified = stringified.encode("ascii") - encoded = base64.urlsafe_b64encode(bytesified) - return six.ensure_str(encoded).rstrip("=") + # Encodes the blob key as a URL-safe string, as required by the + # `BlobReference` API in `tensorboard/data/provider.py`, because these keys + # may be used to construct URLs for retrieving blobs. + stringified = json.dumps( + (experiment_id, plugin_name, run, tag, step, index), + separators=(",", ":"), + ) + bytesified = stringified.encode("ascii") + encoded = base64.urlsafe_b64encode(bytesified) + return six.ensure_str(encoded).rstrip("=") # Any changes to this function need not be backward-compatible, even though @@ -322,34 +329,36 @@ def _encode_blob_key(experiment_id, plugin_name, run, tag, step, index): # within the context of the session that created them (via the matching # `_encode_blob_key` function above). def _decode_blob_key(key): - """Decode a blob key produced by `_encode_blob_key` into component fields. + """Decode a blob key produced by `_encode_blob_key` into component fields. - Args: - key: a blob key, as generated by `_encode_blob_key`. + Args: + key: a blob key, as generated by `_encode_blob_key`. - Returns: - A tuple of `(experiment_id, plugin_name, run, tag, step, index)`, with types - matching the arguments of `_encode_blob_key`. - """ - decoded = base64.urlsafe_b64decode(key + "==") # pad past a multiple of 4. - stringified = decoded.decode("ascii") - (experiment_id, plugin_name, run, tag, step, index) = json.loads(stringified) - return (experiment_id, plugin_name, run, tag, step, index) + Returns: + A tuple of `(experiment_id, plugin_name, run, tag, step, index)`, with types + matching the arguments of `_encode_blob_key`. + """ + decoded = base64.urlsafe_b64decode(key + "==") # pad past a multiple of 4. + stringified = decoded.decode("ascii") + (experiment_id, plugin_name, run, tag, step, index) = json.loads( + stringified + ) + return (experiment_id, plugin_name, run, tag, step, index) def _convert_scalar_event(event): - """Helper for `read_scalars`.""" - return provider.ScalarDatum( - step=event.step, - wall_time=event.wall_time, - value=tensor_util.make_ndarray(event.tensor_proto).item(), - ) + """Helper for `read_scalars`.""" + return provider.ScalarDatum( + step=event.step, + wall_time=event.wall_time, + value=tensor_util.make_ndarray(event.tensor_proto).item(), + ) def _convert_tensor_event(event): - """Helper for `read_tensors`.""" - return provider.TensorDatum( - step=event.step, - wall_time=event.wall_time, - numpy=tensor_util.make_ndarray(event.tensor_proto), - ) + """Helper for `read_tensors`.""" + return provider.TensorDatum( + step=event.step, + wall_time=event.wall_time, + numpy=tensor_util.make_ndarray(event.tensor_proto), + ) diff --git a/tensorboard/backend/event_processing/data_provider_test.py b/tensorboard/backend/event_processing/data_provider_test.py index 53256c303f..e13ede243d 100644 --- a/tensorboard/backend/event_processing/data_provider_test.py +++ b/tensorboard/backend/event_processing/data_provider_test.py @@ -43,239 +43,267 @@ class MultiplexerDataProviderTest(tf.test.TestCase): - def setUp(self): - super(MultiplexerDataProviderTest, self).setUp() - self.logdir = self.get_temp_dir() - - logdir = os.path.join(self.logdir, "polynomials") - with tf.summary.create_file_writer(logdir).as_default(): - for i in xrange(10): - scalar_summary.scalar("square", i ** 2, step=2 * i, description="boxen") - scalar_summary.scalar("cube", i ** 3, step=3 * i) - - logdir = os.path.join(self.logdir, "waves") - with tf.summary.create_file_writer(logdir).as_default(): - for i in xrange(10): - scalar_summary.scalar("sine", tf.sin(float(i)), step=i) - scalar_summary.scalar("square", tf.sign(tf.sin(float(i))), step=i) - # Summary with rank-0 data but not owned by the scalars plugin. - metadata = summary_pb2.SummaryMetadata() - metadata.plugin_data.plugin_name = "marigraphs" - tf.summary.write("high_tide", tensor=i, step=i, metadata=metadata) - - logdir = os.path.join(self.logdir, "pictures") - with tf.summary.create_file_writer(logdir).as_default(): - colors = [ - ("`#F0F`", (255, 0, 255), "purple"), - ("`#0F0`", (255, 0, 255), "green"), - ] - for (description, rgb, name) in colors: - pixel = tf.constant([[list(rgb)]], dtype=tf.uint8) - for i in xrange(1, 11): - pixels = [tf.tile(pixel, [i, i, 1])] - image_summary.image(name, pixels, step=i, description=description) - - def create_multiplexer(self): - multiplexer = event_multiplexer.EventMultiplexer() - multiplexer.AddRunsFromDirectory(self.logdir) - multiplexer.Reload() - return multiplexer - - def create_provider(self): - multiplexer = self.create_multiplexer() - return data_provider.MultiplexerDataProvider(multiplexer, self.logdir) - - def test_data_location(self): - provider = self.create_provider() - result = provider.data_location(experiment_id="unused") - self.assertEqual(result, self.logdir) - - def test_list_runs(self): - # We can't control the timestamps of events written to disk (without - # manually reading the tfrecords, modifying the data, and writing - # them back out), so we provide a fake multiplexer instead. - start_times = { - "second_2": 2.0, - "first": 1.5, - "no_time": None, - "second_1": 2.0, - } - class FakeMultiplexer(object): - def Runs(multiplexer): - result = ["second_2", "first", "no_time", "second_1"] - self.assertItemsEqual(result, start_times) - return result - - def FirstEventTimestamp(multiplexer, run): - self.assertIn(run, start_times) - result = start_times[run] - if result is None: - raise ValueError("No event timestep could be found") - else: - return result - - multiplexer = FakeMultiplexer() - provider = data_provider.MultiplexerDataProvider(multiplexer, "fake_logdir") - result = provider.list_runs(experiment_id="unused") - self.assertItemsEqual(result, [ - base_provider.Run(run_id=run, run_name=run, start_time=start_time) - for (run, start_time) in six.iteritems(start_times) - ]) - - def test_list_scalars_all(self): - provider = self.create_provider() - result = provider.list_scalars( - experiment_id="unused", - plugin_name=scalar_metadata.PLUGIN_NAME, - run_tag_filter=None, - ) - self.assertItemsEqual(result.keys(), ["polynomials", "waves"]) - self.assertItemsEqual(result["polynomials"].keys(), ["square", "cube"]) - self.assertItemsEqual(result["waves"].keys(), ["square", "sine"]) - sample = result["polynomials"]["square"] - self.assertIsInstance(sample, base_provider.ScalarTimeSeries) - self.assertEqual(sample.max_step, 18) - # nothing to test for wall time, as it can't be mocked out - self.assertEqual(sample.plugin_content, b"") - self.assertEqual(sample.display_name, "") # not written by V2 summary ops - self.assertEqual(sample.description, "boxen") - - def test_list_scalars_filters(self): - provider = self.create_provider() - - result = provider.list_scalars( - experiment_id="unused", - plugin_name=scalar_metadata.PLUGIN_NAME, - run_tag_filter=base_provider.RunTagFilter(["waves"], ["square"]), - ) - self.assertItemsEqual(result.keys(), ["waves"]) - self.assertItemsEqual(result["waves"].keys(), ["square"]) - - result = provider.list_scalars( - experiment_id="unused", - plugin_name=scalar_metadata.PLUGIN_NAME, - run_tag_filter=base_provider.RunTagFilter(tags=["square", "quartic"]), - ) - self.assertItemsEqual(result.keys(), ["polynomials", "waves"]) - self.assertItemsEqual(result["polynomials"].keys(), ["square"]) - self.assertItemsEqual(result["waves"].keys(), ["square"]) - - result = provider.list_scalars( - experiment_id="unused", - plugin_name=scalar_metadata.PLUGIN_NAME, - run_tag_filter=base_provider.RunTagFilter(runs=["waves", "hugs"]), - ) - self.assertItemsEqual(result.keys(), ["waves"]) - self.assertItemsEqual(result["waves"].keys(), ["sine", "square"]) - - result = provider.list_scalars( - experiment_id="unused", - plugin_name=scalar_metadata.PLUGIN_NAME, - run_tag_filter=base_provider.RunTagFilter(["un"], ["likely"]), - ) - self.assertEqual(result, {}) - - def test_read_scalars(self): - multiplexer = self.create_multiplexer() - provider = data_provider.MultiplexerDataProvider(multiplexer, self.logdir) - - run_tag_filter = base_provider.RunTagFilter( - runs=["waves", "polynomials", "unicorns"], - tags=["sine", "square", "cube", "iridescence"], - ) - result = provider.read_scalars( - experiment_id="unused", - plugin_name=scalar_metadata.PLUGIN_NAME, - run_tag_filter=run_tag_filter, - downsample=None, # not yet implemented - ) - - self.assertItemsEqual(result.keys(), ["polynomials", "waves"]) - self.assertItemsEqual(result["polynomials"].keys(), ["square", "cube"]) - self.assertItemsEqual(result["waves"].keys(), ["square", "sine"]) - for run in result: - for tag in result[run]: - tensor_events = multiplexer.Tensors(run, tag) - self.assertLen(result[run][tag], len(tensor_events)) - for (datum, event) in zip(result[run][tag], tensor_events): - self.assertEqual(datum.step, event.step) - self.assertEqual(datum.wall_time, event.wall_time) - self.assertEqual( - datum.value, tensor_util.make_ndarray(event.tensor_proto).item() - ) - - def test_read_scalars_but_not_rank_0(self): - provider = self.create_provider() - run_tag_filter = base_provider.RunTagFilter(["pictures"], ["purple"]) - # No explicit checks yet. - with six.assertRaisesRegex( - self, - ValueError, - "can only convert an array of size 1 to a Python scalar"): - provider.read_scalars( - experiment_id="unused", - plugin_name=image_metadata.PLUGIN_NAME, - run_tag_filter=run_tag_filter, - ) - - def test_list_tensors_all(self): - provider = self.create_provider() - result = provider.list_tensors( - experiment_id="unused", - plugin_name=image_metadata.PLUGIN_NAME, - run_tag_filter=None, - ) - self.assertItemsEqual(result.keys(), ["pictures"]) - self.assertItemsEqual(result["pictures"].keys(), ["purple", "green"]) - sample = result["pictures"]["purple"] - self.assertIsInstance(sample, base_provider.TensorTimeSeries) - self.assertEqual(sample.max_step, 10) - # nothing to test for wall time, as it can't be mocked out - self.assertEqual(sample.plugin_content, b"") - self.assertEqual(sample.display_name, "") # not written by V2 summary ops - self.assertEqual(sample.description, "`#F0F`") - - def test_list_tensors_filters(self): - provider = self.create_provider() - - # Quick check only, as scalars and tensors use the same underlying - # filtering implementation. - result = provider.list_tensors( - experiment_id="unused", - plugin_name=image_metadata.PLUGIN_NAME, - run_tag_filter=base_provider.RunTagFilter(["pictures"], ["green"]), - ) - self.assertItemsEqual(result.keys(), ["pictures"]) - self.assertItemsEqual(result["pictures"].keys(), ["green"]) - - def test_read_tensors(self): - multiplexer = self.create_multiplexer() - provider = data_provider.MultiplexerDataProvider(multiplexer, self.logdir) - - run_tag_filter = base_provider.RunTagFilter( - runs=["pictures"], - tags=["purple", "green"], - ) - result = provider.read_tensors( - experiment_id="unused", - plugin_name=image_metadata.PLUGIN_NAME, - run_tag_filter=run_tag_filter, - downsample=None, # not yet implemented - ) - - self.assertItemsEqual(result.keys(), ["pictures"]) - self.assertItemsEqual(result["pictures"].keys(), ["purple", "green"]) - for run in result: - for tag in result[run]: - tensor_events = multiplexer.Tensors(run, tag) - self.assertLen(result[run][tag], len(tensor_events)) - for (datum, event) in zip(result[run][tag], tensor_events): - self.assertEqual(datum.step, event.step) - self.assertEqual(datum.wall_time, event.wall_time) - np.testing.assert_equal( - datum.numpy, tensor_util.make_ndarray(event.tensor_proto) - ) + def setUp(self): + super(MultiplexerDataProviderTest, self).setUp() + self.logdir = self.get_temp_dir() + + logdir = os.path.join(self.logdir, "polynomials") + with tf.summary.create_file_writer(logdir).as_default(): + for i in xrange(10): + scalar_summary.scalar( + "square", i ** 2, step=2 * i, description="boxen" + ) + scalar_summary.scalar("cube", i ** 3, step=3 * i) + + logdir = os.path.join(self.logdir, "waves") + with tf.summary.create_file_writer(logdir).as_default(): + for i in xrange(10): + scalar_summary.scalar("sine", tf.sin(float(i)), step=i) + scalar_summary.scalar( + "square", tf.sign(tf.sin(float(i))), step=i + ) + # Summary with rank-0 data but not owned by the scalars plugin. + metadata = summary_pb2.SummaryMetadata() + metadata.plugin_data.plugin_name = "marigraphs" + tf.summary.write( + "high_tide", tensor=i, step=i, metadata=metadata + ) + + logdir = os.path.join(self.logdir, "pictures") + with tf.summary.create_file_writer(logdir).as_default(): + colors = [ + ("`#F0F`", (255, 0, 255), "purple"), + ("`#0F0`", (255, 0, 255), "green"), + ] + for (description, rgb, name) in colors: + pixel = tf.constant([[list(rgb)]], dtype=tf.uint8) + for i in xrange(1, 11): + pixels = [tf.tile(pixel, [i, i, 1])] + image_summary.image( + name, pixels, step=i, description=description + ) + + def create_multiplexer(self): + multiplexer = event_multiplexer.EventMultiplexer() + multiplexer.AddRunsFromDirectory(self.logdir) + multiplexer.Reload() + return multiplexer + + def create_provider(self): + multiplexer = self.create_multiplexer() + return data_provider.MultiplexerDataProvider(multiplexer, self.logdir) + + def test_data_location(self): + provider = self.create_provider() + result = provider.data_location(experiment_id="unused") + self.assertEqual(result, self.logdir) + + def test_list_runs(self): + # We can't control the timestamps of events written to disk (without + # manually reading the tfrecords, modifying the data, and writing + # them back out), so we provide a fake multiplexer instead. + start_times = { + "second_2": 2.0, + "first": 1.5, + "no_time": None, + "second_1": 2.0, + } + + class FakeMultiplexer(object): + def Runs(multiplexer): + result = ["second_2", "first", "no_time", "second_1"] + self.assertItemsEqual(result, start_times) + return result + + def FirstEventTimestamp(multiplexer, run): + self.assertIn(run, start_times) + result = start_times[run] + if result is None: + raise ValueError("No event timestep could be found") + else: + return result + + multiplexer = FakeMultiplexer() + provider = data_provider.MultiplexerDataProvider( + multiplexer, "fake_logdir" + ) + result = provider.list_runs(experiment_id="unused") + self.assertItemsEqual( + result, + [ + base_provider.Run( + run_id=run, run_name=run, start_time=start_time + ) + for (run, start_time) in six.iteritems(start_times) + ], + ) + + def test_list_scalars_all(self): + provider = self.create_provider() + result = provider.list_scalars( + experiment_id="unused", + plugin_name=scalar_metadata.PLUGIN_NAME, + run_tag_filter=None, + ) + self.assertItemsEqual(result.keys(), ["polynomials", "waves"]) + self.assertItemsEqual(result["polynomials"].keys(), ["square", "cube"]) + self.assertItemsEqual(result["waves"].keys(), ["square", "sine"]) + sample = result["polynomials"]["square"] + self.assertIsInstance(sample, base_provider.ScalarTimeSeries) + self.assertEqual(sample.max_step, 18) + # nothing to test for wall time, as it can't be mocked out + self.assertEqual(sample.plugin_content, b"") + self.assertEqual( + sample.display_name, "" + ) # not written by V2 summary ops + self.assertEqual(sample.description, "boxen") + + def test_list_scalars_filters(self): + provider = self.create_provider() + + result = provider.list_scalars( + experiment_id="unused", + plugin_name=scalar_metadata.PLUGIN_NAME, + run_tag_filter=base_provider.RunTagFilter(["waves"], ["square"]), + ) + self.assertItemsEqual(result.keys(), ["waves"]) + self.assertItemsEqual(result["waves"].keys(), ["square"]) + + result = provider.list_scalars( + experiment_id="unused", + plugin_name=scalar_metadata.PLUGIN_NAME, + run_tag_filter=base_provider.RunTagFilter( + tags=["square", "quartic"] + ), + ) + self.assertItemsEqual(result.keys(), ["polynomials", "waves"]) + self.assertItemsEqual(result["polynomials"].keys(), ["square"]) + self.assertItemsEqual(result["waves"].keys(), ["square"]) + + result = provider.list_scalars( + experiment_id="unused", + plugin_name=scalar_metadata.PLUGIN_NAME, + run_tag_filter=base_provider.RunTagFilter(runs=["waves", "hugs"]), + ) + self.assertItemsEqual(result.keys(), ["waves"]) + self.assertItemsEqual(result["waves"].keys(), ["sine", "square"]) + + result = provider.list_scalars( + experiment_id="unused", + plugin_name=scalar_metadata.PLUGIN_NAME, + run_tag_filter=base_provider.RunTagFilter(["un"], ["likely"]), + ) + self.assertEqual(result, {}) + + def test_read_scalars(self): + multiplexer = self.create_multiplexer() + provider = data_provider.MultiplexerDataProvider( + multiplexer, self.logdir + ) + + run_tag_filter = base_provider.RunTagFilter( + runs=["waves", "polynomials", "unicorns"], + tags=["sine", "square", "cube", "iridescence"], + ) + result = provider.read_scalars( + experiment_id="unused", + plugin_name=scalar_metadata.PLUGIN_NAME, + run_tag_filter=run_tag_filter, + downsample=None, # not yet implemented + ) + + self.assertItemsEqual(result.keys(), ["polynomials", "waves"]) + self.assertItemsEqual(result["polynomials"].keys(), ["square", "cube"]) + self.assertItemsEqual(result["waves"].keys(), ["square", "sine"]) + for run in result: + for tag in result[run]: + tensor_events = multiplexer.Tensors(run, tag) + self.assertLen(result[run][tag], len(tensor_events)) + for (datum, event) in zip(result[run][tag], tensor_events): + self.assertEqual(datum.step, event.step) + self.assertEqual(datum.wall_time, event.wall_time) + self.assertEqual( + datum.value, + tensor_util.make_ndarray(event.tensor_proto).item(), + ) + + def test_read_scalars_but_not_rank_0(self): + provider = self.create_provider() + run_tag_filter = base_provider.RunTagFilter(["pictures"], ["purple"]) + # No explicit checks yet. + with six.assertRaisesRegex( + self, + ValueError, + "can only convert an array of size 1 to a Python scalar", + ): + provider.read_scalars( + experiment_id="unused", + plugin_name=image_metadata.PLUGIN_NAME, + run_tag_filter=run_tag_filter, + ) + + def test_list_tensors_all(self): + provider = self.create_provider() + result = provider.list_tensors( + experiment_id="unused", + plugin_name=image_metadata.PLUGIN_NAME, + run_tag_filter=None, + ) + self.assertItemsEqual(result.keys(), ["pictures"]) + self.assertItemsEqual(result["pictures"].keys(), ["purple", "green"]) + sample = result["pictures"]["purple"] + self.assertIsInstance(sample, base_provider.TensorTimeSeries) + self.assertEqual(sample.max_step, 10) + # nothing to test for wall time, as it can't be mocked out + self.assertEqual(sample.plugin_content, b"") + self.assertEqual( + sample.display_name, "" + ) # not written by V2 summary ops + self.assertEqual(sample.description, "`#F0F`") + + def test_list_tensors_filters(self): + provider = self.create_provider() + + # Quick check only, as scalars and tensors use the same underlying + # filtering implementation. + result = provider.list_tensors( + experiment_id="unused", + plugin_name=image_metadata.PLUGIN_NAME, + run_tag_filter=base_provider.RunTagFilter(["pictures"], ["green"]), + ) + self.assertItemsEqual(result.keys(), ["pictures"]) + self.assertItemsEqual(result["pictures"].keys(), ["green"]) + + def test_read_tensors(self): + multiplexer = self.create_multiplexer() + provider = data_provider.MultiplexerDataProvider( + multiplexer, self.logdir + ) + + run_tag_filter = base_provider.RunTagFilter( + runs=["pictures"], tags=["purple", "green"], + ) + result = provider.read_tensors( + experiment_id="unused", + plugin_name=image_metadata.PLUGIN_NAME, + run_tag_filter=run_tag_filter, + downsample=None, # not yet implemented + ) + + self.assertItemsEqual(result.keys(), ["pictures"]) + self.assertItemsEqual(result["pictures"].keys(), ["purple", "green"]) + for run in result: + for tag in result[run]: + tensor_events = multiplexer.Tensors(run, tag) + self.assertLen(result[run][tag], len(tensor_events)) + for (datum, event) in zip(result[run][tag], tensor_events): + self.assertEqual(datum.step, event.step) + self.assertEqual(datum.wall_time, event.wall_time) + np.testing.assert_equal( + datum.numpy, + tensor_util.make_ndarray(event.tensor_proto), + ) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/tensorboard/backend/event_processing/db_import_multiplexer.py b/tensorboard/backend/event_processing/db_import_multiplexer.py index 968a06ea0f..dd4d39700b 100644 --- a/tensorboard/backend/event_processing/db_import_multiplexer.py +++ b/tensorboard/backend/event_processing/db_import_multiplexer.py @@ -42,262 +42,289 @@ class DbImportMultiplexer(plugin_event_multiplexer.EventMultiplexer): - """A loading-only `EventMultiplexer` that populates a SQLite DB. - - This EventMultiplexer only loads data; the read APIs always return empty - results, since all data is accessed instead via SQL against the - db_connection_provider wrapped by this multiplexer. - """ - - def __init__(self, - db_uri, - db_connection_provider, - purge_orphaned_data, - max_reload_threads): - """Constructor for `DbImportMultiplexer`. - - Args: - db_uri: A URI to the database file in use. - db_connection_provider: Provider function for creating a DB connection. - purge_orphaned_data: Whether to discard any events that were "orphaned" by - a TensorFlow restart. - max_reload_threads: The max number of threads that TensorBoard can use - to reload runs. Each thread reloads one run at a time. If not provided, - reloads runs serially (one after another). - """ - logger.info('DbImportMultiplexer initializing for %s', db_uri) - super(DbImportMultiplexer, self).__init__() - self.db_uri = db_uri - self.db_connection_provider = db_connection_provider - self._purge_orphaned_data = purge_orphaned_data - self._max_reload_threads = max_reload_threads - self._event_sink = None - self._run_loaders = {} - - if self._purge_orphaned_data: - logger.warn( - '--db_import does not yet support purging orphaned data') - - conn = self.db_connection_provider() - # Set the DB in WAL mode so reads don't block writes. - conn.execute('PRAGMA journal_mode=wal') - conn.execute('PRAGMA synchronous=normal') # Recommended for WAL mode - sqlite_writer.initialize_schema(conn) - logger.info('DbImportMultiplexer done initializing') - - def AddRun(self, path, name=None): - """Unsupported; instead use AddRunsFromDirectory.""" - raise NotImplementedError("Unsupported; use AddRunsFromDirectory()") - - def AddRunsFromDirectory(self, path, name=None): - """Load runs from a directory; recursively walks subdirectories. - - If path doesn't exist, no-op. This ensures that it is safe to call - `AddRunsFromDirectory` multiple times, even before the directory is made. - - Args: - path: A string path to a directory to load runs from. - name: Optional, specifies a name for the experiment under which the - runs from this directory hierarchy will be imported. If omitted, the - path will be used as the name. - - Raises: - ValueError: If the path exists and isn't a directory. + """A loading-only `EventMultiplexer` that populates a SQLite DB. + + This EventMultiplexer only loads data; the read APIs always return + empty results, since all data is accessed instead via SQL against + the db_connection_provider wrapped by this multiplexer. """ - logger.info('Starting AddRunsFromDirectory: %s (as %s)', path, name) - for subdir in io_wrapper.GetLogdirSubdirectories(path): - logger.info('Processing directory %s', subdir) - if subdir not in self._run_loaders: - logger.info('Creating DB loader for directory %s', subdir) - names = self._get_exp_and_run_names(path, subdir, name) - experiment_name, run_name = names - self._run_loaders[subdir] = _RunLoader( - subdir=subdir, - experiment_name=experiment_name, - run_name=run_name) - logger.info('Done with AddRunsFromDirectory: %s', path) - - def Reload(self): - """Load events from every detected run.""" - logger.info('Beginning DbImportMultiplexer.Reload()') - # Defer event sink creation until needed; this ensures it will only exist in - # the thread that calls Reload(), since DB connections must be thread-local. - if not self._event_sink: - self._event_sink = _SqliteWriterEventSink(self.db_connection_provider) - # Use collections.deque() for speed when we don't need blocking since it - # also has thread-safe appends/pops. - loader_queue = collections.deque(six.itervalues(self._run_loaders)) - loader_delete_queue = collections.deque() - - def batch_generator(): - while True: - try: - loader = loader_queue.popleft() - except IndexError: - return - try: - for batch in loader.load_batches(): - yield batch - except directory_watcher.DirectoryDeletedError: - loader_delete_queue.append(loader) - except (OSError, IOError) as e: - logger.error('Unable to load run %r: %s', loader.subdir, e) - - num_threads = min(self._max_reload_threads, len(self._run_loaders)) - if num_threads <= 1: - logger.info('Importing runs serially on a single thread') - for batch in batch_generator(): - self._event_sink.write_batch(batch) - else: - output_queue = queue.Queue() - sentinel = object() - def producer(): - try: - for batch in batch_generator(): - output_queue.put(batch) - finally: - output_queue.put(sentinel) - logger.info('Starting %d threads to import runs', num_threads) - for i in xrange(num_threads): - thread = threading.Thread(target=producer, name='Loader %d' % i) - thread.daemon = True - thread.start() - num_live_threads = num_threads - while num_live_threads > 0: - output = output_queue.get() - if output == sentinel: - num_live_threads -= 1 - continue - self._event_sink.write_batch(output) - for loader in loader_delete_queue: - logger.warn('Deleting loader %r', loader.subdir) - del self._run_loaders[loader.subdir] - logger.info('Finished with DbImportMultiplexer.Reload()') - - def _get_exp_and_run_names(self, path, subdir, experiment_name_override=None): - if experiment_name_override is not None: - return (experiment_name_override, os.path.relpath(subdir, path)) - sep = io_wrapper.PathSeparator(path) - path_parts = os.path.relpath(subdir, path).split(sep, 1) - experiment_name = path_parts[0] - run_name = path_parts[1] if len(path_parts) == 2 else '.' - return (experiment_name, run_name) + + def __init__( + self, + db_uri, + db_connection_provider, + purge_orphaned_data, + max_reload_threads, + ): + """Constructor for `DbImportMultiplexer`. + + Args: + db_uri: A URI to the database file in use. + db_connection_provider: Provider function for creating a DB connection. + purge_orphaned_data: Whether to discard any events that were "orphaned" by + a TensorFlow restart. + max_reload_threads: The max number of threads that TensorBoard can use + to reload runs. Each thread reloads one run at a time. If not provided, + reloads runs serially (one after another). + """ + logger.info("DbImportMultiplexer initializing for %s", db_uri) + super(DbImportMultiplexer, self).__init__() + self.db_uri = db_uri + self.db_connection_provider = db_connection_provider + self._purge_orphaned_data = purge_orphaned_data + self._max_reload_threads = max_reload_threads + self._event_sink = None + self._run_loaders = {} + + if self._purge_orphaned_data: + logger.warn( + "--db_import does not yet support purging orphaned data" + ) + + conn = self.db_connection_provider() + # Set the DB in WAL mode so reads don't block writes. + conn.execute("PRAGMA journal_mode=wal") + conn.execute("PRAGMA synchronous=normal") # Recommended for WAL mode + sqlite_writer.initialize_schema(conn) + logger.info("DbImportMultiplexer done initializing") + + def AddRun(self, path, name=None): + """Unsupported; instead use AddRunsFromDirectory.""" + raise NotImplementedError("Unsupported; use AddRunsFromDirectory()") + + def AddRunsFromDirectory(self, path, name=None): + """Load runs from a directory; recursively walks subdirectories. + + If path doesn't exist, no-op. This ensures that it is safe to call + `AddRunsFromDirectory` multiple times, even before the directory is made. + + Args: + path: A string path to a directory to load runs from. + name: Optional, specifies a name for the experiment under which the + runs from this directory hierarchy will be imported. If omitted, the + path will be used as the name. + + Raises: + ValueError: If the path exists and isn't a directory. + """ + logger.info("Starting AddRunsFromDirectory: %s (as %s)", path, name) + for subdir in io_wrapper.GetLogdirSubdirectories(path): + logger.info("Processing directory %s", subdir) + if subdir not in self._run_loaders: + logger.info("Creating DB loader for directory %s", subdir) + names = self._get_exp_and_run_names(path, subdir, name) + experiment_name, run_name = names + self._run_loaders[subdir] = _RunLoader( + subdir=subdir, + experiment_name=experiment_name, + run_name=run_name, + ) + logger.info("Done with AddRunsFromDirectory: %s", path) + + def Reload(self): + """Load events from every detected run.""" + logger.info("Beginning DbImportMultiplexer.Reload()") + # Defer event sink creation until needed; this ensures it will only exist in + # the thread that calls Reload(), since DB connections must be thread-local. + if not self._event_sink: + self._event_sink = _SqliteWriterEventSink( + self.db_connection_provider + ) + # Use collections.deque() for speed when we don't need blocking since it + # also has thread-safe appends/pops. + loader_queue = collections.deque(six.itervalues(self._run_loaders)) + loader_delete_queue = collections.deque() + + def batch_generator(): + while True: + try: + loader = loader_queue.popleft() + except IndexError: + return + try: + for batch in loader.load_batches(): + yield batch + except directory_watcher.DirectoryDeletedError: + loader_delete_queue.append(loader) + except (OSError, IOError) as e: + logger.error("Unable to load run %r: %s", loader.subdir, e) + + num_threads = min(self._max_reload_threads, len(self._run_loaders)) + if num_threads <= 1: + logger.info("Importing runs serially on a single thread") + for batch in batch_generator(): + self._event_sink.write_batch(batch) + else: + output_queue = queue.Queue() + sentinel = object() + + def producer(): + try: + for batch in batch_generator(): + output_queue.put(batch) + finally: + output_queue.put(sentinel) + + logger.info("Starting %d threads to import runs", num_threads) + for i in xrange(num_threads): + thread = threading.Thread(target=producer, name="Loader %d" % i) + thread.daemon = True + thread.start() + num_live_threads = num_threads + while num_live_threads > 0: + output = output_queue.get() + if output == sentinel: + num_live_threads -= 1 + continue + self._event_sink.write_batch(output) + for loader in loader_delete_queue: + logger.warn("Deleting loader %r", loader.subdir) + del self._run_loaders[loader.subdir] + logger.info("Finished with DbImportMultiplexer.Reload()") + + def _get_exp_and_run_names( + self, path, subdir, experiment_name_override=None + ): + if experiment_name_override is not None: + return (experiment_name_override, os.path.relpath(subdir, path)) + sep = io_wrapper.PathSeparator(path) + path_parts = os.path.relpath(subdir, path).split(sep, 1) + experiment_name = path_parts[0] + run_name = path_parts[1] if len(path_parts) == 2 else "." + return (experiment_name, run_name) + # Struct holding a list of tf.Event serialized protos along with metadata about # the associated experiment and run. -_EventBatch = collections.namedtuple('EventBatch', - ['events', 'experiment_name', 'run_name']) +_EventBatch = collections.namedtuple( + "EventBatch", ["events", "experiment_name", "run_name"] +) class _RunLoader(object): - """Loads a single run directory in batches.""" - - _BATCH_COUNT = 5000 - _BATCH_BYTES = 2**20 # 1 MiB - - def __init__(self, subdir, experiment_name, run_name): - """Constructs a `_RunLoader`. - - Args: - subdir: string, filesystem path of the run directory - experiment_name: string, name of the run's experiment - run_name: string, name of the run - """ - self._subdir = subdir - self._experiment_name = experiment_name - self._run_name = run_name - self._directory_watcher = directory_watcher.DirectoryWatcher( - subdir, - event_file_loader.RawEventFileLoader, - io_wrapper.IsTensorFlowEventsFile) - - @property - def subdir(self): - return self._subdir - - def load_batches(self): - """Returns a batched event iterator over the run directory event files.""" - event_iterator = self._directory_watcher.Load() - while True: - events = [] - event_bytes = 0 - start = time.time() - for event_proto in event_iterator: - events.append(event_proto) - event_bytes += len(event_proto) - if len(events) >= self._BATCH_COUNT or event_bytes >= self._BATCH_BYTES: - break - elapsed = time.time() - start - logger.debug('RunLoader.load_batch() yielded in %0.3f sec for %s', - elapsed, self._subdir) - if not events: - return - yield _EventBatch( - events=events, - experiment_name=self._experiment_name, - run_name=self._run_name) + """Loads a single run directory in batches.""" + + _BATCH_COUNT = 5000 + _BATCH_BYTES = 2 ** 20 # 1 MiB + + def __init__(self, subdir, experiment_name, run_name): + """Constructs a `_RunLoader`. + + Args: + subdir: string, filesystem path of the run directory + experiment_name: string, name of the run's experiment + run_name: string, name of the run + """ + self._subdir = subdir + self._experiment_name = experiment_name + self._run_name = run_name + self._directory_watcher = directory_watcher.DirectoryWatcher( + subdir, + event_file_loader.RawEventFileLoader, + io_wrapper.IsTensorFlowEventsFile, + ) + + @property + def subdir(self): + return self._subdir + + def load_batches(self): + """Returns a batched event iterator over the run directory event + files.""" + event_iterator = self._directory_watcher.Load() + while True: + events = [] + event_bytes = 0 + start = time.time() + for event_proto in event_iterator: + events.append(event_proto) + event_bytes += len(event_proto) + if ( + len(events) >= self._BATCH_COUNT + or event_bytes >= self._BATCH_BYTES + ): + break + elapsed = time.time() - start + logger.debug( + "RunLoader.load_batch() yielded in %0.3f sec for %s", + elapsed, + self._subdir, + ) + if not events: + return + yield _EventBatch( + events=events, + experiment_name=self._experiment_name, + run_name=self._run_name, + ) @six.add_metaclass(abc.ABCMeta) class _EventSink(object): - """Abstract sink for batches of serialized tf.Event data.""" + """Abstract sink for batches of serialized tf.Event data.""" - @abc.abstractmethod - def write_batch(self, event_batch): - """Writes the given event batch to the sink. + @abc.abstractmethod + def write_batch(self, event_batch): + """Writes the given event batch to the sink. - Args: - event_batch: an _EventBatch of event data. - """ - raise NotImplementedError() + Args: + event_batch: an _EventBatch of event data. + """ + raise NotImplementedError() class _SqliteWriterEventSink(_EventSink): - """Implementation of EventSink using SqliteWriter.""" - - def __init__(self, db_connection_provider): - """Constructs a SqliteWriterEventSink. - - Args: - db_connection_provider: Provider function for creating a DB connection. - """ - self._writer = sqlite_writer.SqliteWriter(db_connection_provider) - - def write_batch(self, event_batch): - start = time.time() - tagged_data = {} - for event_proto in event_batch.events: - event = event_pb2.Event.FromString(event_proto) - self._process_event(event, tagged_data) - if tagged_data: - self._writer.write_summaries( - tagged_data, - experiment_name=event_batch.experiment_name, - run_name=event_batch.run_name) - elapsed = time.time() - start - logger.debug( - 'SqliteWriterEventSink.WriteBatch() took %0.3f sec for %s events', - elapsed, len(event_batch.events)) - - def _process_event(self, event, tagged_data): - """Processes a single tf.Event and records it in tagged_data.""" - event_type = event.WhichOneof('what') - # Handle the most common case first. - if event_type == 'summary': - for value in event.summary.value: - value = data_compat.migrate_value(value) - tag, metadata, values = tagged_data.get(value.tag, (None, None, [])) - values.append((event.step, event.wall_time, value.tensor)) - if tag is None: - # Store metadata only from the first event. - tagged_data[value.tag] = sqlite_writer.TagData( - value.tag, value.metadata, values) - elif event_type == 'file_version': - pass # TODO: reject file version < 2 (at loader level) - elif event_type == 'session_log': - if event.session_log.status == event_pb2.SessionLog.START: - pass # TODO: implement purging via sqlite writer truncation method - elif event_type in ('graph_def', 'meta_graph_def'): - pass # TODO: support graphs - elif event_type == 'tagged_run_metadata': - pass # TODO: support run metadata + """Implementation of EventSink using SqliteWriter.""" + + def __init__(self, db_connection_provider): + """Constructs a SqliteWriterEventSink. + + Args: + db_connection_provider: Provider function for creating a DB connection. + """ + self._writer = sqlite_writer.SqliteWriter(db_connection_provider) + + def write_batch(self, event_batch): + start = time.time() + tagged_data = {} + for event_proto in event_batch.events: + event = event_pb2.Event.FromString(event_proto) + self._process_event(event, tagged_data) + if tagged_data: + self._writer.write_summaries( + tagged_data, + experiment_name=event_batch.experiment_name, + run_name=event_batch.run_name, + ) + elapsed = time.time() - start + logger.debug( + "SqliteWriterEventSink.WriteBatch() took %0.3f sec for %s events", + elapsed, + len(event_batch.events), + ) + + def _process_event(self, event, tagged_data): + """Processes a single tf.Event and records it in tagged_data.""" + event_type = event.WhichOneof("what") + # Handle the most common case first. + if event_type == "summary": + for value in event.summary.value: + value = data_compat.migrate_value(value) + tag, metadata, values = tagged_data.get( + value.tag, (None, None, []) + ) + values.append((event.step, event.wall_time, value.tensor)) + if tag is None: + # Store metadata only from the first event. + tagged_data[value.tag] = sqlite_writer.TagData( + value.tag, value.metadata, values + ) + elif event_type == "file_version": + pass # TODO: reject file version < 2 (at loader level) + elif event_type == "session_log": + if event.session_log.status == event_pb2.SessionLog.START: + pass # TODO: implement purging via sqlite writer truncation method + elif event_type in ("graph_def", "meta_graph_def"): + pass # TODO: support graphs + elif event_type == "tagged_run_metadata": + pass # TODO: support run metadata diff --git a/tensorboard/backend/event_processing/db_import_multiplexer_test.py b/tensorboard/backend/event_processing/db_import_multiplexer_test.py index 4a43dd3fb8..89ee99f2d3 100644 --- a/tensorboard/backend/event_processing/db_import_multiplexer_test.py +++ b/tensorboard/backend/event_processing/db_import_multiplexer_test.py @@ -31,134 +31,151 @@ def add_event(path): - with test_util.FileWriterCache.get(path) as writer: - event = event_pb2.Event() - event.summary.value.add(tag='tag', tensor=tensor_util.make_tensor_proto(1)) - writer.add_event(event) + with test_util.FileWriterCache.get(path) as writer: + event = event_pb2.Event() + event.summary.value.add( + tag="tag", tensor=tensor_util.make_tensor_proto(1) + ) + writer.add_event(event) class DbImportMultiplexerTest(tf.test.TestCase): - - def setUp(self): - super(DbImportMultiplexerTest, self).setUp() - - db_file_name = os.path.join(self.get_temp_dir(), 'db') - self.db_connection_provider = lambda: sqlite3.connect(db_file_name) - self.multiplexer = db_import_multiplexer.DbImportMultiplexer( - db_uri='sqlite:' + db_file_name, - db_connection_provider=self.db_connection_provider, - purge_orphaned_data=False, - max_reload_threads=1) - - def _get_runs(self): - db = self.db_connection_provider() - cursor = db.execute(''' + def setUp(self): + super(DbImportMultiplexerTest, self).setUp() + + db_file_name = os.path.join(self.get_temp_dir(), "db") + self.db_connection_provider = lambda: sqlite3.connect(db_file_name) + self.multiplexer = db_import_multiplexer.DbImportMultiplexer( + db_uri="sqlite:" + db_file_name, + db_connection_provider=self.db_connection_provider, + purge_orphaned_data=False, + max_reload_threads=1, + ) + + def _get_runs(self): + db = self.db_connection_provider() + cursor = db.execute( + """ SELECT Runs.run_name FROM Runs ORDER BY Runs.run_name - ''') - return [row[0] for row in cursor] - - def _get_experiments(self): - db = self.db_connection_provider() - cursor = db.execute(''' + """ + ) + return [row[0] for row in cursor] + + def _get_experiments(self): + db = self.db_connection_provider() + cursor = db.execute( + """ SELECT Experiments.experiment_name FROM Experiments ORDER BY Experiments.experiment_name - ''') - return [row[0] for row in cursor] - - def test_init(self): - """Tests that DB schema is created when creating DbImportMultiplexer.""" - # Reading DB before schema initialization raises. - self.assertEqual(self._get_experiments(), []) - self.assertEqual(self._get_runs(), []) - - def test_empty_folder(self): - fake_dir = os.path.join(self.get_temp_dir(), 'fake_dir') - self.multiplexer.AddRunsFromDirectory(fake_dir) - self.assertEqual(self._get_experiments(), []) - self.assertEqual(self._get_runs(), []) - - def test_flat(self): - path = self.get_temp_dir() - add_event(path) - self.multiplexer.AddRunsFromDirectory(path) - self.multiplexer.Reload() - # Because we added runs from `path`, there is no folder to infer experiment - # and run names from. - self.assertEqual(self._get_experiments(), [u'.']) - self.assertEqual(self._get_runs(), [u'.']) - - def test_single_level(self): - path = self.get_temp_dir() - add_event(os.path.join(path, 'exp1')) - add_event(os.path.join(path, 'exp2')) - self.multiplexer.AddRunsFromDirectory(path) - self.multiplexer.Reload() - self.assertEqual(self._get_experiments(), [u'exp1', u'exp2']) - # Run names are '.'. because we already used the directory name for - # inferring experiment name. There are two items with the same name but - # with different ids. - self.assertEqual(self._get_runs(), [u'.', u'.']) - - def test_double_level(self): - path = self.get_temp_dir() - add_event(os.path.join(path, 'exp1', 'test')) - add_event(os.path.join(path, 'exp1', 'train')) - add_event(os.path.join(path, 'exp2', 'test')) - self.multiplexer.AddRunsFromDirectory(path) - self.multiplexer.Reload() - self.assertEqual(self._get_experiments(), [u'exp1', u'exp2']) - # There are two items with the same name but with different ids. - self.assertEqual(self._get_runs(), [u'test', u'test', u'train']) - - def test_mixed_levels(self): - # Mixture of root and single levels. - path = self.get_temp_dir() - # Train is in the root directory. - add_event(os.path.join(path)) - add_event(os.path.join(path, 'eval')) - self.multiplexer.AddRunsFromDirectory(path) - self.multiplexer.Reload() - self.assertEqual(self._get_experiments(), [u'.', u'eval']) - self.assertEqual(self._get_runs(), [u'.', u'.']) - - def test_deep(self): - path = self.get_temp_dir() - add_event(os.path.join(path, 'exp1', 'run1', 'bar', 'train')) - add_event(os.path.join(path, 'exp2', 'run1', 'baz', 'train')) - self.multiplexer.AddRunsFromDirectory(path) - self.multiplexer.Reload() - self.assertEqual(self._get_experiments(), [u'exp1', u'exp2']) - self.assertEqual(self._get_runs(), [os.path.join('run1', 'bar', 'train'), - os.path.join('run1', 'baz', 'train')]) - - def test_manual_name(self): - path1 = os.path.join(self.get_temp_dir(), 'foo') - path2 = os.path.join(self.get_temp_dir(), 'bar') - add_event(os.path.join(path1, 'some', 'nested', 'name')) - add_event(os.path.join(path2, 'some', 'nested', 'name')) - self.multiplexer.AddRunsFromDirectory(path1, 'name1') - self.multiplexer.AddRunsFromDirectory(path2, 'name2') - self.multiplexer.Reload() - self.assertEqual(self._get_experiments(), [u'name1', u'name2']) - # Run name ignored 'foo' and 'bar' on 'foo/some/nested/name' and - # 'bar/some/nested/name', respectively. - # There are two items with the same name but with different ids. - self.assertEqual(self._get_runs(), [os.path.join('some', 'nested', 'name'), - os.path.join('some', 'nested', 'name')]) - - def test_empty_read_apis(self): - path = self.get_temp_dir() - add_event(path) - self.assertEmpty(self.multiplexer.Runs()) - self.multiplexer.AddRunsFromDirectory(path) - self.multiplexer.Reload() - self.assertEmpty(self.multiplexer.Runs()) - - -if __name__ == '__main__': - tf.test.main() + """ + ) + return [row[0] for row in cursor] + + def test_init(self): + """Tests that DB schema is created when creating + DbImportMultiplexer.""" + # Reading DB before schema initialization raises. + self.assertEqual(self._get_experiments(), []) + self.assertEqual(self._get_runs(), []) + + def test_empty_folder(self): + fake_dir = os.path.join(self.get_temp_dir(), "fake_dir") + self.multiplexer.AddRunsFromDirectory(fake_dir) + self.assertEqual(self._get_experiments(), []) + self.assertEqual(self._get_runs(), []) + + def test_flat(self): + path = self.get_temp_dir() + add_event(path) + self.multiplexer.AddRunsFromDirectory(path) + self.multiplexer.Reload() + # Because we added runs from `path`, there is no folder to infer experiment + # and run names from. + self.assertEqual(self._get_experiments(), [u"."]) + self.assertEqual(self._get_runs(), [u"."]) + + def test_single_level(self): + path = self.get_temp_dir() + add_event(os.path.join(path, "exp1")) + add_event(os.path.join(path, "exp2")) + self.multiplexer.AddRunsFromDirectory(path) + self.multiplexer.Reload() + self.assertEqual(self._get_experiments(), [u"exp1", u"exp2"]) + # Run names are '.'. because we already used the directory name for + # inferring experiment name. There are two items with the same name but + # with different ids. + self.assertEqual(self._get_runs(), [u".", u"."]) + + def test_double_level(self): + path = self.get_temp_dir() + add_event(os.path.join(path, "exp1", "test")) + add_event(os.path.join(path, "exp1", "train")) + add_event(os.path.join(path, "exp2", "test")) + self.multiplexer.AddRunsFromDirectory(path) + self.multiplexer.Reload() + self.assertEqual(self._get_experiments(), [u"exp1", u"exp2"]) + # There are two items with the same name but with different ids. + self.assertEqual(self._get_runs(), [u"test", u"test", u"train"]) + + def test_mixed_levels(self): + # Mixture of root and single levels. + path = self.get_temp_dir() + # Train is in the root directory. + add_event(os.path.join(path)) + add_event(os.path.join(path, "eval")) + self.multiplexer.AddRunsFromDirectory(path) + self.multiplexer.Reload() + self.assertEqual(self._get_experiments(), [u".", u"eval"]) + self.assertEqual(self._get_runs(), [u".", u"."]) + + def test_deep(self): + path = self.get_temp_dir() + add_event(os.path.join(path, "exp1", "run1", "bar", "train")) + add_event(os.path.join(path, "exp2", "run1", "baz", "train")) + self.multiplexer.AddRunsFromDirectory(path) + self.multiplexer.Reload() + self.assertEqual(self._get_experiments(), [u"exp1", u"exp2"]) + self.assertEqual( + self._get_runs(), + [ + os.path.join("run1", "bar", "train"), + os.path.join("run1", "baz", "train"), + ], + ) + + def test_manual_name(self): + path1 = os.path.join(self.get_temp_dir(), "foo") + path2 = os.path.join(self.get_temp_dir(), "bar") + add_event(os.path.join(path1, "some", "nested", "name")) + add_event(os.path.join(path2, "some", "nested", "name")) + self.multiplexer.AddRunsFromDirectory(path1, "name1") + self.multiplexer.AddRunsFromDirectory(path2, "name2") + self.multiplexer.Reload() + self.assertEqual(self._get_experiments(), [u"name1", u"name2"]) + # Run name ignored 'foo' and 'bar' on 'foo/some/nested/name' and + # 'bar/some/nested/name', respectively. + # There are two items with the same name but with different ids. + self.assertEqual( + self._get_runs(), + [ + os.path.join("some", "nested", "name"), + os.path.join("some", "nested", "name"), + ], + ) + + def test_empty_read_apis(self): + path = self.get_temp_dir() + add_event(path) + self.assertEmpty(self.multiplexer.Runs()) + self.multiplexer.AddRunsFromDirectory(path) + self.multiplexer.Reload() + self.assertEmpty(self.multiplexer.Runs()) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/backend/event_processing/directory_loader.py b/tensorboard/backend/event_processing/directory_loader.py index 4182c3f38d..27979842e6 100644 --- a/tensorboard/backend/event_processing/directory_loader.py +++ b/tensorboard/backend/event_processing/directory_loader.py @@ -33,104 +33,115 @@ class DirectoryLoader(object): - """Loader for an entire directory, maintaining multiple active file loaders. - - This class takes a directory, a factory for loaders, and optionally a - path filter and watches all the paths inside that directory for new data. - Each file loader created by the factory must read a path and produce an - iterator of (timestamp, value) pairs. - - Unlike DirectoryWatcher, this class does not assume that only one file - receives new data at a time; there can be arbitrarily many active files. - However, any file whose maximum load timestamp fails an "active" predicate - will be marked as inactive and no longer checked for new data. - """ - - def __init__(self, directory, loader_factory, path_filter=lambda x: True, - active_filter=lambda timestamp: True): - """Constructs a new MultiFileDirectoryLoader. - - Args: - directory: The directory to load files from. - loader_factory: A factory for creating loaders. The factory should take a - path and return an object that has a Load method returning an iterator - yielding (unix timestamp as float, value) pairs for any new data - path_filter: If specified, only paths matching this filter are loaded. - active_filter: If specified, any loader whose maximum load timestamp does - not pass this filter will be marked as inactive and no longer read. - - Raises: - ValueError: If directory or loader_factory are None. + """Loader for an entire directory, maintaining multiple active file + loaders. + + This class takes a directory, a factory for loaders, and optionally a + path filter and watches all the paths inside that directory for new data. + Each file loader created by the factory must read a path and produce an + iterator of (timestamp, value) pairs. + + Unlike DirectoryWatcher, this class does not assume that only one file + receives new data at a time; there can be arbitrarily many active files. + However, any file whose maximum load timestamp fails an "active" predicate + will be marked as inactive and no longer checked for new data. """ - if directory is None: - raise ValueError('A directory is required') - if loader_factory is None: - raise ValueError('A loader factory is required') - self._directory = directory - self._loader_factory = loader_factory - self._path_filter = path_filter - self._active_filter = active_filter - self._loaders = {} - self._max_timestamps = {} - - def Load(self): - """Loads new values from all active files. - - Yields: - All values that have not been yielded yet. - - Raises: - DirectoryDeletedError: If the directory has been permanently deleted - (as opposed to being temporarily unavailable). - """ - try: - all_paths = io_wrapper.ListDirectoryAbsolute(self._directory) - paths = sorted(p for p in all_paths if self._path_filter(p)) - for path in paths: - for value in self._LoadPath(path): - yield value - except tf.errors.OpError as e: - if not tf.io.gfile.exists(self._directory): - raise directory_watcher.DirectoryDeletedError( - 'Directory %s has been permanently deleted' % self._directory) - else: - logger.info('Ignoring error during file loading: %s' % e) - - def _LoadPath(self, path): - """Generator for values from a single path's loader. - - Args: - path: the path to load from - - Yields: - All values from this path's loader that have not been yielded yet. - """ - max_timestamp = self._max_timestamps.get(path, None) - if max_timestamp is _INACTIVE or self._MarkIfInactive(path, max_timestamp): - logger.debug('Skipping inactive path %s', path) - return - loader = self._loaders.get(path, None) - if loader is None: - try: - loader = self._loader_factory(path) - except tf.errors.NotFoundError: - # Happens if a file was removed after we listed the directory. - logger.debug('Skipping nonexistent path %s', path) - return - self._loaders[path] = loader - logger.info('Loading data from path %s', path) - for timestamp, value in loader.Load(): - if max_timestamp is None or timestamp > max_timestamp: - max_timestamp = timestamp - yield value - if not self._MarkIfInactive(path, max_timestamp): - self._max_timestamps[path] = max_timestamp - - def _MarkIfInactive(self, path, max_timestamp): - """If max_timestamp is inactive, returns True and marks the path as such.""" - logger.debug('Checking active status of %s at %s', path, max_timestamp) - if max_timestamp is not None and not self._active_filter(max_timestamp): - self._max_timestamps[path] = _INACTIVE - del self._loaders[path] - return True - return False + + def __init__( + self, + directory, + loader_factory, + path_filter=lambda x: True, + active_filter=lambda timestamp: True, + ): + """Constructs a new MultiFileDirectoryLoader. + + Args: + directory: The directory to load files from. + loader_factory: A factory for creating loaders. The factory should take a + path and return an object that has a Load method returning an iterator + yielding (unix timestamp as float, value) pairs for any new data + path_filter: If specified, only paths matching this filter are loaded. + active_filter: If specified, any loader whose maximum load timestamp does + not pass this filter will be marked as inactive and no longer read. + + Raises: + ValueError: If directory or loader_factory are None. + """ + if directory is None: + raise ValueError("A directory is required") + if loader_factory is None: + raise ValueError("A loader factory is required") + self._directory = directory + self._loader_factory = loader_factory + self._path_filter = path_filter + self._active_filter = active_filter + self._loaders = {} + self._max_timestamps = {} + + def Load(self): + """Loads new values from all active files. + + Yields: + All values that have not been yielded yet. + + Raises: + DirectoryDeletedError: If the directory has been permanently deleted + (as opposed to being temporarily unavailable). + """ + try: + all_paths = io_wrapper.ListDirectoryAbsolute(self._directory) + paths = sorted(p for p in all_paths if self._path_filter(p)) + for path in paths: + for value in self._LoadPath(path): + yield value + except tf.errors.OpError as e: + if not tf.io.gfile.exists(self._directory): + raise directory_watcher.DirectoryDeletedError( + "Directory %s has been permanently deleted" + % self._directory + ) + else: + logger.info("Ignoring error during file loading: %s" % e) + + def _LoadPath(self, path): + """Generator for values from a single path's loader. + + Args: + path: the path to load from + + Yields: + All values from this path's loader that have not been yielded yet. + """ + max_timestamp = self._max_timestamps.get(path, None) + if max_timestamp is _INACTIVE or self._MarkIfInactive( + path, max_timestamp + ): + logger.debug("Skipping inactive path %s", path) + return + loader = self._loaders.get(path, None) + if loader is None: + try: + loader = self._loader_factory(path) + except tf.errors.NotFoundError: + # Happens if a file was removed after we listed the directory. + logger.debug("Skipping nonexistent path %s", path) + return + self._loaders[path] = loader + logger.info("Loading data from path %s", path) + for timestamp, value in loader.Load(): + if max_timestamp is None or timestamp > max_timestamp: + max_timestamp = timestamp + yield value + if not self._MarkIfInactive(path, max_timestamp): + self._max_timestamps[path] = max_timestamp + + def _MarkIfInactive(self, path, max_timestamp): + """If max_timestamp is inactive, returns True and marks the path as + such.""" + logger.debug("Checking active status of %s at %s", path, max_timestamp) + if max_timestamp is not None and not self._active_filter(max_timestamp): + self._max_timestamps[path] = _INACTIVE + del self._loaders[path] + return True + return False diff --git a/tensorboard/backend/event_processing/directory_loader_test.py b/tensorboard/backend/event_processing/directory_loader_test.py index deaa1030fa..efaab24227 100644 --- a/tensorboard/backend/event_processing/directory_loader_test.py +++ b/tensorboard/backend/event_processing/directory_loader_test.py @@ -25,10 +25,10 @@ import shutil try: - # python version >= 3.3 - from unittest import mock + # python version >= 3.3 + from unittest import mock except ImportError: - import mock # pylint: disable=unused-import + import mock # pylint: disable=unused-import import tensorflow as tf @@ -40,214 +40,239 @@ class _TimestampedByteLoader(object): - """A loader that loads timestamped bytes from a file.""" + """A loader that loads timestamped bytes from a file.""" - def __init__(self, path, registry=None): - self._path = path - self._registry = registry if registry is not None else [] - self._registry.append(path) - self._f = open(path) + def __init__(self, path, registry=None): + self._path = path + self._registry = registry if registry is not None else [] + self._registry.append(path) + self._f = open(path) - def __del__(self): - self._registry.remove(self._path) + def __del__(self): + self._registry.remove(self._path) - def Load(self): - while True: - line = self._f.readline() - if not line: - return - ts, value = line.rstrip('\n').split(':') - yield float(ts), value + def Load(self): + while True: + line = self._f.readline() + if not line: + return + ts, value = line.rstrip("\n").split(":") + yield float(ts), value class DirectoryLoaderTest(tf.test.TestCase): + def setUp(self): + # Put everything in a directory so it's easier to delete w/in tests. + self._directory = os.path.join(self.get_temp_dir(), "testdir") + os.mkdir(self._directory) + self._loader = directory_loader.DirectoryLoader( + self._directory, _TimestampedByteLoader + ) - def setUp(self): - # Put everything in a directory so it's easier to delete w/in tests. - self._directory = os.path.join(self.get_temp_dir(), 'testdir') - os.mkdir(self._directory) - self._loader = directory_loader.DirectoryLoader( - self._directory, _TimestampedByteLoader) - - def _WriteToFile(self, filename, data, timestamps=None): - if timestamps is None: - timestamps = range(len(data)) - self.assertEqual(len(data), len(timestamps)) - path = os.path.join(self._directory, filename) - with open(path, 'a') as f: - for byte, timestamp in zip(data, timestamps): - f.write('%f:%s\n' % (timestamp, byte)) - - def assertLoaderYields(self, values): - self.assertEqual(list(self._loader.Load()), values) - - def testRaisesWithBadArguments(self): - with self.assertRaises(ValueError): - directory_loader.DirectoryLoader(None, lambda x: None) - with self.assertRaises(ValueError): - directory_loader.DirectoryLoader('dir', None) - - def testEmptyDirectory(self): - self.assertLoaderYields([]) - - def testSingleFileLoading(self): - self._WriteToFile('a', 'abc') - self.assertLoaderYields(['a', 'b', 'c']) - self.assertLoaderYields([]) - self._WriteToFile('a', 'xyz') - self.assertLoaderYields(['x', 'y', 'z']) - self.assertLoaderYields([]) - - def testMultipleFileLoading(self): - self._WriteToFile('a', 'a') - self._WriteToFile('b', 'b') - self.assertLoaderYields(['a', 'b']) - self.assertLoaderYields([]) - self._WriteToFile('a', 'A') - self._WriteToFile('b', 'B') - self._WriteToFile('c', 'c') - # The loader should read new data from all the files. - self.assertLoaderYields(['A', 'B', 'c']) - self.assertLoaderYields([]) - - def testMultipleFileLoading_intermediateEmptyFiles(self): - self._WriteToFile('a', 'a') - self._WriteToFile('b', '') - self._WriteToFile('c', 'c') - self.assertLoaderYields(['a', 'c']) - - def testPathFilter(self): - self._loader = directory_loader.DirectoryLoader( - self._directory, _TimestampedByteLoader, - lambda path: 'tfevents' in path) - self._WriteToFile('skipped', 'a') - self._WriteToFile('event.out.tfevents.foo.bar', 'b') - self._WriteToFile('tf.event', 'c') - self.assertLoaderYields(['b']) - - def testActiveFilter_staticFilterBehavior(self): - """Tests behavior of a static active_filter.""" - loader_registry = [] - loader_factory = functools.partial( - _TimestampedByteLoader, registry=loader_registry) - active_filter = lambda timestamp: timestamp >= 2 - self._loader = directory_loader.DirectoryLoader( - self._directory, loader_factory, active_filter=active_filter) - def assertLoadersForPaths(paths): - paths = [os.path.join(self._directory, path) for path in paths] - self.assertEqual(loader_registry, paths) - # a: normal-looking file. - # b: file without sufficiently active data (should be marked inactive). - # c: file with timestamps in reverse order (max computed correctly). - # d: empty file (should be considered active in absence of timestamps). - self._WriteToFile('a', ['A1', 'A2'], [1, 2]) - self._WriteToFile('b', ['B1'], [1]) - self._WriteToFile('c', ['C2', 'C1', 'C0'], [2, 1, 0]) - self._WriteToFile('d', [], []) - self.assertLoaderYields(['A1', 'A2', 'B1', 'C2', 'C1', 'C0']) - assertLoadersForPaths(['a', 'c', 'd']) - self._WriteToFile('a', ['A3'], [3]) - self._WriteToFile('b', ['B3'], [3]) - self._WriteToFile('c', ['C0'], [0]) - self._WriteToFile('d', ['D3'], [3]) - self.assertLoaderYields(['A3', 'C0', 'D3']) - assertLoadersForPaths(['a', 'c', 'd']) - # Check that a 0 timestamp in file C on the most recent load doesn't - # override the max timestamp of 2 seen in the earlier load. - self._WriteToFile('c', ['C4'], [4]) - self.assertLoaderYields(['C4']) - assertLoadersForPaths(['a', 'c', 'd']) - - def testActiveFilter_dynamicFilterBehavior(self): - """Tests behavior of a dynamic active_filter.""" - loader_registry = [] - loader_factory = functools.partial( - _TimestampedByteLoader, registry=loader_registry) - threshold = 0 - active_filter = lambda timestamp: timestamp >= threshold - self._loader = directory_loader.DirectoryLoader( - self._directory, loader_factory, active_filter=active_filter) - def assertLoadersForPaths(paths): - paths = [os.path.join(self._directory, path) for path in paths] - self.assertEqual(loader_registry, paths) - self._WriteToFile('a', ['A1', 'A2'], [1, 2]) - self._WriteToFile('b', ['B1', 'B2', 'B3'], [1, 2, 3]) - self._WriteToFile('c', ['C1'], [1]) - threshold = 2 - # First load pass should leave file C marked inactive. - self.assertLoaderYields(['A1', 'A2', 'B1', 'B2', 'B3', 'C1']) - assertLoadersForPaths(['a', 'b']) - self._WriteToFile('a', ['A4'], [4]) - self._WriteToFile('b', ['B4'], [4]) - self._WriteToFile('c', ['C4'], [4]) - threshold = 3 - # Second load pass should mark file A as inactive (due to newly - # increased threshold) and thus skip reading data from it. - self.assertLoaderYields(['B4']) - assertLoadersForPaths(['b']) - self._WriteToFile('b', ['B5', 'B6'], [5, 6]) - # Simulate a third pass in which the threshold increases while - # we're processing a file, so it's still active at the start of the - # load but should be marked inactive at the end. - load_generator = self._loader.Load() - self.assertEqual('B5', next(load_generator)) - threshold = 7 - self.assertEqual(['B6'], list(load_generator)) - assertLoadersForPaths([]) - # Confirm that all loaders are now inactive. - self._WriteToFile('b', ['B7'], [7]) - self.assertLoaderYields([]) - - def testDoesntCrashWhenCurrentFileIsDeleted(self): - # Use actual file loader so it emits the real error. - self._loader = directory_loader.DirectoryLoader( - self._directory, event_file_loader.TimestampedEventFileLoader) - with test_util.FileWriter(self._directory, filename_suffix='.a') as writer_a: - writer_a.add_test_summary('a') - events = list(self._loader.Load()) - events.pop(0) # Ignore the file_version event. - self.assertEqual(1, len(events)) - self.assertEqual('a', events[0].summary.value[0].tag) - os.remove(glob.glob(os.path.join(self._directory, '*.a'))[0]) - with test_util.FileWriter(self._directory, filename_suffix='.b') as writer_b: - writer_b.add_test_summary('b') - events = list(self._loader.Load()) - events.pop(0) # Ignore the file_version event. - self.assertEqual(1, len(events)) - self.assertEqual('b', events[0].summary.value[0].tag) - - def testDoesntCrashWhenUpcomingFileIsDeleted(self): - # Use actual file loader so it emits the real error. - self._loader = directory_loader.DirectoryLoader( - self._directory, event_file_loader.TimestampedEventFileLoader) - with test_util.FileWriter(self._directory, filename_suffix='.a') as writer_a: - writer_a.add_test_summary('a') - with test_util.FileWriter(self._directory, filename_suffix='.b') as writer_b: - writer_b.add_test_summary('b') - generator = self._loader.Load() - next(generator) # Ignore the file_version event. - event = next(generator) - self.assertEqual('a', event.summary.value[0].tag) - os.remove(glob.glob(os.path.join(self._directory, '*.b'))[0]) - self.assertEmpty(list(generator)) - - def testRaisesDirectoryDeletedError_whenDirectoryIsDeleted(self): - self._WriteToFile('a', 'a') - self.assertLoaderYields(['a']) - shutil.rmtree(self._directory) - with self.assertRaises(directory_watcher.DirectoryDeletedError): - next(self._loader.Load()) - - def testDoesntRaiseDirectoryDeletedError_forUnrecognizedException(self): - self._WriteToFile('a', 'a') - self.assertLoaderYields(['a']) - class MyException(Exception): - pass - with mock.patch.object(io_wrapper, 'ListDirectoryAbsolute') as mock_listdir: - mock_listdir.side_effect = MyException - with self.assertRaises(MyException): - next(self._loader.Load()) - self.assertLoaderYields([]) - -if __name__ == '__main__': - tf.test.main() + def _WriteToFile(self, filename, data, timestamps=None): + if timestamps is None: + timestamps = range(len(data)) + self.assertEqual(len(data), len(timestamps)) + path = os.path.join(self._directory, filename) + with open(path, "a") as f: + for byte, timestamp in zip(data, timestamps): + f.write("%f:%s\n" % (timestamp, byte)) + + def assertLoaderYields(self, values): + self.assertEqual(list(self._loader.Load()), values) + + def testRaisesWithBadArguments(self): + with self.assertRaises(ValueError): + directory_loader.DirectoryLoader(None, lambda x: None) + with self.assertRaises(ValueError): + directory_loader.DirectoryLoader("dir", None) + + def testEmptyDirectory(self): + self.assertLoaderYields([]) + + def testSingleFileLoading(self): + self._WriteToFile("a", "abc") + self.assertLoaderYields(["a", "b", "c"]) + self.assertLoaderYields([]) + self._WriteToFile("a", "xyz") + self.assertLoaderYields(["x", "y", "z"]) + self.assertLoaderYields([]) + + def testMultipleFileLoading(self): + self._WriteToFile("a", "a") + self._WriteToFile("b", "b") + self.assertLoaderYields(["a", "b"]) + self.assertLoaderYields([]) + self._WriteToFile("a", "A") + self._WriteToFile("b", "B") + self._WriteToFile("c", "c") + # The loader should read new data from all the files. + self.assertLoaderYields(["A", "B", "c"]) + self.assertLoaderYields([]) + + def testMultipleFileLoading_intermediateEmptyFiles(self): + self._WriteToFile("a", "a") + self._WriteToFile("b", "") + self._WriteToFile("c", "c") + self.assertLoaderYields(["a", "c"]) + + def testPathFilter(self): + self._loader = directory_loader.DirectoryLoader( + self._directory, + _TimestampedByteLoader, + lambda path: "tfevents" in path, + ) + self._WriteToFile("skipped", "a") + self._WriteToFile("event.out.tfevents.foo.bar", "b") + self._WriteToFile("tf.event", "c") + self.assertLoaderYields(["b"]) + + def testActiveFilter_staticFilterBehavior(self): + """Tests behavior of a static active_filter.""" + loader_registry = [] + loader_factory = functools.partial( + _TimestampedByteLoader, registry=loader_registry + ) + active_filter = lambda timestamp: timestamp >= 2 + self._loader = directory_loader.DirectoryLoader( + self._directory, loader_factory, active_filter=active_filter + ) + + def assertLoadersForPaths(paths): + paths = [os.path.join(self._directory, path) for path in paths] + self.assertEqual(loader_registry, paths) + + # a: normal-looking file. + # b: file without sufficiently active data (should be marked inactive). + # c: file with timestamps in reverse order (max computed correctly). + # d: empty file (should be considered active in absence of timestamps). + self._WriteToFile("a", ["A1", "A2"], [1, 2]) + self._WriteToFile("b", ["B1"], [1]) + self._WriteToFile("c", ["C2", "C1", "C0"], [2, 1, 0]) + self._WriteToFile("d", [], []) + self.assertLoaderYields(["A1", "A2", "B1", "C2", "C1", "C0"]) + assertLoadersForPaths(["a", "c", "d"]) + self._WriteToFile("a", ["A3"], [3]) + self._WriteToFile("b", ["B3"], [3]) + self._WriteToFile("c", ["C0"], [0]) + self._WriteToFile("d", ["D3"], [3]) + self.assertLoaderYields(["A3", "C0", "D3"]) + assertLoadersForPaths(["a", "c", "d"]) + # Check that a 0 timestamp in file C on the most recent load doesn't + # override the max timestamp of 2 seen in the earlier load. + self._WriteToFile("c", ["C4"], [4]) + self.assertLoaderYields(["C4"]) + assertLoadersForPaths(["a", "c", "d"]) + + def testActiveFilter_dynamicFilterBehavior(self): + """Tests behavior of a dynamic active_filter.""" + loader_registry = [] + loader_factory = functools.partial( + _TimestampedByteLoader, registry=loader_registry + ) + threshold = 0 + active_filter = lambda timestamp: timestamp >= threshold + self._loader = directory_loader.DirectoryLoader( + self._directory, loader_factory, active_filter=active_filter + ) + + def assertLoadersForPaths(paths): + paths = [os.path.join(self._directory, path) for path in paths] + self.assertEqual(loader_registry, paths) + + self._WriteToFile("a", ["A1", "A2"], [1, 2]) + self._WriteToFile("b", ["B1", "B2", "B3"], [1, 2, 3]) + self._WriteToFile("c", ["C1"], [1]) + threshold = 2 + # First load pass should leave file C marked inactive. + self.assertLoaderYields(["A1", "A2", "B1", "B2", "B3", "C1"]) + assertLoadersForPaths(["a", "b"]) + self._WriteToFile("a", ["A4"], [4]) + self._WriteToFile("b", ["B4"], [4]) + self._WriteToFile("c", ["C4"], [4]) + threshold = 3 + # Second load pass should mark file A as inactive (due to newly + # increased threshold) and thus skip reading data from it. + self.assertLoaderYields(["B4"]) + assertLoadersForPaths(["b"]) + self._WriteToFile("b", ["B5", "B6"], [5, 6]) + # Simulate a third pass in which the threshold increases while + # we're processing a file, so it's still active at the start of the + # load but should be marked inactive at the end. + load_generator = self._loader.Load() + self.assertEqual("B5", next(load_generator)) + threshold = 7 + self.assertEqual(["B6"], list(load_generator)) + assertLoadersForPaths([]) + # Confirm that all loaders are now inactive. + self._WriteToFile("b", ["B7"], [7]) + self.assertLoaderYields([]) + + def testDoesntCrashWhenCurrentFileIsDeleted(self): + # Use actual file loader so it emits the real error. + self._loader = directory_loader.DirectoryLoader( + self._directory, event_file_loader.TimestampedEventFileLoader + ) + with test_util.FileWriter( + self._directory, filename_suffix=".a" + ) as writer_a: + writer_a.add_test_summary("a") + events = list(self._loader.Load()) + events.pop(0) # Ignore the file_version event. + self.assertEqual(1, len(events)) + self.assertEqual("a", events[0].summary.value[0].tag) + os.remove(glob.glob(os.path.join(self._directory, "*.a"))[0]) + with test_util.FileWriter( + self._directory, filename_suffix=".b" + ) as writer_b: + writer_b.add_test_summary("b") + events = list(self._loader.Load()) + events.pop(0) # Ignore the file_version event. + self.assertEqual(1, len(events)) + self.assertEqual("b", events[0].summary.value[0].tag) + + def testDoesntCrashWhenUpcomingFileIsDeleted(self): + # Use actual file loader so it emits the real error. + self._loader = directory_loader.DirectoryLoader( + self._directory, event_file_loader.TimestampedEventFileLoader + ) + with test_util.FileWriter( + self._directory, filename_suffix=".a" + ) as writer_a: + writer_a.add_test_summary("a") + with test_util.FileWriter( + self._directory, filename_suffix=".b" + ) as writer_b: + writer_b.add_test_summary("b") + generator = self._loader.Load() + next(generator) # Ignore the file_version event. + event = next(generator) + self.assertEqual("a", event.summary.value[0].tag) + os.remove(glob.glob(os.path.join(self._directory, "*.b"))[0]) + self.assertEmpty(list(generator)) + + def testRaisesDirectoryDeletedError_whenDirectoryIsDeleted(self): + self._WriteToFile("a", "a") + self.assertLoaderYields(["a"]) + shutil.rmtree(self._directory) + with self.assertRaises(directory_watcher.DirectoryDeletedError): + next(self._loader.Load()) + + def testDoesntRaiseDirectoryDeletedError_forUnrecognizedException(self): + self._WriteToFile("a", "a") + self.assertLoaderYields(["a"]) + + class MyException(Exception): + pass + + with mock.patch.object( + io_wrapper, "ListDirectoryAbsolute" + ) as mock_listdir: + mock_listdir.side_effect = MyException + with self.assertRaises(MyException): + next(self._loader.Load()) + self.assertLoaderYields([]) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/backend/event_processing/directory_watcher.py b/tensorboard/backend/event_processing/directory_watcher.py index 6ca7df19bd..05dca4603b 100644 --- a/tensorboard/backend/event_processing/directory_watcher.py +++ b/tensorboard/backend/event_processing/directory_watcher.py @@ -27,231 +27,252 @@ logger = tb_logging.get_logger() -class DirectoryWatcher(object): - """A DirectoryWatcher wraps a loader to load from a sequence of paths. - - A loader reads a path and produces some kind of values as an iterator. A - DirectoryWatcher takes a directory, a factory for loaders, and optionally a - path filter and watches all the paths inside that directory. - - This class is only valid under the assumption that only one path will be - written to by the data source at a time and that once the source stops writing - to a path, it will start writing to a new path that's lexicographically - greater and never come back. It uses some heuristics to check whether this is - true based on tracking changes to the files' sizes, but the check can have - false negatives. However, it should have no false positives. - """ - - def __init__(self, directory, loader_factory, path_filter=lambda x: True): - """Constructs a new DirectoryWatcher. - - Args: - directory: The directory to load files from. - loader_factory: A factory for creating loaders. The factory should take a - path and return an object that has a Load method returning an - iterator that will yield all events that have not been yielded yet. - path_filter: If specified, only paths matching this filter are loaded. - - Raises: - ValueError: If path_provider or loader_factory are None. - """ - if directory is None: - raise ValueError('A directory is required') - if loader_factory is None: - raise ValueError('A loader factory is required') - self._directory = directory - self._path = None - self._loader_factory = loader_factory - self._loader = None - self._path_filter = path_filter - self._ooo_writes_detected = False - # The file size for each file at the time it was finalized. - self._finalized_sizes = {} - - def Load(self): - """Loads new values. - - The watcher will load from one path at a time; as soon as that path stops - yielding events, it will move on to the next path. We assume that old paths - are never modified after a newer path has been written. As a result, Load() - can be called multiple times in a row without losing events that have not - been yielded yet. In other words, we guarantee that every event will be - yielded exactly once. - - Yields: - All values that have not been yielded yet. - - Raises: - DirectoryDeletedError: If the directory has been permanently deleted - (as opposed to being temporarily unavailable). - """ - try: - for event in self._LoadInternal(): - yield event - except tf.errors.OpError: - if not tf.io.gfile.exists(self._directory): - raise DirectoryDeletedError( - 'Directory %s has been permanently deleted' % self._directory) - - def _LoadInternal(self): - """Internal implementation of Load(). - - The only difference between this and Load() is that the latter will throw - DirectoryDeletedError on I/O errors if it thinks that the directory has been - permanently deleted. - - Yields: - All values that have not been yielded yet. - """ - - # If the loader exists, check it for a value. - if not self._loader: - self._InitializeLoader() - - # If it still doesn't exist, there is no data - if not self._loader: - return - - while True: - # Yield all the new events in the path we're currently loading from. - for event in self._loader.Load(): - yield event - - next_path = self._GetNextPath() - if not next_path: - logger.info('No path found after %s', self._path) - # Current path is empty and there are no new paths, so we're done. - return - - # There's a new path, so check to make sure there weren't any events - # written between when we finished reading the current path and when we - # checked for the new one. The sequence of events might look something - # like this: - # - # 1. Event #1 written to path #1. - # 2. We check for events and yield event #1 from path #1 - # 3. We check for events and see that there are no more events in path #1. - # 4. Event #2 is written to path #1. - # 5. Event #3 is written to path #2. - # 6. We check for a new path and see that path #2 exists. - # - # Without this loop, we would miss event #2. We're also guaranteed by the - # loader contract that no more events will be written to path #1 after - # events start being written to path #2, so we don't have to worry about - # that. - for event in self._loader.Load(): - yield event - - logger.info('Directory watcher advancing from %s to %s', self._path, - next_path) - - # Advance to the next path and start over. - self._SetPath(next_path) - - # The number of paths before the current one to check for out of order writes. - _OOO_WRITE_CHECK_COUNT = 20 - - def OutOfOrderWritesDetected(self): - """Returns whether any out-of-order writes have been detected. - - Out-of-order writes are only checked as part of the Load() iterator. Once an - out-of-order write is detected, this function will always return true. - - Note that out-of-order write detection is not performed on GCS paths, so - this function will always return false. - - Returns: - Whether any out-of-order write has ever been detected by this watcher. +class DirectoryWatcher(object): + """A DirectoryWatcher wraps a loader to load from a sequence of paths. + + A loader reads a path and produces some kind of values as an iterator. A + DirectoryWatcher takes a directory, a factory for loaders, and optionally a + path filter and watches all the paths inside that directory. + + This class is only valid under the assumption that only one path will be + written to by the data source at a time and that once the source stops writing + to a path, it will start writing to a new path that's lexicographically + greater and never come back. It uses some heuristics to check whether this is + true based on tracking changes to the files' sizes, but the check can have + false negatives. However, it should have no false positives. """ - return self._ooo_writes_detected - def _InitializeLoader(self): - path = self._GetNextPath() - if path: - self._SetPath(path) + def __init__(self, directory, loader_factory, path_filter=lambda x: True): + """Constructs a new DirectoryWatcher. + + Args: + directory: The directory to load files from. + loader_factory: A factory for creating loaders. The factory should take a + path and return an object that has a Load method returning an + iterator that will yield all events that have not been yielded yet. + path_filter: If specified, only paths matching this filter are loaded. + + Raises: + ValueError: If path_provider or loader_factory are None. + """ + if directory is None: + raise ValueError("A directory is required") + if loader_factory is None: + raise ValueError("A loader factory is required") + self._directory = directory + self._path = None + self._loader_factory = loader_factory + self._loader = None + self._path_filter = path_filter + self._ooo_writes_detected = False + # The file size for each file at the time it was finalized. + self._finalized_sizes = {} + + def Load(self): + """Loads new values. + + The watcher will load from one path at a time; as soon as that path stops + yielding events, it will move on to the next path. We assume that old paths + are never modified after a newer path has been written. As a result, Load() + can be called multiple times in a row without losing events that have not + been yielded yet. In other words, we guarantee that every event will be + yielded exactly once. + + Yields: + All values that have not been yielded yet. + + Raises: + DirectoryDeletedError: If the directory has been permanently deleted + (as opposed to being temporarily unavailable). + """ + try: + for event in self._LoadInternal(): + yield event + except tf.errors.OpError: + if not tf.io.gfile.exists(self._directory): + raise DirectoryDeletedError( + "Directory %s has been permanently deleted" + % self._directory + ) + + def _LoadInternal(self): + """Internal implementation of Load(). + + The only difference between this and Load() is that the latter will throw + DirectoryDeletedError on I/O errors if it thinks that the directory has been + permanently deleted. + + Yields: + All values that have not been yielded yet. + """ + + # If the loader exists, check it for a value. + if not self._loader: + self._InitializeLoader() + + # If it still doesn't exist, there is no data + if not self._loader: + return + + while True: + # Yield all the new events in the path we're currently loading from. + for event in self._loader.Load(): + yield event + + next_path = self._GetNextPath() + if not next_path: + logger.info("No path found after %s", self._path) + # Current path is empty and there are no new paths, so we're done. + return + + # There's a new path, so check to make sure there weren't any events + # written between when we finished reading the current path and when we + # checked for the new one. The sequence of events might look something + # like this: + # + # 1. Event #1 written to path #1. + # 2. We check for events and yield event #1 from path #1 + # 3. We check for events and see that there are no more events in path #1. + # 4. Event #2 is written to path #1. + # 5. Event #3 is written to path #2. + # 6. We check for a new path and see that path #2 exists. + # + # Without this loop, we would miss event #2. We're also guaranteed by the + # loader contract that no more events will be written to path #1 after + # events start being written to path #2, so we don't have to worry about + # that. + for event in self._loader.Load(): + yield event + + logger.info( + "Directory watcher advancing from %s to %s", + self._path, + next_path, + ) + + # Advance to the next path and start over. + self._SetPath(next_path) + + # The number of paths before the current one to check for out of order writes. + _OOO_WRITE_CHECK_COUNT = 20 + + def OutOfOrderWritesDetected(self): + """Returns whether any out-of-order writes have been detected. + + Out-of-order writes are only checked as part of the Load() iterator. Once an + out-of-order write is detected, this function will always return true. + + Note that out-of-order write detection is not performed on GCS paths, so + this function will always return false. + + Returns: + Whether any out-of-order write has ever been detected by this watcher. + """ + return self._ooo_writes_detected + + def _InitializeLoader(self): + path = self._GetNextPath() + if path: + self._SetPath(path) + + def _SetPath(self, path): + """Sets the current path to watch for new events. + + This also records the size of the old path, if any. If the size can't be + found, an error is logged. + + Args: + path: The full path of the file to watch. + """ + old_path = self._path + if old_path and not io_wrapper.IsCloudPath(old_path): + try: + # We're done with the path, so store its size. + size = tf.io.gfile.stat(old_path).length + logger.debug("Setting latest size of %s to %d", old_path, size) + self._finalized_sizes[old_path] = size + except tf.errors.OpError as e: + logger.error("Unable to get size of %s: %s", old_path, e) + + self._path = path + self._loader = self._loader_factory(path) + + def _GetNextPath(self): + """Gets the next path to load from. + + This function also does the checking for out-of-order writes as it iterates + through the paths. + + Returns: + The next path to load events from, or None if there are no more paths. + """ + paths = sorted( + path + for path in io_wrapper.ListDirectoryAbsolute(self._directory) + if self._path_filter(path) + ) + if not paths: + return None + + if self._path is None: + return paths[0] + + # Don't bother checking if the paths are GCS (which we can't check) or if + # we've already detected an OOO write. + if ( + not io_wrapper.IsCloudPath(paths[0]) + and not self._ooo_writes_detected + ): + # Check the previous _OOO_WRITE_CHECK_COUNT paths for out of order writes. + current_path_index = bisect.bisect_left(paths, self._path) + ooo_check_start = max( + 0, current_path_index - self._OOO_WRITE_CHECK_COUNT + ) + for path in paths[ooo_check_start:current_path_index]: + if self._HasOOOWrite(path): + self._ooo_writes_detected = True + break + + next_paths = list( + path for path in paths if self._path is None or path > self._path + ) + if next_paths: + return min(next_paths) + else: + return None + + def _HasOOOWrite(self, path): + """Returns whether the path has had an out-of-order write.""" + # Check the sizes of each path before the current one. + size = tf.io.gfile.stat(path).length + old_size = self._finalized_sizes.get(path, None) + if size != old_size: + if old_size is None: + logger.error( + "File %s created after file %s even though it's " + "lexicographically earlier", + path, + self._path, + ) + else: + logger.error( + "File %s updated even though the current file is %s", + path, + self._path, + ) + return True + else: + return False - def _SetPath(self, path): - """Sets the current path to watch for new events. - This also records the size of the old path, if any. If the size can't be - found, an error is logged. +class DirectoryDeletedError(Exception): + """Thrown by Load() when the directory is *permanently* gone. - Args: - path: The full path of the file to watch. - """ - old_path = self._path - if old_path and not io_wrapper.IsCloudPath(old_path): - try: - # We're done with the path, so store its size. - size = tf.io.gfile.stat(old_path).length - logger.debug('Setting latest size of %s to %d', old_path, size) - self._finalized_sizes[old_path] = size - except tf.errors.OpError as e: - logger.error('Unable to get size of %s: %s', old_path, e) - - self._path = path - self._loader = self._loader_factory(path) - - def _GetNextPath(self): - """Gets the next path to load from. - - This function also does the checking for out-of-order writes as it iterates - through the paths. - - Returns: - The next path to load events from, or None if there are no more paths. + We distinguish this from temporary errors so that other code can + decide to drop all of our data only when a directory has been + intentionally deleted, as opposed to due to transient filesystem + errors. """ - paths = sorted(path - for path in io_wrapper.ListDirectoryAbsolute(self._directory) - if self._path_filter(path)) - if not paths: - return None - - if self._path is None: - return paths[0] - - # Don't bother checking if the paths are GCS (which we can't check) or if - # we've already detected an OOO write. - if not io_wrapper.IsCloudPath(paths[0]) and not self._ooo_writes_detected: - # Check the previous _OOO_WRITE_CHECK_COUNT paths for out of order writes. - current_path_index = bisect.bisect_left(paths, self._path) - ooo_check_start = max(0, current_path_index - self._OOO_WRITE_CHECK_COUNT) - for path in paths[ooo_check_start:current_path_index]: - if self._HasOOOWrite(path): - self._ooo_writes_detected = True - break - - next_paths = list(path - for path in paths - if self._path is None or path > self._path) - if next_paths: - return min(next_paths) - else: - return None - - def _HasOOOWrite(self, path): - """Returns whether the path has had an out-of-order write.""" - # Check the sizes of each path before the current one. - size = tf.io.gfile.stat(path).length - old_size = self._finalized_sizes.get(path, None) - if size != old_size: - if old_size is None: - logger.error('File %s created after file %s even though it\'s ' - 'lexicographically earlier', path, self._path) - else: - logger.error('File %s updated even though the current file is %s', - path, self._path) - return True - else: - return False - - -class DirectoryDeletedError(Exception): - """Thrown by Load() when the directory is *permanently* gone. - We distinguish this from temporary errors so that other code can decide to - drop all of our data only when a directory has been intentionally deleted, - as opposed to due to transient filesystem errors. - """ - pass + pass diff --git a/tensorboard/backend/event_processing/directory_watcher_test.py b/tensorboard/backend/event_processing/directory_watcher_test.py index b9a00007d4..619edb8fbb 100644 --- a/tensorboard/backend/event_processing/directory_watcher_test.py +++ b/tensorboard/backend/event_processing/directory_watcher_test.py @@ -29,185 +29,192 @@ class _ByteLoader(object): - """A loader that loads individual bytes from a file.""" + """A loader that loads individual bytes from a file.""" - def __init__(self, path): - self._f = open(path) - self.bytes_read = 0 + def __init__(self, path): + self._f = open(path) + self.bytes_read = 0 - def Load(self): - while True: - self._f.seek(self.bytes_read) - byte = self._f.read(1) - if byte: - self.bytes_read += 1 - yield byte - else: - return + def Load(self): + while True: + self._f.seek(self.bytes_read) + byte = self._f.read(1) + if byte: + self.bytes_read += 1 + yield byte + else: + return class DirectoryWatcherTest(tf.test.TestCase): - - def setUp(self): - # Put everything in a directory so it's easier to delete. - self._directory = os.path.join(self.get_temp_dir(), 'monitor_dir') - os.mkdir(self._directory) - self._watcher = directory_watcher.DirectoryWatcher(self._directory, - _ByteLoader) - self.stubs = tf.compat.v1.test.StubOutForTesting() - - def tearDown(self): - self.stubs.CleanUp() - try: - shutil.rmtree(self._directory) - except OSError: - # Some tests delete the directory. - pass - - def _WriteToFile(self, filename, data): - path = os.path.join(self._directory, filename) - with open(path, 'a') as f: - f.write(data) - - def _LoadAllEvents(self): - """Loads all events in the watcher.""" - for _ in self._watcher.Load(): - pass - - def assertWatcherYields(self, values): - self.assertEqual(list(self._watcher.Load()), values) - - def testRaisesWithBadArguments(self): - with self.assertRaises(ValueError): - directory_watcher.DirectoryWatcher(None, lambda x: None) - with self.assertRaises(ValueError): - directory_watcher.DirectoryWatcher('dir', None) - - def testEmptyDirectory(self): - self.assertWatcherYields([]) - - def testSingleWrite(self): - self._WriteToFile('a', 'abc') - self.assertWatcherYields(['a', 'b', 'c']) - self.assertFalse(self._watcher.OutOfOrderWritesDetected()) - - def testMultipleWrites(self): - self._WriteToFile('a', 'abc') - self.assertWatcherYields(['a', 'b', 'c']) - self._WriteToFile('a', 'xyz') - self.assertWatcherYields(['x', 'y', 'z']) - self.assertFalse(self._watcher.OutOfOrderWritesDetected()) - - def testMultipleLoads(self): - self._WriteToFile('a', 'a') - self._watcher.Load() - self._watcher.Load() - self.assertWatcherYields(['a']) - self.assertFalse(self._watcher.OutOfOrderWritesDetected()) - - def testMultipleFilesAtOnce(self): - self._WriteToFile('b', 'b') - self._WriteToFile('a', 'a') - self.assertWatcherYields(['a', 'b']) - self.assertFalse(self._watcher.OutOfOrderWritesDetected()) - - def testFinishesLoadingFileWhenSwitchingToNewFile(self): - self._WriteToFile('a', 'a') - # Empty the iterator. - self.assertEqual(['a'], list(self._watcher.Load())) - self._WriteToFile('a', 'b') - self._WriteToFile('b', 'c') - # The watcher should finish its current file before starting a new one. - self.assertWatcherYields(['b', 'c']) - self.assertFalse(self._watcher.OutOfOrderWritesDetected()) - - def testIntermediateEmptyFiles(self): - self._WriteToFile('a', 'a') - self._WriteToFile('b', '') - self._WriteToFile('c', 'c') - self.assertWatcherYields(['a', 'c']) - self.assertFalse(self._watcher.OutOfOrderWritesDetected()) - - def testPathFilter(self): - self._watcher = directory_watcher.DirectoryWatcher( - self._directory, _ByteLoader, - lambda path: 'do_not_watch_me' not in path) - - self._WriteToFile('a', 'a') - self._WriteToFile('do_not_watch_me', 'b') - self._WriteToFile('c', 'c') - self.assertWatcherYields(['a', 'c']) - self.assertFalse(self._watcher.OutOfOrderWritesDetected()) - - def testDetectsNewOldFiles(self): - self._WriteToFile('b', 'a') - self._LoadAllEvents() - self._WriteToFile('a', 'a') - self._LoadAllEvents() - self.assertTrue(self._watcher.OutOfOrderWritesDetected()) - - def testIgnoresNewerFiles(self): - self._WriteToFile('a', 'a') - self._LoadAllEvents() - self._WriteToFile('q', 'a') - self._LoadAllEvents() - self.assertFalse(self._watcher.OutOfOrderWritesDetected()) - - def testDetectsChangingOldFiles(self): - self._WriteToFile('a', 'a') - self._WriteToFile('b', 'a') - self._LoadAllEvents() - self._WriteToFile('a', 'c') - self._LoadAllEvents() - self.assertTrue(self._watcher.OutOfOrderWritesDetected()) - - def testDoesntCrashWhenFileIsDeleted(self): - self._WriteToFile('a', 'a') - self._LoadAllEvents() - os.remove(os.path.join(self._directory, 'a')) - self._WriteToFile('b', 'b') - self.assertWatcherYields(['b']) - - def testRaisesRightErrorWhenDirectoryIsDeleted(self): - self._WriteToFile('a', 'a') - self._LoadAllEvents() - shutil.rmtree(self._directory) - with self.assertRaises(directory_watcher.DirectoryDeletedError): - self._LoadAllEvents() - - def testDoesntRaiseDirectoryDeletedErrorIfOutageIsTransient(self): - self._WriteToFile('a', 'a') - self._LoadAllEvents() - shutil.rmtree(self._directory) - - # Fake a single transient I/O error. - def FakeFactory(original): - - def Fake(*args, **kwargs): - if FakeFactory.has_been_called: - original(*args, **kwargs) - else: - raise OSError('lp0 temporarily on fire') - - return Fake - - FakeFactory.has_been_called = False - - stub_names = [ - 'ListDirectoryAbsolute', - 'ListRecursivelyViaGlobbing', - 'ListRecursivelyViaWalking', - ] - for stub_name in stub_names: - self.stubs.Set(io_wrapper, stub_name, - FakeFactory(getattr(io_wrapper, stub_name))) - for stub_name in ['exists', 'stat']: - self.stubs.Set(tf.io.gfile, stub_name, - FakeFactory(getattr(tf.io.gfile, stub_name))) - - with self.assertRaises((IOError, OSError)): - self._LoadAllEvents() - - -if __name__ == '__main__': - tf.test.main() + def setUp(self): + # Put everything in a directory so it's easier to delete. + self._directory = os.path.join(self.get_temp_dir(), "monitor_dir") + os.mkdir(self._directory) + self._watcher = directory_watcher.DirectoryWatcher( + self._directory, _ByteLoader + ) + self.stubs = tf.compat.v1.test.StubOutForTesting() + + def tearDown(self): + self.stubs.CleanUp() + try: + shutil.rmtree(self._directory) + except OSError: + # Some tests delete the directory. + pass + + def _WriteToFile(self, filename, data): + path = os.path.join(self._directory, filename) + with open(path, "a") as f: + f.write(data) + + def _LoadAllEvents(self): + """Loads all events in the watcher.""" + for _ in self._watcher.Load(): + pass + + def assertWatcherYields(self, values): + self.assertEqual(list(self._watcher.Load()), values) + + def testRaisesWithBadArguments(self): + with self.assertRaises(ValueError): + directory_watcher.DirectoryWatcher(None, lambda x: None) + with self.assertRaises(ValueError): + directory_watcher.DirectoryWatcher("dir", None) + + def testEmptyDirectory(self): + self.assertWatcherYields([]) + + def testSingleWrite(self): + self._WriteToFile("a", "abc") + self.assertWatcherYields(["a", "b", "c"]) + self.assertFalse(self._watcher.OutOfOrderWritesDetected()) + + def testMultipleWrites(self): + self._WriteToFile("a", "abc") + self.assertWatcherYields(["a", "b", "c"]) + self._WriteToFile("a", "xyz") + self.assertWatcherYields(["x", "y", "z"]) + self.assertFalse(self._watcher.OutOfOrderWritesDetected()) + + def testMultipleLoads(self): + self._WriteToFile("a", "a") + self._watcher.Load() + self._watcher.Load() + self.assertWatcherYields(["a"]) + self.assertFalse(self._watcher.OutOfOrderWritesDetected()) + + def testMultipleFilesAtOnce(self): + self._WriteToFile("b", "b") + self._WriteToFile("a", "a") + self.assertWatcherYields(["a", "b"]) + self.assertFalse(self._watcher.OutOfOrderWritesDetected()) + + def testFinishesLoadingFileWhenSwitchingToNewFile(self): + self._WriteToFile("a", "a") + # Empty the iterator. + self.assertEqual(["a"], list(self._watcher.Load())) + self._WriteToFile("a", "b") + self._WriteToFile("b", "c") + # The watcher should finish its current file before starting a new one. + self.assertWatcherYields(["b", "c"]) + self.assertFalse(self._watcher.OutOfOrderWritesDetected()) + + def testIntermediateEmptyFiles(self): + self._WriteToFile("a", "a") + self._WriteToFile("b", "") + self._WriteToFile("c", "c") + self.assertWatcherYields(["a", "c"]) + self.assertFalse(self._watcher.OutOfOrderWritesDetected()) + + def testPathFilter(self): + self._watcher = directory_watcher.DirectoryWatcher( + self._directory, + _ByteLoader, + lambda path: "do_not_watch_me" not in path, + ) + + self._WriteToFile("a", "a") + self._WriteToFile("do_not_watch_me", "b") + self._WriteToFile("c", "c") + self.assertWatcherYields(["a", "c"]) + self.assertFalse(self._watcher.OutOfOrderWritesDetected()) + + def testDetectsNewOldFiles(self): + self._WriteToFile("b", "a") + self._LoadAllEvents() + self._WriteToFile("a", "a") + self._LoadAllEvents() + self.assertTrue(self._watcher.OutOfOrderWritesDetected()) + + def testIgnoresNewerFiles(self): + self._WriteToFile("a", "a") + self._LoadAllEvents() + self._WriteToFile("q", "a") + self._LoadAllEvents() + self.assertFalse(self._watcher.OutOfOrderWritesDetected()) + + def testDetectsChangingOldFiles(self): + self._WriteToFile("a", "a") + self._WriteToFile("b", "a") + self._LoadAllEvents() + self._WriteToFile("a", "c") + self._LoadAllEvents() + self.assertTrue(self._watcher.OutOfOrderWritesDetected()) + + def testDoesntCrashWhenFileIsDeleted(self): + self._WriteToFile("a", "a") + self._LoadAllEvents() + os.remove(os.path.join(self._directory, "a")) + self._WriteToFile("b", "b") + self.assertWatcherYields(["b"]) + + def testRaisesRightErrorWhenDirectoryIsDeleted(self): + self._WriteToFile("a", "a") + self._LoadAllEvents() + shutil.rmtree(self._directory) + with self.assertRaises(directory_watcher.DirectoryDeletedError): + self._LoadAllEvents() + + def testDoesntRaiseDirectoryDeletedErrorIfOutageIsTransient(self): + self._WriteToFile("a", "a") + self._LoadAllEvents() + shutil.rmtree(self._directory) + + # Fake a single transient I/O error. + def FakeFactory(original): + def Fake(*args, **kwargs): + if FakeFactory.has_been_called: + original(*args, **kwargs) + else: + raise OSError("lp0 temporarily on fire") + + return Fake + + FakeFactory.has_been_called = False + + stub_names = [ + "ListDirectoryAbsolute", + "ListRecursivelyViaGlobbing", + "ListRecursivelyViaWalking", + ] + for stub_name in stub_names: + self.stubs.Set( + io_wrapper, + stub_name, + FakeFactory(getattr(io_wrapper, stub_name)), + ) + for stub_name in ["exists", "stat"]: + self.stubs.Set( + tf.io.gfile, + stub_name, + FakeFactory(getattr(tf.io.gfile, stub_name)), + ) + + with self.assertRaises((IOError, OSError)): + self._LoadAllEvents() + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/backend/event_processing/event_accumulator.py b/tensorboard/backend/event_processing/event_accumulator.py index 83c42685b6..d616cf3006 100644 --- a/tensorboard/backend/event_processing/event_accumulator.py +++ b/tensorboard/backend/event_processing/event_accumulator.py @@ -37,49 +37,61 @@ logger = tb_logging.get_logger() namedtuple = collections.namedtuple -ScalarEvent = namedtuple('ScalarEvent', ['wall_time', 'step', 'value']) - -CompressedHistogramEvent = namedtuple('CompressedHistogramEvent', - ['wall_time', 'step', - 'compressed_histogram_values']) - -HistogramEvent = namedtuple('HistogramEvent', - ['wall_time', 'step', 'histogram_value']) - -HistogramValue = namedtuple('HistogramValue', ['min', 'max', 'num', 'sum', - 'sum_squares', 'bucket_limit', - 'bucket']) - -ImageEvent = namedtuple('ImageEvent', ['wall_time', 'step', - 'encoded_image_string', 'width', - 'height']) - -AudioEvent = namedtuple('AudioEvent', ['wall_time', 'step', - 'encoded_audio_string', 'content_type', - 'sample_rate', 'length_frames']) - -TensorEvent = namedtuple('TensorEvent', ['wall_time', 'step', 'tensor_proto']) +ScalarEvent = namedtuple("ScalarEvent", ["wall_time", "step", "value"]) + +CompressedHistogramEvent = namedtuple( + "CompressedHistogramEvent", + ["wall_time", "step", "compressed_histogram_values"], +) + +HistogramEvent = namedtuple( + "HistogramEvent", ["wall_time", "step", "histogram_value"] +) + +HistogramValue = namedtuple( + "HistogramValue", + ["min", "max", "num", "sum", "sum_squares", "bucket_limit", "bucket"], +) + +ImageEvent = namedtuple( + "ImageEvent", + ["wall_time", "step", "encoded_image_string", "width", "height"], +) + +AudioEvent = namedtuple( + "AudioEvent", + [ + "wall_time", + "step", + "encoded_audio_string", + "content_type", + "sample_rate", + "length_frames", + ], +) + +TensorEvent = namedtuple("TensorEvent", ["wall_time", "step", "tensor_proto"]) ## Different types of summary events handled by the event_accumulator SUMMARY_TYPES = { - 'simple_value': '_ProcessScalar', - 'histo': '_ProcessHistogram', - 'image': '_ProcessImage', - 'audio': '_ProcessAudio', - 'tensor': '_ProcessTensor', + "simple_value": "_ProcessScalar", + "histo": "_ProcessHistogram", + "image": "_ProcessImage", + "audio": "_ProcessAudio", + "tensor": "_ProcessTensor", } ## The tagTypes below are just arbitrary strings chosen to pass the type ## information of the tag from the backend to the frontend -COMPRESSED_HISTOGRAMS = 'distributions' -HISTOGRAMS = 'histograms' -IMAGES = 'images' -AUDIO = 'audio' -SCALARS = 'scalars' -TENSORS = 'tensors' -GRAPH = 'graph' -META_GRAPH = 'meta_graph' -RUN_METADATA = 'run_metadata' +COMPRESSED_HISTOGRAMS = "distributions" +HISTOGRAMS = "histograms" +IMAGES = "images" +AUDIO = "audio" +SCALARS = "scalars" +TENSORS = "tensors" +GRAPH = "graph" +META_GRAPH = "meta_graph" +RUN_METADATA = "run_metadata" ## Normal CDF for std_devs: (-Inf, -1.5, -1, -0.5, 0, 0.5, 1, 1.5, Inf) ## naturally gives bands around median of width 1 std dev, 2 std dev, 3 std dev, @@ -106,662 +118,737 @@ class EventAccumulator(object): - """An `EventAccumulator` takes an event generator, and accumulates the values. - - The `EventAccumulator` is intended to provide a convenient Python interface - for loading Event data written during a TensorFlow run. TensorFlow writes out - `Event` protobuf objects, which have a timestamp and step number, and often - contain a `Summary`. Summaries can have different kinds of data like an image, - a scalar value, or a histogram. The Summaries also have a tag, which we use to - organize logically related data. The `EventAccumulator` supports retrieving - the `Event` and `Summary` data by its tag. - - Calling `Tags()` gets a map from `tagType` (e.g. `'images'`, - `'compressedHistograms'`, `'scalars'`, etc) to the associated tags for those - data types. Then, various functional endpoints (eg - `Accumulator.Scalars(tag)`) allow for the retrieval of all data - associated with that tag. - - The `Reload()` method synchronously loads all of the data written so far. - - Histograms, audio, and images are very large, so storing all of them is not - recommended. - - Fields: - audios: A reservoir.Reservoir of audio summaries. - compressed_histograms: A reservoir.Reservoir of compressed - histogram summaries. - histograms: A reservoir.Reservoir of histogram summaries. - images: A reservoir.Reservoir of image summaries. - most_recent_step: Step of last Event proto added. This should only - be accessed from the thread that calls Reload. This is -1 if - nothing has been loaded yet. - most_recent_wall_time: Timestamp of last Event proto added. This is - a float containing seconds from the UNIX epoch, or -1 if - nothing has been loaded yet. This should only be accessed from - the thread that calls Reload. - path: A file path to a directory containing tf events files, or a single - tf events file. The accumulator will load events from this path. - scalars: A reservoir.Reservoir of scalar summaries. - tensors: A reservoir.Reservoir of tensor summaries. - - @@Tensors - """ - - def __init__(self, - path, - size_guidance=None, - compression_bps=NORMAL_HISTOGRAM_BPS, - purge_orphaned_data=True): - """Construct the `EventAccumulator`. - - Args: + """An `EventAccumulator` takes an event generator, and accumulates the + values. + + The `EventAccumulator` is intended to provide a convenient Python interface + for loading Event data written during a TensorFlow run. TensorFlow writes out + `Event` protobuf objects, which have a timestamp and step number, and often + contain a `Summary`. Summaries can have different kinds of data like an image, + a scalar value, or a histogram. The Summaries also have a tag, which we use to + organize logically related data. The `EventAccumulator` supports retrieving + the `Event` and `Summary` data by its tag. + + Calling `Tags()` gets a map from `tagType` (e.g. `'images'`, + `'compressedHistograms'`, `'scalars'`, etc) to the associated tags for those + data types. Then, various functional endpoints (eg + `Accumulator.Scalars(tag)`) allow for the retrieval of all data + associated with that tag. + + The `Reload()` method synchronously loads all of the data written so far. + + Histograms, audio, and images are very large, so storing all of them is not + recommended. + + Fields: + audios: A reservoir.Reservoir of audio summaries. + compressed_histograms: A reservoir.Reservoir of compressed + histogram summaries. + histograms: A reservoir.Reservoir of histogram summaries. + images: A reservoir.Reservoir of image summaries. + most_recent_step: Step of last Event proto added. This should only + be accessed from the thread that calls Reload. This is -1 if + nothing has been loaded yet. + most_recent_wall_time: Timestamp of last Event proto added. This is + a float containing seconds from the UNIX epoch, or -1 if + nothing has been loaded yet. This should only be accessed from + the thread that calls Reload. path: A file path to a directory containing tf events files, or a single - tf events file. The accumulator will load events from this path. - size_guidance: Information on how much data the EventAccumulator should - store in memory. The DEFAULT_SIZE_GUIDANCE tries not to store too much - so as to avoid OOMing the client. The size_guidance should be a map - from a `tagType` string to an integer representing the number of - items to keep per tag for items of that `tagType`. If the size is 0, - all events are stored. - compression_bps: Information on how the `EventAccumulator` should compress - histogram data for the `CompressedHistograms` tag (for details see - `ProcessCompressedHistogram`). - purge_orphaned_data: Whether to discard any events that were "orphaned" by - a TensorFlow restart. - """ - size_guidance = size_guidance or DEFAULT_SIZE_GUIDANCE - sizes = {} - for key in DEFAULT_SIZE_GUIDANCE: - if key in size_guidance: - sizes[key] = size_guidance[key] - else: - sizes[key] = DEFAULT_SIZE_GUIDANCE[key] - - self._first_event_timestamp = None - self.scalars = reservoir.Reservoir(size=sizes[SCALARS]) - - self._graph = None - self._graph_from_metagraph = False - self._meta_graph = None - self._tagged_metadata = {} - self.summary_metadata = {} - self.histograms = reservoir.Reservoir(size=sizes[HISTOGRAMS]) - self.compressed_histograms = reservoir.Reservoir( - size=sizes[COMPRESSED_HISTOGRAMS], always_keep_last=False) - self.images = reservoir.Reservoir(size=sizes[IMAGES]) - self.audios = reservoir.Reservoir(size=sizes[AUDIO]) - self.tensors = reservoir.Reservoir(size=sizes[TENSORS]) - - # Keep a mapping from plugin name to a dict mapping from tag to plugin data - # content obtained from the SummaryMetadata (metadata field of Value) for - # that plugin (This is not the entire SummaryMetadata proto - only the - # content for that plugin). The SummaryWriter only keeps the content on the - # first event encountered per tag, so we must store that first instance of - # content for each tag. - self._plugin_to_tag_to_content = collections.defaultdict(dict) - - self._generator_mutex = threading.Lock() - self.path = path - self._generator = _GeneratorFromPath(path) - - self._compression_bps = compression_bps - self.purge_orphaned_data = purge_orphaned_data - - self.most_recent_step = -1 - self.most_recent_wall_time = -1 - self.file_version = None - - # The attributes that get built up by the accumulator - self.accumulated_attrs = ('scalars', 'histograms', - 'compressed_histograms', 'images', 'audios') - self._tensor_summaries = {} - - def Reload(self): - """Loads all events added since the last call to `Reload`. - - If `Reload` was never called, loads all events in the file. + tf events file. The accumulator will load events from this path. + scalars: A reservoir.Reservoir of scalar summaries. + tensors: A reservoir.Reservoir of tensor summaries. - Returns: - The `EventAccumulator`. - """ - with self._generator_mutex: - for event in self._generator.Load(): - self._ProcessEvent(event) - return self - - def PluginAssets(self, plugin_name): - """Return a list of all plugin assets for the given plugin. - - Args: - plugin_name: The string name of a plugin to retrieve assets for. - - Returns: - A list of string plugin asset names, or empty list if none are available. - If the plugin was not registered, an empty list is returned. + @@Tensors """ - return plugin_asset_util.ListAssets(self.path, plugin_name) - - def RetrievePluginAsset(self, plugin_name, asset_name): - """Return the contents of a given plugin asset. - - Args: - plugin_name: The string name of a plugin. - asset_name: The string name of an asset. - - Returns: - The string contents of the plugin asset. - - Raises: - KeyError: If the asset is not available. - """ - return plugin_asset_util.RetrieveAsset(self.path, plugin_name, asset_name) - - def FirstEventTimestamp(self): - """Returns the timestamp in seconds of the first event. - - If the first event has been loaded (either by this method or by `Reload`, - this returns immediately. Otherwise, it will load in the first event. Note - that this means that calling `Reload` will cause this to block until - `Reload` has finished. - - Returns: - The timestamp in seconds of the first event that was loaded. - - Raises: - ValueError: If no events have been loaded and there were no events found - on disk. - """ - if self._first_event_timestamp is not None: - return self._first_event_timestamp - with self._generator_mutex: - try: - event = next(self._generator.Load()) - self._ProcessEvent(event) - return self._first_event_timestamp - - except StopIteration: - raise ValueError('No event timestamp could be found') - - def PluginTagToContent(self, plugin_name): - """Returns a dict mapping tags to content specific to that plugin. - - Args: - plugin_name: The name of the plugin for which to fetch plugin-specific - content. - Raises: - KeyError: if the plugin name is not found. - - Returns: - A dict mapping tag names to bytestrings of plugin-specific content-- by - convention, in the form of binary serialized protos. - """ - if plugin_name not in self._plugin_to_tag_to_content: - raise KeyError('Plugin %r could not be found.' % plugin_name) - return self._plugin_to_tag_to_content[plugin_name] - - def SummaryMetadata(self, tag): - """Given a summary tag name, return the associated metadata object. - - Args: - tag: The name of a tag, as a string. - - Raises: - KeyError: If the tag is not found. - - Returns: - A `SummaryMetadata` protobuf. - """ - return self.summary_metadata[tag] - - def _ProcessEvent(self, event): - """Called whenever an event is loaded.""" - if self._first_event_timestamp is None: - self._first_event_timestamp = event.wall_time - - if event.HasField('file_version'): - new_file_version = _ParseFileVersion(event.file_version) - if self.file_version and self.file_version != new_file_version: - ## This should not happen. - logger.warn(('Found new file_version for event.proto. This will ' - 'affect purging logic for TensorFlow restarts. ' - 'Old: {0} New: {1}').format(self.file_version, - new_file_version)) - self.file_version = new_file_version - - self._MaybePurgeOrphanedData(event) - - ## Process the event. - # GraphDef and MetaGraphDef are handled in a special way: - # If no graph_def Event is available, but a meta_graph_def is, and it - # contains a graph_def, then use the meta_graph_def.graph_def as our graph. - # If a graph_def Event is available, always prefer it to the graph_def - # inside the meta_graph_def. - if event.HasField('graph_def'): - if self._graph is not None: - logger.warn( - ('Found more than one graph event per run, or there was ' - 'a metagraph containing a graph_def, as well as one or ' - 'more graph events. Overwriting the graph with the ' - 'newest event.')) - self._graph = event.graph_def - self._graph_from_metagraph = False - elif event.HasField('meta_graph_def'): - if self._meta_graph is not None: - logger.warn(('Found more than one metagraph event per run. ' - 'Overwriting the metagraph with the newest event.')) - self._meta_graph = event.meta_graph_def - if self._graph is None or self._graph_from_metagraph: - # We may have a graph_def in the metagraph. If so, and no - # graph_def is directly available, use this one instead. + def __init__( + self, + path, + size_guidance=None, + compression_bps=NORMAL_HISTOGRAM_BPS, + purge_orphaned_data=True, + ): + """Construct the `EventAccumulator`. + + Args: + path: A file path to a directory containing tf events files, or a single + tf events file. The accumulator will load events from this path. + size_guidance: Information on how much data the EventAccumulator should + store in memory. The DEFAULT_SIZE_GUIDANCE tries not to store too much + so as to avoid OOMing the client. The size_guidance should be a map + from a `tagType` string to an integer representing the number of + items to keep per tag for items of that `tagType`. If the size is 0, + all events are stored. + compression_bps: Information on how the `EventAccumulator` should compress + histogram data for the `CompressedHistograms` tag (for details see + `ProcessCompressedHistogram`). + purge_orphaned_data: Whether to discard any events that were "orphaned" by + a TensorFlow restart. + """ + size_guidance = size_guidance or DEFAULT_SIZE_GUIDANCE + sizes = {} + for key in DEFAULT_SIZE_GUIDANCE: + if key in size_guidance: + sizes[key] = size_guidance[key] + else: + sizes[key] = DEFAULT_SIZE_GUIDANCE[key] + + self._first_event_timestamp = None + self.scalars = reservoir.Reservoir(size=sizes[SCALARS]) + + self._graph = None + self._graph_from_metagraph = False + self._meta_graph = None + self._tagged_metadata = {} + self.summary_metadata = {} + self.histograms = reservoir.Reservoir(size=sizes[HISTOGRAMS]) + self.compressed_histograms = reservoir.Reservoir( + size=sizes[COMPRESSED_HISTOGRAMS], always_keep_last=False + ) + self.images = reservoir.Reservoir(size=sizes[IMAGES]) + self.audios = reservoir.Reservoir(size=sizes[AUDIO]) + self.tensors = reservoir.Reservoir(size=sizes[TENSORS]) + + # Keep a mapping from plugin name to a dict mapping from tag to plugin data + # content obtained from the SummaryMetadata (metadata field of Value) for + # that plugin (This is not the entire SummaryMetadata proto - only the + # content for that plugin). The SummaryWriter only keeps the content on the + # first event encountered per tag, so we must store that first instance of + # content for each tag. + self._plugin_to_tag_to_content = collections.defaultdict(dict) + + self._generator_mutex = threading.Lock() + self.path = path + self._generator = _GeneratorFromPath(path) + + self._compression_bps = compression_bps + self.purge_orphaned_data = purge_orphaned_data + + self.most_recent_step = -1 + self.most_recent_wall_time = -1 + self.file_version = None + + # The attributes that get built up by the accumulator + self.accumulated_attrs = ( + "scalars", + "histograms", + "compressed_histograms", + "images", + "audios", + ) + self._tensor_summaries = {} + + def Reload(self): + """Loads all events added since the last call to `Reload`. + + If `Reload` was never called, loads all events in the file. + + Returns: + The `EventAccumulator`. + """ + with self._generator_mutex: + for event in self._generator.Load(): + self._ProcessEvent(event) + return self + + def PluginAssets(self, plugin_name): + """Return a list of all plugin assets for the given plugin. + + Args: + plugin_name: The string name of a plugin to retrieve assets for. + + Returns: + A list of string plugin asset names, or empty list if none are available. + If the plugin was not registered, an empty list is returned. + """ + return plugin_asset_util.ListAssets(self.path, plugin_name) + + def RetrievePluginAsset(self, plugin_name, asset_name): + """Return the contents of a given plugin asset. + + Args: + plugin_name: The string name of a plugin. + asset_name: The string name of an asset. + + Returns: + The string contents of the plugin asset. + + Raises: + KeyError: If the asset is not available. + """ + return plugin_asset_util.RetrieveAsset( + self.path, plugin_name, asset_name + ) + + def FirstEventTimestamp(self): + """Returns the timestamp in seconds of the first event. + + If the first event has been loaded (either by this method or by `Reload`, + this returns immediately. Otherwise, it will load in the first event. Note + that this means that calling `Reload` will cause this to block until + `Reload` has finished. + + Returns: + The timestamp in seconds of the first event that was loaded. + + Raises: + ValueError: If no events have been loaded and there were no events found + on disk. + """ + if self._first_event_timestamp is not None: + return self._first_event_timestamp + with self._generator_mutex: + try: + event = next(self._generator.Load()) + self._ProcessEvent(event) + return self._first_event_timestamp + + except StopIteration: + raise ValueError("No event timestamp could be found") + + def PluginTagToContent(self, plugin_name): + """Returns a dict mapping tags to content specific to that plugin. + + Args: + plugin_name: The name of the plugin for which to fetch plugin-specific + content. + + Raises: + KeyError: if the plugin name is not found. + + Returns: + A dict mapping tag names to bytestrings of plugin-specific content-- by + convention, in the form of binary serialized protos. + """ + if plugin_name not in self._plugin_to_tag_to_content: + raise KeyError("Plugin %r could not be found." % plugin_name) + return self._plugin_to_tag_to_content[plugin_name] + + def SummaryMetadata(self, tag): + """Given a summary tag name, return the associated metadata object. + + Args: + tag: The name of a tag, as a string. + + Raises: + KeyError: If the tag is not found. + + Returns: + A `SummaryMetadata` protobuf. + """ + return self.summary_metadata[tag] + + def _ProcessEvent(self, event): + """Called whenever an event is loaded.""" + if self._first_event_timestamp is None: + self._first_event_timestamp = event.wall_time + + if event.HasField("file_version"): + new_file_version = _ParseFileVersion(event.file_version) + if self.file_version and self.file_version != new_file_version: + ## This should not happen. + logger.warn( + ( + "Found new file_version for event.proto. This will " + "affect purging logic for TensorFlow restarts. " + "Old: {0} New: {1}" + ).format(self.file_version, new_file_version) + ) + self.file_version = new_file_version + + self._MaybePurgeOrphanedData(event) + + ## Process the event. + # GraphDef and MetaGraphDef are handled in a special way: + # If no graph_def Event is available, but a meta_graph_def is, and it + # contains a graph_def, then use the meta_graph_def.graph_def as our graph. + # If a graph_def Event is available, always prefer it to the graph_def + # inside the meta_graph_def. + if event.HasField("graph_def"): + if self._graph is not None: + logger.warn( + ( + "Found more than one graph event per run, or there was " + "a metagraph containing a graph_def, as well as one or " + "more graph events. Overwriting the graph with the " + "newest event." + ) + ) + self._graph = event.graph_def + self._graph_from_metagraph = False + elif event.HasField("meta_graph_def"): + if self._meta_graph is not None: + logger.warn( + ( + "Found more than one metagraph event per run. " + "Overwriting the metagraph with the newest event." + ) + ) + self._meta_graph = event.meta_graph_def + if self._graph is None or self._graph_from_metagraph: + # We may have a graph_def in the metagraph. If so, and no + # graph_def is directly available, use this one instead. + meta_graph = meta_graph_pb2.MetaGraphDef() + meta_graph.ParseFromString(self._meta_graph) + if meta_graph.graph_def: + if self._graph is not None: + logger.warn( + ( + "Found multiple metagraphs containing graph_defs," + "but did not find any graph events. Overwriting the " + "graph with the newest metagraph version." + ) + ) + self._graph_from_metagraph = True + self._graph = meta_graph.graph_def.SerializeToString() + elif event.HasField("tagged_run_metadata"): + tag = event.tagged_run_metadata.tag + if tag in self._tagged_metadata: + logger.warn( + 'Found more than one "run metadata" event with tag ' + + tag + + ". Overwriting it with the newest event." + ) + self._tagged_metadata[tag] = event.tagged_run_metadata.run_metadata + elif event.HasField("summary"): + for value in event.summary.value: + if value.HasField("metadata"): + tag = value.tag + # We only store the first instance of the metadata. This check + # is important: the `FileWriter` does strip metadata from all + # values except the first one per each tag, but a new + # `FileWriter` is created every time a training job stops and + # restarts. Hence, we must also ignore non-initial metadata in + # this logic. + if tag not in self.summary_metadata: + self.summary_metadata[tag] = value.metadata + plugin_data = value.metadata.plugin_data + if plugin_data.plugin_name: + self._plugin_to_tag_to_content[ + plugin_data.plugin_name + ][tag] = plugin_data.content + else: + logger.warn( + ( + "This summary with tag %r is oddly not associated with a " + "plugin." + ), + tag, + ) + + for summary_type, summary_func in SUMMARY_TYPES.items(): + if value.HasField(summary_type): + datum = getattr(value, summary_type) + tag = value.tag + if summary_type == "tensor" and not tag: + # This tensor summary was created using the old method that used + # plugin assets. We must still continue to support it. + tag = value.node_name + getattr(self, summary_func)( + tag, event.wall_time, event.step, datum + ) + + def Tags(self): + """Return all tags found in the value stream. + + Returns: + A `{tagType: ['list', 'of', 'tags']}` dictionary. + """ + return { + IMAGES: self.images.Keys(), + AUDIO: self.audios.Keys(), + HISTOGRAMS: self.histograms.Keys(), + SCALARS: self.scalars.Keys(), + COMPRESSED_HISTOGRAMS: self.compressed_histograms.Keys(), + TENSORS: self.tensors.Keys(), + # Use a heuristic: if the metagraph is available, but + # graph is not, then we assume the metagraph contains the graph. + GRAPH: self._graph is not None, + META_GRAPH: self._meta_graph is not None, + RUN_METADATA: list(self._tagged_metadata.keys()), + } + + def Scalars(self, tag): + """Given a summary tag, return all associated `ScalarEvent`s. + + Args: + tag: A string tag associated with the events. + + Raises: + KeyError: If the tag is not found. + + Returns: + An array of `ScalarEvent`s. + """ + return self.scalars.Items(tag) + + def Graph(self): + """Return the graph definition, if there is one. + + If the graph is stored directly, return that. If no graph is stored + directly but a metagraph is stored containing a graph, return that. + + Raises: + ValueError: If there is no graph for this run. + + Returns: + The `graph_def` proto. + """ + graph = graph_pb2.GraphDef() + if self._graph is not None: + graph.ParseFromString(self._graph) + return graph + raise ValueError("There is no graph in this EventAccumulator") + + def MetaGraph(self): + """Return the metagraph definition, if there is one. + + Raises: + ValueError: If there is no metagraph for this run. + + Returns: + The `meta_graph_def` proto. + """ + if self._meta_graph is None: + raise ValueError("There is no metagraph in this EventAccumulator") meta_graph = meta_graph_pb2.MetaGraphDef() meta_graph.ParseFromString(self._meta_graph) - if meta_graph.graph_def: - if self._graph is not None: - logger.warn( - ('Found multiple metagraphs containing graph_defs,' - 'but did not find any graph events. Overwriting the ' - 'graph with the newest metagraph version.')) - self._graph_from_metagraph = True - self._graph = meta_graph.graph_def.SerializeToString() - elif event.HasField('tagged_run_metadata'): - tag = event.tagged_run_metadata.tag - if tag in self._tagged_metadata: - logger.warn('Found more than one "run metadata" event with tag ' + - tag + '. Overwriting it with the newest event.') - self._tagged_metadata[tag] = event.tagged_run_metadata.run_metadata - elif event.HasField('summary'): - for value in event.summary.value: - if value.HasField('metadata'): - tag = value.tag - # We only store the first instance of the metadata. This check - # is important: the `FileWriter` does strip metadata from all - # values except the first one per each tag, but a new - # `FileWriter` is created every time a training job stops and - # restarts. Hence, we must also ignore non-initial metadata in - # this logic. - if tag not in self.summary_metadata: - self.summary_metadata[tag] = value.metadata - plugin_data = value.metadata.plugin_data - if plugin_data.plugin_name: - self._plugin_to_tag_to_content[plugin_data.plugin_name][tag] = ( - plugin_data.content) - else: - logger.warn( - ('This summary with tag %r is oddly not associated with a ' - 'plugin.'), tag) - - for summary_type, summary_func in SUMMARY_TYPES.items(): - if value.HasField(summary_type): - datum = getattr(value, summary_type) - tag = value.tag - if summary_type == 'tensor' and not tag: - # This tensor summary was created using the old method that used - # plugin assets. We must still continue to support it. - tag = value.node_name - getattr(self, summary_func)(tag, event.wall_time, event.step, datum) - + return meta_graph - def Tags(self): - """Return all tags found in the value stream. + def RunMetadata(self, tag): + """Given a tag, return the associated session.run() metadata. - Returns: - A `{tagType: ['list', 'of', 'tags']}` dictionary. - """ - return { - IMAGES: self.images.Keys(), - AUDIO: self.audios.Keys(), - HISTOGRAMS: self.histograms.Keys(), - SCALARS: self.scalars.Keys(), - COMPRESSED_HISTOGRAMS: self.compressed_histograms.Keys(), - TENSORS: self.tensors.Keys(), - # Use a heuristic: if the metagraph is available, but - # graph is not, then we assume the metagraph contains the graph. - GRAPH: self._graph is not None, - META_GRAPH: self._meta_graph is not None, - RUN_METADATA: list(self._tagged_metadata.keys()) - } - - def Scalars(self, tag): - """Given a summary tag, return all associated `ScalarEvent`s. - - Args: - tag: A string tag associated with the events. - - Raises: - KeyError: If the tag is not found. - - Returns: - An array of `ScalarEvent`s. - """ - return self.scalars.Items(tag) + Args: + tag: A string tag associated with the event. - def Graph(self): - """Return the graph definition, if there is one. + Raises: + ValueError: If the tag is not found. - If the graph is stored directly, return that. If no graph is stored - directly but a metagraph is stored containing a graph, return that. + Returns: + The metadata in form of `RunMetadata` proto. + """ + if tag not in self._tagged_metadata: + raise ValueError("There is no run metadata with this tag name") - Raises: - ValueError: If there is no graph for this run. + run_metadata = config_pb2.RunMetadata() + run_metadata.ParseFromString(self._tagged_metadata[tag]) + return run_metadata - Returns: - The `graph_def` proto. - """ - graph = graph_pb2.GraphDef() - if self._graph is not None: - graph.ParseFromString(self._graph) - return graph - raise ValueError('There is no graph in this EventAccumulator') + def Histograms(self, tag): + """Given a summary tag, return all associated histograms. - def MetaGraph(self): - """Return the metagraph definition, if there is one. + Args: + tag: A string tag associated with the events. - Raises: - ValueError: If there is no metagraph for this run. - - Returns: - The `meta_graph_def` proto. - """ - if self._meta_graph is None: - raise ValueError('There is no metagraph in this EventAccumulator') - meta_graph = meta_graph_pb2.MetaGraphDef() - meta_graph.ParseFromString(self._meta_graph) - return meta_graph + Raises: + KeyError: If the tag is not found. - def RunMetadata(self, tag): - """Given a tag, return the associated session.run() metadata. + Returns: + An array of `HistogramEvent`s. + """ + return self.histograms.Items(tag) - Args: - tag: A string tag associated with the event. + def CompressedHistograms(self, tag): + """Given a summary tag, return all associated compressed histograms. - Raises: - ValueError: If the tag is not found. + Args: + tag: A string tag associated with the events. + + Raises: + KeyError: If the tag is not found. + + Returns: + An array of `CompressedHistogramEvent`s. + """ + return self.compressed_histograms.Items(tag) + + def Images(self, tag): + """Given a summary tag, return all associated images. + + Args: + tag: A string tag associated with the events. + + Raises: + KeyError: If the tag is not found. + + Returns: + An array of `ImageEvent`s. + """ + return self.images.Items(tag) + + def Audio(self, tag): + """Given a summary tag, return all associated audio. + + Args: + tag: A string tag associated with the events. + + Raises: + KeyError: If the tag is not found. + + Returns: + An array of `AudioEvent`s. + """ + return self.audios.Items(tag) + + def Tensors(self, tag): + """Given a summary tag, return all associated tensors. + + Args: + tag: A string tag associated with the events. + + Raises: + KeyError: If the tag is not found. + + Returns: + An array of `TensorEvent`s. + """ + return self.tensors.Items(tag) + + def _MaybePurgeOrphanedData(self, event): + """Maybe purge orphaned data due to a TensorFlow crash. + + When TensorFlow crashes at step T+O and restarts at step T, any events + written after step T are now "orphaned" and will be at best misleading if + they are included in TensorBoard. + + This logic attempts to determine if there is orphaned data, and purge it + if it is found. + + Args: + event: The event to use as a reference, to determine if a purge is needed. + """ + if not self.purge_orphaned_data: + return + ## Check if the event happened after a crash, and purge expired tags. + if self.file_version and self.file_version >= 2: + ## If the file_version is recent enough, use the SessionLog enum + ## to check for restarts. + self._CheckForRestartAndMaybePurge(event) + else: + ## If there is no file version, default to old logic of checking for + ## out of order steps. + self._CheckForOutOfOrderStepAndMaybePurge(event) + + def _CheckForRestartAndMaybePurge(self, event): + """Check and discard expired events using SessionLog.START. + + Check for a SessionLog.START event and purge all previously seen events + with larger steps, because they are out of date. Because of supervisor + threading, it is possible that this logic will cause the first few event + messages to be discarded since supervisor threading does not guarantee + that the START message is deterministically written first. + + This method is preferred over _CheckForOutOfOrderStepAndMaybePurge which + can inadvertently discard events due to supervisor threading. + + Args: + event: The event to use as reference. If the event is a START event, all + previously seen events with a greater event.step will be purged. + """ + if ( + event.HasField("session_log") + and event.session_log.status == event_pb2.SessionLog.START + ): + self._Purge(event, by_tags=False) + + def _CheckForOutOfOrderStepAndMaybePurge(self, event): + """Check for out-of-order event.step and discard expired events for + tags. + + Check if the event is out of order relative to the global most recent step. + If it is, purge outdated summaries for tags that the event contains. + + Args: + event: The event to use as reference. If the event is out-of-order, all + events with the same tags, but with a greater event.step will be purged. + """ + if event.step < self.most_recent_step and event.HasField("summary"): + self._Purge(event, by_tags=True) + else: + self.most_recent_step = event.step + self.most_recent_wall_time = event.wall_time + + def _ConvertHistogramProtoToTuple(self, histo): + return HistogramValue( + min=histo.min, + max=histo.max, + num=histo.num, + sum=histo.sum, + sum_squares=histo.sum_squares, + bucket_limit=list(histo.bucket_limit), + bucket=list(histo.bucket), + ) + + def _ProcessHistogram(self, tag, wall_time, step, histo): + """Processes a proto histogram by adding it to accumulated state.""" + histo = self._ConvertHistogramProtoToTuple(histo) + histo_ev = HistogramEvent(wall_time, step, histo) + self.histograms.AddItem(tag, histo_ev) + self.compressed_histograms.AddItem( + tag, histo_ev, self._CompressHistogram + ) + + def _CompressHistogram(self, histo_ev): + """Callback for _ProcessHistogram.""" + return CompressedHistogramEvent( + histo_ev.wall_time, + histo_ev.step, + compressor.compress_histogram_proto( + histo_ev.histogram_value, self._compression_bps + ), + ) + + def _ProcessImage(self, tag, wall_time, step, image): + """Processes an image by adding it to accumulated state.""" + event = ImageEvent( + wall_time=wall_time, + step=step, + encoded_image_string=image.encoded_image_string, + width=image.width, + height=image.height, + ) + self.images.AddItem(tag, event) + + def _ProcessAudio(self, tag, wall_time, step, audio): + """Processes a audio by adding it to accumulated state.""" + event = AudioEvent( + wall_time=wall_time, + step=step, + encoded_audio_string=audio.encoded_audio_string, + content_type=audio.content_type, + sample_rate=audio.sample_rate, + length_frames=audio.length_frames, + ) + self.audios.AddItem(tag, event) + + def _ProcessScalar(self, tag, wall_time, step, scalar): + """Processes a simple value by adding it to accumulated state.""" + sv = ScalarEvent(wall_time=wall_time, step=step, value=scalar) + self.scalars.AddItem(tag, sv) + + def _ProcessTensor(self, tag, wall_time, step, tensor): + tv = TensorEvent(wall_time=wall_time, step=step, tensor_proto=tensor) + self.tensors.AddItem(tag, tv) + + def _Purge(self, event, by_tags): + """Purge all events that have occurred after the given event.step. + + If by_tags is True, purge all events that occurred after the given + event.step, but only for the tags that the event has. Non-sequential + event.steps suggest that a TensorFlow restart occurred, and we discard + the out-of-order events to display a consistent view in TensorBoard. + + Discarding by tags is the safer method, when we are unsure whether a restart + has occurred, given that threading in supervisor can cause events of + different tags to arrive with unsynchronized step values. + + If by_tags is False, then purge all events with event.step greater than the + given event.step. This can be used when we are certain that a TensorFlow + restart has occurred and these events can be discarded. + + Args: + event: The event to use as reference for the purge. All events with + the same tags, but with a greater event.step will be purged. + by_tags: Bool to dictate whether to discard all out-of-order events or + only those that are associated with the given reference event. + """ + ## Keep data in reservoirs that has a step less than event.step + _NotExpired = lambda x: x.step < event.step + + if by_tags: + + def _ExpiredPerTag(value): + return [ + getattr(self, x).FilterItems(_NotExpired, value.tag) + for x in self.accumulated_attrs + ] + + expired_per_tags = [ + _ExpiredPerTag(value) for value in event.summary.value + ] + expired_per_type = [sum(x) for x in zip(*expired_per_tags)] + else: + expired_per_type = [ + getattr(self, x).FilterItems(_NotExpired) + for x in self.accumulated_attrs + ] + + if sum(expired_per_type) > 0: + purge_msg = _GetPurgeMessage( + self.most_recent_step, + self.most_recent_wall_time, + event.step, + event.wall_time, + *expired_per_type + ) + logger.warn(purge_msg) + + +def _GetPurgeMessage( + most_recent_step, + most_recent_wall_time, + event_step, + event_wall_time, + num_expired_scalars, + num_expired_histos, + num_expired_comp_histos, + num_expired_images, + num_expired_audio, +): + """Return the string message associated with TensorBoard purges.""" + return ( + "Detected out of order event.step likely caused by " + "a TensorFlow restart. Purging expired events from Tensorboard" + " display between the previous step: {} (timestamp: {}) and " + "current step: {} (timestamp: {}). Removing {} scalars, {} " + "histograms, {} compressed histograms, {} images, " + "and {} audio." + ).format( + most_recent_step, + most_recent_wall_time, + event_step, + event_wall_time, + num_expired_scalars, + num_expired_histos, + num_expired_comp_histos, + num_expired_images, + num_expired_audio, + ) - Returns: - The metadata in form of `RunMetadata` proto. - """ - if tag not in self._tagged_metadata: - raise ValueError('There is no run metadata with this tag name') - run_metadata = config_pb2.RunMetadata() - run_metadata.ParseFromString(self._tagged_metadata[tag]) - return run_metadata - - def Histograms(self, tag): - """Given a summary tag, return all associated histograms. - - Args: - tag: A string tag associated with the events. - - Raises: - KeyError: If the tag is not found. - - Returns: - An array of `HistogramEvent`s. - """ - return self.histograms.Items(tag) - - def CompressedHistograms(self, tag): - """Given a summary tag, return all associated compressed histograms. - - Args: - tag: A string tag associated with the events. - - Raises: - KeyError: If the tag is not found. - - Returns: - An array of `CompressedHistogramEvent`s. - """ - return self.compressed_histograms.Items(tag) - - def Images(self, tag): - """Given a summary tag, return all associated images. - - Args: - tag: A string tag associated with the events. - - Raises: - KeyError: If the tag is not found. - - Returns: - An array of `ImageEvent`s. - """ - return self.images.Items(tag) - - def Audio(self, tag): - """Given a summary tag, return all associated audio. - - Args: - tag: A string tag associated with the events. - - Raises: - KeyError: If the tag is not found. - - Returns: - An array of `AudioEvent`s. - """ - return self.audios.Items(tag) - - def Tensors(self, tag): - """Given a summary tag, return all associated tensors. - - Args: - tag: A string tag associated with the events. - - Raises: - KeyError: If the tag is not found. - - Returns: - An array of `TensorEvent`s. - """ - return self.tensors.Items(tag) - - def _MaybePurgeOrphanedData(self, event): - """Maybe purge orphaned data due to a TensorFlow crash. - - When TensorFlow crashes at step T+O and restarts at step T, any events - written after step T are now "orphaned" and will be at best misleading if - they are included in TensorBoard. - - This logic attempts to determine if there is orphaned data, and purge it - if it is found. - - Args: - event: The event to use as a reference, to determine if a purge is needed. - """ - if not self.purge_orphaned_data: - return - ## Check if the event happened after a crash, and purge expired tags. - if self.file_version and self.file_version >= 2: - ## If the file_version is recent enough, use the SessionLog enum - ## to check for restarts. - self._CheckForRestartAndMaybePurge(event) +def _GeneratorFromPath(path): + """Create an event generator for file or directory at given path string.""" + if not path: + raise ValueError("path must be a valid string") + if io_wrapper.IsTensorFlowEventsFile(path): + return event_file_loader.EventFileLoader(path) else: - ## If there is no file version, default to old logic of checking for - ## out of order steps. - self._CheckForOutOfOrderStepAndMaybePurge(event) - - def _CheckForRestartAndMaybePurge(self, event): - """Check and discard expired events using SessionLog.START. - - Check for a SessionLog.START event and purge all previously seen events - with larger steps, because they are out of date. Because of supervisor - threading, it is possible that this logic will cause the first few event - messages to be discarded since supervisor threading does not guarantee - that the START message is deterministically written first. - - This method is preferred over _CheckForOutOfOrderStepAndMaybePurge which - can inadvertently discard events due to supervisor threading. - - Args: - event: The event to use as reference. If the event is a START event, all - previously seen events with a greater event.step will be purged. - """ - if event.HasField( - 'session_log') and event.session_log.status == event_pb2.SessionLog.START: - self._Purge(event, by_tags=False) + return directory_watcher.DirectoryWatcher( + path, + event_file_loader.EventFileLoader, + io_wrapper.IsTensorFlowEventsFile, + ) - def _CheckForOutOfOrderStepAndMaybePurge(self, event): - """Check for out-of-order event.step and discard expired events for tags. - Check if the event is out of order relative to the global most recent step. - If it is, purge outdated summaries for tags that the event contains. +def _ParseFileVersion(file_version): + """Convert the string file_version in event.proto into a float. Args: - event: The event to use as reference. If the event is out-of-order, all - events with the same tags, but with a greater event.step will be purged. - """ - if event.step < self.most_recent_step and event.HasField('summary'): - self._Purge(event, by_tags=True) - else: - self.most_recent_step = event.step - self.most_recent_wall_time = event.wall_time - - def _ConvertHistogramProtoToTuple(self, histo): - return HistogramValue(min=histo.min, - max=histo.max, - num=histo.num, - sum=histo.sum, - sum_squares=histo.sum_squares, - bucket_limit=list(histo.bucket_limit), - bucket=list(histo.bucket)) - - def _ProcessHistogram(self, tag, wall_time, step, histo): - """Processes a proto histogram by adding it to accumulated state.""" - histo = self._ConvertHistogramProtoToTuple(histo) - histo_ev = HistogramEvent(wall_time, step, histo) - self.histograms.AddItem(tag, histo_ev) - self.compressed_histograms.AddItem(tag, histo_ev, self._CompressHistogram) - - def _CompressHistogram(self, histo_ev): - """Callback for _ProcessHistogram.""" - return CompressedHistogramEvent( - histo_ev.wall_time, - histo_ev.step, - compressor.compress_histogram_proto( - histo_ev.histogram_value, self._compression_bps)) - - def _ProcessImage(self, tag, wall_time, step, image): - """Processes an image by adding it to accumulated state.""" - event = ImageEvent(wall_time=wall_time, - step=step, - encoded_image_string=image.encoded_image_string, - width=image.width, - height=image.height) - self.images.AddItem(tag, event) - - def _ProcessAudio(self, tag, wall_time, step, audio): - """Processes a audio by adding it to accumulated state.""" - event = AudioEvent(wall_time=wall_time, - step=step, - encoded_audio_string=audio.encoded_audio_string, - content_type=audio.content_type, - sample_rate=audio.sample_rate, - length_frames=audio.length_frames) - self.audios.AddItem(tag, event) - - def _ProcessScalar(self, tag, wall_time, step, scalar): - """Processes a simple value by adding it to accumulated state.""" - sv = ScalarEvent(wall_time=wall_time, step=step, value=scalar) - self.scalars.AddItem(tag, sv) - - def _ProcessTensor(self, tag, wall_time, step, tensor): - tv = TensorEvent(wall_time=wall_time, step=step, tensor_proto=tensor) - self.tensors.AddItem(tag, tv) - - def _Purge(self, event, by_tags): - """Purge all events that have occurred after the given event.step. - - If by_tags is True, purge all events that occurred after the given - event.step, but only for the tags that the event has. Non-sequential - event.steps suggest that a TensorFlow restart occurred, and we discard - the out-of-order events to display a consistent view in TensorBoard. - - Discarding by tags is the safer method, when we are unsure whether a restart - has occurred, given that threading in supervisor can cause events of - different tags to arrive with unsynchronized step values. - - If by_tags is False, then purge all events with event.step greater than the - given event.step. This can be used when we are certain that a TensorFlow - restart has occurred and these events can be discarded. + file_version: String file_version from event.proto - Args: - event: The event to use as reference for the purge. All events with - the same tags, but with a greater event.step will be purged. - by_tags: Bool to dictate whether to discard all out-of-order events or - only those that are associated with the given reference event. + Returns: + Version number as a float. """ - ## Keep data in reservoirs that has a step less than event.step - _NotExpired = lambda x: x.step < event.step - - if by_tags: - def _ExpiredPerTag(value): - return [getattr(self, x).FilterItems(_NotExpired, value.tag) - for x in self.accumulated_attrs] - - expired_per_tags = [_ExpiredPerTag(value) - for value in event.summary.value] - expired_per_type = [sum(x) for x in zip(*expired_per_tags)] - else: - expired_per_type = [getattr(self, x).FilterItems(_NotExpired) - for x in self.accumulated_attrs] - - if sum(expired_per_type) > 0: - purge_msg = _GetPurgeMessage(self.most_recent_step, - self.most_recent_wall_time, event.step, - event.wall_time, *expired_per_type) - logger.warn(purge_msg) - - -def _GetPurgeMessage(most_recent_step, most_recent_wall_time, event_step, - event_wall_time, num_expired_scalars, num_expired_histos, - num_expired_comp_histos, num_expired_images, - num_expired_audio): - """Return the string message associated with TensorBoard purges.""" - return ('Detected out of order event.step likely caused by ' - 'a TensorFlow restart. Purging expired events from Tensorboard' - ' display between the previous step: {} (timestamp: {}) and ' - 'current step: {} (timestamp: {}). Removing {} scalars, {} ' - 'histograms, {} compressed histograms, {} images, ' - 'and {} audio.').format(most_recent_step, most_recent_wall_time, - event_step, event_wall_time, - num_expired_scalars, num_expired_histos, - num_expired_comp_histos, num_expired_images, - num_expired_audio) - - -def _GeneratorFromPath(path): - """Create an event generator for file or directory at given path string.""" - if not path: - raise ValueError('path must be a valid string') - if io_wrapper.IsTensorFlowEventsFile(path): - return event_file_loader.EventFileLoader(path) - else: - return directory_watcher.DirectoryWatcher( - path, - event_file_loader.EventFileLoader, - io_wrapper.IsTensorFlowEventsFile) - - -def _ParseFileVersion(file_version): - """Convert the string file_version in event.proto into a float. - - Args: - file_version: String file_version from event.proto - - Returns: - Version number as a float. - """ - tokens = file_version.split('brain.Event:') - try: - return float(tokens[-1]) - except ValueError: - ## This should never happen according to the definition of file_version - ## specified in event.proto. - logger.warn( - ('Invalid event.proto file_version. Defaulting to use of ' - 'out-of-order event.step logic for purging expired events.')) - return -1 + tokens = file_version.split("brain.Event:") + try: + return float(tokens[-1]) + except ValueError: + ## This should never happen according to the definition of file_version + ## specified in event.proto. + logger.warn( + ( + "Invalid event.proto file_version. Defaulting to use of " + "out-of-order event.step logic for purging expired events." + ) + ) + return -1 diff --git a/tensorboard/backend/event_processing/event_accumulator_test.py b/tensorboard/backend/event_processing/event_accumulator_test.py index 7a857e78e8..89c55cf076 100644 --- a/tensorboard/backend/event_processing/event_accumulator_test.py +++ b/tensorboard/backend/event_processing/event_accumulator_test.py @@ -40,918 +40,1020 @@ class _EventGenerator(object): - """Class that can add_events and then yield them back. - - Satisfies the EventGenerator API required for the EventAccumulator. - Satisfies the EventWriter API required to create a tf.summary.FileWriter. - - Has additional convenience methods for adding test events. - """ - - def __init__(self, testcase, zero_out_timestamps=False): - self._testcase = testcase - self.items = [] - self.zero_out_timestamps = zero_out_timestamps - - def Load(self): - while self.items: - yield self.items.pop(0) - - def AddScalar(self, tag, wall_time=0, step=0, value=0): - event = event_pb2.Event( - wall_time=wall_time, - step=step, - summary=summary_pb2.Summary( - value=[summary_pb2.Summary.Value(tag=tag, simple_value=value)])) - self.AddEvent(event) - - def AddHistogram(self, - tag, - wall_time=0, - step=0, - hmin=1, - hmax=2, - hnum=3, - hsum=4, - hsum_squares=5, - hbucket_limit=None, - hbucket=None): - histo = summary_pb2.HistogramProto( - min=hmin, - max=hmax, - num=hnum, - sum=hsum, - sum_squares=hsum_squares, - bucket_limit=hbucket_limit, - bucket=hbucket) - event = event_pb2.Event( - wall_time=wall_time, - step=step, - summary=summary_pb2.Summary( - value=[summary_pb2.Summary.Value(tag=tag, histo=histo)])) - self.AddEvent(event) - - def AddImage(self, - tag, - wall_time=0, - step=0, - encoded_image_string=b'imgstr', - width=150, - height=100): - image = summary_pb2.Summary.Image( - encoded_image_string=encoded_image_string, width=width, height=height) - event = event_pb2.Event( - wall_time=wall_time, - step=step, - summary=summary_pb2.Summary( - value=[summary_pb2.Summary.Value(tag=tag, image=image)])) - self.AddEvent(event) - - def AddAudio(self, - tag, - wall_time=0, - step=0, - encoded_audio_string=b'sndstr', - content_type='audio/wav', - sample_rate=44100, - length_frames=22050): - audio = summary_pb2.Summary.Audio( - encoded_audio_string=encoded_audio_string, - content_type=content_type, - sample_rate=sample_rate, - length_frames=length_frames) - event = event_pb2.Event( - wall_time=wall_time, - step=step, - summary=summary_pb2.Summary( - value=[summary_pb2.Summary.Value(tag=tag, audio=audio)])) - self.AddEvent(event) - - def AddEvent(self, event): - if self.zero_out_timestamps: - event.wall_time = 0 - self.items.append(event) - - def add_event(self, event): # pylint: disable=invalid-name - """Match the EventWriter API.""" - self.AddEvent(event) - - def get_logdir(self): # pylint: disable=invalid-name - """Return a temp directory for asset writing.""" - return self._testcase.get_temp_dir() - - def close(self): - """Closes the event writer""" - # noop + """Class that can add_events and then yield them back. + Satisfies the EventGenerator API required for the EventAccumulator. + Satisfies the EventWriter API required to create a tf.summary.FileWriter. -class EventAccumulatorTest(tf.test.TestCase): - - def assertTagsEqual(self, actual, expected): - """Utility method for checking the return value of the Tags() call. - - It fills out the `expected` arg with the default (empty) values for every - tag type, so that the author needs only specify the non-empty values they - are interested in testing. - - Args: - actual: The actual Accumulator tags response. - expected: The expected tags response (empty fields may be omitted) + Has additional convenience methods for adding test events. """ - empty_tags = { - ea.IMAGES: [], - ea.AUDIO: [], - ea.SCALARS: [], - ea.HISTOGRAMS: [], - ea.COMPRESSED_HISTOGRAMS: [], - ea.GRAPH: False, - ea.META_GRAPH: False, - ea.RUN_METADATA: [], - ea.TENSORS: [], - } - - # Verifies that there are no unexpected keys in the actual response. - # If this line fails, likely you added a new tag type, and need to update - # the empty_tags dictionary above. - self.assertItemsEqual(actual.keys(), empty_tags.keys()) - - for key in actual: - expected_value = expected.get(key, empty_tags[key]) - if isinstance(expected_value, list): - self.assertItemsEqual(actual[key], expected_value) - else: - self.assertEqual(actual[key], expected_value) - - -class MockingEventAccumulatorTest(EventAccumulatorTest): - - def setUp(self): - super(MockingEventAccumulatorTest, self).setUp() - self.stubs = tf.compat.v1.test.StubOutForTesting() - self._real_constructor = ea.EventAccumulator - self._real_generator = ea._GeneratorFromPath - - def _FakeAccumulatorConstructor(generator, *args, **kwargs): - ea._GeneratorFromPath = lambda x: generator - return self._real_constructor(generator, *args, **kwargs) - - ea.EventAccumulator = _FakeAccumulatorConstructor - - def tearDown(self): - self.stubs.CleanUp() - ea.EventAccumulator = self._real_constructor - ea._GeneratorFromPath = self._real_generator - - def testEmptyAccumulator(self): - gen = _EventGenerator(self) - x = ea.EventAccumulator(gen) - x.Reload() - self.assertTagsEqual(x.Tags(), {}) - - def testTags(self): - """Tags should be found in EventAccumulator after adding some events.""" - gen = _EventGenerator(self) - gen.AddScalar('s1') - gen.AddScalar('s2') - gen.AddHistogram('hst1') - gen.AddHistogram('hst2') - gen.AddImage('im1') - gen.AddImage('im2') - gen.AddAudio('snd1') - gen.AddAudio('snd2') - acc = ea.EventAccumulator(gen) - acc.Reload() - self.assertTagsEqual(acc.Tags(), { - ea.IMAGES: ['im1', 'im2'], - ea.AUDIO: ['snd1', 'snd2'], - ea.SCALARS: ['s1', 's2'], - ea.HISTOGRAMS: ['hst1', 'hst2'], - ea.COMPRESSED_HISTOGRAMS: ['hst1', 'hst2'], - }) - - def testReload(self): - """EventAccumulator contains suitable tags after calling Reload.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - acc.Reload() - self.assertTagsEqual(acc.Tags(), {}) - gen.AddScalar('s1') - gen.AddScalar('s2') - gen.AddHistogram('hst1') - gen.AddHistogram('hst2') - gen.AddImage('im1') - gen.AddImage('im2') - gen.AddAudio('snd1') - gen.AddAudio('snd2') - acc.Reload() - self.assertTagsEqual(acc.Tags(), { - ea.IMAGES: ['im1', 'im2'], - ea.AUDIO: ['snd1', 'snd2'], - ea.SCALARS: ['s1', 's2'], - ea.HISTOGRAMS: ['hst1', 'hst2'], - ea.COMPRESSED_HISTOGRAMS: ['hst1', 'hst2'], - }) - - def testScalars(self): - """Tests whether EventAccumulator contains scalars after adding them.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - s1 = ea.ScalarEvent(wall_time=1, step=10, value=32) - s2 = ea.ScalarEvent(wall_time=2, step=12, value=64) - gen.AddScalar('s1', wall_time=1, step=10, value=32) - gen.AddScalar('s2', wall_time=2, step=12, value=64) - acc.Reload() - self.assertEqual(acc.Scalars('s1'), [s1]) - self.assertEqual(acc.Scalars('s2'), [s2]) - - def testHistograms(self): - """Tests whether histograms are inserted into EventAccumulator.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - - val1 = ea.HistogramValue( - min=1, - max=2, - num=3, - sum=4, - sum_squares=5, - bucket_limit=[1, 2, 3], - bucket=[0, 3, 0]) - val2 = ea.HistogramValue( - min=-2, - max=3, - num=4, - sum=5, - sum_squares=6, - bucket_limit=[2, 3, 4], - bucket=[1, 3, 0]) - - hst1 = ea.HistogramEvent(wall_time=1, step=10, histogram_value=val1) - hst2 = ea.HistogramEvent(wall_time=2, step=12, histogram_value=val2) - gen.AddHistogram( - 'hst1', - wall_time=1, - step=10, + def __init__(self, testcase, zero_out_timestamps=False): + self._testcase = testcase + self.items = [] + self.zero_out_timestamps = zero_out_timestamps + + def Load(self): + while self.items: + yield self.items.pop(0) + + def AddScalar(self, tag, wall_time=0, step=0, value=0): + event = event_pb2.Event( + wall_time=wall_time, + step=step, + summary=summary_pb2.Summary( + value=[summary_pb2.Summary.Value(tag=tag, simple_value=value)] + ), + ) + self.AddEvent(event) + + def AddHistogram( + self, + tag, + wall_time=0, + step=0, hmin=1, hmax=2, hnum=3, hsum=4, hsum_squares=5, - hbucket_limit=[1, 2, 3], - hbucket=[0, 3, 0]) - gen.AddHistogram( - 'hst2', - wall_time=2, - step=12, - hmin=-2, - hmax=3, - hnum=4, - hsum=5, - hsum_squares=6, - hbucket_limit=[2, 3, 4], - hbucket=[1, 3, 0]) - acc.Reload() - self.assertEqual(acc.Histograms('hst1'), [hst1]) - self.assertEqual(acc.Histograms('hst2'), [hst2]) - - def testCompressedHistograms(self): - """Tests compressed histograms inserted into EventAccumulator.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen, compression_bps=(0, 2500, 5000, 7500, 10000)) - - gen.AddHistogram( - 'hst1', - wall_time=1, - step=10, - hmin=1, - hmax=2, - hnum=3, - hsum=4, - hsum_squares=5, - hbucket_limit=[1, 2, 3], - hbucket=[0, 3, 0]) - gen.AddHistogram( - 'hst2', - wall_time=2, - step=12, - hmin=-2, - hmax=3, - hnum=4, - hsum=5, - hsum_squares=6, - hbucket_limit=[2, 3, 4], - hbucket=[1, 3, 0]) - acc.Reload() - - # Create the expected values after compressing hst1 - expected_vals1 = [ - compressor.CompressedHistogramValue(bp, val) - for bp, val in [(0, 1.0), (2500, 1.25), (5000, 1.5), (7500, 1.75 - ), (10000, 2.0)] - ] - expected_cmphst1 = ea.CompressedHistogramEvent( - wall_time=1, step=10, compressed_histogram_values=expected_vals1) - self.assertEqual(acc.CompressedHistograms('hst1'), [expected_cmphst1]) - - # Create the expected values after compressing hst2 - expected_vals2 = [ - compressor.CompressedHistogramValue(bp, val) - for bp, val in [(0, -2), - (2500, 2), - (5000, 2 + 1 / 3), - (7500, 2 + 2 / 3), - (10000, 3)] - ] - expected_cmphst2 = ea.CompressedHistogramEvent( - wall_time=2, step=12, compressed_histogram_values=expected_vals2) - self.assertEqual(acc.CompressedHistograms('hst2'), [expected_cmphst2]) - - def testImages(self): - """Tests 2 images inserted/accessed in EventAccumulator.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - im1 = ea.ImageEvent( - wall_time=1, - step=10, - encoded_image_string=b'big', - width=400, - height=300) - im2 = ea.ImageEvent( - wall_time=2, - step=12, - encoded_image_string=b'small', - width=40, - height=30) - gen.AddImage( - 'im1', - wall_time=1, - step=10, - encoded_image_string=b'big', - width=400, - height=300) - gen.AddImage( - 'im2', - wall_time=2, - step=12, - encoded_image_string=b'small', - width=40, - height=30) - acc.Reload() - self.assertEqual(acc.Images('im1'), [im1]) - self.assertEqual(acc.Images('im2'), [im2]) - - def testAudio(self): - """Tests 2 audio events inserted/accessed in EventAccumulator.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - snd1 = ea.AudioEvent( - wall_time=1, - step=10, - encoded_audio_string=b'big', - content_type='audio/wav', - sample_rate=44100, - length_frames=441000) - snd2 = ea.AudioEvent( - wall_time=2, - step=12, - encoded_audio_string=b'small', - content_type='audio/wav', + hbucket_limit=None, + hbucket=None, + ): + histo = summary_pb2.HistogramProto( + min=hmin, + max=hmax, + num=hnum, + sum=hsum, + sum_squares=hsum_squares, + bucket_limit=hbucket_limit, + bucket=hbucket, + ) + event = event_pb2.Event( + wall_time=wall_time, + step=step, + summary=summary_pb2.Summary( + value=[summary_pb2.Summary.Value(tag=tag, histo=histo)] + ), + ) + self.AddEvent(event) + + def AddImage( + self, + tag, + wall_time=0, + step=0, + encoded_image_string=b"imgstr", + width=150, + height=100, + ): + image = summary_pb2.Summary.Image( + encoded_image_string=encoded_image_string, + width=width, + height=height, + ) + event = event_pb2.Event( + wall_time=wall_time, + step=step, + summary=summary_pb2.Summary( + value=[summary_pb2.Summary.Value(tag=tag, image=image)] + ), + ) + self.AddEvent(event) + + def AddAudio( + self, + tag, + wall_time=0, + step=0, + encoded_audio_string=b"sndstr", + content_type="audio/wav", sample_rate=44100, - length_frames=44100) - gen.AddAudio( - 'snd1', - wall_time=1, - step=10, - encoded_audio_string=b'big', - content_type='audio/wav', - sample_rate=44100, - length_frames=441000) - gen.AddAudio( - 'snd2', - wall_time=2, - step=12, - encoded_audio_string=b'small', - content_type='audio/wav', - sample_rate=44100, - length_frames=44100) - acc.Reload() - self.assertEqual(acc.Audio('snd1'), [snd1]) - self.assertEqual(acc.Audio('snd2'), [snd2]) - - def testKeyError(self): - """KeyError should be raised when accessing non-existing keys.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - acc.Reload() - with self.assertRaises(KeyError): - acc.Scalars('s1') - with self.assertRaises(KeyError): - acc.Scalars('hst1') - with self.assertRaises(KeyError): - acc.Scalars('im1') - with self.assertRaises(KeyError): - acc.Histograms('s1') - with self.assertRaises(KeyError): - acc.Histograms('im1') - with self.assertRaises(KeyError): - acc.Images('s1') - with self.assertRaises(KeyError): - acc.Images('hst1') - with self.assertRaises(KeyError): - acc.Audio('s1') - with self.assertRaises(KeyError): - acc.Audio('hst1') - - def testNonValueEvents(self): - """Non-value events in the generator don't cause early exits.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - gen.AddScalar('s1', wall_time=1, step=10, value=20) - gen.AddEvent( - event_pb2.Event(wall_time=2, step=20, file_version='nots2')) - gen.AddScalar('s3', wall_time=3, step=100, value=1) - gen.AddHistogram('hst1') - gen.AddImage('im1') - gen.AddAudio('snd1') - - acc.Reload() - self.assertTagsEqual(acc.Tags(), { - ea.IMAGES: ['im1'], - ea.AUDIO: ['snd1'], - ea.SCALARS: ['s1', 's3'], - ea.HISTOGRAMS: ['hst1'], - ea.COMPRESSED_HISTOGRAMS: ['hst1'], - }) - - def testExpiredDataDiscardedAfterRestartForFileVersionLessThan2(self): - """Tests that events are discarded after a restart is detected. - - If a step value is observed to be lower than what was previously seen, - this should force a discard of all previous items with the same tag - that are outdated. - - Only file versions < 2 use this out-of-order discard logic. Later versions - discard events based on the step value of SessionLog.START. - """ - warnings = [] - self.stubs.Set(logger, 'warn', warnings.append) - - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - - gen.AddEvent( - event_pb2.Event(wall_time=0, step=0, file_version='brain.Event:1')) - gen.AddScalar('s1', wall_time=1, step=100, value=20) - gen.AddScalar('s1', wall_time=1, step=200, value=20) - gen.AddScalar('s1', wall_time=1, step=300, value=20) - acc.Reload() - ## Check that number of items are what they should be - self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200, 300]) - - gen.AddScalar('s1', wall_time=1, step=101, value=20) - gen.AddScalar('s1', wall_time=1, step=201, value=20) - gen.AddScalar('s1', wall_time=1, step=301, value=20) - acc.Reload() - ## Check that we have discarded 200 and 300 from s1 - self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 101, 201, 301]) - - def testOrphanedDataNotDiscardedIfFlagUnset(self): - """Tests that events are not discarded if purge_orphaned_data is false. - """ - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen, purge_orphaned_data=False) - - gen.AddEvent( - event_pb2.Event(wall_time=0, step=0, file_version='brain.Event:1')) - gen.AddScalar('s1', wall_time=1, step=100, value=20) - gen.AddScalar('s1', wall_time=1, step=200, value=20) - gen.AddScalar('s1', wall_time=1, step=300, value=20) - acc.Reload() - ## Check that number of items are what they should be - self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200, 300]) - - gen.AddScalar('s1', wall_time=1, step=101, value=20) - gen.AddScalar('s1', wall_time=1, step=201, value=20) - gen.AddScalar('s1', wall_time=1, step=301, value=20) - acc.Reload() - ## Check that we have NOT discarded 200 and 300 from s1 - self.assertEqual([x.step for x in acc.Scalars('s1')], - [100, 200, 300, 101, 201, 301]) - - def testEventsDiscardedPerTagAfterRestartForFileVersionLessThan2(self): - """Tests that event discards after restart, only affect the misordered tag. - - If a step value is observed to be lower than what was previously seen, - this should force a discard of all previous items that are outdated, but - only for the out of order tag. Other tags should remain unaffected. - - Only file versions < 2 use this out-of-order discard logic. Later versions - discard events based on the step value of SessionLog.START. - """ - warnings = [] - self.stubs.Set(logger, 'warn', warnings.append) - - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - - gen.AddEvent( - event_pb2.Event(wall_time=0, step=0, file_version='brain.Event:1')) - gen.AddScalar('s1', wall_time=1, step=100, value=20) - gen.AddScalar('s1', wall_time=1, step=200, value=20) - gen.AddScalar('s1', wall_time=1, step=300, value=20) - gen.AddScalar('s1', wall_time=1, step=101, value=20) - gen.AddScalar('s1', wall_time=1, step=201, value=20) - gen.AddScalar('s1', wall_time=1, step=301, value=20) - - gen.AddScalar('s2', wall_time=1, step=101, value=20) - gen.AddScalar('s2', wall_time=1, step=201, value=20) - gen.AddScalar('s2', wall_time=1, step=301, value=20) - - acc.Reload() - ## Check that we have discarded 200 and 300 - self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 101, 201, 301]) - - ## Check that s1 discards do not affect s2 - ## i.e. check that only events from the out of order tag are discarded - self.assertEqual([x.step for x in acc.Scalars('s2')], [101, 201, 301]) - - def testOnlySummaryEventsTriggerDiscards(self): - """Test that file version event does not trigger data purge.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - gen.AddScalar('s1', wall_time=1, step=100, value=20) - ev1 = event_pb2.Event(wall_time=2, step=0, file_version='brain.Event:1') - graph_bytes = tf.compat.v1.GraphDef().SerializeToString() - ev2 = event_pb2.Event(wall_time=3, step=0, graph_def=graph_bytes) - gen.AddEvent(ev1) - gen.AddEvent(ev2) - acc.Reload() - self.assertEqual([x.step for x in acc.Scalars('s1')], [100]) - - def testSessionLogStartMessageDiscardsExpiredEvents(self): - """Test that SessionLog.START message discards expired events. - - This discard logic is preferred over the out-of-order step discard logic, - but this logic can only be used for event protos which have the SessionLog - enum, which was introduced to event.proto for file_version >= brain.Event:2. - """ - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - gen.AddEvent( - event_pb2.Event(wall_time=0, step=1, file_version='brain.Event:2')) - - gen.AddScalar('s1', wall_time=1, step=100, value=20) - gen.AddScalar('s1', wall_time=1, step=200, value=20) - gen.AddScalar('s1', wall_time=1, step=300, value=20) - gen.AddScalar('s1', wall_time=1, step=400, value=20) - - gen.AddScalar('s2', wall_time=1, step=202, value=20) - gen.AddScalar('s2', wall_time=1, step=203, value=20) - - slog = event_pb2.SessionLog(status=event_pb2.SessionLog.START) - gen.AddEvent( - event_pb2.Event(wall_time=2, step=201, session_log=slog)) - acc.Reload() - self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200]) - self.assertEqual([x.step for x in acc.Scalars('s2')], []) - - def testFirstEventTimestamp(self): - """Test that FirstEventTimestamp() returns wall_time of the first event.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - gen.AddEvent( - event_pb2.Event(wall_time=10, step=20, file_version='brain.Event:2')) - gen.AddScalar('s1', wall_time=30, step=40, value=20) - self.assertEqual(acc.FirstEventTimestamp(), 10) - - def testReloadPopulatesFirstEventTimestamp(self): - """Test that Reload() means FirstEventTimestamp() won't load events.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - gen.AddEvent( - event_pb2.Event(wall_time=1, step=2, file_version='brain.Event:2')) - - acc.Reload() - - def _Die(*args, **kwargs): # pylint: disable=unused-argument - raise RuntimeError('Load() should not be called') - - self.stubs.Set(gen, 'Load', _Die) - self.assertEqual(acc.FirstEventTimestamp(), 1) - - def testFirstEventTimestampLoadsEvent(self): - """Test that FirstEventTimestamp() doesn't discard the loaded event.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - gen.AddEvent( - event_pb2.Event(wall_time=1, step=2, file_version='brain.Event:2')) - - self.assertEqual(acc.FirstEventTimestamp(), 1) - acc.Reload() - self.assertEqual(acc.file_version, 2.0) - - def testTFSummaryScalar(self): - """Verify processing of tf.summary.scalar.""" - event_sink = _EventGenerator(self, zero_out_timestamps=True) - with test_util.FileWriterCache.get(self.get_temp_dir()) as writer: - writer.event_writer = event_sink - with self.test_session() as sess: - ipt = tf.compat.v1.placeholder(tf.float32) - tf.compat.v1.summary.scalar('scalar1', ipt) - tf.compat.v1.summary.scalar('scalar2', ipt * ipt) - merged = tf.compat.v1.summary.merge_all() - writer.add_graph(sess.graph) - for i in xrange(10): - summ = sess.run(merged, feed_dict={ipt: i}) - writer.add_summary(summ, global_step=i) - - accumulator = ea.EventAccumulator(event_sink) - accumulator.Reload() - - seq1 = [ea.ScalarEvent(wall_time=0, step=i, value=i) for i in xrange(10)] - seq2 = [ - ea.ScalarEvent( - wall_time=0, step=i, value=i * i) for i in xrange(10) - ] - - self.assertTagsEqual(accumulator.Tags(), { - ea.SCALARS: ['scalar1', 'scalar2'], - ea.GRAPH: True, - ea.META_GRAPH: False, - }) - - self.assertEqual(accumulator.Scalars('scalar1'), seq1) - self.assertEqual(accumulator.Scalars('scalar2'), seq2) - first_value = accumulator.Scalars('scalar1')[0].value - self.assertTrue(isinstance(first_value, float)) - - def testTFSummaryImage(self): - """Verify processing of tf.summary.image.""" - event_sink = _EventGenerator(self, zero_out_timestamps=True) - with test_util.FileWriterCache.get(self.get_temp_dir()) as writer: - writer.event_writer = event_sink - with self.test_session() as sess: - ipt = tf.ones([10, 4, 4, 3], tf.uint8) - # This is an interesting example, because the old tf.image_summary op - # would throw an error here, because it would be tag reuse. - # Using the tf node name instead allows argument re-use to the image - # summary. - with tf.name_scope('1'): - tf.compat.v1.summary.image('images', ipt, max_outputs=1) - with tf.name_scope('2'): - tf.compat.v1.summary.image('images', ipt, max_outputs=2) - with tf.name_scope('3'): - tf.compat.v1.summary.image('images', ipt, max_outputs=3) - merged = tf.compat.v1.summary.merge_all() - writer.add_graph(sess.graph) - for i in xrange(10): - summ = sess.run(merged) - writer.add_summary(summ, global_step=i) - - accumulator = ea.EventAccumulator(event_sink) - accumulator.Reload() - - tags = [ - u'1/images/image', u'2/images/image/0', u'2/images/image/1', - u'3/images/image/0', u'3/images/image/1', u'3/images/image/2' - ] - - self.assertTagsEqual(accumulator.Tags(), { - ea.IMAGES: tags, - ea.GRAPH: True, - ea.META_GRAPH: False, - }) - - def testTFSummaryTensor(self): - """Verify processing of tf.summary.tensor.""" - event_sink = _EventGenerator(self, zero_out_timestamps=True) - with test_util.FileWriterCache.get(self.get_temp_dir()) as writer: - writer.event_writer = event_sink - with self.test_session() as sess: - tf.compat.v1.summary.tensor_summary('scalar', tf.constant(1.0)) - tf.compat.v1.summary.tensor_summary('vector', tf.constant([1.0, 2.0, 3.0])) - tf.compat.v1.summary.tensor_summary('string', tf.constant(six.b('foobar'))) - merged = tf.compat.v1.summary.merge_all() - summ = sess.run(merged) - writer.add_summary(summ, 0) - - accumulator = ea.EventAccumulator(event_sink) - accumulator.Reload() - - self.assertTagsEqual(accumulator.Tags(), { - ea.TENSORS: ['scalar', 'vector', 'string'], - }) - - scalar_proto = accumulator.Tensors('scalar')[0].tensor_proto - scalar = tf.compat.v1.make_ndarray(scalar_proto) - vector_proto = accumulator.Tensors('vector')[0].tensor_proto - vector = tf.compat.v1.make_ndarray(vector_proto) - string_proto = accumulator.Tensors('string')[0].tensor_proto - string = tf.compat.v1.make_ndarray(string_proto) - - self.assertTrue(np.array_equal(scalar, 1.0)) - self.assertTrue(np.array_equal(vector, [1.0, 2.0, 3.0])) - self.assertTrue(np.array_equal(string, six.b('foobar'))) + length_frames=22050, + ): + audio = summary_pb2.Summary.Audio( + encoded_audio_string=encoded_audio_string, + content_type=content_type, + sample_rate=sample_rate, + length_frames=length_frames, + ) + event = event_pb2.Event( + wall_time=wall_time, + step=step, + summary=summary_pb2.Summary( + value=[summary_pb2.Summary.Value(tag=tag, audio=audio)] + ), + ) + self.AddEvent(event) + + def AddEvent(self, event): + if self.zero_out_timestamps: + event.wall_time = 0 + self.items.append(event) + + def add_event(self, event): # pylint: disable=invalid-name + """Match the EventWriter API.""" + self.AddEvent(event) + + def get_logdir(self): # pylint: disable=invalid-name + """Return a temp directory for asset writing.""" + return self._testcase.get_temp_dir() + + def close(self): + """Closes the event writer.""" + # noop -class RealisticEventAccumulatorTest(EventAccumulatorTest): +class EventAccumulatorTest(tf.test.TestCase): + def assertTagsEqual(self, actual, expected): + """Utility method for checking the return value of the Tags() call. + + It fills out the `expected` arg with the default (empty) values for every + tag type, so that the author needs only specify the non-empty values they + are interested in testing. + + Args: + actual: The actual Accumulator tags response. + expected: The expected tags response (empty fields may be omitted) + """ + + empty_tags = { + ea.IMAGES: [], + ea.AUDIO: [], + ea.SCALARS: [], + ea.HISTOGRAMS: [], + ea.COMPRESSED_HISTOGRAMS: [], + ea.GRAPH: False, + ea.META_GRAPH: False, + ea.RUN_METADATA: [], + ea.TENSORS: [], + } + + # Verifies that there are no unexpected keys in the actual response. + # If this line fails, likely you added a new tag type, and need to update + # the empty_tags dictionary above. + self.assertItemsEqual(actual.keys(), empty_tags.keys()) + + for key in actual: + expected_value = expected.get(key, empty_tags[key]) + if isinstance(expected_value, list): + self.assertItemsEqual(actual[key], expected_value) + else: + self.assertEqual(actual[key], expected_value) + + +class MockingEventAccumulatorTest(EventAccumulatorTest): + def setUp(self): + super(MockingEventAccumulatorTest, self).setUp() + self.stubs = tf.compat.v1.test.StubOutForTesting() + self._real_constructor = ea.EventAccumulator + self._real_generator = ea._GeneratorFromPath + + def _FakeAccumulatorConstructor(generator, *args, **kwargs): + ea._GeneratorFromPath = lambda x: generator + return self._real_constructor(generator, *args, **kwargs) + + ea.EventAccumulator = _FakeAccumulatorConstructor + + def tearDown(self): + self.stubs.CleanUp() + ea.EventAccumulator = self._real_constructor + ea._GeneratorFromPath = self._real_generator + + def testEmptyAccumulator(self): + gen = _EventGenerator(self) + x = ea.EventAccumulator(gen) + x.Reload() + self.assertTagsEqual(x.Tags(), {}) + + def testTags(self): + """Tags should be found in EventAccumulator after adding some + events.""" + gen = _EventGenerator(self) + gen.AddScalar("s1") + gen.AddScalar("s2") + gen.AddHistogram("hst1") + gen.AddHistogram("hst2") + gen.AddImage("im1") + gen.AddImage("im2") + gen.AddAudio("snd1") + gen.AddAudio("snd2") + acc = ea.EventAccumulator(gen) + acc.Reload() + self.assertTagsEqual( + acc.Tags(), + { + ea.IMAGES: ["im1", "im2"], + ea.AUDIO: ["snd1", "snd2"], + ea.SCALARS: ["s1", "s2"], + ea.HISTOGRAMS: ["hst1", "hst2"], + ea.COMPRESSED_HISTOGRAMS: ["hst1", "hst2"], + }, + ) + + def testReload(self): + """EventAccumulator contains suitable tags after calling Reload.""" + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + acc.Reload() + self.assertTagsEqual(acc.Tags(), {}) + gen.AddScalar("s1") + gen.AddScalar("s2") + gen.AddHistogram("hst1") + gen.AddHistogram("hst2") + gen.AddImage("im1") + gen.AddImage("im2") + gen.AddAudio("snd1") + gen.AddAudio("snd2") + acc.Reload() + self.assertTagsEqual( + acc.Tags(), + { + ea.IMAGES: ["im1", "im2"], + ea.AUDIO: ["snd1", "snd2"], + ea.SCALARS: ["s1", "s2"], + ea.HISTOGRAMS: ["hst1", "hst2"], + ea.COMPRESSED_HISTOGRAMS: ["hst1", "hst2"], + }, + ) + + def testScalars(self): + """Tests whether EventAccumulator contains scalars after adding + them.""" + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + s1 = ea.ScalarEvent(wall_time=1, step=10, value=32) + s2 = ea.ScalarEvent(wall_time=2, step=12, value=64) + gen.AddScalar("s1", wall_time=1, step=10, value=32) + gen.AddScalar("s2", wall_time=2, step=12, value=64) + acc.Reload() + self.assertEqual(acc.Scalars("s1"), [s1]) + self.assertEqual(acc.Scalars("s2"), [s2]) + + def testHistograms(self): + """Tests whether histograms are inserted into EventAccumulator.""" + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + + val1 = ea.HistogramValue( + min=1, + max=2, + num=3, + sum=4, + sum_squares=5, + bucket_limit=[1, 2, 3], + bucket=[0, 3, 0], + ) + val2 = ea.HistogramValue( + min=-2, + max=3, + num=4, + sum=5, + sum_squares=6, + bucket_limit=[2, 3, 4], + bucket=[1, 3, 0], + ) + + hst1 = ea.HistogramEvent(wall_time=1, step=10, histogram_value=val1) + hst2 = ea.HistogramEvent(wall_time=2, step=12, histogram_value=val2) + gen.AddHistogram( + "hst1", + wall_time=1, + step=10, + hmin=1, + hmax=2, + hnum=3, + hsum=4, + hsum_squares=5, + hbucket_limit=[1, 2, 3], + hbucket=[0, 3, 0], + ) + gen.AddHistogram( + "hst2", + wall_time=2, + step=12, + hmin=-2, + hmax=3, + hnum=4, + hsum=5, + hsum_squares=6, + hbucket_limit=[2, 3, 4], + hbucket=[1, 3, 0], + ) + acc.Reload() + self.assertEqual(acc.Histograms("hst1"), [hst1]) + self.assertEqual(acc.Histograms("hst2"), [hst2]) + + def testCompressedHistograms(self): + """Tests compressed histograms inserted into EventAccumulator.""" + gen = _EventGenerator(self) + acc = ea.EventAccumulator( + gen, compression_bps=(0, 2500, 5000, 7500, 10000) + ) + + gen.AddHistogram( + "hst1", + wall_time=1, + step=10, + hmin=1, + hmax=2, + hnum=3, + hsum=4, + hsum_squares=5, + hbucket_limit=[1, 2, 3], + hbucket=[0, 3, 0], + ) + gen.AddHistogram( + "hst2", + wall_time=2, + step=12, + hmin=-2, + hmax=3, + hnum=4, + hsum=5, + hsum_squares=6, + hbucket_limit=[2, 3, 4], + hbucket=[1, 3, 0], + ) + acc.Reload() + + # Create the expected values after compressing hst1 + expected_vals1 = [ + compressor.CompressedHistogramValue(bp, val) + for bp, val in [ + (0, 1.0), + (2500, 1.25), + (5000, 1.5), + (7500, 1.75), + (10000, 2.0), + ] + ] + expected_cmphst1 = ea.CompressedHistogramEvent( + wall_time=1, step=10, compressed_histogram_values=expected_vals1 + ) + self.assertEqual(acc.CompressedHistograms("hst1"), [expected_cmphst1]) + + # Create the expected values after compressing hst2 + expected_vals2 = [ + compressor.CompressedHistogramValue(bp, val) + for bp, val in [ + (0, -2), + (2500, 2), + (5000, 2 + 1 / 3), + (7500, 2 + 2 / 3), + (10000, 3), + ] + ] + expected_cmphst2 = ea.CompressedHistogramEvent( + wall_time=2, step=12, compressed_histogram_values=expected_vals2 + ) + self.assertEqual(acc.CompressedHistograms("hst2"), [expected_cmphst2]) + + def testImages(self): + """Tests 2 images inserted/accessed in EventAccumulator.""" + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + im1 = ea.ImageEvent( + wall_time=1, + step=10, + encoded_image_string=b"big", + width=400, + height=300, + ) + im2 = ea.ImageEvent( + wall_time=2, + step=12, + encoded_image_string=b"small", + width=40, + height=30, + ) + gen.AddImage( + "im1", + wall_time=1, + step=10, + encoded_image_string=b"big", + width=400, + height=300, + ) + gen.AddImage( + "im2", + wall_time=2, + step=12, + encoded_image_string=b"small", + width=40, + height=30, + ) + acc.Reload() + self.assertEqual(acc.Images("im1"), [im1]) + self.assertEqual(acc.Images("im2"), [im2]) + + def testAudio(self): + """Tests 2 audio events inserted/accessed in EventAccumulator.""" + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + snd1 = ea.AudioEvent( + wall_time=1, + step=10, + encoded_audio_string=b"big", + content_type="audio/wav", + sample_rate=44100, + length_frames=441000, + ) + snd2 = ea.AudioEvent( + wall_time=2, + step=12, + encoded_audio_string=b"small", + content_type="audio/wav", + sample_rate=44100, + length_frames=44100, + ) + gen.AddAudio( + "snd1", + wall_time=1, + step=10, + encoded_audio_string=b"big", + content_type="audio/wav", + sample_rate=44100, + length_frames=441000, + ) + gen.AddAudio( + "snd2", + wall_time=2, + step=12, + encoded_audio_string=b"small", + content_type="audio/wav", + sample_rate=44100, + length_frames=44100, + ) + acc.Reload() + self.assertEqual(acc.Audio("snd1"), [snd1]) + self.assertEqual(acc.Audio("snd2"), [snd2]) + + def testKeyError(self): + """KeyError should be raised when accessing non-existing keys.""" + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + acc.Reload() + with self.assertRaises(KeyError): + acc.Scalars("s1") + with self.assertRaises(KeyError): + acc.Scalars("hst1") + with self.assertRaises(KeyError): + acc.Scalars("im1") + with self.assertRaises(KeyError): + acc.Histograms("s1") + with self.assertRaises(KeyError): + acc.Histograms("im1") + with self.assertRaises(KeyError): + acc.Images("s1") + with self.assertRaises(KeyError): + acc.Images("hst1") + with self.assertRaises(KeyError): + acc.Audio("s1") + with self.assertRaises(KeyError): + acc.Audio("hst1") + + def testNonValueEvents(self): + """Non-value events in the generator don't cause early exits.""" + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + gen.AddScalar("s1", wall_time=1, step=10, value=20) + gen.AddEvent( + event_pb2.Event(wall_time=2, step=20, file_version="nots2") + ) + gen.AddScalar("s3", wall_time=3, step=100, value=1) + gen.AddHistogram("hst1") + gen.AddImage("im1") + gen.AddAudio("snd1") + + acc.Reload() + self.assertTagsEqual( + acc.Tags(), + { + ea.IMAGES: ["im1"], + ea.AUDIO: ["snd1"], + ea.SCALARS: ["s1", "s3"], + ea.HISTOGRAMS: ["hst1"], + ea.COMPRESSED_HISTOGRAMS: ["hst1"], + }, + ) + + def testExpiredDataDiscardedAfterRestartForFileVersionLessThan2(self): + """Tests that events are discarded after a restart is detected. + + If a step value is observed to be lower than what was previously seen, + this should force a discard of all previous items with the same tag + that are outdated. + + Only file versions < 2 use this out-of-order discard logic. Later versions + discard events based on the step value of SessionLog.START. + """ + warnings = [] + self.stubs.Set(logger, "warn", warnings.append) + + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + + gen.AddEvent( + event_pb2.Event(wall_time=0, step=0, file_version="brain.Event:1") + ) + gen.AddScalar("s1", wall_time=1, step=100, value=20) + gen.AddScalar("s1", wall_time=1, step=200, value=20) + gen.AddScalar("s1", wall_time=1, step=300, value=20) + acc.Reload() + ## Check that number of items are what they should be + self.assertEqual([x.step for x in acc.Scalars("s1")], [100, 200, 300]) + + gen.AddScalar("s1", wall_time=1, step=101, value=20) + gen.AddScalar("s1", wall_time=1, step=201, value=20) + gen.AddScalar("s1", wall_time=1, step=301, value=20) + acc.Reload() + ## Check that we have discarded 200 and 300 from s1 + self.assertEqual( + [x.step for x in acc.Scalars("s1")], [100, 101, 201, 301] + ) + + def testOrphanedDataNotDiscardedIfFlagUnset(self): + """Tests that events are not discarded if purge_orphaned_data is + false.""" + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen, purge_orphaned_data=False) + + gen.AddEvent( + event_pb2.Event(wall_time=0, step=0, file_version="brain.Event:1") + ) + gen.AddScalar("s1", wall_time=1, step=100, value=20) + gen.AddScalar("s1", wall_time=1, step=200, value=20) + gen.AddScalar("s1", wall_time=1, step=300, value=20) + acc.Reload() + ## Check that number of items are what they should be + self.assertEqual([x.step for x in acc.Scalars("s1")], [100, 200, 300]) + + gen.AddScalar("s1", wall_time=1, step=101, value=20) + gen.AddScalar("s1", wall_time=1, step=201, value=20) + gen.AddScalar("s1", wall_time=1, step=301, value=20) + acc.Reload() + ## Check that we have NOT discarded 200 and 300 from s1 + self.assertEqual( + [x.step for x in acc.Scalars("s1")], [100, 200, 300, 101, 201, 301] + ) + + def testEventsDiscardedPerTagAfterRestartForFileVersionLessThan2(self): + """Tests that event discards after restart, only affect the misordered + tag. + + If a step value is observed to be lower than what was previously seen, + this should force a discard of all previous items that are outdated, but + only for the out of order tag. Other tags should remain unaffected. + + Only file versions < 2 use this out-of-order discard logic. Later versions + discard events based on the step value of SessionLog.START. + """ + warnings = [] + self.stubs.Set(logger, "warn", warnings.append) + + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + + gen.AddEvent( + event_pb2.Event(wall_time=0, step=0, file_version="brain.Event:1") + ) + gen.AddScalar("s1", wall_time=1, step=100, value=20) + gen.AddScalar("s1", wall_time=1, step=200, value=20) + gen.AddScalar("s1", wall_time=1, step=300, value=20) + gen.AddScalar("s1", wall_time=1, step=101, value=20) + gen.AddScalar("s1", wall_time=1, step=201, value=20) + gen.AddScalar("s1", wall_time=1, step=301, value=20) + + gen.AddScalar("s2", wall_time=1, step=101, value=20) + gen.AddScalar("s2", wall_time=1, step=201, value=20) + gen.AddScalar("s2", wall_time=1, step=301, value=20) + + acc.Reload() + ## Check that we have discarded 200 and 300 + self.assertEqual( + [x.step for x in acc.Scalars("s1")], [100, 101, 201, 301] + ) + + ## Check that s1 discards do not affect s2 + ## i.e. check that only events from the out of order tag are discarded + self.assertEqual([x.step for x in acc.Scalars("s2")], [101, 201, 301]) + + def testOnlySummaryEventsTriggerDiscards(self): + """Test that file version event does not trigger data purge.""" + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + gen.AddScalar("s1", wall_time=1, step=100, value=20) + ev1 = event_pb2.Event(wall_time=2, step=0, file_version="brain.Event:1") + graph_bytes = tf.compat.v1.GraphDef().SerializeToString() + ev2 = event_pb2.Event(wall_time=3, step=0, graph_def=graph_bytes) + gen.AddEvent(ev1) + gen.AddEvent(ev2) + acc.Reload() + self.assertEqual([x.step for x in acc.Scalars("s1")], [100]) + + def testSessionLogStartMessageDiscardsExpiredEvents(self): + """Test that SessionLog.START message discards expired events. + + This discard logic is preferred over the out-of-order step + discard logic, but this logic can only be used for event protos + which have the SessionLog enum, which was introduced to + event.proto for file_version >= brain.Event:2. + """ + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + gen.AddEvent( + event_pb2.Event(wall_time=0, step=1, file_version="brain.Event:2") + ) + + gen.AddScalar("s1", wall_time=1, step=100, value=20) + gen.AddScalar("s1", wall_time=1, step=200, value=20) + gen.AddScalar("s1", wall_time=1, step=300, value=20) + gen.AddScalar("s1", wall_time=1, step=400, value=20) + + gen.AddScalar("s2", wall_time=1, step=202, value=20) + gen.AddScalar("s2", wall_time=1, step=203, value=20) + + slog = event_pb2.SessionLog(status=event_pb2.SessionLog.START) + gen.AddEvent(event_pb2.Event(wall_time=2, step=201, session_log=slog)) + acc.Reload() + self.assertEqual([x.step for x in acc.Scalars("s1")], [100, 200]) + self.assertEqual([x.step for x in acc.Scalars("s2")], []) + + def testFirstEventTimestamp(self): + """Test that FirstEventTimestamp() returns wall_time of the first + event.""" + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + gen.AddEvent( + event_pb2.Event(wall_time=10, step=20, file_version="brain.Event:2") + ) + gen.AddScalar("s1", wall_time=30, step=40, value=20) + self.assertEqual(acc.FirstEventTimestamp(), 10) + + def testReloadPopulatesFirstEventTimestamp(self): + """Test that Reload() means FirstEventTimestamp() won't load events.""" + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + gen.AddEvent( + event_pb2.Event(wall_time=1, step=2, file_version="brain.Event:2") + ) + + acc.Reload() + + def _Die(*args, **kwargs): # pylint: disable=unused-argument + raise RuntimeError("Load() should not be called") + + self.stubs.Set(gen, "Load", _Die) + self.assertEqual(acc.FirstEventTimestamp(), 1) + + def testFirstEventTimestampLoadsEvent(self): + """Test that FirstEventTimestamp() doesn't discard the loaded event.""" + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + gen.AddEvent( + event_pb2.Event(wall_time=1, step=2, file_version="brain.Event:2") + ) + + self.assertEqual(acc.FirstEventTimestamp(), 1) + acc.Reload() + self.assertEqual(acc.file_version, 2.0) + + def testTFSummaryScalar(self): + """Verify processing of tf.summary.scalar.""" + event_sink = _EventGenerator(self, zero_out_timestamps=True) + with test_util.FileWriterCache.get(self.get_temp_dir()) as writer: + writer.event_writer = event_sink + with self.test_session() as sess: + ipt = tf.compat.v1.placeholder(tf.float32) + tf.compat.v1.summary.scalar("scalar1", ipt) + tf.compat.v1.summary.scalar("scalar2", ipt * ipt) + merged = tf.compat.v1.summary.merge_all() + writer.add_graph(sess.graph) + for i in xrange(10): + summ = sess.run(merged, feed_dict={ipt: i}) + writer.add_summary(summ, global_step=i) + + accumulator = ea.EventAccumulator(event_sink) + accumulator.Reload() + + seq1 = [ + ea.ScalarEvent(wall_time=0, step=i, value=i) for i in xrange(10) + ] + seq2 = [ + ea.ScalarEvent(wall_time=0, step=i, value=i * i) for i in xrange(10) + ] + + self.assertTagsEqual( + accumulator.Tags(), + { + ea.SCALARS: ["scalar1", "scalar2"], + ea.GRAPH: True, + ea.META_GRAPH: False, + }, + ) + + self.assertEqual(accumulator.Scalars("scalar1"), seq1) + self.assertEqual(accumulator.Scalars("scalar2"), seq2) + first_value = accumulator.Scalars("scalar1")[0].value + self.assertTrue(isinstance(first_value, float)) + + def testTFSummaryImage(self): + """Verify processing of tf.summary.image.""" + event_sink = _EventGenerator(self, zero_out_timestamps=True) + with test_util.FileWriterCache.get(self.get_temp_dir()) as writer: + writer.event_writer = event_sink + with self.test_session() as sess: + ipt = tf.ones([10, 4, 4, 3], tf.uint8) + # This is an interesting example, because the old tf.image_summary op + # would throw an error here, because it would be tag reuse. + # Using the tf node name instead allows argument re-use to the image + # summary. + with tf.name_scope("1"): + tf.compat.v1.summary.image("images", ipt, max_outputs=1) + with tf.name_scope("2"): + tf.compat.v1.summary.image("images", ipt, max_outputs=2) + with tf.name_scope("3"): + tf.compat.v1.summary.image("images", ipt, max_outputs=3) + merged = tf.compat.v1.summary.merge_all() + writer.add_graph(sess.graph) + for i in xrange(10): + summ = sess.run(merged) + writer.add_summary(summ, global_step=i) + + accumulator = ea.EventAccumulator(event_sink) + accumulator.Reload() + + tags = [ + u"1/images/image", + u"2/images/image/0", + u"2/images/image/1", + u"3/images/image/0", + u"3/images/image/1", + u"3/images/image/2", + ] + + self.assertTagsEqual( + accumulator.Tags(), + {ea.IMAGES: tags, ea.GRAPH: True, ea.META_GRAPH: False,}, + ) + + def testTFSummaryTensor(self): + """Verify processing of tf.summary.tensor.""" + event_sink = _EventGenerator(self, zero_out_timestamps=True) + with test_util.FileWriterCache.get(self.get_temp_dir()) as writer: + writer.event_writer = event_sink + with self.test_session() as sess: + tf.compat.v1.summary.tensor_summary("scalar", tf.constant(1.0)) + tf.compat.v1.summary.tensor_summary( + "vector", tf.constant([1.0, 2.0, 3.0]) + ) + tf.compat.v1.summary.tensor_summary( + "string", tf.constant(six.b("foobar")) + ) + merged = tf.compat.v1.summary.merge_all() + summ = sess.run(merged) + writer.add_summary(summ, 0) + + accumulator = ea.EventAccumulator(event_sink) + accumulator.Reload() + + self.assertTagsEqual( + accumulator.Tags(), {ea.TENSORS: ["scalar", "vector", "string"],} + ) + + scalar_proto = accumulator.Tensors("scalar")[0].tensor_proto + scalar = tf.compat.v1.make_ndarray(scalar_proto) + vector_proto = accumulator.Tensors("vector")[0].tensor_proto + vector = tf.compat.v1.make_ndarray(vector_proto) + string_proto = accumulator.Tensors("string")[0].tensor_proto + string = tf.compat.v1.make_ndarray(string_proto) + + self.assertTrue(np.array_equal(scalar, 1.0)) + self.assertTrue(np.array_equal(vector, [1.0, 2.0, 3.0])) + self.assertTrue(np.array_equal(string, six.b("foobar"))) - def testScalarsRealistically(self): - """Test accumulator by writing values and then reading them.""" - - def FakeScalarSummary(tag, value): - value = summary_pb2.Summary.Value(tag=tag, simple_value=value) - summary = summary_pb2.Summary(value=[value]) - return summary - - directory = os.path.join(self.get_temp_dir(), 'values_dir') - if tf.io.gfile.isdir(directory): - tf.io.gfile.rmtree(directory) - tf.io.gfile.mkdir(directory) - - writer = test_util.FileWriter(directory, max_queue=100) - - with tf.Graph().as_default() as graph: - _ = tf.constant([2.0, 1.0]) - # Add a graph to the summary writer. - writer.add_graph(graph) - meta_graph_def = tf.compat.v1.train.export_meta_graph(graph_def=graph.as_graph_def( - add_shapes=True)) - writer.add_meta_graph(meta_graph_def) - - run_metadata = config_pb2.RunMetadata() - device_stats = run_metadata.step_stats.dev_stats.add() - device_stats.device = 'test device' - writer.add_run_metadata(run_metadata, 'test run') - - # Write a bunch of events using the writer. - for i in xrange(30): - summ_id = FakeScalarSummary('id', i) - summ_sq = FakeScalarSummary('sq', i * i) - writer.add_summary(summ_id, i * 5) - writer.add_summary(summ_sq, i * 5) - writer.flush() - - # Verify that we can load those events properly - acc = ea.EventAccumulator(directory) - acc.Reload() - self.assertTagsEqual(acc.Tags(), { - ea.SCALARS: ['id', 'sq'], - ea.GRAPH: True, - ea.META_GRAPH: True, - ea.RUN_METADATA: ['test run'], - }) - id_events = acc.Scalars('id') - sq_events = acc.Scalars('sq') - self.assertEqual(30, len(id_events)) - self.assertEqual(30, len(sq_events)) - for i in xrange(30): - self.assertEqual(i * 5, id_events[i].step) - self.assertEqual(i * 5, sq_events[i].step) - self.assertEqual(i, id_events[i].value) - self.assertEqual(i * i, sq_events[i].value) - - # Write a few more events to test incremental reloading - for i in xrange(30, 40): - summ_id = FakeScalarSummary('id', i) - summ_sq = FakeScalarSummary('sq', i * i) - writer.add_summary(summ_id, i * 5) - writer.add_summary(summ_sq, i * 5) - writer.flush() - - # Verify we can now see all of the data - acc.Reload() - id_events = acc.Scalars('id') - sq_events = acc.Scalars('sq') - self.assertEqual(40, len(id_events)) - self.assertEqual(40, len(sq_events)) - for i in xrange(40): - self.assertEqual(i * 5, id_events[i].step) - self.assertEqual(i * 5, sq_events[i].step) - self.assertEqual(i, id_events[i].value) - self.assertEqual(i * i, sq_events[i].value) - - expected_graph_def = graph_pb2.GraphDef.FromString( - graph.as_graph_def(add_shapes=True).SerializeToString()) - self.assertProtoEquals(expected_graph_def, acc.Graph()) - - expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString( - meta_graph_def.SerializeToString()) - self.assertProtoEquals(expected_meta_graph, acc.MetaGraph()) - - def testGraphFromMetaGraphBecomesAvailable(self): - """Test accumulator by writing values and then reading them.""" - - directory = os.path.join(self.get_temp_dir(), 'metagraph_test_values_dir') - if tf.io.gfile.isdir(directory): - tf.io.gfile.rmtree(directory) - tf.io.gfile.mkdir(directory) - - writer = test_util.FileWriter(directory, max_queue=100) - - with tf.Graph().as_default() as graph: - _ = tf.constant([2.0, 1.0]) - # Add a graph to the summary writer. - meta_graph_def = tf.compat.v1.train.export_meta_graph(graph_def=graph.as_graph_def( - add_shapes=True)) - writer.add_meta_graph(meta_graph_def) - - writer.flush() - - # Verify that we can load those events properly - acc = ea.EventAccumulator(directory) - acc.Reload() - self.assertTagsEqual(acc.Tags(), { - ea.GRAPH: True, - ea.META_GRAPH: True, - }) - - expected_graph_def = graph_pb2.GraphDef.FromString( - graph.as_graph_def(add_shapes=True).SerializeToString()) - self.assertProtoEquals(expected_graph_def, acc.Graph()) - - expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString( - meta_graph_def.SerializeToString()) - self.assertProtoEquals(expected_meta_graph, acc.MetaGraph()) - - def _writeMetadata(self, logdir, summary_metadata, nonce=''): - """Write to disk a summary with the given metadata. - - Arguments: - logdir: a string - summary_metadata: a `SummaryMetadata` protobuf object - nonce: optional; will be added to the end of the event file name - to guarantee that multiple calls to this function do not stomp the - same file - """ - summary = summary_pb2.Summary() - summary.value.add( - tensor=tensor_util.make_tensor_proto(['po', 'ta', 'to'], dtype=tf.string), - tag='you_are_it', - metadata=summary_metadata) - writer = test_util.FileWriter(logdir, filename_suffix=nonce) - writer.add_summary(summary.SerializeToString()) - writer.close() - - def testSummaryMetadata(self): - logdir = self.get_temp_dir() - summary_metadata = summary_pb2.SummaryMetadata( - display_name='current tagee', summary_description='no') - summary_metadata.plugin_data.plugin_name = 'outlet' - self._writeMetadata(logdir, summary_metadata) - acc = ea.EventAccumulator(logdir) - acc.Reload() - self.assertProtoEquals(summary_metadata, - acc.SummaryMetadata('you_are_it')) - - def testSummaryMetadata_FirstMetadataWins(self): - logdir = self.get_temp_dir() - summary_metadata_1 = summary_pb2.SummaryMetadata( - display_name='current tagee', - summary_description='no', - plugin_data=summary_pb2.SummaryMetadata.PluginData( - plugin_name='outlet', content=b'120v')) - self._writeMetadata(logdir, summary_metadata_1, nonce='1') - acc = ea.EventAccumulator(logdir) - acc.Reload() - summary_metadata_2 = summary_pb2.SummaryMetadata( - display_name='tagee of the future', - summary_description='definitely not', - plugin_data=summary_pb2.SummaryMetadata.PluginData( - plugin_name='plug', content=b'110v')) - self._writeMetadata(logdir, summary_metadata_2, nonce='2') - acc.Reload() - - self.assertProtoEquals(summary_metadata_1, - acc.SummaryMetadata('you_are_it')) - - def testPluginTagToContent_PluginsCannotJumpOnTheBandwagon(self): - # If there are multiple `SummaryMetadata` for a given tag, and the - # set of plugins in the `plugin_data` of second is different from - # that of the first, then the second set should be ignored. - logdir = self.get_temp_dir() - summary_metadata_1 = summary_pb2.SummaryMetadata( - display_name='current tagee', - summary_description='no', - plugin_data=summary_pb2.SummaryMetadata.PluginData( - plugin_name='outlet', content=b'120v')) - self._writeMetadata(logdir, summary_metadata_1, nonce='1') - acc = ea.EventAccumulator(logdir) - acc.Reload() - summary_metadata_2 = summary_pb2.SummaryMetadata( - display_name='tagee of the future', - summary_description='definitely not', - plugin_data=summary_pb2.SummaryMetadata.PluginData( - plugin_name='plug', content=b'110v')) - self._writeMetadata(logdir, summary_metadata_2, nonce='2') - acc.Reload() - - self.assertEqual(acc.PluginTagToContent('outlet'), - {'you_are_it': b'120v'}) - with six.assertRaisesRegex(self, KeyError, 'plug'): - acc.PluginTagToContent('plug') - -if __name__ == '__main__': - tf.test.main() +class RealisticEventAccumulatorTest(EventAccumulatorTest): + def testScalarsRealistically(self): + """Test accumulator by writing values and then reading them.""" + + def FakeScalarSummary(tag, value): + value = summary_pb2.Summary.Value(tag=tag, simple_value=value) + summary = summary_pb2.Summary(value=[value]) + return summary + + directory = os.path.join(self.get_temp_dir(), "values_dir") + if tf.io.gfile.isdir(directory): + tf.io.gfile.rmtree(directory) + tf.io.gfile.mkdir(directory) + + writer = test_util.FileWriter(directory, max_queue=100) + + with tf.Graph().as_default() as graph: + _ = tf.constant([2.0, 1.0]) + # Add a graph to the summary writer. + writer.add_graph(graph) + meta_graph_def = tf.compat.v1.train.export_meta_graph( + graph_def=graph.as_graph_def(add_shapes=True) + ) + writer.add_meta_graph(meta_graph_def) + + run_metadata = config_pb2.RunMetadata() + device_stats = run_metadata.step_stats.dev_stats.add() + device_stats.device = "test device" + writer.add_run_metadata(run_metadata, "test run") + + # Write a bunch of events using the writer. + for i in xrange(30): + summ_id = FakeScalarSummary("id", i) + summ_sq = FakeScalarSummary("sq", i * i) + writer.add_summary(summ_id, i * 5) + writer.add_summary(summ_sq, i * 5) + writer.flush() + + # Verify that we can load those events properly + acc = ea.EventAccumulator(directory) + acc.Reload() + self.assertTagsEqual( + acc.Tags(), + { + ea.SCALARS: ["id", "sq"], + ea.GRAPH: True, + ea.META_GRAPH: True, + ea.RUN_METADATA: ["test run"], + }, + ) + id_events = acc.Scalars("id") + sq_events = acc.Scalars("sq") + self.assertEqual(30, len(id_events)) + self.assertEqual(30, len(sq_events)) + for i in xrange(30): + self.assertEqual(i * 5, id_events[i].step) + self.assertEqual(i * 5, sq_events[i].step) + self.assertEqual(i, id_events[i].value) + self.assertEqual(i * i, sq_events[i].value) + + # Write a few more events to test incremental reloading + for i in xrange(30, 40): + summ_id = FakeScalarSummary("id", i) + summ_sq = FakeScalarSummary("sq", i * i) + writer.add_summary(summ_id, i * 5) + writer.add_summary(summ_sq, i * 5) + writer.flush() + + # Verify we can now see all of the data + acc.Reload() + id_events = acc.Scalars("id") + sq_events = acc.Scalars("sq") + self.assertEqual(40, len(id_events)) + self.assertEqual(40, len(sq_events)) + for i in xrange(40): + self.assertEqual(i * 5, id_events[i].step) + self.assertEqual(i * 5, sq_events[i].step) + self.assertEqual(i, id_events[i].value) + self.assertEqual(i * i, sq_events[i].value) + + expected_graph_def = graph_pb2.GraphDef.FromString( + graph.as_graph_def(add_shapes=True).SerializeToString() + ) + self.assertProtoEquals(expected_graph_def, acc.Graph()) + + expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString( + meta_graph_def.SerializeToString() + ) + self.assertProtoEquals(expected_meta_graph, acc.MetaGraph()) + + def testGraphFromMetaGraphBecomesAvailable(self): + """Test accumulator by writing values and then reading them.""" + + directory = os.path.join( + self.get_temp_dir(), "metagraph_test_values_dir" + ) + if tf.io.gfile.isdir(directory): + tf.io.gfile.rmtree(directory) + tf.io.gfile.mkdir(directory) + + writer = test_util.FileWriter(directory, max_queue=100) + + with tf.Graph().as_default() as graph: + _ = tf.constant([2.0, 1.0]) + # Add a graph to the summary writer. + meta_graph_def = tf.compat.v1.train.export_meta_graph( + graph_def=graph.as_graph_def(add_shapes=True) + ) + writer.add_meta_graph(meta_graph_def) + + writer.flush() + + # Verify that we can load those events properly + acc = ea.EventAccumulator(directory) + acc.Reload() + self.assertTagsEqual(acc.Tags(), {ea.GRAPH: True, ea.META_GRAPH: True,}) + + expected_graph_def = graph_pb2.GraphDef.FromString( + graph.as_graph_def(add_shapes=True).SerializeToString() + ) + self.assertProtoEquals(expected_graph_def, acc.Graph()) + + expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString( + meta_graph_def.SerializeToString() + ) + self.assertProtoEquals(expected_meta_graph, acc.MetaGraph()) + + def _writeMetadata(self, logdir, summary_metadata, nonce=""): + """Write to disk a summary with the given metadata. + + Arguments: + logdir: a string + summary_metadata: a `SummaryMetadata` protobuf object + nonce: optional; will be added to the end of the event file name + to guarantee that multiple calls to this function do not stomp the + same file + """ + + summary = summary_pb2.Summary() + summary.value.add( + tensor=tensor_util.make_tensor_proto( + ["po", "ta", "to"], dtype=tf.string + ), + tag="you_are_it", + metadata=summary_metadata, + ) + writer = test_util.FileWriter(logdir, filename_suffix=nonce) + writer.add_summary(summary.SerializeToString()) + writer.close() + + def testSummaryMetadata(self): + logdir = self.get_temp_dir() + summary_metadata = summary_pb2.SummaryMetadata( + display_name="current tagee", summary_description="no" + ) + summary_metadata.plugin_data.plugin_name = "outlet" + self._writeMetadata(logdir, summary_metadata) + acc = ea.EventAccumulator(logdir) + acc.Reload() + self.assertProtoEquals( + summary_metadata, acc.SummaryMetadata("you_are_it") + ) + + def testSummaryMetadata_FirstMetadataWins(self): + logdir = self.get_temp_dir() + summary_metadata_1 = summary_pb2.SummaryMetadata( + display_name="current tagee", + summary_description="no", + plugin_data=summary_pb2.SummaryMetadata.PluginData( + plugin_name="outlet", content=b"120v" + ), + ) + self._writeMetadata(logdir, summary_metadata_1, nonce="1") + acc = ea.EventAccumulator(logdir) + acc.Reload() + summary_metadata_2 = summary_pb2.SummaryMetadata( + display_name="tagee of the future", + summary_description="definitely not", + plugin_data=summary_pb2.SummaryMetadata.PluginData( + plugin_name="plug", content=b"110v" + ), + ) + self._writeMetadata(logdir, summary_metadata_2, nonce="2") + acc.Reload() + + self.assertProtoEquals( + summary_metadata_1, acc.SummaryMetadata("you_are_it") + ) + + def testPluginTagToContent_PluginsCannotJumpOnTheBandwagon(self): + # If there are multiple `SummaryMetadata` for a given tag, and the + # set of plugins in the `plugin_data` of second is different from + # that of the first, then the second set should be ignored. + logdir = self.get_temp_dir() + summary_metadata_1 = summary_pb2.SummaryMetadata( + display_name="current tagee", + summary_description="no", + plugin_data=summary_pb2.SummaryMetadata.PluginData( + plugin_name="outlet", content=b"120v" + ), + ) + self._writeMetadata(logdir, summary_metadata_1, nonce="1") + acc = ea.EventAccumulator(logdir) + acc.Reload() + summary_metadata_2 = summary_pb2.SummaryMetadata( + display_name="tagee of the future", + summary_description="definitely not", + plugin_data=summary_pb2.SummaryMetadata.PluginData( + plugin_name="plug", content=b"110v" + ), + ) + self._writeMetadata(logdir, summary_metadata_2, nonce="2") + acc.Reload() + + self.assertEqual( + acc.PluginTagToContent("outlet"), {"you_are_it": b"120v"} + ) + with six.assertRaisesRegex(self, KeyError, "plug"): + acc.PluginTagToContent("plug") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/backend/event_processing/event_file_inspector.py b/tensorboard/backend/event_processing/event_file_inspector.py index 5504d671a5..284970944e 100644 --- a/tensorboard/backend/event_processing/event_file_inspector.py +++ b/tensorboard/backend/event_processing/event_file_inspector.py @@ -125,31 +125,34 @@ # Map of field names within summary.proto to the user-facing names that this # script outputs. -SUMMARY_TYPE_TO_FIELD = {'simple_value': 'scalars', - 'histo': 'histograms', - 'image': 'images', - 'audio': 'audio'} +SUMMARY_TYPE_TO_FIELD = { + "simple_value": "scalars", + "histo": "histograms", + "image": "images", + "audio": "audio", +} for summary_type in event_accumulator.SUMMARY_TYPES: - if summary_type not in SUMMARY_TYPE_TO_FIELD: - SUMMARY_TYPE_TO_FIELD[summary_type] = summary_type + if summary_type not in SUMMARY_TYPE_TO_FIELD: + SUMMARY_TYPE_TO_FIELD[summary_type] = summary_type # Types of summaries that we may want to query for by tag. TAG_FIELDS = list(SUMMARY_TYPE_TO_FIELD.values()) # Summaries that we want to see every instance of. -LONG_FIELDS = ['sessionlog:start', 'sessionlog:stop'] +LONG_FIELDS = ["sessionlog:start", "sessionlog:stop"] # Summaries that we only want an abridged digest of, since they would # take too much screen real estate otherwise. -SHORT_FIELDS = ['graph', 'sessionlog:checkpoint'] + TAG_FIELDS +SHORT_FIELDS = ["graph", "sessionlog:checkpoint"] + TAG_FIELDS # All summary types that we can inspect. TRACKED_FIELDS = SHORT_FIELDS + LONG_FIELDS # An `Observation` contains the data within each Event file that the inspector # cares about. The inspector accumulates Observations as it processes events. -Observation = collections.namedtuple('Observation', ['step', 'wall_time', - 'tag']) +Observation = collections.namedtuple( + "Observation", ["step", "wall_time", "tag"] +) # An InspectionUnit is created for each organizational structure in the event # files visible in the final terminal output. For instance, one InspectionUnit @@ -159,259 +162,286 @@ # The InspectionUnit contains the `name` of the organizational unit that will be # printed to console, a `generator` that yields `Event` protos, and a mapping # from string fields to `Observations` that the inspector creates. -InspectionUnit = collections.namedtuple('InspectionUnit', ['name', 'generator', - 'field_to_obs']) - -PRINT_SEPARATOR = '=' * 70 + '\n' - - -def get_field_to_observations_map(generator, query_for_tag=''): - """Return a field to `Observations` dict for the event generator. - - Args: - generator: A generator over event protos. - query_for_tag: A string that if specified, only create observations for - events with this tag name. - - Returns: - A dict mapping keys in `TRACKED_FIELDS` to an `Observation` list. - """ - - def increment(stat, event, tag=''): - assert stat in TRACKED_FIELDS - field_to_obs[stat].append(Observation(step=event.step, - wall_time=event.wall_time, - tag=tag)._asdict()) - - field_to_obs = dict([(t, []) for t in TRACKED_FIELDS]) - - for event in generator: - ## Process the event - if event.HasField('graph_def') and (not query_for_tag): - increment('graph', event) - if event.HasField('session_log') and (not query_for_tag): - status = event.session_log.status - if status == event_pb2.SessionLog.START: - increment('sessionlog:start', event) - elif status == event_pb2.SessionLog.STOP: - increment('sessionlog:stop', event) - elif status == event_pb2.SessionLog.CHECKPOINT: - increment('sessionlog:checkpoint', event) - elif event.HasField('summary'): - for value in event.summary.value: - if query_for_tag and value.tag != query_for_tag: - continue - - for proto_name, display_name in SUMMARY_TYPE_TO_FIELD.items(): - if value.HasField(proto_name): - increment(display_name, event, value.tag) - return field_to_obs +InspectionUnit = collections.namedtuple( + "InspectionUnit", ["name", "generator", "field_to_obs"] +) + +PRINT_SEPARATOR = "=" * 70 + "\n" + + +def get_field_to_observations_map(generator, query_for_tag=""): + """Return a field to `Observations` dict for the event generator. + + Args: + generator: A generator over event protos. + query_for_tag: A string that if specified, only create observations for + events with this tag name. + + Returns: + A dict mapping keys in `TRACKED_FIELDS` to an `Observation` list. + """ + + def increment(stat, event, tag=""): + assert stat in TRACKED_FIELDS + field_to_obs[stat].append( + Observation( + step=event.step, wall_time=event.wall_time, tag=tag + )._asdict() + ) + + field_to_obs = dict([(t, []) for t in TRACKED_FIELDS]) + + for event in generator: + ## Process the event + if event.HasField("graph_def") and (not query_for_tag): + increment("graph", event) + if event.HasField("session_log") and (not query_for_tag): + status = event.session_log.status + if status == event_pb2.SessionLog.START: + increment("sessionlog:start", event) + elif status == event_pb2.SessionLog.STOP: + increment("sessionlog:stop", event) + elif status == event_pb2.SessionLog.CHECKPOINT: + increment("sessionlog:checkpoint", event) + elif event.HasField("summary"): + for value in event.summary.value: + if query_for_tag and value.tag != query_for_tag: + continue + + for proto_name, display_name in SUMMARY_TYPE_TO_FIELD.items(): + if value.HasField(proto_name): + increment(display_name, event, value.tag) + return field_to_obs def get_unique_tags(field_to_obs): - """Returns a dictionary of tags that a user could query over. + """Returns a dictionary of tags that a user could query over. - Args: - field_to_obs: Dict that maps string field to `Observation` list. + Args: + field_to_obs: Dict that maps string field to `Observation` list. - Returns: - A dict that maps keys in `TAG_FIELDS` to a list of string tags present in - the event files. If the dict does not have any observations of the type, - maps to an empty list so that we can render this to console. - """ - return {field: sorted(set([x.get('tag', '') for x in observations])) - for field, observations in field_to_obs.items() - if field in TAG_FIELDS} + Returns: + A dict that maps keys in `TAG_FIELDS` to a list of string tags present in + the event files. If the dict does not have any observations of the type, + maps to an empty list so that we can render this to console. + """ + return { + field: sorted(set([x.get("tag", "") for x in observations])) + for field, observations in field_to_obs.items() + if field in TAG_FIELDS + } def print_dict(d, show_missing=True): - """Prints a shallow dict to console. - - Args: - d: Dict to print. - show_missing: Whether to show keys with empty values. - """ - for k, v in sorted(d.items()): - if (not v) and show_missing: - # No instances of the key, so print missing symbol. - print('{} -'.format(k)) - elif isinstance(v, list): - # Value is a list, so print each item of the list. - print(k) - for item in v: - print(' {}'.format(item)) - elif isinstance(v, dict): - # Value is a dict, so print each (key, value) pair of the dict. - print(k) - for kk, vv in sorted(v.items()): - print(' {:<20} {}'.format(kk, vv)) + """Prints a shallow dict to console. + + Args: + d: Dict to print. + show_missing: Whether to show keys with empty values. + """ + for k, v in sorted(d.items()): + if (not v) and show_missing: + # No instances of the key, so print missing symbol. + print("{} -".format(k)) + elif isinstance(v, list): + # Value is a list, so print each item of the list. + print(k) + for item in v: + print(" {}".format(item)) + elif isinstance(v, dict): + # Value is a dict, so print each (key, value) pair of the dict. + print(k) + for kk, vv in sorted(v.items()): + print(" {:<20} {}".format(kk, vv)) def get_dict_to_print(field_to_obs): - """Transform the field-to-obs mapping into a printable dictionary. + """Transform the field-to-obs mapping into a printable dictionary. - Args: - field_to_obs: Dict that maps string field to `Observation` list. + Args: + field_to_obs: Dict that maps string field to `Observation` list. - Returns: - A dict with the keys and values to print to console. - """ + Returns: + A dict with the keys and values to print to console. + """ - def compressed_steps(steps): - return {'num_steps': len(set(steps)), - 'min_step': min(steps), - 'max_step': max(steps), - 'last_step': steps[-1], - 'first_step': steps[0], - 'outoforder_steps': get_out_of_order(steps)} + def compressed_steps(steps): + return { + "num_steps": len(set(steps)), + "min_step": min(steps), + "max_step": max(steps), + "last_step": steps[-1], + "first_step": steps[0], + "outoforder_steps": get_out_of_order(steps), + } - def full_steps(steps): - return {'steps': steps, 'outoforder_steps': get_out_of_order(steps)} + def full_steps(steps): + return {"steps": steps, "outoforder_steps": get_out_of_order(steps)} - output = {} - for field, observations in field_to_obs.items(): - if not observations: - output[field] = None - continue + output = {} + for field, observations in field_to_obs.items(): + if not observations: + output[field] = None + continue - steps = [x['step'] for x in observations] - if field in SHORT_FIELDS: - output[field] = compressed_steps(steps) - if field in LONG_FIELDS: - output[field] = full_steps(steps) + steps = [x["step"] for x in observations] + if field in SHORT_FIELDS: + output[field] = compressed_steps(steps) + if field in LONG_FIELDS: + output[field] = full_steps(steps) - return output + return output def get_out_of_order(list_of_numbers): - """Returns elements that break the monotonically non-decreasing trend. - - This is used to find instances of global step values that are "out-of-order", - which may trigger TensorBoard event discarding logic. - - Args: - list_of_numbers: A list of numbers. - - Returns: - A list of tuples in which each tuple are two elements are adjacent, but the - second element is lower than the first. - """ - # TODO: Consider changing this to only check for out-of-order - # steps within a particular tag. - result = [] - # pylint: disable=consider-using-enumerate - for i in range(len(list_of_numbers)): - if i == 0: - continue - if list_of_numbers[i] < list_of_numbers[i - 1]: - result.append((list_of_numbers[i - 1], list_of_numbers[i])) - return result + """Returns elements that break the monotonically non-decreasing trend. + + This is used to find instances of global step values that are "out-of-order", + which may trigger TensorBoard event discarding logic. + + Args: + list_of_numbers: A list of numbers. + + Returns: + A list of tuples in which each tuple are two elements are adjacent, but the + second element is lower than the first. + """ + # TODO: Consider changing this to only check for out-of-order + # steps within a particular tag. + result = [] + # pylint: disable=consider-using-enumerate + for i in range(len(list_of_numbers)): + if i == 0: + continue + if list_of_numbers[i] < list_of_numbers[i - 1]: + result.append((list_of_numbers[i - 1], list_of_numbers[i])) + return result def generators_from_logdir(logdir): - """Returns a list of event generators for subdirectories with event files. + """Returns a list of event generators for subdirectories with event files. - The number of generators returned should equal the number of directories - within logdir that contain event files. If only logdir contains event files, - returns a list of length one. + The number of generators returned should equal the number of directories + within logdir that contain event files. If only logdir contains event files, + returns a list of length one. - Args: - logdir: A log directory that contains event files. + Args: + logdir: A log directory that contains event files. - Returns: - List of event generators for each subdirectory with event files. - """ - subdirs = io_wrapper.GetLogdirSubdirectories(logdir) - generators = [ - itertools.chain(*[ - generator_from_event_file(os.path.join(subdir, f)) - for f in tf.io.gfile.listdir(subdir) - if io_wrapper.IsTensorFlowEventsFile(os.path.join(subdir, f)) - ]) for subdir in subdirs - ] - return generators + Returns: + List of event generators for each subdirectory with event files. + """ + subdirs = io_wrapper.GetLogdirSubdirectories(logdir) + generators = [ + itertools.chain( + *[ + generator_from_event_file(os.path.join(subdir, f)) + for f in tf.io.gfile.listdir(subdir) + if io_wrapper.IsTensorFlowEventsFile(os.path.join(subdir, f)) + ] + ) + for subdir in subdirs + ] + return generators def generator_from_event_file(event_file): - """Returns a generator that yields events from an event file.""" - return event_file_loader.EventFileLoader(event_file).Load() - - -def get_inspection_units(logdir='', event_file='', tag=''): - """Returns a list of InspectionUnit objects given either logdir or event_file. - - If logdir is given, the number of InspectionUnits should equal the - number of directories or subdirectories that contain event files. - - If event_file is given, the number of InspectionUnits should be 1. - - Args: - logdir: A log directory that contains event files. - event_file: Or, a particular event file path. - tag: An optional tag name to query for. - - Returns: - A list of InspectionUnit objects. - """ - if logdir: - subdirs = io_wrapper.GetLogdirSubdirectories(logdir) - inspection_units = [] - for subdir in subdirs: - generator = itertools.chain(*[ - generator_from_event_file(os.path.join(subdir, f)) - for f in tf.io.gfile.listdir(subdir) - if io_wrapper.IsTensorFlowEventsFile(os.path.join(subdir, f)) - ]) - inspection_units.append(InspectionUnit( - name=subdir, - generator=generator, - field_to_obs=get_field_to_observations_map(generator, tag))) - if inspection_units: - print('Found event files in:\n{}\n'.format('\n'.join( - [u.name for u in inspection_units]))) - elif io_wrapper.IsTensorFlowEventsFile(logdir): - print( - 'It seems that {} may be an event file instead of a logdir. If this ' - 'is the case, use --event_file instead of --logdir to pass ' - 'it in.'.format(logdir)) - else: - print('No event files found within logdir {}'.format(logdir)) - return inspection_units - elif event_file: - generator = generator_from_event_file(event_file) - return [InspectionUnit( - name=event_file, - generator=generator, - field_to_obs=get_field_to_observations_map(generator, tag))] - return [] - - -def inspect(logdir='', event_file='', tag=''): - """Main function for inspector that prints out a digest of event files. - - Args: - logdir: A log directory that contains event files. - event_file: Or, a particular event file path. - tag: An optional tag name to query for. - - Raises: - ValueError: If neither logdir and event_file are given, or both are given. - """ - print(PRINT_SEPARATOR + - 'Processing event files... (this can take a few minutes)\n' + - PRINT_SEPARATOR) - inspection_units = get_inspection_units(logdir, event_file, tag) - - for unit in inspection_units: - if tag: - print('Event statistics for tag {} in {}:'.format(tag, unit.name)) - else: - # If the user is not inspecting a particular tag, also print the list of - # all available tags that they can query. - print('These tags are in {}:'.format(unit.name)) - print_dict(get_unique_tags(unit.field_to_obs)) - print(PRINT_SEPARATOR) - print('Event statistics for {}:'.format(unit.name)) - - print_dict(get_dict_to_print(unit.field_to_obs), show_missing=(not tag)) - print(PRINT_SEPARATOR) + """Returns a generator that yields events from an event file.""" + return event_file_loader.EventFileLoader(event_file).Load() + + +def get_inspection_units(logdir="", event_file="", tag=""): + """Returns a list of InspectionUnit objects given either logdir or + event_file. + + If logdir is given, the number of InspectionUnits should equal the + number of directories or subdirectories that contain event files. + + If event_file is given, the number of InspectionUnits should be 1. + + Args: + logdir: A log directory that contains event files. + event_file: Or, a particular event file path. + tag: An optional tag name to query for. + + Returns: + A list of InspectionUnit objects. + """ + if logdir: + subdirs = io_wrapper.GetLogdirSubdirectories(logdir) + inspection_units = [] + for subdir in subdirs: + generator = itertools.chain( + *[ + generator_from_event_file(os.path.join(subdir, f)) + for f in tf.io.gfile.listdir(subdir) + if io_wrapper.IsTensorFlowEventsFile( + os.path.join(subdir, f) + ) + ] + ) + inspection_units.append( + InspectionUnit( + name=subdir, + generator=generator, + field_to_obs=get_field_to_observations_map(generator, tag), + ) + ) + if inspection_units: + print( + "Found event files in:\n{}\n".format( + "\n".join([u.name for u in inspection_units]) + ) + ) + elif io_wrapper.IsTensorFlowEventsFile(logdir): + print( + "It seems that {} may be an event file instead of a logdir. If this " + "is the case, use --event_file instead of --logdir to pass " + "it in.".format(logdir) + ) + else: + print("No event files found within logdir {}".format(logdir)) + return inspection_units + elif event_file: + generator = generator_from_event_file(event_file) + return [ + InspectionUnit( + name=event_file, + generator=generator, + field_to_obs=get_field_to_observations_map(generator, tag), + ) + ] + return [] + + +def inspect(logdir="", event_file="", tag=""): + """Main function for inspector that prints out a digest of event files. + + Args: + logdir: A log directory that contains event files. + event_file: Or, a particular event file path. + tag: An optional tag name to query for. + + Raises: + ValueError: If neither logdir and event_file are given, or both are given. + """ + print( + PRINT_SEPARATOR + + "Processing event files... (this can take a few minutes)\n" + + PRINT_SEPARATOR + ) + inspection_units = get_inspection_units(logdir, event_file, tag) + + for unit in inspection_units: + if tag: + print("Event statistics for tag {} in {}:".format(tag, unit.name)) + else: + # If the user is not inspecting a particular tag, also print the list of + # all available tags that they can query. + print("These tags are in {}:".format(unit.name)) + print_dict(get_unique_tags(unit.field_to_obs)) + print(PRINT_SEPARATOR) + print("Event statistics for {}:".format(unit.name)) + + print_dict(get_dict_to_print(unit.field_to_obs), show_missing=(not tag)) + print(PRINT_SEPARATOR) diff --git a/tensorboard/backend/event_processing/event_file_inspector_test.py b/tensorboard/backend/event_processing/event_file_inspector_test.py index 2a12a8fe9b..7648d3e3f4 100644 --- a/tensorboard/backend/event_processing/event_file_inspector_test.py +++ b/tensorboard/backend/event_processing/event_file_inspector_test.py @@ -31,171 +31,192 @@ class EventFileInspectorTest(tf.test.TestCase): - - def setUp(self): - self.logdir = os.path.join(self.get_temp_dir(), 'tfevents') - self._MakeDirectoryIfNotExists(self.logdir) - - def tearDown(self): - shutil.rmtree(self.logdir) - - def _MakeDirectoryIfNotExists(self, path): - if not os.path.exists(path): - os.mkdir(path) - - def _WriteScalarSummaries(self, data, subdirs=('',)): - # Writes data to a tempfile in subdirs, and returns generator for the data. - # If subdirs is given, writes data identically to all subdirectories. - for subdir_ in subdirs: - subdir = os.path.join(self.logdir, subdir_) - self._MakeDirectoryIfNotExists(subdir) - - with test_util.FileWriterCache.get(subdir) as sw: - for datum in data: - summary = summary_pb2.Summary() - if 'simple_value' in datum: - summary.value.add(tag=datum['tag'], - simple_value=datum['simple_value']) - sw.add_summary(summary, global_step=datum['step']) - elif 'histo' in datum: - summary.value.add(tag=datum['tag'], - histo=summary_pb2.HistogramProto()) - sw.add_summary(summary, global_step=datum['step']) - elif 'session_log' in datum: - sw.add_session_log(datum['session_log'], global_step=datum['step']) - - def testEmptyLogdir(self): - # Nothing was written to logdir - units = efi.get_inspection_units(self.logdir) - self.assertEqual([], units) - - def testGetAvailableTags(self): - data = [{'tag': 'c', 'histo': 2, 'step': 10}, - {'tag': 'c', 'histo': 2, 'step': 11}, - {'tag': 'c', 'histo': 2, 'step': 9}, - {'tag': 'b', 'simple_value': 2, 'step': 20}, - {'tag': 'b', 'simple_value': 2, 'step': 15}, - {'tag': 'a', 'simple_value': 2, 'step': 3}] - self._WriteScalarSummaries(data) - units = efi.get_inspection_units(self.logdir) - tags = efi.get_unique_tags(units[0].field_to_obs) - self.assertEqual(['a', 'b'], tags['scalars']) - self.assertEqual(['c'], tags['histograms']) - - def testInspectAll(self): - data = [{'tag': 'c', 'histo': 2, 'step': 10}, - {'tag': 'c', 'histo': 2, 'step': 11}, - {'tag': 'c', 'histo': 2, 'step': 9}, - {'tag': 'b', 'simple_value': 2, 'step': 20}, - {'tag': 'b', 'simple_value': 2, 'step': 15}, - {'tag': 'a', 'simple_value': 2, 'step': 3}] - self._WriteScalarSummaries(data) - units = efi.get_inspection_units(self.logdir) - printable = efi.get_dict_to_print(units[0].field_to_obs) - self.assertEqual(printable['histograms']['max_step'], 11) - self.assertEqual(printable['histograms']['min_step'], 9) - self.assertEqual(printable['histograms']['num_steps'], 3) - self.assertEqual(printable['histograms']['last_step'], 9) - self.assertEqual(printable['histograms']['first_step'], 10) - self.assertEqual(printable['histograms']['outoforder_steps'], [(11, 9)]) - - self.assertEqual(printable['scalars']['max_step'], 20) - self.assertEqual(printable['scalars']['min_step'], 3) - self.assertEqual(printable['scalars']['num_steps'], 3) - self.assertEqual(printable['scalars']['last_step'], 3) - self.assertEqual(printable['scalars']['first_step'], 20) - self.assertEqual(printable['scalars']['outoforder_steps'], [(20, 15), - (15, 3)]) - - def testInspectTag(self): - data = [{'tag': 'c', 'histo': 2, 'step': 10}, - {'tag': 'c', 'histo': 2, 'step': 11}, - {'tag': 'c', 'histo': 2, 'step': 9}, - {'tag': 'b', 'histo': 2, 'step': 20}, - {'tag': 'b', 'simple_value': 2, 'step': 15}, - {'tag': 'a', 'simple_value': 2, 'step': 3}] - self._WriteScalarSummaries(data) - units = efi.get_inspection_units(self.logdir, tag='c') - printable = efi.get_dict_to_print(units[0].field_to_obs) - self.assertEqual(printable['histograms']['max_step'], 11) - self.assertEqual(printable['histograms']['min_step'], 9) - self.assertEqual(printable['histograms']['num_steps'], 3) - self.assertEqual(printable['histograms']['last_step'], 9) - self.assertEqual(printable['histograms']['first_step'], 10) - self.assertEqual(printable['histograms']['outoforder_steps'], [(11, 9)]) - self.assertEqual(printable['scalars'], None) - - def testSessionLogSummaries(self): - data = [ - { - 'session_log': event_pb2.SessionLog( - status=event_pb2.SessionLog.START), - 'step': 0 - }, - { - 'session_log': event_pb2.SessionLog( - status=event_pb2.SessionLog.CHECKPOINT), - 'step': 1 - }, - { - 'session_log': event_pb2.SessionLog( - status=event_pb2.SessionLog.CHECKPOINT), - 'step': 2 - }, - { - 'session_log': event_pb2.SessionLog( - status=event_pb2.SessionLog.CHECKPOINT), - 'step': 3 - }, - { - 'session_log': event_pb2.SessionLog( - status=event_pb2.SessionLog.STOP), - 'step': 4 - }, - { - 'session_log': event_pb2.SessionLog( - status=event_pb2.SessionLog.START), - 'step': 5 - }, - { - 'session_log': event_pb2.SessionLog( - status=event_pb2.SessionLog.STOP), - 'step': 6 - }, - ] - - self._WriteScalarSummaries(data) - units = efi.get_inspection_units(self.logdir) - self.assertEqual(1, len(units)) - printable = efi.get_dict_to_print(units[0].field_to_obs) - self.assertEqual(printable['sessionlog:start']['steps'], [0, 5]) - self.assertEqual(printable['sessionlog:stop']['steps'], [4, 6]) - self.assertEqual(printable['sessionlog:checkpoint']['num_steps'], 3) - - def testInspectAllWithNestedLogdirs(self): - data = [{'tag': 'c', 'simple_value': 2, 'step': 10}, - {'tag': 'c', 'simple_value': 2, 'step': 11}, - {'tag': 'c', 'simple_value': 2, 'step': 9}, - {'tag': 'b', 'simple_value': 2, 'step': 20}, - {'tag': 'b', 'simple_value': 2, 'step': 15}, - {'tag': 'a', 'simple_value': 2, 'step': 3}] - - subdirs = ['eval', 'train'] - self._WriteScalarSummaries(data, subdirs=subdirs) - units = efi.get_inspection_units(self.logdir) - self.assertEqual(2, len(units)) - directory_names = [os.path.join(self.logdir, name) for name in subdirs] - self.assertEqual(directory_names, sorted([unit.name for unit in units])) - - for unit in units: - printable = efi.get_dict_to_print(unit.field_to_obs)['scalars'] - self.assertEqual(printable['max_step'], 20) - self.assertEqual(printable['min_step'], 3) - self.assertEqual(printable['num_steps'], 6) - self.assertEqual(printable['last_step'], 3) - self.assertEqual(printable['first_step'], 10) - self.assertEqual(printable['outoforder_steps'], [(11, 9), (20, 15), - (15, 3)]) - -if __name__ == '__main__': - tf.test.main() + def setUp(self): + self.logdir = os.path.join(self.get_temp_dir(), "tfevents") + self._MakeDirectoryIfNotExists(self.logdir) + + def tearDown(self): + shutil.rmtree(self.logdir) + + def _MakeDirectoryIfNotExists(self, path): + if not os.path.exists(path): + os.mkdir(path) + + def _WriteScalarSummaries(self, data, subdirs=("",)): + # Writes data to a tempfile in subdirs, and returns generator for the data. + # If subdirs is given, writes data identically to all subdirectories. + for subdir_ in subdirs: + subdir = os.path.join(self.logdir, subdir_) + self._MakeDirectoryIfNotExists(subdir) + + with test_util.FileWriterCache.get(subdir) as sw: + for datum in data: + summary = summary_pb2.Summary() + if "simple_value" in datum: + summary.value.add( + tag=datum["tag"], simple_value=datum["simple_value"] + ) + sw.add_summary(summary, global_step=datum["step"]) + elif "histo" in datum: + summary.value.add( + tag=datum["tag"], histo=summary_pb2.HistogramProto() + ) + sw.add_summary(summary, global_step=datum["step"]) + elif "session_log" in datum: + sw.add_session_log( + datum["session_log"], global_step=datum["step"] + ) + + def testEmptyLogdir(self): + # Nothing was written to logdir + units = efi.get_inspection_units(self.logdir) + self.assertEqual([], units) + + def testGetAvailableTags(self): + data = [ + {"tag": "c", "histo": 2, "step": 10}, + {"tag": "c", "histo": 2, "step": 11}, + {"tag": "c", "histo": 2, "step": 9}, + {"tag": "b", "simple_value": 2, "step": 20}, + {"tag": "b", "simple_value": 2, "step": 15}, + {"tag": "a", "simple_value": 2, "step": 3}, + ] + self._WriteScalarSummaries(data) + units = efi.get_inspection_units(self.logdir) + tags = efi.get_unique_tags(units[0].field_to_obs) + self.assertEqual(["a", "b"], tags["scalars"]) + self.assertEqual(["c"], tags["histograms"]) + + def testInspectAll(self): + data = [ + {"tag": "c", "histo": 2, "step": 10}, + {"tag": "c", "histo": 2, "step": 11}, + {"tag": "c", "histo": 2, "step": 9}, + {"tag": "b", "simple_value": 2, "step": 20}, + {"tag": "b", "simple_value": 2, "step": 15}, + {"tag": "a", "simple_value": 2, "step": 3}, + ] + self._WriteScalarSummaries(data) + units = efi.get_inspection_units(self.logdir) + printable = efi.get_dict_to_print(units[0].field_to_obs) + self.assertEqual(printable["histograms"]["max_step"], 11) + self.assertEqual(printable["histograms"]["min_step"], 9) + self.assertEqual(printable["histograms"]["num_steps"], 3) + self.assertEqual(printable["histograms"]["last_step"], 9) + self.assertEqual(printable["histograms"]["first_step"], 10) + self.assertEqual(printable["histograms"]["outoforder_steps"], [(11, 9)]) + + self.assertEqual(printable["scalars"]["max_step"], 20) + self.assertEqual(printable["scalars"]["min_step"], 3) + self.assertEqual(printable["scalars"]["num_steps"], 3) + self.assertEqual(printable["scalars"]["last_step"], 3) + self.assertEqual(printable["scalars"]["first_step"], 20) + self.assertEqual( + printable["scalars"]["outoforder_steps"], [(20, 15), (15, 3)] + ) + + def testInspectTag(self): + data = [ + {"tag": "c", "histo": 2, "step": 10}, + {"tag": "c", "histo": 2, "step": 11}, + {"tag": "c", "histo": 2, "step": 9}, + {"tag": "b", "histo": 2, "step": 20}, + {"tag": "b", "simple_value": 2, "step": 15}, + {"tag": "a", "simple_value": 2, "step": 3}, + ] + self._WriteScalarSummaries(data) + units = efi.get_inspection_units(self.logdir, tag="c") + printable = efi.get_dict_to_print(units[0].field_to_obs) + self.assertEqual(printable["histograms"]["max_step"], 11) + self.assertEqual(printable["histograms"]["min_step"], 9) + self.assertEqual(printable["histograms"]["num_steps"], 3) + self.assertEqual(printable["histograms"]["last_step"], 9) + self.assertEqual(printable["histograms"]["first_step"], 10) + self.assertEqual(printable["histograms"]["outoforder_steps"], [(11, 9)]) + self.assertEqual(printable["scalars"], None) + + def testSessionLogSummaries(self): + data = [ + { + "session_log": event_pb2.SessionLog( + status=event_pb2.SessionLog.START + ), + "step": 0, + }, + { + "session_log": event_pb2.SessionLog( + status=event_pb2.SessionLog.CHECKPOINT + ), + "step": 1, + }, + { + "session_log": event_pb2.SessionLog( + status=event_pb2.SessionLog.CHECKPOINT + ), + "step": 2, + }, + { + "session_log": event_pb2.SessionLog( + status=event_pb2.SessionLog.CHECKPOINT + ), + "step": 3, + }, + { + "session_log": event_pb2.SessionLog( + status=event_pb2.SessionLog.STOP + ), + "step": 4, + }, + { + "session_log": event_pb2.SessionLog( + status=event_pb2.SessionLog.START + ), + "step": 5, + }, + { + "session_log": event_pb2.SessionLog( + status=event_pb2.SessionLog.STOP + ), + "step": 6, + }, + ] + + self._WriteScalarSummaries(data) + units = efi.get_inspection_units(self.logdir) + self.assertEqual(1, len(units)) + printable = efi.get_dict_to_print(units[0].field_to_obs) + self.assertEqual(printable["sessionlog:start"]["steps"], [0, 5]) + self.assertEqual(printable["sessionlog:stop"]["steps"], [4, 6]) + self.assertEqual(printable["sessionlog:checkpoint"]["num_steps"], 3) + + def testInspectAllWithNestedLogdirs(self): + data = [ + {"tag": "c", "simple_value": 2, "step": 10}, + {"tag": "c", "simple_value": 2, "step": 11}, + {"tag": "c", "simple_value": 2, "step": 9}, + {"tag": "b", "simple_value": 2, "step": 20}, + {"tag": "b", "simple_value": 2, "step": 15}, + {"tag": "a", "simple_value": 2, "step": 3}, + ] + + subdirs = ["eval", "train"] + self._WriteScalarSummaries(data, subdirs=subdirs) + units = efi.get_inspection_units(self.logdir) + self.assertEqual(2, len(units)) + directory_names = [os.path.join(self.logdir, name) for name in subdirs] + self.assertEqual(directory_names, sorted([unit.name for unit in units])) + + for unit in units: + printable = efi.get_dict_to_print(unit.field_to_obs)["scalars"] + self.assertEqual(printable["max_step"], 20) + self.assertEqual(printable["min_step"], 3) + self.assertEqual(printable["num_steps"], 6) + self.assertEqual(printable["last_step"], 3) + self.assertEqual(printable["first_step"], 10) + self.assertEqual( + printable["outoforder_steps"], [(11, 9), (20, 15), (15, 3)] + ) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/backend/event_processing/event_file_loader.py b/tensorboard/backend/event_processing/event_file_loader.py index 50615e257c..3a1d52bdf0 100644 --- a/tensorboard/backend/event_processing/event_file_loader.py +++ b/tensorboard/backend/event_processing/event_file_loader.py @@ -31,82 +31,87 @@ class RawEventFileLoader(object): - """An iterator that yields Event protos as serialized bytestrings.""" - - def __init__(self, file_path): - if file_path is None: - raise ValueError('A file path is required') - file_path = platform_util.readahead_file_path(file_path) - logger.debug('Opening a record reader pointing at %s', file_path) - with tf.compat.v1.errors.raise_exception_on_not_ok_status() as status: - self._reader = _pywrap_tensorflow.PyRecordReader_New( - tf.compat.as_bytes(file_path), 0, tf.compat.as_bytes(''), status) - # Store it for logging purposes. - self._file_path = file_path - if not self._reader: - raise IOError('Failed to open a record reader pointing to %s' % file_path) - - def Load(self): - """Loads all new events from disk as raw serialized proto bytestrings. - - Calling Load multiple times in a row will not 'drop' events as long as the - return value is not iterated over. - - Yields: - All event proto bytestrings in the file that have not been yielded yet. - """ - logger.debug('Loading events from %s', self._file_path) - - # GetNext() expects a status argument on TF <= 1.7. - get_next_args = inspect.getargspec(self._reader.GetNext).args # pylint: disable=deprecated-method - # First argument is self - legacy_get_next = (len(get_next_args) > 1) - - while True: - try: - if legacy_get_next: - with tf.compat.v1.errors.raise_exception_on_not_ok_status() as status: - self._reader.GetNext(status) - else: - self._reader.GetNext() - except (tf.errors.DataLossError, tf.errors.OutOfRangeError) as e: - logger.debug('Cannot read more events: %s', e) - # We ignore partial read exceptions, because a record may be truncated. - # PyRecordReader holds the offset prior to the failed read, so retrying - # will succeed. - break - yield self._reader.record() - logger.debug('No more events in %s', self._file_path) + """An iterator that yields Event protos as serialized bytestrings.""" + + def __init__(self, file_path): + if file_path is None: + raise ValueError("A file path is required") + file_path = platform_util.readahead_file_path(file_path) + logger.debug("Opening a record reader pointing at %s", file_path) + with tf.compat.v1.errors.raise_exception_on_not_ok_status() as status: + self._reader = _pywrap_tensorflow.PyRecordReader_New( + tf.compat.as_bytes(file_path), 0, tf.compat.as_bytes(""), status + ) + # Store it for logging purposes. + self._file_path = file_path + if not self._reader: + raise IOError( + "Failed to open a record reader pointing to %s" % file_path + ) + + def Load(self): + """Loads all new events from disk as raw serialized proto bytestrings. + + Calling Load multiple times in a row will not 'drop' events as long as the + return value is not iterated over. + + Yields: + All event proto bytestrings in the file that have not been yielded yet. + """ + logger.debug("Loading events from %s", self._file_path) + + # GetNext() expects a status argument on TF <= 1.7. + get_next_args = inspect.getargspec( + self._reader.GetNext + ).args # pylint: disable=deprecated-method + # First argument is self + legacy_get_next = len(get_next_args) > 1 + + while True: + try: + if legacy_get_next: + with tf.compat.v1.errors.raise_exception_on_not_ok_status() as status: + self._reader.GetNext(status) + else: + self._reader.GetNext() + except (tf.errors.DataLossError, tf.errors.OutOfRangeError) as e: + logger.debug("Cannot read more events: %s", e) + # We ignore partial read exceptions, because a record may be truncated. + # PyRecordReader holds the offset prior to the failed read, so retrying + # will succeed. + break + yield self._reader.record() + logger.debug("No more events in %s", self._file_path) class EventFileLoader(RawEventFileLoader): - """An iterator that yields parsed Event protos.""" + """An iterator that yields parsed Event protos.""" - def Load(self): - """Loads all new events from disk. + def Load(self): + """Loads all new events from disk. - Calling Load multiple times in a row will not 'drop' events as long as the - return value is not iterated over. + Calling Load multiple times in a row will not 'drop' events as long as the + return value is not iterated over. - Yields: - All events in the file that have not been yielded yet. - """ - for record in super(EventFileLoader, self).Load(): - yield event_pb2.Event.FromString(record) + Yields: + All events in the file that have not been yielded yet. + """ + for record in super(EventFileLoader, self).Load(): + yield event_pb2.Event.FromString(record) class TimestampedEventFileLoader(EventFileLoader): - """An iterator that yields (UNIX timestamp float, Event proto) pairs.""" + """An iterator that yields (UNIX timestamp float, Event proto) pairs.""" - def Load(self): - """Loads all new events and their wall time values from disk. + def Load(self): + """Loads all new events and their wall time values from disk. - Calling Load multiple times in a row will not 'drop' events as long as the - return value is not iterated over. + Calling Load multiple times in a row will not 'drop' events as long as the + return value is not iterated over. - Yields: - Pairs of (UNIX timestamp float, Event proto) for all events in the file - that have not been yielded yet. - """ - for event in super(TimestampedEventFileLoader, self).Load(): - yield (event.wall_time, event) + Yields: + Pairs of (UNIX timestamp float, Event proto) for all events in the file + that have not been yielded yet. + """ + for event in super(TimestampedEventFileLoader, self).Load(): + yield (event.wall_time, event) diff --git a/tensorboard/backend/event_processing/event_file_loader_test.py b/tensorboard/backend/event_processing/event_file_loader_test.py index 99ada11ec5..2f67b4b9bf 100644 --- a/tensorboard/backend/event_processing/event_file_loader_test.py +++ b/tensorboard/backend/event_processing/event_file_loader_test.py @@ -29,82 +29,85 @@ class EventFileLoaderTest(tf.test.TestCase): - # A record containing a simple event. - RECORD = (b'\x18\x00\x00\x00\x00\x00\x00\x00\xa3\x7fK"\t\x00\x00\xc0%\xddu' - b'\xd5A\x1a\rbrain.Event:1\xec\xf32\x8d') - - def _WriteToFile(self, filename, data): - with open(filename, 'ab') as f: - f.write(data) - - def _LoaderForTestFile(self, filename): - return event_file_loader.EventFileLoader( - os.path.join(self.get_temp_dir(), filename)) - - def testEmptyEventFile(self): - filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name - self._WriteToFile(filename, b'') - loader = self._LoaderForTestFile(filename) - self.assertEqual(len(list(loader.Load())), 0) - - def testSingleWrite(self): - filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name - self._WriteToFile(filename, EventFileLoaderTest.RECORD) - loader = self._LoaderForTestFile(filename) - events = list(loader.Load()) - self.assertEqual(len(events), 1) - self.assertEqual(events[0].wall_time, 1440183447.0) - self.assertEqual(len(list(loader.Load())), 0) - - def testMultipleWrites(self): - filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name - self._WriteToFile(filename, EventFileLoaderTest.RECORD) - loader = self._LoaderForTestFile(filename) - self.assertEqual(len(list(loader.Load())), 1) - self._WriteToFile(filename, EventFileLoaderTest.RECORD) - self.assertEqual(len(list(loader.Load())), 1) - - def testMultipleLoads(self): - filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name - self._WriteToFile(filename, EventFileLoaderTest.RECORD) - loader = self._LoaderForTestFile(filename) - loader.Load() - loader.Load() - self.assertEqual(len(list(loader.Load())), 1) - - def testMultipleWritesAtOnce(self): - filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name - self._WriteToFile(filename, EventFileLoaderTest.RECORD) - self._WriteToFile(filename, EventFileLoaderTest.RECORD) - loader = self._LoaderForTestFile(filename) - self.assertEqual(len(list(loader.Load())), 2) - - def testMultipleWritesWithBadWrite(self): - filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name - self._WriteToFile(filename, EventFileLoaderTest.RECORD) - self._WriteToFile(filename, EventFileLoaderTest.RECORD) - # Test that we ignore partial record writes at the end of the file. - self._WriteToFile(filename, b'123') - loader = self._LoaderForTestFile(filename) - self.assertEqual(len(list(loader.Load())), 2) + # A record containing a simple event. + RECORD = ( + b'\x18\x00\x00\x00\x00\x00\x00\x00\xa3\x7fK"\t\x00\x00\xc0%\xddu' + b"\xd5A\x1a\rbrain.Event:1\xec\xf32\x8d" + ) + + def _WriteToFile(self, filename, data): + with open(filename, "ab") as f: + f.write(data) + + def _LoaderForTestFile(self, filename): + return event_file_loader.EventFileLoader( + os.path.join(self.get_temp_dir(), filename) + ) + + def testEmptyEventFile(self): + filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name + self._WriteToFile(filename, b"") + loader = self._LoaderForTestFile(filename) + self.assertEqual(len(list(loader.Load())), 0) + + def testSingleWrite(self): + filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name + self._WriteToFile(filename, EventFileLoaderTest.RECORD) + loader = self._LoaderForTestFile(filename) + events = list(loader.Load()) + self.assertEqual(len(events), 1) + self.assertEqual(events[0].wall_time, 1440183447.0) + self.assertEqual(len(list(loader.Load())), 0) + + def testMultipleWrites(self): + filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name + self._WriteToFile(filename, EventFileLoaderTest.RECORD) + loader = self._LoaderForTestFile(filename) + self.assertEqual(len(list(loader.Load())), 1) + self._WriteToFile(filename, EventFileLoaderTest.RECORD) + self.assertEqual(len(list(loader.Load())), 1) + + def testMultipleLoads(self): + filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name + self._WriteToFile(filename, EventFileLoaderTest.RECORD) + loader = self._LoaderForTestFile(filename) + loader.Load() + loader.Load() + self.assertEqual(len(list(loader.Load())), 1) + + def testMultipleWritesAtOnce(self): + filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name + self._WriteToFile(filename, EventFileLoaderTest.RECORD) + self._WriteToFile(filename, EventFileLoaderTest.RECORD) + loader = self._LoaderForTestFile(filename) + self.assertEqual(len(list(loader.Load())), 2) + + def testMultipleWritesWithBadWrite(self): + filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name + self._WriteToFile(filename, EventFileLoaderTest.RECORD) + self._WriteToFile(filename, EventFileLoaderTest.RECORD) + # Test that we ignore partial record writes at the end of the file. + self._WriteToFile(filename, b"123") + loader = self._LoaderForTestFile(filename) + self.assertEqual(len(list(loader.Load())), 2) class RawEventFileLoaderTest(EventFileLoaderTest): - - def _LoaderForTestFile(self, filename): - return event_file_loader.RawEventFileLoader( - os.path.join(self.get_temp_dir(), filename)) - - def testSingleWrite(self): - filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name - self._WriteToFile(filename, EventFileLoaderTest.RECORD) - loader = self._LoaderForTestFile(filename) - event_protos = list(loader.Load()) - self.assertEqual(len(event_protos), 1) - # Record format has a 12 byte header and a 4 byte trailer. - expected_event_proto = EventFileLoaderTest.RECORD[12:-4] - self.assertEqual(event_protos[0], expected_event_proto) - - -if __name__ == '__main__': - tf.test.main() + def _LoaderForTestFile(self, filename): + return event_file_loader.RawEventFileLoader( + os.path.join(self.get_temp_dir(), filename) + ) + + def testSingleWrite(self): + filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name + self._WriteToFile(filename, EventFileLoaderTest.RECORD) + loader = self._LoaderForTestFile(filename) + event_protos = list(loader.Load()) + self.assertEqual(len(event_protos), 1) + # Record format has a 12 byte header and a 4 byte trailer. + expected_event_proto = EventFileLoaderTest.RECORD[12:-4] + self.assertEqual(event_protos[0], expected_event_proto) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/backend/event_processing/event_multiplexer.py b/tensorboard/backend/event_processing/event_multiplexer.py index 5823a6412d..b575018abe 100644 --- a/tensorboard/backend/event_processing/event_multiplexer.py +++ b/tensorboard/backend/event_processing/event_multiplexer.py @@ -31,470 +31,481 @@ logger = tb_logging.get_logger() -class EventMultiplexer(object): - """An `EventMultiplexer` manages access to multiple `EventAccumulator`s. - - Each `EventAccumulator` is associated with a `run`, which is a self-contained - TensorFlow execution. The `EventMultiplexer` provides methods for extracting - information about events from multiple `run`s. - - Example usage for loading specific runs from files: - - ```python - x = EventMultiplexer({'run1': 'path/to/run1', 'run2': 'path/to/run2'}) - x.Reload() - ``` - - Example usage for loading a directory where each subdirectory is a run - - ```python - (eg:) /parent/directory/path/ - /parent/directory/path/run1/ - /parent/directory/path/run1/events.out.tfevents.1001 - /parent/directory/path/run1/events.out.tfevents.1002 - - /parent/directory/path/run2/ - /parent/directory/path/run2/events.out.tfevents.9232 - - /parent/directory/path/run3/ - /parent/directory/path/run3/events.out.tfevents.9232 - x = EventMultiplexer().AddRunsFromDirectory('/parent/directory/path') - (which is equivalent to:) - x = EventMultiplexer({'run1': '/parent/directory/path/run1', 'run2':...} - ``` - - If you would like to watch `/parent/directory/path`, wait for it to be created - (if necessary) and then periodically pick up new runs, use - `AutoloadingMultiplexer` - @@Tensors - """ - - def __init__(self, - run_path_map=None, - size_guidance=None, - purge_orphaned_data=True): - """Constructor for the `EventMultiplexer`. - - Args: - run_path_map: Dict `{run: path}` which specifies the - name of a run, and the path to find the associated events. If it is - None, then the EventMultiplexer initializes without any runs. - size_guidance: A dictionary mapping from `tagType` to the number of items - to store for each tag of that type. See - `event_accumulator.EventAccumulator` for details. - purge_orphaned_data: Whether to discard any events that were "orphaned" by - a TensorFlow restart. - """ - logger.info('Event Multiplexer initializing.') - self._accumulators_mutex = threading.Lock() - self._accumulators = {} - self._paths = {} - self._reload_called = False - self._size_guidance = (size_guidance or - event_accumulator.DEFAULT_SIZE_GUIDANCE) - self.purge_orphaned_data = purge_orphaned_data - if run_path_map is not None: - logger.info('Event Multplexer doing initialization load for %s', - run_path_map) - for (run, path) in six.iteritems(run_path_map): - self.AddRun(path, run) - logger.info('Event Multiplexer done initializing') - - def AddRun(self, path, name=None): - """Add a run to the multiplexer. - - If the name is not specified, it is the same as the path. - - If a run by that name exists, and we are already watching the right path, - do nothing. If we are watching a different path, replace the event - accumulator. - - If `Reload` has been called, it will `Reload` the newly created - accumulators. - - Args: - path: Path to the event files (or event directory) for given run. - name: Name of the run to add. If not provided, is set to path. - - Returns: - The `EventMultiplexer`. - """ - name = name or path - accumulator = None - with self._accumulators_mutex: - if name not in self._accumulators or self._paths[name] != path: - if name in self._paths and self._paths[name] != path: - # TODO(@decentralion) - Make it impossible to overwrite an old path - # with a new path (just give the new path a distinct name) - logger.warn('Conflict for name %s: old path %s, new path %s', - name, self._paths[name], path) - logger.info('Constructing EventAccumulator for %s', path) - accumulator = event_accumulator.EventAccumulator( - path, - size_guidance=self._size_guidance, - purge_orphaned_data=self.purge_orphaned_data) - self._accumulators[name] = accumulator - self._paths[name] = path - if accumulator: - if self._reload_called: - accumulator.Reload() - return self - - def AddRunsFromDirectory(self, path, name=None): - """Load runs from a directory; recursively walks subdirectories. - - If path doesn't exist, no-op. This ensures that it is safe to call - `AddRunsFromDirectory` multiple times, even before the directory is made. - - If path is a directory, load event files in the directory (if any exist) and - recursively call AddRunsFromDirectory on any subdirectories. This mean you - can call AddRunsFromDirectory at the root of a tree of event logs and - TensorBoard will load them all. - - If the `EventMultiplexer` is already loaded this will cause - the newly created accumulators to `Reload()`. - Args: - path: A string path to a directory to load runs from. - name: Optionally, what name to apply to the runs. If name is provided - and the directory contains run subdirectories, the name of each subrun - is the concatenation of the parent name and the subdirectory name. If - name is provided and the directory contains event files, then a run - is added called "name" and with the events from the path. - - Raises: - ValueError: If the path exists and isn't a directory. - - Returns: - The `EventMultiplexer`. - """ - logger.info('Starting AddRunsFromDirectory: %s', path) - for subdir in io_wrapper.GetLogdirSubdirectories(path): - logger.info('Adding events from directory %s', subdir) - rpath = os.path.relpath(subdir, path) - subname = os.path.join(name, rpath) if name else rpath - self.AddRun(subdir, name=subname) - logger.info('Done with AddRunsFromDirectory: %s', path) - return self - - def Reload(self): - """Call `Reload` on every `EventAccumulator`.""" - logger.info('Beginning EventMultiplexer.Reload()') - self._reload_called = True - # Build a list so we're safe even if the list of accumulators is modified - # even while we're reloading. - with self._accumulators_mutex: - items = list(self._accumulators.items()) - - names_to_delete = set() - for name, accumulator in items: - try: - accumulator.Reload() - except (OSError, IOError) as e: - logger.error("Unable to reload accumulator '%s': %s", name, e) - except directory_watcher.DirectoryDeletedError: - names_to_delete.add(name) - - with self._accumulators_mutex: - for name in names_to_delete: - logger.warn("Deleting accumulator '%s'", name) - del self._accumulators[name] - logger.info('Finished with EventMultiplexer.Reload()') - return self - - def PluginAssets(self, plugin_name): - """Get index of runs and assets for a given plugin. - - Args: - plugin_name: Name of the plugin we are checking for. - - Returns: - A dictionary that maps from run_name to a list of plugin - assets for that run. - """ - with self._accumulators_mutex: - # To avoid nested locks, we construct a copy of the run-accumulator map - items = list(six.iteritems(self._accumulators)) - - return {run: accum.PluginAssets(plugin_name) for run, accum in items} - - def RetrievePluginAsset(self, run, plugin_name, asset_name): - """Return the contents for a specific plugin asset from a run. - - Args: - run: The string name of the run. - plugin_name: The string name of a plugin. - asset_name: The string name of an asset. - - Returns: - The string contents of the plugin asset. - - Raises: - KeyError: If the asset is not available. - """ - accumulator = self.GetAccumulator(run) - return accumulator.RetrievePluginAsset(plugin_name, asset_name) - - def FirstEventTimestamp(self, run): - """Return the timestamp of the first event of the given run. - - This may perform I/O if no events have been loaded yet for the run. - - Args: - run: A string name of the run for which the timestamp is retrieved. - - Returns: - The wall_time of the first event of the run, which will typically be - seconds since the epoch. - - Raises: - KeyError: If the run is not found. - ValueError: If the run has no events loaded and there are no events on - disk to load. - """ - accumulator = self.GetAccumulator(run) - return accumulator.FirstEventTimestamp() - - def Scalars(self, run, tag): - """Retrieve the scalar events associated with a run and tag. - - Args: - run: A string name of the run for which values are retrieved. - tag: A string name of the tag for which values are retrieved. - - Raises: - KeyError: If the run is not found, or the tag is not available for - the given run. - - Returns: - An array of `event_accumulator.ScalarEvents`. - """ - accumulator = self.GetAccumulator(run) - return accumulator.Scalars(tag) - - def Graph(self, run): - """Retrieve the graph associated with the provided run. - - Args: - run: A string name of a run to load the graph for. - - Raises: - KeyError: If the run is not found. - ValueError: If the run does not have an associated graph. - - Returns: - The `GraphDef` protobuf data structure. - """ - accumulator = self.GetAccumulator(run) - return accumulator.Graph() - - def SerializedGraph(self, run): - """Retrieve the serialized graph associated with the provided run. - - Args: - run: A string name of a run to load the graph for. - - Raises: - KeyError: If the run is not found. - ValueError: If the run does not have an associated graph. - - Returns: - The serialized form of the `GraphDef` protobuf data structure. - """ - accumulator = self.GetAccumulator(run) - return accumulator.SerializedGraph() - - def MetaGraph(self, run): - """Retrieve the metagraph associated with the provided run. - - Args: - run: A string name of a run to load the graph for. - - Raises: - KeyError: If the run is not found. - ValueError: If the run does not have an associated graph. - - Returns: - The `MetaGraphDef` protobuf data structure. - """ - accumulator = self.GetAccumulator(run) - return accumulator.MetaGraph() - - def RunMetadata(self, run, tag): - """Get the session.run() metadata associated with a TensorFlow run and tag. - Args: - run: A string name of a TensorFlow run. - tag: A string name of the tag associated with a particular session.run(). - - Raises: - KeyError: If the run is not found, or the tag is not available for the - given run. - - Returns: - The metadata in the form of `RunMetadata` protobuf data structure. - """ - accumulator = self.GetAccumulator(run) - return accumulator.RunMetadata(tag) - - def Histograms(self, run, tag): - """Retrieve the histogram events associated with a run and tag. - - Args: - run: A string name of the run for which values are retrieved. - tag: A string name of the tag for which values are retrieved. - - Raises: - KeyError: If the run is not found, or the tag is not available for - the given run. - - Returns: - An array of `event_accumulator.HistogramEvents`. - """ - accumulator = self.GetAccumulator(run) - return accumulator.Histograms(tag) - - def CompressedHistograms(self, run, tag): - """Retrieve the compressed histogram events associated with a run and tag. - - Args: - run: A string name of the run for which values are retrieved. - tag: A string name of the tag for which values are retrieved. - - Raises: - KeyError: If the run is not found, or the tag is not available for - the given run. - - Returns: - An array of `event_accumulator.CompressedHistogramEvents`. - """ - accumulator = self.GetAccumulator(run) - return accumulator.CompressedHistograms(tag) - - def Images(self, run, tag): - """Retrieve the image events associated with a run and tag. - - Args: - run: A string name of the run for which values are retrieved. - tag: A string name of the tag for which values are retrieved. - - Raises: - KeyError: If the run is not found, or the tag is not available for - the given run. - - Returns: - An array of `event_accumulator.ImageEvents`. - """ - accumulator = self.GetAccumulator(run) - return accumulator.Images(tag) - - def Audio(self, run, tag): - """Retrieve the audio events associated with a run and tag. - - Args: - run: A string name of the run for which values are retrieved. - tag: A string name of the tag for which values are retrieved. - - Raises: - KeyError: If the run is not found, or the tag is not available for - the given run. - - Returns: - An array of `event_accumulator.AudioEvents`. - """ - accumulator = self.GetAccumulator(run) - return accumulator.Audio(tag) - - def Tensors(self, run, tag): - """Retrieve the tensor events associated with a run and tag. - - Args: - run: A string name of the run for which values are retrieved. - tag: A string name of the tag for which values are retrieved. - - Raises: - KeyError: If the run is not found, or the tag is not available for - the given run. - - Returns: - An array of `event_accumulator.TensorEvent`s. - """ - accumulator = self.GetAccumulator(run) - return accumulator.Tensors(tag) - - def PluginRunToTagToContent(self, plugin_name): - """Returns a 2-layer dictionary of the form {run: {tag: content}}. - - The `content` referred above is the content field of the PluginData proto - for the specified plugin within a Summary.Value proto. - - Args: - plugin_name: The name of the plugin for which to fetch content. +class EventMultiplexer(object): + """An `EventMultiplexer` manages access to multiple `EventAccumulator`s. - Returns: - A dictionary of the form {run: {tag: content}}. - """ - mapping = {} - for run in self.Runs(): - try: - tag_to_content = self.GetAccumulator(run).PluginTagToContent( - plugin_name) - except KeyError: - # This run lacks content for the plugin. Try the next run. - continue - mapping[run] = tag_to_content - return mapping - - def SummaryMetadata(self, run, tag): - """Return the summary metadata for the given tag on the given run. - - Args: - run: A string name of the run for which summary metadata is to be - retrieved. - tag: A string name of the tag whose summary metadata is to be - retrieved. - - Raises: - KeyError: If the run is not found, or the tag is not available for - the given run. - - Returns: - A `SummaryMetadata` protobuf. - """ - accumulator = self.GetAccumulator(run) - return accumulator.SummaryMetadata(tag) + Each `EventAccumulator` is associated with a `run`, which is a self-contained + TensorFlow execution. The `EventMultiplexer` provides methods for extracting + information about events from multiple `run`s. - def Runs(self): - """Return all the run names in the `EventMultiplexer`. + Example usage for loading specific runs from files: - Returns: + ```python + x = EventMultiplexer({'run1': 'path/to/run1', 'run2': 'path/to/run2'}) + x.Reload() ``` - {runName: { images: [tag1, tag2, tag3], - scalarValues: [tagA, tagB, tagC], - histograms: [tagX, tagY, tagZ], - compressedHistograms: [tagX, tagY, tagZ], - graph: true, meta_graph: true}} - ``` - """ - with self._accumulators_mutex: - # To avoid nested locks, we construct a copy of the run-accumulator map - items = list(six.iteritems(self._accumulators)) - return {run_name: accumulator.Tags() for run_name, accumulator in items} - def RunPaths(self): - """Returns a dict mapping run names to event file paths.""" - return self._paths + Example usage for loading a directory where each subdirectory is a run - def GetAccumulator(self, run): - """Returns EventAccumulator for a given run. + ```python + (eg:) /parent/directory/path/ + /parent/directory/path/run1/ + /parent/directory/path/run1/events.out.tfevents.1001 + /parent/directory/path/run1/events.out.tfevents.1002 - Args: - run: String name of run. + /parent/directory/path/run2/ + /parent/directory/path/run2/events.out.tfevents.9232 - Returns: - An EventAccumulator object. + /parent/directory/path/run3/ + /parent/directory/path/run3/events.out.tfevents.9232 + x = EventMultiplexer().AddRunsFromDirectory('/parent/directory/path') + (which is equivalent to:) + x = EventMultiplexer({'run1': '/parent/directory/path/run1', 'run2':...} + ``` - Raises: - KeyError: If run does not exist. + If you would like to watch `/parent/directory/path`, wait for it to be created + (if necessary) and then periodically pick up new runs, use + `AutoloadingMultiplexer` + @@Tensors """ - with self._accumulators_mutex: - return self._accumulators[run] + + def __init__( + self, run_path_map=None, size_guidance=None, purge_orphaned_data=True + ): + """Constructor for the `EventMultiplexer`. + + Args: + run_path_map: Dict `{run: path}` which specifies the + name of a run, and the path to find the associated events. If it is + None, then the EventMultiplexer initializes without any runs. + size_guidance: A dictionary mapping from `tagType` to the number of items + to store for each tag of that type. See + `event_accumulator.EventAccumulator` for details. + purge_orphaned_data: Whether to discard any events that were "orphaned" by + a TensorFlow restart. + """ + logger.info("Event Multiplexer initializing.") + self._accumulators_mutex = threading.Lock() + self._accumulators = {} + self._paths = {} + self._reload_called = False + self._size_guidance = ( + size_guidance or event_accumulator.DEFAULT_SIZE_GUIDANCE + ) + self.purge_orphaned_data = purge_orphaned_data + if run_path_map is not None: + logger.info( + "Event Multplexer doing initialization load for %s", + run_path_map, + ) + for (run, path) in six.iteritems(run_path_map): + self.AddRun(path, run) + logger.info("Event Multiplexer done initializing") + + def AddRun(self, path, name=None): + """Add a run to the multiplexer. + + If the name is not specified, it is the same as the path. + + If a run by that name exists, and we are already watching the right path, + do nothing. If we are watching a different path, replace the event + accumulator. + + If `Reload` has been called, it will `Reload` the newly created + accumulators. + + Args: + path: Path to the event files (or event directory) for given run. + name: Name of the run to add. If not provided, is set to path. + + Returns: + The `EventMultiplexer`. + """ + name = name or path + accumulator = None + with self._accumulators_mutex: + if name not in self._accumulators or self._paths[name] != path: + if name in self._paths and self._paths[name] != path: + # TODO(@decentralion) - Make it impossible to overwrite an old path + # with a new path (just give the new path a distinct name) + logger.warn( + "Conflict for name %s: old path %s, new path %s", + name, + self._paths[name], + path, + ) + logger.info("Constructing EventAccumulator for %s", path) + accumulator = event_accumulator.EventAccumulator( + path, + size_guidance=self._size_guidance, + purge_orphaned_data=self.purge_orphaned_data, + ) + self._accumulators[name] = accumulator + self._paths[name] = path + if accumulator: + if self._reload_called: + accumulator.Reload() + return self + + def AddRunsFromDirectory(self, path, name=None): + """Load runs from a directory; recursively walks subdirectories. + + If path doesn't exist, no-op. This ensures that it is safe to call + `AddRunsFromDirectory` multiple times, even before the directory is made. + + If path is a directory, load event files in the directory (if any exist) and + recursively call AddRunsFromDirectory on any subdirectories. This mean you + can call AddRunsFromDirectory at the root of a tree of event logs and + TensorBoard will load them all. + + If the `EventMultiplexer` is already loaded this will cause + the newly created accumulators to `Reload()`. + Args: + path: A string path to a directory to load runs from. + name: Optionally, what name to apply to the runs. If name is provided + and the directory contains run subdirectories, the name of each subrun + is the concatenation of the parent name and the subdirectory name. If + name is provided and the directory contains event files, then a run + is added called "name" and with the events from the path. + + Raises: + ValueError: If the path exists and isn't a directory. + + Returns: + The `EventMultiplexer`. + """ + logger.info("Starting AddRunsFromDirectory: %s", path) + for subdir in io_wrapper.GetLogdirSubdirectories(path): + logger.info("Adding events from directory %s", subdir) + rpath = os.path.relpath(subdir, path) + subname = os.path.join(name, rpath) if name else rpath + self.AddRun(subdir, name=subname) + logger.info("Done with AddRunsFromDirectory: %s", path) + return self + + def Reload(self): + """Call `Reload` on every `EventAccumulator`.""" + logger.info("Beginning EventMultiplexer.Reload()") + self._reload_called = True + # Build a list so we're safe even if the list of accumulators is modified + # even while we're reloading. + with self._accumulators_mutex: + items = list(self._accumulators.items()) + + names_to_delete = set() + for name, accumulator in items: + try: + accumulator.Reload() + except (OSError, IOError) as e: + logger.error("Unable to reload accumulator '%s': %s", name, e) + except directory_watcher.DirectoryDeletedError: + names_to_delete.add(name) + + with self._accumulators_mutex: + for name in names_to_delete: + logger.warn("Deleting accumulator '%s'", name) + del self._accumulators[name] + logger.info("Finished with EventMultiplexer.Reload()") + return self + + def PluginAssets(self, plugin_name): + """Get index of runs and assets for a given plugin. + + Args: + plugin_name: Name of the plugin we are checking for. + + Returns: + A dictionary that maps from run_name to a list of plugin + assets for that run. + """ + with self._accumulators_mutex: + # To avoid nested locks, we construct a copy of the run-accumulator map + items = list(six.iteritems(self._accumulators)) + + return {run: accum.PluginAssets(plugin_name) for run, accum in items} + + def RetrievePluginAsset(self, run, plugin_name, asset_name): + """Return the contents for a specific plugin asset from a run. + + Args: + run: The string name of the run. + plugin_name: The string name of a plugin. + asset_name: The string name of an asset. + + Returns: + The string contents of the plugin asset. + + Raises: + KeyError: If the asset is not available. + """ + accumulator = self.GetAccumulator(run) + return accumulator.RetrievePluginAsset(plugin_name, asset_name) + + def FirstEventTimestamp(self, run): + """Return the timestamp of the first event of the given run. + + This may perform I/O if no events have been loaded yet for the run. + + Args: + run: A string name of the run for which the timestamp is retrieved. + + Returns: + The wall_time of the first event of the run, which will typically be + seconds since the epoch. + + Raises: + KeyError: If the run is not found. + ValueError: If the run has no events loaded and there are no events on + disk to load. + """ + accumulator = self.GetAccumulator(run) + return accumulator.FirstEventTimestamp() + + def Scalars(self, run, tag): + """Retrieve the scalar events associated with a run and tag. + + Args: + run: A string name of the run for which values are retrieved. + tag: A string name of the tag for which values are retrieved. + + Raises: + KeyError: If the run is not found, or the tag is not available for + the given run. + + Returns: + An array of `event_accumulator.ScalarEvents`. + """ + accumulator = self.GetAccumulator(run) + return accumulator.Scalars(tag) + + def Graph(self, run): + """Retrieve the graph associated with the provided run. + + Args: + run: A string name of a run to load the graph for. + + Raises: + KeyError: If the run is not found. + ValueError: If the run does not have an associated graph. + + Returns: + The `GraphDef` protobuf data structure. + """ + accumulator = self.GetAccumulator(run) + return accumulator.Graph() + + def SerializedGraph(self, run): + """Retrieve the serialized graph associated with the provided run. + + Args: + run: A string name of a run to load the graph for. + + Raises: + KeyError: If the run is not found. + ValueError: If the run does not have an associated graph. + + Returns: + The serialized form of the `GraphDef` protobuf data structure. + """ + accumulator = self.GetAccumulator(run) + return accumulator.SerializedGraph() + + def MetaGraph(self, run): + """Retrieve the metagraph associated with the provided run. + + Args: + run: A string name of a run to load the graph for. + + Raises: + KeyError: If the run is not found. + ValueError: If the run does not have an associated graph. + + Returns: + The `MetaGraphDef` protobuf data structure. + """ + accumulator = self.GetAccumulator(run) + return accumulator.MetaGraph() + + def RunMetadata(self, run, tag): + """Get the session.run() metadata associated with a TensorFlow run and + tag. + + Args: + run: A string name of a TensorFlow run. + tag: A string name of the tag associated with a particular session.run(). + + Raises: + KeyError: If the run is not found, or the tag is not available for the + given run. + + Returns: + The metadata in the form of `RunMetadata` protobuf data structure. + """ + accumulator = self.GetAccumulator(run) + return accumulator.RunMetadata(tag) + + def Histograms(self, run, tag): + """Retrieve the histogram events associated with a run and tag. + + Args: + run: A string name of the run for which values are retrieved. + tag: A string name of the tag for which values are retrieved. + + Raises: + KeyError: If the run is not found, or the tag is not available for + the given run. + + Returns: + An array of `event_accumulator.HistogramEvents`. + """ + accumulator = self.GetAccumulator(run) + return accumulator.Histograms(tag) + + def CompressedHistograms(self, run, tag): + """Retrieve the compressed histogram events associated with a run and + tag. + + Args: + run: A string name of the run for which values are retrieved. + tag: A string name of the tag for which values are retrieved. + + Raises: + KeyError: If the run is not found, or the tag is not available for + the given run. + + Returns: + An array of `event_accumulator.CompressedHistogramEvents`. + """ + accumulator = self.GetAccumulator(run) + return accumulator.CompressedHistograms(tag) + + def Images(self, run, tag): + """Retrieve the image events associated with a run and tag. + + Args: + run: A string name of the run for which values are retrieved. + tag: A string name of the tag for which values are retrieved. + + Raises: + KeyError: If the run is not found, or the tag is not available for + the given run. + + Returns: + An array of `event_accumulator.ImageEvents`. + """ + accumulator = self.GetAccumulator(run) + return accumulator.Images(tag) + + def Audio(self, run, tag): + """Retrieve the audio events associated with a run and tag. + + Args: + run: A string name of the run for which values are retrieved. + tag: A string name of the tag for which values are retrieved. + + Raises: + KeyError: If the run is not found, or the tag is not available for + the given run. + + Returns: + An array of `event_accumulator.AudioEvents`. + """ + accumulator = self.GetAccumulator(run) + return accumulator.Audio(tag) + + def Tensors(self, run, tag): + """Retrieve the tensor events associated with a run and tag. + + Args: + run: A string name of the run for which values are retrieved. + tag: A string name of the tag for which values are retrieved. + + Raises: + KeyError: If the run is not found, or the tag is not available for + the given run. + + Returns: + An array of `event_accumulator.TensorEvent`s. + """ + accumulator = self.GetAccumulator(run) + return accumulator.Tensors(tag) + + def PluginRunToTagToContent(self, plugin_name): + """Returns a 2-layer dictionary of the form {run: {tag: content}}. + + The `content` referred above is the content field of the PluginData proto + for the specified plugin within a Summary.Value proto. + + Args: + plugin_name: The name of the plugin for which to fetch content. + + Returns: + A dictionary of the form {run: {tag: content}}. + """ + mapping = {} + for run in self.Runs(): + try: + tag_to_content = self.GetAccumulator(run).PluginTagToContent( + plugin_name + ) + except KeyError: + # This run lacks content for the plugin. Try the next run. + continue + mapping[run] = tag_to_content + return mapping + + def SummaryMetadata(self, run, tag): + """Return the summary metadata for the given tag on the given run. + + Args: + run: A string name of the run for which summary metadata is to be + retrieved. + tag: A string name of the tag whose summary metadata is to be + retrieved. + + Raises: + KeyError: If the run is not found, or the tag is not available for + the given run. + + Returns: + A `SummaryMetadata` protobuf. + """ + accumulator = self.GetAccumulator(run) + return accumulator.SummaryMetadata(tag) + + def Runs(self): + """Return all the run names in the `EventMultiplexer`. + + Returns: + ``` + {runName: { images: [tag1, tag2, tag3], + scalarValues: [tagA, tagB, tagC], + histograms: [tagX, tagY, tagZ], + compressedHistograms: [tagX, tagY, tagZ], + graph: true, meta_graph: true}} + ``` + """ + with self._accumulators_mutex: + # To avoid nested locks, we construct a copy of the run-accumulator map + items = list(six.iteritems(self._accumulators)) + return {run_name: accumulator.Tags() for run_name, accumulator in items} + + def RunPaths(self): + """Returns a dict mapping run names to event file paths.""" + return self._paths + + def GetAccumulator(self, run): + """Returns EventAccumulator for a given run. + + Args: + run: String name of run. + + Returns: + An EventAccumulator object. + + Raises: + KeyError: If run does not exist. + """ + with self._accumulators_mutex: + return self._accumulators[run] diff --git a/tensorboard/backend/event_processing/event_multiplexer_test.py b/tensorboard/backend/event_processing/event_multiplexer_test.py index 6263a7df06..7c24c7eb52 100644 --- a/tensorboard/backend/event_processing/event_multiplexer_test.py +++ b/tensorboard/backend/event_processing/event_multiplexer_test.py @@ -28,326 +28,350 @@ def _AddEvents(path): - if not tf.io.gfile.isdir(path): - tf.io.gfile.makedirs(path) - fpath = os.path.join(path, 'hypothetical.tfevents.out') - with tf.io.gfile.GFile(fpath, 'w') as f: - f.write('') - return fpath + if not tf.io.gfile.isdir(path): + tf.io.gfile.makedirs(path) + fpath = os.path.join(path, "hypothetical.tfevents.out") + with tf.io.gfile.GFile(fpath, "w") as f: + f.write("") + return fpath def _CreateCleanDirectory(path): - if tf.io.gfile.isdir(path): - tf.io.gfile.rmtree(path) - tf.io.gfile.mkdir(path) + if tf.io.gfile.isdir(path): + tf.io.gfile.rmtree(path) + tf.io.gfile.mkdir(path) class _FakeAccumulator(object): - - def __init__(self, path): - """Constructs a fake accumulator with some fake events. - - Args: - path: The path for the run that this accumulator is for. - """ - self._path = path - self.reload_called = False - self._plugin_to_tag_to_content = { - 'baz_plugin': { - 'foo': 'foo_content', - 'bar': 'bar_content', + def __init__(self, path): + """Constructs a fake accumulator with some fake events. + + Args: + path: The path for the run that this accumulator is for. + """ + self._path = path + self.reload_called = False + self._plugin_to_tag_to_content = { + "baz_plugin": {"foo": "foo_content", "bar": "bar_content",} } - } - def Tags(self): - return {event_accumulator.IMAGES: ['im1', 'im2'], - event_accumulator.AUDIO: ['snd1', 'snd2'], - event_accumulator.HISTOGRAMS: ['hst1', 'hst2'], - event_accumulator.COMPRESSED_HISTOGRAMS: ['cmphst1', 'cmphst2'], - event_accumulator.SCALARS: ['sv1', 'sv2']} + def Tags(self): + return { + event_accumulator.IMAGES: ["im1", "im2"], + event_accumulator.AUDIO: ["snd1", "snd2"], + event_accumulator.HISTOGRAMS: ["hst1", "hst2"], + event_accumulator.COMPRESSED_HISTOGRAMS: ["cmphst1", "cmphst2"], + event_accumulator.SCALARS: ["sv1", "sv2"], + } - def FirstEventTimestamp(self): - return 0 + def FirstEventTimestamp(self): + return 0 - def _TagHelper(self, tag_name, enum): - if tag_name not in self.Tags()[enum]: - raise KeyError - return ['%s/%s' % (self._path, tag_name)] + def _TagHelper(self, tag_name, enum): + if tag_name not in self.Tags()[enum]: + raise KeyError + return ["%s/%s" % (self._path, tag_name)] - def Scalars(self, tag_name): - return self._TagHelper(tag_name, event_accumulator.SCALARS) + def Scalars(self, tag_name): + return self._TagHelper(tag_name, event_accumulator.SCALARS) - def Histograms(self, tag_name): - return self._TagHelper(tag_name, event_accumulator.HISTOGRAMS) + def Histograms(self, tag_name): + return self._TagHelper(tag_name, event_accumulator.HISTOGRAMS) - def CompressedHistograms(self, tag_name): - return self._TagHelper(tag_name, event_accumulator.COMPRESSED_HISTOGRAMS) + def CompressedHistograms(self, tag_name): + return self._TagHelper( + tag_name, event_accumulator.COMPRESSED_HISTOGRAMS + ) - def Images(self, tag_name): - return self._TagHelper(tag_name, event_accumulator.IMAGES) + def Images(self, tag_name): + return self._TagHelper(tag_name, event_accumulator.IMAGES) - def Audio(self, tag_name): - return self._TagHelper(tag_name, event_accumulator.AUDIO) + def Audio(self, tag_name): + return self._TagHelper(tag_name, event_accumulator.AUDIO) - def Tensors(self, tag_name): - return self._TagHelper(tag_name, event_accumulator.TENSORS) + def Tensors(self, tag_name): + return self._TagHelper(tag_name, event_accumulator.TENSORS) - def PluginTagToContent(self, plugin_name): - # We pre-pend the runs with the path and '_' so that we can verify that the - # tags are associated with the correct runs. - return { - self._path + '_' + run: content_mapping - for (run, content_mapping - ) in self._plugin_to_tag_to_content[plugin_name].items() - } + def PluginTagToContent(self, plugin_name): + # We pre-pend the runs with the path and '_' so that we can verify that the + # tags are associated with the correct runs. + return { + self._path + "_" + run: content_mapping + for (run, content_mapping) in self._plugin_to_tag_to_content[ + plugin_name + ].items() + } - def Reload(self): - self.reload_called = True + def Reload(self): + self.reload_called = True -def _GetFakeAccumulator(path, - size_guidance=None, - compression_bps=None, - purge_orphaned_data=None): - del size_guidance, compression_bps, purge_orphaned_data # Unused. - return _FakeAccumulator(path) +def _GetFakeAccumulator( + path, size_guidance=None, compression_bps=None, purge_orphaned_data=None +): + del size_guidance, compression_bps, purge_orphaned_data # Unused. + return _FakeAccumulator(path) class EventMultiplexerTest(tf.test.TestCase): - - def setUp(self): - super(EventMultiplexerTest, self).setUp() - self.stubs = tf.compat.v1.test.StubOutForTesting() - - self.stubs.Set(event_accumulator, 'EventAccumulator', _GetFakeAccumulator) - - def tearDown(self): - self.stubs.CleanUp() - - def testEmptyLoader(self): - """Tests empty EventMultiplexer creation.""" - x = event_multiplexer.EventMultiplexer() - self.assertEqual(x.Runs(), {}) - - def testRunNamesRespected(self): - """Tests two EventAccumulators inserted/accessed in EventMultiplexer.""" - x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'}) - self.assertItemsEqual(sorted(x.Runs().keys()), ['run1', 'run2']) - self.assertEqual(x.GetAccumulator('run1')._path, 'path1') - self.assertEqual(x.GetAccumulator('run2')._path, 'path2') - - def testReload(self): - """EventAccumulators should Reload after EventMultiplexer call it.""" - x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'}) - self.assertFalse(x.GetAccumulator('run1').reload_called) - self.assertFalse(x.GetAccumulator('run2').reload_called) - x.Reload() - self.assertTrue(x.GetAccumulator('run1').reload_called) - self.assertTrue(x.GetAccumulator('run2').reload_called) - - def testScalars(self): - """Tests Scalars function returns suitable values.""" - x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'}) - - run1_actual = x.Scalars('run1', 'sv1') - run1_expected = ['path1/sv1'] - - self.assertEqual(run1_expected, run1_actual) - - def testPluginRunToTagToContent(self): - """Tests the method that produces the run to tag to content mapping.""" - x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'}) - self.assertDictEqual({ - 'run1': { - 'path1_foo': 'foo_content', - 'path1_bar': 'bar_content', - }, - 'run2': { - 'path2_foo': 'foo_content', - 'path2_bar': 'bar_content', - } - }, x.PluginRunToTagToContent('baz_plugin')) - - def testExceptions(self): - """KeyError should be raised when accessing non-existing keys.""" - x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'}) - with self.assertRaises(KeyError): - x.Scalars('sv1', 'xxx') - - def testInitialization(self): - """Tests EventMultiplexer is created properly with its params.""" - x = event_multiplexer.EventMultiplexer() - self.assertEqual(x.Runs(), {}) - x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'}) - self.assertItemsEqual(x.Runs(), ['run1', 'run2']) - self.assertEqual(x.GetAccumulator('run1')._path, 'path1') - self.assertEqual(x.GetAccumulator('run2')._path, 'path2') - - def testAddRunsFromDirectory(self): - """Tests AddRunsFromDirectory function. - - Tests the following scenarios: - - When the directory does not exist. - - When the directory is empty. - - When the directory has empty subdirectory. - - Contains proper EventAccumulators after adding events. - """ - x = event_multiplexer.EventMultiplexer() - tmpdir = self.get_temp_dir() - join = os.path.join - fakedir = join(tmpdir, 'fake_accumulator_directory') - realdir = join(tmpdir, 'real_accumulator_directory') - self.assertEqual(x.Runs(), {}) - x.AddRunsFromDirectory(fakedir) - self.assertEqual(x.Runs(), {}, 'loading fakedir had no effect') - - _CreateCleanDirectory(realdir) - x.AddRunsFromDirectory(realdir) - self.assertEqual(x.Runs(), {}, 'loading empty directory had no effect') - - path1 = join(realdir, 'path1') - tf.io.gfile.mkdir(path1) - x.AddRunsFromDirectory(realdir) - self.assertEqual(x.Runs(), {}, 'creating empty subdirectory had no effect') - - _AddEvents(path1) - x.AddRunsFromDirectory(realdir) - self.assertItemsEqual(x.Runs(), ['path1'], 'loaded run: path1') - loader1 = x.GetAccumulator('path1') - self.assertEqual(loader1._path, path1, 'has the correct path') - - path2 = join(realdir, 'path2') - _AddEvents(path2) - x.AddRunsFromDirectory(realdir) - self.assertItemsEqual(x.Runs(), ['path1', 'path2']) - self.assertEqual( - x.GetAccumulator('path1'), loader1, 'loader1 not regenerated') - - path2_2 = join(path2, 'path2') - _AddEvents(path2_2) - x.AddRunsFromDirectory(realdir) - self.assertItemsEqual(x.Runs(), ['path1', 'path2', 'path2/path2']) - self.assertEqual( - x.GetAccumulator('path2/path2')._path, path2_2, 'loader2 path correct') - - def testAddRunsFromDirectoryThatContainsEvents(self): - x = event_multiplexer.EventMultiplexer() - tmpdir = self.get_temp_dir() - join = os.path.join - realdir = join(tmpdir, 'event_containing_directory') - - _CreateCleanDirectory(realdir) - - self.assertEqual(x.Runs(), {}) - - _AddEvents(realdir) - x.AddRunsFromDirectory(realdir) - self.assertItemsEqual(x.Runs(), ['.']) - - subdir = join(realdir, 'subdir') - _AddEvents(subdir) - x.AddRunsFromDirectory(realdir) - self.assertItemsEqual(x.Runs(), ['.', 'subdir']) - - def testAddRunsFromDirectoryWithRunNames(self): - x = event_multiplexer.EventMultiplexer() - tmpdir = self.get_temp_dir() - join = os.path.join - realdir = join(tmpdir, 'event_containing_directory') - - _CreateCleanDirectory(realdir) - - self.assertEqual(x.Runs(), {}) - - _AddEvents(realdir) - x.AddRunsFromDirectory(realdir, 'foo') - self.assertItemsEqual(x.Runs(), ['foo/.']) - - subdir = join(realdir, 'subdir') - _AddEvents(subdir) - x.AddRunsFromDirectory(realdir, 'foo') - self.assertItemsEqual(x.Runs(), ['foo/.', 'foo/subdir']) - - def testAddRunsFromDirectoryWalksTree(self): - x = event_multiplexer.EventMultiplexer() - tmpdir = self.get_temp_dir() - join = os.path.join - realdir = join(tmpdir, 'event_containing_directory') - - _CreateCleanDirectory(realdir) - _AddEvents(realdir) - sub = join(realdir, 'subdirectory') - sub1 = join(sub, '1') - sub2 = join(sub, '2') - sub1_1 = join(sub1, '1') - _AddEvents(sub1) - _AddEvents(sub2) - _AddEvents(sub1_1) - x.AddRunsFromDirectory(realdir) - - self.assertItemsEqual(x.Runs(), ['.', 'subdirectory/1', 'subdirectory/2', - 'subdirectory/1/1']) - - def testAddRunsFromDirectoryThrowsException(self): - x = event_multiplexer.EventMultiplexer() - tmpdir = self.get_temp_dir() - - filepath = _AddEvents(tmpdir) - with self.assertRaises(ValueError): - x.AddRunsFromDirectory(filepath) - - def testAddRun(self): - x = event_multiplexer.EventMultiplexer() - x.AddRun('run1_path', 'run1') - run1 = x.GetAccumulator('run1') - self.assertEqual(sorted(x.Runs().keys()), ['run1']) - self.assertEqual(run1._path, 'run1_path') - - x.AddRun('run1_path', 'run1') - self.assertEqual(run1, x.GetAccumulator('run1'), 'loader not recreated') - - x.AddRun('run2_path', 'run1') - new_run1 = x.GetAccumulator('run1') - self.assertEqual(new_run1._path, 'run2_path') - self.assertNotEqual(run1, new_run1) - - x.AddRun('runName3') - self.assertItemsEqual(sorted(x.Runs().keys()), ['run1', 'runName3']) - self.assertEqual(x.GetAccumulator('runName3')._path, 'runName3') - - def testAddRunMaintainsLoading(self): - x = event_multiplexer.EventMultiplexer() - x.Reload() - x.AddRun('run1') - x.AddRun('run2') - self.assertTrue(x.GetAccumulator('run1').reload_called) - self.assertTrue(x.GetAccumulator('run2').reload_called) + def setUp(self): + super(EventMultiplexerTest, self).setUp() + self.stubs = tf.compat.v1.test.StubOutForTesting() + + self.stubs.Set( + event_accumulator, "EventAccumulator", _GetFakeAccumulator + ) + + def tearDown(self): + self.stubs.CleanUp() + + def testEmptyLoader(self): + """Tests empty EventMultiplexer creation.""" + x = event_multiplexer.EventMultiplexer() + self.assertEqual(x.Runs(), {}) + + def testRunNamesRespected(self): + """Tests two EventAccumulators inserted/accessed in + EventMultiplexer.""" + x = event_multiplexer.EventMultiplexer( + {"run1": "path1", "run2": "path2"} + ) + self.assertItemsEqual(sorted(x.Runs().keys()), ["run1", "run2"]) + self.assertEqual(x.GetAccumulator("run1")._path, "path1") + self.assertEqual(x.GetAccumulator("run2")._path, "path2") + + def testReload(self): + """EventAccumulators should Reload after EventMultiplexer call it.""" + x = event_multiplexer.EventMultiplexer( + {"run1": "path1", "run2": "path2"} + ) + self.assertFalse(x.GetAccumulator("run1").reload_called) + self.assertFalse(x.GetAccumulator("run2").reload_called) + x.Reload() + self.assertTrue(x.GetAccumulator("run1").reload_called) + self.assertTrue(x.GetAccumulator("run2").reload_called) + + def testScalars(self): + """Tests Scalars function returns suitable values.""" + x = event_multiplexer.EventMultiplexer( + {"run1": "path1", "run2": "path2"} + ) + + run1_actual = x.Scalars("run1", "sv1") + run1_expected = ["path1/sv1"] + + self.assertEqual(run1_expected, run1_actual) + + def testPluginRunToTagToContent(self): + """Tests the method that produces the run to tag to content mapping.""" + x = event_multiplexer.EventMultiplexer( + {"run1": "path1", "run2": "path2"} + ) + self.assertDictEqual( + { + "run1": { + "path1_foo": "foo_content", + "path1_bar": "bar_content", + }, + "run2": { + "path2_foo": "foo_content", + "path2_bar": "bar_content", + }, + }, + x.PluginRunToTagToContent("baz_plugin"), + ) + + def testExceptions(self): + """KeyError should be raised when accessing non-existing keys.""" + x = event_multiplexer.EventMultiplexer( + {"run1": "path1", "run2": "path2"} + ) + with self.assertRaises(KeyError): + x.Scalars("sv1", "xxx") + + def testInitialization(self): + """Tests EventMultiplexer is created properly with its params.""" + x = event_multiplexer.EventMultiplexer() + self.assertEqual(x.Runs(), {}) + x = event_multiplexer.EventMultiplexer( + {"run1": "path1", "run2": "path2"} + ) + self.assertItemsEqual(x.Runs(), ["run1", "run2"]) + self.assertEqual(x.GetAccumulator("run1")._path, "path1") + self.assertEqual(x.GetAccumulator("run2")._path, "path2") + + def testAddRunsFromDirectory(self): + """Tests AddRunsFromDirectory function. + + Tests the following scenarios: + - When the directory does not exist. + - When the directory is empty. + - When the directory has empty subdirectory. + - Contains proper EventAccumulators after adding events. + """ + x = event_multiplexer.EventMultiplexer() + tmpdir = self.get_temp_dir() + join = os.path.join + fakedir = join(tmpdir, "fake_accumulator_directory") + realdir = join(tmpdir, "real_accumulator_directory") + self.assertEqual(x.Runs(), {}) + x.AddRunsFromDirectory(fakedir) + self.assertEqual(x.Runs(), {}, "loading fakedir had no effect") + + _CreateCleanDirectory(realdir) + x.AddRunsFromDirectory(realdir) + self.assertEqual(x.Runs(), {}, "loading empty directory had no effect") + + path1 = join(realdir, "path1") + tf.io.gfile.mkdir(path1) + x.AddRunsFromDirectory(realdir) + self.assertEqual( + x.Runs(), {}, "creating empty subdirectory had no effect" + ) + + _AddEvents(path1) + x.AddRunsFromDirectory(realdir) + self.assertItemsEqual(x.Runs(), ["path1"], "loaded run: path1") + loader1 = x.GetAccumulator("path1") + self.assertEqual(loader1._path, path1, "has the correct path") + + path2 = join(realdir, "path2") + _AddEvents(path2) + x.AddRunsFromDirectory(realdir) + self.assertItemsEqual(x.Runs(), ["path1", "path2"]) + self.assertEqual( + x.GetAccumulator("path1"), loader1, "loader1 not regenerated" + ) + + path2_2 = join(path2, "path2") + _AddEvents(path2_2) + x.AddRunsFromDirectory(realdir) + self.assertItemsEqual(x.Runs(), ["path1", "path2", "path2/path2"]) + self.assertEqual( + x.GetAccumulator("path2/path2")._path, + path2_2, + "loader2 path correct", + ) + + def testAddRunsFromDirectoryThatContainsEvents(self): + x = event_multiplexer.EventMultiplexer() + tmpdir = self.get_temp_dir() + join = os.path.join + realdir = join(tmpdir, "event_containing_directory") + + _CreateCleanDirectory(realdir) + + self.assertEqual(x.Runs(), {}) + + _AddEvents(realdir) + x.AddRunsFromDirectory(realdir) + self.assertItemsEqual(x.Runs(), ["."]) + + subdir = join(realdir, "subdir") + _AddEvents(subdir) + x.AddRunsFromDirectory(realdir) + self.assertItemsEqual(x.Runs(), [".", "subdir"]) + + def testAddRunsFromDirectoryWithRunNames(self): + x = event_multiplexer.EventMultiplexer() + tmpdir = self.get_temp_dir() + join = os.path.join + realdir = join(tmpdir, "event_containing_directory") + + _CreateCleanDirectory(realdir) + + self.assertEqual(x.Runs(), {}) + + _AddEvents(realdir) + x.AddRunsFromDirectory(realdir, "foo") + self.assertItemsEqual(x.Runs(), ["foo/."]) + + subdir = join(realdir, "subdir") + _AddEvents(subdir) + x.AddRunsFromDirectory(realdir, "foo") + self.assertItemsEqual(x.Runs(), ["foo/.", "foo/subdir"]) + + def testAddRunsFromDirectoryWalksTree(self): + x = event_multiplexer.EventMultiplexer() + tmpdir = self.get_temp_dir() + join = os.path.join + realdir = join(tmpdir, "event_containing_directory") + + _CreateCleanDirectory(realdir) + _AddEvents(realdir) + sub = join(realdir, "subdirectory") + sub1 = join(sub, "1") + sub2 = join(sub, "2") + sub1_1 = join(sub1, "1") + _AddEvents(sub1) + _AddEvents(sub2) + _AddEvents(sub1_1) + x.AddRunsFromDirectory(realdir) + + self.assertItemsEqual( + x.Runs(), + [".", "subdirectory/1", "subdirectory/2", "subdirectory/1/1"], + ) + + def testAddRunsFromDirectoryThrowsException(self): + x = event_multiplexer.EventMultiplexer() + tmpdir = self.get_temp_dir() + + filepath = _AddEvents(tmpdir) + with self.assertRaises(ValueError): + x.AddRunsFromDirectory(filepath) + + def testAddRun(self): + x = event_multiplexer.EventMultiplexer() + x.AddRun("run1_path", "run1") + run1 = x.GetAccumulator("run1") + self.assertEqual(sorted(x.Runs().keys()), ["run1"]) + self.assertEqual(run1._path, "run1_path") + + x.AddRun("run1_path", "run1") + self.assertEqual(run1, x.GetAccumulator("run1"), "loader not recreated") + + x.AddRun("run2_path", "run1") + new_run1 = x.GetAccumulator("run1") + self.assertEqual(new_run1._path, "run2_path") + self.assertNotEqual(run1, new_run1) + + x.AddRun("runName3") + self.assertItemsEqual(sorted(x.Runs().keys()), ["run1", "runName3"]) + self.assertEqual(x.GetAccumulator("runName3")._path, "runName3") + + def testAddRunMaintainsLoading(self): + x = event_multiplexer.EventMultiplexer() + x.Reload() + x.AddRun("run1") + x.AddRun("run2") + self.assertTrue(x.GetAccumulator("run1").reload_called) + self.assertTrue(x.GetAccumulator("run2").reload_called) class EventMultiplexerWithRealAccumulatorTest(tf.test.TestCase): + def testDeletingDirectoryRemovesRun(self): + x = event_multiplexer.EventMultiplexer() + tmpdir = self.get_temp_dir() + join = os.path.join + run1_dir = join(tmpdir, "run1") + run2_dir = join(tmpdir, "run2") + run3_dir = join(tmpdir, "run3") - def testDeletingDirectoryRemovesRun(self): - x = event_multiplexer.EventMultiplexer() - tmpdir = self.get_temp_dir() - join = os.path.join - run1_dir = join(tmpdir, 'run1') - run2_dir = join(tmpdir, 'run2') - run3_dir = join(tmpdir, 'run3') - - for dirname in [run1_dir, run2_dir, run3_dir]: - _AddEvents(dirname) + for dirname in [run1_dir, run2_dir, run3_dir]: + _AddEvents(dirname) - x.AddRun(run1_dir, 'run1') - x.AddRun(run2_dir, 'run2') - x.AddRun(run3_dir, 'run3') + x.AddRun(run1_dir, "run1") + x.AddRun(run2_dir, "run2") + x.AddRun(run3_dir, "run3") - x.Reload() + x.Reload() - # Delete the directory, then reload. - shutil.rmtree(run2_dir) - x.Reload() - self.assertNotIn('run2', x.Runs().keys()) + # Delete the directory, then reload. + shutil.rmtree(run2_dir) + x.Reload() + self.assertNotIn("run2", x.Runs().keys()) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/backend/event_processing/io_wrapper.py b/tensorboard/backend/event_processing/io_wrapper.py index 7c78ff2bdc..eebd8c79db 100644 --- a/tensorboard/backend/event_processing/io_wrapper.py +++ b/tensorboard/backend/event_processing/io_wrapper.py @@ -29,175 +29,188 @@ logger = tb_logging.get_logger() -_ESCAPE_GLOB_CHARACTERS_REGEX = re.compile('([*?[])') +_ESCAPE_GLOB_CHARACTERS_REGEX = re.compile("([*?[])") def IsCloudPath(path): - return ( - path.startswith("gs://") or - path.startswith("s3://") or - path.startswith("/cns/") - ) + return ( + path.startswith("gs://") + or path.startswith("s3://") + or path.startswith("/cns/") + ) + def PathSeparator(path): - return '/' if IsCloudPath(path) else os.sep + return "/" if IsCloudPath(path) else os.sep + def IsTensorFlowEventsFile(path): - """Check the path name to see if it is probably a TF Events file. + """Check the path name to see if it is probably a TF Events file. - Args: - path: A file path to check if it is an event file. + Args: + path: A file path to check if it is an event file. - Raises: - ValueError: If the path is an empty string. + Raises: + ValueError: If the path is an empty string. - Returns: - If path is formatted like a TensorFlowEventsFile. - """ - if not path: - raise ValueError('Path must be a nonempty string') - return 'tfevents' in tf.compat.as_str_any(os.path.basename(path)) + Returns: + If path is formatted like a TensorFlowEventsFile. + """ + if not path: + raise ValueError("Path must be a nonempty string") + return "tfevents" in tf.compat.as_str_any(os.path.basename(path)) def ListDirectoryAbsolute(directory): - """Yields all files in the given directory. The paths are absolute.""" - return (os.path.join(directory, path) - for path in tf.io.gfile.listdir(directory)) + """Yields all files in the given directory. + + The paths are absolute. + """ + return ( + os.path.join(directory, path) for path in tf.io.gfile.listdir(directory) + ) def _EscapeGlobCharacters(path): - """Escapes the glob characters in a path. + """Escapes the glob characters in a path. - Python 3 has a glob.escape method, but python 2 lacks it, so we manually - implement this method. + Python 3 has a glob.escape method, but python 2 lacks it, so we manually + implement this method. - Args: - path: The absolute path to escape. + Args: + path: The absolute path to escape. - Returns: - The escaped path string. - """ - drive, path = os.path.splitdrive(path) - return '%s%s' % (drive, _ESCAPE_GLOB_CHARACTERS_REGEX.sub(r'[\1]', path)) + Returns: + The escaped path string. + """ + drive, path = os.path.splitdrive(path) + return "%s%s" % (drive, _ESCAPE_GLOB_CHARACTERS_REGEX.sub(r"[\1]", path)) def ListRecursivelyViaGlobbing(top): - """Recursively lists all files within the directory. - - This method does not list subdirectories (in addition to regular files), and - the file paths are all absolute. If the directory does not exist, this yields - nothing. - - This method does so by glob-ing deeper and deeper directories, ie - foo/*, foo/*/*, foo/*/*/* and so on until all files are listed. All file - paths are absolute, and this method lists subdirectories too. - - For certain file systems, globbing via this method may prove significantly - faster than recursively walking a directory. Specifically, TF file systems - that implement TensorFlow's FileSystem.GetMatchingPaths method could save - costly disk reads by using this method. However, for other file systems, this - method might prove slower because the file system performs a walk per call to - glob (in which case it might as well just perform 1 walk). - - Args: - top: A path to a directory. - - Yields: - A (dir_path, file_paths) tuple for each directory/subdirectory. - """ - current_glob_string = os.path.join(_EscapeGlobCharacters(top), '*') - level = 0 - - while True: - logger.info('GlobAndListFiles: Starting to glob level %d', level) - glob = tf.io.gfile.glob(current_glob_string) - logger.info( - 'GlobAndListFiles: %d files glob-ed at level %d', len(glob), level) - - if not glob: - # This subdirectory level lacks files. Terminate. - return - - # Map subdirectory to a list of files. - pairs = collections.defaultdict(list) - for file_path in glob: - pairs[os.path.dirname(file_path)].append(file_path) - for dir_name, file_paths in six.iteritems(pairs): - yield (dir_name, tuple(file_paths)) - - if len(pairs) == 1: - # If at any point the glob returns files that are all in a single - # directory, replace the current globbing path with that directory as the - # literal prefix. This should improve efficiency in cases where a single - # subdir is significantly deeper than the rest of the sudirs. - current_glob_string = os.path.join(list(pairs.keys())[0], '*') - - # Iterate to the next level of subdirectories. - current_glob_string = os.path.join(current_glob_string, '*') - level += 1 + """Recursively lists all files within the directory. + + This method does not list subdirectories (in addition to regular files), and + the file paths are all absolute. If the directory does not exist, this yields + nothing. + + This method does so by glob-ing deeper and deeper directories, ie + foo/*, foo/*/*, foo/*/*/* and so on until all files are listed. All file + paths are absolute, and this method lists subdirectories too. + + For certain file systems, globbing via this method may prove significantly + faster than recursively walking a directory. Specifically, TF file systems + that implement TensorFlow's FileSystem.GetMatchingPaths method could save + costly disk reads by using this method. However, for other file systems, this + method might prove slower because the file system performs a walk per call to + glob (in which case it might as well just perform 1 walk). + + Args: + top: A path to a directory. + + Yields: + A (dir_path, file_paths) tuple for each directory/subdirectory. + """ + current_glob_string = os.path.join(_EscapeGlobCharacters(top), "*") + level = 0 + + while True: + logger.info("GlobAndListFiles: Starting to glob level %d", level) + glob = tf.io.gfile.glob(current_glob_string) + logger.info( + "GlobAndListFiles: %d files glob-ed at level %d", len(glob), level + ) + + if not glob: + # This subdirectory level lacks files. Terminate. + return + + # Map subdirectory to a list of files. + pairs = collections.defaultdict(list) + for file_path in glob: + pairs[os.path.dirname(file_path)].append(file_path) + for dir_name, file_paths in six.iteritems(pairs): + yield (dir_name, tuple(file_paths)) + + if len(pairs) == 1: + # If at any point the glob returns files that are all in a single + # directory, replace the current globbing path with that directory as the + # literal prefix. This should improve efficiency in cases where a single + # subdir is significantly deeper than the rest of the sudirs. + current_glob_string = os.path.join(list(pairs.keys())[0], "*") + + # Iterate to the next level of subdirectories. + current_glob_string = os.path.join(current_glob_string, "*") + level += 1 def ListRecursivelyViaWalking(top): - """Walks a directory tree, yielding (dir_path, file_paths) tuples. + """Walks a directory tree, yielding (dir_path, file_paths) tuples. - For each of `top` and its subdirectories, yields a tuple containing the path - to the directory and the path to each of the contained files. Note that - unlike os.Walk()/tf.io.gfile.walk()/ListRecursivelyViaGlobbing, this does not - list subdirectories. The file paths are all absolute. If the directory does - not exist, this yields nothing. + For each of `top` and its subdirectories, yields a tuple containing the path + to the directory and the path to each of the contained files. Note that + unlike os.Walk()/tf.io.gfile.walk()/ListRecursivelyViaGlobbing, this does not + list subdirectories. The file paths are all absolute. If the directory does + not exist, this yields nothing. - Walking may be incredibly slow on certain file systems. + Walking may be incredibly slow on certain file systems. - Args: - top: A path to a directory. + Args: + top: A path to a directory. - Yields: - A (dir_path, file_paths) tuple for each directory/subdirectory. - """ - for dir_path, _, filenames in tf.io.gfile.walk(top, topdown=True): - yield (dir_path, (os.path.join(dir_path, filename) - for filename in filenames)) + Yields: + A (dir_path, file_paths) tuple for each directory/subdirectory. + """ + for dir_path, _, filenames in tf.io.gfile.walk(top, topdown=True): + yield ( + dir_path, + (os.path.join(dir_path, filename) for filename in filenames), + ) def GetLogdirSubdirectories(path): - """Obtains all subdirectories with events files. - - The order of the subdirectories returned is unspecified. The internal logic - that determines order varies by scenario. - - Args: - path: The path to a directory under which to find subdirectories. - - Returns: - A tuple of absolute paths of all subdirectories each with at least 1 events - file directly within the subdirectory. - - Raises: - ValueError: If the path passed to the method exists and is not a directory. - """ - if not tf.io.gfile.exists(path): - # No directory to traverse. - return () - - if not tf.io.gfile.isdir(path): - raise ValueError('GetLogdirSubdirectories: path exists and is not a ' - 'directory, %s' % path) - - if IsCloudPath(path): - # Glob-ing for files can be significantly faster than recursively - # walking through directories for some file systems. - logger.info( - 'GetLogdirSubdirectories: Starting to list directories via glob-ing.') - traversal_method = ListRecursivelyViaGlobbing - else: - # For other file systems, the glob-ing based method might be slower because - # each call to glob could involve performing a recursive walk. - logger.info( - 'GetLogdirSubdirectories: Starting to list directories via walking.') - traversal_method = ListRecursivelyViaWalking - - return ( - subdir - for (subdir, files) in traversal_method(path) - if any(IsTensorFlowEventsFile(f) for f in files) - ) + """Obtains all subdirectories with events files. + + The order of the subdirectories returned is unspecified. The internal logic + that determines order varies by scenario. + + Args: + path: The path to a directory under which to find subdirectories. + + Returns: + A tuple of absolute paths of all subdirectories each with at least 1 events + file directly within the subdirectory. + + Raises: + ValueError: If the path passed to the method exists and is not a directory. + """ + if not tf.io.gfile.exists(path): + # No directory to traverse. + return () + + if not tf.io.gfile.isdir(path): + raise ValueError( + "GetLogdirSubdirectories: path exists and is not a " + "directory, %s" % path + ) + + if IsCloudPath(path): + # Glob-ing for files can be significantly faster than recursively + # walking through directories for some file systems. + logger.info( + "GetLogdirSubdirectories: Starting to list directories via glob-ing." + ) + traversal_method = ListRecursivelyViaGlobbing + else: + # For other file systems, the glob-ing based method might be slower because + # each call to glob could involve performing a recursive walk. + logger.info( + "GetLogdirSubdirectories: Starting to list directories via walking." + ) + traversal_method = ListRecursivelyViaWalking + + return ( + subdir + for (subdir, files) in traversal_method(path) + if any(IsTensorFlowEventsFile(f) for f in files) + ) diff --git a/tensorboard/backend/event_processing/io_wrapper_test.py b/tensorboard/backend/event_processing/io_wrapper_test.py index 3642275a7f..ab2aed8097 100644 --- a/tensorboard/backend/event_processing/io_wrapper_test.py +++ b/tensorboard/backend/event_processing/io_wrapper_test.py @@ -27,303 +27,268 @@ class IoWrapperTest(tf.test.TestCase): - def setUp(self): - self.stubs = tf.compat.v1.test.StubOutForTesting() - - def tearDown(self): - self.stubs.CleanUp() - - def testIsCloudPathGcsIsTrue(self): - self.assertTrue(io_wrapper.IsCloudPath('gs://bucket/foo')) - - def testIsCloudPathS3IsTrue(self): - self.assertTrue(io_wrapper.IsCloudPath('s3://bucket/foo')) - - def testIsCloudPathCnsIsTrue(self): - self.assertTrue(io_wrapper.IsCloudPath('/cns/foo/bar')) - - def testIsCloudPathFileIsFalse(self): - self.assertFalse(io_wrapper.IsCloudPath('file:///tmp/foo')) - - def testIsCloudPathLocalIsFalse(self): - self.assertFalse(io_wrapper.IsCloudPath('/tmp/foo')) - - def testPathSeparator(self): - # In nix systems, path separator would be the same as that of CNS/GCS - # making it hard to tell if something went wrong. - self.stubs.Set(os, 'sep', '#') - - self.assertEqual(io_wrapper.PathSeparator('/tmp/foo'), '#') - self.assertEqual(io_wrapper.PathSeparator('tmp/foo'), '#') - self.assertEqual(io_wrapper.PathSeparator('/cns/tmp/foo'), '/') - self.assertEqual(io_wrapper.PathSeparator('gs://foo'), '/') - - def testIsIsTensorFlowEventsFileTrue(self): - self.assertTrue( - io_wrapper.IsTensorFlowEventsFile( - '/logdir/events.out.tfevents.1473720042.com')) - - def testIsIsTensorFlowEventsFileFalse(self): - self.assertFalse( - io_wrapper.IsTensorFlowEventsFile('/logdir/model.ckpt')) - - def testIsIsTensorFlowEventsFileWithEmptyInput(self): - with six.assertRaisesRegex(self, - ValueError, - r'Path must be a nonempty string'): - io_wrapper.IsTensorFlowEventsFile('') - - def testListDirectoryAbsolute(self): - temp_dir = tempfile.mkdtemp(prefix=self.get_temp_dir()) - self._CreateDeepDirectoryStructure(temp_dir) - - expected_files = ( - 'foo', - 'bar', - 'quuz', - 'a.tfevents.1', - 'model.ckpt', - 'waldo', - ) - self.assertItemsEqual( - (os.path.join(temp_dir, f) for f in expected_files), - io_wrapper.ListDirectoryAbsolute(temp_dir)) - - def testListRecursivelyViaGlobbing(self): - temp_dir = tempfile.mkdtemp(prefix=self.get_temp_dir()) - self._CreateDeepDirectoryStructure(temp_dir) - expected = [ - ['', [ - 'foo', - 'bar', - 'a.tfevents.1', - 'model.ckpt', - 'quuz', - 'waldo', - ]], - ['bar', [ - 'b.tfevents.1', - 'red_herring.txt', - 'baz', - 'quux', - ]], - ['bar/baz', [ - 'c.tfevents.1', - 'd.tfevents.1', - ]], - ['bar/quux', [ - 'some_flume_output.txt', - 'some_more_flume_output.txt', - ]], - ['quuz', [ - 'e.tfevents.1', - 'garply', - ]], - ['quuz/garply', [ - 'f.tfevents.1', - 'corge', - 'grault', - ]], - ['quuz/garply/corge', [ - 'g.tfevents.1' - ]], - ['quuz/garply/grault', [ - 'h.tfevents.1', - ]], - ['waldo', [ - 'fred', - ]], - ['waldo/fred', [ - 'i.tfevents.1', - ]], - ] - for pair in expected: - # If this is not the top-level directory, prepend the high-level - # directory. - pair[0] = os.path.join(temp_dir, pair[0]) if pair[0] else temp_dir - pair[1] = [os.path.join(pair[0], f) for f in pair[1]] - self._CompareFilesPerSubdirectory( - expected, io_wrapper.ListRecursivelyViaGlobbing(temp_dir)) - - def testListRecursivelyViaGlobbingForPathWithGlobCharacters(self): - temp_dir = tempfile.mkdtemp(prefix=self.get_temp_dir()) - directory_names = ( - 'ba*', - 'ba*/subdirectory', - 'bar', - ) - for directory_name in directory_names: - os.makedirs(os.path.join(temp_dir, directory_name)) - - file_names = ( - 'ba*/a.tfevents.1', - 'ba*/subdirectory/b.tfevents.1', - 'bar/c.tfevents.1', - ) - for file_name in file_names: - open(os.path.join(temp_dir, file_name), 'w').close() - - expected = [ - ['', [ - 'a.tfevents.1', - 'subdirectory', - ]], - ['subdirectory', [ - 'b.tfevents.1', - ]], - # The contents of the bar subdirectory should be excluded from - # this listing because the * character should have been escaped. - ] - top = os.path.join(temp_dir, 'ba*') - for pair in expected: - # If this is not the top-level directory, prepend the high-level - # directory. - pair[0] = os.path.join(top, pair[0]) if pair[0] else top - pair[1] = [os.path.join(pair[0], f) for f in pair[1]] - self._CompareFilesPerSubdirectory( - expected, io_wrapper.ListRecursivelyViaGlobbing(top)) - - def testListRecursivelyViaWalking(self): - temp_dir = tempfile.mkdtemp(prefix=self.get_temp_dir()) - self._CreateDeepDirectoryStructure(temp_dir) - expected = [ - ['', [ - 'a.tfevents.1', - 'model.ckpt', - ]], - ['foo', []], - ['bar', [ - 'b.tfevents.1', - 'red_herring.txt', - ]], - ['bar/baz', [ - 'c.tfevents.1', - 'd.tfevents.1', - ]], - ['bar/quux', [ - 'some_flume_output.txt', - 'some_more_flume_output.txt', - ]], - ['quuz', [ - 'e.tfevents.1', - ]], - ['quuz/garply', [ - 'f.tfevents.1', - ]], - ['quuz/garply/corge', [ - 'g.tfevents.1', - ]], - ['quuz/garply/grault', [ - 'h.tfevents.1', - ]], - ['waldo', []], - ['waldo/fred', [ - 'i.tfevents.1', - ]], - ] - for pair in expected: - # If this is not the top-level directory, prepend the high-level - # directory. - pair[0] = os.path.join(temp_dir, pair[0]) if pair[0] else temp_dir - pair[1] = [os.path.join(pair[0], f) for f in pair[1]] - self._CompareFilesPerSubdirectory( - expected, io_wrapper.ListRecursivelyViaWalking(temp_dir)) - - def testGetLogdirSubdirectories(self): - temp_dir = tempfile.mkdtemp(prefix=self.get_temp_dir()) - self._CreateDeepDirectoryStructure(temp_dir) - # Only subdirectories that immediately contains at least 1 events - # file should be listed. - expected = [ - '', - 'bar', - 'bar/baz', - 'quuz', - 'quuz/garply', - 'quuz/garply/corge', - 'quuz/garply/grault', - 'waldo/fred', - ] - self.assertItemsEqual( - [(os.path.join(temp_dir, subdir) if subdir else temp_dir) - for subdir in expected], - io_wrapper.GetLogdirSubdirectories(temp_dir)) - - def _CreateDeepDirectoryStructure(self, top_directory): - """Creates a reasonable deep structure of subdirectories with files. - - Args: - top_directory: The absolute path of the top level directory in - which to create the directory structure. - """ - # Add a few subdirectories. - directory_names = ( - # An empty directory. - 'foo', - # A directory with an events file (and a text file). - 'bar', - # A deeper directory with events files. - 'bar/baz', - # A non-empty subdirectory that lacks event files (should be ignored). - 'bar/quux', - # This 3-level deep set of subdirectories tests logic that replaces the - # full glob string with an absolute path prefix if there is only 1 - # subdirectory in the final mapping. - 'quuz/garply', - 'quuz/garply/corge', - 'quuz/garply/grault', - # A directory that lacks events files, but contains a subdirectory - # with events files (first level should be ignored, second level should - # be included). - 'waldo', - 'waldo/fred', - ) - for directory_name in directory_names: - os.makedirs(os.path.join(top_directory, directory_name)) - - # Add a few files to the directory. - file_names = ( - 'a.tfevents.1', - 'model.ckpt', - 'bar/b.tfevents.1', - 'bar/red_herring.txt', - 'bar/baz/c.tfevents.1', - 'bar/baz/d.tfevents.1', - 'bar/quux/some_flume_output.txt', - 'bar/quux/some_more_flume_output.txt', - 'quuz/e.tfevents.1', - 'quuz/garply/f.tfevents.1', - 'quuz/garply/corge/g.tfevents.1', - 'quuz/garply/grault/h.tfevents.1', - 'waldo/fred/i.tfevents.1', - ) - for file_name in file_names: - open(os.path.join(top_directory, file_name), 'w').close() - - def _CompareFilesPerSubdirectory(self, expected, gotten): - """Compares iterables of (subdirectory path, list of absolute paths) - - Args: - expected: The expected iterable of 2-tuples. - gotten: The gotten iterable of 2-tuples. - """ - expected_directory_to_listing = { - result[0]: list(result[1]) for result in expected} - gotten_directory_to_listing = { - result[0]: list(result[1]) for result in gotten} - self.assertItemsEqual( - expected_directory_to_listing.keys(), - gotten_directory_to_listing.keys()) - - for subdirectory, expected_listing in expected_directory_to_listing.items(): - gotten_listing = gotten_directory_to_listing[subdirectory] - self.assertItemsEqual( - expected_listing, - gotten_listing, - 'Files for subdirectory %r must match. Expected %r. Got %r.' % ( - subdirectory, expected_listing, gotten_listing)) - - - -if __name__ == '__main__': - tf.test.main() + def setUp(self): + self.stubs = tf.compat.v1.test.StubOutForTesting() + + def tearDown(self): + self.stubs.CleanUp() + + def testIsCloudPathGcsIsTrue(self): + self.assertTrue(io_wrapper.IsCloudPath("gs://bucket/foo")) + + def testIsCloudPathS3IsTrue(self): + self.assertTrue(io_wrapper.IsCloudPath("s3://bucket/foo")) + + def testIsCloudPathCnsIsTrue(self): + self.assertTrue(io_wrapper.IsCloudPath("/cns/foo/bar")) + + def testIsCloudPathFileIsFalse(self): + self.assertFalse(io_wrapper.IsCloudPath("file:///tmp/foo")) + + def testIsCloudPathLocalIsFalse(self): + self.assertFalse(io_wrapper.IsCloudPath("/tmp/foo")) + + def testPathSeparator(self): + # In nix systems, path separator would be the same as that of CNS/GCS + # making it hard to tell if something went wrong. + self.stubs.Set(os, "sep", "#") + + self.assertEqual(io_wrapper.PathSeparator("/tmp/foo"), "#") + self.assertEqual(io_wrapper.PathSeparator("tmp/foo"), "#") + self.assertEqual(io_wrapper.PathSeparator("/cns/tmp/foo"), "/") + self.assertEqual(io_wrapper.PathSeparator("gs://foo"), "/") + + def testIsIsTensorFlowEventsFileTrue(self): + self.assertTrue( + io_wrapper.IsTensorFlowEventsFile( + "/logdir/events.out.tfevents.1473720042.com" + ) + ) + + def testIsIsTensorFlowEventsFileFalse(self): + self.assertFalse( + io_wrapper.IsTensorFlowEventsFile("/logdir/model.ckpt") + ) + + def testIsIsTensorFlowEventsFileWithEmptyInput(self): + with six.assertRaisesRegex( + self, ValueError, r"Path must be a nonempty string" + ): + io_wrapper.IsTensorFlowEventsFile("") + + def testListDirectoryAbsolute(self): + temp_dir = tempfile.mkdtemp(prefix=self.get_temp_dir()) + self._CreateDeepDirectoryStructure(temp_dir) + + expected_files = ( + "foo", + "bar", + "quuz", + "a.tfevents.1", + "model.ckpt", + "waldo", + ) + self.assertItemsEqual( + (os.path.join(temp_dir, f) for f in expected_files), + io_wrapper.ListDirectoryAbsolute(temp_dir), + ) + + def testListRecursivelyViaGlobbing(self): + temp_dir = tempfile.mkdtemp(prefix=self.get_temp_dir()) + self._CreateDeepDirectoryStructure(temp_dir) + expected = [ + [ + "", + ["foo", "bar", "a.tfevents.1", "model.ckpt", "quuz", "waldo",], + ], + ["bar", ["b.tfevents.1", "red_herring.txt", "baz", "quux",]], + ["bar/baz", ["c.tfevents.1", "d.tfevents.1",]], + [ + "bar/quux", + ["some_flume_output.txt", "some_more_flume_output.txt",], + ], + ["quuz", ["e.tfevents.1", "garply",]], + ["quuz/garply", ["f.tfevents.1", "corge", "grault",]], + ["quuz/garply/corge", ["g.tfevents.1"]], + ["quuz/garply/grault", ["h.tfevents.1",]], + ["waldo", ["fred",]], + ["waldo/fred", ["i.tfevents.1",]], + ] + for pair in expected: + # If this is not the top-level directory, prepend the high-level + # directory. + pair[0] = os.path.join(temp_dir, pair[0]) if pair[0] else temp_dir + pair[1] = [os.path.join(pair[0], f) for f in pair[1]] + self._CompareFilesPerSubdirectory( + expected, io_wrapper.ListRecursivelyViaGlobbing(temp_dir) + ) + + def testListRecursivelyViaGlobbingForPathWithGlobCharacters(self): + temp_dir = tempfile.mkdtemp(prefix=self.get_temp_dir()) + directory_names = ( + "ba*", + "ba*/subdirectory", + "bar", + ) + for directory_name in directory_names: + os.makedirs(os.path.join(temp_dir, directory_name)) + + file_names = ( + "ba*/a.tfevents.1", + "ba*/subdirectory/b.tfevents.1", + "bar/c.tfevents.1", + ) + for file_name in file_names: + open(os.path.join(temp_dir, file_name), "w").close() + + expected = [ + ["", ["a.tfevents.1", "subdirectory",]], + ["subdirectory", ["b.tfevents.1",]], + # The contents of the bar subdirectory should be excluded from + # this listing because the * character should have been escaped. + ] + top = os.path.join(temp_dir, "ba*") + for pair in expected: + # If this is not the top-level directory, prepend the high-level + # directory. + pair[0] = os.path.join(top, pair[0]) if pair[0] else top + pair[1] = [os.path.join(pair[0], f) for f in pair[1]] + self._CompareFilesPerSubdirectory( + expected, io_wrapper.ListRecursivelyViaGlobbing(top) + ) + + def testListRecursivelyViaWalking(self): + temp_dir = tempfile.mkdtemp(prefix=self.get_temp_dir()) + self._CreateDeepDirectoryStructure(temp_dir) + expected = [ + ["", ["a.tfevents.1", "model.ckpt",]], + ["foo", []], + ["bar", ["b.tfevents.1", "red_herring.txt",]], + ["bar/baz", ["c.tfevents.1", "d.tfevents.1",]], + [ + "bar/quux", + ["some_flume_output.txt", "some_more_flume_output.txt",], + ], + ["quuz", ["e.tfevents.1",]], + ["quuz/garply", ["f.tfevents.1",]], + ["quuz/garply/corge", ["g.tfevents.1",]], + ["quuz/garply/grault", ["h.tfevents.1",]], + ["waldo", []], + ["waldo/fred", ["i.tfevents.1",]], + ] + for pair in expected: + # If this is not the top-level directory, prepend the high-level + # directory. + pair[0] = os.path.join(temp_dir, pair[0]) if pair[0] else temp_dir + pair[1] = [os.path.join(pair[0], f) for f in pair[1]] + self._CompareFilesPerSubdirectory( + expected, io_wrapper.ListRecursivelyViaWalking(temp_dir) + ) + + def testGetLogdirSubdirectories(self): + temp_dir = tempfile.mkdtemp(prefix=self.get_temp_dir()) + self._CreateDeepDirectoryStructure(temp_dir) + # Only subdirectories that immediately contains at least 1 events + # file should be listed. + expected = [ + "", + "bar", + "bar/baz", + "quuz", + "quuz/garply", + "quuz/garply/corge", + "quuz/garply/grault", + "waldo/fred", + ] + self.assertItemsEqual( + [ + (os.path.join(temp_dir, subdir) if subdir else temp_dir) + for subdir in expected + ], + io_wrapper.GetLogdirSubdirectories(temp_dir), + ) + + def _CreateDeepDirectoryStructure(self, top_directory): + """Creates a reasonable deep structure of subdirectories with files. + + Args: + top_directory: The absolute path of the top level directory in + which to create the directory structure. + """ + # Add a few subdirectories. + directory_names = ( + # An empty directory. + "foo", + # A directory with an events file (and a text file). + "bar", + # A deeper directory with events files. + "bar/baz", + # A non-empty subdirectory that lacks event files (should be ignored). + "bar/quux", + # This 3-level deep set of subdirectories tests logic that replaces the + # full glob string with an absolute path prefix if there is only 1 + # subdirectory in the final mapping. + "quuz/garply", + "quuz/garply/corge", + "quuz/garply/grault", + # A directory that lacks events files, but contains a subdirectory + # with events files (first level should be ignored, second level should + # be included). + "waldo", + "waldo/fred", + ) + for directory_name in directory_names: + os.makedirs(os.path.join(top_directory, directory_name)) + + # Add a few files to the directory. + file_names = ( + "a.tfevents.1", + "model.ckpt", + "bar/b.tfevents.1", + "bar/red_herring.txt", + "bar/baz/c.tfevents.1", + "bar/baz/d.tfevents.1", + "bar/quux/some_flume_output.txt", + "bar/quux/some_more_flume_output.txt", + "quuz/e.tfevents.1", + "quuz/garply/f.tfevents.1", + "quuz/garply/corge/g.tfevents.1", + "quuz/garply/grault/h.tfevents.1", + "waldo/fred/i.tfevents.1", + ) + for file_name in file_names: + open(os.path.join(top_directory, file_name), "w").close() + + def _CompareFilesPerSubdirectory(self, expected, gotten): + """Compares iterables of (subdirectory path, list of absolute paths) + + Args: + expected: The expected iterable of 2-tuples. + gotten: The gotten iterable of 2-tuples. + """ + expected_directory_to_listing = { + result[0]: list(result[1]) for result in expected + } + gotten_directory_to_listing = { + result[0]: list(result[1]) for result in gotten + } + self.assertItemsEqual( + expected_directory_to_listing.keys(), + gotten_directory_to_listing.keys(), + ) + + for ( + subdirectory, + expected_listing, + ) in expected_directory_to_listing.items(): + gotten_listing = gotten_directory_to_listing[subdirectory] + self.assertItemsEqual( + expected_listing, + gotten_listing, + "Files for subdirectory %r must match. Expected %r. Got %r." + % (subdirectory, expected_listing, gotten_listing), + ) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/backend/event_processing/plugin_asset_util.py b/tensorboard/backend/event_processing/plugin_asset_util.py index 4adf4a8f80..b13385c62c 100644 --- a/tensorboard/backend/event_processing/plugin_asset_util.py +++ b/tensorboard/backend/event_processing/plugin_asset_util.py @@ -26,78 +26,83 @@ def _IsDirectory(parent, item): - """Helper that returns if parent/item is a directory.""" - return tf.io.gfile.isdir(os.path.join(parent, item)) + """Helper that returns if parent/item is a directory.""" + return tf.io.gfile.isdir(os.path.join(parent, item)) def PluginDirectory(logdir, plugin_name): - """Returns the plugin directory for plugin_name.""" - return os.path.join(logdir, _PLUGINS_DIR, plugin_name) + """Returns the plugin directory for plugin_name.""" + return os.path.join(logdir, _PLUGINS_DIR, plugin_name) def ListPlugins(logdir): - """List all the plugins that have registered assets in logdir. - - If the plugins_dir does not exist, it returns an empty list. This maintains - compatibility with old directories that have no plugins written. - - Args: - logdir: A directory that was created by a TensorFlow events writer. - - Returns: - a list of plugin names, as strings - """ - plugins_dir = os.path.join(logdir, _PLUGINS_DIR) - try: - entries = tf.io.gfile.listdir(plugins_dir) - except tf.errors.NotFoundError: - return [] - # Strip trailing slashes, which listdir() includes for some filesystems - # for subdirectories, after using them to bypass IsDirectory(). - return [x.rstrip('/') for x in entries - if x.endswith('/') or _IsDirectory(plugins_dir, x)] + """List all the plugins that have registered assets in logdir. + + If the plugins_dir does not exist, it returns an empty list. This maintains + compatibility with old directories that have no plugins written. + + Args: + logdir: A directory that was created by a TensorFlow events writer. + + Returns: + a list of plugin names, as strings + """ + plugins_dir = os.path.join(logdir, _PLUGINS_DIR) + try: + entries = tf.io.gfile.listdir(plugins_dir) + except tf.errors.NotFoundError: + return [] + # Strip trailing slashes, which listdir() includes for some filesystems + # for subdirectories, after using them to bypass IsDirectory(). + return [ + x.rstrip("/") + for x in entries + if x.endswith("/") or _IsDirectory(plugins_dir, x) + ] def ListAssets(logdir, plugin_name): - """List all the assets that are available for given plugin in a logdir. + """List all the assets that are available for given plugin in a logdir. - Args: - logdir: A directory that was created by a TensorFlow summary.FileWriter. - plugin_name: A string name of a plugin to list assets for. + Args: + logdir: A directory that was created by a TensorFlow summary.FileWriter. + plugin_name: A string name of a plugin to list assets for. - Returns: - A string list of available plugin assets. If the plugin subdirectory does - not exist (either because the logdir doesn't exist, or because the plugin - didn't register) an empty list is returned. - """ - plugin_dir = PluginDirectory(logdir, plugin_name) - try: - # Strip trailing slashes, which listdir() includes for some filesystems. - return [x.rstrip('/') for x in tf.io.gfile.listdir(plugin_dir)] - except tf.errors.NotFoundError: - return [] + Returns: + A string list of available plugin assets. If the plugin subdirectory does + not exist (either because the logdir doesn't exist, or because the plugin + didn't register) an empty list is returned. + """ + plugin_dir = PluginDirectory(logdir, plugin_name) + try: + # Strip trailing slashes, which listdir() includes for some filesystems. + return [x.rstrip("/") for x in tf.io.gfile.listdir(plugin_dir)] + except tf.errors.NotFoundError: + return [] def RetrieveAsset(logdir, plugin_name, asset_name): - """Retrieve a particular plugin asset from a logdir. - - Args: - logdir: A directory that was created by a TensorFlow summary.FileWriter. - plugin_name: The plugin we want an asset from. - asset_name: The name of the requested asset. - - Returns: - string contents of the plugin asset. - - Raises: - KeyError: if the asset does not exist. - """ - - asset_path = os.path.join(PluginDirectory(logdir, plugin_name), asset_name) - try: - with tf.io.gfile.GFile(asset_path, "r") as f: - return f.read() - except tf.errors.NotFoundError: - raise KeyError("Asset path %s not found" % asset_path) - except tf.errors.OpError as e: - raise KeyError("Couldn't read asset path: %s, OpError %s" % (asset_path, e)) + """Retrieve a particular plugin asset from a logdir. + + Args: + logdir: A directory that was created by a TensorFlow summary.FileWriter. + plugin_name: The plugin we want an asset from. + asset_name: The name of the requested asset. + + Returns: + string contents of the plugin asset. + + Raises: + KeyError: if the asset does not exist. + """ + + asset_path = os.path.join(PluginDirectory(logdir, plugin_name), asset_name) + try: + with tf.io.gfile.GFile(asset_path, "r") as f: + return f.read() + except tf.errors.NotFoundError: + raise KeyError("Asset path %s not found" % asset_path) + except tf.errors.OpError as e: + raise KeyError( + "Couldn't read asset path: %s, OpError %s" % (asset_path, e) + ) diff --git a/tensorboard/backend/event_processing/plugin_event_accumulator.py b/tensorboard/backend/event_processing/plugin_event_accumulator.py index 8fbea6f53a..8a976fd8df 100644 --- a/tensorboard/backend/event_processing/plugin_event_accumulator.py +++ b/tensorboard/backend/event_processing/plugin_event_accumulator.py @@ -41,14 +41,14 @@ namedtuple = collections.namedtuple -TensorEvent = namedtuple('TensorEvent', ['wall_time', 'step', 'tensor_proto']) +TensorEvent = namedtuple("TensorEvent", ["wall_time", "step", "tensor_proto"]) ## The tagTypes below are just arbitrary strings chosen to pass the type ## information of the tag from the backend to the frontend -TENSORS = 'tensors' -GRAPH = 'graph' -META_GRAPH = 'meta_graph' -RUN_METADATA = 'run_metadata' +TENSORS = "tensors" +GRAPH = "graph" +META_GRAPH = "meta_graph" +RUN_METADATA = "run_metadata" DEFAULT_SIZE_GUIDANCE = { TENSORS: 500, @@ -62,550 +62,604 @@ class EventAccumulator(object): - """An `EventAccumulator` takes an event generator, and accumulates the values. - - The `EventAccumulator` is intended to provide a convenient Python - interface for loading Event data written during a TensorFlow run. - TensorFlow writes out `Event` protobuf objects, which have a timestamp - and step number, and often contain a `Summary`. Summaries can have - different kinds of data stored as arbitrary tensors. The Summaries - also have a tag, which we use to organize logically related data. The - `EventAccumulator` supports retrieving the `Event` and `Summary` data - by its tag. - - Calling `Tags()` gets a map from `tagType` (i.e., `tensors`) to the - associated tags for those data types. Then, the functional endpoint - (i.g., `Accumulator.Tensors(tag)`) allows for the retrieval of all - data associated with that tag. - - The `Reload()` method synchronously loads all of the data written so far. - - Fields: - most_recent_step: Step of last Event proto added. This should only - be accessed from the thread that calls Reload. This is -1 if - nothing has been loaded yet. - most_recent_wall_time: Timestamp of last Event proto added. This is - a float containing seconds from the UNIX epoch, or -1 if - nothing has been loaded yet. This should only be accessed from - the thread that calls Reload. - path: A file path to a directory containing tf events files, or a single - tf events file. The accumulator will load events from this path. - tensors_by_tag: A dictionary mapping each tag name to a - reservoir.Reservoir of tensor summaries. Each such reservoir will - only use a single key, given by `_TENSOR_RESERVOIR_KEY`. - - @@Tensors - """ - - def __init__(self, - path, - size_guidance=None, - tensor_size_guidance=None, - purge_orphaned_data=True, - event_file_active_filter=None): - """Construct the `EventAccumulator`. - - Args: + """An `EventAccumulator` takes an event generator, and accumulates the + values. + + The `EventAccumulator` is intended to provide a convenient Python + interface for loading Event data written during a TensorFlow run. + TensorFlow writes out `Event` protobuf objects, which have a timestamp + and step number, and often contain a `Summary`. Summaries can have + different kinds of data stored as arbitrary tensors. The Summaries + also have a tag, which we use to organize logically related data. The + `EventAccumulator` supports retrieving the `Event` and `Summary` data + by its tag. + + Calling `Tags()` gets a map from `tagType` (i.e., `tensors`) to the + associated tags for those data types. Then, the functional endpoint + (i.g., `Accumulator.Tensors(tag)`) allows for the retrieval of all + data associated with that tag. + + The `Reload()` method synchronously loads all of the data written so far. + + Fields: + most_recent_step: Step of last Event proto added. This should only + be accessed from the thread that calls Reload. This is -1 if + nothing has been loaded yet. + most_recent_wall_time: Timestamp of last Event proto added. This is + a float containing seconds from the UNIX epoch, or -1 if + nothing has been loaded yet. This should only be accessed from + the thread that calls Reload. path: A file path to a directory containing tf events files, or a single - tf events file. The accumulator will load events from this path. - size_guidance: Information on how much data the EventAccumulator should - store in memory. The DEFAULT_SIZE_GUIDANCE tries not to store too much - so as to avoid OOMing the client. The size_guidance should be a map - from a `tagType` string to an integer representing the number of - items to keep per tag for items of that `tagType`. If the size is 0, - all events are stored. - tensor_size_guidance: Like `size_guidance`, but allowing finer - granularity for tensor summaries. Should be a map from the - `plugin_name` field on the `PluginData` proto to an integer - representing the number of items to keep per tag. Plugins for - which there is no entry in this map will default to the value of - `size_guidance[event_accumulator.TENSORS]`. Defaults to `{}`. - purge_orphaned_data: Whether to discard any events that were "orphaned" by - a TensorFlow restart. - event_file_active_filter: Optional predicate for determining whether an - event file latest load timestamp should be considered active. If passed, - this will enable multifile directory loading. - """ - size_guidance = dict(size_guidance or DEFAULT_SIZE_GUIDANCE) - sizes = {} - for key in DEFAULT_SIZE_GUIDANCE: - if key in size_guidance: - sizes[key] = size_guidance[key] - else: - sizes[key] = DEFAULT_SIZE_GUIDANCE[key] - self._size_guidance = size_guidance - self._tensor_size_guidance = dict(tensor_size_guidance or {}) - - self._first_event_timestamp = None - - self._graph = None - self._graph_from_metagraph = False - self._meta_graph = None - self._tagged_metadata = {} - self.summary_metadata = {} - self.tensors_by_tag = {} - self._tensors_by_tag_lock = threading.Lock() - - # Keep a mapping from plugin name to a dict mapping from tag to plugin data - # content obtained from the SummaryMetadata (metadata field of Value) for - # that plugin (This is not the entire SummaryMetadata proto - only the - # content for that plugin). The SummaryWriter only keeps the content on the - # first event encountered per tag, so we must store that first instance of - # content for each tag. - self._plugin_to_tag_to_content = collections.defaultdict(dict) - self._plugin_tag_locks = collections.defaultdict(threading.Lock) - - self.path = path - self._generator = _GeneratorFromPath(path, event_file_active_filter) - self._generator_mutex = threading.Lock() - - self.purge_orphaned_data = purge_orphaned_data - - self.most_recent_step = -1 - self.most_recent_wall_time = -1 - self.file_version = None - - def Reload(self): - """Loads all events added since the last call to `Reload`. - - If `Reload` was never called, loads all events in the file. - - Returns: - The `EventAccumulator`. - """ - with self._generator_mutex: - for event in self._generator.Load(): - self._ProcessEvent(event) - return self - - def PluginAssets(self, plugin_name): - """Return a list of all plugin assets for the given plugin. - - Args: - plugin_name: The string name of a plugin to retrieve assets for. - - Returns: - A list of string plugin asset names, or empty list if none are available. - If the plugin was not registered, an empty list is returned. - """ - return plugin_asset_util.ListAssets(self.path, plugin_name) - - def RetrievePluginAsset(self, plugin_name, asset_name): - """Return the contents of a given plugin asset. - - Args: - plugin_name: The string name of a plugin. - asset_name: The string name of an asset. - - Returns: - The string contents of the plugin asset. - - Raises: - KeyError: If the asset is not available. - """ - return plugin_asset_util.RetrieveAsset(self.path, plugin_name, asset_name) - - def FirstEventTimestamp(self): - """Returns the timestamp in seconds of the first event. - - If the first event has been loaded (either by this method or by `Reload`, - this returns immediately. Otherwise, it will load in the first event. Note - that this means that calling `Reload` will cause this to block until - `Reload` has finished. - - Returns: - The timestamp in seconds of the first event that was loaded. + tf events file. The accumulator will load events from this path. + tensors_by_tag: A dictionary mapping each tag name to a + reservoir.Reservoir of tensor summaries. Each such reservoir will + only use a single key, given by `_TENSOR_RESERVOIR_KEY`. - Raises: - ValueError: If no events have been loaded and there were no events found - on disk. + @@Tensors """ - if self._first_event_timestamp is not None: - return self._first_event_timestamp - with self._generator_mutex: - try: - event = next(self._generator.Load()) - self._ProcessEvent(event) - return self._first_event_timestamp - except StopIteration: - raise ValueError('No event timestamp could be found') - - def PluginTagToContent(self, plugin_name): - """Returns a dict mapping tags to content specific to that plugin. - - Args: - plugin_name: The name of the plugin for which to fetch plugin-specific - content. - - Raises: - KeyError: if the plugin name is not found. - - Returns: - A dict mapping tag names to bytestrings of plugin-specific content-- by - convention, in the form of binary serialized protos. - """ - if plugin_name not in self._plugin_to_tag_to_content: - raise KeyError('Plugin %r could not be found.' % plugin_name) - with self._plugin_tag_locks[plugin_name]: - # Return a snapshot to avoid concurrent mutation and iteration issues. - return dict(self._plugin_to_tag_to_content[plugin_name]) - - def SummaryMetadata(self, tag): - """Given a summary tag name, return the associated metadata object. - - Args: - tag: The name of a tag, as a string. - - Raises: - KeyError: If the tag is not found. - - Returns: - A `SummaryMetadata` protobuf. - """ - return self.summary_metadata[tag] - - def _ProcessEvent(self, event): - """Called whenever an event is loaded.""" - if self._first_event_timestamp is None: - self._first_event_timestamp = event.wall_time - - if event.HasField('file_version'): - new_file_version = _ParseFileVersion(event.file_version) - if self.file_version and self.file_version != new_file_version: - ## This should not happen. - logger.warn(('Found new file_version for event.proto. This will ' - 'affect purging logic for TensorFlow restarts. ' - 'Old: {0} New: {1}').format(self.file_version, - new_file_version)) - self.file_version = new_file_version - - self._MaybePurgeOrphanedData(event) - - ## Process the event. - # GraphDef and MetaGraphDef are handled in a special way: - # If no graph_def Event is available, but a meta_graph_def is, and it - # contains a graph_def, then use the meta_graph_def.graph_def as our graph. - # If a graph_def Event is available, always prefer it to the graph_def - # inside the meta_graph_def. - if event.HasField('graph_def'): - if self._graph is not None: - logger.warn( - ('Found more than one graph event per run, or there was ' - 'a metagraph containing a graph_def, as well as one or ' - 'more graph events. Overwriting the graph with the ' - 'newest event.')) - self._graph = event.graph_def - self._graph_from_metagraph = False - elif event.HasField('meta_graph_def'): - if self._meta_graph is not None: - logger.warn(('Found more than one metagraph event per run. ' - 'Overwriting the metagraph with the newest event.')) - self._meta_graph = event.meta_graph_def - if self._graph is None or self._graph_from_metagraph: - # We may have a graph_def in the metagraph. If so, and no - # graph_def is directly available, use this one instead. + def __init__( + self, + path, + size_guidance=None, + tensor_size_guidance=None, + purge_orphaned_data=True, + event_file_active_filter=None, + ): + """Construct the `EventAccumulator`. + + Args: + path: A file path to a directory containing tf events files, or a single + tf events file. The accumulator will load events from this path. + size_guidance: Information on how much data the EventAccumulator should + store in memory. The DEFAULT_SIZE_GUIDANCE tries not to store too much + so as to avoid OOMing the client. The size_guidance should be a map + from a `tagType` string to an integer representing the number of + items to keep per tag for items of that `tagType`. If the size is 0, + all events are stored. + tensor_size_guidance: Like `size_guidance`, but allowing finer + granularity for tensor summaries. Should be a map from the + `plugin_name` field on the `PluginData` proto to an integer + representing the number of items to keep per tag. Plugins for + which there is no entry in this map will default to the value of + `size_guidance[event_accumulator.TENSORS]`. Defaults to `{}`. + purge_orphaned_data: Whether to discard any events that were "orphaned" by + a TensorFlow restart. + event_file_active_filter: Optional predicate for determining whether an + event file latest load timestamp should be considered active. If passed, + this will enable multifile directory loading. + """ + size_guidance = dict(size_guidance or DEFAULT_SIZE_GUIDANCE) + sizes = {} + for key in DEFAULT_SIZE_GUIDANCE: + if key in size_guidance: + sizes[key] = size_guidance[key] + else: + sizes[key] = DEFAULT_SIZE_GUIDANCE[key] + self._size_guidance = size_guidance + self._tensor_size_guidance = dict(tensor_size_guidance or {}) + + self._first_event_timestamp = None + + self._graph = None + self._graph_from_metagraph = False + self._meta_graph = None + self._tagged_metadata = {} + self.summary_metadata = {} + self.tensors_by_tag = {} + self._tensors_by_tag_lock = threading.Lock() + + # Keep a mapping from plugin name to a dict mapping from tag to plugin data + # content obtained from the SummaryMetadata (metadata field of Value) for + # that plugin (This is not the entire SummaryMetadata proto - only the + # content for that plugin). The SummaryWriter only keeps the content on the + # first event encountered per tag, so we must store that first instance of + # content for each tag. + self._plugin_to_tag_to_content = collections.defaultdict(dict) + self._plugin_tag_locks = collections.defaultdict(threading.Lock) + + self.path = path + self._generator = _GeneratorFromPath(path, event_file_active_filter) + self._generator_mutex = threading.Lock() + + self.purge_orphaned_data = purge_orphaned_data + + self.most_recent_step = -1 + self.most_recent_wall_time = -1 + self.file_version = None + + def Reload(self): + """Loads all events added since the last call to `Reload`. + + If `Reload` was never called, loads all events in the file. + + Returns: + The `EventAccumulator`. + """ + with self._generator_mutex: + for event in self._generator.Load(): + self._ProcessEvent(event) + return self + + def PluginAssets(self, plugin_name): + """Return a list of all plugin assets for the given plugin. + + Args: + plugin_name: The string name of a plugin to retrieve assets for. + + Returns: + A list of string plugin asset names, or empty list if none are available. + If the plugin was not registered, an empty list is returned. + """ + return plugin_asset_util.ListAssets(self.path, plugin_name) + + def RetrievePluginAsset(self, plugin_name, asset_name): + """Return the contents of a given plugin asset. + + Args: + plugin_name: The string name of a plugin. + asset_name: The string name of an asset. + + Returns: + The string contents of the plugin asset. + + Raises: + KeyError: If the asset is not available. + """ + return plugin_asset_util.RetrieveAsset( + self.path, plugin_name, asset_name + ) + + def FirstEventTimestamp(self): + """Returns the timestamp in seconds of the first event. + + If the first event has been loaded (either by this method or by `Reload`, + this returns immediately. Otherwise, it will load in the first event. Note + that this means that calling `Reload` will cause this to block until + `Reload` has finished. + + Returns: + The timestamp in seconds of the first event that was loaded. + + Raises: + ValueError: If no events have been loaded and there were no events found + on disk. + """ + if self._first_event_timestamp is not None: + return self._first_event_timestamp + with self._generator_mutex: + try: + event = next(self._generator.Load()) + self._ProcessEvent(event) + return self._first_event_timestamp + + except StopIteration: + raise ValueError("No event timestamp could be found") + + def PluginTagToContent(self, plugin_name): + """Returns a dict mapping tags to content specific to that plugin. + + Args: + plugin_name: The name of the plugin for which to fetch plugin-specific + content. + + Raises: + KeyError: if the plugin name is not found. + + Returns: + A dict mapping tag names to bytestrings of plugin-specific content-- by + convention, in the form of binary serialized protos. + """ + if plugin_name not in self._plugin_to_tag_to_content: + raise KeyError("Plugin %r could not be found." % plugin_name) + with self._plugin_tag_locks[plugin_name]: + # Return a snapshot to avoid concurrent mutation and iteration issues. + return dict(self._plugin_to_tag_to_content[plugin_name]) + + def SummaryMetadata(self, tag): + """Given a summary tag name, return the associated metadata object. + + Args: + tag: The name of a tag, as a string. + + Raises: + KeyError: If the tag is not found. + + Returns: + A `SummaryMetadata` protobuf. + """ + return self.summary_metadata[tag] + + def _ProcessEvent(self, event): + """Called whenever an event is loaded.""" + if self._first_event_timestamp is None: + self._first_event_timestamp = event.wall_time + + if event.HasField("file_version"): + new_file_version = _ParseFileVersion(event.file_version) + if self.file_version and self.file_version != new_file_version: + ## This should not happen. + logger.warn( + ( + "Found new file_version for event.proto. This will " + "affect purging logic for TensorFlow restarts. " + "Old: {0} New: {1}" + ).format(self.file_version, new_file_version) + ) + self.file_version = new_file_version + + self._MaybePurgeOrphanedData(event) + + ## Process the event. + # GraphDef and MetaGraphDef are handled in a special way: + # If no graph_def Event is available, but a meta_graph_def is, and it + # contains a graph_def, then use the meta_graph_def.graph_def as our graph. + # If a graph_def Event is available, always prefer it to the graph_def + # inside the meta_graph_def. + if event.HasField("graph_def"): + if self._graph is not None: + logger.warn( + ( + "Found more than one graph event per run, or there was " + "a metagraph containing a graph_def, as well as one or " + "more graph events. Overwriting the graph with the " + "newest event." + ) + ) + self._graph = event.graph_def + self._graph_from_metagraph = False + elif event.HasField("meta_graph_def"): + if self._meta_graph is not None: + logger.warn( + ( + "Found more than one metagraph event per run. " + "Overwriting the metagraph with the newest event." + ) + ) + self._meta_graph = event.meta_graph_def + if self._graph is None or self._graph_from_metagraph: + # We may have a graph_def in the metagraph. If so, and no + # graph_def is directly available, use this one instead. + meta_graph = meta_graph_pb2.MetaGraphDef() + meta_graph.ParseFromString(self._meta_graph) + if meta_graph.graph_def: + if self._graph is not None: + logger.warn( + ( + "Found multiple metagraphs containing graph_defs," + "but did not find any graph events. Overwriting the " + "graph with the newest metagraph version." + ) + ) + self._graph_from_metagraph = True + self._graph = meta_graph.graph_def.SerializeToString() + elif event.HasField("tagged_run_metadata"): + tag = event.tagged_run_metadata.tag + if tag in self._tagged_metadata: + logger.warn( + 'Found more than one "run metadata" event with tag ' + + tag + + ". Overwriting it with the newest event." + ) + self._tagged_metadata[tag] = event.tagged_run_metadata.run_metadata + elif event.HasField("summary"): + for value in event.summary.value: + value = data_compat.migrate_value(value) + + if value.HasField("metadata"): + tag = value.tag + # We only store the first instance of the metadata. This check + # is important: the `FileWriter` does strip metadata from all + # values except the first one per each tag, but a new + # `FileWriter` is created every time a training job stops and + # restarts. Hence, we must also ignore non-initial metadata in + # this logic. + if tag not in self.summary_metadata: + self.summary_metadata[tag] = value.metadata + plugin_data = value.metadata.plugin_data + if plugin_data.plugin_name: + with self._plugin_tag_locks[ + plugin_data.plugin_name + ]: + self._plugin_to_tag_to_content[ + plugin_data.plugin_name + ][tag] = plugin_data.content + else: + logger.warn( + ( + "This summary with tag %r is oddly not associated with a " + "plugin." + ), + tag, + ) + + if value.HasField("tensor"): + datum = value.tensor + tag = value.tag + if not tag: + # This tensor summary was created using the old method that used + # plugin assets. We must still continue to support it. + tag = value.node_name + self._ProcessTensor(tag, event.wall_time, event.step, datum) + + def Tags(self): + """Return all tags found in the value stream. + + Returns: + A `{tagType: ['list', 'of', 'tags']}` dictionary. + """ + return { + TENSORS: list(self.tensors_by_tag.keys()), + # Use a heuristic: if the metagraph is available, but + # graph is not, then we assume the metagraph contains the graph. + GRAPH: self._graph is not None, + META_GRAPH: self._meta_graph is not None, + RUN_METADATA: list(self._tagged_metadata.keys()), + } + + def Graph(self): + """Return the graph definition, if there is one. + + If the graph is stored directly, return that. If no graph is stored + directly but a metagraph is stored containing a graph, return that. + + Raises: + ValueError: If there is no graph for this run. + + Returns: + The `graph_def` proto. + """ + graph = graph_pb2.GraphDef() + if self._graph is not None: + graph.ParseFromString(self._graph) + return graph + raise ValueError("There is no graph in this EventAccumulator") + + def SerializedGraph(self): + """Return the graph definition in serialized form, if there is one.""" + return self._graph + + def MetaGraph(self): + """Return the metagraph definition, if there is one. + + Raises: + ValueError: If there is no metagraph for this run. + + Returns: + The `meta_graph_def` proto. + """ + if self._meta_graph is None: + raise ValueError("There is no metagraph in this EventAccumulator") meta_graph = meta_graph_pb2.MetaGraphDef() meta_graph.ParseFromString(self._meta_graph) - if meta_graph.graph_def: - if self._graph is not None: - logger.warn( - ('Found multiple metagraphs containing graph_defs,' - 'but did not find any graph events. Overwriting the ' - 'graph with the newest metagraph version.')) - self._graph_from_metagraph = True - self._graph = meta_graph.graph_def.SerializeToString() - elif event.HasField('tagged_run_metadata'): - tag = event.tagged_run_metadata.tag - if tag in self._tagged_metadata: - logger.warn('Found more than one "run metadata" event with tag ' + - tag + '. Overwriting it with the newest event.') - self._tagged_metadata[tag] = event.tagged_run_metadata.run_metadata - elif event.HasField('summary'): - for value in event.summary.value: - value = data_compat.migrate_value(value) - - if value.HasField('metadata'): - tag = value.tag - # We only store the first instance of the metadata. This check - # is important: the `FileWriter` does strip metadata from all - # values except the first one per each tag, but a new - # `FileWriter` is created every time a training job stops and - # restarts. Hence, we must also ignore non-initial metadata in - # this logic. - if tag not in self.summary_metadata: - self.summary_metadata[tag] = value.metadata - plugin_data = value.metadata.plugin_data - if plugin_data.plugin_name: - with self._plugin_tag_locks[plugin_data.plugin_name]: - self._plugin_to_tag_to_content[plugin_data.plugin_name][tag] = ( - plugin_data.content) - else: - logger.warn( - ('This summary with tag %r is oddly not associated with a ' - 'plugin.'), tag) - - if value.HasField('tensor'): - datum = value.tensor - tag = value.tag - if not tag: - # This tensor summary was created using the old method that used - # plugin assets. We must still continue to support it. - tag = value.node_name - self._ProcessTensor(tag, event.wall_time, event.step, datum) - - def Tags(self): - """Return all tags found in the value stream. - - Returns: - A `{tagType: ['list', 'of', 'tags']}` dictionary. - """ - return { - TENSORS: list(self.tensors_by_tag.keys()), - # Use a heuristic: if the metagraph is available, but - # graph is not, then we assume the metagraph contains the graph. - GRAPH: self._graph is not None, - META_GRAPH: self._meta_graph is not None, - RUN_METADATA: list(self._tagged_metadata.keys()) - } - - def Graph(self): - """Return the graph definition, if there is one. - - If the graph is stored directly, return that. If no graph is stored - directly but a metagraph is stored containing a graph, return that. + return meta_graph + + def RunMetadata(self, tag): + """Given a tag, return the associated session.run() metadata. + + Args: + tag: A string tag associated with the event. + + Raises: + ValueError: If the tag is not found. + + Returns: + The metadata in form of `RunMetadata` proto. + """ + if tag not in self._tagged_metadata: + raise ValueError("There is no run metadata with this tag name") + + run_metadata = config_pb2.RunMetadata() + run_metadata.ParseFromString(self._tagged_metadata[tag]) + return run_metadata + + def Tensors(self, tag): + """Given a summary tag, return all associated tensors. + + Args: + tag: A string tag associated with the events. + + Raises: + KeyError: If the tag is not found. + + Returns: + An array of `TensorEvent`s. + """ + return self.tensors_by_tag[tag].Items(_TENSOR_RESERVOIR_KEY) + + def _MaybePurgeOrphanedData(self, event): + """Maybe purge orphaned data due to a TensorFlow crash. + + When TensorFlow crashes at step T+O and restarts at step T, any events + written after step T are now "orphaned" and will be at best misleading if + they are included in TensorBoard. + + This logic attempts to determine if there is orphaned data, and purge it + if it is found. + + Args: + event: The event to use as a reference, to determine if a purge is needed. + """ + if not self.purge_orphaned_data: + return + ## Check if the event happened after a crash, and purge expired tags. + if self.file_version and self.file_version >= 2: + ## If the file_version is recent enough, use the SessionLog enum + ## to check for restarts. + self._CheckForRestartAndMaybePurge(event) + else: + ## If there is no file version, default to old logic of checking for + ## out of order steps. + self._CheckForOutOfOrderStepAndMaybePurge(event) + # After checking, update the most recent summary step and wall time. + if event.HasField("summary"): + self.most_recent_step = event.step + self.most_recent_wall_time = event.wall_time + + def _CheckForRestartAndMaybePurge(self, event): + """Check and discard expired events using SessionLog.START. + + Check for a SessionLog.START event and purge all previously seen events + with larger steps, because they are out of date. Because of supervisor + threading, it is possible that this logic will cause the first few event + messages to be discarded since supervisor threading does not guarantee + that the START message is deterministically written first. + + This method is preferred over _CheckForOutOfOrderStepAndMaybePurge which + can inadvertently discard events due to supervisor threading. + + Args: + event: The event to use as reference. If the event is a START event, all + previously seen events with a greater event.step will be purged. + """ + if ( + event.HasField("session_log") + and event.session_log.status == event_pb2.SessionLog.START + ): + self._Purge(event, by_tags=False) + + def _CheckForOutOfOrderStepAndMaybePurge(self, event): + """Check for out-of-order event.step and discard expired events for + tags. + + Check if the event is out of order relative to the global most recent step. + If it is, purge outdated summaries for tags that the event contains. + + Args: + event: The event to use as reference. If the event is out-of-order, all + events with the same tags, but with a greater event.step will be purged. + """ + if event.step < self.most_recent_step and event.HasField("summary"): + self._Purge(event, by_tags=True) + + def _ProcessTensor(self, tag, wall_time, step, tensor): + tv = TensorEvent(wall_time=wall_time, step=step, tensor_proto=tensor) + with self._tensors_by_tag_lock: + if tag not in self.tensors_by_tag: + reservoir_size = self._GetTensorReservoirSize(tag) + self.tensors_by_tag[tag] = reservoir.Reservoir(reservoir_size) + self.tensors_by_tag[tag].AddItem(_TENSOR_RESERVOIR_KEY, tv) + + def _GetTensorReservoirSize(self, tag): + default = self._size_guidance[TENSORS] + summary_metadata = self.summary_metadata.get(tag) + if summary_metadata is None: + return default + return self._tensor_size_guidance.get( + summary_metadata.plugin_data.plugin_name, default + ) + + def _Purge(self, event, by_tags): + """Purge all events that have occurred after the given event.step. + + If by_tags is True, purge all events that occurred after the given + event.step, but only for the tags that the event has. Non-sequential + event.steps suggest that a TensorFlow restart occurred, and we discard + the out-of-order events to display a consistent view in TensorBoard. + + Discarding by tags is the safer method, when we are unsure whether a restart + has occurred, given that threading in supervisor can cause events of + different tags to arrive with unsynchronized step values. + + If by_tags is False, then purge all events with event.step greater than the + given event.step. This can be used when we are certain that a TensorFlow + restart has occurred and these events can be discarded. + + Args: + event: The event to use as reference for the purge. All events with + the same tags, but with a greater event.step will be purged. + by_tags: Bool to dictate whether to discard all out-of-order events or + only those that are associated with the given reference event. + """ + ## Keep data in reservoirs that has a step less than event.step + _NotExpired = lambda x: x.step < event.step + + num_expired = 0 + if by_tags: + for value in event.summary.value: + if value.tag in self.tensors_by_tag: + tag_reservoir = self.tensors_by_tag[value.tag] + num_expired += tag_reservoir.FilterItems( + _NotExpired, _TENSOR_RESERVOIR_KEY + ) + else: + for tag_reservoir in six.itervalues(self.tensors_by_tag): + num_expired += tag_reservoir.FilterItems( + _NotExpired, _TENSOR_RESERVOIR_KEY + ) + if num_expired > 0: + purge_msg = _GetPurgeMessage( + self.most_recent_step, + self.most_recent_wall_time, + event.step, + event.wall_time, + num_expired, + ) + logger.warn(purge_msg) + + +def _GetPurgeMessage( + most_recent_step, + most_recent_wall_time, + event_step, + event_wall_time, + num_expired, +): + """Return the string message associated with TensorBoard purges.""" + return ( + "Detected out of order event.step likely caused by a TensorFlow " + "restart. Purging {} expired tensor events from Tensorboard display " + "between the previous step: {} (timestamp: {}) and current step: {} " + "(timestamp: {})." + ).format( + num_expired, + most_recent_step, + most_recent_wall_time, + event_step, + event_wall_time, + ) - Raises: - ValueError: If there is no graph for this run. - Returns: - The `graph_def` proto. - """ - graph = graph_pb2.GraphDef() - if self._graph is not None: - graph.ParseFromString(self._graph) - return graph - raise ValueError('There is no graph in this EventAccumulator') - - def SerializedGraph(self): - """Return the graph definition in serialized form, if there is one.""" - return self._graph - - def MetaGraph(self): - """Return the metagraph definition, if there is one. - - Raises: - ValueError: If there is no metagraph for this run. - - Returns: - The `meta_graph_def` proto. - """ - if self._meta_graph is None: - raise ValueError('There is no metagraph in this EventAccumulator') - meta_graph = meta_graph_pb2.MetaGraphDef() - meta_graph.ParseFromString(self._meta_graph) - return meta_graph - - def RunMetadata(self, tag): - """Given a tag, return the associated session.run() metadata. - - Args: - tag: A string tag associated with the event. - - Raises: - ValueError: If the tag is not found. - - Returns: - The metadata in form of `RunMetadata` proto. - """ - if tag not in self._tagged_metadata: - raise ValueError('There is no run metadata with this tag name') - - run_metadata = config_pb2.RunMetadata() - run_metadata.ParseFromString(self._tagged_metadata[tag]) - return run_metadata - - def Tensors(self, tag): - """Given a summary tag, return all associated tensors. - - Args: - tag: A string tag associated with the events. - - Raises: - KeyError: If the tag is not found. - - Returns: - An array of `TensorEvent`s. - """ - return self.tensors_by_tag[tag].Items(_TENSOR_RESERVOIR_KEY) - - def _MaybePurgeOrphanedData(self, event): - """Maybe purge orphaned data due to a TensorFlow crash. - - When TensorFlow crashes at step T+O and restarts at step T, any events - written after step T are now "orphaned" and will be at best misleading if - they are included in TensorBoard. - - This logic attempts to determine if there is orphaned data, and purge it - if it is found. - - Args: - event: The event to use as a reference, to determine if a purge is needed. - """ - if not self.purge_orphaned_data: - return - ## Check if the event happened after a crash, and purge expired tags. - if self.file_version and self.file_version >= 2: - ## If the file_version is recent enough, use the SessionLog enum - ## to check for restarts. - self._CheckForRestartAndMaybePurge(event) +def _GeneratorFromPath(path, event_file_active_filter=None): + """Create an event generator for file or directory at given path string.""" + if not path: + raise ValueError("path must be a valid string") + if io_wrapper.IsTensorFlowEventsFile(path): + return event_file_loader.EventFileLoader(path) + elif event_file_active_filter: + return directory_loader.DirectoryLoader( + path, + event_file_loader.TimestampedEventFileLoader, + path_filter=io_wrapper.IsTensorFlowEventsFile, + active_filter=event_file_active_filter, + ) else: - ## If there is no file version, default to old logic of checking for - ## out of order steps. - self._CheckForOutOfOrderStepAndMaybePurge(event) - # After checking, update the most recent summary step and wall time. - if event.HasField('summary'): - self.most_recent_step = event.step - self.most_recent_wall_time = event.wall_time - - def _CheckForRestartAndMaybePurge(self, event): - """Check and discard expired events using SessionLog.START. - - Check for a SessionLog.START event and purge all previously seen events - with larger steps, because they are out of date. Because of supervisor - threading, it is possible that this logic will cause the first few event - messages to be discarded since supervisor threading does not guarantee - that the START message is deterministically written first. - - This method is preferred over _CheckForOutOfOrderStepAndMaybePurge which - can inadvertently discard events due to supervisor threading. - - Args: - event: The event to use as reference. If the event is a START event, all - previously seen events with a greater event.step will be purged. - """ - if event.HasField( - 'session_log') and event.session_log.status == event_pb2.SessionLog.START: - self._Purge(event, by_tags=False) + return directory_watcher.DirectoryWatcher( + path, + event_file_loader.EventFileLoader, + io_wrapper.IsTensorFlowEventsFile, + ) - def _CheckForOutOfOrderStepAndMaybePurge(self, event): - """Check for out-of-order event.step and discard expired events for tags. - Check if the event is out of order relative to the global most recent step. - If it is, purge outdated summaries for tags that the event contains. +def _ParseFileVersion(file_version): + """Convert the string file_version in event.proto into a float. Args: - event: The event to use as reference. If the event is out-of-order, all - events with the same tags, but with a greater event.step will be purged. - """ - if event.step < self.most_recent_step and event.HasField('summary'): - self._Purge(event, by_tags=True) - - def _ProcessTensor(self, tag, wall_time, step, tensor): - tv = TensorEvent(wall_time=wall_time, step=step, tensor_proto=tensor) - with self._tensors_by_tag_lock: - if tag not in self.tensors_by_tag: - reservoir_size = self._GetTensorReservoirSize(tag) - self.tensors_by_tag[tag] = reservoir.Reservoir(reservoir_size) - self.tensors_by_tag[tag].AddItem(_TENSOR_RESERVOIR_KEY, tv) - - def _GetTensorReservoirSize(self, tag): - default = self._size_guidance[TENSORS] - summary_metadata = self.summary_metadata.get(tag) - if summary_metadata is None: - return default - return self._tensor_size_guidance.get( - summary_metadata.plugin_data.plugin_name, default) - - def _Purge(self, event, by_tags): - """Purge all events that have occurred after the given event.step. - - If by_tags is True, purge all events that occurred after the given - event.step, but only for the tags that the event has. Non-sequential - event.steps suggest that a TensorFlow restart occurred, and we discard - the out-of-order events to display a consistent view in TensorBoard. - - Discarding by tags is the safer method, when we are unsure whether a restart - has occurred, given that threading in supervisor can cause events of - different tags to arrive with unsynchronized step values. - - If by_tags is False, then purge all events with event.step greater than the - given event.step. This can be used when we are certain that a TensorFlow - restart has occurred and these events can be discarded. + file_version: String file_version from event.proto - Args: - event: The event to use as reference for the purge. All events with - the same tags, but with a greater event.step will be purged. - by_tags: Bool to dictate whether to discard all out-of-order events or - only those that are associated with the given reference event. + Returns: + Version number as a float. """ - ## Keep data in reservoirs that has a step less than event.step - _NotExpired = lambda x: x.step < event.step - - num_expired = 0 - if by_tags: - for value in event.summary.value: - if value.tag in self.tensors_by_tag: - tag_reservoir = self.tensors_by_tag[value.tag] - num_expired += tag_reservoir.FilterItems( - _NotExpired, _TENSOR_RESERVOIR_KEY) - else: - for tag_reservoir in six.itervalues(self.tensors_by_tag): - num_expired += tag_reservoir.FilterItems( - _NotExpired, _TENSOR_RESERVOIR_KEY) - if num_expired > 0: - purge_msg = _GetPurgeMessage(self.most_recent_step, - self.most_recent_wall_time, event.step, - event.wall_time, num_expired) - logger.warn(purge_msg) - - -def _GetPurgeMessage(most_recent_step, most_recent_wall_time, event_step, - event_wall_time, num_expired): - """Return the string message associated with TensorBoard purges.""" - return ('Detected out of order event.step likely caused by a TensorFlow ' - 'restart. Purging {} expired tensor events from Tensorboard display ' - 'between the previous step: {} (timestamp: {}) and current step: {} ' - '(timestamp: {}).' - ).format(num_expired, most_recent_step, most_recent_wall_time, - event_step, event_wall_time) - - -def _GeneratorFromPath(path, event_file_active_filter=None): - """Create an event generator for file or directory at given path string.""" - if not path: - raise ValueError('path must be a valid string') - if io_wrapper.IsTensorFlowEventsFile(path): - return event_file_loader.EventFileLoader(path) - elif event_file_active_filter: - return directory_loader.DirectoryLoader( - path, - event_file_loader.TimestampedEventFileLoader, - path_filter=io_wrapper.IsTensorFlowEventsFile, - active_filter=event_file_active_filter) - else: - return directory_watcher.DirectoryWatcher( - path, - event_file_loader.EventFileLoader, - io_wrapper.IsTensorFlowEventsFile) - - -def _ParseFileVersion(file_version): - """Convert the string file_version in event.proto into a float. - - Args: - file_version: String file_version from event.proto - - Returns: - Version number as a float. - """ - tokens = file_version.split('brain.Event:') - try: - return float(tokens[-1]) - except ValueError: - ## This should never happen according to the definition of file_version - ## specified in event.proto. - logger.warn( - ('Invalid event.proto file_version. Defaulting to use of ' - 'out-of-order event.step logic for purging expired events.')) - return -1 + tokens = file_version.split("brain.Event:") + try: + return float(tokens[-1]) + except ValueError: + ## This should never happen according to the definition of file_version + ## specified in event.proto. + logger.warn( + ( + "Invalid event.proto file_version. Defaulting to use of " + "out-of-order event.step logic for purging expired events." + ) + ) + return -1 diff --git a/tensorboard/backend/event_processing/plugin_event_accumulator_test.py b/tensorboard/backend/event_processing/plugin_event_accumulator_test.py index c6861a9838..17782c5f3c 100644 --- a/tensorboard/backend/event_processing/plugin_event_accumulator_test.py +++ b/tensorboard/backend/event_processing/plugin_event_accumulator_test.py @@ -41,720 +41,788 @@ class _EventGenerator(object): - """Class that can add_events and then yield them back. + """Class that can add_events and then yield them back. - Satisfies the EventGenerator API required for the EventAccumulator. - Satisfies the EventWriter API required to create a tf.summary.FileWriter. + Satisfies the EventGenerator API required for the EventAccumulator. + Satisfies the EventWriter API required to create a tf.summary.FileWriter. - Has additional convenience methods for adding test events. - """ - - def __init__(self, testcase, zero_out_timestamps=False): - self._testcase = testcase - self.items = [] - self.zero_out_timestamps = zero_out_timestamps - - def Load(self): - while self.items: - yield self.items.pop(0) - - def AddScalarTensor(self, tag, wall_time=0, step=0, value=0): - """Add a rank-0 tensor event. - - Note: This is not related to the scalar plugin; it's just a - convenience function to add an event whose contents aren't - important. + Has additional convenience methods for adding test events. """ - tensor = tensor_util.make_tensor_proto(float(value)) - event = event_pb2.Event( - wall_time=wall_time, - step=step, - summary=summary_pb2.Summary( - value=[summary_pb2.Summary.Value(tag=tag, tensor=tensor)])) - self.AddEvent(event) - def AddEvent(self, event): - if self.zero_out_timestamps: - event.wall_time = 0. - self.items.append(event) - - def add_event(self, event): # pylint: disable=invalid-name - """Match the EventWriter API.""" - self.AddEvent(event) - - def get_logdir(self): # pylint: disable=invalid-name - """Return a temp directory for asset writing.""" - return self._testcase.get_temp_dir() + def __init__(self, testcase, zero_out_timestamps=False): + self._testcase = testcase + self.items = [] + self.zero_out_timestamps = zero_out_timestamps + + def Load(self): + while self.items: + yield self.items.pop(0) + + def AddScalarTensor(self, tag, wall_time=0, step=0, value=0): + """Add a rank-0 tensor event. + + Note: This is not related to the scalar plugin; it's just a + convenience function to add an event whose contents aren't + important. + """ + tensor = tensor_util.make_tensor_proto(float(value)) + event = event_pb2.Event( + wall_time=wall_time, + step=step, + summary=summary_pb2.Summary( + value=[summary_pb2.Summary.Value(tag=tag, tensor=tensor)] + ), + ) + self.AddEvent(event) + + def AddEvent(self, event): + if self.zero_out_timestamps: + event.wall_time = 0.0 + self.items.append(event) + + def add_event(self, event): # pylint: disable=invalid-name + """Match the EventWriter API.""" + self.AddEvent(event) + + def get_logdir(self): # pylint: disable=invalid-name + """Return a temp directory for asset writing.""" + return self._testcase.get_temp_dir() class EventAccumulatorTest(tf.test.TestCase): - - def assertTagsEqual(self, actual, expected): - """Utility method for checking the return value of the Tags() call. - - It fills out the `expected` arg with the default (empty) values for every - tag type, so that the author needs only specify the non-empty values they - are interested in testing. - - Args: - actual: The actual Accumulator tags response. - expected: The expected tags response (empty fields may be omitted) - """ - - empty_tags = { - ea.GRAPH: False, - ea.META_GRAPH: False, - ea.RUN_METADATA: [], - ea.TENSORS: [], - } - - # Verifies that there are no unexpected keys in the actual response. - # If this line fails, likely you added a new tag type, and need to update - # the empty_tags dictionary above. - self.assertItemsEqual(actual.keys(), empty_tags.keys()) - - for key in actual: - expected_value = expected.get(key, empty_tags[key]) - if isinstance(expected_value, list): - self.assertItemsEqual(actual[key], expected_value) - else: - self.assertEqual(actual[key], expected_value) + def assertTagsEqual(self, actual, expected): + """Utility method for checking the return value of the Tags() call. + + It fills out the `expected` arg with the default (empty) values for every + tag type, so that the author needs only specify the non-empty values they + are interested in testing. + + Args: + actual: The actual Accumulator tags response. + expected: The expected tags response (empty fields may be omitted) + """ + + empty_tags = { + ea.GRAPH: False, + ea.META_GRAPH: False, + ea.RUN_METADATA: [], + ea.TENSORS: [], + } + + # Verifies that there are no unexpected keys in the actual response. + # If this line fails, likely you added a new tag type, and need to update + # the empty_tags dictionary above. + self.assertItemsEqual(actual.keys(), empty_tags.keys()) + + for key in actual: + expected_value = expected.get(key, empty_tags[key]) + if isinstance(expected_value, list): + self.assertItemsEqual(actual[key], expected_value) + else: + self.assertEqual(actual[key], expected_value) class MockingEventAccumulatorTest(EventAccumulatorTest): + def setUp(self): + super(MockingEventAccumulatorTest, self).setUp() + self.stubs = tf.compat.v1.test.StubOutForTesting() + self._real_constructor = ea.EventAccumulator + self._real_generator = ea._GeneratorFromPath + + def _FakeAccumulatorConstructor(generator, *args, **kwargs): + def _FakeGeneratorFromPath(path, event_file_active_filter=None): + return generator + + ea._GeneratorFromPath = _FakeGeneratorFromPath + return self._real_constructor(generator, *args, **kwargs) + + ea.EventAccumulator = _FakeAccumulatorConstructor + + def tearDown(self): + self.stubs.CleanUp() + ea.EventAccumulator = self._real_constructor + ea._GeneratorFromPath = self._real_generator + + def testEmptyAccumulator(self): + gen = _EventGenerator(self) + x = ea.EventAccumulator(gen) + x.Reload() + self.assertTagsEqual(x.Tags(), {}) + + def testReload(self): + """EventAccumulator contains suitable tags after calling Reload.""" + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + acc.Reload() + self.assertTagsEqual(acc.Tags(), {}) + gen.AddScalarTensor("s1", wall_time=1, step=10, value=50) + gen.AddScalarTensor("s2", wall_time=1, step=10, value=80) + acc.Reload() + self.assertTagsEqual(acc.Tags(), {ea.TENSORS: ["s1", "s2"],}) + + def testKeyError(self): + """KeyError should be raised when accessing non-existing keys.""" + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + acc.Reload() + with self.assertRaises(KeyError): + acc.Tensors("s1") + + def testNonValueEvents(self): + """Non-value events in the generator don't cause early exits.""" + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + gen.AddScalarTensor("s1", wall_time=1, step=10, value=20) + gen.AddEvent( + event_pb2.Event(wall_time=2, step=20, file_version="nots2") + ) + gen.AddScalarTensor("s3", wall_time=3, step=100, value=1) + + acc.Reload() + self.assertTagsEqual(acc.Tags(), {ea.TENSORS: ["s1", "s3"],}) + + def testExpiredDataDiscardedAfterRestartForFileVersionLessThan2(self): + """Tests that events are discarded after a restart is detected. + + If a step value is observed to be lower than what was previously seen, + this should force a discard of all previous items with the same tag + that are outdated. + + Only file versions < 2 use this out-of-order discard logic. Later versions + discard events based on the step value of SessionLog.START. + """ + warnings = [] + self.stubs.Set(logger, "warn", warnings.append) + + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + + gen.AddEvent( + event_pb2.Event(wall_time=0, step=0, file_version="brain.Event:1") + ) + gen.AddScalarTensor("s1", wall_time=1, step=100, value=20) + gen.AddScalarTensor("s1", wall_time=1, step=200, value=20) + gen.AddScalarTensor("s1", wall_time=1, step=300, value=20) + acc.Reload() + ## Check that number of items are what they should be + self.assertEqual([x.step for x in acc.Tensors("s1")], [100, 200, 300]) + + gen.AddScalarTensor("s1", wall_time=1, step=101, value=20) + gen.AddScalarTensor("s1", wall_time=1, step=201, value=20) + gen.AddScalarTensor("s1", wall_time=1, step=301, value=20) + acc.Reload() + ## Check that we have discarded 200 and 300 from s1 + self.assertEqual( + [x.step for x in acc.Tensors("s1")], [100, 101, 201, 301] + ) + + def testOrphanedDataNotDiscardedIfFlagUnset(self): + """Tests that events are not discarded if purge_orphaned_data is + false.""" + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen, purge_orphaned_data=False) + + gen.AddEvent( + event_pb2.Event(wall_time=0, step=0, file_version="brain.Event:1") + ) + gen.AddScalarTensor("s1", wall_time=1, step=100, value=20) + gen.AddScalarTensor("s1", wall_time=1, step=200, value=20) + gen.AddScalarTensor("s1", wall_time=1, step=300, value=20) + acc.Reload() + ## Check that number of items are what they should be + self.assertEqual([x.step for x in acc.Tensors("s1")], [100, 200, 300]) + + gen.AddScalarTensor("s1", wall_time=1, step=101, value=20) + gen.AddScalarTensor("s1", wall_time=1, step=201, value=20) + gen.AddScalarTensor("s1", wall_time=1, step=301, value=20) + acc.Reload() + ## Check that we have NOT discarded 200 and 300 from s1 + self.assertEqual( + [x.step for x in acc.Tensors("s1")], [100, 200, 300, 101, 201, 301] + ) + + def testEventsDiscardedPerTagAfterRestartForFileVersionLessThan2(self): + """Tests that event discards after restart, only affect the misordered + tag. + + If a step value is observed to be lower than what was previously seen, + this should force a discard of all previous items that are outdated, but + only for the out of order tag. Other tags should remain unaffected. + + Only file versions < 2 use this out-of-order discard logic. Later versions + discard events based on the step value of SessionLog.START. + """ + warnings = [] + self.stubs.Set(logger, "warn", warnings.append) + + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + + gen.AddEvent( + event_pb2.Event(wall_time=0, step=0, file_version="brain.Event:1") + ) + gen.AddScalarTensor("s1", wall_time=1, step=100, value=20) + gen.AddScalarTensor("s2", wall_time=1, step=101, value=20) + gen.AddScalarTensor("s1", wall_time=1, step=200, value=20) + gen.AddScalarTensor("s2", wall_time=1, step=201, value=20) + gen.AddScalarTensor("s1", wall_time=1, step=300, value=20) + gen.AddScalarTensor("s2", wall_time=1, step=301, value=20) + gen.AddScalarTensor("s1", wall_time=1, step=101, value=20) + gen.AddScalarTensor("s3", wall_time=1, step=101, value=20) + gen.AddScalarTensor("s1", wall_time=1, step=201, value=20) + gen.AddScalarTensor("s1", wall_time=1, step=301, value=20) + + acc.Reload() + ## Check that we have discarded 200 and 300 for s1 + self.assertEqual( + [x.step for x in acc.Tensors("s1")], [100, 101, 201, 301] + ) + + ## Check that s1 discards do not affect s2 (written before out-of-order) + ## or s3 (written after out-of-order). + ## i.e. check that only events from the out of order tag are discarded + self.assertEqual([x.step for x in acc.Tensors("s2")], [101, 201, 301]) + self.assertEqual([x.step for x in acc.Tensors("s3")], [101]) + + def testOnlySummaryEventsTriggerDiscards(self): + """Test that file version event does not trigger data purge.""" + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + gen.AddScalarTensor("s1", wall_time=1, step=100, value=20) + ev1 = event_pb2.Event(wall_time=2, step=0, file_version="brain.Event:1") + graph_bytes = tf.compat.v1.GraphDef().SerializeToString() + ev2 = event_pb2.Event(wall_time=3, step=0, graph_def=graph_bytes) + gen.AddEvent(ev1) + gen.AddEvent(ev2) + acc.Reload() + self.assertEqual([x.step for x in acc.Tensors("s1")], [100]) + + def testSessionLogStartMessageDiscardsExpiredEvents(self): + """Test that SessionLog.START message discards expired events. + + This discard logic is preferred over the out-of-order step + discard logic, but this logic can only be used for event protos + which have the SessionLog enum, which was introduced to + event.proto for file_version >= brain.Event:2. + """ + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + gen.AddEvent( + event_pb2.Event(wall_time=0, step=1, file_version="brain.Event:2") + ) + + gen.AddScalarTensor("s1", wall_time=1, step=100, value=20) + gen.AddScalarTensor("s1", wall_time=1, step=200, value=20) + gen.AddScalarTensor("s1", wall_time=1, step=300, value=20) + gen.AddScalarTensor("s1", wall_time=1, step=400, value=20) + + gen.AddScalarTensor("s2", wall_time=1, step=202, value=20) + gen.AddScalarTensor("s2", wall_time=1, step=203, value=20) + + slog = event_pb2.SessionLog(status=event_pb2.SessionLog.START) + gen.AddEvent(event_pb2.Event(wall_time=2, step=201, session_log=slog)) + acc.Reload() + self.assertEqual([x.step for x in acc.Tensors("s1")], [100, 200]) + self.assertEqual([x.step for x in acc.Tensors("s2")], []) + + def testFirstEventTimestamp(self): + """Test that FirstEventTimestamp() returns wall_time of the first + event.""" + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + gen.AddEvent( + event_pb2.Event(wall_time=10, step=20, file_version="brain.Event:2") + ) + gen.AddScalarTensor("s1", wall_time=30, step=40, value=20) + self.assertEqual(acc.FirstEventTimestamp(), 10) + + def testReloadPopulatesFirstEventTimestamp(self): + """Test that Reload() means FirstEventTimestamp() won't load events.""" + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + gen.AddEvent( + event_pb2.Event(wall_time=1, step=2, file_version="brain.Event:2") + ) + + acc.Reload() + + def _Die(*args, **kwargs): # pylint: disable=unused-argument + raise RuntimeError("Load() should not be called") + + self.stubs.Set(gen, "Load", _Die) + self.assertEqual(acc.FirstEventTimestamp(), 1) + + def testFirstEventTimestampLoadsEvent(self): + """Test that FirstEventTimestamp() doesn't discard the loaded event.""" + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + gen.AddEvent( + event_pb2.Event(wall_time=1, step=2, file_version="brain.Event:2") + ) + + self.assertEqual(acc.FirstEventTimestamp(), 1) + acc.Reload() + self.assertEqual(acc.file_version, 2.0) + + def testNewStyleScalarSummary(self): + """Verify processing of tensorboard.plugins.scalar.summary.""" + event_sink = _EventGenerator(self, zero_out_timestamps=True) + writer = test_util.FileWriter(self.get_temp_dir()) + writer.event_writer = event_sink + with tf.compat.v1.Graph().as_default(): + with self.test_session() as sess: + step = tf.compat.v1.placeholder(tf.float32, shape=[]) + scalar_summary.op( + "accuracy", 1.0 - 1.0 / (step + tf.constant(1.0)) + ) + scalar_summary.op("xent", 1.0 / (step + tf.constant(1.0))) + merged = tf.compat.v1.summary.merge_all() + writer.add_graph(sess.graph) + for i in xrange(10): + summ = sess.run(merged, feed_dict={step: float(i)}) + writer.add_summary(summ, global_step=i) + + accumulator = ea.EventAccumulator(event_sink) + accumulator.Reload() + + tags = [ + u"accuracy/scalar_summary", + u"xent/scalar_summary", + ] + + self.assertTagsEqual( + accumulator.Tags(), + {ea.TENSORS: tags, ea.GRAPH: True, ea.META_GRAPH: False,}, + ) + + def testNewStyleAudioSummary(self): + """Verify processing of tensorboard.plugins.audio.summary.""" + event_sink = _EventGenerator(self, zero_out_timestamps=True) + writer = test_util.FileWriter(self.get_temp_dir()) + writer.event_writer = event_sink + with tf.compat.v1.Graph().as_default(): + with self.test_session() as sess: + ipt = tf.random.normal(shape=[5, 441, 2]) + with tf.name_scope("1"): + audio_summary.op( + "one", ipt, sample_rate=44100, max_outputs=1 + ) + with tf.name_scope("2"): + audio_summary.op( + "two", ipt, sample_rate=44100, max_outputs=2 + ) + with tf.name_scope("3"): + audio_summary.op( + "three", ipt, sample_rate=44100, max_outputs=3 + ) + merged = tf.compat.v1.summary.merge_all() + writer.add_graph(sess.graph) + for i in xrange(10): + summ = sess.run(merged) + writer.add_summary(summ, global_step=i) + + accumulator = ea.EventAccumulator(event_sink) + accumulator.Reload() + + tags = [ + u"1/one/audio_summary", + u"2/two/audio_summary", + u"3/three/audio_summary", + ] + + self.assertTagsEqual( + accumulator.Tags(), + {ea.TENSORS: tags, ea.GRAPH: True, ea.META_GRAPH: False,}, + ) + + def testNewStyleImageSummary(self): + """Verify processing of tensorboard.plugins.image.summary.""" + event_sink = _EventGenerator(self, zero_out_timestamps=True) + writer = test_util.FileWriter(self.get_temp_dir()) + writer.event_writer = event_sink + with tf.compat.v1.Graph().as_default(): + with self.test_session() as sess: + ipt = tf.ones([10, 4, 4, 3], tf.uint8) + # This is an interesting example, because the old tf.image_summary op + # would throw an error here, because it would be tag reuse. + # Using the tf node name instead allows argument re-use to the image + # summary. + with tf.name_scope("1"): + image_summary.op("images", ipt, max_outputs=1) + with tf.name_scope("2"): + image_summary.op("images", ipt, max_outputs=2) + with tf.name_scope("3"): + image_summary.op("images", ipt, max_outputs=3) + merged = tf.compat.v1.summary.merge_all() + writer.add_graph(sess.graph) + for i in xrange(10): + summ = sess.run(merged) + writer.add_summary(summ, global_step=i) + + accumulator = ea.EventAccumulator(event_sink) + accumulator.Reload() + + tags = [ + u"1/images/image_summary", + u"2/images/image_summary", + u"3/images/image_summary", + ] + + self.assertTagsEqual( + accumulator.Tags(), + {ea.TENSORS: tags, ea.GRAPH: True, ea.META_GRAPH: False,}, + ) + + def testTFSummaryTensor(self): + """Verify processing of tf.summary.tensor.""" + event_sink = _EventGenerator(self, zero_out_timestamps=True) + writer = test_util.FileWriter(self.get_temp_dir()) + writer.event_writer = event_sink + with tf.compat.v1.Graph().as_default(): + with self.test_session() as sess: + tensor_summary = tf.compat.v1.summary.tensor_summary + tensor_summary("scalar", tf.constant(1.0)) + tensor_summary("vector", tf.constant([1.0, 2.0, 3.0])) + tensor_summary("string", tf.constant(six.b("foobar"))) + merged = tf.compat.v1.summary.merge_all() + summ = sess.run(merged) + writer.add_summary(summ, 0) + + accumulator = ea.EventAccumulator(event_sink) + accumulator.Reload() + + self.assertTagsEqual( + accumulator.Tags(), {ea.TENSORS: ["scalar", "vector", "string"],} + ) + + scalar_proto = accumulator.Tensors("scalar")[0].tensor_proto + scalar = tensor_util.make_ndarray(scalar_proto) + vector_proto = accumulator.Tensors("vector")[0].tensor_proto + vector = tensor_util.make_ndarray(vector_proto) + string_proto = accumulator.Tensors("string")[0].tensor_proto + string = tensor_util.make_ndarray(string_proto) + + self.assertTrue(np.array_equal(scalar, 1.0)) + self.assertTrue(np.array_equal(vector, [1.0, 2.0, 3.0])) + self.assertTrue(np.array_equal(string, six.b("foobar"))) + + def _testTFSummaryTensor_SizeGuidance( + self, plugin_name, tensor_size_guidance, steps, expected_count + ): + event_sink = _EventGenerator(self, zero_out_timestamps=True) + writer = test_util.FileWriter(self.get_temp_dir()) + writer.event_writer = event_sink + with tf.compat.v1.Graph().as_default(): + with self.test_session() as sess: + summary_metadata = summary_pb2.SummaryMetadata( + plugin_data=summary_pb2.SummaryMetadata.PluginData( + plugin_name=plugin_name, content=b"{}" + ) + ) + tf.compat.v1.summary.tensor_summary( + "scalar", + tf.constant(1.0), + summary_metadata=summary_metadata, + ) + merged = tf.compat.v1.summary.merge_all() + for step in xrange(steps): + writer.add_summary(sess.run(merged), global_step=step) + + accumulator = ea.EventAccumulator( + event_sink, tensor_size_guidance=tensor_size_guidance + ) + accumulator.Reload() + + tensors = accumulator.Tensors("scalar") + self.assertEqual(len(tensors), expected_count) + + def testTFSummaryTensor_SizeGuidance_DefaultToTensorGuidance(self): + self._testTFSummaryTensor_SizeGuidance( + plugin_name="jabberwocky", + tensor_size_guidance={}, + steps=ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] + 1, + expected_count=ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS], + ) + + def testTFSummaryTensor_SizeGuidance_UseSmallSingularPluginGuidance(self): + size = int(ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] / 2) + assert size < ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS], size + self._testTFSummaryTensor_SizeGuidance( + plugin_name="jabberwocky", + tensor_size_guidance={"jabberwocky": size}, + steps=ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] + 1, + expected_count=size, + ) + + def testTFSummaryTensor_SizeGuidance_UseLargeSingularPluginGuidance(self): + size = ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] + 5 + self._testTFSummaryTensor_SizeGuidance( + plugin_name="jabberwocky", + tensor_size_guidance={"jabberwocky": size}, + steps=ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] + 10, + expected_count=size, + ) + + def testTFSummaryTensor_SizeGuidance_IgnoreIrrelevantGuidances(self): + size_small = int(ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] / 3) + size_large = int(ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] / 2) + assert size_small < size_large < ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS], ( + size_small, + size_large, + ) + self._testTFSummaryTensor_SizeGuidance( + plugin_name="jabberwocky", + tensor_size_guidance={ + "jabberwocky": size_small, + "wnoorejbpxl": size_large, + }, + steps=ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] + 1, + expected_count=size_small, + ) - def setUp(self): - super(MockingEventAccumulatorTest, self).setUp() - self.stubs = tf.compat.v1.test.StubOutForTesting() - self._real_constructor = ea.EventAccumulator - self._real_generator = ea._GeneratorFromPath - - def _FakeAccumulatorConstructor(generator, *args, **kwargs): - def _FakeGeneratorFromPath(path, event_file_active_filter=None): - return generator - ea._GeneratorFromPath = _FakeGeneratorFromPath - return self._real_constructor(generator, *args, **kwargs) - - ea.EventAccumulator = _FakeAccumulatorConstructor - - def tearDown(self): - self.stubs.CleanUp() - ea.EventAccumulator = self._real_constructor - ea._GeneratorFromPath = self._real_generator - - def testEmptyAccumulator(self): - gen = _EventGenerator(self) - x = ea.EventAccumulator(gen) - x.Reload() - self.assertTagsEqual(x.Tags(), {}) - - def testReload(self): - """EventAccumulator contains suitable tags after calling Reload.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - acc.Reload() - self.assertTagsEqual(acc.Tags(), {}) - gen.AddScalarTensor('s1', wall_time=1, step=10, value=50) - gen.AddScalarTensor('s2', wall_time=1, step=10, value=80) - acc.Reload() - self.assertTagsEqual(acc.Tags(), { - ea.TENSORS: ['s1', 's2'], - }) - - def testKeyError(self): - """KeyError should be raised when accessing non-existing keys.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - acc.Reload() - with self.assertRaises(KeyError): - acc.Tensors('s1') - - def testNonValueEvents(self): - """Non-value events in the generator don't cause early exits.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - gen.AddScalarTensor('s1', wall_time=1, step=10, value=20) - gen.AddEvent( - event_pb2.Event(wall_time=2, step=20, file_version='nots2')) - gen.AddScalarTensor('s3', wall_time=3, step=100, value=1) - - acc.Reload() - self.assertTagsEqual(acc.Tags(), { - ea.TENSORS: ['s1', 's3'], - }) - - def testExpiredDataDiscardedAfterRestartForFileVersionLessThan2(self): - """Tests that events are discarded after a restart is detected. - - If a step value is observed to be lower than what was previously seen, - this should force a discard of all previous items with the same tag - that are outdated. - - Only file versions < 2 use this out-of-order discard logic. Later versions - discard events based on the step value of SessionLog.START. - """ - warnings = [] - self.stubs.Set(logger, 'warn', warnings.append) - - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - - gen.AddEvent( - event_pb2.Event(wall_time=0, step=0, file_version='brain.Event:1')) - gen.AddScalarTensor('s1', wall_time=1, step=100, value=20) - gen.AddScalarTensor('s1', wall_time=1, step=200, value=20) - gen.AddScalarTensor('s1', wall_time=1, step=300, value=20) - acc.Reload() - ## Check that number of items are what they should be - self.assertEqual([x.step for x in acc.Tensors('s1')], [100, 200, 300]) - - gen.AddScalarTensor('s1', wall_time=1, step=101, value=20) - gen.AddScalarTensor('s1', wall_time=1, step=201, value=20) - gen.AddScalarTensor('s1', wall_time=1, step=301, value=20) - acc.Reload() - ## Check that we have discarded 200 and 300 from s1 - self.assertEqual([x.step for x in acc.Tensors('s1')], [100, 101, 201, 301]) - - def testOrphanedDataNotDiscardedIfFlagUnset(self): - """Tests that events are not discarded if purge_orphaned_data is false. - """ - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen, purge_orphaned_data=False) - - gen.AddEvent( - event_pb2.Event(wall_time=0, step=0, file_version='brain.Event:1')) - gen.AddScalarTensor('s1', wall_time=1, step=100, value=20) - gen.AddScalarTensor('s1', wall_time=1, step=200, value=20) - gen.AddScalarTensor('s1', wall_time=1, step=300, value=20) - acc.Reload() - ## Check that number of items are what they should be - self.assertEqual([x.step for x in acc.Tensors('s1')], [100, 200, 300]) - - gen.AddScalarTensor('s1', wall_time=1, step=101, value=20) - gen.AddScalarTensor('s1', wall_time=1, step=201, value=20) - gen.AddScalarTensor('s1', wall_time=1, step=301, value=20) - acc.Reload() - ## Check that we have NOT discarded 200 and 300 from s1 - self.assertEqual([x.step for x in acc.Tensors('s1')], - [100, 200, 300, 101, 201, 301]) - - def testEventsDiscardedPerTagAfterRestartForFileVersionLessThan2(self): - """Tests that event discards after restart, only affect the misordered tag. - - If a step value is observed to be lower than what was previously seen, - this should force a discard of all previous items that are outdated, but - only for the out of order tag. Other tags should remain unaffected. - - Only file versions < 2 use this out-of-order discard logic. Later versions - discard events based on the step value of SessionLog.START. - """ - warnings = [] - self.stubs.Set(logger, 'warn', warnings.append) - - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - - gen.AddEvent( - event_pb2.Event(wall_time=0, step=0, file_version='brain.Event:1')) - gen.AddScalarTensor('s1', wall_time=1, step=100, value=20) - gen.AddScalarTensor('s2', wall_time=1, step=101, value=20) - gen.AddScalarTensor('s1', wall_time=1, step=200, value=20) - gen.AddScalarTensor('s2', wall_time=1, step=201, value=20) - gen.AddScalarTensor('s1', wall_time=1, step=300, value=20) - gen.AddScalarTensor('s2', wall_time=1, step=301, value=20) - gen.AddScalarTensor('s1', wall_time=1, step=101, value=20) - gen.AddScalarTensor('s3', wall_time=1, step=101, value=20) - gen.AddScalarTensor('s1', wall_time=1, step=201, value=20) - gen.AddScalarTensor('s1', wall_time=1, step=301, value=20) - - acc.Reload() - ## Check that we have discarded 200 and 300 for s1 - self.assertEqual([x.step for x in acc.Tensors('s1')], [100, 101, 201, 301]) - - ## Check that s1 discards do not affect s2 (written before out-of-order) - ## or s3 (written after out-of-order). - ## i.e. check that only events from the out of order tag are discarded - self.assertEqual([x.step for x in acc.Tensors('s2')], [101, 201, 301]) - self.assertEqual([x.step for x in acc.Tensors('s3')], [101]) - - def testOnlySummaryEventsTriggerDiscards(self): - """Test that file version event does not trigger data purge.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - gen.AddScalarTensor('s1', wall_time=1, step=100, value=20) - ev1 = event_pb2.Event(wall_time=2, step=0, file_version='brain.Event:1') - graph_bytes = tf.compat.v1.GraphDef().SerializeToString() - ev2 = event_pb2.Event(wall_time=3, step=0, graph_def=graph_bytes) - gen.AddEvent(ev1) - gen.AddEvent(ev2) - acc.Reload() - self.assertEqual([x.step for x in acc.Tensors('s1')], [100]) - - def testSessionLogStartMessageDiscardsExpiredEvents(self): - """Test that SessionLog.START message discards expired events. - - This discard logic is preferred over the out-of-order step discard logic, - but this logic can only be used for event protos which have the SessionLog - enum, which was introduced to event.proto for file_version >= brain.Event:2. - """ - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - gen.AddEvent( - event_pb2.Event(wall_time=0, step=1, file_version='brain.Event:2')) - - gen.AddScalarTensor('s1', wall_time=1, step=100, value=20) - gen.AddScalarTensor('s1', wall_time=1, step=200, value=20) - gen.AddScalarTensor('s1', wall_time=1, step=300, value=20) - gen.AddScalarTensor('s1', wall_time=1, step=400, value=20) - - gen.AddScalarTensor('s2', wall_time=1, step=202, value=20) - gen.AddScalarTensor('s2', wall_time=1, step=203, value=20) - - slog = event_pb2.SessionLog(status=event_pb2.SessionLog.START) - gen.AddEvent( - event_pb2.Event(wall_time=2, step=201, session_log=slog)) - acc.Reload() - self.assertEqual([x.step for x in acc.Tensors('s1')], [100, 200]) - self.assertEqual([x.step for x in acc.Tensors('s2')], []) - - def testFirstEventTimestamp(self): - """Test that FirstEventTimestamp() returns wall_time of the first event.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - gen.AddEvent( - event_pb2.Event(wall_time=10, step=20, file_version='brain.Event:2')) - gen.AddScalarTensor('s1', wall_time=30, step=40, value=20) - self.assertEqual(acc.FirstEventTimestamp(), 10) - - def testReloadPopulatesFirstEventTimestamp(self): - """Test that Reload() means FirstEventTimestamp() won't load events.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - gen.AddEvent( - event_pb2.Event(wall_time=1, step=2, file_version='brain.Event:2')) - - acc.Reload() - - def _Die(*args, **kwargs): # pylint: disable=unused-argument - raise RuntimeError('Load() should not be called') - - self.stubs.Set(gen, 'Load', _Die) - self.assertEqual(acc.FirstEventTimestamp(), 1) - - def testFirstEventTimestampLoadsEvent(self): - """Test that FirstEventTimestamp() doesn't discard the loaded event.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - gen.AddEvent( - event_pb2.Event(wall_time=1, step=2, file_version='brain.Event:2')) - - self.assertEqual(acc.FirstEventTimestamp(), 1) - acc.Reload() - self.assertEqual(acc.file_version, 2.0) - - def testNewStyleScalarSummary(self): - """Verify processing of tensorboard.plugins.scalar.summary.""" - event_sink = _EventGenerator(self, zero_out_timestamps=True) - writer = test_util.FileWriter(self.get_temp_dir()) - writer.event_writer = event_sink - with tf.compat.v1.Graph().as_default(): - with self.test_session() as sess: - step = tf.compat.v1.placeholder(tf.float32, shape=[]) - scalar_summary.op('accuracy', 1.0 - 1.0 / (step + tf.constant(1.0))) - scalar_summary.op('xent', 1.0 / (step + tf.constant(1.0))) - merged = tf.compat.v1.summary.merge_all() - writer.add_graph(sess.graph) - for i in xrange(10): - summ = sess.run(merged, feed_dict={step: float(i)}) - writer.add_summary(summ, global_step=i) - - accumulator = ea.EventAccumulator(event_sink) - accumulator.Reload() - - tags = [ - u'accuracy/scalar_summary', - u'xent/scalar_summary', - ] - - self.assertTagsEqual(accumulator.Tags(), { - ea.TENSORS: tags, - ea.GRAPH: True, - ea.META_GRAPH: False, - }) - - def testNewStyleAudioSummary(self): - """Verify processing of tensorboard.plugins.audio.summary.""" - event_sink = _EventGenerator(self, zero_out_timestamps=True) - writer = test_util.FileWriter(self.get_temp_dir()) - writer.event_writer = event_sink - with tf.compat.v1.Graph().as_default(): - with self.test_session() as sess: - ipt = tf.random.normal(shape=[5, 441, 2]) - with tf.name_scope('1'): - audio_summary.op('one', ipt, sample_rate=44100, max_outputs=1) - with tf.name_scope('2'): - audio_summary.op('two', ipt, sample_rate=44100, max_outputs=2) - with tf.name_scope('3'): - audio_summary.op('three', ipt, sample_rate=44100, max_outputs=3) - merged = tf.compat.v1.summary.merge_all() - writer.add_graph(sess.graph) - for i in xrange(10): - summ = sess.run(merged) - writer.add_summary(summ, global_step=i) - - accumulator = ea.EventAccumulator(event_sink) - accumulator.Reload() - - tags = [ - u'1/one/audio_summary', - u'2/two/audio_summary', - u'3/three/audio_summary', - ] - - self.assertTagsEqual(accumulator.Tags(), { - ea.TENSORS: tags, - ea.GRAPH: True, - ea.META_GRAPH: False, - }) - - def testNewStyleImageSummary(self): - """Verify processing of tensorboard.plugins.image.summary.""" - event_sink = _EventGenerator(self, zero_out_timestamps=True) - writer = test_util.FileWriter(self.get_temp_dir()) - writer.event_writer = event_sink - with tf.compat.v1.Graph().as_default(): - with self.test_session() as sess: - ipt = tf.ones([10, 4, 4, 3], tf.uint8) - # This is an interesting example, because the old tf.image_summary op - # would throw an error here, because it would be tag reuse. - # Using the tf node name instead allows argument re-use to the image - # summary. - with tf.name_scope('1'): - image_summary.op('images', ipt, max_outputs=1) - with tf.name_scope('2'): - image_summary.op('images', ipt, max_outputs=2) - with tf.name_scope('3'): - image_summary.op('images', ipt, max_outputs=3) - merged = tf.compat.v1.summary.merge_all() - writer.add_graph(sess.graph) - for i in xrange(10): - summ = sess.run(merged) - writer.add_summary(summ, global_step=i) - - accumulator = ea.EventAccumulator(event_sink) - accumulator.Reload() - - tags = [ - u'1/images/image_summary', - u'2/images/image_summary', - u'3/images/image_summary', - ] - - self.assertTagsEqual(accumulator.Tags(), { - ea.TENSORS: tags, - ea.GRAPH: True, - ea.META_GRAPH: False, - }) - - def testTFSummaryTensor(self): - """Verify processing of tf.summary.tensor.""" - event_sink = _EventGenerator(self, zero_out_timestamps=True) - writer = test_util.FileWriter(self.get_temp_dir()) - writer.event_writer = event_sink - with tf.compat.v1.Graph().as_default(): - with self.test_session() as sess: - tensor_summary = tf.compat.v1.summary.tensor_summary - tensor_summary('scalar', tf.constant(1.0)) - tensor_summary('vector', tf.constant([1.0, 2.0, 3.0])) - tensor_summary('string', tf.constant(six.b('foobar'))) - merged = tf.compat.v1.summary.merge_all() - summ = sess.run(merged) - writer.add_summary(summ, 0) - - accumulator = ea.EventAccumulator(event_sink) - accumulator.Reload() - - self.assertTagsEqual(accumulator.Tags(), { - ea.TENSORS: ['scalar', 'vector', 'string'], - }) - - scalar_proto = accumulator.Tensors('scalar')[0].tensor_proto - scalar = tensor_util.make_ndarray(scalar_proto) - vector_proto = accumulator.Tensors('vector')[0].tensor_proto - vector = tensor_util.make_ndarray(vector_proto) - string_proto = accumulator.Tensors('string')[0].tensor_proto - string = tensor_util.make_ndarray(string_proto) - - self.assertTrue(np.array_equal(scalar, 1.0)) - self.assertTrue(np.array_equal(vector, [1.0, 2.0, 3.0])) - self.assertTrue(np.array_equal(string, six.b('foobar'))) - - def _testTFSummaryTensor_SizeGuidance(self, - plugin_name, - tensor_size_guidance, - steps, - expected_count): - event_sink = _EventGenerator(self, zero_out_timestamps=True) - writer = test_util.FileWriter(self.get_temp_dir()) - writer.event_writer = event_sink - with tf.compat.v1.Graph().as_default(): - with self.test_session() as sess: + +class RealisticEventAccumulatorTest(EventAccumulatorTest): + def testTensorsRealistically(self): + """Test accumulator by writing values and then reading them.""" + + def FakeScalarSummary(tag, value): + value = summary_pb2.Summary.Value(tag=tag, simple_value=value) + summary = summary_pb2.Summary(value=[value]) + return summary + + directory = os.path.join(self.get_temp_dir(), "values_dir") + if tf.io.gfile.isdir(directory): + tf.io.gfile.rmtree(directory) + tf.io.gfile.mkdir(directory) + + writer = test_util.FileWriter(directory, max_queue=100) + + with tf.Graph().as_default() as graph: + _ = tf.constant([2.0, 1.0]) + # Add a graph to the summary writer. + writer.add_graph(graph) + graph_def = graph.as_graph_def(add_shapes=True) + meta_graph_def = tf.compat.v1.train.export_meta_graph( + graph_def=graph_def + ) + writer.add_meta_graph(meta_graph_def) + + run_metadata = config_pb2.RunMetadata() + device_stats = run_metadata.step_stats.dev_stats.add() + device_stats.device = "test device" + writer.add_run_metadata(run_metadata, "test run") + + # Write a bunch of events using the writer. + for i in xrange(30): + summ_id = FakeScalarSummary("id", i) + summ_sq = FakeScalarSummary("sq", i * i) + writer.add_summary(summ_id, i * 5) + writer.add_summary(summ_sq, i * 5) + writer.flush() + + # Verify that we can load those events properly + acc = ea.EventAccumulator(directory) + acc.Reload() + self.assertTagsEqual( + acc.Tags(), + { + ea.TENSORS: ["id", "sq"], + ea.GRAPH: True, + ea.META_GRAPH: True, + ea.RUN_METADATA: ["test run"], + }, + ) + id_events = acc.Tensors("id") + sq_events = acc.Tensors("sq") + self.assertEqual(30, len(id_events)) + self.assertEqual(30, len(sq_events)) + for i in xrange(30): + self.assertEqual(i * 5, id_events[i].step) + self.assertEqual(i * 5, sq_events[i].step) + self.assertEqual( + i, tensor_util.make_ndarray(id_events[i].tensor_proto).item() + ) + self.assertEqual( + i * i, + tensor_util.make_ndarray(sq_events[i].tensor_proto).item(), + ) + + # Write a few more events to test incremental reloading + for i in xrange(30, 40): + summ_id = FakeScalarSummary("id", i) + summ_sq = FakeScalarSummary("sq", i * i) + writer.add_summary(summ_id, i * 5) + writer.add_summary(summ_sq, i * 5) + writer.flush() + + # Verify we can now see all of the data + acc.Reload() + id_events = acc.Tensors("id") + sq_events = acc.Tensors("sq") + self.assertEqual(40, len(id_events)) + self.assertEqual(40, len(sq_events)) + for i in xrange(40): + self.assertEqual(i * 5, id_events[i].step) + self.assertEqual(i * 5, sq_events[i].step) + self.assertEqual( + i, tensor_util.make_ndarray(id_events[i].tensor_proto).item() + ) + self.assertEqual( + i * i, + tensor_util.make_ndarray(sq_events[i].tensor_proto).item(), + ) + + expected_graph_def = graph_pb2.GraphDef.FromString( + graph.as_graph_def(add_shapes=True).SerializeToString() + ) + self.assertProtoEquals(expected_graph_def, acc.Graph()) + self.assertProtoEquals( + expected_graph_def, + graph_pb2.GraphDef.FromString(acc.SerializedGraph()), + ) + + expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString( + meta_graph_def.SerializeToString() + ) + self.assertProtoEquals(expected_meta_graph, acc.MetaGraph()) + + def testGraphFromMetaGraphBecomesAvailable(self): + """Test accumulator by writing values and then reading them.""" + + directory = os.path.join( + self.get_temp_dir(), "metagraph_test_values_dir" + ) + if tf.io.gfile.isdir(directory): + tf.io.gfile.rmtree(directory) + tf.io.gfile.mkdir(directory) + + writer = test_util.FileWriter(directory, max_queue=100) + + with tf.Graph().as_default() as graph: + _ = tf.constant([2.0, 1.0]) + # Add a graph to the summary writer. + graph_def = graph.as_graph_def(add_shapes=True) + meta_graph_def = tf.compat.v1.train.export_meta_graph( + graph_def=graph_def + ) + writer.add_meta_graph(meta_graph_def) + writer.flush() + + # Verify that we can load those events properly + acc = ea.EventAccumulator(directory) + acc.Reload() + self.assertTagsEqual(acc.Tags(), {ea.GRAPH: True, ea.META_GRAPH: True,}) + + expected_graph_def = graph_pb2.GraphDef.FromString( + graph.as_graph_def(add_shapes=True).SerializeToString() + ) + self.assertProtoEquals(expected_graph_def, acc.Graph()) + self.assertProtoEquals( + expected_graph_def, + graph_pb2.GraphDef.FromString(acc.SerializedGraph()), + ) + + expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString( + meta_graph_def.SerializeToString() + ) + self.assertProtoEquals(expected_meta_graph, acc.MetaGraph()) + + def _writeMetadata(self, logdir, summary_metadata, nonce=""): + """Write to disk a summary with the given metadata. + + Arguments: + logdir: a string + summary_metadata: a `SummaryMetadata` protobuf object + nonce: optional; will be added to the end of the event file name + to guarantee that multiple calls to this function do not stomp the + same file + """ + + summary = summary_pb2.Summary() + summary.value.add( + tensor=tensor_util.make_tensor_proto( + ["po", "ta", "to"], dtype=tf.string + ), + tag="you_are_it", + metadata=summary_metadata, + ) + writer = test_util.FileWriter(logdir, filename_suffix=nonce) + writer.add_summary(summary.SerializeToString()) + writer.close() + + def testSummaryMetadata(self): + logdir = self.get_temp_dir() summary_metadata = summary_pb2.SummaryMetadata( + display_name="current tagee", + summary_description="no", plugin_data=summary_pb2.SummaryMetadata.PluginData( - plugin_name=plugin_name, content=b'{}')) - tf.compat.v1.summary.tensor_summary('scalar', tf.constant(1.0), - summary_metadata=summary_metadata) - merged = tf.compat.v1.summary.merge_all() - for step in xrange(steps): - writer.add_summary(sess.run(merged), global_step=step) - - - accumulator = ea.EventAccumulator( - event_sink, tensor_size_guidance=tensor_size_guidance) - accumulator.Reload() - - tensors = accumulator.Tensors('scalar') - self.assertEqual(len(tensors), expected_count) - - def testTFSummaryTensor_SizeGuidance_DefaultToTensorGuidance(self): - self._testTFSummaryTensor_SizeGuidance( - plugin_name='jabberwocky', - tensor_size_guidance={}, - steps=ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] + 1, - expected_count=ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS]) - - def testTFSummaryTensor_SizeGuidance_UseSmallSingularPluginGuidance(self): - size = int(ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] / 2) - assert size < ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS], size - self._testTFSummaryTensor_SizeGuidance( - plugin_name='jabberwocky', - tensor_size_guidance={'jabberwocky': size}, - steps=ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] + 1, - expected_count=size) - - def testTFSummaryTensor_SizeGuidance_UseLargeSingularPluginGuidance(self): - size = ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] + 5 - self._testTFSummaryTensor_SizeGuidance( - plugin_name='jabberwocky', - tensor_size_guidance={'jabberwocky': size}, - steps=ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] + 10, - expected_count=size) - - def testTFSummaryTensor_SizeGuidance_IgnoreIrrelevantGuidances(self): - size_small = int(ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] / 3) - size_large = int(ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] / 2) - assert size_small < size_large < ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS], ( - size_small, size_large) - self._testTFSummaryTensor_SizeGuidance( - plugin_name='jabberwocky', - tensor_size_guidance={'jabberwocky': size_small, - 'wnoorejbpxl': size_large}, - steps=ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] + 1, - expected_count=size_small) - + plugin_name="outlet" + ), + ) + self._writeMetadata(logdir, summary_metadata) + acc = ea.EventAccumulator(logdir) + acc.Reload() + self.assertProtoEquals( + summary_metadata, acc.SummaryMetadata("you_are_it") + ) + + def testSummaryMetadata_FirstMetadataWins(self): + logdir = self.get_temp_dir() + summary_metadata_1 = summary_pb2.SummaryMetadata( + display_name="current tagee", + summary_description="no", + plugin_data=summary_pb2.SummaryMetadata.PluginData( + plugin_name="outlet", content=b"120v" + ), + ) + self._writeMetadata(logdir, summary_metadata_1, nonce="1") + acc = ea.EventAccumulator(logdir) + acc.Reload() + summary_metadata_2 = summary_pb2.SummaryMetadata( + display_name="tagee of the future", + summary_description="definitely not", + plugin_data=summary_pb2.SummaryMetadata.PluginData( + plugin_name="plug", content=b"110v" + ), + ) + self._writeMetadata(logdir, summary_metadata_2, nonce="2") + acc.Reload() + + self.assertProtoEquals( + summary_metadata_1, acc.SummaryMetadata("you_are_it") + ) + + def testPluginTagToContent_PluginsCannotJumpOnTheBandwagon(self): + # If there are multiple `SummaryMetadata` for a given tag, and the + # set of plugins in the `plugin_data` of second is different from + # that of the first, then the second set should be ignored. + logdir = self.get_temp_dir() + summary_metadata_1 = summary_pb2.SummaryMetadata( + display_name="current tagee", + summary_description="no", + plugin_data=summary_pb2.SummaryMetadata.PluginData( + plugin_name="outlet", content=b"120v" + ), + ) + self._writeMetadata(logdir, summary_metadata_1, nonce="1") + acc = ea.EventAccumulator(logdir) + acc.Reload() + summary_metadata_2 = summary_pb2.SummaryMetadata( + display_name="tagee of the future", + summary_description="definitely not", + plugin_data=summary_pb2.SummaryMetadata.PluginData( + plugin_name="plug", content=b"110v" + ), + ) + self._writeMetadata(logdir, summary_metadata_2, nonce="2") + acc.Reload() -class RealisticEventAccumulatorTest(EventAccumulatorTest): + self.assertEqual( + acc.PluginTagToContent("outlet"), {"you_are_it": b"120v"} + ) + with six.assertRaisesRegex(self, KeyError, "plug"): + acc.PluginTagToContent("plug") - def testTensorsRealistically(self): - """Test accumulator by writing values and then reading them.""" - - def FakeScalarSummary(tag, value): - value = summary_pb2.Summary.Value(tag=tag, simple_value=value) - summary = summary_pb2.Summary(value=[value]) - return summary - - directory = os.path.join(self.get_temp_dir(), 'values_dir') - if tf.io.gfile.isdir(directory): - tf.io.gfile.rmtree(directory) - tf.io.gfile.mkdir(directory) - - writer = test_util.FileWriter(directory, max_queue=100) - - with tf.Graph().as_default() as graph: - _ = tf.constant([2.0, 1.0]) - # Add a graph to the summary writer. - writer.add_graph(graph) - graph_def = graph.as_graph_def(add_shapes=True) - meta_graph_def = tf.compat.v1.train.export_meta_graph(graph_def=graph_def) - writer.add_meta_graph(meta_graph_def) - - run_metadata = config_pb2.RunMetadata() - device_stats = run_metadata.step_stats.dev_stats.add() - device_stats.device = 'test device' - writer.add_run_metadata(run_metadata, 'test run') - - # Write a bunch of events using the writer. - for i in xrange(30): - summ_id = FakeScalarSummary('id', i) - summ_sq = FakeScalarSummary('sq', i * i) - writer.add_summary(summ_id, i * 5) - writer.add_summary(summ_sq, i * 5) - writer.flush() - - # Verify that we can load those events properly - acc = ea.EventAccumulator(directory) - acc.Reload() - self.assertTagsEqual(acc.Tags(), { - ea.TENSORS: ['id', 'sq'], - ea.GRAPH: True, - ea.META_GRAPH: True, - ea.RUN_METADATA: ['test run'], - }) - id_events = acc.Tensors('id') - sq_events = acc.Tensors('sq') - self.assertEqual(30, len(id_events)) - self.assertEqual(30, len(sq_events)) - for i in xrange(30): - self.assertEqual(i * 5, id_events[i].step) - self.assertEqual(i * 5, sq_events[i].step) - self.assertEqual(i, tensor_util.make_ndarray(id_events[i].tensor_proto).item()) - self.assertEqual(i * i, tensor_util.make_ndarray(sq_events[i].tensor_proto).item()) - - # Write a few more events to test incremental reloading - for i in xrange(30, 40): - summ_id = FakeScalarSummary('id', i) - summ_sq = FakeScalarSummary('sq', i * i) - writer.add_summary(summ_id, i * 5) - writer.add_summary(summ_sq, i * 5) - writer.flush() - - # Verify we can now see all of the data - acc.Reload() - id_events = acc.Tensors('id') - sq_events = acc.Tensors('sq') - self.assertEqual(40, len(id_events)) - self.assertEqual(40, len(sq_events)) - for i in xrange(40): - self.assertEqual(i * 5, id_events[i].step) - self.assertEqual(i * 5, sq_events[i].step) - self.assertEqual(i, tensor_util.make_ndarray(id_events[i].tensor_proto).item()) - self.assertEqual(i * i, tensor_util.make_ndarray(sq_events[i].tensor_proto).item()) - - expected_graph_def = graph_pb2.GraphDef.FromString( - graph.as_graph_def(add_shapes=True).SerializeToString()) - self.assertProtoEquals(expected_graph_def, acc.Graph()) - self.assertProtoEquals(expected_graph_def, - graph_pb2.GraphDef.FromString(acc.SerializedGraph())) - - expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString( - meta_graph_def.SerializeToString()) - self.assertProtoEquals(expected_meta_graph, acc.MetaGraph()) - - def testGraphFromMetaGraphBecomesAvailable(self): - """Test accumulator by writing values and then reading them.""" - - directory = os.path.join(self.get_temp_dir(), 'metagraph_test_values_dir') - if tf.io.gfile.isdir(directory): - tf.io.gfile.rmtree(directory) - tf.io.gfile.mkdir(directory) - - writer = test_util.FileWriter(directory, max_queue=100) - - with tf.Graph().as_default() as graph: - _ = tf.constant([2.0, 1.0]) - # Add a graph to the summary writer. - graph_def = graph.as_graph_def(add_shapes=True) - meta_graph_def = tf.compat.v1.train.export_meta_graph(graph_def=graph_def) - writer.add_meta_graph(meta_graph_def) - writer.flush() - - # Verify that we can load those events properly - acc = ea.EventAccumulator(directory) - acc.Reload() - self.assertTagsEqual(acc.Tags(), { - ea.GRAPH: True, - ea.META_GRAPH: True, - }) - - expected_graph_def = graph_pb2.GraphDef.FromString( - graph.as_graph_def(add_shapes=True).SerializeToString()) - self.assertProtoEquals(expected_graph_def, acc.Graph()) - self.assertProtoEquals(expected_graph_def, - graph_pb2.GraphDef.FromString(acc.SerializedGraph())) - - expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString( - meta_graph_def.SerializeToString()) - self.assertProtoEquals(expected_meta_graph, acc.MetaGraph()) - - def _writeMetadata(self, logdir, summary_metadata, nonce=''): - """Write to disk a summary with the given metadata. - - Arguments: - logdir: a string - summary_metadata: a `SummaryMetadata` protobuf object - nonce: optional; will be added to the end of the event file name - to guarantee that multiple calls to this function do not stomp the - same file - """ - summary = summary_pb2.Summary() - summary.value.add( - tensor=tensor_util.make_tensor_proto(['po', 'ta', 'to'], dtype=tf.string), - tag='you_are_it', - metadata=summary_metadata) - writer = test_util.FileWriter(logdir, filename_suffix=nonce) - writer.add_summary(summary.SerializeToString()) - writer.close() - - def testSummaryMetadata(self): - logdir = self.get_temp_dir() - summary_metadata = summary_pb2.SummaryMetadata( - display_name='current tagee', - summary_description='no', - plugin_data=summary_pb2.SummaryMetadata.PluginData( - plugin_name='outlet')) - self._writeMetadata(logdir, summary_metadata) - acc = ea.EventAccumulator(logdir) - acc.Reload() - self.assertProtoEquals(summary_metadata, - acc.SummaryMetadata('you_are_it')) - - def testSummaryMetadata_FirstMetadataWins(self): - logdir = self.get_temp_dir() - summary_metadata_1 = summary_pb2.SummaryMetadata( - display_name='current tagee', - summary_description='no', - plugin_data=summary_pb2.SummaryMetadata.PluginData( - plugin_name='outlet', content=b'120v')) - self._writeMetadata(logdir, summary_metadata_1, nonce='1') - acc = ea.EventAccumulator(logdir) - acc.Reload() - summary_metadata_2 = summary_pb2.SummaryMetadata( - display_name='tagee of the future', - summary_description='definitely not', - plugin_data=summary_pb2.SummaryMetadata.PluginData( - plugin_name='plug', content=b'110v')) - self._writeMetadata(logdir, summary_metadata_2, nonce='2') - acc.Reload() - - self.assertProtoEquals(summary_metadata_1, - acc.SummaryMetadata('you_are_it')) - - def testPluginTagToContent_PluginsCannotJumpOnTheBandwagon(self): - # If there are multiple `SummaryMetadata` for a given tag, and the - # set of plugins in the `plugin_data` of second is different from - # that of the first, then the second set should be ignored. - logdir = self.get_temp_dir() - summary_metadata_1 = summary_pb2.SummaryMetadata( - display_name='current tagee', - summary_description='no', - plugin_data=summary_pb2.SummaryMetadata.PluginData( - plugin_name='outlet', content=b'120v')) - self._writeMetadata(logdir, summary_metadata_1, nonce='1') - acc = ea.EventAccumulator(logdir) - acc.Reload() - summary_metadata_2 = summary_pb2.SummaryMetadata( - display_name='tagee of the future', - summary_description='definitely not', - plugin_data=summary_pb2.SummaryMetadata.PluginData( - plugin_name='plug', content=b'110v')) - self._writeMetadata(logdir, summary_metadata_2, nonce='2') - acc.Reload() - - self.assertEqual(acc.PluginTagToContent('outlet'), - {'you_are_it': b'120v'}) - with six.assertRaisesRegex(self, KeyError, 'plug'): - acc.PluginTagToContent('plug') - -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/backend/event_processing/plugin_event_multiplexer.py b/tensorboard/backend/event_processing/plugin_event_multiplexer.py index 566ab6f9bb..e5304f3e3e 100644 --- a/tensorboard/backend/event_processing/plugin_event_multiplexer.py +++ b/tensorboard/backend/event_processing/plugin_event_multiplexer.py @@ -25,439 +25,454 @@ from six.moves import queue, xrange # pylint: disable=redefined-builtin from tensorboard.backend.event_processing import directory_watcher -from tensorboard.backend.event_processing import plugin_event_accumulator as event_accumulator +from tensorboard.backend.event_processing import ( + plugin_event_accumulator as event_accumulator, +) from tensorboard.backend.event_processing import io_wrapper from tensorboard.util import tb_logging logger = tb_logging.get_logger() -class EventMultiplexer(object): - """An `EventMultiplexer` manages access to multiple `EventAccumulator`s. - - Each `EventAccumulator` is associated with a `run`, which is a self-contained - TensorFlow execution. The `EventMultiplexer` provides methods for extracting - information about events from multiple `run`s. - - Example usage for loading specific runs from files: - - ```python - x = EventMultiplexer({'run1': 'path/to/run1', 'run2': 'path/to/run2'}) - x.Reload() - ``` - - Example usage for loading a directory where each subdirectory is a run - - ```python - (eg:) /parent/directory/path/ - /parent/directory/path/run1/ - /parent/directory/path/run1/events.out.tfevents.1001 - /parent/directory/path/run1/events.out.tfevents.1002 - - /parent/directory/path/run2/ - /parent/directory/path/run2/events.out.tfevents.9232 - - /parent/directory/path/run3/ - /parent/directory/path/run3/events.out.tfevents.9232 - x = EventMultiplexer().AddRunsFromDirectory('/parent/directory/path') - (which is equivalent to:) - x = EventMultiplexer({'run1': '/parent/directory/path/run1', 'run2':...} - ``` - - If you would like to watch `/parent/directory/path`, wait for it to be created - (if necessary) and then periodically pick up new runs, use - `AutoloadingMultiplexer` - @@Tensors - """ - - def __init__(self, - run_path_map=None, - size_guidance=None, - tensor_size_guidance=None, - purge_orphaned_data=True, - max_reload_threads=None, - event_file_active_filter=None): - """Constructor for the `EventMultiplexer`. - - Args: - run_path_map: Dict `{run: path}` which specifies the - name of a run, and the path to find the associated events. If it is - None, then the EventMultiplexer initializes without any runs. - size_guidance: A dictionary mapping from `tagType` to the number of items - to store for each tag of that type. See - `event_accumulator.EventAccumulator` for details. - tensor_size_guidance: A dictionary mapping from `plugin_name` to - the number of items to store for each tag of that type. See - `event_accumulator.EventAccumulator` for details. - purge_orphaned_data: Whether to discard any events that were "orphaned" by - a TensorFlow restart. - max_reload_threads: The max number of threads that TensorBoard can use - to reload runs. Each thread reloads one run at a time. If not provided, - reloads runs serially (one after another). - event_file_active_filter: Optional predicate for determining whether an - event file latest load timestamp should be considered active. If passed, - this will enable multifile directory loading. - """ - logger.info('Event Multiplexer initializing.') - self._accumulators_mutex = threading.Lock() - self._accumulators = {} - self._paths = {} - self._reload_called = False - self._size_guidance = (size_guidance or - event_accumulator.DEFAULT_SIZE_GUIDANCE) - self._tensor_size_guidance = tensor_size_guidance - self.purge_orphaned_data = purge_orphaned_data - self._max_reload_threads = max_reload_threads or 1 - self._event_file_active_filter = event_file_active_filter - if run_path_map is not None: - logger.info('Event Multplexer doing initialization load for %s', - run_path_map) - for (run, path) in six.iteritems(run_path_map): - self.AddRun(path, run) - logger.info('Event Multiplexer done initializing') - - def AddRun(self, path, name=None): - """Add a run to the multiplexer. - - If the name is not specified, it is the same as the path. - - If a run by that name exists, and we are already watching the right path, - do nothing. If we are watching a different path, replace the event - accumulator. - - If `Reload` has been called, it will `Reload` the newly created - accumulators. - - Args: - path: Path to the event files (or event directory) for given run. - name: Name of the run to add. If not provided, is set to path. - - Returns: - The `EventMultiplexer`. - """ - name = name or path - accumulator = None - with self._accumulators_mutex: - if name not in self._accumulators or self._paths[name] != path: - if name in self._paths and self._paths[name] != path: - # TODO(@decentralion) - Make it impossible to overwrite an old path - # with a new path (just give the new path a distinct name) - logger.warn('Conflict for name %s: old path %s, new path %s', - name, self._paths[name], path) - logger.info('Constructing EventAccumulator for %s', path) - accumulator = event_accumulator.EventAccumulator( - path, - size_guidance=self._size_guidance, - tensor_size_guidance=self._tensor_size_guidance, - purge_orphaned_data=self.purge_orphaned_data, - event_file_active_filter=self._event_file_active_filter) - self._accumulators[name] = accumulator - self._paths[name] = path - if accumulator: - if self._reload_called: - accumulator.Reload() - return self - - def AddRunsFromDirectory(self, path, name=None): - """Load runs from a directory; recursively walks subdirectories. - - If path doesn't exist, no-op. This ensures that it is safe to call - `AddRunsFromDirectory` multiple times, even before the directory is made. - - If path is a directory, load event files in the directory (if any exist) and - recursively call AddRunsFromDirectory on any subdirectories. This mean you - can call AddRunsFromDirectory at the root of a tree of event logs and - TensorBoard will load them all. - - If the `EventMultiplexer` is already loaded this will cause - the newly created accumulators to `Reload()`. - Args: - path: A string path to a directory to load runs from. - name: Optionally, what name to apply to the runs. If name is provided - and the directory contains run subdirectories, the name of each subrun - is the concatenation of the parent name and the subdirectory name. If - name is provided and the directory contains event files, then a run - is added called "name" and with the events from the path. - - Raises: - ValueError: If the path exists and isn't a directory. - - Returns: - The `EventMultiplexer`. - """ - logger.info('Starting AddRunsFromDirectory: %s', path) - for subdir in io_wrapper.GetLogdirSubdirectories(path): - logger.info('Adding run from directory %s', subdir) - rpath = os.path.relpath(subdir, path) - subname = os.path.join(name, rpath) if name else rpath - self.AddRun(subdir, name=subname) - logger.info('Done with AddRunsFromDirectory: %s', path) - return self - - def Reload(self): - """Call `Reload` on every `EventAccumulator`.""" - logger.info('Beginning EventMultiplexer.Reload()') - self._reload_called = True - # Build a list so we're safe even if the list of accumulators is modified - # even while we're reloading. - with self._accumulators_mutex: - items = list(self._accumulators.items()) - items_queue = queue.Queue() - for item in items: - items_queue.put(item) - - # Methods of built-in python containers are thread-safe so long as the GIL - # for the thread exists, but we might as well be careful. - names_to_delete = set() - names_to_delete_mutex = threading.Lock() - - def Worker(): - """Keeps reloading accumulators til none are left.""" - while True: - try: - name, accumulator = items_queue.get(block=False) - except queue.Empty: - # No more runs to reload. - break - - try: - accumulator.Reload() - except (OSError, IOError) as e: - logger.error('Unable to reload accumulator %r: %s', name, e) - except directory_watcher.DirectoryDeletedError: - with names_to_delete_mutex: - names_to_delete.add(name) - finally: - items_queue.task_done() - - if self._max_reload_threads > 1: - num_threads = min( - self._max_reload_threads, len(items)) - logger.info('Starting %d threads to reload runs', num_threads) - for i in xrange(num_threads): - thread = threading.Thread(target=Worker, name='Reloader %d' % i) - thread.daemon = True - thread.start() - items_queue.join() - else: - logger.info( - 'Reloading runs serially (one after another) on the main ' - 'thread.') - Worker() - - with self._accumulators_mutex: - for name in names_to_delete: - logger.warn('Deleting accumulator %r', name) - del self._accumulators[name] - logger.info('Finished with EventMultiplexer.Reload()') - return self - - def PluginAssets(self, plugin_name): - """Get index of runs and assets for a given plugin. - - Args: - plugin_name: Name of the plugin we are checking for. - - Returns: - A dictionary that maps from run_name to a list of plugin - assets for that run. - """ - with self._accumulators_mutex: - # To avoid nested locks, we construct a copy of the run-accumulator map - items = list(six.iteritems(self._accumulators)) - - return {run: accum.PluginAssets(plugin_name) for run, accum in items} - - def RetrievePluginAsset(self, run, plugin_name, asset_name): - """Return the contents for a specific plugin asset from a run. - - Args: - run: The string name of the run. - plugin_name: The string name of a plugin. - asset_name: The string name of an asset. - - Returns: - The string contents of the plugin asset. - - Raises: - KeyError: If the asset is not available. - """ - accumulator = self.GetAccumulator(run) - return accumulator.RetrievePluginAsset(plugin_name, asset_name) - - def FirstEventTimestamp(self, run): - """Return the timestamp of the first event of the given run. - - This may perform I/O if no events have been loaded yet for the run. - Args: - run: A string name of the run for which the timestamp is retrieved. - - Returns: - The wall_time of the first event of the run, which will typically be - seconds since the epoch. - - Raises: - KeyError: If the run is not found. - ValueError: If the run has no events loaded and there are no events on - disk to load. - """ - accumulator = self.GetAccumulator(run) - return accumulator.FirstEventTimestamp() - - def Graph(self, run): - """Retrieve the graph associated with the provided run. - - Args: - run: A string name of a run to load the graph for. - - Raises: - KeyError: If the run is not found. - ValueError: If the run does not have an associated graph. - - Returns: - The `GraphDef` protobuf data structure. - """ - accumulator = self.GetAccumulator(run) - return accumulator.Graph() - - def SerializedGraph(self, run): - """Retrieve the serialized graph associated with the provided run. - - Args: - run: A string name of a run to load the graph for. - - Raises: - KeyError: If the run is not found. - ValueError: If the run does not have an associated graph. - - Returns: - The serialized form of the `GraphDef` protobuf data structure. - """ - accumulator = self.GetAccumulator(run) - return accumulator.SerializedGraph() - - def MetaGraph(self, run): - """Retrieve the metagraph associated with the provided run. - - Args: - run: A string name of a run to load the graph for. - - Raises: - KeyError: If the run is not found. - ValueError: If the run does not have an associated graph. - - Returns: - The `MetaGraphDef` protobuf data structure. - """ - accumulator = self.GetAccumulator(run) - return accumulator.MetaGraph() - - def RunMetadata(self, run, tag): - """Get the session.run() metadata associated with a TensorFlow run and tag. - - Args: - run: A string name of a TensorFlow run. - tag: A string name of the tag associated with a particular session.run(). - - Raises: - KeyError: If the run is not found, or the tag is not available for the - given run. - - Returns: - The metadata in the form of `RunMetadata` protobuf data structure. - """ - accumulator = self.GetAccumulator(run) - return accumulator.RunMetadata(tag) - - def Tensors(self, run, tag): - """Retrieve the tensor events associated with a run and tag. - - Args: - run: A string name of the run for which values are retrieved. - tag: A string name of the tag for which values are retrieved. - - Raises: - KeyError: If the run is not found, or the tag is not available for - the given run. - - Returns: - An array of `event_accumulator.TensorEvent`s. - """ - accumulator = self.GetAccumulator(run) - return accumulator.Tensors(tag) - - def PluginRunToTagToContent(self, plugin_name): - """Returns a 2-layer dictionary of the form {run: {tag: content}}. - - The `content` referred above is the content field of the PluginData proto - for the specified plugin within a Summary.Value proto. - - Args: - plugin_name: The name of the plugin for which to fetch content. +class EventMultiplexer(object): + """An `EventMultiplexer` manages access to multiple `EventAccumulator`s. - Returns: - A dictionary of the form {run: {tag: content}}. - """ - mapping = {} - for run in self.Runs(): - try: - tag_to_content = self.GetAccumulator(run).PluginTagToContent( - plugin_name) - except KeyError: - # This run lacks content for the plugin. Try the next run. - continue - mapping[run] = tag_to_content - return mapping - - def SummaryMetadata(self, run, tag): - """Return the summary metadata for the given tag on the given run. - - Args: - run: A string name of the run for which summary metadata is to be - retrieved. - tag: A string name of the tag whose summary metadata is to be - retrieved. - - Raises: - KeyError: If the run is not found, or the tag is not available for - the given run. - - Returns: - A `SummaryMetadata` protobuf. - """ - accumulator = self.GetAccumulator(run) - return accumulator.SummaryMetadata(tag) + Each `EventAccumulator` is associated with a `run`, which is a self-contained + TensorFlow execution. The `EventMultiplexer` provides methods for extracting + information about events from multiple `run`s. - def Runs(self): - """Return all the run names in the `EventMultiplexer`. + Example usage for loading specific runs from files: - Returns: - ``` - {runName: { scalarValues: [tagA, tagB, tagC], - graph: true, meta_graph: true}} + ```python + x = EventMultiplexer({'run1': 'path/to/run1', 'run2': 'path/to/run2'}) + x.Reload() ``` - """ - with self._accumulators_mutex: - # To avoid nested locks, we construct a copy of the run-accumulator map - items = list(six.iteritems(self._accumulators)) - return {run_name: accumulator.Tags() for run_name, accumulator in items} - def RunPaths(self): - """Returns a dict mapping run names to event file paths.""" - return self._paths + Example usage for loading a directory where each subdirectory is a run - def GetAccumulator(self, run): - """Returns EventAccumulator for a given run. + ```python + (eg:) /parent/directory/path/ + /parent/directory/path/run1/ + /parent/directory/path/run1/events.out.tfevents.1001 + /parent/directory/path/run1/events.out.tfevents.1002 - Args: - run: String name of run. + /parent/directory/path/run2/ + /parent/directory/path/run2/events.out.tfevents.9232 - Returns: - An EventAccumulator object. + /parent/directory/path/run3/ + /parent/directory/path/run3/events.out.tfevents.9232 + x = EventMultiplexer().AddRunsFromDirectory('/parent/directory/path') + (which is equivalent to:) + x = EventMultiplexer({'run1': '/parent/directory/path/run1', 'run2':...} + ``` - Raises: - KeyError: If run does not exist. + If you would like to watch `/parent/directory/path`, wait for it to be created + (if necessary) and then periodically pick up new runs, use + `AutoloadingMultiplexer` + @@Tensors """ - with self._accumulators_mutex: - return self._accumulators[run] + + def __init__( + self, + run_path_map=None, + size_guidance=None, + tensor_size_guidance=None, + purge_orphaned_data=True, + max_reload_threads=None, + event_file_active_filter=None, + ): + """Constructor for the `EventMultiplexer`. + + Args: + run_path_map: Dict `{run: path}` which specifies the + name of a run, and the path to find the associated events. If it is + None, then the EventMultiplexer initializes without any runs. + size_guidance: A dictionary mapping from `tagType` to the number of items + to store for each tag of that type. See + `event_accumulator.EventAccumulator` for details. + tensor_size_guidance: A dictionary mapping from `plugin_name` to + the number of items to store for each tag of that type. See + `event_accumulator.EventAccumulator` for details. + purge_orphaned_data: Whether to discard any events that were "orphaned" by + a TensorFlow restart. + max_reload_threads: The max number of threads that TensorBoard can use + to reload runs. Each thread reloads one run at a time. If not provided, + reloads runs serially (one after another). + event_file_active_filter: Optional predicate for determining whether an + event file latest load timestamp should be considered active. If passed, + this will enable multifile directory loading. + """ + logger.info("Event Multiplexer initializing.") + self._accumulators_mutex = threading.Lock() + self._accumulators = {} + self._paths = {} + self._reload_called = False + self._size_guidance = ( + size_guidance or event_accumulator.DEFAULT_SIZE_GUIDANCE + ) + self._tensor_size_guidance = tensor_size_guidance + self.purge_orphaned_data = purge_orphaned_data + self._max_reload_threads = max_reload_threads or 1 + self._event_file_active_filter = event_file_active_filter + if run_path_map is not None: + logger.info( + "Event Multplexer doing initialization load for %s", + run_path_map, + ) + for (run, path) in six.iteritems(run_path_map): + self.AddRun(path, run) + logger.info("Event Multiplexer done initializing") + + def AddRun(self, path, name=None): + """Add a run to the multiplexer. + + If the name is not specified, it is the same as the path. + + If a run by that name exists, and we are already watching the right path, + do nothing. If we are watching a different path, replace the event + accumulator. + + If `Reload` has been called, it will `Reload` the newly created + accumulators. + + Args: + path: Path to the event files (or event directory) for given run. + name: Name of the run to add. If not provided, is set to path. + + Returns: + The `EventMultiplexer`. + """ + name = name or path + accumulator = None + with self._accumulators_mutex: + if name not in self._accumulators or self._paths[name] != path: + if name in self._paths and self._paths[name] != path: + # TODO(@decentralion) - Make it impossible to overwrite an old path + # with a new path (just give the new path a distinct name) + logger.warn( + "Conflict for name %s: old path %s, new path %s", + name, + self._paths[name], + path, + ) + logger.info("Constructing EventAccumulator for %s", path) + accumulator = event_accumulator.EventAccumulator( + path, + size_guidance=self._size_guidance, + tensor_size_guidance=self._tensor_size_guidance, + purge_orphaned_data=self.purge_orphaned_data, + event_file_active_filter=self._event_file_active_filter, + ) + self._accumulators[name] = accumulator + self._paths[name] = path + if accumulator: + if self._reload_called: + accumulator.Reload() + return self + + def AddRunsFromDirectory(self, path, name=None): + """Load runs from a directory; recursively walks subdirectories. + + If path doesn't exist, no-op. This ensures that it is safe to call + `AddRunsFromDirectory` multiple times, even before the directory is made. + + If path is a directory, load event files in the directory (if any exist) and + recursively call AddRunsFromDirectory on any subdirectories. This mean you + can call AddRunsFromDirectory at the root of a tree of event logs and + TensorBoard will load them all. + + If the `EventMultiplexer` is already loaded this will cause + the newly created accumulators to `Reload()`. + Args: + path: A string path to a directory to load runs from. + name: Optionally, what name to apply to the runs. If name is provided + and the directory contains run subdirectories, the name of each subrun + is the concatenation of the parent name and the subdirectory name. If + name is provided and the directory contains event files, then a run + is added called "name" and with the events from the path. + + Raises: + ValueError: If the path exists and isn't a directory. + + Returns: + The `EventMultiplexer`. + """ + logger.info("Starting AddRunsFromDirectory: %s", path) + for subdir in io_wrapper.GetLogdirSubdirectories(path): + logger.info("Adding run from directory %s", subdir) + rpath = os.path.relpath(subdir, path) + subname = os.path.join(name, rpath) if name else rpath + self.AddRun(subdir, name=subname) + logger.info("Done with AddRunsFromDirectory: %s", path) + return self + + def Reload(self): + """Call `Reload` on every `EventAccumulator`.""" + logger.info("Beginning EventMultiplexer.Reload()") + self._reload_called = True + # Build a list so we're safe even if the list of accumulators is modified + # even while we're reloading. + with self._accumulators_mutex: + items = list(self._accumulators.items()) + items_queue = queue.Queue() + for item in items: + items_queue.put(item) + + # Methods of built-in python containers are thread-safe so long as the GIL + # for the thread exists, but we might as well be careful. + names_to_delete = set() + names_to_delete_mutex = threading.Lock() + + def Worker(): + """Keeps reloading accumulators til none are left.""" + while True: + try: + name, accumulator = items_queue.get(block=False) + except queue.Empty: + # No more runs to reload. + break + + try: + accumulator.Reload() + except (OSError, IOError) as e: + logger.error("Unable to reload accumulator %r: %s", name, e) + except directory_watcher.DirectoryDeletedError: + with names_to_delete_mutex: + names_to_delete.add(name) + finally: + items_queue.task_done() + + if self._max_reload_threads > 1: + num_threads = min(self._max_reload_threads, len(items)) + logger.info("Starting %d threads to reload runs", num_threads) + for i in xrange(num_threads): + thread = threading.Thread(target=Worker, name="Reloader %d" % i) + thread.daemon = True + thread.start() + items_queue.join() + else: + logger.info( + "Reloading runs serially (one after another) on the main " + "thread." + ) + Worker() + + with self._accumulators_mutex: + for name in names_to_delete: + logger.warn("Deleting accumulator %r", name) + del self._accumulators[name] + logger.info("Finished with EventMultiplexer.Reload()") + return self + + def PluginAssets(self, plugin_name): + """Get index of runs and assets for a given plugin. + + Args: + plugin_name: Name of the plugin we are checking for. + + Returns: + A dictionary that maps from run_name to a list of plugin + assets for that run. + """ + with self._accumulators_mutex: + # To avoid nested locks, we construct a copy of the run-accumulator map + items = list(six.iteritems(self._accumulators)) + + return {run: accum.PluginAssets(plugin_name) for run, accum in items} + + def RetrievePluginAsset(self, run, plugin_name, asset_name): + """Return the contents for a specific plugin asset from a run. + + Args: + run: The string name of the run. + plugin_name: The string name of a plugin. + asset_name: The string name of an asset. + + Returns: + The string contents of the plugin asset. + + Raises: + KeyError: If the asset is not available. + """ + accumulator = self.GetAccumulator(run) + return accumulator.RetrievePluginAsset(plugin_name, asset_name) + + def FirstEventTimestamp(self, run): + """Return the timestamp of the first event of the given run. + + This may perform I/O if no events have been loaded yet for the run. + + Args: + run: A string name of the run for which the timestamp is retrieved. + + Returns: + The wall_time of the first event of the run, which will typically be + seconds since the epoch. + + Raises: + KeyError: If the run is not found. + ValueError: If the run has no events loaded and there are no events on + disk to load. + """ + accumulator = self.GetAccumulator(run) + return accumulator.FirstEventTimestamp() + + def Graph(self, run): + """Retrieve the graph associated with the provided run. + + Args: + run: A string name of a run to load the graph for. + + Raises: + KeyError: If the run is not found. + ValueError: If the run does not have an associated graph. + + Returns: + The `GraphDef` protobuf data structure. + """ + accumulator = self.GetAccumulator(run) + return accumulator.Graph() + + def SerializedGraph(self, run): + """Retrieve the serialized graph associated with the provided run. + + Args: + run: A string name of a run to load the graph for. + + Raises: + KeyError: If the run is not found. + ValueError: If the run does not have an associated graph. + + Returns: + The serialized form of the `GraphDef` protobuf data structure. + """ + accumulator = self.GetAccumulator(run) + return accumulator.SerializedGraph() + + def MetaGraph(self, run): + """Retrieve the metagraph associated with the provided run. + + Args: + run: A string name of a run to load the graph for. + + Raises: + KeyError: If the run is not found. + ValueError: If the run does not have an associated graph. + + Returns: + The `MetaGraphDef` protobuf data structure. + """ + accumulator = self.GetAccumulator(run) + return accumulator.MetaGraph() + + def RunMetadata(self, run, tag): + """Get the session.run() metadata associated with a TensorFlow run and + tag. + + Args: + run: A string name of a TensorFlow run. + tag: A string name of the tag associated with a particular session.run(). + + Raises: + KeyError: If the run is not found, or the tag is not available for the + given run. + + Returns: + The metadata in the form of `RunMetadata` protobuf data structure. + """ + accumulator = self.GetAccumulator(run) + return accumulator.RunMetadata(tag) + + def Tensors(self, run, tag): + """Retrieve the tensor events associated with a run and tag. + + Args: + run: A string name of the run for which values are retrieved. + tag: A string name of the tag for which values are retrieved. + + Raises: + KeyError: If the run is not found, or the tag is not available for + the given run. + + Returns: + An array of `event_accumulator.TensorEvent`s. + """ + accumulator = self.GetAccumulator(run) + return accumulator.Tensors(tag) + + def PluginRunToTagToContent(self, plugin_name): + """Returns a 2-layer dictionary of the form {run: {tag: content}}. + + The `content` referred above is the content field of the PluginData proto + for the specified plugin within a Summary.Value proto. + + Args: + plugin_name: The name of the plugin for which to fetch content. + + Returns: + A dictionary of the form {run: {tag: content}}. + """ + mapping = {} + for run in self.Runs(): + try: + tag_to_content = self.GetAccumulator(run).PluginTagToContent( + plugin_name + ) + except KeyError: + # This run lacks content for the plugin. Try the next run. + continue + mapping[run] = tag_to_content + return mapping + + def SummaryMetadata(self, run, tag): + """Return the summary metadata for the given tag on the given run. + + Args: + run: A string name of the run for which summary metadata is to be + retrieved. + tag: A string name of the tag whose summary metadata is to be + retrieved. + + Raises: + KeyError: If the run is not found, or the tag is not available for + the given run. + + Returns: + A `SummaryMetadata` protobuf. + """ + accumulator = self.GetAccumulator(run) + return accumulator.SummaryMetadata(tag) + + def Runs(self): + """Return all the run names in the `EventMultiplexer`. + + Returns: + ``` + {runName: { scalarValues: [tagA, tagB, tagC], + graph: true, meta_graph: true}} + ``` + """ + with self._accumulators_mutex: + # To avoid nested locks, we construct a copy of the run-accumulator map + items = list(six.iteritems(self._accumulators)) + return {run_name: accumulator.Tags() for run_name, accumulator in items} + + def RunPaths(self): + """Returns a dict mapping run names to event file paths.""" + return self._paths + + def GetAccumulator(self, run): + """Returns EventAccumulator for a given run. + + Args: + run: String name of run. + + Returns: + An EventAccumulator object. + + Raises: + KeyError: If run does not exist. + """ + with self._accumulators_mutex: + return self._accumulators[run] diff --git a/tensorboard/backend/event_processing/plugin_event_multiplexer_test.py b/tensorboard/backend/event_processing/plugin_event_multiplexer_test.py index 0614ed1895..62207e631d 100644 --- a/tensorboard/backend/event_processing/plugin_event_multiplexer_test.py +++ b/tensorboard/backend/event_processing/plugin_event_multiplexer_test.py @@ -23,402 +23,439 @@ import tensorflow as tf -from tensorboard.backend.event_processing import plugin_event_accumulator as event_accumulator -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from tensorboard.backend.event_processing import ( + plugin_event_accumulator as event_accumulator, +) +from tensorboard.backend.event_processing import ( + plugin_event_multiplexer as event_multiplexer, +) from tensorboard.util import test_util def _AddEvents(path): - if not tf.io.gfile.isdir(path): - tf.io.gfile.makedirs(path) - fpath = os.path.join(path, 'hypothetical.tfevents.out') - with tf.io.gfile.GFile(fpath, 'w') as f: - f.write('') - return fpath + if not tf.io.gfile.isdir(path): + tf.io.gfile.makedirs(path) + fpath = os.path.join(path, "hypothetical.tfevents.out") + with tf.io.gfile.GFile(fpath, "w") as f: + f.write("") + return fpath def _CreateCleanDirectory(path): - if tf.io.gfile.isdir(path): - tf.io.gfile.rmtree(path) - tf.io.gfile.mkdir(path) + if tf.io.gfile.isdir(path): + tf.io.gfile.rmtree(path) + tf.io.gfile.mkdir(path) class _FakeAccumulator(object): - - def __init__(self, path): - """Constructs a fake accumulator with some fake events. - - Args: - path: The path for the run that this accumulator is for. - """ - self._path = path - self.reload_called = False - self._plugin_to_tag_to_content = { - 'baz_plugin': { - 'foo': 'foo_content', - 'bar': 'bar_content', + def __init__(self, path): + """Constructs a fake accumulator with some fake events. + + Args: + path: The path for the run that this accumulator is for. + """ + self._path = path + self.reload_called = False + self._plugin_to_tag_to_content = { + "baz_plugin": {"foo": "foo_content", "bar": "bar_content",} } - } - def Tags(self): - return {} + def Tags(self): + return {} - def FirstEventTimestamp(self): - return 0 + def FirstEventTimestamp(self): + return 0 - def _TagHelper(self, tag_name, enum): - if tag_name not in self.Tags()[enum]: - raise KeyError - return ['%s/%s' % (self._path, tag_name)] + def _TagHelper(self, tag_name, enum): + if tag_name not in self.Tags()[enum]: + raise KeyError + return ["%s/%s" % (self._path, tag_name)] - def Tensors(self, tag_name): - return self._TagHelper(tag_name, event_accumulator.TENSORS) + def Tensors(self, tag_name): + return self._TagHelper(tag_name, event_accumulator.TENSORS) - def PluginTagToContent(self, plugin_name): - # We pre-pend the runs with the path and '_' so that we can verify that the - # tags are associated with the correct runs. - return { - self._path + '_' + run: content_mapping - for (run, content_mapping - ) in self._plugin_to_tag_to_content[plugin_name].items() - } + def PluginTagToContent(self, plugin_name): + # We pre-pend the runs with the path and '_' so that we can verify that the + # tags are associated with the correct runs. + return { + self._path + "_" + run: content_mapping + for (run, content_mapping) in self._plugin_to_tag_to_content[ + plugin_name + ].items() + } - def Reload(self): - self.reload_called = True + def Reload(self): + self.reload_called = True -def _GetFakeAccumulator(path, - size_guidance=None, - tensor_size_guidance=None, - purge_orphaned_data=None, - event_file_active_filter=None): - del size_guidance, tensor_size_guidance, purge_orphaned_data # Unused. - del event_file_active_filter # unused - return _FakeAccumulator(path) +def _GetFakeAccumulator( + path, + size_guidance=None, + tensor_size_guidance=None, + purge_orphaned_data=None, + event_file_active_filter=None, +): + del size_guidance, tensor_size_guidance, purge_orphaned_data # Unused. + del event_file_active_filter # unused + return _FakeAccumulator(path) class EventMultiplexerTest(tf.test.TestCase): - - def setUp(self): - super(EventMultiplexerTest, self).setUp() - self.stubs = tf.compat.v1.test.StubOutForTesting() - - self.stubs.Set(event_accumulator, 'EventAccumulator', _GetFakeAccumulator) - - def tearDown(self): - self.stubs.CleanUp() - - def testEmptyLoader(self): - """Tests empty EventMultiplexer creation.""" - x = event_multiplexer.EventMultiplexer() - self.assertEqual(x.Runs(), {}) - - def testRunNamesRespected(self): - """Tests two EventAccumulators inserted/accessed in EventMultiplexer.""" - x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'}) - self.assertItemsEqual(sorted(x.Runs().keys()), ['run1', 'run2']) - self.assertEqual(x.GetAccumulator('run1')._path, 'path1') - self.assertEqual(x.GetAccumulator('run2')._path, 'path2') - - def testReload(self): - """EventAccumulators should Reload after EventMultiplexer call it.""" - x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'}) - self.assertFalse(x.GetAccumulator('run1').reload_called) - self.assertFalse(x.GetAccumulator('run2').reload_called) - x.Reload() - self.assertTrue(x.GetAccumulator('run1').reload_called) - self.assertTrue(x.GetAccumulator('run2').reload_called) - - def testPluginRunToTagToContent(self): - """Tests the method that produces the run to tag to content mapping.""" - x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'}) - self.assertDictEqual({ - 'run1': { - 'path1_foo': 'foo_content', - 'path1_bar': 'bar_content', - }, - 'run2': { - 'path2_foo': 'foo_content', - 'path2_bar': 'bar_content', - } - }, x.PluginRunToTagToContent('baz_plugin')) - - def testExceptions(self): - """KeyError should be raised when accessing non-existing keys.""" - x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'}) - with self.assertRaises(KeyError): - x.Tensors('sv1', 'xxx') - - def testInitialization(self): - """Tests EventMultiplexer is created properly with its params.""" - x = event_multiplexer.EventMultiplexer() - self.assertEqual(x.Runs(), {}) - x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'}) - self.assertItemsEqual(x.Runs(), ['run1', 'run2']) - self.assertEqual(x.GetAccumulator('run1')._path, 'path1') - self.assertEqual(x.GetAccumulator('run2')._path, 'path2') - - def testAddRunsFromDirectory(self): - """Tests AddRunsFromDirectory function. - - Tests the following scenarios: - - When the directory does not exist. - - When the directory is empty. - - When the directory has empty subdirectory. - - Contains proper EventAccumulators after adding events. - """ - x = event_multiplexer.EventMultiplexer() - tmpdir = self.get_temp_dir() - join = os.path.join - fakedir = join(tmpdir, 'fake_accumulator_directory') - realdir = join(tmpdir, 'real_accumulator_directory') - self.assertEqual(x.Runs(), {}) - x.AddRunsFromDirectory(fakedir) - self.assertEqual(x.Runs(), {}, 'loading fakedir had no effect') - - _CreateCleanDirectory(realdir) - x.AddRunsFromDirectory(realdir) - self.assertEqual(x.Runs(), {}, 'loading empty directory had no effect') - - path1 = join(realdir, 'path1') - tf.io.gfile.mkdir(path1) - x.AddRunsFromDirectory(realdir) - self.assertEqual(x.Runs(), {}, 'creating empty subdirectory had no effect') - - _AddEvents(path1) - x.AddRunsFromDirectory(realdir) - self.assertItemsEqual(x.Runs(), ['path1'], 'loaded run: path1') - loader1 = x.GetAccumulator('path1') - self.assertEqual(loader1._path, path1, 'has the correct path') - - path2 = join(realdir, 'path2') - _AddEvents(path2) - x.AddRunsFromDirectory(realdir) - self.assertItemsEqual(x.Runs(), ['path1', 'path2']) - self.assertEqual( - x.GetAccumulator('path1'), loader1, 'loader1 not regenerated') - - path2_2 = join(path2, 'path2') - _AddEvents(path2_2) - x.AddRunsFromDirectory(realdir) - self.assertItemsEqual(x.Runs(), ['path1', 'path2', 'path2/path2']) - self.assertEqual( - x.GetAccumulator('path2/path2')._path, path2_2, 'loader2 path correct') - - def testAddRunsFromDirectoryThatContainsEvents(self): - x = event_multiplexer.EventMultiplexer() - tmpdir = self.get_temp_dir() - join = os.path.join - realdir = join(tmpdir, 'event_containing_directory') - - _CreateCleanDirectory(realdir) - - self.assertEqual(x.Runs(), {}) - - _AddEvents(realdir) - x.AddRunsFromDirectory(realdir) - self.assertItemsEqual(x.Runs(), ['.']) - - subdir = join(realdir, 'subdir') - _AddEvents(subdir) - x.AddRunsFromDirectory(realdir) - self.assertItemsEqual(x.Runs(), ['.', 'subdir']) - - def testAddRunsFromDirectoryWithRunNames(self): - x = event_multiplexer.EventMultiplexer() - tmpdir = self.get_temp_dir() - join = os.path.join - realdir = join(tmpdir, 'event_containing_directory') - - _CreateCleanDirectory(realdir) - - self.assertEqual(x.Runs(), {}) - - _AddEvents(realdir) - x.AddRunsFromDirectory(realdir, 'foo') - self.assertItemsEqual(x.Runs(), ['foo/.']) - - subdir = join(realdir, 'subdir') - _AddEvents(subdir) - x.AddRunsFromDirectory(realdir, 'foo') - self.assertItemsEqual(x.Runs(), ['foo/.', 'foo/subdir']) - - def testAddRunsFromDirectoryWalksTree(self): - x = event_multiplexer.EventMultiplexer() - tmpdir = self.get_temp_dir() - join = os.path.join - realdir = join(tmpdir, 'event_containing_directory') - - _CreateCleanDirectory(realdir) - _AddEvents(realdir) - sub = join(realdir, 'subdirectory') - sub1 = join(sub, '1') - sub2 = join(sub, '2') - sub1_1 = join(sub1, '1') - _AddEvents(sub1) - _AddEvents(sub2) - _AddEvents(sub1_1) - x.AddRunsFromDirectory(realdir) - - self.assertItemsEqual(x.Runs(), ['.', 'subdirectory/1', 'subdirectory/2', - 'subdirectory/1/1']) - - def testAddRunsFromDirectoryThrowsException(self): - x = event_multiplexer.EventMultiplexer() - tmpdir = self.get_temp_dir() - - filepath = _AddEvents(tmpdir) - with self.assertRaises(ValueError): - x.AddRunsFromDirectory(filepath) - - def testAddRun(self): - x = event_multiplexer.EventMultiplexer() - x.AddRun('run1_path', 'run1') - run1 = x.GetAccumulator('run1') - self.assertEqual(sorted(x.Runs().keys()), ['run1']) - self.assertEqual(run1._path, 'run1_path') - - x.AddRun('run1_path', 'run1') - self.assertEqual(run1, x.GetAccumulator('run1'), 'loader not recreated') - - x.AddRun('run2_path', 'run1') - new_run1 = x.GetAccumulator('run1') - self.assertEqual(new_run1._path, 'run2_path') - self.assertNotEqual(run1, new_run1) - - x.AddRun('runName3') - self.assertItemsEqual(sorted(x.Runs().keys()), ['run1', 'runName3']) - self.assertEqual(x.GetAccumulator('runName3')._path, 'runName3') - - def testAddRunMaintainsLoading(self): - x = event_multiplexer.EventMultiplexer() - x.Reload() - x.AddRun('run1') - x.AddRun('run2') - self.assertTrue(x.GetAccumulator('run1').reload_called) - self.assertTrue(x.GetAccumulator('run2').reload_called) - - def testAddReloadWithMultipleThreads(self): - x = event_multiplexer.EventMultiplexer(max_reload_threads=2) - x.Reload() - x.AddRun('run1') - x.AddRun('run2') - x.AddRun('run3') - self.assertTrue(x.GetAccumulator('run1').reload_called) - self.assertTrue(x.GetAccumulator('run2').reload_called) - self.assertTrue(x.GetAccumulator('run3').reload_called) - - def testReloadWithMoreRunsThanThreads(self): - patcher = tf.compat.v1.test.mock.patch('threading.Thread.start', autospec=True) - start_mock = patcher.start() - self.addCleanup(patcher.stop) - patcher = tf.compat.v1.test.mock.patch( - 'six.moves.queue.Queue.join', autospec=True) - join_mock = patcher.start() - self.addCleanup(patcher.stop) - - x = event_multiplexer.EventMultiplexer(max_reload_threads=2) - x.AddRun('run1') - x.AddRun('run2') - x.AddRun('run3') - x.Reload() - - # 2 threads should have been started despite how there are 3 runs. - self.assertEqual(2, start_mock.call_count) - self.assertEqual(1, join_mock.call_count) - - def testReloadWithMoreThreadsThanRuns(self): - patcher = tf.compat.v1.test.mock.patch('threading.Thread.start', autospec=True) - start_mock = patcher.start() - self.addCleanup(patcher.stop) - patcher = tf.compat.v1.test.mock.patch( - 'six.moves.queue.Queue.join', autospec=True) - join_mock = patcher.start() - self.addCleanup(patcher.stop) - - x = event_multiplexer.EventMultiplexer(max_reload_threads=42) - x.AddRun('run1') - x.AddRun('run2') - x.AddRun('run3') - x.Reload() - - # 3 threads should have been started despite how the multiplexer - # could have started up to 42 threads. - self.assertEqual(3, start_mock.call_count) - self.assertEqual(1, join_mock.call_count) - - def testReloadWith1Thread(self): - patcher = tf.compat.v1.test.mock.patch('threading.Thread.start', autospec=True) - start_mock = patcher.start() - self.addCleanup(patcher.stop) - patcher = tf.compat.v1.test.mock.patch( - 'six.moves.queue.Queue.join', autospec=True) - join_mock = patcher.start() - self.addCleanup(patcher.stop) - - x = event_multiplexer.EventMultiplexer(max_reload_threads=1) - x.AddRun('run1') - x.AddRun('run2') - x.AddRun('run3') - x.Reload() - - # The multiplexer should have started no new threads. - self.assertEqual(0, start_mock.call_count) - self.assertEqual(0, join_mock.call_count) + def setUp(self): + super(EventMultiplexerTest, self).setUp() + self.stubs = tf.compat.v1.test.StubOutForTesting() + + self.stubs.Set( + event_accumulator, "EventAccumulator", _GetFakeAccumulator + ) + + def tearDown(self): + self.stubs.CleanUp() + + def testEmptyLoader(self): + """Tests empty EventMultiplexer creation.""" + x = event_multiplexer.EventMultiplexer() + self.assertEqual(x.Runs(), {}) + + def testRunNamesRespected(self): + """Tests two EventAccumulators inserted/accessed in + EventMultiplexer.""" + x = event_multiplexer.EventMultiplexer( + {"run1": "path1", "run2": "path2"} + ) + self.assertItemsEqual(sorted(x.Runs().keys()), ["run1", "run2"]) + self.assertEqual(x.GetAccumulator("run1")._path, "path1") + self.assertEqual(x.GetAccumulator("run2")._path, "path2") + + def testReload(self): + """EventAccumulators should Reload after EventMultiplexer call it.""" + x = event_multiplexer.EventMultiplexer( + {"run1": "path1", "run2": "path2"} + ) + self.assertFalse(x.GetAccumulator("run1").reload_called) + self.assertFalse(x.GetAccumulator("run2").reload_called) + x.Reload() + self.assertTrue(x.GetAccumulator("run1").reload_called) + self.assertTrue(x.GetAccumulator("run2").reload_called) + + def testPluginRunToTagToContent(self): + """Tests the method that produces the run to tag to content mapping.""" + x = event_multiplexer.EventMultiplexer( + {"run1": "path1", "run2": "path2"} + ) + self.assertDictEqual( + { + "run1": { + "path1_foo": "foo_content", + "path1_bar": "bar_content", + }, + "run2": { + "path2_foo": "foo_content", + "path2_bar": "bar_content", + }, + }, + x.PluginRunToTagToContent("baz_plugin"), + ) + + def testExceptions(self): + """KeyError should be raised when accessing non-existing keys.""" + x = event_multiplexer.EventMultiplexer( + {"run1": "path1", "run2": "path2"} + ) + with self.assertRaises(KeyError): + x.Tensors("sv1", "xxx") + + def testInitialization(self): + """Tests EventMultiplexer is created properly with its params.""" + x = event_multiplexer.EventMultiplexer() + self.assertEqual(x.Runs(), {}) + x = event_multiplexer.EventMultiplexer( + {"run1": "path1", "run2": "path2"} + ) + self.assertItemsEqual(x.Runs(), ["run1", "run2"]) + self.assertEqual(x.GetAccumulator("run1")._path, "path1") + self.assertEqual(x.GetAccumulator("run2")._path, "path2") + + def testAddRunsFromDirectory(self): + """Tests AddRunsFromDirectory function. + + Tests the following scenarios: + - When the directory does not exist. + - When the directory is empty. + - When the directory has empty subdirectory. + - Contains proper EventAccumulators after adding events. + """ + x = event_multiplexer.EventMultiplexer() + tmpdir = self.get_temp_dir() + join = os.path.join + fakedir = join(tmpdir, "fake_accumulator_directory") + realdir = join(tmpdir, "real_accumulator_directory") + self.assertEqual(x.Runs(), {}) + x.AddRunsFromDirectory(fakedir) + self.assertEqual(x.Runs(), {}, "loading fakedir had no effect") + + _CreateCleanDirectory(realdir) + x.AddRunsFromDirectory(realdir) + self.assertEqual(x.Runs(), {}, "loading empty directory had no effect") + + path1 = join(realdir, "path1") + tf.io.gfile.mkdir(path1) + x.AddRunsFromDirectory(realdir) + self.assertEqual( + x.Runs(), {}, "creating empty subdirectory had no effect" + ) + + _AddEvents(path1) + x.AddRunsFromDirectory(realdir) + self.assertItemsEqual(x.Runs(), ["path1"], "loaded run: path1") + loader1 = x.GetAccumulator("path1") + self.assertEqual(loader1._path, path1, "has the correct path") + + path2 = join(realdir, "path2") + _AddEvents(path2) + x.AddRunsFromDirectory(realdir) + self.assertItemsEqual(x.Runs(), ["path1", "path2"]) + self.assertEqual( + x.GetAccumulator("path1"), loader1, "loader1 not regenerated" + ) + + path2_2 = join(path2, "path2") + _AddEvents(path2_2) + x.AddRunsFromDirectory(realdir) + self.assertItemsEqual(x.Runs(), ["path1", "path2", "path2/path2"]) + self.assertEqual( + x.GetAccumulator("path2/path2")._path, + path2_2, + "loader2 path correct", + ) + + def testAddRunsFromDirectoryThatContainsEvents(self): + x = event_multiplexer.EventMultiplexer() + tmpdir = self.get_temp_dir() + join = os.path.join + realdir = join(tmpdir, "event_containing_directory") + + _CreateCleanDirectory(realdir) + + self.assertEqual(x.Runs(), {}) + + _AddEvents(realdir) + x.AddRunsFromDirectory(realdir) + self.assertItemsEqual(x.Runs(), ["."]) + + subdir = join(realdir, "subdir") + _AddEvents(subdir) + x.AddRunsFromDirectory(realdir) + self.assertItemsEqual(x.Runs(), [".", "subdir"]) + + def testAddRunsFromDirectoryWithRunNames(self): + x = event_multiplexer.EventMultiplexer() + tmpdir = self.get_temp_dir() + join = os.path.join + realdir = join(tmpdir, "event_containing_directory") + + _CreateCleanDirectory(realdir) + + self.assertEqual(x.Runs(), {}) + + _AddEvents(realdir) + x.AddRunsFromDirectory(realdir, "foo") + self.assertItemsEqual(x.Runs(), ["foo/."]) + + subdir = join(realdir, "subdir") + _AddEvents(subdir) + x.AddRunsFromDirectory(realdir, "foo") + self.assertItemsEqual(x.Runs(), ["foo/.", "foo/subdir"]) + + def testAddRunsFromDirectoryWalksTree(self): + x = event_multiplexer.EventMultiplexer() + tmpdir = self.get_temp_dir() + join = os.path.join + realdir = join(tmpdir, "event_containing_directory") + + _CreateCleanDirectory(realdir) + _AddEvents(realdir) + sub = join(realdir, "subdirectory") + sub1 = join(sub, "1") + sub2 = join(sub, "2") + sub1_1 = join(sub1, "1") + _AddEvents(sub1) + _AddEvents(sub2) + _AddEvents(sub1_1) + x.AddRunsFromDirectory(realdir) + + self.assertItemsEqual( + x.Runs(), + [".", "subdirectory/1", "subdirectory/2", "subdirectory/1/1"], + ) + + def testAddRunsFromDirectoryThrowsException(self): + x = event_multiplexer.EventMultiplexer() + tmpdir = self.get_temp_dir() + + filepath = _AddEvents(tmpdir) + with self.assertRaises(ValueError): + x.AddRunsFromDirectory(filepath) + + def testAddRun(self): + x = event_multiplexer.EventMultiplexer() + x.AddRun("run1_path", "run1") + run1 = x.GetAccumulator("run1") + self.assertEqual(sorted(x.Runs().keys()), ["run1"]) + self.assertEqual(run1._path, "run1_path") + + x.AddRun("run1_path", "run1") + self.assertEqual(run1, x.GetAccumulator("run1"), "loader not recreated") + + x.AddRun("run2_path", "run1") + new_run1 = x.GetAccumulator("run1") + self.assertEqual(new_run1._path, "run2_path") + self.assertNotEqual(run1, new_run1) + + x.AddRun("runName3") + self.assertItemsEqual(sorted(x.Runs().keys()), ["run1", "runName3"]) + self.assertEqual(x.GetAccumulator("runName3")._path, "runName3") + + def testAddRunMaintainsLoading(self): + x = event_multiplexer.EventMultiplexer() + x.Reload() + x.AddRun("run1") + x.AddRun("run2") + self.assertTrue(x.GetAccumulator("run1").reload_called) + self.assertTrue(x.GetAccumulator("run2").reload_called) + + def testAddReloadWithMultipleThreads(self): + x = event_multiplexer.EventMultiplexer(max_reload_threads=2) + x.Reload() + x.AddRun("run1") + x.AddRun("run2") + x.AddRun("run3") + self.assertTrue(x.GetAccumulator("run1").reload_called) + self.assertTrue(x.GetAccumulator("run2").reload_called) + self.assertTrue(x.GetAccumulator("run3").reload_called) + + def testReloadWithMoreRunsThanThreads(self): + patcher = tf.compat.v1.test.mock.patch( + "threading.Thread.start", autospec=True + ) + start_mock = patcher.start() + self.addCleanup(patcher.stop) + patcher = tf.compat.v1.test.mock.patch( + "six.moves.queue.Queue.join", autospec=True + ) + join_mock = patcher.start() + self.addCleanup(patcher.stop) + + x = event_multiplexer.EventMultiplexer(max_reload_threads=2) + x.AddRun("run1") + x.AddRun("run2") + x.AddRun("run3") + x.Reload() + + # 2 threads should have been started despite how there are 3 runs. + self.assertEqual(2, start_mock.call_count) + self.assertEqual(1, join_mock.call_count) + + def testReloadWithMoreThreadsThanRuns(self): + patcher = tf.compat.v1.test.mock.patch( + "threading.Thread.start", autospec=True + ) + start_mock = patcher.start() + self.addCleanup(patcher.stop) + patcher = tf.compat.v1.test.mock.patch( + "six.moves.queue.Queue.join", autospec=True + ) + join_mock = patcher.start() + self.addCleanup(patcher.stop) + + x = event_multiplexer.EventMultiplexer(max_reload_threads=42) + x.AddRun("run1") + x.AddRun("run2") + x.AddRun("run3") + x.Reload() + + # 3 threads should have been started despite how the multiplexer + # could have started up to 42 threads. + self.assertEqual(3, start_mock.call_count) + self.assertEqual(1, join_mock.call_count) + + def testReloadWith1Thread(self): + patcher = tf.compat.v1.test.mock.patch( + "threading.Thread.start", autospec=True + ) + start_mock = patcher.start() + self.addCleanup(patcher.stop) + patcher = tf.compat.v1.test.mock.patch( + "six.moves.queue.Queue.join", autospec=True + ) + join_mock = patcher.start() + self.addCleanup(patcher.stop) + + x = event_multiplexer.EventMultiplexer(max_reload_threads=1) + x.AddRun("run1") + x.AddRun("run2") + x.AddRun("run3") + x.Reload() + + # The multiplexer should have started no new threads. + self.assertEqual(0, start_mock.call_count) + self.assertEqual(0, join_mock.call_count) class EventMultiplexerWithRealAccumulatorTest(tf.test.TestCase): - - def testMultifileReload(self): - multiplexer = event_multiplexer.EventMultiplexer( - event_file_active_filter=lambda timestamp: True) - logdir = self.get_temp_dir() - run_name = 'run1' - run_path = os.path.join(logdir, run_name) - # Create two separate event files, using filename suffix to ensure a - # deterministic sort order, and then simulate a write to file A, then - # to file B, then another write to file A (with reloads after each). - with test_util.FileWriter(run_path, filename_suffix='.a') as writer_a: - writer_a.add_test_summary('a1', step=1) - writer_a.flush() - multiplexer.AddRunsFromDirectory(logdir) - multiplexer.Reload() - with test_util.FileWriter(run_path, filename_suffix='.b') as writer_b: - writer_b.add_test_summary('b', step=1) - multiplexer.Reload() - writer_a.add_test_summary('a2', step=2) - writer_a.flush() - multiplexer.Reload() - # Both event files should be treated as active, so we should load the newly - # written data to the first file even though it's no longer the latest one. - self.assertEqual(1, len(multiplexer.Tensors(run_name, 'a1'))) - self.assertEqual(1, len(multiplexer.Tensors(run_name, 'b'))) - self.assertEqual(1, len(multiplexer.Tensors(run_name, 'a2'))) - - def testDeletingDirectoryRemovesRun(self): - x = event_multiplexer.EventMultiplexer() - tmpdir = self.get_temp_dir() - self._add3RunsToMultiplexer(tmpdir, x) - x.Reload() - - # Delete the directory, then reload. - shutil.rmtree(os.path.join(tmpdir, 'run2')) - x.Reload() - self.assertNotIn('run2', x.Runs().keys()) - - def _add3RunsToMultiplexer(self, logdir, multiplexer): - """Creates and adds 3 runs to the multiplexer.""" - run1_dir = os.path.join(logdir, 'run1') - run2_dir = os.path.join(logdir, 'run2') - run3_dir = os.path.join(logdir, 'run3') - - for dirname in [run1_dir, run2_dir, run3_dir]: - _AddEvents(dirname) - - multiplexer.AddRun(run1_dir, 'run1') - multiplexer.AddRun(run2_dir, 'run2') - multiplexer.AddRun(run3_dir, 'run3') - - -if __name__ == '__main__': - tf.test.main() + def testMultifileReload(self): + multiplexer = event_multiplexer.EventMultiplexer( + event_file_active_filter=lambda timestamp: True + ) + logdir = self.get_temp_dir() + run_name = "run1" + run_path = os.path.join(logdir, run_name) + # Create two separate event files, using filename suffix to ensure a + # deterministic sort order, and then simulate a write to file A, then + # to file B, then another write to file A (with reloads after each). + with test_util.FileWriter(run_path, filename_suffix=".a") as writer_a: + writer_a.add_test_summary("a1", step=1) + writer_a.flush() + multiplexer.AddRunsFromDirectory(logdir) + multiplexer.Reload() + with test_util.FileWriter( + run_path, filename_suffix=".b" + ) as writer_b: + writer_b.add_test_summary("b", step=1) + multiplexer.Reload() + writer_a.add_test_summary("a2", step=2) + writer_a.flush() + multiplexer.Reload() + # Both event files should be treated as active, so we should load the newly + # written data to the first file even though it's no longer the latest one. + self.assertEqual(1, len(multiplexer.Tensors(run_name, "a1"))) + self.assertEqual(1, len(multiplexer.Tensors(run_name, "b"))) + self.assertEqual(1, len(multiplexer.Tensors(run_name, "a2"))) + + def testDeletingDirectoryRemovesRun(self): + x = event_multiplexer.EventMultiplexer() + tmpdir = self.get_temp_dir() + self._add3RunsToMultiplexer(tmpdir, x) + x.Reload() + + # Delete the directory, then reload. + shutil.rmtree(os.path.join(tmpdir, "run2")) + x.Reload() + self.assertNotIn("run2", x.Runs().keys()) + + def _add3RunsToMultiplexer(self, logdir, multiplexer): + """Creates and adds 3 runs to the multiplexer.""" + run1_dir = os.path.join(logdir, "run1") + run2_dir = os.path.join(logdir, "run2") + run3_dir = os.path.join(logdir, "run3") + + for dirname in [run1_dir, run2_dir, run3_dir]: + _AddEvents(dirname) + + multiplexer.AddRun(run1_dir, "run1") + multiplexer.AddRun(run2_dir, "run2") + multiplexer.AddRun(run3_dir, "run3") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/backend/event_processing/reservoir.py b/tensorboard/backend/event_processing/reservoir.py index f0ec95689f..74648e2021 100644 --- a/tensorboard/backend/event_processing/reservoir.py +++ b/tensorboard/backend/event_processing/reservoir.py @@ -25,7 +25,7 @@ class Reservoir(object): - """A map-to-arrays container, with deterministic Reservoir Sampling. + """A map-to-arrays container, with deterministic Reservoir Sampling. Items are added with an associated key. Items may be retrieved by key, and a list of keys can also be retrieved. If size is not zero, then it dictates @@ -57,203 +57,214 @@ class Reservoir(object): always_keep_last: Whether the latest seen sample is always at the end of the reservoir. Defaults to True. size: An integer of the maximum number of samples. - """ - - def __init__(self, size, seed=0, always_keep_last=True): - """Creates a new reservoir. - - Args: - size: The number of values to keep in the reservoir for each tag. If 0, - all values will be kept. - seed: The seed of the random number generator to use when sampling. - Different values for |seed| will produce different samples from the same - input items. - always_keep_last: Whether to always keep the latest seen item in the - end of the reservoir. Defaults to True. - - Raises: - ValueError: If size is negative or not an integer. """ - if size < 0 or size != round(size): - raise ValueError('size must be nonnegative integer, was %s' % size) - self._buckets = collections.defaultdict( - lambda: _ReservoirBucket(size, random.Random(seed), always_keep_last)) - # _mutex guards the keys - creating new keys, retrieving by key, etc - # the internal items are guarded by the ReservoirBuckets' internal mutexes - self._mutex = threading.Lock() - self.size = size - self.always_keep_last = always_keep_last - - def Keys(self): - """Return all the keys in the reservoir. - - Returns: - ['list', 'of', 'keys'] in the Reservoir. - """ - with self._mutex: - return list(self._buckets.keys()) - - def Items(self, key): - """Return items associated with given key. - - Args: - key: The key for which we are finding associated items. - - Raises: - KeyError: If the key is not found in the reservoir. - - Returns: - [list, of, items] associated with that key. - """ - with self._mutex: - if key not in self._buckets: - raise KeyError('Key %s was not found in Reservoir' % key) - bucket = self._buckets[key] - return bucket.Items() - def AddItem(self, key, item, f=lambda x: x): - """Add a new item to the Reservoir with the given tag. + def __init__(self, size, seed=0, always_keep_last=True): + """Creates a new reservoir. + + Args: + size: The number of values to keep in the reservoir for each tag. If 0, + all values will be kept. + seed: The seed of the random number generator to use when sampling. + Different values for |seed| will produce different samples from the same + input items. + always_keep_last: Whether to always keep the latest seen item in the + end of the reservoir. Defaults to True. + + Raises: + ValueError: If size is negative or not an integer. + """ + if size < 0 or size != round(size): + raise ValueError("size must be nonnegative integer, was %s" % size) + self._buckets = collections.defaultdict( + lambda: _ReservoirBucket( + size, random.Random(seed), always_keep_last + ) + ) + # _mutex guards the keys - creating new keys, retrieving by key, etc + # the internal items are guarded by the ReservoirBuckets' internal mutexes + self._mutex = threading.Lock() + self.size = size + self.always_keep_last = always_keep_last + + def Keys(self): + """Return all the keys in the reservoir. + + Returns: + ['list', 'of', 'keys'] in the Reservoir. + """ + with self._mutex: + return list(self._buckets.keys()) + + def Items(self, key): + """Return items associated with given key. + + Args: + key: The key for which we are finding associated items. + + Raises: + KeyError: If the key is not found in the reservoir. + + Returns: + [list, of, items] associated with that key. + """ + with self._mutex: + if key not in self._buckets: + raise KeyError("Key %s was not found in Reservoir" % key) + bucket = self._buckets[key] + return bucket.Items() + + def AddItem(self, key, item, f=lambda x: x): + """Add a new item to the Reservoir with the given tag. + + If the reservoir has not yet reached full size, the new item is guaranteed + to be added. If the reservoir is full, then behavior depends on the + always_keep_last boolean. + + If always_keep_last was set to true, the new item is guaranteed to be added + to the reservoir, and either the previous last item will be replaced, or + (with low probability) an older item will be replaced. + + If always_keep_last was set to false, then the new item will replace an + old item with low probability. + + If f is provided, it will be applied to transform item (lazily, iff item is + going to be included in the reservoir). + + Args: + key: The key to store the item under. + item: The item to add to the reservoir. + f: An optional function to transform the item prior to addition. + """ + with self._mutex: + bucket = self._buckets[key] + bucket.AddItem(item, f) + + def FilterItems(self, filterFn, key=None): + """Filter items within a Reservoir, using a filtering function. + + Args: + filterFn: A function that returns True for the items to be kept. + key: An optional bucket key to filter. If not specified, will filter all + all buckets. + + Returns: + The number of items removed. + """ + with self._mutex: + if key: + if key in self._buckets: + return self._buckets[key].FilterItems(filterFn) + else: + return 0 + else: + return sum( + bucket.FilterItems(filterFn) + for bucket in self._buckets.values() + ) - If the reservoir has not yet reached full size, the new item is guaranteed - to be added. If the reservoir is full, then behavior depends on the - always_keep_last boolean. - If always_keep_last was set to true, the new item is guaranteed to be added - to the reservoir, and either the previous last item will be replaced, or - (with low probability) an older item will be replaced. - - If always_keep_last was set to false, then the new item will replace an - old item with low probability. - - If f is provided, it will be applied to transform item (lazily, iff item is - going to be included in the reservoir). +class _ReservoirBucket(object): + """A container for items from a stream, that implements reservoir sampling. - Args: - key: The key to store the item under. - item: The item to add to the reservoir. - f: An optional function to transform the item prior to addition. + It always stores the most recent item as its final item. """ - with self._mutex: - bucket = self._buckets[key] - bucket.AddItem(item, f) - - def FilterItems(self, filterFn, key=None): - """Filter items within a Reservoir, using a filtering function. - - Args: - filterFn: A function that returns True for the items to be kept. - key: An optional bucket key to filter. If not specified, will filter all - all buckets. - Returns: - The number of items removed. - """ - with self._mutex: - if key: - if key in self._buckets: - return self._buckets[key].FilterItems(filterFn) + def __init__(self, _max_size, _random=None, always_keep_last=True): + """Create the _ReservoirBucket. + + Args: + _max_size: The maximum size the reservoir bucket may grow to. If size is + zero, the bucket has unbounded size. + _random: The random number generator to use. If not specified, defaults to + random.Random(0). + always_keep_last: Whether the latest seen item should always be included + in the end of the bucket. + + Raises: + ValueError: if the size is not a nonnegative integer. + """ + if _max_size < 0 or _max_size != round(_max_size): + raise ValueError( + "_max_size must be nonnegative int, was %s" % _max_size + ) + self.items = [] + # This mutex protects the internal items, ensuring that calls to Items and + # AddItem are thread-safe + self._mutex = threading.Lock() + self._max_size = _max_size + self._num_items_seen = 0 + if _random is not None: + self._random = _random else: - return 0 - else: - return sum(bucket.FilterItems(filterFn) - for bucket in self._buckets.values()) - - -class _ReservoirBucket(object): - """A container for items from a stream, that implements reservoir sampling. - - It always stores the most recent item as its final item. - """ - - def __init__(self, _max_size, _random=None, always_keep_last=True): - """Create the _ReservoirBucket. - - Args: - _max_size: The maximum size the reservoir bucket may grow to. If size is - zero, the bucket has unbounded size. - _random: The random number generator to use. If not specified, defaults to - random.Random(0). - always_keep_last: Whether the latest seen item should always be included - in the end of the bucket. - - Raises: - ValueError: if the size is not a nonnegative integer. - """ - if _max_size < 0 or _max_size != round(_max_size): - raise ValueError('_max_size must be nonnegative int, was %s' % _max_size) - self.items = [] - # This mutex protects the internal items, ensuring that calls to Items and - # AddItem are thread-safe - self._mutex = threading.Lock() - self._max_size = _max_size - self._num_items_seen = 0 - if _random is not None: - self._random = _random - else: - self._random = random.Random(0) - self.always_keep_last = always_keep_last - - def AddItem(self, item, f=lambda x: x): - """Add an item to the ReservoirBucket, replacing an old item if necessary. - - The new item is guaranteed to be added to the bucket, and to be the last - element in the bucket. If the bucket has reached capacity, then an old item - will be replaced. With probability (_max_size/_num_items_seen) a random item - in the bucket will be popped out and the new item will be appended - to the end. With probability (1 - _max_size/_num_items_seen) - the last item in the bucket will be replaced. - - Since the O(n) replacements occur with O(1/_num_items_seen) likelihood, - the amortized runtime is O(1). - - Args: - item: The item to add to the bucket. - f: A function to transform item before addition, if it will be kept in - the reservoir. - """ - with self._mutex: - if len(self.items) < self._max_size or self._max_size == 0: - self.items.append(f(item)) - else: - r = self._random.randint(0, self._num_items_seen) - if r < self._max_size: - self.items.pop(r) - self.items.append(f(item)) - elif self.always_keep_last: - self.items[-1] = f(item) - self._num_items_seen += 1 - - def FilterItems(self, filterFn): - """Filter items in a ReservoirBucket, using a filtering function. - - Filtering items from the reservoir bucket must update the - internal state variable self._num_items_seen, which is used for determining - the rate of replacement in reservoir sampling. Ideally, self._num_items_seen - would contain the exact number of items that have ever seen by the - ReservoirBucket and satisfy filterFn. However, the ReservoirBucket does not - have access to all items seen -- it only has access to the subset of items - that have survived sampling (self.items). Therefore, we estimate - self._num_items_seen by scaling it by the same ratio as the ratio of items - not removed from self.items. - - Args: - filterFn: A function that returns True for items to be kept. - - Returns: - The number of items removed from the bucket. - """ - with self._mutex: - size_before = len(self.items) - self.items = list(filter(filterFn, self.items)) - size_diff = size_before - len(self.items) - - # Estimate a correction the number of items seen - prop_remaining = len(self.items) / float( - size_before) if size_before > 0 else 0 - self._num_items_seen = int(round(self._num_items_seen * prop_remaining)) - return size_diff - - def Items(self): - """Get all the items in the bucket.""" - with self._mutex: - return list(self.items) + self._random = random.Random(0) + self.always_keep_last = always_keep_last + + def AddItem(self, item, f=lambda x: x): + """Add an item to the ReservoirBucket, replacing an old item if + necessary. + + The new item is guaranteed to be added to the bucket, and to be the last + element in the bucket. If the bucket has reached capacity, then an old item + will be replaced. With probability (_max_size/_num_items_seen) a random item + in the bucket will be popped out and the new item will be appended + to the end. With probability (1 - _max_size/_num_items_seen) + the last item in the bucket will be replaced. + + Since the O(n) replacements occur with O(1/_num_items_seen) likelihood, + the amortized runtime is O(1). + + Args: + item: The item to add to the bucket. + f: A function to transform item before addition, if it will be kept in + the reservoir. + """ + with self._mutex: + if len(self.items) < self._max_size or self._max_size == 0: + self.items.append(f(item)) + else: + r = self._random.randint(0, self._num_items_seen) + if r < self._max_size: + self.items.pop(r) + self.items.append(f(item)) + elif self.always_keep_last: + self.items[-1] = f(item) + self._num_items_seen += 1 + + def FilterItems(self, filterFn): + """Filter items in a ReservoirBucket, using a filtering function. + + Filtering items from the reservoir bucket must update the + internal state variable self._num_items_seen, which is used for determining + the rate of replacement in reservoir sampling. Ideally, self._num_items_seen + would contain the exact number of items that have ever seen by the + ReservoirBucket and satisfy filterFn. However, the ReservoirBucket does not + have access to all items seen -- it only has access to the subset of items + that have survived sampling (self.items). Therefore, we estimate + self._num_items_seen by scaling it by the same ratio as the ratio of items + not removed from self.items. + + Args: + filterFn: A function that returns True for items to be kept. + + Returns: + The number of items removed from the bucket. + """ + with self._mutex: + size_before = len(self.items) + self.items = list(filter(filterFn, self.items)) + size_diff = size_before - len(self.items) + + # Estimate a correction the number of items seen + prop_remaining = ( + len(self.items) / float(size_before) if size_before > 0 else 0 + ) + self._num_items_seen = int( + round(self._num_items_seen * prop_remaining) + ) + return size_diff + + def Items(self): + """Get all the items in the bucket.""" + with self._mutex: + return list(self.items) diff --git a/tensorboard/backend/event_processing/reservoir_test.py b/tensorboard/backend/event_processing/reservoir_test.py index 50b830903e..4096918482 100644 --- a/tensorboard/backend/event_processing/reservoir_test.py +++ b/tensorboard/backend/event_processing/reservoir_test.py @@ -24,256 +24,264 @@ class ReservoirTest(tf.test.TestCase): - - def testEmptyReservoir(self): - r = reservoir.Reservoir(1) - self.assertFalse(r.Keys()) - - def testRespectsSize(self): - r = reservoir.Reservoir(42) - self.assertEqual(r._buckets['meaning of life']._max_size, 42) - - def testItemsAndKeys(self): - r = reservoir.Reservoir(42) - r.AddItem('foo', 4) - r.AddItem('bar', 9) - r.AddItem('foo', 19) - self.assertItemsEqual(r.Keys(), ['foo', 'bar']) - self.assertEqual(r.Items('foo'), [4, 19]) - self.assertEqual(r.Items('bar'), [9]) - - def testExceptions(self): - with self.assertRaises(ValueError): - reservoir.Reservoir(-1) - with self.assertRaises(ValueError): - reservoir.Reservoir(13.3) - - r = reservoir.Reservoir(12) - with self.assertRaises(KeyError): - r.Items('missing key') - - def testDeterminism(self): - """Tests that the reservoir is deterministic.""" - key = 'key' - r1 = reservoir.Reservoir(10) - r2 = reservoir.Reservoir(10) - for i in xrange(100): - r1.AddItem('key', i) - r2.AddItem('key', i) - - self.assertEqual(r1.Items(key), r2.Items(key)) - - def testBucketDeterminism(self): - """Tests that reservoirs are deterministic at a bucket level. - - This means that only the order elements are added within a bucket matters. - """ - separate_reservoir = reservoir.Reservoir(10) - interleaved_reservoir = reservoir.Reservoir(10) - for i in xrange(100): - separate_reservoir.AddItem('key1', i) - for i in xrange(100): - separate_reservoir.AddItem('key2', i) - for i in xrange(100): - interleaved_reservoir.AddItem('key1', i) - interleaved_reservoir.AddItem('key2', i) - - for key in ['key1', 'key2']: - self.assertEqual( - separate_reservoir.Items(key), interleaved_reservoir.Items(key)) - - def testUsesSeed(self): - """Tests that reservoirs with different seeds keep different samples.""" - key = 'key' - r1 = reservoir.Reservoir(10, seed=0) - r2 = reservoir.Reservoir(10, seed=1) - for i in xrange(100): - r1.AddItem('key', i) - r2.AddItem('key', i) - self.assertNotEqual(r1.Items(key), r2.Items(key)) - - def testFilterItemsByKey(self): - r = reservoir.Reservoir(100, seed=0) - for i in xrange(10): - r.AddItem('key1', i) - r.AddItem('key2', i) - - self.assertEqual(len(r.Items('key1')), 10) - self.assertEqual(len(r.Items('key2')), 10) - - self.assertEqual(r.FilterItems(lambda x: x <= 7, 'key2'), 2) - self.assertEqual(len(r.Items('key2')), 8) - self.assertEqual(len(r.Items('key1')), 10) - - self.assertEqual(r.FilterItems(lambda x: x <= 3, 'key1'), 6) - self.assertEqual(len(r.Items('key1')), 4) - self.assertEqual(len(r.Items('key2')), 8) + def testEmptyReservoir(self): + r = reservoir.Reservoir(1) + self.assertFalse(r.Keys()) + + def testRespectsSize(self): + r = reservoir.Reservoir(42) + self.assertEqual(r._buckets["meaning of life"]._max_size, 42) + + def testItemsAndKeys(self): + r = reservoir.Reservoir(42) + r.AddItem("foo", 4) + r.AddItem("bar", 9) + r.AddItem("foo", 19) + self.assertItemsEqual(r.Keys(), ["foo", "bar"]) + self.assertEqual(r.Items("foo"), [4, 19]) + self.assertEqual(r.Items("bar"), [9]) + + def testExceptions(self): + with self.assertRaises(ValueError): + reservoir.Reservoir(-1) + with self.assertRaises(ValueError): + reservoir.Reservoir(13.3) + + r = reservoir.Reservoir(12) + with self.assertRaises(KeyError): + r.Items("missing key") + + def testDeterminism(self): + """Tests that the reservoir is deterministic.""" + key = "key" + r1 = reservoir.Reservoir(10) + r2 = reservoir.Reservoir(10) + for i in xrange(100): + r1.AddItem("key", i) + r2.AddItem("key", i) + + self.assertEqual(r1.Items(key), r2.Items(key)) + + def testBucketDeterminism(self): + """Tests that reservoirs are deterministic at a bucket level. + + This means that only the order elements are added within a + bucket matters. + """ + separate_reservoir = reservoir.Reservoir(10) + interleaved_reservoir = reservoir.Reservoir(10) + for i in xrange(100): + separate_reservoir.AddItem("key1", i) + for i in xrange(100): + separate_reservoir.AddItem("key2", i) + for i in xrange(100): + interleaved_reservoir.AddItem("key1", i) + interleaved_reservoir.AddItem("key2", i) + + for key in ["key1", "key2"]: + self.assertEqual( + separate_reservoir.Items(key), interleaved_reservoir.Items(key) + ) + + def testUsesSeed(self): + """Tests that reservoirs with different seeds keep different + samples.""" + key = "key" + r1 = reservoir.Reservoir(10, seed=0) + r2 = reservoir.Reservoir(10, seed=1) + for i in xrange(100): + r1.AddItem("key", i) + r2.AddItem("key", i) + self.assertNotEqual(r1.Items(key), r2.Items(key)) + + def testFilterItemsByKey(self): + r = reservoir.Reservoir(100, seed=0) + for i in xrange(10): + r.AddItem("key1", i) + r.AddItem("key2", i) + + self.assertEqual(len(r.Items("key1")), 10) + self.assertEqual(len(r.Items("key2")), 10) + + self.assertEqual(r.FilterItems(lambda x: x <= 7, "key2"), 2) + self.assertEqual(len(r.Items("key2")), 8) + self.assertEqual(len(r.Items("key1")), 10) + + self.assertEqual(r.FilterItems(lambda x: x <= 3, "key1"), 6) + self.assertEqual(len(r.Items("key1")), 4) + self.assertEqual(len(r.Items("key2")), 8) class ReservoirBucketTest(tf.test.TestCase): - - def testEmptyBucket(self): - b = reservoir._ReservoirBucket(1) - self.assertFalse(b.Items()) - - def testFillToSize(self): - b = reservoir._ReservoirBucket(100) - for i in xrange(100): - b.AddItem(i) - self.assertEqual(b.Items(), list(xrange(100))) - self.assertEqual(b._num_items_seen, 100) - - def testDoesntOverfill(self): - b = reservoir._ReservoirBucket(10) - for i in xrange(1000): - b.AddItem(i) - self.assertEqual(len(b.Items()), 10) - self.assertEqual(b._num_items_seen, 1000) - - def testMaintainsOrder(self): - b = reservoir._ReservoirBucket(100) - for i in xrange(10000): - b.AddItem(i) - items = b.Items() - prev = -1 - for item in items: - self.assertTrue(item > prev) - prev = item - - def testKeepsLatestItem(self): - b = reservoir._ReservoirBucket(5) - for i in xrange(100): - b.AddItem(i) - last = b.Items()[-1] - self.assertEqual(last, i) - - def testSizeOneBucket(self): - b = reservoir._ReservoirBucket(1) - for i in xrange(20): - b.AddItem(i) - self.assertEqual(b.Items(), [i]) - self.assertEqual(b._num_items_seen, 20) - - def testSizeZeroBucket(self): - b = reservoir._ReservoirBucket(0) - for i in xrange(20): - b.AddItem(i) - self.assertEqual(b.Items(), list(range(i + 1))) - self.assertEqual(b._num_items_seen, 20) - - def testSizeRequirement(self): - with self.assertRaises(ValueError): - reservoir._ReservoirBucket(-1) - with self.assertRaises(ValueError): - reservoir._ReservoirBucket(10.3) - - def testRemovesItems(self): - b = reservoir._ReservoirBucket(100) - for i in xrange(10): - b.AddItem(i) - self.assertEqual(len(b.Items()), 10) - self.assertEqual(b._num_items_seen, 10) - self.assertEqual(b.FilterItems(lambda x: x <= 7), 2) - self.assertEqual(len(b.Items()), 8) - self.assertEqual(b._num_items_seen, 8) - - def testRemovesItemsWhenItemsAreReplaced(self): - b = reservoir._ReservoirBucket(100) - for i in xrange(10000): - b.AddItem(i) - self.assertEqual(b._num_items_seen, 10000) - - # Remove items - num_removed = b.FilterItems(lambda x: x <= 7) - self.assertGreater(num_removed, 92) - self.assertEqual([], [item for item in b.Items() if item > 7]) - self.assertEqual(b._num_items_seen, - int(round(10000 * (1 - float(num_removed) / 100)))) - - def testLazyFunctionEvaluationAndAlwaysKeepLast(self): - - class FakeRandom(object): - - def randint(self, a, b): # pylint:disable=unused-argument - return 999 - - class Incrementer(object): - - def __init__(self): - self.n = 0 - - def increment_and_double(self, x): - self.n += 1 - return x * 2 - - # We've mocked the randomness generator, so that once it is full, the last - # item will never get durable reservoir inclusion. Since always_keep_last is - # false, the function should only get invoked 100 times while filling up - # the reservoir. This laziness property is an essential performance - # optimization. - b = reservoir._ReservoirBucket(100, FakeRandom(), always_keep_last=False) - incrementer = Incrementer() - for i in xrange(1000): - b.AddItem(i, incrementer.increment_and_double) - self.assertEqual(incrementer.n, 100) - self.assertEqual(b.Items(), [x * 2 for x in xrange(100)]) - - # This time, we will always keep the last item, meaning that the function - # should get invoked once for every item we add. - b = reservoir._ReservoirBucket(100, FakeRandom(), always_keep_last=True) - incrementer = Incrementer() - - for i in xrange(1000): - b.AddItem(i, incrementer.increment_and_double) - self.assertEqual(incrementer.n, 1000) - self.assertEqual(b.Items(), [x * 2 for x in xrange(99)] + [999 * 2]) + def testEmptyBucket(self): + b = reservoir._ReservoirBucket(1) + self.assertFalse(b.Items()) + + def testFillToSize(self): + b = reservoir._ReservoirBucket(100) + for i in xrange(100): + b.AddItem(i) + self.assertEqual(b.Items(), list(xrange(100))) + self.assertEqual(b._num_items_seen, 100) + + def testDoesntOverfill(self): + b = reservoir._ReservoirBucket(10) + for i in xrange(1000): + b.AddItem(i) + self.assertEqual(len(b.Items()), 10) + self.assertEqual(b._num_items_seen, 1000) + + def testMaintainsOrder(self): + b = reservoir._ReservoirBucket(100) + for i in xrange(10000): + b.AddItem(i) + items = b.Items() + prev = -1 + for item in items: + self.assertTrue(item > prev) + prev = item + + def testKeepsLatestItem(self): + b = reservoir._ReservoirBucket(5) + for i in xrange(100): + b.AddItem(i) + last = b.Items()[-1] + self.assertEqual(last, i) + + def testSizeOneBucket(self): + b = reservoir._ReservoirBucket(1) + for i in xrange(20): + b.AddItem(i) + self.assertEqual(b.Items(), [i]) + self.assertEqual(b._num_items_seen, 20) + + def testSizeZeroBucket(self): + b = reservoir._ReservoirBucket(0) + for i in xrange(20): + b.AddItem(i) + self.assertEqual(b.Items(), list(range(i + 1))) + self.assertEqual(b._num_items_seen, 20) + + def testSizeRequirement(self): + with self.assertRaises(ValueError): + reservoir._ReservoirBucket(-1) + with self.assertRaises(ValueError): + reservoir._ReservoirBucket(10.3) + + def testRemovesItems(self): + b = reservoir._ReservoirBucket(100) + for i in xrange(10): + b.AddItem(i) + self.assertEqual(len(b.Items()), 10) + self.assertEqual(b._num_items_seen, 10) + self.assertEqual(b.FilterItems(lambda x: x <= 7), 2) + self.assertEqual(len(b.Items()), 8) + self.assertEqual(b._num_items_seen, 8) + + def testRemovesItemsWhenItemsAreReplaced(self): + b = reservoir._ReservoirBucket(100) + for i in xrange(10000): + b.AddItem(i) + self.assertEqual(b._num_items_seen, 10000) + + # Remove items + num_removed = b.FilterItems(lambda x: x <= 7) + self.assertGreater(num_removed, 92) + self.assertEqual([], [item for item in b.Items() if item > 7]) + self.assertEqual( + b._num_items_seen, + int(round(10000 * (1 - float(num_removed) / 100))), + ) + + def testLazyFunctionEvaluationAndAlwaysKeepLast(self): + class FakeRandom(object): + def randint(self, a, b): # pylint:disable=unused-argument + return 999 + + class Incrementer(object): + def __init__(self): + self.n = 0 + + def increment_and_double(self, x): + self.n += 1 + return x * 2 + + # We've mocked the randomness generator, so that once it is full, the last + # item will never get durable reservoir inclusion. Since always_keep_last is + # false, the function should only get invoked 100 times while filling up + # the reservoir. This laziness property is an essential performance + # optimization. + b = reservoir._ReservoirBucket( + 100, FakeRandom(), always_keep_last=False + ) + incrementer = Incrementer() + for i in xrange(1000): + b.AddItem(i, incrementer.increment_and_double) + self.assertEqual(incrementer.n, 100) + self.assertEqual(b.Items(), [x * 2 for x in xrange(100)]) + + # This time, we will always keep the last item, meaning that the function + # should get invoked once for every item we add. + b = reservoir._ReservoirBucket(100, FakeRandom(), always_keep_last=True) + incrementer = Incrementer() + + for i in xrange(1000): + b.AddItem(i, incrementer.increment_and_double) + self.assertEqual(incrementer.n, 1000) + self.assertEqual(b.Items(), [x * 2 for x in xrange(99)] + [999 * 2]) class ReservoirBucketStatisticalDistributionTest(tf.test.TestCase): - - def setUp(self): - self.total = 1000000 - self.samples = 10000 - self.n_buckets = 100 - self.total_per_bucket = self.total // self.n_buckets - self.assertEqual(self.total % self.n_buckets, 0, 'total must be evenly ' - 'divisible by the number of buckets') - self.assertTrue(self.total > self.samples, 'need to have more items ' - 'than samples') - - def AssertBinomialQuantity(self, measured): - p = 1.0 * self.n_buckets / self.samples - mean = p * self.samples - variance = p * (1 - p) * self.samples - error = measured - mean - # Given that the buckets were actually binomially distributed, this - # fails with probability ~2E-9 - passed = error * error <= 36.0 * variance - self.assertTrue(passed, 'found a bucket with measured %d ' - 'too far from expected %d' % (measured, mean)) - - def testBucketReservoirSamplingViaStatisticalProperties(self): - # Not related to a 'ReservoirBucket', but instead number of buckets we put - # samples into for testing the shape of the distribution - b = reservoir._ReservoirBucket(_max_size=self.samples) - # add one extra item because we always keep the most recent item, which - # would skew the distribution; we can just slice it off the end instead. - for i in xrange(self.total + 1): - b.AddItem(i) - - divbins = [0] * self.n_buckets - modbins = [0] * self.n_buckets - # Slice off the last item when we iterate. - for item in b.Items()[0:-1]: - divbins[item // self.total_per_bucket] += 1 - modbins[item % self.n_buckets] += 1 - - for bucket_index in xrange(self.n_buckets): - divbin = divbins[bucket_index] - modbin = modbins[bucket_index] - self.AssertBinomialQuantity(divbin) - self.AssertBinomialQuantity(modbin) - - -if __name__ == '__main__': - tf.test.main() + def setUp(self): + self.total = 1000000 + self.samples = 10000 + self.n_buckets = 100 + self.total_per_bucket = self.total // self.n_buckets + self.assertEqual( + self.total % self.n_buckets, + 0, + "total must be evenly " "divisible by the number of buckets", + ) + self.assertTrue( + self.total > self.samples, "need to have more items " "than samples" + ) + + def AssertBinomialQuantity(self, measured): + p = 1.0 * self.n_buckets / self.samples + mean = p * self.samples + variance = p * (1 - p) * self.samples + error = measured - mean + # Given that the buckets were actually binomially distributed, this + # fails with probability ~2E-9 + passed = error * error <= 36.0 * variance + self.assertTrue( + passed, + "found a bucket with measured %d " + "too far from expected %d" % (measured, mean), + ) + + def testBucketReservoirSamplingViaStatisticalProperties(self): + # Not related to a 'ReservoirBucket', but instead number of buckets we put + # samples into for testing the shape of the distribution + b = reservoir._ReservoirBucket(_max_size=self.samples) + # add one extra item because we always keep the most recent item, which + # would skew the distribution; we can just slice it off the end instead. + for i in xrange(self.total + 1): + b.AddItem(i) + + divbins = [0] * self.n_buckets + modbins = [0] * self.n_buckets + # Slice off the last item when we iterate. + for item in b.Items()[0:-1]: + divbins[item // self.total_per_bucket] += 1 + modbins[item % self.n_buckets] += 1 + + for bucket_index in xrange(self.n_buckets): + divbin = divbins[bucket_index] + modbin = modbins[bucket_index] + self.AssertBinomialQuantity(divbin) + self.AssertBinomialQuantity(modbin) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/backend/event_processing/sqlite_writer.py b/tensorboard/backend/event_processing/sqlite_writer.py index bed30f570a..7ac2f05c2f 100644 --- a/tensorboard/backend/event_processing/sqlite_writer.py +++ b/tensorboard/backend/event_processing/sqlite_writer.py @@ -34,184 +34,215 @@ # Struct bundling a tag with its SummaryMetadata and a list of values, each of # which are a tuple of step, wall time (as a float), and a TensorProto. -TagData = collections.namedtuple('TagData', ['tag', 'metadata', 'values']) +TagData = collections.namedtuple("TagData", ["tag", "metadata", "values"]) class SqliteWriter(object): - """Sends summary data to SQLite using python's sqlite3 module.""" + """Sends summary data to SQLite using python's sqlite3 module.""" - def __init__(self, db_connection_provider): - """Constructs a SqliteWriterEventSink. + def __init__(self, db_connection_provider): + """Constructs a SqliteWriterEventSink. - Args: - db_connection_provider: Provider function for creating a DB connection. - """ - self._db = db_connection_provider() + Args: + db_connection_provider: Provider function for creating a DB connection. + """ + self._db = db_connection_provider() - def _make_blob(self, bytestring): - """Helper to ensure SQLite treats the given data as a BLOB.""" - # Special-case python 2 pysqlite which uses buffers for BLOB. - if sys.version_info[0] == 2: - return buffer(bytestring) # noqa: F821 (undefined name) - return bytestring + def _make_blob(self, bytestring): + """Helper to ensure SQLite treats the given data as a BLOB.""" + # Special-case python 2 pysqlite which uses buffers for BLOB. + if sys.version_info[0] == 2: + return buffer(bytestring) # noqa: F821 (undefined name) + return bytestring - def _create_id(self): - """Returns a freshly created DB-wide unique ID.""" - cursor = self._db.cursor() - cursor.execute('INSERT INTO Ids DEFAULT VALUES') - return cursor.lastrowid + def _create_id(self): + """Returns a freshly created DB-wide unique ID.""" + cursor = self._db.cursor() + cursor.execute("INSERT INTO Ids DEFAULT VALUES") + return cursor.lastrowid - def _maybe_init_user(self): - """Returns the ID for the current user, creating the row if needed.""" - user_name = os.environ.get('USER', '') or os.environ.get('USERNAME', '') - cursor = self._db.cursor() - cursor.execute('SELECT user_id FROM Users WHERE user_name = ?', - (user_name,)) - row = cursor.fetchone() - if row: - return row[0] - user_id = self._create_id() - cursor.execute( - """ + def _maybe_init_user(self): + """Returns the ID for the current user, creating the row if needed.""" + user_name = os.environ.get("USER", "") or os.environ.get("USERNAME", "") + cursor = self._db.cursor() + cursor.execute( + "SELECT user_id FROM Users WHERE user_name = ?", (user_name,) + ) + row = cursor.fetchone() + if row: + return row[0] + user_id = self._create_id() + cursor.execute( + """ INSERT INTO USERS (user_id, user_name, inserted_time) VALUES (?, ?, ?) """, - (user_id, user_name, time.time())) - return user_id + (user_id, user_name, time.time()), + ) + return user_id - def _maybe_init_experiment(self, experiment_name): - """Returns the ID for the given experiment, creating the row if needed. + def _maybe_init_experiment(self, experiment_name): + """Returns the ID for the given experiment, creating the row if needed. - Args: - experiment_name: name of experiment. - """ - user_id = self._maybe_init_user() - cursor = self._db.cursor() - cursor.execute( + Args: + experiment_name: name of experiment. """ + user_id = self._maybe_init_user() + cursor = self._db.cursor() + cursor.execute( + """ SELECT experiment_id FROM Experiments WHERE user_id = ? AND experiment_name = ? """, - (user_id, experiment_name)) - row = cursor.fetchone() - if row: - return row[0] - experiment_id = self._create_id() - # TODO: track computed time from run start times - computed_time = 0 - cursor.execute( - """ + (user_id, experiment_name), + ) + row = cursor.fetchone() + if row: + return row[0] + experiment_id = self._create_id() + # TODO: track computed time from run start times + computed_time = 0 + cursor.execute( + """ INSERT INTO Experiments ( user_id, experiment_id, experiment_name, inserted_time, started_time, is_watching ) VALUES (?, ?, ?, ?, ?, ?) """, - (user_id, experiment_id, experiment_name, time.time(), computed_time, - False)) - return experiment_id + ( + user_id, + experiment_id, + experiment_name, + time.time(), + computed_time, + False, + ), + ) + return experiment_id - def _maybe_init_run(self, experiment_name, run_name): - """Returns the ID for the given run, creating the row if needed. + def _maybe_init_run(self, experiment_name, run_name): + """Returns the ID for the given run, creating the row if needed. - Args: - experiment_name: name of experiment containing this run. - run_name: name of run. - """ - experiment_id = self._maybe_init_experiment(experiment_name) - cursor = self._db.cursor() - cursor.execute( + Args: + experiment_name: name of experiment containing this run. + run_name: name of run. """ + experiment_id = self._maybe_init_experiment(experiment_name) + cursor = self._db.cursor() + cursor.execute( + """ SELECT run_id FROM Runs WHERE experiment_id = ? AND run_name = ? """, - (experiment_id, run_name)) - row = cursor.fetchone() - if row: - return row[0] - run_id = self._create_id() - # TODO: track actual run start times - started_time = 0 - cursor.execute( - """ + (experiment_id, run_name), + ) + row = cursor.fetchone() + if row: + return row[0] + run_id = self._create_id() + # TODO: track actual run start times + started_time = 0 + cursor.execute( + """ INSERT INTO Runs ( experiment_id, run_id, run_name, inserted_time, started_time ) VALUES (?, ?, ?, ?, ?) """, - (experiment_id, run_id, run_name, time.time(), started_time)) - return run_id + (experiment_id, run_id, run_name, time.time(), started_time), + ) + return run_id - def _maybe_init_tags(self, run_id, tag_to_metadata): - """Returns a tag-to-ID map for the given tags, creating rows if needed. + def _maybe_init_tags(self, run_id, tag_to_metadata): + """Returns a tag-to-ID map for the given tags, creating rows if needed. - Args: - run_id: the ID of the run to which these tags belong. - tag_to_metadata: map of tag name to SummaryMetadata for the tag. - """ - cursor = self._db.cursor() - # TODO: for huge numbers of tags (e.g. 1000+), this is slower than just - # querying for the known tag names explicitly; find a better tradeoff. - cursor.execute('SELECT tag_name, tag_id FROM Tags WHERE run_id = ?', - (run_id,)) - tag_to_id = {row[0]: row[1] for row in cursor.fetchall() - if row[0] in tag_to_metadata} - new_tag_data = [] - for tag, metadata in six.iteritems(tag_to_metadata): - if tag not in tag_to_id: - tag_id = self._create_id() - tag_to_id[tag] = tag_id - new_tag_data.append((run_id, tag_id, tag, time.time(), - metadata.display_name, - metadata.plugin_data.plugin_name, - self._make_blob(metadata.plugin_data.content))) - cursor.executemany( + Args: + run_id: the ID of the run to which these tags belong. + tag_to_metadata: map of tag name to SummaryMetadata for the tag. """ + cursor = self._db.cursor() + # TODO: for huge numbers of tags (e.g. 1000+), this is slower than just + # querying for the known tag names explicitly; find a better tradeoff. + cursor.execute( + "SELECT tag_name, tag_id FROM Tags WHERE run_id = ?", (run_id,) + ) + tag_to_id = { + row[0]: row[1] + for row in cursor.fetchall() + if row[0] in tag_to_metadata + } + new_tag_data = [] + for tag, metadata in six.iteritems(tag_to_metadata): + if tag not in tag_to_id: + tag_id = self._create_id() + tag_to_id[tag] = tag_id + new_tag_data.append( + ( + run_id, + tag_id, + tag, + time.time(), + metadata.display_name, + metadata.plugin_data.plugin_name, + self._make_blob(metadata.plugin_data.content), + ) + ) + cursor.executemany( + """ INSERT INTO Tags ( run_id, tag_id, tag_name, inserted_time, display_name, plugin_name, plugin_data ) VALUES (?, ?, ?, ?, ?, ?, ?) """, - new_tag_data) - return tag_to_id + new_tag_data, + ) + return tag_to_id - def write_summaries(self, tagged_data, experiment_name, run_name): - """Transactionally writes the given tagged summary data to the DB. + def write_summaries(self, tagged_data, experiment_name, run_name): + """Transactionally writes the given tagged summary data to the DB. - Args: - tagged_data: map from tag to TagData instances. - experiment_name: name of experiment. - run_name: name of run. - """ - logger.debug('Writing summaries for %s tags', len(tagged_data)) - # Connection used as context manager for auto commit/rollback on exit. - # We still need an explicit BEGIN, because it doesn't do one on enter, - # it waits until the first DML command - which is totally broken. - # See: https://stackoverflow.com/a/44448465/1179226 - with self._db: - self._db.execute('BEGIN TRANSACTION') - run_id = self._maybe_init_run(experiment_name, run_name) - tag_to_metadata = { - tag: tagdata.metadata for tag, tagdata in six.iteritems(tagged_data) - } - tag_to_id = self._maybe_init_tags(run_id, tag_to_metadata) - tensor_values = [] - for tag, tagdata in six.iteritems(tagged_data): - tag_id = tag_to_id[tag] - for step, wall_time, tensor_proto in tagdata.values: - dtype = tensor_proto.dtype - shape = ','.join(str(d.size) for d in tensor_proto.tensor_shape.dim) - # Use tensor_proto.tensor_content if it's set, to skip relatively - # expensive extraction into intermediate ndarray. - data = self._make_blob( - tensor_proto.tensor_content or - tensor_util.make_ndarray(tensor_proto).tobytes()) - tensor_values.append((tag_id, step, wall_time, dtype, shape, data)) - self._db.executemany( - """ + Args: + tagged_data: map from tag to TagData instances. + experiment_name: name of experiment. + run_name: name of run. + """ + logger.debug("Writing summaries for %s tags", len(tagged_data)) + # Connection used as context manager for auto commit/rollback on exit. + # We still need an explicit BEGIN, because it doesn't do one on enter, + # it waits until the first DML command - which is totally broken. + # See: https://stackoverflow.com/a/44448465/1179226 + with self._db: + self._db.execute("BEGIN TRANSACTION") + run_id = self._maybe_init_run(experiment_name, run_name) + tag_to_metadata = { + tag: tagdata.metadata + for tag, tagdata in six.iteritems(tagged_data) + } + tag_to_id = self._maybe_init_tags(run_id, tag_to_metadata) + tensor_values = [] + for tag, tagdata in six.iteritems(tagged_data): + tag_id = tag_to_id[tag] + for step, wall_time, tensor_proto in tagdata.values: + dtype = tensor_proto.dtype + shape = ",".join( + str(d.size) for d in tensor_proto.tensor_shape.dim + ) + # Use tensor_proto.tensor_content if it's set, to skip relatively + # expensive extraction into intermediate ndarray. + data = self._make_blob( + tensor_proto.tensor_content + or tensor_util.make_ndarray(tensor_proto).tobytes() + ) + tensor_values.append( + (tag_id, step, wall_time, dtype, shape, data) + ) + self._db.executemany( + """ INSERT OR REPLACE INTO Tensors ( series, step, computed_time, dtype, shape, data ) VALUES (?, ?, ?, ?, ?, ?) """, - tensor_values) + tensor_values, + ) # See tensorflow/contrib/tensorboard/db/schema.cc for documentation. @@ -414,17 +445,19 @@ def write_summaries(self, tagged_data, experiment_name, run_name): def initialize_schema(connection): - """Initializes the TensorBoard sqlite schema using the given connection. + """Initializes the TensorBoard sqlite schema using the given connection. - Args: - connection: A sqlite DB connection. - """ - cursor = connection.cursor() - cursor.execute("PRAGMA application_id={}".format(_TENSORBOARD_APPLICATION_ID)) - cursor.execute("PRAGMA user_version={}".format(_TENSORBOARD_USER_VERSION)) - with connection: - for statement in _SCHEMA_STATEMENTS: - lines = statement.strip('\n').split('\n') - message = lines[0] + ('...' if len(lines) > 1 else '') - logger.debug('Running DB init statement: %s', message) - cursor.execute(statement) + Args: + connection: A sqlite DB connection. + """ + cursor = connection.cursor() + cursor.execute( + "PRAGMA application_id={}".format(_TENSORBOARD_APPLICATION_ID) + ) + cursor.execute("PRAGMA user_version={}".format(_TENSORBOARD_USER_VERSION)) + with connection: + for statement in _SCHEMA_STATEMENTS: + lines = statement.strip("\n").split("\n") + message = lines[0] + ("..." if len(lines) > 1 else "") + logger.debug("Running DB init statement: %s", message) + cursor.execute(statement) diff --git a/tensorboard/backend/experiment_id.py b/tensorboard/backend/experiment_id.py index 7c427abe32..e695f26638 100644 --- a/tensorboard/backend/experiment_id.py +++ b/tensorboard/backend/experiment_id.py @@ -30,42 +30,42 @@ class ExperimentIdMiddleware(object): - """WSGI middleware extracting experiment IDs from URL to environment. + """WSGI middleware extracting experiment IDs from URL to environment. - Any request whose path matches `/experiment/SOME_EID[/...]` will have - its first two path components stripped, and its experiment ID stored - onto the WSGI environment with key taken from the `WSGI_ENVIRON_KEY` - constant. All other requests will have paths unchanged and the - experiment ID set to the empty string. + Any request whose path matches `/experiment/SOME_EID[/...]` will have + its first two path components stripped, and its experiment ID stored + onto the WSGI environment with key taken from the `WSGI_ENVIRON_KEY` + constant. All other requests will have paths unchanged and the + experiment ID set to the empty string. - Instances of this class are WSGI applications (see PEP 3333). - """ + Instances of this class are WSGI applications (see PEP 3333). + """ - def __init__(self, application): - """Initializes an `ExperimentIdMiddleware`. + def __init__(self, application): + """Initializes an `ExperimentIdMiddleware`. - Args: - application: The WSGI application to wrap (see PEP 3333). - """ - self._application = application - # Regular expression that matches the whole `/experiment/EID` prefix - # (without any trailing slash) and captures the experiment ID. - self._pat = re.compile( - r"/%s/([^/]*)" % re.escape(_EXPERIMENT_PATH_COMPONENT) - ) + Args: + application: The WSGI application to wrap (see PEP 3333). + """ + self._application = application + # Regular expression that matches the whole `/experiment/EID` prefix + # (without any trailing slash) and captures the experiment ID. + self._pat = re.compile( + r"/%s/([^/]*)" % re.escape(_EXPERIMENT_PATH_COMPONENT) + ) - def __call__(self, environ, start_response): - path = environ.get("PATH_INFO", "") - m = self._pat.match(path) - if m: - eid = m.group(1) - new_path = path[m.end(0):] - root = m.group(0) - else: - eid = "" - new_path = path - root = "" - environ[WSGI_ENVIRON_KEY] = eid - environ["PATH_INFO"] = new_path - environ["SCRIPT_NAME"] = environ.get("SCRIPT_NAME", "") + root - return self._application(environ, start_response) + def __call__(self, environ, start_response): + path = environ.get("PATH_INFO", "") + m = self._pat.match(path) + if m: + eid = m.group(1) + new_path = path[m.end(0) :] + root = m.group(0) + else: + eid = "" + new_path = path + root = "" + environ[WSGI_ENVIRON_KEY] = eid + environ["PATH_INFO"] = new_path + environ["SCRIPT_NAME"] = environ.get("SCRIPT_NAME", "") + root + return self._application(environ, start_response) diff --git a/tensorboard/backend/experiment_id_test.py b/tensorboard/backend/experiment_id_test.py index 698d6d4559..5ebc7c9091 100644 --- a/tensorboard/backend/experiment_id_test.py +++ b/tensorboard/backend/experiment_id_test.py @@ -28,66 +28,68 @@ class ExperimentIdMiddlewareTest(tb_test.TestCase): - """Tests for `ExperimentIdMiddleware`.""" - - def setUp(self): - super(ExperimentIdMiddlewareTest, self).setUp() - self.app = experiment_id.ExperimentIdMiddleware(self._echo_app) - self.server = werkzeug_test.Client(self.app, werkzeug.BaseResponse) - - def _echo_app(self, environ, start_response): - # https://www.python.org/dev/peps/pep-0333/#environ-variables - data = { - "eid": environ[experiment_id.WSGI_ENVIRON_KEY], - "path": environ.get("PATH_INFO", ""), - "script": environ.get("SCRIPT_NAME", ""), - } - body = json.dumps(data, sort_keys=True) - start_response("200 OK", [("Content-Type", "application/json")]) - return [body] - - def _assert_ok(self, response, eid, path, script): - self.assertEqual(response.status_code, 200) - actual = json.loads(response.get_data()) - expected = dict(eid=eid, path=path, script=script) - self.assertEqual(actual, expected) - - def test_no_experiment_empty_path(self): - response = self.server.get("") - self._assert_ok(response, eid="", path="", script="") - - def test_no_experiment_root_path(self): - response = self.server.get("/") - self._assert_ok(response, eid="", path="/", script="") - - def test_no_experiment_sub_path(self): - response = self.server.get("/x/y") - self._assert_ok(response, eid="", path="/x/y", script="") - - def test_with_experiment_empty_path(self): - response = self.server.get("/experiment/123") - self._assert_ok(response, eid="123", path="", script="/experiment/123") - - def test_with_experiment_root_path(self): - response = self.server.get("/experiment/123/") - self._assert_ok(response, eid="123", path="/", script="/experiment/123") - - def test_with_experiment_sub_path(self): - response = self.server.get("/experiment/123/x/y") - self._assert_ok(response, eid="123", path="/x/y", script="/experiment/123") - - def test_with_empty_experiment_empty_path(self): - response = self.server.get("/experiment/") - self._assert_ok(response, eid="", path="", script="/experiment/") - - def test_with_empty_experiment_root_path(self): - response = self.server.get("/experiment//") - self._assert_ok(response, eid="", path="/", script="/experiment/") - - def test_with_empty_experiment_sub_path(self): - response = self.server.get("/experiment//x/y") - self._assert_ok(response, eid="", path="/x/y", script="/experiment/") + """Tests for `ExperimentIdMiddleware`.""" + + def setUp(self): + super(ExperimentIdMiddlewareTest, self).setUp() + self.app = experiment_id.ExperimentIdMiddleware(self._echo_app) + self.server = werkzeug_test.Client(self.app, werkzeug.BaseResponse) + + def _echo_app(self, environ, start_response): + # https://www.python.org/dev/peps/pep-0333/#environ-variables + data = { + "eid": environ[experiment_id.WSGI_ENVIRON_KEY], + "path": environ.get("PATH_INFO", ""), + "script": environ.get("SCRIPT_NAME", ""), + } + body = json.dumps(data, sort_keys=True) + start_response("200 OK", [("Content-Type", "application/json")]) + return [body] + + def _assert_ok(self, response, eid, path, script): + self.assertEqual(response.status_code, 200) + actual = json.loads(response.get_data()) + expected = dict(eid=eid, path=path, script=script) + self.assertEqual(actual, expected) + + def test_no_experiment_empty_path(self): + response = self.server.get("") + self._assert_ok(response, eid="", path="", script="") + + def test_no_experiment_root_path(self): + response = self.server.get("/") + self._assert_ok(response, eid="", path="/", script="") + + def test_no_experiment_sub_path(self): + response = self.server.get("/x/y") + self._assert_ok(response, eid="", path="/x/y", script="") + + def test_with_experiment_empty_path(self): + response = self.server.get("/experiment/123") + self._assert_ok(response, eid="123", path="", script="/experiment/123") + + def test_with_experiment_root_path(self): + response = self.server.get("/experiment/123/") + self._assert_ok(response, eid="123", path="/", script="/experiment/123") + + def test_with_experiment_sub_path(self): + response = self.server.get("/experiment/123/x/y") + self._assert_ok( + response, eid="123", path="/x/y", script="/experiment/123" + ) + + def test_with_empty_experiment_empty_path(self): + response = self.server.get("/experiment/") + self._assert_ok(response, eid="", path="", script="/experiment/") + + def test_with_empty_experiment_root_path(self): + response = self.server.get("/experiment//") + self._assert_ok(response, eid="", path="/", script="/experiment/") + + def test_with_empty_experiment_sub_path(self): + response = self.server.get("/experiment//x/y") + self._assert_ok(response, eid="", path="/x/y", script="/experiment/") if __name__ == "__main__": - tb_test.main() + tb_test.main() diff --git a/tensorboard/backend/http_util.py b/tensorboard/backend/http_util.py index ce131b3f07..1fb9d2c675 100644 --- a/tensorboard/backend/http_util.py +++ b/tensorboard/backend/http_util.py @@ -27,14 +27,16 @@ import wsgiref.handlers import six -from six.moves.urllib import parse as urlparse # pylint: disable=wrong-import-order +from six.moves.urllib import ( + parse as urlparse, +) # pylint: disable=wrong-import-order import werkzeug from tensorboard.backend import json_util from tensorboard.compat import tf -_DISALLOWED_CHAR_IN_DOMAIN = re.compile(r'\s') +_DISALLOWED_CHAR_IN_DOMAIN = re.compile(r"\s") # TODO(stephanwlee): Refactor this to not use the module variable but # instead use a configurable via some kind of assets provider which would @@ -48,204 +50,229 @@ _CSP_SCRIPT_UNSAFE_EVAL = True _CSP_STYLE_DOMAINS_WHITELIST = [] -_EXTRACT_MIMETYPE_PATTERN = re.compile(r'^[^;\s]*') -_EXTRACT_CHARSET_PATTERN = re.compile(r'charset=([-_0-9A-Za-z]+)') +_EXTRACT_MIMETYPE_PATTERN = re.compile(r"^[^;\s]*") +_EXTRACT_CHARSET_PATTERN = re.compile(r"charset=([-_0-9A-Za-z]+)") # Allows *, gzip or x-gzip, but forbid gzip;q=0 # https://tools.ietf.org/html/rfc7231#section-5.3.4 _ALLOWS_GZIP_PATTERN = re.compile( - r'(?:^|,|\s)(?:(?:x-)?gzip|\*)(?!;q=0)(?:\s|,|$)') - -_TEXTUAL_MIMETYPES = set([ - 'application/javascript', - 'application/json', - 'application/json+protobuf', - 'image/svg+xml', - 'text/css', - 'text/csv', - 'text/html', - 'text/plain', - 'text/tab-separated-values', - 'text/x-protobuf', -]) - -_JSON_MIMETYPES = set([ - 'application/json', - 'application/json+protobuf', -]) + r"(?:^|,|\s)(?:(?:x-)?gzip|\*)(?!;q=0)(?:\s|,|$)" +) -# Do not support xhtml for now. -_HTML_MIMETYPE = 'text/html' - -def Respond(request, - content, - content_type, - code=200, - expires=0, - content_encoding=None, - encoding='utf-8', - csp_scripts_sha256s=None): - """Construct a werkzeug Response. - - Responses are transmitted to the browser with compression if: a) the browser - supports it; b) it's sane to compress the content_type in question; and c) - the content isn't already compressed, as indicated by the content_encoding - parameter. - - Browser and proxy caching is completely disabled by default. If the expires - parameter is greater than zero then the response will be able to be cached by - the browser for that many seconds; however, proxies are still forbidden from - caching so that developers can bypass the cache with Ctrl+Shift+R. - - For textual content that isn't JSON, the encoding parameter is used as the - transmission charset which is automatically appended to the Content-Type - header. That is unless of course the content_type parameter contains a - charset parameter. If the two disagree, the characters in content will be - transcoded to the latter. - - If content_type declares a JSON media type, then content MAY be a dict, list, - tuple, or set, in which case this function has an implicit composition with - json_util.Cleanse and json.dumps. The encoding parameter is used to decode - byte strings within the JSON object; therefore transmitting binary data - within JSON is not permitted. JSON is transmitted as ASCII unless the - content_type parameter explicitly defines a charset parameter, in which case - the serialized JSON bytes will use that instead of escape sequences. - - Args: - request: A werkzeug Request object. Used mostly to check the - Accept-Encoding header. - content: Payload data as byte string, unicode string, or maybe JSON. - content_type: Media type and optionally an output charset. - code: Numeric HTTP status code to use. - expires: Second duration for browser caching. - content_encoding: Encoding if content is already encoded, e.g. 'gzip'. - encoding: Input charset if content parameter has byte strings. - csp_scripts_sha256s: List of base64 serialized sha256 of whitelisted script - elements for script-src of the Content-Security-Policy. If it is None, the - HTML will disallow any script to execute. It is only be used when the - content_type is text/html. - - Returns: - A werkzeug Response object (a WSGI application). - """ - - mimetype = _EXTRACT_MIMETYPE_PATTERN.search(content_type).group(0) - charset_match = _EXTRACT_CHARSET_PATTERN.search(content_type) - charset = charset_match.group(1) if charset_match else encoding - textual = charset_match or mimetype in _TEXTUAL_MIMETYPES - if (mimetype in _JSON_MIMETYPES and - isinstance(content, (dict, list, set, tuple))): - content = json.dumps(json_util.Cleanse(content, encoding), - ensure_ascii=not charset_match) - if charset != encoding: - content = tf.compat.as_text(content, encoding) - content = tf.compat.as_bytes(content, charset) - if textual and not charset_match and mimetype not in _JSON_MIMETYPES: - content_type += '; charset=' + charset - gzip_accepted = _ALLOWS_GZIP_PATTERN.search( - request.headers.get('Accept-Encoding', '')) - # Automatically gzip uncompressed text data if accepted. - if textual and not content_encoding and gzip_accepted: - out = six.BytesIO() - # Set mtime to zero to make payload for a given input deterministic. - with gzip.GzipFile(fileobj=out, mode='wb', compresslevel=3, mtime=0) as f: - f.write(content) - content = out.getvalue() - content_encoding = 'gzip' - - content_length = len(content) - direct_passthrough = False - # Automatically streamwise-gunzip precompressed data if not accepted. - if content_encoding == 'gzip' and not gzip_accepted: - gzip_file = gzip.GzipFile(fileobj=six.BytesIO(content), mode='rb') - # Last 4 bytes of gzip formatted data (little-endian) store the original - # content length mod 2^32; we just assume it's the content length. That - # means we can't streamwise-gunzip >4 GB precompressed file; this is ok. - content_length = struct.unpack(' 0: - e = wsgiref.handlers.format_date_time(time.time() + float(expires)) - headers.append(('Expires', e)) - headers.append(('Cache-Control', 'private, max-age=%d' % expires)) - else: - headers.append(('Expires', '0')) - headers.append(('Cache-Control', 'no-cache, must-revalidate')) - if mimetype == _HTML_MIMETYPE: - _validate_global_whitelist(_CSP_IMG_DOMAINS_WHITELIST) - _validate_global_whitelist(_CSP_STYLE_DOMAINS_WHITELIST) - _validate_global_whitelist(_CSP_FONT_DOMAINS_WHITELIST) - _validate_global_whitelist(_CSP_FRAME_DOMAINS_WHITELIST) - _validate_global_whitelist(_CSP_SCRIPT_DOMAINS_WHITELIST) - - frags = _CSP_SCRIPT_DOMAINS_WHITELIST + [ - "'self'" if _CSP_SCRIPT_SELF else '', - "'unsafe-eval'" if _CSP_SCRIPT_UNSAFE_EVAL else '', - ] + [ - "'sha256-{}'".format(sha256) for sha256 in (csp_scripts_sha256s or []) +_TEXTUAL_MIMETYPES = set( + [ + "application/javascript", + "application/json", + "application/json+protobuf", + "image/svg+xml", + "text/css", + "text/csv", + "text/html", + "text/plain", + "text/tab-separated-values", + "text/x-protobuf", ] - script_srcs = _create_csp_string(*frags) - - csp_string = ';'.join([ - "default-src 'self'", - 'font-src %s' % _create_csp_string( - "'self'", - *_CSP_FONT_DOMAINS_WHITELIST - ), - 'frame-ancestors *', - # Dynamic plugins are rendered inside an iframe. - 'frame-src %s' % _create_csp_string( - "'self'", - *_CSP_FRAME_DOMAINS_WHITELIST - ), - 'img-src %s' % _create_csp_string( - "'self'", - # used by favicon - 'data:', - # used by What-If tool for image sprites. - 'blob:', - *_CSP_IMG_DOMAINS_WHITELIST - ), - "object-src 'none'", - 'style-src %s' % _create_csp_string( - "'self'", - # used by google-chart - 'https://www.gstatic.com', - 'data:', - # inline styles: Polymer templates + d3 uses inline styles. - "'unsafe-inline'", - *_CSP_STYLE_DOMAINS_WHITELIST - ), - "script-src %s" % script_srcs, - ]) - - headers.append(('Content-Security-Policy', csp_string)) - - if request.method == 'HEAD': - content = None - - return werkzeug.wrappers.Response( - response=content, status=code, headers=headers, content_type=content_type, - direct_passthrough=direct_passthrough) +) + +_JSON_MIMETYPES = set(["application/json", "application/json+protobuf",]) + +# Do not support xhtml for now. +_HTML_MIMETYPE = "text/html" + + +def Respond( + request, + content, + content_type, + code=200, + expires=0, + content_encoding=None, + encoding="utf-8", + csp_scripts_sha256s=None, +): + """Construct a werkzeug Response. + + Responses are transmitted to the browser with compression if: a) the browser + supports it; b) it's sane to compress the content_type in question; and c) + the content isn't already compressed, as indicated by the content_encoding + parameter. + + Browser and proxy caching is completely disabled by default. If the expires + parameter is greater than zero then the response will be able to be cached by + the browser for that many seconds; however, proxies are still forbidden from + caching so that developers can bypass the cache with Ctrl+Shift+R. + + For textual content that isn't JSON, the encoding parameter is used as the + transmission charset which is automatically appended to the Content-Type + header. That is unless of course the content_type parameter contains a + charset parameter. If the two disagree, the characters in content will be + transcoded to the latter. + + If content_type declares a JSON media type, then content MAY be a dict, list, + tuple, or set, in which case this function has an implicit composition with + json_util.Cleanse and json.dumps. The encoding parameter is used to decode + byte strings within the JSON object; therefore transmitting binary data + within JSON is not permitted. JSON is transmitted as ASCII unless the + content_type parameter explicitly defines a charset parameter, in which case + the serialized JSON bytes will use that instead of escape sequences. + + Args: + request: A werkzeug Request object. Used mostly to check the + Accept-Encoding header. + content: Payload data as byte string, unicode string, or maybe JSON. + content_type: Media type and optionally an output charset. + code: Numeric HTTP status code to use. + expires: Second duration for browser caching. + content_encoding: Encoding if content is already encoded, e.g. 'gzip'. + encoding: Input charset if content parameter has byte strings. + csp_scripts_sha256s: List of base64 serialized sha256 of whitelisted script + elements for script-src of the Content-Security-Policy. If it is None, the + HTML will disallow any script to execute. It is only be used when the + content_type is text/html. + + Returns: + A werkzeug Response object (a WSGI application). + """ + + mimetype = _EXTRACT_MIMETYPE_PATTERN.search(content_type).group(0) + charset_match = _EXTRACT_CHARSET_PATTERN.search(content_type) + charset = charset_match.group(1) if charset_match else encoding + textual = charset_match or mimetype in _TEXTUAL_MIMETYPES + if mimetype in _JSON_MIMETYPES and isinstance( + content, (dict, list, set, tuple) + ): + content = json.dumps( + json_util.Cleanse(content, encoding), ensure_ascii=not charset_match + ) + if charset != encoding: + content = tf.compat.as_text(content, encoding) + content = tf.compat.as_bytes(content, charset) + if textual and not charset_match and mimetype not in _JSON_MIMETYPES: + content_type += "; charset=" + charset + gzip_accepted = _ALLOWS_GZIP_PATTERN.search( + request.headers.get("Accept-Encoding", "") + ) + # Automatically gzip uncompressed text data if accepted. + if textual and not content_encoding and gzip_accepted: + out = six.BytesIO() + # Set mtime to zero to make payload for a given input deterministic. + with gzip.GzipFile( + fileobj=out, mode="wb", compresslevel=3, mtime=0 + ) as f: + f.write(content) + content = out.getvalue() + content_encoding = "gzip" + + content_length = len(content) + direct_passthrough = False + # Automatically streamwise-gunzip precompressed data if not accepted. + if content_encoding == "gzip" and not gzip_accepted: + gzip_file = gzip.GzipFile(fileobj=six.BytesIO(content), mode="rb") + # Last 4 bytes of gzip formatted data (little-endian) store the original + # content length mod 2^32; we just assume it's the content length. That + # means we can't streamwise-gunzip >4 GB precompressed file; this is ok. + content_length = struct.unpack(" 0: + e = wsgiref.handlers.format_date_time(time.time() + float(expires)) + headers.append(("Expires", e)) + headers.append(("Cache-Control", "private, max-age=%d" % expires)) + else: + headers.append(("Expires", "0")) + headers.append(("Cache-Control", "no-cache, must-revalidate")) + if mimetype == _HTML_MIMETYPE: + _validate_global_whitelist(_CSP_IMG_DOMAINS_WHITELIST) + _validate_global_whitelist(_CSP_STYLE_DOMAINS_WHITELIST) + _validate_global_whitelist(_CSP_FONT_DOMAINS_WHITELIST) + _validate_global_whitelist(_CSP_FRAME_DOMAINS_WHITELIST) + _validate_global_whitelist(_CSP_SCRIPT_DOMAINS_WHITELIST) + + frags = ( + _CSP_SCRIPT_DOMAINS_WHITELIST + + [ + "'self'" if _CSP_SCRIPT_SELF else "", + "'unsafe-eval'" if _CSP_SCRIPT_UNSAFE_EVAL else "", + ] + + [ + "'sha256-{}'".format(sha256) + for sha256 in (csp_scripts_sha256s or []) + ] + ) + script_srcs = _create_csp_string(*frags) + + csp_string = ";".join( + [ + "default-src 'self'", + "font-src %s" + % _create_csp_string("'self'", *_CSP_FONT_DOMAINS_WHITELIST), + "frame-ancestors *", + # Dynamic plugins are rendered inside an iframe. + "frame-src %s" + % _create_csp_string("'self'", *_CSP_FRAME_DOMAINS_WHITELIST), + "img-src %s" + % _create_csp_string( + "'self'", + # used by favicon + "data:", + # used by What-If tool for image sprites. + "blob:", + *_CSP_IMG_DOMAINS_WHITELIST + ), + "object-src 'none'", + "style-src %s" + % _create_csp_string( + "'self'", + # used by google-chart + "https://www.gstatic.com", + "data:", + # inline styles: Polymer templates + d3 uses inline styles. + "'unsafe-inline'", + *_CSP_STYLE_DOMAINS_WHITELIST + ), + "script-src %s" % script_srcs, + ] + ) + + headers.append(("Content-Security-Policy", csp_string)) + + if request.method == "HEAD": + content = None + + return werkzeug.wrappers.Response( + response=content, + status=code, + headers=headers, + content_type=content_type, + direct_passthrough=direct_passthrough, + ) + def _validate_global_whitelist(whitelists): - for domain in whitelists: - url = urlparse.urlparse(domain) - if not url.scheme == 'https' or not url.netloc: - raise ValueError('Expected all whitelist to be a https URL: %r' % domain) - if ';' in domain: - raise ValueError('Expected whitelist domain to not contain ";": %r' % domain) - if _DISALLOWED_CHAR_IN_DOMAIN.search(domain): - raise ValueError( - 'Expected whitelist domain to not contain a whitespace: %r' % domain) + for domain in whitelists: + url = urlparse.urlparse(domain) + if not url.scheme == "https" or not url.netloc: + raise ValueError( + "Expected all whitelist to be a https URL: %r" % domain + ) + if ";" in domain: + raise ValueError( + 'Expected whitelist domain to not contain ";": %r' % domain + ) + if _DISALLOWED_CHAR_IN_DOMAIN.search(domain): + raise ValueError( + "Expected whitelist domain to not contain a whitespace: %r" + % domain + ) + def _create_csp_string(*csp_fragments): - csp_string = ' '.join([frag for frag in csp_fragments if frag]) - return csp_string if csp_string else "'none'" + csp_string = " ".join([frag for frag in csp_fragments if frag]) + return csp_string if csp_string else "'none'" diff --git a/tensorboard/backend/http_util_test.py b/tensorboard/backend/http_util_test.py index d165450196..de4e31b611 100644 --- a/tensorboard/backend/http_util_test.py +++ b/tensorboard/backend/http_util_test.py @@ -26,10 +26,10 @@ import six try: - # python version >= 3.3 - from unittest import mock + # python version >= 3.3 + from unittest import mock except ImportError: - import mock # pylint: disable=unused-import + import mock # pylint: disable=unused-import from werkzeug import test as wtest from werkzeug import wrappers @@ -39,286 +39,375 @@ class RespondTest(tb_test.TestCase): - - def testHelloWorld(self): - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, 'hello world', 'text/html') - self.assertEqual(r.status_code, 200) - self.assertEqual(r.response, [six.b('hello world')]) - self.assertEqual(r.headers.get('Content-Length'), '18') - - def testHeadRequest_doesNotWrite(self): - builder = wtest.EnvironBuilder(method='HEAD') - env = builder.get_environ() - request = wrappers.Request(env) - r = http_util.Respond(request, 'hello world', 'text/html') - self.assertEqual(r.status_code, 200) - self.assertEqual(r.response, []) - self.assertEqual(r.headers.get('Content-Length'), '18') - - def testPlainText_appendsUtf8ToContentType(self): - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, 'hello', 'text/plain') - h = r.headers - self.assertEqual(h.get('Content-Type'), 'text/plain; charset=utf-8') - - def testContentLength_isInBytes(self): - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, '爱', 'text/plain') - self.assertEqual(r.headers.get('Content-Length'), '3') - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, '爱'.encode('utf-8'), 'text/plain') - self.assertEqual(r.headers.get('Content-Length'), '3') - - def testResponseCharsetTranscoding(self): - bean = '要依法治国是赞美那些谁是公义的和惩罚恶人。 - 韩非' - - # input is unicode string, output is gbk string - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, bean, 'text/plain; charset=gbk') - self.assertEqual(r.response, [bean.encode('gbk')]) - - # input is utf-8 string, output is gbk string - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, bean.encode('utf-8'), 'text/plain; charset=gbk') - self.assertEqual(r.response, [bean.encode('gbk')]) - - # input is object with unicode strings, output is gbk json - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, {'red': bean}, 'application/json; charset=gbk') - self.assertEqual(r.response, [b'{"red": "' + bean.encode('gbk') + b'"}']) - - # input is object with utf-8 strings, output is gbk json - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond( - q, {'red': bean.encode('utf-8')}, 'application/json; charset=gbk') - self.assertEqual(r.response, [b'{"red": "' + bean.encode('gbk') + b'"}']) - - # input is object with gbk strings, output is gbk json - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond( - q, {'red': bean.encode('gbk')}, - 'application/json; charset=gbk', - encoding='gbk') - self.assertEqual(r.response, [b'{"red": "' + bean.encode('gbk') + b'"}']) - - def testAcceptGzip_compressesResponse(self): - fall_of_hyperion_canto1_stanza1 = '\n'.join([ - 'Fanatics have their dreams, wherewith they weave', - 'A paradise for a sect; the savage too', - 'From forth the loftiest fashion of his sleep', - 'Guesses at Heaven; pity these have not', - 'Trac\'d upon vellum or wild Indian leaf', - 'The shadows of melodious utterance.', - 'But bare of laurel they live, dream, and die;', - 'For Poesy alone can tell her dreams,', - 'With the fine spell of words alone can save', - 'Imagination from the sable charm', - 'And dumb enchantment. Who alive can say,', - '\'Thou art no Poet may\'st not tell thy dreams?\'', - 'Since every man whose soul is not a clod', - 'Hath visions, and would speak, if he had loved', - 'And been well nurtured in his mother tongue.', - 'Whether the dream now purpos\'d to rehearse', - 'Be poet\'s or fanatic\'s will be known', - 'When this warm scribe my hand is in the grave.', - ]) - - e1 = wtest.EnvironBuilder(headers={'Accept-Encoding': '*'}).get_environ() - any_encoding = wrappers.Request(e1) - - r = http_util.Respond( - any_encoding, fall_of_hyperion_canto1_stanza1, 'text/plain') - self.assertEqual(r.headers.get('Content-Encoding'), 'gzip') - self.assertEqual(_gunzip(r.response[0]), # pylint: disable=unsubscriptable-object - fall_of_hyperion_canto1_stanza1.encode('utf-8')) - - e2 = wtest.EnvironBuilder(headers={'Accept-Encoding': 'gzip'}).get_environ() - gzip_encoding = wrappers.Request(e2) - - r = http_util.Respond( - gzip_encoding, fall_of_hyperion_canto1_stanza1, 'text/plain') - self.assertEqual(r.headers.get('Content-Encoding'), 'gzip') - self.assertEqual(_gunzip(r.response[0]), # pylint: disable=unsubscriptable-object - fall_of_hyperion_canto1_stanza1.encode('utf-8')) - - r = http_util.Respond( - any_encoding, fall_of_hyperion_canto1_stanza1, 'image/png') - self.assertEqual( - r.response, [fall_of_hyperion_canto1_stanza1.encode('utf-8')]) - - def testAcceptGzip_alreadyCompressed_sendsPrecompressedResponse(self): - gzip_text = _gzip(b'hello hello hello world') - e = wtest.EnvironBuilder(headers={'Accept-Encoding': 'gzip'}).get_environ() - q = wrappers.Request(e) - r = http_util.Respond(q, gzip_text, 'text/plain', content_encoding='gzip') - self.assertEqual(r.response, [gzip_text]) # Still singly zipped - - def testPrecompressedResponse_noAcceptGzip_decompressesResponse(self): - orig_text = b'hello hello hello world' - gzip_text = _gzip(orig_text) - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, gzip_text, 'text/plain', content_encoding='gzip') - # Streaming gunzip produces file-wrapper application iterator as response, - # so rejoin it into the full response before comparison. - full_response = b''.join(r.response) - self.assertEqual(full_response, orig_text) - - def testPrecompressedResponse_streamingDecompression_catchesBadSize(self): - orig_text = b'hello hello hello world' - gzip_text = _gzip(orig_text) - # Corrupt the gzipped data's stored content size (last 4 bytes). - bad_text = gzip_text[:-4] + _bitflip(gzip_text[-4:]) - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, bad_text, 'text/plain', content_encoding='gzip') - # Streaming gunzip defers actual unzipping until response is used; once - # we iterate over the whole file-wrapper application iterator, the - # underlying GzipFile should be closed, and throw the size check error. - with six.assertRaisesRegex(self, IOError, 'Incorrect length'): - _ = list(r.response) - - def testJson_getsAutoSerialized(self): - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, [1, 2, 3], 'application/json') - self.assertEqual(r.response, [b'[1, 2, 3]']) - - def testExpires_setsCruiseControl(self): - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, 'hello world', 'text/html', expires=60) - self.assertEqual(r.headers.get('Cache-Control'), 'private, max-age=60') - - def testCsp(self): - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond( - q, 'hello', 'text/html', csp_scripts_sha256s=['abcdefghi']) - expected_csp = ( - "default-src 'self';font-src 'self';frame-ancestors *;" - "frame-src 'self';img-src 'self' data: blob:;object-src 'none';" - "style-src 'self' https://www.gstatic.com data: 'unsafe-inline';" - "script-src 'self' 'unsafe-eval' 'sha256-abcdefghi'" - ) - self.assertEqual(r.headers.get('Content-Security-Policy'), expected_csp) - - @mock.patch.object(http_util, '_CSP_SCRIPT_SELF', False) - def testCsp_noHash(self): - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, 'hello', 'text/html', csp_scripts_sha256s=None) - expected_csp = ( - "default-src 'self';font-src 'self';frame-ancestors *;" - "frame-src 'self';img-src 'self' data: blob:;object-src 'none';" - "style-src 'self' https://www.gstatic.com data: 'unsafe-inline';" - "script-src 'unsafe-eval'" - ) - self.assertEqual(r.headers.get('Content-Security-Policy'), expected_csp) - - @mock.patch.object(http_util, '_CSP_SCRIPT_SELF', False) - @mock.patch.object(http_util, '_CSP_SCRIPT_UNSAFE_EVAL', False) - def testCsp_noHash_noUnsafeEval(self): - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, 'hello', 'text/html', csp_scripts_sha256s=None) - expected_csp = ( - "default-src 'self';font-src 'self';frame-ancestors *;" - "frame-src 'self';img-src 'self' data: blob:;object-src 'none';" - "style-src 'self' https://www.gstatic.com data: 'unsafe-inline';" - "script-src 'none'" + def testHelloWorld(self): + q = wrappers.Request(wtest.EnvironBuilder().get_environ()) + r = http_util.Respond(q, "hello world", "text/html") + self.assertEqual(r.status_code, 200) + self.assertEqual(r.response, [six.b("hello world")]) + self.assertEqual(r.headers.get("Content-Length"), "18") + + def testHeadRequest_doesNotWrite(self): + builder = wtest.EnvironBuilder(method="HEAD") + env = builder.get_environ() + request = wrappers.Request(env) + r = http_util.Respond(request, "hello world", "text/html") + self.assertEqual(r.status_code, 200) + self.assertEqual(r.response, []) + self.assertEqual(r.headers.get("Content-Length"), "18") + + def testPlainText_appendsUtf8ToContentType(self): + q = wrappers.Request(wtest.EnvironBuilder().get_environ()) + r = http_util.Respond(q, "hello", "text/plain") + h = r.headers + self.assertEqual(h.get("Content-Type"), "text/plain; charset=utf-8") + + def testContentLength_isInBytes(self): + q = wrappers.Request(wtest.EnvironBuilder().get_environ()) + r = http_util.Respond(q, "爱", "text/plain") + self.assertEqual(r.headers.get("Content-Length"), "3") + q = wrappers.Request(wtest.EnvironBuilder().get_environ()) + r = http_util.Respond(q, "爱".encode("utf-8"), "text/plain") + self.assertEqual(r.headers.get("Content-Length"), "3") + + def testResponseCharsetTranscoding(self): + bean = "要依法治国是赞美那些谁是公义的和惩罚恶人。 - 韩非" + + # input is unicode string, output is gbk string + q = wrappers.Request(wtest.EnvironBuilder().get_environ()) + r = http_util.Respond(q, bean, "text/plain; charset=gbk") + self.assertEqual(r.response, [bean.encode("gbk")]) + + # input is utf-8 string, output is gbk string + q = wrappers.Request(wtest.EnvironBuilder().get_environ()) + r = http_util.Respond( + q, bean.encode("utf-8"), "text/plain; charset=gbk" + ) + self.assertEqual(r.response, [bean.encode("gbk")]) + + # input is object with unicode strings, output is gbk json + q = wrappers.Request(wtest.EnvironBuilder().get_environ()) + r = http_util.Respond(q, {"red": bean}, "application/json; charset=gbk") + self.assertEqual( + r.response, [b'{"red": "' + bean.encode("gbk") + b'"}'] + ) + + # input is object with utf-8 strings, output is gbk json + q = wrappers.Request(wtest.EnvironBuilder().get_environ()) + r = http_util.Respond( + q, {"red": bean.encode("utf-8")}, "application/json; charset=gbk" + ) + self.assertEqual( + r.response, [b'{"red": "' + bean.encode("gbk") + b'"}'] + ) + + # input is object with gbk strings, output is gbk json + q = wrappers.Request(wtest.EnvironBuilder().get_environ()) + r = http_util.Respond( + q, + {"red": bean.encode("gbk")}, + "application/json; charset=gbk", + encoding="gbk", + ) + self.assertEqual( + r.response, [b'{"red": "' + bean.encode("gbk") + b'"}'] + ) + + def testAcceptGzip_compressesResponse(self): + fall_of_hyperion_canto1_stanza1 = "\n".join( + [ + "Fanatics have their dreams, wherewith they weave", + "A paradise for a sect; the savage too", + "From forth the loftiest fashion of his sleep", + "Guesses at Heaven; pity these have not", + "Trac'd upon vellum or wild Indian leaf", + "The shadows of melodious utterance.", + "But bare of laurel they live, dream, and die;", + "For Poesy alone can tell her dreams,", + "With the fine spell of words alone can save", + "Imagination from the sable charm", + "And dumb enchantment. Who alive can say,", + "'Thou art no Poet may'st not tell thy dreams?'", + "Since every man whose soul is not a clod", + "Hath visions, and would speak, if he had loved", + "And been well nurtured in his mother tongue.", + "Whether the dream now purpos'd to rehearse", + "Be poet's or fanatic's will be known", + "When this warm scribe my hand is in the grave.", + ] + ) + + e1 = wtest.EnvironBuilder( + headers={"Accept-Encoding": "*"} + ).get_environ() + any_encoding = wrappers.Request(e1) + + r = http_util.Respond( + any_encoding, fall_of_hyperion_canto1_stanza1, "text/plain" + ) + self.assertEqual(r.headers.get("Content-Encoding"), "gzip") + self.assertEqual( + _gunzip(r.response[0]), # pylint: disable=unsubscriptable-object + fall_of_hyperion_canto1_stanza1.encode("utf-8"), + ) + + e2 = wtest.EnvironBuilder( + headers={"Accept-Encoding": "gzip"} + ).get_environ() + gzip_encoding = wrappers.Request(e2) + + r = http_util.Respond( + gzip_encoding, fall_of_hyperion_canto1_stanza1, "text/plain" + ) + self.assertEqual(r.headers.get("Content-Encoding"), "gzip") + self.assertEqual( + _gunzip(r.response[0]), # pylint: disable=unsubscriptable-object + fall_of_hyperion_canto1_stanza1.encode("utf-8"), + ) + + r = http_util.Respond( + any_encoding, fall_of_hyperion_canto1_stanza1, "image/png" + ) + self.assertEqual( + r.response, [fall_of_hyperion_canto1_stanza1.encode("utf-8")] + ) + + def testAcceptGzip_alreadyCompressed_sendsPrecompressedResponse(self): + gzip_text = _gzip(b"hello hello hello world") + e = wtest.EnvironBuilder( + headers={"Accept-Encoding": "gzip"} + ).get_environ() + q = wrappers.Request(e) + r = http_util.Respond( + q, gzip_text, "text/plain", content_encoding="gzip" + ) + self.assertEqual(r.response, [gzip_text]) # Still singly zipped + + def testPrecompressedResponse_noAcceptGzip_decompressesResponse(self): + orig_text = b"hello hello hello world" + gzip_text = _gzip(orig_text) + q = wrappers.Request(wtest.EnvironBuilder().get_environ()) + r = http_util.Respond( + q, gzip_text, "text/plain", content_encoding="gzip" + ) + # Streaming gunzip produces file-wrapper application iterator as response, + # so rejoin it into the full response before comparison. + full_response = b"".join(r.response) + self.assertEqual(full_response, orig_text) + + def testPrecompressedResponse_streamingDecompression_catchesBadSize(self): + orig_text = b"hello hello hello world" + gzip_text = _gzip(orig_text) + # Corrupt the gzipped data's stored content size (last 4 bytes). + bad_text = gzip_text[:-4] + _bitflip(gzip_text[-4:]) + q = wrappers.Request(wtest.EnvironBuilder().get_environ()) + r = http_util.Respond( + q, bad_text, "text/plain", content_encoding="gzip" + ) + # Streaming gunzip defers actual unzipping until response is used; once + # we iterate over the whole file-wrapper application iterator, the + # underlying GzipFile should be closed, and throw the size check error. + with six.assertRaisesRegex(self, IOError, "Incorrect length"): + _ = list(r.response) + + def testJson_getsAutoSerialized(self): + q = wrappers.Request(wtest.EnvironBuilder().get_environ()) + r = http_util.Respond(q, [1, 2, 3], "application/json") + self.assertEqual(r.response, [b"[1, 2, 3]"]) + + def testExpires_setsCruiseControl(self): + q = wrappers.Request(wtest.EnvironBuilder().get_environ()) + r = http_util.Respond(q, "hello world", "text/html", expires=60) + self.assertEqual(r.headers.get("Cache-Control"), "private, max-age=60") + + def testCsp(self): + q = wrappers.Request(wtest.EnvironBuilder().get_environ()) + r = http_util.Respond( + q, "hello", "text/html", csp_scripts_sha256s=["abcdefghi"] + ) + expected_csp = ( + "default-src 'self';font-src 'self';frame-ancestors *;" + "frame-src 'self';img-src 'self' data: blob:;object-src 'none';" + "style-src 'self' https://www.gstatic.com data: 'unsafe-inline';" + "script-src 'self' 'unsafe-eval' 'sha256-abcdefghi'" + ) + self.assertEqual(r.headers.get("Content-Security-Policy"), expected_csp) + + @mock.patch.object(http_util, "_CSP_SCRIPT_SELF", False) + def testCsp_noHash(self): + q = wrappers.Request(wtest.EnvironBuilder().get_environ()) + r = http_util.Respond( + q, "hello", "text/html", csp_scripts_sha256s=None + ) + expected_csp = ( + "default-src 'self';font-src 'self';frame-ancestors *;" + "frame-src 'self';img-src 'self' data: blob:;object-src 'none';" + "style-src 'self' https://www.gstatic.com data: 'unsafe-inline';" + "script-src 'unsafe-eval'" + ) + self.assertEqual(r.headers.get("Content-Security-Policy"), expected_csp) + + @mock.patch.object(http_util, "_CSP_SCRIPT_SELF", False) + @mock.patch.object(http_util, "_CSP_SCRIPT_UNSAFE_EVAL", False) + def testCsp_noHash_noUnsafeEval(self): + q = wrappers.Request(wtest.EnvironBuilder().get_environ()) + r = http_util.Respond( + q, "hello", "text/html", csp_scripts_sha256s=None + ) + expected_csp = ( + "default-src 'self';font-src 'self';frame-ancestors *;" + "frame-src 'self';img-src 'self' data: blob:;object-src 'none';" + "style-src 'self' https://www.gstatic.com data: 'unsafe-inline';" + "script-src 'none'" + ) + self.assertEqual(r.headers.get("Content-Security-Policy"), expected_csp) + + @mock.patch.object(http_util, "_CSP_SCRIPT_SELF", True) + @mock.patch.object(http_util, "_CSP_SCRIPT_UNSAFE_EVAL", False) + def testCsp_onlySelf(self): + q = wrappers.Request(wtest.EnvironBuilder().get_environ()) + r = http_util.Respond( + q, "hello", "text/html", csp_scripts_sha256s=None + ) + expected_csp = ( + "default-src 'self';font-src 'self';frame-ancestors *;" + "frame-src 'self';img-src 'self' data: blob:;object-src 'none';" + "style-src 'self' https://www.gstatic.com data: 'unsafe-inline';" + "script-src 'self'" + ) + self.assertEqual(r.headers.get("Content-Security-Policy"), expected_csp) + + @mock.patch.object(http_util, "_CSP_SCRIPT_UNSAFE_EVAL", False) + def testCsp_disableUnsafeEval(self): + q = wrappers.Request(wtest.EnvironBuilder().get_environ()) + r = http_util.Respond( + q, "hello", "text/html", csp_scripts_sha256s=["abcdefghi"] + ) + expected_csp = ( + "default-src 'self';font-src 'self';frame-ancestors *;" + "frame-src 'self';img-src 'self' data: blob:;object-src 'none';" + "style-src 'self' https://www.gstatic.com data: 'unsafe-inline';" + "script-src 'self' 'sha256-abcdefghi'" + ) + self.assertEqual(r.headers.get("Content-Security-Policy"), expected_csp) + + @mock.patch.object( + http_util, "_CSP_IMG_DOMAINS_WHITELIST", ["https://example.com"] ) - self.assertEqual(r.headers.get('Content-Security-Policy'), expected_csp) - - @mock.patch.object(http_util, '_CSP_SCRIPT_SELF', True) - @mock.patch.object(http_util, '_CSP_SCRIPT_UNSAFE_EVAL', False) - def testCsp_onlySelf(self): - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, 'hello', 'text/html', csp_scripts_sha256s=None) - expected_csp = ( - "default-src 'self';font-src 'self';frame-ancestors *;" - "frame-src 'self';img-src 'self' data: blob:;object-src 'none';" - "style-src 'self' https://www.gstatic.com data: 'unsafe-inline';" - "script-src 'self'" + @mock.patch.object( + http_util, + "_CSP_SCRIPT_DOMAINS_WHITELIST", + ["https://tensorflow.org/tensorboard"], ) - self.assertEqual(r.headers.get('Content-Security-Policy'), expected_csp) - - @mock.patch.object(http_util, '_CSP_SCRIPT_UNSAFE_EVAL', False) - def testCsp_disableUnsafeEval(self): - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond( - q, 'hello', 'text/html', csp_scripts_sha256s=['abcdefghi']) - expected_csp = ( - "default-src 'self';font-src 'self';frame-ancestors *;" - "frame-src 'self';img-src 'self' data: blob:;object-src 'none';" - "style-src 'self' https://www.gstatic.com data: 'unsafe-inline';" - "script-src 'self' 'sha256-abcdefghi'" + @mock.patch.object( + http_util, "_CSP_STYLE_DOMAINS_WHITELIST", ["https://googol.com"] ) - self.assertEqual(r.headers.get('Content-Security-Policy'), expected_csp) - - @mock.patch.object(http_util, '_CSP_IMG_DOMAINS_WHITELIST', ['https://example.com']) - @mock.patch.object(http_util, '_CSP_SCRIPT_DOMAINS_WHITELIST', - ['https://tensorflow.org/tensorboard']) - @mock.patch.object(http_util, '_CSP_STYLE_DOMAINS_WHITELIST', ['https://googol.com']) - @mock.patch.object(http_util, '_CSP_FRAME_DOMAINS_WHITELIST', ['https://myframe.com']) - def testCsp_globalDomainWhiteList(self): - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, 'hello', 'text/html', csp_scripts_sha256s=['abcd']) - expected_csp = ( - "default-src 'self';font-src 'self';frame-ancestors *;" - "frame-src 'self' https://myframe.com;" - "img-src 'self' data: blob: https://example.com;" - "object-src 'none';style-src 'self' https://www.gstatic.com data: " - "'unsafe-inline' https://googol.com;script-src " - "https://tensorflow.org/tensorboard 'self' 'unsafe-eval' 'sha256-abcd'" + @mock.patch.object( + http_util, "_CSP_FRAME_DOMAINS_WHITELIST", ["https://myframe.com"] ) - self.assertEqual(r.headers.get('Content-Security-Policy'), expected_csp) - - def testCsp_badGlobalDomainWhiteList(self): - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - configs = [ - '_CSP_SCRIPT_DOMAINS_WHITELIST', - '_CSP_IMG_DOMAINS_WHITELIST', - '_CSP_STYLE_DOMAINS_WHITELIST', - '_CSP_FONT_DOMAINS_WHITELIST', - '_CSP_FRAME_DOMAINS_WHITELIST', - ] - - for config in configs: - with mock.patch.object(http_util, config, ['http://tensorflow.org']): - with self.assertRaisesRegex( - ValueError, '^Expected all whitelist to be a https URL'): - http_util.Respond(q, 'hello', 'text/html', csp_scripts_sha256s=['abcd']) - - # Cannot grant more trust to a script from a remote source. - with mock.patch.object(http_util, config, - ["'strict-dynamic' 'unsafe-eval' https://tensorflow.org/"]): - with self.assertRaisesRegex( - ValueError, '^Expected all whitelist to be a https URL'): - http_util.Respond(q, 'hello', 'text/html', csp_scripts_sha256s=['abcd']) - - # Attempt to terminate the script-src to specify a new one that allows ALL! - with mock.patch.object(http_util, config, ['https://tensorflow.org;script-src *']): - with self.assertRaisesRegex( - ValueError, '^Expected whitelist domain to not contain ";"'): - http_util.Respond(q, 'hello', 'text/html', csp_scripts_sha256s=['abcd']) - - # Attempt to use whitespace, delimit character, to specify a new one. - with mock.patch.object(http_util, config, ['https://tensorflow.org *']): - with self.assertRaisesRegex( - ValueError, '^Expected whitelist domain to not contain a whitespace'): - http_util.Respond(q, 'hello', 'text/html', csp_scripts_sha256s=['abcd']) + def testCsp_globalDomainWhiteList(self): + q = wrappers.Request(wtest.EnvironBuilder().get_environ()) + r = http_util.Respond( + q, "hello", "text/html", csp_scripts_sha256s=["abcd"] + ) + expected_csp = ( + "default-src 'self';font-src 'self';frame-ancestors *;" + "frame-src 'self' https://myframe.com;" + "img-src 'self' data: blob: https://example.com;" + "object-src 'none';style-src 'self' https://www.gstatic.com data: " + "'unsafe-inline' https://googol.com;script-src " + "https://tensorflow.org/tensorboard 'self' 'unsafe-eval' 'sha256-abcd'" + ) + self.assertEqual(r.headers.get("Content-Security-Policy"), expected_csp) + + def testCsp_badGlobalDomainWhiteList(self): + q = wrappers.Request(wtest.EnvironBuilder().get_environ()) + configs = [ + "_CSP_SCRIPT_DOMAINS_WHITELIST", + "_CSP_IMG_DOMAINS_WHITELIST", + "_CSP_STYLE_DOMAINS_WHITELIST", + "_CSP_FONT_DOMAINS_WHITELIST", + "_CSP_FRAME_DOMAINS_WHITELIST", + ] + + for config in configs: + with mock.patch.object( + http_util, config, ["http://tensorflow.org"] + ): + with self.assertRaisesRegex( + ValueError, "^Expected all whitelist to be a https URL" + ): + http_util.Respond( + q, + "hello", + "text/html", + csp_scripts_sha256s=["abcd"], + ) + + # Cannot grant more trust to a script from a remote source. + with mock.patch.object( + http_util, + config, + ["'strict-dynamic' 'unsafe-eval' https://tensorflow.org/"], + ): + with self.assertRaisesRegex( + ValueError, "^Expected all whitelist to be a https URL" + ): + http_util.Respond( + q, + "hello", + "text/html", + csp_scripts_sha256s=["abcd"], + ) + + # Attempt to terminate the script-src to specify a new one that allows ALL! + with mock.patch.object( + http_util, config, ["https://tensorflow.org;script-src *"] + ): + with self.assertRaisesRegex( + ValueError, '^Expected whitelist domain to not contain ";"' + ): + http_util.Respond( + q, + "hello", + "text/html", + csp_scripts_sha256s=["abcd"], + ) + + # Attempt to use whitespace, delimit character, to specify a new one. + with mock.patch.object( + http_util, config, ["https://tensorflow.org *"] + ): + with self.assertRaisesRegex( + ValueError, + "^Expected whitelist domain to not contain a whitespace", + ): + http_util.Respond( + q, + "hello", + "text/html", + csp_scripts_sha256s=["abcd"], + ) def _gzip(bs): - out = six.BytesIO() - with gzip.GzipFile(fileobj=out, mode='wb') as f: - f.write(bs) - return out.getvalue() + out = six.BytesIO() + with gzip.GzipFile(fileobj=out, mode="wb") as f: + f.write(bs) + return out.getvalue() def _gunzip(bs): - with gzip.GzipFile(fileobj=six.BytesIO(bs), mode='rb') as f: - return f.read() + with gzip.GzipFile(fileobj=six.BytesIO(bs), mode="rb") as f: + return f.read() + def _bitflip(bs): - # Return bytestring with all its bits flipped. - return b''.join(struct.pack('B', 0xFF ^ struct.unpack_from('B', bs, i)[0]) - for i in range(len(bs))) + # Return bytestring with all its bits flipped. + return b"".join( + struct.pack("B", 0xFF ^ struct.unpack_from("B", bs, i)[0]) + for i in range(len(bs)) + ) + -if __name__ == '__main__': - tb_test.main() +if __name__ == "__main__": + tb_test.main() diff --git a/tensorboard/backend/json_util.py b/tensorboard/backend/json_util.py index 65b2c01c7c..5ca23afc48 100644 --- a/tensorboard/backend/json_util.py +++ b/tensorboard/backend/json_util.py @@ -17,10 +17,10 @@ Python provides no way to override how json.dumps serializes Infinity/-Infinity/NaN; if allow_nan is true, it encodes them as -Infinity/-Infinity/NaN, in violation of the JSON spec and in violation of what -JSON.parse accepts. If it's false, it throws a ValueError, Neither subclassing -JSONEncoder nor passing a function in the |default| keyword argument overrides -this. +Infinity/-Infinity/NaN, in violation of the JSON spec and in violation +of what JSON.parse accepts. If it's false, it throws a ValueError, +Neither subclassing JSONEncoder nor passing a function in the |default| +keyword argument overrides this. """ from __future__ import absolute_import @@ -33,46 +33,45 @@ from tensorboard.compat import tf -_INFINITY = float('inf') -_NEGATIVE_INFINITY = float('-inf') +_INFINITY = float("inf") +_NEGATIVE_INFINITY = float("-inf") -def Cleanse(obj, encoding='utf-8'): - """Makes Python object appropriate for JSON serialization. +def Cleanse(obj, encoding="utf-8"): + """Makes Python object appropriate for JSON serialization. - - Replaces instances of Infinity/-Infinity/NaN with strings. - - Turns byte strings into unicode strings. - - Turns sets into sorted lists. - - Turns tuples into lists. + - Replaces instances of Infinity/-Infinity/NaN with strings. + - Turns byte strings into unicode strings. + - Turns sets into sorted lists. + - Turns tuples into lists. - Args: - obj: Python data structure. - encoding: Charset used to decode byte strings. + Args: + obj: Python data structure. + encoding: Charset used to decode byte strings. - Returns: - Unicode JSON data structure. - """ - if isinstance(obj, int): - return obj - elif isinstance(obj, float): - if obj == _INFINITY: - return 'Infinity' - elif obj == _NEGATIVE_INFINITY: - return '-Infinity' - elif math.isnan(obj): - return 'NaN' + Returns: + Unicode JSON data structure. + """ + if isinstance(obj, int): + return obj + elif isinstance(obj, float): + if obj == _INFINITY: + return "Infinity" + elif obj == _NEGATIVE_INFINITY: + return "-Infinity" + elif math.isnan(obj): + return "NaN" + else: + return obj + elif isinstance(obj, bytes): + return tf.compat.as_text(obj, encoding) + elif isinstance(obj, (list, tuple)): + return [Cleanse(i, encoding) for i in obj] + elif isinstance(obj, set): + return [Cleanse(i, encoding) for i in sorted(obj)] + elif isinstance(obj, dict): + return collections.OrderedDict( + (Cleanse(k, encoding), Cleanse(v, encoding)) for k, v in obj.items() + ) else: - return obj - elif isinstance(obj, bytes): - return tf.compat.as_text(obj, encoding) - elif isinstance(obj, (list, tuple)): - return [Cleanse(i, encoding) for i in obj] - elif isinstance(obj, set): - return [Cleanse(i, encoding) for i in sorted(obj)] - elif isinstance(obj, dict): - return collections.OrderedDict( - (Cleanse(k, encoding), Cleanse(v, encoding)) - for k, v in obj.items() - ) - else: - return obj + return obj diff --git a/tensorboard/backend/json_util_test.py b/tensorboard/backend/json_util_test.py index a9b3424dbc..e0d030be52 100644 --- a/tensorboard/backend/json_util_test.py +++ b/tensorboard/backend/json_util_test.py @@ -23,54 +23,55 @@ from tensorboard import test as tb_test from tensorboard.backend import json_util -_INFINITY = float('inf') +_INFINITY = float("inf") class CleanseTest(tb_test.TestCase): - - def _assertWrapsAs(self, to_wrap, expected): - """Asserts that |to_wrap| becomes |expected| when wrapped.""" - actual = json_util.Cleanse(to_wrap) - for a, e in zip(actual, expected): - self.assertEqual(e, a) - - def testWrapsPrimitives(self): - self._assertWrapsAs(_INFINITY, 'Infinity') - self._assertWrapsAs(-_INFINITY, '-Infinity') - self._assertWrapsAs(float('nan'), 'NaN') - - def testWrapsObjectValues(self): - self._assertWrapsAs({'x': _INFINITY}, {'x': 'Infinity'}) - - def testWrapsObjectKeys(self): - self._assertWrapsAs({_INFINITY: 'foo'}, {'Infinity': 'foo'}) - - def testWrapsInListsAndTuples(self): - self._assertWrapsAs([_INFINITY], ['Infinity']) - # map() returns a list even if the argument is a tuple. - self._assertWrapsAs((_INFINITY,), ['Infinity',]) - - def testWrapsRecursively(self): - self._assertWrapsAs({'x': [_INFINITY]}, {'x': ['Infinity']}) - - def testOrderedDict_preservesOrder(self): - # dict iteration order is not specified prior to Python 3.7, and is - # observably different from insertion order in CPython 2.7. - od = collections.OrderedDict() - for c in string.ascii_lowercase: - od[c] = c - self.assertEqual(len(od), 26, od) - self.assertEqual(list(od), list(json_util.Cleanse(od))) - - def testTuple_turnsIntoList(self): - self.assertEqual(json_util.Cleanse(('a', 'b')), ['a', 'b']) - - def testSet_turnsIntoSortedList(self): - self.assertEqual(json_util.Cleanse(set(['b', 'a'])), ['a', 'b']) - - def testByteString_turnsIntoUnicodeString(self): - self.assertEqual(json_util.Cleanse(b'\xc2\xa3'), u'\u00a3') # is # sterling - - -if __name__ == '__main__': - tb_test.main() + def _assertWrapsAs(self, to_wrap, expected): + """Asserts that |to_wrap| becomes |expected| when wrapped.""" + actual = json_util.Cleanse(to_wrap) + for a, e in zip(actual, expected): + self.assertEqual(e, a) + + def testWrapsPrimitives(self): + self._assertWrapsAs(_INFINITY, "Infinity") + self._assertWrapsAs(-_INFINITY, "-Infinity") + self._assertWrapsAs(float("nan"), "NaN") + + def testWrapsObjectValues(self): + self._assertWrapsAs({"x": _INFINITY}, {"x": "Infinity"}) + + def testWrapsObjectKeys(self): + self._assertWrapsAs({_INFINITY: "foo"}, {"Infinity": "foo"}) + + def testWrapsInListsAndTuples(self): + self._assertWrapsAs([_INFINITY], ["Infinity"]) + # map() returns a list even if the argument is a tuple. + self._assertWrapsAs((_INFINITY,), ["Infinity",]) + + def testWrapsRecursively(self): + self._assertWrapsAs({"x": [_INFINITY]}, {"x": ["Infinity"]}) + + def testOrderedDict_preservesOrder(self): + # dict iteration order is not specified prior to Python 3.7, and is + # observably different from insertion order in CPython 2.7. + od = collections.OrderedDict() + for c in string.ascii_lowercase: + od[c] = c + self.assertEqual(len(od), 26, od) + self.assertEqual(list(od), list(json_util.Cleanse(od))) + + def testTuple_turnsIntoList(self): + self.assertEqual(json_util.Cleanse(("a", "b")), ["a", "b"]) + + def testSet_turnsIntoSortedList(self): + self.assertEqual(json_util.Cleanse(set(["b", "a"])), ["a", "b"]) + + def testByteString_turnsIntoUnicodeString(self): + self.assertEqual( + json_util.Cleanse(b"\xc2\xa3"), u"\u00a3" + ) # is # sterling + + +if __name__ == "__main__": + tb_test.main() diff --git a/tensorboard/backend/path_prefix.py b/tensorboard/backend/path_prefix.py index 2192deb9f7..13ffeb61d0 100644 --- a/tensorboard/backend/path_prefix.py +++ b/tensorboard/backend/path_prefix.py @@ -27,39 +27,45 @@ class PathPrefixMiddleware(object): - """WSGI middleware for path prefixes. + """WSGI middleware for path prefixes. - All requests to this middleware must begin with the specified path - prefix (otherwise, a 404 will be returned immediately). Requests will - be forwarded to the underlying application with the path prefix - stripped and appended to `SCRIPT_NAME` (see the WSGI spec, PEP 3333, - for details). - """ + All requests to this middleware must begin with the specified path + prefix (otherwise, a 404 will be returned immediately). Requests + will be forwarded to the underlying application with the path prefix + stripped and appended to `SCRIPT_NAME` (see the WSGI spec, PEP 3333, + for details). + """ - def __init__(self, application, path_prefix): - """Initializes this middleware. + def __init__(self, application, path_prefix): + """Initializes this middleware. - Args: - application: The WSGI application to wrap (see PEP 3333). - path_prefix: A string path prefix to be stripped from incoming - requests. If empty, this middleware is a no-op. If non-empty, - the path prefix must start with a slash and not end with one - (e.g., "/tensorboard"). - """ - if path_prefix.endswith("/"): - raise ValueError("Path prefix must not end with slash: %r" % path_prefix) - if path_prefix and not path_prefix.startswith("/"): - raise ValueError( - "Non-empty path prefix must start with slash: %r" % path_prefix - ) - self._application = application - self._path_prefix = path_prefix - self._strict_prefix = self._path_prefix + "/" + Args: + application: The WSGI application to wrap (see PEP 3333). + path_prefix: A string path prefix to be stripped from incoming + requests. If empty, this middleware is a no-op. If non-empty, + the path prefix must start with a slash and not end with one + (e.g., "/tensorboard"). + """ + if path_prefix.endswith("/"): + raise ValueError( + "Path prefix must not end with slash: %r" % path_prefix + ) + if path_prefix and not path_prefix.startswith("/"): + raise ValueError( + "Non-empty path prefix must start with slash: %r" % path_prefix + ) + self._application = application + self._path_prefix = path_prefix + self._strict_prefix = self._path_prefix + "/" - def __call__(self, environ, start_response): - path = environ.get("PATH_INFO", "") - if path != self._path_prefix and not path.startswith(self._strict_prefix): - raise errors.NotFoundError() - environ["PATH_INFO"] = path[len(self._path_prefix):] - environ["SCRIPT_NAME"] = environ.get("SCRIPT_NAME", "") + self._path_prefix - return self._application(environ, start_response) + def __call__(self, environ, start_response): + path = environ.get("PATH_INFO", "") + if path != self._path_prefix and not path.startswith( + self._strict_prefix + ): + raise errors.NotFoundError() + environ["PATH_INFO"] = path[len(self._path_prefix) :] + environ["SCRIPT_NAME"] = ( + environ.get("SCRIPT_NAME", "") + self._path_prefix + ) + return self._application(environ, start_response) diff --git a/tensorboard/backend/path_prefix_test.py b/tensorboard/backend/path_prefix_test.py index 80dda313c1..6b62fbf76f 100644 --- a/tensorboard/backend/path_prefix_test.py +++ b/tensorboard/backend/path_prefix_test.py @@ -29,85 +29,85 @@ class PathPrefixMiddlewareTest(tb_test.TestCase): - """Tests for `PathPrefixMiddleware`.""" - - def _echo_app(self, environ, start_response): - # https://www.python.org/dev/peps/pep-0333/#environ-variables - data = { - "path": environ.get("PATH_INFO", ""), - "script": environ.get("SCRIPT_NAME", ""), - } - body = json.dumps(data, sort_keys=True) - start_response("200 OK", [("Content-Type", "application/json")]) - return [body] - - def _assert_ok(self, response, path, script): - self.assertEqual(response.status_code, 200) - actual = json.loads(response.get_data()) - expected = dict(path=path, script=script) - self.assertEqual(actual, expected) - - def test_bad_path_prefix_without_leading_slash(self): - with self.assertRaises(ValueError) as cm: - path_prefix.PathPrefixMiddleware(self._echo_app, "hmm") - msg = str(cm.exception) - self.assertIn("must start with slash", msg) - self.assertIn(repr("hmm"), msg) - - def test_bad_path_prefix_with_trailing_slash(self): - with self.assertRaises(ValueError) as cm: - path_prefix.PathPrefixMiddleware(self._echo_app, "/hmm/") - msg = str(cm.exception) - self.assertIn("must not end with slash", msg) - self.assertIn(repr("/hmm/"), msg) - - def test_empty_path_prefix(self): - app = path_prefix.PathPrefixMiddleware(self._echo_app, "") - server = werkzeug_test.Client(app, werkzeug.BaseResponse) - - with self.subTest("at empty"): - self._assert_ok(server.get(""), path="", script="") - - with self.subTest("at root"): - self._assert_ok(server.get("/"), path="/", script="") - - with self.subTest("at subpath"): - response = server.get("/foo/bar") - self._assert_ok(server.get("/foo/bar"), path="/foo/bar", script="") - - def test_nonempty_path_prefix(self): - app = path_prefix.PathPrefixMiddleware(self._echo_app, "/pfx") - server = werkzeug_test.Client(app, werkzeug.BaseResponse) - - with self.subTest("at root"): - response = server.get("/pfx") - self._assert_ok(response, path="", script="/pfx") - - with self.subTest("at root with slash"): - response = server.get("/pfx/") - self._assert_ok(response, path="/", script="/pfx") - - with self.subTest("at subpath"): - response = server.get("/pfx/foo/bar") - self._assert_ok(response, path="/foo/bar", script="/pfx") - - with self.subTest("at non-path-component extension"): - with self.assertRaises(errors.NotFoundError): - server.get("/pfxz") - - with self.subTest("above path prefix"): - with self.assertRaises(errors.NotFoundError): - server.get("/hmm") - - def test_composition(self): - app = self._echo_app - app = path_prefix.PathPrefixMiddleware(app, "/bar") - app = path_prefix.PathPrefixMiddleware(app, "/foo") - server = werkzeug_test.Client(app, werkzeug.BaseResponse) - - response = server.get("/foo/bar/baz/quux") - self._assert_ok(response, path="/baz/quux", script="/foo/bar") + """Tests for `PathPrefixMiddleware`.""" + + def _echo_app(self, environ, start_response): + # https://www.python.org/dev/peps/pep-0333/#environ-variables + data = { + "path": environ.get("PATH_INFO", ""), + "script": environ.get("SCRIPT_NAME", ""), + } + body = json.dumps(data, sort_keys=True) + start_response("200 OK", [("Content-Type", "application/json")]) + return [body] + + def _assert_ok(self, response, path, script): + self.assertEqual(response.status_code, 200) + actual = json.loads(response.get_data()) + expected = dict(path=path, script=script) + self.assertEqual(actual, expected) + + def test_bad_path_prefix_without_leading_slash(self): + with self.assertRaises(ValueError) as cm: + path_prefix.PathPrefixMiddleware(self._echo_app, "hmm") + msg = str(cm.exception) + self.assertIn("must start with slash", msg) + self.assertIn(repr("hmm"), msg) + + def test_bad_path_prefix_with_trailing_slash(self): + with self.assertRaises(ValueError) as cm: + path_prefix.PathPrefixMiddleware(self._echo_app, "/hmm/") + msg = str(cm.exception) + self.assertIn("must not end with slash", msg) + self.assertIn(repr("/hmm/"), msg) + + def test_empty_path_prefix(self): + app = path_prefix.PathPrefixMiddleware(self._echo_app, "") + server = werkzeug_test.Client(app, werkzeug.BaseResponse) + + with self.subTest("at empty"): + self._assert_ok(server.get(""), path="", script="") + + with self.subTest("at root"): + self._assert_ok(server.get("/"), path="/", script="") + + with self.subTest("at subpath"): + response = server.get("/foo/bar") + self._assert_ok(server.get("/foo/bar"), path="/foo/bar", script="") + + def test_nonempty_path_prefix(self): + app = path_prefix.PathPrefixMiddleware(self._echo_app, "/pfx") + server = werkzeug_test.Client(app, werkzeug.BaseResponse) + + with self.subTest("at root"): + response = server.get("/pfx") + self._assert_ok(response, path="", script="/pfx") + + with self.subTest("at root with slash"): + response = server.get("/pfx/") + self._assert_ok(response, path="/", script="/pfx") + + with self.subTest("at subpath"): + response = server.get("/pfx/foo/bar") + self._assert_ok(response, path="/foo/bar", script="/pfx") + + with self.subTest("at non-path-component extension"): + with self.assertRaises(errors.NotFoundError): + server.get("/pfxz") + + with self.subTest("above path prefix"): + with self.assertRaises(errors.NotFoundError): + server.get("/hmm") + + def test_composition(self): + app = self._echo_app + app = path_prefix.PathPrefixMiddleware(app, "/bar") + app = path_prefix.PathPrefixMiddleware(app, "/foo") + server = werkzeug_test.Client(app, werkzeug.BaseResponse) + + response = server.get("/foo/bar/baz/quux") + self._assert_ok(response, path="/baz/quux", script="/foo/bar") if __name__ == "__main__": - tb_test.main() + tb_test.main() diff --git a/tensorboard/backend/process_graph.py b/tensorboard/backend/process_graph.py index adf7f0e1b3..71a6b1c8e7 100644 --- a/tensorboard/backend/process_graph.py +++ b/tensorboard/backend/process_graph.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Graph post-processing logic. Used by both TensorBoard and mldash.""" +"""Graph post-processing logic. + +Used by both TensorBoard and mldash. +""" from __future__ import absolute_import @@ -22,47 +25,53 @@ from tensorboard.compat import tf -def prepare_graph_for_ui(graph, limit_attr_size=1024, - large_attrs_key='_too_large_attrs'): - """Prepares (modifies in-place) the graph to be served to the front-end. +def prepare_graph_for_ui( + graph, limit_attr_size=1024, large_attrs_key="_too_large_attrs" +): + """Prepares (modifies in-place) the graph to be served to the front-end. - For now, it supports filtering out attributes that are - too large to be shown in the graph UI. + For now, it supports filtering out attributes that are + too large to be shown in the graph UI. - Args: - graph: The GraphDef proto message. - limit_attr_size: Maximum allowed size in bytes, before the attribute - is considered large. Default is 1024 (1KB). Must be > 0 or None. - If None, there will be no filtering. - large_attrs_key: The attribute key that will be used for storing attributes - that are too large. Default is '_too_large_attrs'. Must be != None if - `limit_attr_size` is != None. + Args: + graph: The GraphDef proto message. + limit_attr_size: Maximum allowed size in bytes, before the attribute + is considered large. Default is 1024 (1KB). Must be > 0 or None. + If None, there will be no filtering. + large_attrs_key: The attribute key that will be used for storing attributes + that are too large. Default is '_too_large_attrs'. Must be != None if + `limit_attr_size` is != None. - Raises: - ValueError: If `large_attrs_key is None` while `limit_attr_size != None`. - ValueError: If `limit_attr_size` is defined, but <= 0. - """ - # Check input for validity. - if limit_attr_size is not None: - if large_attrs_key is None: - raise ValueError('large_attrs_key must be != None when limit_attr_size' - '!= None.') + Raises: + ValueError: If `large_attrs_key is None` while `limit_attr_size != None`. + ValueError: If `limit_attr_size` is defined, but <= 0. + """ + # Check input for validity. + if limit_attr_size is not None: + if large_attrs_key is None: + raise ValueError( + "large_attrs_key must be != None when limit_attr_size" + "!= None." + ) - if limit_attr_size <= 0: - raise ValueError('limit_attr_size must be > 0, but is %d' % - limit_attr_size) + if limit_attr_size <= 0: + raise ValueError( + "limit_attr_size must be > 0, but is %d" % limit_attr_size + ) - # Filter only if a limit size is defined. - if limit_attr_size is not None: - for node in graph.node: - # Go through all the attributes and filter out ones bigger than the - # limit. - keys = list(node.attr.keys()) - for key in keys: - size = node.attr[key].ByteSize() - if size > limit_attr_size or size < 0: - del node.attr[key] - # Add the attribute key to the list of "too large" attributes. - # This is used in the info card in the graph UI to show the user - # that some attributes are too large to be shown. - node.attr[large_attrs_key].list.s.append(tf.compat.as_bytes(key)) + # Filter only if a limit size is defined. + if limit_attr_size is not None: + for node in graph.node: + # Go through all the attributes and filter out ones bigger than the + # limit. + keys = list(node.attr.keys()) + for key in keys: + size = node.attr[key].ByteSize() + if size > limit_attr_size or size < 0: + del node.attr[key] + # Add the attribute key to the list of "too large" attributes. + # This is used in the info card in the graph UI to show the user + # that some attributes are too large to be shown. + node.attr[large_attrs_key].list.s.append( + tf.compat.as_bytes(key) + ) diff --git a/tensorboard/backend/security_validator.py b/tensorboard/backend/security_validator.py index 388a74d9c4..ec0353a601 100644 --- a/tensorboard/backend/security_validator.py +++ b/tensorboard/backend/security_validator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Validates responses and their security features""" +"""Validates responses and their security features.""" from __future__ import absolute_import from __future__ import division @@ -48,145 +48,147 @@ def _maybe_raise_value_error(error_msg): - logger.warn("In 3.0, this warning will become an error:\n%s" % error_msg) - # TODO(3.x): raise a value error. + logger.warn("In 3.0, this warning will become an error:\n%s" % error_msg) + # TODO(3.x): raise a value error. class SecurityValidatorMiddleware(object): - """WSGI middleware validating security on response. + """WSGI middleware validating security on response. - It validates: - - responses have Content-Type - - responses have X-Content-Type-Options: nosniff - - text/html responses have CSP header. It also validates whether the CSP - headers pass basic requirement. e.g., default-src should be present, cannot - use "*" directive, and others. For more complete list, please refer to - _validate_csp_policies. + It validates: + - responses have Content-Type + - responses have X-Content-Type-Options: nosniff + - text/html responses have CSP header. It also validates whether the CSP + headers pass basic requirement. e.g., default-src should be present, cannot + use "*" directive, and others. For more complete list, please refer to + _validate_csp_policies. - Instances of this class are WSGI applications (see PEP 3333). - """ - - def __init__(self, application): - """Initializes an `SecurityValidatorMiddleware`. - - Args: - application: The WSGI application to wrap (see PEP 3333). + Instances of this class are WSGI applications (see PEP 3333). """ - self._application = application - - def __call__(self, environ, start_response): - - def start_response_proxy(status, headers, exc_info=None): - self._validate_headers(headers) - return start_response(status, headers, exc_info) - - return self._application(environ, start_response_proxy) - - def _validate_headers(self, headers_list): - headers = Headers(headers_list) - self._validate_content_type(headers) - self._validate_x_content_type_options(headers) - self._validate_csp_headers(headers) - - def _validate_content_type(self, headers): - if headers.get("Content-Type"): - return - - _maybe_raise_value_error("Content-Type is required on a Response") - - def _validate_x_content_type_options(self, headers): - option = headers.get("X-Content-Type-Options") - if option == "nosniff": - return - - _maybe_raise_value_error( - 'X-Content-Type-Options is required to be "nosniff"' - ) - - def _validate_csp_headers(self, headers): - mime_type, _ = http.parse_options_header(headers.get("Content-Type")) - if mime_type != _HTML_MIME_TYPE: - return - - csp_texts = headers.get_all("Content-Security-Policy") - policies = [] - - for csp_text in csp_texts: - policies += self._parse_serialized_csp(csp_text) - - self._validate_csp_policies(policies) - - def _validate_csp_policies(self, policies): - has_default_src = False - violations = [] - - for directive in policies: - name = directive.name - for value in directive.value: - has_default_src = has_default_src or name == _CSP_DEFAULT_SRC - - if value in _CSP_IGNORE.get(name, []): - # There are cases where certain directives are legitimate. - continue - - # TensorBoard follows principle of least privilege. However, to make it - # easier to conform to the security policy for plugin authors, - # TensorBoard trusts request and resources originating its server. Also, - # it can selectively trust domains as long as they use https protocol. - # Lastly, it can allow 'none' directive. - # TODO(stephanwlee): allow configuration for whitelist of domains for - # stricter enforcement. - # TODO(stephanwlee): deprecate the sha-based whitelisting. - if ( - value == "'self'" or value == "'none'" - or value.startswith("https:") - or value.startswith("'sha256-") - ): - continue - - msg = "Illegal Content-Security-Policy for {name}: {value}".format( - name=name, value=value - ) - violations.append(msg) - if not has_default_src: - violations.append("Requires default-src for Content-Security-Policy") + def __init__(self, application): + """Initializes an `SecurityValidatorMiddleware`. - if violations: - _maybe_raise_value_error("\n".join(violations)) + Args: + application: The WSGI application to wrap (see PEP 3333). + """ + self._application = application - def _parse_serialized_csp(self, csp_text): - # See https://www.w3.org/TR/CSP/#parse-serialized-policy. - # Below Steps are based on the algorithm stated in above spec. - # Deviations: - # - it does not warn and ignore duplicative directive (Step 2.5) + def __call__(self, environ, start_response): + def start_response_proxy(status, headers, exc_info=None): + self._validate_headers(headers) + return start_response(status, headers, exc_info) - # Step 2 - csp_srcs = csp_text.split(";") + return self._application(environ, start_response_proxy) - policy = [] - for token in csp_srcs: - # Step 2.1 - token = token.strip() + def _validate_headers(self, headers_list): + headers = Headers(headers_list) + self._validate_content_type(headers) + self._validate_x_content_type_options(headers) + self._validate_csp_headers(headers) - if not token: - # Step 2.2 - continue + def _validate_content_type(self, headers): + if headers.get("Content-Type"): + return - # Step 2.3 - token_frag = token.split(None, 1) - name = token_frag[0] + _maybe_raise_value_error("Content-Type is required on a Response") - values = token_frag[1] if len(token_frag) == 2 else "" + def _validate_x_content_type_options(self, headers): + option = headers.get("X-Content-Type-Options") + if option == "nosniff": + return - # Step 2.4 - name = name.lower() - - # Step 2.6 - value = values.split() - # Step 2.7 - directive = Directive(name=name, value=value) - # Step 2.8 - policy.append(directive) + _maybe_raise_value_error( + 'X-Content-Type-Options is required to be "nosniff"' + ) - return policy + def _validate_csp_headers(self, headers): + mime_type, _ = http.parse_options_header(headers.get("Content-Type")) + if mime_type != _HTML_MIME_TYPE: + return + + csp_texts = headers.get_all("Content-Security-Policy") + policies = [] + + for csp_text in csp_texts: + policies += self._parse_serialized_csp(csp_text) + + self._validate_csp_policies(policies) + + def _validate_csp_policies(self, policies): + has_default_src = False + violations = [] + + for directive in policies: + name = directive.name + for value in directive.value: + has_default_src = has_default_src or name == _CSP_DEFAULT_SRC + + if value in _CSP_IGNORE.get(name, []): + # There are cases where certain directives are legitimate. + continue + + # TensorBoard follows principle of least privilege. However, to make it + # easier to conform to the security policy for plugin authors, + # TensorBoard trusts request and resources originating its server. Also, + # it can selectively trust domains as long as they use https protocol. + # Lastly, it can allow 'none' directive. + # TODO(stephanwlee): allow configuration for whitelist of domains for + # stricter enforcement. + # TODO(stephanwlee): deprecate the sha-based whitelisting. + if ( + value == "'self'" + or value == "'none'" + or value.startswith("https:") + or value.startswith("'sha256-") + ): + continue + + msg = "Illegal Content-Security-Policy for {name}: {value}".format( + name=name, value=value + ) + violations.append(msg) + + if not has_default_src: + violations.append( + "Requires default-src for Content-Security-Policy" + ) + + if violations: + _maybe_raise_value_error("\n".join(violations)) + + def _parse_serialized_csp(self, csp_text): + # See https://www.w3.org/TR/CSP/#parse-serialized-policy. + # Below Steps are based on the algorithm stated in above spec. + # Deviations: + # - it does not warn and ignore duplicative directive (Step 2.5) + + # Step 2 + csp_srcs = csp_text.split(";") + + policy = [] + for token in csp_srcs: + # Step 2.1 + token = token.strip() + + if not token: + # Step 2.2 + continue + + # Step 2.3 + token_frag = token.split(None, 1) + name = token_frag[0] + + values = token_frag[1] if len(token_frag) == 2 else "" + + # Step 2.4 + name = name.lower() + + # Step 2.6 + value = values.split() + # Step 2.7 + directive = Directive(name=name, value=value) + # Step 2.8 + policy.append(directive) + + return policy diff --git a/tensorboard/backend/security_validator_test.py b/tensorboard/backend/security_validator_test.py index 9a972c9685..a3cd343746 100644 --- a/tensorboard/backend/security_validator_test.py +++ b/tensorboard/backend/security_validator_test.py @@ -19,10 +19,10 @@ from __future__ import print_function try: - # python version >= 3.3 - from unittest import mock + # python version >= 3.3 + from unittest import mock except ImportError: - import mock # pylint: disable=unused-import + import mock # pylint: disable=unused-import import werkzeug from werkzeug import test as werkzeug_test @@ -43,159 +43,154 @@ def create_headers( x_content_type_options="nosniff", content_security_policy="", ): - return Headers( - { - "Content-Type": content_type, - "X-Content-Type-Options": x_content_type_options, - "Content-Security-Policy": content_security_policy, - } - ) - - -class SecurityValidatorMiddlewareTest(tb_test.TestCase): - """Tests for `SecurityValidatorMiddleware`.""" - - def make_request_and_maybe_assert_warn( - self, - headers, - expected_warn_substr, - ): - - @werkzeug.Request.application - def _simple_app(req): - return werkzeug.Response("OK", headers=headers) - - app = security_validator.SecurityValidatorMiddleware(_simple_app) - server = werkzeug_test.Client(app, BaseResponse) - - with mock.patch.object(logger, "warn") as mock_warn: - server.get("") - - if expected_warn_substr is None: - mock_warn.assert_not_called() - else: - mock_warn.assert_called_with(_WARN_PREFIX + expected_warn_substr) - - def make_request_and_assert_no_warn( - self, - headers, - ): - self.make_request_and_maybe_assert_warn(headers, None) - - def test_validate_content_type(self): - self.make_request_and_assert_no_warn( - create_headers(content_type="application/json"), + return Headers( + { + "Content-Type": content_type, + "X-Content-Type-Options": x_content_type_options, + "Content-Security-Policy": content_security_policy, + } ) - self.make_request_and_maybe_assert_warn( - create_headers(content_type=""), - "Content-Type is required on a Response" - ) - def test_validate_x_content_type_options(self): - self.make_request_and_assert_no_warn( - create_headers(x_content_type_options="nosniff") - ) - - self.make_request_and_maybe_assert_warn( - create_headers(x_content_type_options=""), - 'X-Content-Type-Options is required to be "nosniff"', - ) - - def test_validate_csp_text_html(self): - self.make_request_and_assert_no_warn( - create_headers( - content_type="text/html; charset=UTF-8", - content_security_policy=( - "DEFAult-src 'self';script-src https://google.com;" - "style-src 'self' https://example; object-src " +class SecurityValidatorMiddlewareTest(tb_test.TestCase): + """Tests for `SecurityValidatorMiddleware`.""" + + def make_request_and_maybe_assert_warn( + self, headers, expected_warn_substr, + ): + @werkzeug.Request.application + def _simple_app(req): + return werkzeug.Response("OK", headers=headers) + + app = security_validator.SecurityValidatorMiddleware(_simple_app) + server = werkzeug_test.Client(app, BaseResponse) + + with mock.patch.object(logger, "warn") as mock_warn: + server.get("") + + if expected_warn_substr is None: + mock_warn.assert_not_called() + else: + mock_warn.assert_called_with(_WARN_PREFIX + expected_warn_substr) + + def make_request_and_assert_no_warn( + self, headers, + ): + self.make_request_and_maybe_assert_warn(headers, None) + + def test_validate_content_type(self): + self.make_request_and_assert_no_warn( + create_headers(content_type="application/json"), + ) + + self.make_request_and_maybe_assert_warn( + create_headers(content_type=""), + "Content-Type is required on a Response", + ) + + def test_validate_x_content_type_options(self): + self.make_request_and_assert_no_warn( + create_headers(x_content_type_options="nosniff") + ) + + self.make_request_and_maybe_assert_warn( + create_headers(x_content_type_options=""), + 'X-Content-Type-Options is required to be "nosniff"', + ) + + def test_validate_csp_text_html(self): + self.make_request_and_assert_no_warn( + create_headers( + content_type="text/html; charset=UTF-8", + content_security_policy=( + "DEFAult-src 'self';script-src https://google.com;" + "style-src 'self' https://example; object-src " + ), ), - ), - ) + ) - self.make_request_and_maybe_assert_warn( - create_headers( - content_type="text/html; charset=UTF-8", - content_security_policy="", - ), - "Requires default-src for Content-Security-Policy", - ) + self.make_request_and_maybe_assert_warn( + create_headers( + content_type="text/html; charset=UTF-8", + content_security_policy="", + ), + "Requires default-src for Content-Security-Policy", + ) - self.make_request_and_maybe_assert_warn( - create_headers( - content_type="text/html; charset=UTF-8", - content_security_policy="default-src *", - ), - "Illegal Content-Security-Policy for default-src: *", - ) + self.make_request_and_maybe_assert_warn( + create_headers( + content_type="text/html; charset=UTF-8", + content_security_policy="default-src *", + ), + "Illegal Content-Security-Policy for default-src: *", + ) - self.make_request_and_maybe_assert_warn( - create_headers( - content_type="text/html; charset=UTF-8", - content_security_policy="default-src 'self';script-src *", - ), - "Illegal Content-Security-Policy for script-src: *", - ) + self.make_request_and_maybe_assert_warn( + create_headers( + content_type="text/html; charset=UTF-8", + content_security_policy="default-src 'self';script-src *", + ), + "Illegal Content-Security-Policy for script-src: *", + ) + + self.make_request_and_maybe_assert_warn( + create_headers( + content_type="text/html; charset=UTF-8", + content_security_policy=( + "script-src * 'sha256-foo' 'nonce-bar';" + "style-src http://google.com;object-src *;" + "img-src 'unsafe-inline';default-src 'self';" + "script-src * 'strict-dynamic'" + ), + ), + "\n".join( + [ + "Illegal Content-Security-Policy for script-src: *", + "Illegal Content-Security-Policy for script-src: 'nonce-bar'", + "Illegal Content-Security-Policy for style-src: http://google.com", + "Illegal Content-Security-Policy for object-src: *", + "Illegal Content-Security-Policy for img-src: 'unsafe-inline'", + "Illegal Content-Security-Policy for script-src: *", + "Illegal Content-Security-Policy for script-src: 'strict-dynamic'", + ] + ), + ) - self.make_request_and_maybe_assert_warn( - create_headers( + def test_validate_csp_multiple_csp_headers(self): + base_headers = create_headers( content_type="text/html; charset=UTF-8", content_security_policy=( - "script-src * 'sha256-foo' 'nonce-bar';" - "style-src http://google.com;object-src *;" - "img-src 'unsafe-inline';default-src 'self';" - "script-src * 'strict-dynamic'" + "script-src * 'sha256-foo';" "style-src http://google.com" ), - ), - "\n".join( - [ - "Illegal Content-Security-Policy for script-src: *", - "Illegal Content-Security-Policy for script-src: 'nonce-bar'", - "Illegal Content-Security-Policy for style-src: http://google.com", - "Illegal Content-Security-Policy for object-src: *", - "Illegal Content-Security-Policy for img-src: 'unsafe-inline'", - "Illegal Content-Security-Policy for script-src: *", - "Illegal Content-Security-Policy for script-src: 'strict-dynamic'", - ] - ), - ) - - def test_validate_csp_multiple_csp_headers(self): - base_headers = create_headers( - content_type="text/html; charset=UTF-8", - content_security_policy=( - "script-src * 'sha256-foo';" - "style-src http://google.com" - ), - ) - base_headers.add( - "Content-Security-Policy", - "default-src 'self';script-src 'nonce-bar';object-src *", - ) - - self.make_request_and_maybe_assert_warn( - base_headers, - "\n".join( - [ - "Illegal Content-Security-Policy for script-src: *", - "Illegal Content-Security-Policy for style-src: http://google.com", - "Illegal Content-Security-Policy for script-src: 'nonce-bar'", - "Illegal Content-Security-Policy for object-src: *", - ] - ), - ) - - def test_validate_csp_non_text_html(self): - self.make_request_and_assert_no_warn( - create_headers( - content_type="application/xhtml", - content_security_policy=( - "script-src * 'sha256-foo' 'nonce-bar';" - "style-src http://google.com;object-src *;" + ) + base_headers.add( + "Content-Security-Policy", + "default-src 'self';script-src 'nonce-bar';object-src *", + ) + + self.make_request_and_maybe_assert_warn( + base_headers, + "\n".join( + [ + "Illegal Content-Security-Policy for script-src: *", + "Illegal Content-Security-Policy for style-src: http://google.com", + "Illegal Content-Security-Policy for script-src: 'nonce-bar'", + "Illegal Content-Security-Policy for object-src: *", + ] ), - ), - ) + ) + + def test_validate_csp_non_text_html(self): + self.make_request_and_assert_no_warn( + create_headers( + content_type="application/xhtml", + content_security_policy=( + "script-src * 'sha256-foo' 'nonce-bar';" + "style-src http://google.com;object-src *;" + ), + ), + ) if __name__ == "__main__": - tb_test.main() + tb_test.main() diff --git a/tensorboard/compat/__init__.py b/tensorboard/compat/__init__.py index 07ee19b17c..78cd610e9c 100644 --- a/tensorboard/compat/__init__.py +++ b/tensorboard/compat/__init__.py @@ -14,9 +14,9 @@ """Compatibility interfaces for TensorBoard. -This module provides logic for importing variations on the TensorFlow APIs, as -lazily loaded imports to help avoid circular dependency issues and defer the -search and loading of the module until necessary. +This module provides logic for importing variations on the TensorFlow +APIs, as lazily loaded imports to help avoid circular dependency issues +and defer the search and loading of the module until necessary. """ from __future__ import absolute_import @@ -28,78 +28,82 @@ import tensorboard.lazy as _lazy -@_lazy.lazy_load('tensorboard.compat.tf') +@_lazy.lazy_load("tensorboard.compat.tf") def tf(): - """Provide the root module of a TF-like API for use within TensorBoard. - - By default this is equivalent to `import tensorflow as tf`, but it can be used - in combination with //tensorboard/compat:tensorflow (to fall back to a stub TF - API implementation if the real one is not available) or with - //tensorboard/compat:no_tensorflow (to force unconditional use of the stub). - - Returns: - The root module of a TF-like API, if available. - - Raises: - ImportError: if a TF-like API is not available. - """ - try: - from tensorboard.compat import notf - except ImportError: + """Provide the root module of a TF-like API for use within TensorBoard. + + By default this is equivalent to `import tensorflow as tf`, but it can be used + in combination with //tensorboard/compat:tensorflow (to fall back to a stub TF + API implementation if the real one is not available) or with + //tensorboard/compat:no_tensorflow (to force unconditional use of the stub). + + Returns: + The root module of a TF-like API, if available. + + Raises: + ImportError: if a TF-like API is not available. + """ try: - import tensorflow - return tensorflow + from tensorboard.compat import notf except ImportError: - pass - from tensorboard.compat import tensorflow_stub - return tensorflow_stub + try: + import tensorflow + + return tensorflow + except ImportError: + pass + from tensorboard.compat import tensorflow_stub + return tensorflow_stub -@_lazy.lazy_load('tensorboard.compat.tf2') + +@_lazy.lazy_load("tensorboard.compat.tf2") def tf2(): - """Provide the root module of a TF-2.0 API for use within TensorBoard. + """Provide the root module of a TF-2.0 API for use within TensorBoard. - Returns: - The root module of a TF-2.0 API, if available. + Returns: + The root module of a TF-2.0 API, if available. - Raises: - ImportError: if a TF-2.0 API is not available. - """ - # Import the `tf` compat API from this file and check if it's already TF 2.0. - if tf.__version__.startswith('2.'): - return tf - elif hasattr(tf, 'compat') and hasattr(tf.compat, 'v2'): - # As a fallback, try `tensorflow.compat.v2` if it's defined. - return tf.compat.v2 - raise ImportError('cannot import tensorflow 2.0 API') + Raises: + ImportError: if a TF-2.0 API is not available. + """ + # Import the `tf` compat API from this file and check if it's already TF 2.0. + if tf.__version__.startswith("2."): + return tf + elif hasattr(tf, "compat") and hasattr(tf.compat, "v2"): + # As a fallback, try `tensorflow.compat.v2` if it's defined. + return tf.compat.v2 + raise ImportError("cannot import tensorflow 2.0 API") # TODO(https://github.com/tensorflow/tensorboard/issues/1711): remove this -@_lazy.lazy_load('tensorboard.compat._pywrap_tensorflow') +@_lazy.lazy_load("tensorboard.compat._pywrap_tensorflow") def _pywrap_tensorflow(): - """Provide pywrap_tensorflow access in TensorBoard. + """Provide pywrap_tensorflow access in TensorBoard. - pywrap_tensorflow cannot be accessed from tf.python.pywrap_tensorflow - and needs to be imported using - `from tensorflow.python import pywrap_tensorflow`. Therefore, we provide - a separate accessor function for it here. + pywrap_tensorflow cannot be accessed from tf.python.pywrap_tensorflow + and needs to be imported using + `from tensorflow.python import pywrap_tensorflow`. Therefore, we provide + a separate accessor function for it here. - NOTE: pywrap_tensorflow is not part of TensorFlow API and this - dependency will go away soon. + NOTE: pywrap_tensorflow is not part of TensorFlow API and this + dependency will go away soon. - Returns: - pywrap_tensorflow import, if available. + Returns: + pywrap_tensorflow import, if available. - Raises: - ImportError: if we couldn't import pywrap_tensorflow. - """ - try: - from tensorboard.compat import notf - except ImportError: + Raises: + ImportError: if we couldn't import pywrap_tensorflow. + """ try: - from tensorflow.python import pywrap_tensorflow - return pywrap_tensorflow + from tensorboard.compat import notf except ImportError: - pass - from tensorboard.compat.tensorflow_stub import pywrap_tensorflow - return pywrap_tensorflow + try: + from tensorflow.python import pywrap_tensorflow + + return pywrap_tensorflow + except ImportError: + pass + from tensorboard.compat.tensorflow_stub import pywrap_tensorflow + + return pywrap_tensorflow diff --git a/tensorboard/compat/proto/proto_test.py b/tensorboard/compat/proto/proto_test.py index 49885f2b0f..d05559b6df 100644 --- a/tensorboard/compat/proto/proto_test.py +++ b/tensorboard/compat/proto/proto_test.py @@ -15,8 +15,8 @@ """Proto match tests between `tensorboard.compat.proto` and TensorFlow. These tests verify that the local copy of TensorFlow protos are the same -as those available directly from TensorFlow. Local protos are used to build -`tensorboard-notf` without a TensorFlow dependency. +as those available directly from TensorFlow. Local protos are used to +build `tensorboard-notf` without a TensorFlow dependency. """ from __future__ import absolute_import @@ -32,104 +32,162 @@ # Keep this list synced with BUILD in current directory PROTO_IMPORTS = [ - ('tensorflow.core.framework.allocation_description_pb2', - 'tensorboard.compat.proto.allocation_description_pb2'), - ('tensorflow.core.framework.api_def_pb2', - 'tensorboard.compat.proto.api_def_pb2'), - ('tensorflow.core.framework.attr_value_pb2', - 'tensorboard.compat.proto.attr_value_pb2'), - ('tensorflow.core.protobuf.cluster_pb2', - 'tensorboard.compat.proto.cluster_pb2'), - ('tensorflow.core.protobuf.config_pb2', - 'tensorboard.compat.proto.config_pb2'), - ('tensorflow.core.framework.cost_graph_pb2', - 'tensorboard.compat.proto.cost_graph_pb2'), - ('tensorflow.python.framework.cpp_shape_inference_pb2', - 'tensorboard.compat.proto.cpp_shape_inference_pb2'), - ('tensorflow.core.protobuf.debug_pb2', - 'tensorboard.compat.proto.debug_pb2'), - ('tensorflow.core.util.event_pb2', - 'tensorboard.compat.proto.event_pb2'), - ('tensorflow.core.framework.function_pb2', - 'tensorboard.compat.proto.function_pb2'), - ('tensorflow.core.framework.graph_pb2', - 'tensorboard.compat.proto.graph_pb2'), - ('tensorflow.core.protobuf.meta_graph_pb2', - 'tensorboard.compat.proto.meta_graph_pb2'), - ('tensorflow.core.framework.node_def_pb2', - 'tensorboard.compat.proto.node_def_pb2'), - ('tensorflow.core.framework.op_def_pb2', - 'tensorboard.compat.proto.op_def_pb2'), - ('tensorflow.core.framework.resource_handle_pb2', - 'tensorboard.compat.proto.resource_handle_pb2'), - ('tensorflow.core.protobuf.rewriter_config_pb2', - 'tensorboard.compat.proto.rewriter_config_pb2'), - ('tensorflow.core.protobuf.saved_object_graph_pb2', - 'tensorboard.compat.proto.saved_object_graph_pb2'), - ('tensorflow.core.protobuf.saver_pb2', - 'tensorboard.compat.proto.saver_pb2'), - ('tensorflow.core.framework.step_stats_pb2', - 'tensorboard.compat.proto.step_stats_pb2'), - ('tensorflow.core.protobuf.struct_pb2', - 'tensorboard.compat.proto.struct_pb2'), - ('tensorflow.core.framework.summary_pb2', - 'tensorboard.compat.proto.summary_pb2'), - ('tensorflow.core.framework.tensor_pb2', - 'tensorboard.compat.proto.tensor_pb2'), - ('tensorflow.core.framework.tensor_description_pb2', - 'tensorboard.compat.proto.tensor_description_pb2'), - ('tensorflow.core.framework.tensor_shape_pb2', - 'tensorboard.compat.proto.tensor_shape_pb2'), - ('tensorflow.core.profiler.tfprof_log_pb2', - 'tensorboard.compat.proto.tfprof_log_pb2'), - ('tensorflow.core.protobuf.trackable_object_graph_pb2', - 'tensorboard.compat.proto.trackable_object_graph_pb2'), - ('tensorflow.core.framework.types_pb2', - 'tensorboard.compat.proto.types_pb2'), - ('tensorflow.core.framework.variable_pb2', - 'tensorboard.compat.proto.variable_pb2'), - ('tensorflow.core.framework.versions_pb2', - 'tensorboard.compat.proto.versions_pb2'), + ( + "tensorflow.core.framework.allocation_description_pb2", + "tensorboard.compat.proto.allocation_description_pb2", + ), + ( + "tensorflow.core.framework.api_def_pb2", + "tensorboard.compat.proto.api_def_pb2", + ), + ( + "tensorflow.core.framework.attr_value_pb2", + "tensorboard.compat.proto.attr_value_pb2", + ), + ( + "tensorflow.core.protobuf.cluster_pb2", + "tensorboard.compat.proto.cluster_pb2", + ), + ( + "tensorflow.core.protobuf.config_pb2", + "tensorboard.compat.proto.config_pb2", + ), + ( + "tensorflow.core.framework.cost_graph_pb2", + "tensorboard.compat.proto.cost_graph_pb2", + ), + ( + "tensorflow.python.framework.cpp_shape_inference_pb2", + "tensorboard.compat.proto.cpp_shape_inference_pb2", + ), + ( + "tensorflow.core.protobuf.debug_pb2", + "tensorboard.compat.proto.debug_pb2", + ), + ("tensorflow.core.util.event_pb2", "tensorboard.compat.proto.event_pb2"), + ( + "tensorflow.core.framework.function_pb2", + "tensorboard.compat.proto.function_pb2", + ), + ( + "tensorflow.core.framework.graph_pb2", + "tensorboard.compat.proto.graph_pb2", + ), + ( + "tensorflow.core.protobuf.meta_graph_pb2", + "tensorboard.compat.proto.meta_graph_pb2", + ), + ( + "tensorflow.core.framework.node_def_pb2", + "tensorboard.compat.proto.node_def_pb2", + ), + ( + "tensorflow.core.framework.op_def_pb2", + "tensorboard.compat.proto.op_def_pb2", + ), + ( + "tensorflow.core.framework.resource_handle_pb2", + "tensorboard.compat.proto.resource_handle_pb2", + ), + ( + "tensorflow.core.protobuf.rewriter_config_pb2", + "tensorboard.compat.proto.rewriter_config_pb2", + ), + ( + "tensorflow.core.protobuf.saved_object_graph_pb2", + "tensorboard.compat.proto.saved_object_graph_pb2", + ), + ( + "tensorflow.core.protobuf.saver_pb2", + "tensorboard.compat.proto.saver_pb2", + ), + ( + "tensorflow.core.framework.step_stats_pb2", + "tensorboard.compat.proto.step_stats_pb2", + ), + ( + "tensorflow.core.protobuf.struct_pb2", + "tensorboard.compat.proto.struct_pb2", + ), + ( + "tensorflow.core.framework.summary_pb2", + "tensorboard.compat.proto.summary_pb2", + ), + ( + "tensorflow.core.framework.tensor_pb2", + "tensorboard.compat.proto.tensor_pb2", + ), + ( + "tensorflow.core.framework.tensor_description_pb2", + "tensorboard.compat.proto.tensor_description_pb2", + ), + ( + "tensorflow.core.framework.tensor_shape_pb2", + "tensorboard.compat.proto.tensor_shape_pb2", + ), + ( + "tensorflow.core.profiler.tfprof_log_pb2", + "tensorboard.compat.proto.tfprof_log_pb2", + ), + ( + "tensorflow.core.protobuf.trackable_object_graph_pb2", + "tensorboard.compat.proto.trackable_object_graph_pb2", + ), + ( + "tensorflow.core.framework.types_pb2", + "tensorboard.compat.proto.types_pb2", + ), + ( + "tensorflow.core.framework.variable_pb2", + "tensorboard.compat.proto.variable_pb2", + ), + ( + "tensorflow.core.framework.versions_pb2", + "tensorboard.compat.proto.versions_pb2", + ), ] PROTO_REPLACEMENTS = [ - ('tensorflow/core/framework/', 'tensorboard/compat/proto/'), - ('tensorflow/core/protobuf/', 'tensorboard/compat/proto/'), - ('tensorflow/core/profiler/', 'tensorboard/compat/proto/'), - ('tensorflow/python/framework/', 'tensorboard/compat/proto/'), - ('tensorflow/core/util/', 'tensorboard/compat/proto/'), - ('package: "tensorflow.tfprof"', 'package: "tensorboard"'), - ('package: "tensorflow"', 'package: "tensorboard"'), - ('type_name: ".tensorflow.tfprof', 'type_name: ".tensorboard'), - ('type_name: ".tensorflow', 'type_name: ".tensorboard'), + ("tensorflow/core/framework/", "tensorboard/compat/proto/"), + ("tensorflow/core/protobuf/", "tensorboard/compat/proto/"), + ("tensorflow/core/profiler/", "tensorboard/compat/proto/"), + ("tensorflow/python/framework/", "tensorboard/compat/proto/"), + ("tensorflow/core/util/", "tensorboard/compat/proto/"), + ('package: "tensorflow.tfprof"', 'package: "tensorboard"'), + ('package: "tensorflow"', 'package: "tensorboard"'), + ('type_name: ".tensorflow.tfprof', 'type_name: ".tensorboard'), + ('type_name: ".tensorflow', 'type_name: ".tensorboard'), ] class ProtoMatchTest(tf.test.TestCase): + def test_each_proto_matches_tensorflow(self): + for tf_path, tb_path in PROTO_IMPORTS: + tf_pb2 = importlib.import_module(tf_path) + tb_pb2 = importlib.import_module(tb_path) + expected = descriptor_pb2.FileDescriptorProto() + actual = descriptor_pb2.FileDescriptorProto() + tf_pb2.DESCRIPTOR.CopyToProto(expected) + tb_pb2.DESCRIPTOR.CopyToProto(actual) - def test_each_proto_matches_tensorflow(self): - for tf_path, tb_path in PROTO_IMPORTS: - tf_pb2 = importlib.import_module(tf_path) - tb_pb2 = importlib.import_module(tb_path) - expected = descriptor_pb2.FileDescriptorProto() - actual = descriptor_pb2.FileDescriptorProto() - tf_pb2.DESCRIPTOR.CopyToProto(expected) - tb_pb2.DESCRIPTOR.CopyToProto(actual) + # Convert expected to be actual since this matches the + # replacements done in proto/update.sh + actual = str(actual) + expected = str(expected) + for orig, repl in PROTO_REPLACEMENTS: + expected = expected.replace(orig, repl) - # Convert expected to be actual since this matches the - # replacements done in proto/update.sh - actual = str(actual) - expected = str(expected) - for orig, repl in PROTO_REPLACEMENTS: - expected = expected.replace(orig, repl) + diff = difflib.unified_diff( + actual.splitlines(1), expected.splitlines(1) + ) + diff = "".join(diff) - diff = difflib.unified_diff(actual.splitlines(1), - expected.splitlines(1)) - diff = ''.join(diff) + self.assertEquals( + diff, + "", + "{} and {} did not match:\n{}".format(tf_path, tb_path, diff), + ) - self.assertEquals(diff, '', - '{} and {} did not match:\n{}'.format(tf_path, tb_path, diff)) - -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/compat/tensorflow_stub/__init__.py b/tensorboard/compat/tensorflow_stub/__init__.py index b94f654a28..8b3cf3ffa2 100644 --- a/tensorboard/compat/tensorflow_stub/__init__.py +++ b/tensorboard/compat/tensorflow_stub/__init__.py @@ -38,4 +38,4 @@ compat.v1.errors = errors # Set a fake __version__ to help distinguish this as our own stub API. -__version__ = 'stub' +__version__ = "stub" diff --git a/tensorboard/compat/tensorflow_stub/app.py b/tensorboard/compat/tensorflow_stub/app.py index b49e6e47b4..6c3f265a41 100644 --- a/tensorboard/compat/tensorflow_stub/app.py +++ b/tensorboard/compat/tensorflow_stub/app.py @@ -31,13 +31,13 @@ def _usage(shorthelp): shorthelp: bool, if True, prints only flags from the main module, rather than all flags. """ - doc = _sys.modules['__main__'].__doc__ + doc = _sys.modules["__main__"].__doc__ if not doc: - doc = '\nUSAGE: %s [flags]\n' % _sys.argv[0] - doc = flags.text_wrap(doc, indent=' ', firstline_indent='') + doc = "\nUSAGE: %s [flags]\n" % _sys.argv[0] + doc = flags.text_wrap(doc, indent=" ", firstline_indent="") else: # Replace all '%s' with sys.argv[0], and all '%%' with '%'. - num_specifiers = doc.count('%') - 2 * doc.count('%%') + num_specifiers = doc.count("%") - 2 * doc.count("%%") try: doc %= (_sys.argv[0],) * num_specifiers except (OverflowError, TypeError, ValueError): @@ -50,9 +50,9 @@ def _usage(shorthelp): try: _sys.stdout.write(doc) if flag_str: - _sys.stdout.write('\nflags:\n') + _sys.stdout.write("\nflags:\n") _sys.stdout.write(flag_str) - _sys.stdout.write('\n') + _sys.stdout.write("\n") except IOError as e: # We avoid printing a huge backtrace if we get EPIPE, because # "foo.par --help | less" is a frequent use case. @@ -62,24 +62,27 @@ def _usage(shorthelp): class _HelpFlag(flags.BooleanFlag): """Special boolean flag that displays usage and raises SystemExit.""" - NAME = 'help' - SHORT_NAME = 'h' + + NAME = "help" + SHORT_NAME = "h" def __init__(self): super(_HelpFlag, self).__init__( - self.NAME, False, 'show this help', short_name=self.SHORT_NAME) + self.NAME, False, "show this help", short_name=self.SHORT_NAME + ) def parse(self, arg): if arg: _usage(shorthelp=True) print() - print('Try --helpfull to get a list of all flags.') + print("Try --helpfull to get a list of all flags.") _sys.exit(1) class _HelpshortFlag(_HelpFlag): """--helpshort is an alias for --help.""" - NAME = 'helpshort' + + NAME = "helpshort" SHORT_NAME = None @@ -87,12 +90,7 @@ class _HelpfullFlag(flags.BooleanFlag): """Display help for flags in main module and all dependent modules.""" def __init__(self): - super( - _HelpfullFlag, - self).__init__( - 'helpfull', - False, - 'show full help') + super(_HelpfullFlag, self).__init__("helpfull", False, "show full help") def parse(self, arg): if arg: @@ -122,7 +120,7 @@ def run(main=None, argv=None): # Parse known flags. argv = flags.FLAGS(_sys.argv if argv is None else argv, known_only=True) - main = main or _sys.modules['__main__'].main + main = main or _sys.modules["__main__"].main # Call the main function, passing through any arguments # to the final program. diff --git a/tensorboard/compat/tensorflow_stub/compat/__init__.py b/tensorboard/compat/tensorflow_stub/compat/__init__.py index a7dd0d582e..1938e8d65c 100644 --- a/tensorboard/compat/tensorflow_stub/compat/__init__.py +++ b/tensorboard/compat/tensorflow_stub/compat/__init__.py @@ -39,7 +39,8 @@ def as_bytes(bytes_or_text, encoding="utf-8"): - """Converts either bytes or unicode to `bytes`, using utf-8 encoding for text. + """Converts either bytes or unicode to `bytes`, using utf-8 encoding for + text. Args: bytes_or_text: A `bytes`, `str`, or `unicode` object. @@ -56,7 +57,9 @@ def as_bytes(bytes_or_text, encoding="utf-8"): elif isinstance(bytes_or_text, bytes): return bytes_or_text else: - raise TypeError("Expected binary or unicode string, got %r" % (bytes_or_text,)) + raise TypeError( + "Expected binary or unicode string, got %r" % (bytes_or_text,) + ) def as_text(bytes_or_text, encoding="utf-8"): @@ -77,7 +80,9 @@ def as_text(bytes_or_text, encoding="utf-8"): elif isinstance(bytes_or_text, bytes): return bytes_or_text.decode(encoding) else: - raise TypeError("Expected binary or unicode string, got %r" % bytes_or_text) + raise TypeError( + "Expected binary or unicode string, got %r" % bytes_or_text + ) # Convert an object to a `str` in both Python 2 and 3. @@ -109,7 +114,8 @@ def as_str_any(value): # @tf_export('compat.path_to_str') def path_to_str(path): - """Returns the file system path representation of a `PathLike` object, else as it is. + """Returns the file system path representation of a `PathLike` object, else + as it is. Args: path: An object that can be converted to path representation. diff --git a/tensorboard/compat/tensorflow_stub/dtypes.py b/tensorboard/compat/tensorflow_stub/dtypes.py index 9115a5b8a0..6a1f864123 100644 --- a/tensorboard/compat/tensorflow_stub/dtypes.py +++ b/tensorboard/compat/tensorflow_stub/dtypes.py @@ -74,7 +74,6 @@ def __init__(self, type_enum): Raises: TypeError: If `type_enum` is not a value `types_pb2.DataType`. - """ # TODO(mrry): Make the necessary changes (using __new__) to ensure # that calling this returns one of the interned values. @@ -136,7 +135,7 @@ def as_datatype_enum(self): @property def is_bool(self): - """Returns whether this is a boolean data type""" + """Returns whether this is a boolean data type.""" return self.base_dtype == bool @property @@ -150,9 +149,11 @@ def is_integer(self): @property def is_floating(self): - """Returns whether this is a (non-quantized, real) floating point type.""" + """Returns whether this is a (non-quantized, real) floating point + type.""" return ( - self.is_numpy_compatible and np.issubdtype(self.as_numpy_dtype, np.floating) + self.is_numpy_compatible + and np.issubdtype(self.as_numpy_dtype, np.floating) ) or self.base_dtype == bfloat16 @property @@ -186,7 +187,6 @@ def min(self): Raises: TypeError: if this is a non-numeric, unordered, or quantized type. - """ if self.is_quantized or self.base_dtype in ( bool, @@ -214,7 +214,6 @@ def max(self): Raises: TypeError: if this is a non-numeric, unordered, or quantized type. - """ if self.is_quantized or self.base_dtype in ( bool, @@ -239,6 +238,7 @@ def max(self): @property def limits(self, clip_negative=True): """Return intensity limits, i.e. (min, max) tuple, of the dtype. + Args: clip_negative : bool, optional If True, clip the negative range (i.e. return 0 for min intensity) @@ -247,7 +247,9 @@ def limits(self, clip_negative=True): min, max : tuple Lower and upper intensity limits. """ - min, max = dtype_range[self.as_numpy_dtype] # pylint: disable=redefined-builtin + min, max = dtype_range[ + self.as_numpy_dtype + ] # pylint: disable=redefined-builtin if clip_negative: min = 0 # pylint: disable=redefined-builtin return min, max @@ -329,9 +331,9 @@ def size(self): np.uint16: (0, 65535), np.int8: (-128, 127), np.int16: (-32768, 32767), - np.int64: (-2 ** 63, 2 ** 63 - 1), + np.int64: (-(2 ** 63), 2 ** 63 - 1), np.uint64: (0, 2 ** 64 - 1), - np.int32: (-2 ** 31, 2 ** 31 - 1), + np.int32: (-(2 ** 31), 2 ** 31 - 1), np.uint32: (0, 2 ** 32 - 1), np.float32: (-1, 1), np.float64: (-1, 1), @@ -523,7 +525,9 @@ def size(self): types_pb2.DT_RESOURCE_REF: "resource_ref", types_pb2.DT_VARIANT_REF: "variant_ref", } -_STRING_TO_TF = {value: _INTERN_TABLE[key] for key, value in _TYPE_TO_STRING.items()} +_STRING_TO_TF = { + value: _INTERN_TABLE[key] for key, value in _TYPE_TO_STRING.items() +} # Add non-canonical aliases. _STRING_TO_TF["half"] = float16 _STRING_TO_TF["half_ref"] = float16_ref @@ -687,4 +691,6 @@ def as_dtype(type_value): "Cannot convert {} to a dtype. {}".format(type_value, e) ) - raise TypeError("Cannot convert value %r to a TensorFlow DType." % type_value) + raise TypeError( + "Cannot convert value %r to a TensorFlow DType." % type_value + ) diff --git a/tensorboard/compat/tensorflow_stub/error_codes.py b/tensorboard/compat/tensorflow_stub/error_codes.py index d599ba24db..6d8a12e746 100644 --- a/tensorboard/compat/tensorflow_stub/error_codes.py +++ b/tensorboard/compat/tensorflow_stub/error_codes.py @@ -21,62 +21,62 @@ # The operation was cancelled (typically by the caller). CANCELLED = 1 -''' +""" Unknown error. An example of where this error may be returned is if a Status value received from another address space belongs to an error-space that is not known in this address space. Also errors raised by APIs that do not return enough error information may be converted to this error. -''' +""" UNKNOWN = 2 -''' +""" Client specified an invalid argument. Note that this differs from FAILED_PRECONDITION. INVALID_ARGUMENT indicates arguments that are problematic regardless of the state of the system (e.g., a malformed file name). -''' +""" INVALID_ARGUMENT = 3 -''' +""" Deadline expired before operation could complete. For operations that change the state of the system, this error may be returned even if the operation has completed successfully. For example, a successful response from a server could have been delayed long enough for the deadline to expire. -''' +""" DEADLINE_EXCEEDED = 4 -''' +""" Some requested entity (e.g., file or directory) was not found. For privacy reasons, this code *may* be returned when the client does not have the access right to the entity. -''' +""" NOT_FOUND = 5 -''' +""" Some entity that we attempted to create (e.g., file or directory) already exists. -''' +""" ALREADY_EXISTS = 6 -''' +""" The caller does not have permission to execute the specified operation. PERMISSION_DENIED must not be used for rejections caused by exhausting some resource (use RESOURCE_EXHAUSTED instead for those errors). PERMISSION_DENIED must not be used if the caller can not be identified (use UNAUTHENTICATED instead for those errors). -''' +""" PERMISSION_DENIED = 7 -''' +""" Some resource has been exhausted, perhaps a per-user quota, or perhaps the entire file system is out of space. -''' +""" RESOURCE_EXHAUSTED = 8 -''' +""" Operation was rejected because the system is not in a state required for the operation's execution. For example, directory to be deleted may be non-empty, an rmdir operation is applied to @@ -95,19 +95,19 @@ REST Get/Update/Delete on a resource and the resource on the server does not match the condition. E.g., conflicting read-modify-write on the same resource. -''' +""" FAILED_PRECONDITION = 9 -''' +""" The operation was aborted, typically due to a concurrency issue like sequencer check failures, transaction aborts, etc. See litmus test above for deciding between FAILED_PRECONDITION, ABORTED, and UNAVAILABLE. -''' +""" ABORTED = 10 -''' +""" Operation tried to iterate past the valid input range. E.g., seeking or reading past end of file. @@ -123,39 +123,39 @@ error) when it applies so that callers who are iterating through a space can easily look for an OUT_OF_RANGE error to detect when they are done. -''' +""" OUT_OF_RANGE = 11 # Operation is not implemented or not supported/enabled in this service. UNIMPLEMENTED = 12 -''' +""" Internal errors. Means some invariant expected by the underlying system has been broken. If you see one of these errors, something is very broken. -''' +""" INTERNAL = 13 -''' +""" The service is currently unavailable. This is a most likely a transient condition and may be corrected by retrying with a backoff. See litmus test above for deciding between FAILED_PRECONDITION, ABORTED, and UNAVAILABLE. -''' +""" UNAVAILABLE = 14 # Unrecoverable data loss or corruption. DATA_LOSS = 15 -''' +""" The request does not have valid authentication credentials for the operation. -''' +""" UNAUTHENTICATED = 16 -''' +""" An extra enum entry to prevent people from writing code that fails to compile when a new code is added. @@ -165,5 +165,5 @@ Nobody should rely on the value (currently 20) listed here. It may change in the future. -''' +""" DO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_USE_DEFAULT_IN_SWITCH_INSTEAD_ = 20 diff --git a/tensorboard/compat/tensorflow_stub/errors.py b/tensorboard/compat/tensorflow_stub/errors.py index 4c8ba89613..3f9f6cfc8d 100644 --- a/tensorboard/compat/tensorflow_stub/errors.py +++ b/tensorboard/compat/tensorflow_stub/errors.py @@ -82,7 +82,8 @@ def node_def(self): def __str__(self): if self._op is not None: output = [ - "%s\n\nCaused by op %r, defined at:\n" % (self.message, self._op.name) + "%s\n\nCaused by op %r, defined at:\n" + % (self.message, self._op.name) ] curr_traceback_list = traceback.format_list(self._op.traceback) output.extend(curr_traceback_list) @@ -95,7 +96,9 @@ def __str__(self): % (original_op.name,) ) prev_traceback_list = curr_traceback_list - curr_traceback_list = traceback.format_list(original_op.traceback) + curr_traceback_list = traceback.format_list( + original_op.traceback + ) # Attempt to elide large common subsequences of the subsequent # stack traces. @@ -104,7 +107,9 @@ def __str__(self): is_eliding = False elide_count = 0 last_elided_line = None - for line, line_in_prev in zip(curr_traceback_list, prev_traceback_list): + for line, line_in_prev in zip( + curr_traceback_list, prev_traceback_list + ): if line == line_in_prev: if is_eliding: elide_count += 1 @@ -199,8 +204,6 @@ def __init__(self, node_def, op, message): super(CancelledError, self).__init__(node_def, op, message, CANCELLED) - - # @tf_export("errors.UnknownError") class UnknownError(OpError): """Unknown error. @@ -259,7 +262,8 @@ def __init__(self, node_def, op, message): # @tf_export("errors.NotFoundError") class NotFoundError(OpError): - """Raised when a requested entity (e.g., a file or directory) was not found. + """Raised when a requested entity (e.g., a file or directory) was not + found. For example, running the @{tf.WholeFileReader.read} @@ -288,7 +292,9 @@ class AlreadyExistsError(OpError): def __init__(self, node_def, op, message): """Creates an `AlreadyExistsError`.""" - super(AlreadyExistsError, self).__init__(node_def, op, message, ALREADY_EXISTS) + super(AlreadyExistsError, self).__init__( + node_def, op, message, ALREADY_EXISTS + ) # @tf_export("errors.PermissionDeniedError") @@ -345,7 +351,8 @@ def __init__(self, node_def, op, message): # @tf_export("errors.FailedPreconditionError") class FailedPreconditionError(OpError): - """Operation was rejected because the system is not in a state to execute it. + """Operation was rejected because the system is not in a state to execute + it. This exception is most commonly raised when running an operation that reads a @{tf.Variable} @@ -394,7 +401,9 @@ class OutOfRangeError(OpError): def __init__(self, node_def, op, message): """Creates an `OutOfRangeError`.""" - super(OutOfRangeError, self).__init__(node_def, op, message, OUT_OF_RANGE) + super(OutOfRangeError, self).__init__( + node_def, op, message, OUT_OF_RANGE + ) # @tf_export("errors.UnimplementedError") @@ -412,7 +421,9 @@ class UnimplementedError(OpError): def __init__(self, node_def, op, message): """Creates an `UnimplementedError`.""" - super(UnimplementedError, self).__init__(node_def, op, message, UNIMPLEMENTED) + super(UnimplementedError, self).__init__( + node_def, op, message, UNIMPLEMENTED + ) # @tf_export("errors.InternalError") @@ -441,7 +452,9 @@ class UnavailableError(OpError): def __init__(self, node_def, op, message): """Creates an `UnavailableError`.""" - super(UnavailableError, self).__init__(node_def, op, message, UNAVAILABLE) + super(UnavailableError, self).__init__( + node_def, op, message, UNAVAILABLE + ) # @tf_export("errors.DataLossError") diff --git a/tensorboard/compat/tensorflow_stub/flags.py b/tensorboard/compat/tensorflow_stub/flags.py index dfbcdd5cb5..98f67aba21 100644 --- a/tensorboard/compat/tensorflow_stub/flags.py +++ b/tensorboard/compat/tensorflow_stub/flags.py @@ -13,7 +13,10 @@ # limitations under the License. # ============================================================================== -"""Import router for absl.flags. See https://github.com/abseil/abseil-py.""" +"""Import router for absl.flags. + +See https://github.com/abseil/abseil-py. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -62,9 +65,9 @@ def wrapper(*args, **kwargs): class _FlagValuesWrapper(object): """Wrapper class for absl.flags.FLAGS. - The difference is that tf.compat.v1.flags.FLAGS implicitly parses flags with sys.argv - when accessing the FLAGS values before it's explicitly parsed, - while absl.flags.FLAGS raises an exception. + The difference is that tf.compat.v1.flags.FLAGS implicitly parses + flags with sys.argv when accessing the FLAGS values before it's + explicitly parsed, while absl.flags.FLAGS raises an exception. """ def __init__(self, flags_object): diff --git a/tensorboard/compat/tensorflow_stub/io/gfile.py b/tensorboard/compat/tensorflow_stub/io/gfile.py index d0dd2f588a..e56bb96917 100644 --- a/tensorboard/compat/tensorflow_stub/io/gfile.py +++ b/tensorboard/compat/tensorflow_stub/io/gfile.py @@ -14,10 +14,10 @@ # ============================================================================== """A limited reimplementation of the TensorFlow FileIO API. -The TensorFlow version wraps the C++ FileSystem API. Here we provide a pure -Python implementation, limited to the features required for TensorBoard. This -allows running TensorBoard without depending on TensorFlow for file operations. - +The TensorFlow version wraps the C++ FileSystem API. Here we provide a +pure Python implementation, limited to the features required for +TensorBoard. This allows running TensorBoard without depending on +TensorFlow for file operations. """ from __future__ import absolute_import from __future__ import division @@ -34,9 +34,11 @@ import sys import tempfile import uuid + try: import botocore.exceptions import boto3 + S3_ENABLED = True except ImportError: S3_ENABLED = False @@ -62,7 +64,7 @@ def register_filesystem(prefix, filesystem): - if ':' in prefix: + if ":" in prefix: raise ValueError("Filesystem prefix cannot contain a :") _REGISTERED_FILESYSTEMS[prefix] = filesystem @@ -119,7 +121,8 @@ def read(self, filename, binary_mode=False, size=None, continue_from=None): encoding = None if binary_mode else "utf8" if not exists(filename): raise errors.NotFoundError( - None, None, 'Not Found: ' + compat.as_text(filename)) + None, None, "Not Found: " + compat.as_text(filename) + ) offset = None if continue_from is not None: offset = continue_from.get("opaque_offset", None) @@ -134,8 +137,8 @@ def read(self, filename, binary_mode=False, size=None, continue_from=None): return (data, continuation_token) def write(self, filename, file_content, binary_mode=False): - """Writes string file contents to a file, overwriting any - existing contents. + """Writes string file contents to a file, overwriting any existing + contents. Args: filename: string, a path @@ -166,8 +169,7 @@ def glob(self, filename): return [ # Convert the filenames to string from bytes. compat.as_str_any(matching_filename) - for matching_filename in py_glob.glob( - compat.as_bytes(filename)) + for matching_filename in py_glob.glob(compat.as_bytes(filename)) ] else: return [ @@ -175,7 +177,8 @@ def glob(self, filename): compat.as_str_any(matching_filename) for single_filename in filename for matching_filename in py_glob.glob( - compat.as_bytes(single_filename)) + compat.as_bytes(single_filename) + ) ] def isdir(self, dirname): @@ -197,7 +200,8 @@ def makedirs(self, path): os.makedirs(path) except FileExistsError: raise errors.AlreadyExistsError( - None, None, "Directory already exists") + None, None, "Directory already exists" + ) def stat(self, filename): """Returns file statistics for a given path.""" @@ -221,10 +225,10 @@ def bucket_and_path(self, url): """Split an S3-prefixed URL into bucket and path.""" url = compat.as_str_any(url) if url.startswith("s3://"): - url = url[len("s3://"):] + url = url[len("s3://") :] idx = url.index("/") bucket = url[:idx] - path = url[(idx + 1):] + path = url[(idx + 1) :] return bucket, path def exists(self, filename): @@ -270,44 +274,44 @@ def read(self, filename, binary_mode=False, size=None, continue_from=None): if continue_from is not None: offset = continue_from.get("byte_offset", 0) - endpoint = '' + endpoint = "" if size is not None: # TODO(orionr): This endpoint risks splitting a multi-byte # character or splitting \r and \n in the case of CRLFs, # producing decoding errors below. endpoint = offset + size - if offset != 0 or endpoint != '': + if offset != 0 or endpoint != "": # Asked for a range, so modify the request - args['Range'] = 'bytes={}-{}'.format(offset, endpoint) + args["Range"] = "bytes={}-{}".format(offset, endpoint) try: - stream = s3.Object(bucket, path).get(**args)['Body'].read() + stream = s3.Object(bucket, path).get(**args)["Body"].read() except botocore.exceptions.ClientError as exc: - if exc.response['Error']['Code'] == '416': + if exc.response["Error"]["Code"] == "416": if size is not None: # Asked for too much, so request just to the end. Do this # in a second request so we don't check length in all cases. client = boto3.client("s3") obj = client.head_object(Bucket=bucket, Key=path) - content_length = obj['ContentLength'] + content_length = obj["ContentLength"] endpoint = min(content_length, offset + size) if offset == endpoint: # Asked for no bytes, so just return empty - stream = b'' + stream = b"" else: - args['Range'] = 'bytes={}-{}'.format(offset, endpoint) - stream = s3.Object(bucket, path).get(**args)['Body'].read() + args["Range"] = "bytes={}-{}".format(offset, endpoint) + stream = s3.Object(bucket, path).get(**args)["Body"].read() else: raise # `stream` should contain raw bytes here (i.e., there has been neither # decoding nor newline translation), so the byte offset increases by # the expected amount. - continuation_token = {'byte_offset': (offset + len(stream))} + continuation_token = {"byte_offset": (offset + len(stream))} if binary_mode: return (bytes(stream), continuation_token) else: - return (stream.decode('utf-8'), continuation_token) + return (stream.decode("utf-8"), continuation_token) def write(self, filename, file_content, binary_mode=False): """Writes string file contents to a file. @@ -322,7 +326,7 @@ def write(self, filename, file_content, binary_mode=False): # Always convert to bytes for writing if binary_mode: if not isinstance(file_content, six.binary_type): - raise TypeError('File content type must be bytes') + raise TypeError("File content type must be bytes") else: file_content = compat.as_bytes(file_content) client.put_object(Body=file_content, Bucket=bucket, Key=path) @@ -330,11 +334,12 @@ def write(self, filename, file_content, binary_mode=False): def glob(self, filename): """Returns a list of files that match the given pattern(s).""" # Only support prefix with * at the end and no ? in the string - star_i = filename.find('*') - quest_i = filename.find('?') + star_i = filename.find("*") + quest_i = filename.find("?") if quest_i >= 0: raise NotImplementedError( - "{} not supported by compat glob".format(filename)) + "{} not supported by compat glob".format(filename) + ) if star_i != len(filename) - 1: # Just return empty so we can use glob from directory watcher # @@ -349,7 +354,7 @@ def glob(self, filename): keys = [] for r in p.paginate(Bucket=bucket, Prefix=path): for o in r.get("Contents", []): - key = o["Key"][len(path):] + key = o["Key"][len(path) :] if key: # Skip the base dir, which would add an empty string keys.append(filename + key) return keys @@ -375,9 +380,10 @@ def listdir(self, dirname): keys = [] for r in p.paginate(Bucket=bucket, Prefix=path, Delimiter="/"): keys.extend( - o["Prefix"][len(path):-1] for o in r.get("CommonPrefixes", [])) + o["Prefix"][len(path) : -1] for o in r.get("CommonPrefixes", []) + ) for o in r.get("Contents", []): - key = o["Key"][len(path):] + key = o["Key"][len(path) :] if key: # Skip the base dir, which would add an empty string keys.append(key) return keys @@ -386,12 +392,13 @@ def makedirs(self, dirname): """Creates a directory and all parent/intermediate directories.""" if self.exists(dirname): raise errors.AlreadyExistsError( - None, None, "Directory already exists") + None, None, "Directory already exists" + ) client = boto3.client("s3") bucket, path = self.bucket_and_path(dirname) if not path.endswith("/"): path += "/" # This will make sure we don't override a file - client.put_object(Body='', Bucket=bucket, Key=path) + client.put_object(Body="", Bucket=bucket, Key=path) def stat(self, filename): """Returns file statistics for a given path.""" @@ -401,9 +408,9 @@ def stat(self, filename): bucket, path = self.bucket_and_path(filename) try: obj = client.head_object(Bucket=bucket, Key=path) - return StatData(obj['ContentLength']) + return StatData(obj["ContentLength"]) except botocore.exceptions.ClientError as exc: - if exc.response['Error']['Code'] == '404': + if exc.response["Error"]["Code"] == "404": raise errors.NotFoundError(None, None, "Could not find file") else: raise @@ -418,12 +425,13 @@ class GFile(object): # Only methods needed for TensorBoard are implemented. def __init__(self, filename, mode): - if mode not in ('r', 'rb', 'br', 'w', 'wb', 'bw'): + if mode not in ("r", "rb", "br", "w", "wb", "bw"): raise NotImplementedError( - "mode {} not supported by compat GFile".format(mode)) + "mode {} not supported by compat GFile".format(mode) + ) self.filename = compat.as_bytes(filename) self.fs = get_filesystem(self.filename) - self.fs_supports_append = hasattr(self.fs, 'append') + self.fs_supports_append = hasattr(self.fs, "append") self.buff = None # The buffer offset and the buffer chunk size are measured in the # natural units of the underlying stream, i.e. bytes for binary mode, @@ -433,8 +441,8 @@ def __init__(self, filename, mode): self.continuation_token = None self.write_temp = None self.write_started = False - self.binary_mode = 'b' in mode - self.write_mode = 'w' in mode + self.binary_mode = "b" in mode + self.write_mode = "w" in mode self.closed = False def __enter__(self): @@ -453,7 +461,7 @@ def _read_buffer_to_offset(self, new_buff_offset): old_buff_offset = self.buff_offset read_size = min(len(self.buff), new_buff_offset) - old_buff_offset self.buff_offset += read_size - return self.buff[old_buff_offset:old_buff_offset + read_size] + return self.buff[old_buff_offset : old_buff_offset + read_size] def read(self, n=None): """Reads contents of file to a string. @@ -467,7 +475,8 @@ def read(self, n=None): """ if self.write_mode: raise errors.PermissionDeniedError( - None, None, "File not opened in read mode") + None, None, "File not opened in read mode" + ) result = None if self.buff and len(self.buff) > self.buff_offset: @@ -485,7 +494,8 @@ def read(self, n=None): # read from filesystem read_size = max(self.buff_chunk_size, n) if n is not None else None (self.buff, self.continuation_token) = self.fs.read( - self.filename, self.binary_mode, read_size, self.continuation_token) + self.filename, self.binary_mode, read_size, self.continuation_token + ) self.buff_offset = 0 # add from filesystem @@ -499,18 +509,20 @@ def read(self, n=None): return result def write(self, file_content): - """Writes string file contents to file, clearing contents of the - file on first write and then appending on subsequent calls. + """Writes string file contents to file, clearing contents of the file + on first write and then appending on subsequent calls. Args: file_content: string, the contents """ if not self.write_mode: raise errors.PermissionDeniedError( - None, None, "File not opened in write mode") + None, None, "File not opened in write mode" + ) if self.closed: raise errors.FailedPreconditionError( - None, None, "File already closed") + None, None, "File already closed" + ) if self.fs_supports_append: if not self.write_started: @@ -535,12 +547,12 @@ def __next__(self): if not self.buff: # read one unit into the buffer line = self.read(1) - if line and (line[-1] == '\n' or not self.buff): + if line and (line[-1] == "\n" or not self.buff): return line if not self.buff: raise StopIteration() else: - index = self.buff.find('\n', self.buff_offset) + index = self.buff.find("\n", self.buff_offset) if index != -1: # include line until now plus newline chunk = self.read(index + 1 - self.buff_offset) @@ -550,7 +562,7 @@ def __next__(self): # read one unit past end of buffer chunk = self.read(len(self.buff) + 1 - self.buff_offset) line = line + chunk if line else chunk - if line and (line[-1] == '\n' or not self.buff): + if line and (line[-1] == "\n" or not self.buff): return line if not self.buff: raise StopIteration() @@ -561,7 +573,8 @@ def next(self): def flush(self): if self.closed: raise errors.FailedPreconditionError( - None, None, "File already closed") + None, None, "File already closed" + ) if not self.fs_supports_append: if self.write_temp is not None: @@ -576,7 +589,7 @@ def flush(self): def close(self): self.flush() - if self.write_temp is not None: + if self.write_temp is not None: self.write_temp.close() self.write_temp = None self.write_started = False @@ -723,39 +736,40 @@ def stat(filename): """ return get_filesystem(filename).stat(filename) + # Used for tests only def _write_string_to_file(filename, file_content): - """Writes a string to a given file. + """Writes a string to a given file. - Args: - filename: string, path to a file - file_content: string, contents that need to be written to the file + Args: + filename: string, path to a file + file_content: string, contents that need to be written to the file - Raises: - errors.OpError: If there are errors during the operation. - """ - with GFile(filename, mode="w") as f: - f.write(compat.as_text(file_content)) + Raises: + errors.OpError: If there are errors during the operation. + """ + with GFile(filename, mode="w") as f: + f.write(compat.as_text(file_content)) # Used for tests only def _read_file_to_string(filename, binary_mode=False): - """Reads the entire contents of a file to a string. - - Args: - filename: string, path to a file - binary_mode: whether to open the file in binary mode or not. This changes - the type of the object returned. - - Returns: - contents of the file as a string or bytes. - - Raises: - errors.OpError: Raises variety of errors that are subtypes e.g. - `NotFoundError` etc. - """ - if binary_mode: - f = GFile(filename, mode="rb") - else: - f = GFile(filename, mode="r") - return f.read() + """Reads the entire contents of a file to a string. + + Args: + filename: string, path to a file + binary_mode: whether to open the file in binary mode or not. This changes + the type of the object returned. + + Returns: + contents of the file as a string or bytes. + + Raises: + errors.OpError: Raises variety of errors that are subtypes e.g. + `NotFoundError` etc. + """ + if binary_mode: + f = GFile(filename, mode="rb") + else: + f = GFile(filename, mode="r") + return f.read() diff --git a/tensorboard/compat/tensorflow_stub/io/gfile_s3_test.py b/tensorboard/compat/tensorflow_stub/io/gfile_s3_test.py index 85c3c5a71b..67a9dbf18c 100644 --- a/tensorboard/compat/tensorflow_stub/io/gfile_s3_test.py +++ b/tensorboard/compat/tensorflow_stub/io/gfile_s3_test.py @@ -33,11 +33,10 @@ class GFileTest(unittest.TestCase): - @mock_s3 def testExists(self): temp_dir = self._CreateDeepS3Structure() - ckpt_path = self._PathJoin(temp_dir, 'model.ckpt') + ckpt_path = self._PathJoin(temp_dir, "model.ckpt") self.assertTrue(gfile.exists(temp_dir)) self.assertTrue(gfile.exists(ckpt_path)) @@ -47,19 +46,19 @@ def testGlob(self): # S3 glob includes subdirectory content, which standard # filesystem does not. However, this is good for perf. expected = [ - 'a.tfevents.1', - 'bar/b.tfevents.1', - 'bar/baz/c.tfevents.1', - 'bar/baz/d.tfevents.1', - 'bar/quux/some_flume_output.txt', - 'bar/quux/some_more_flume_output.txt', - 'bar/red_herring.txt', - 'model.ckpt', - 'quuz/e.tfevents.1', - 'quuz/garply/corge/g.tfevents.1', - 'quuz/garply/f.tfevents.1', - 'quuz/garply/grault/h.tfevents.1', - 'waldo/fred/i.tfevents.1' + "a.tfevents.1", + "bar/b.tfevents.1", + "bar/baz/c.tfevents.1", + "bar/baz/d.tfevents.1", + "bar/quux/some_flume_output.txt", + "bar/quux/some_more_flume_output.txt", + "bar/red_herring.txt", + "model.ckpt", + "quuz/e.tfevents.1", + "quuz/garply/corge/g.tfevents.1", + "quuz/garply/f.tfevents.1", + "quuz/garply/grault/h.tfevents.1", + "waldo/fred/i.tfevents.1", ] expected_listing = [self._PathJoin(temp_dir, f) for f in expected] gotten_listing = gfile.glob(self._PathJoin(temp_dir, "*")) @@ -67,8 +66,9 @@ def testGlob(self): self, expected_listing, gotten_listing, - 'Files must match. Expected %r. Got %r.' % ( - expected_listing, gotten_listing)) + "Files must match. Expected %r. Got %r." + % (expected_listing, gotten_listing), + ) @mock_s3 def testIsdir(self): @@ -82,11 +82,11 @@ def testListdir(self): expected_files = [ # Empty directory not returned # 'foo', - 'bar', - 'quuz', - 'a.tfevents.1', - 'model.ckpt', - 'waldo', + "bar", + "quuz", + "a.tfevents.1", + "model.ckpt", + "waldo", ] gotten_files = gfile.listdir(temp_dir) six.assertCountEqual(self, expected_files, gotten_files) @@ -94,14 +94,14 @@ def testListdir(self): @mock_s3 def testMakeDirs(self): temp_dir = self._CreateDeepS3Structure() - new_dir = self._PathJoin(temp_dir, 'newdir', 'subdir', 'subsubdir') + new_dir = self._PathJoin(temp_dir, "newdir", "subdir", "subsubdir") gfile.makedirs(new_dir) self.assertTrue(gfile.isdir(new_dir)) @mock_s3 def testMakeDirsAlreadyExists(self): temp_dir = self._CreateDeepS3Structure() - new_dir = self._PathJoin(temp_dir, 'bar', 'baz') + new_dir = self._PathJoin(temp_dir, "bar", "baz") with self.assertRaises(errors.AlreadyExistsError): gfile.makedirs(new_dir) @@ -110,40 +110,21 @@ def testWalk(self): temp_dir = self._CreateDeepS3Structure() self._CreateDeepS3Structure(temp_dir) expected = [ - ['', [ - 'a.tfevents.1', - 'model.ckpt', - ]], + ["", ["a.tfevents.1", "model.ckpt",]], # Empty directory not returned # ['foo', []], - ['bar', [ - 'b.tfevents.1', - 'red_herring.txt', - ]], - ['bar/baz', [ - 'c.tfevents.1', - 'd.tfevents.1', - ]], - ['bar/quux', [ - 'some_flume_output.txt', - 'some_more_flume_output.txt', - ]], - ['quuz', [ - 'e.tfevents.1', - ]], - ['quuz/garply', [ - 'f.tfevents.1', - ]], - ['quuz/garply/corge', [ - 'g.tfevents.1', - ]], - ['quuz/garply/grault', [ - 'h.tfevents.1', - ]], - ['waldo', []], - ['waldo/fred', [ - 'i.tfevents.1', - ]], + ["bar", ["b.tfevents.1", "red_herring.txt",]], + ["bar/baz", ["c.tfevents.1", "d.tfevents.1",]], + [ + "bar/quux", + ["some_flume_output.txt", "some_more_flume_output.txt",], + ], + ["quuz", ["e.tfevents.1",]], + ["quuz/garply", ["f.tfevents.1",]], + ["quuz/garply/corge", ["g.tfevents.1",]], + ["quuz/garply/grault", ["h.tfevents.1",]], + ["waldo", []], + ["waldo/fred", ["i.tfevents.1",]], ] for pair in expected: # If this is not the top-level directory, prepend the high-level @@ -154,21 +135,21 @@ def testWalk(self): @mock_s3 def testStat(self): - ckpt_content = 'asdfasdfasdffoobarbuzz' + ckpt_content = "asdfasdfasdffoobarbuzz" temp_dir = self._CreateDeepS3Structure(ckpt_content=ckpt_content) - ckpt_path = self._PathJoin(temp_dir, 'model.ckpt') + ckpt_path = self._PathJoin(temp_dir, "model.ckpt") ckpt_stat = gfile.stat(ckpt_path) self.assertEqual(ckpt_stat.length, len(ckpt_content)) - bad_ckpt_path = self._PathJoin(temp_dir, 'bad_model.ckpt') + bad_ckpt_path = self._PathJoin(temp_dir, "bad_model.ckpt") with self.assertRaises(errors.NotFoundError): gfile.stat(bad_ckpt_path) @mock_s3 def testRead(self): - ckpt_content = 'asdfasdfasdffoobarbuzz' + ckpt_content = "asdfasdfasdffoobarbuzz" temp_dir = self._CreateDeepS3Structure(ckpt_content=ckpt_content) - ckpt_path = self._PathJoin(temp_dir, 'model.ckpt') - with gfile.GFile(ckpt_path, 'r') as f: + ckpt_path = self._PathJoin(temp_dir, "model.ckpt") + with gfile.GFile(ckpt_path, "r") as f: f.buff_chunk_size = 4 # Test buffering by reducing chunk size ckpt_read = f.read() self.assertEqual(ckpt_content, ckpt_read) @@ -176,120 +157,125 @@ def testRead(self): @mock_s3 def testReadLines(self): ckpt_lines = ( - [u'\n'] + [u'line {}\n'.format(i) for i in range(10)] + [u' '] + [u"\n"] + [u"line {}\n".format(i) for i in range(10)] + [u" "] ) - ckpt_content = u''.join(ckpt_lines) + ckpt_content = u"".join(ckpt_lines) temp_dir = self._CreateDeepS3Structure(ckpt_content=ckpt_content) - ckpt_path = self._PathJoin(temp_dir, 'model.ckpt') - with gfile.GFile(ckpt_path, 'r') as f: + ckpt_path = self._PathJoin(temp_dir, "model.ckpt") + with gfile.GFile(ckpt_path, "r") as f: f.buff_chunk_size = 4 # Test buffering by reducing chunk size ckpt_read_lines = list(f) self.assertEqual(ckpt_lines, ckpt_read_lines) @mock_s3 def testReadWithOffset(self): - ckpt_content = 'asdfasdfasdffoobarbuzz' - ckpt_b_content = b'asdfasdfasdffoobarbuzz' + ckpt_content = "asdfasdfasdffoobarbuzz" + ckpt_b_content = b"asdfasdfasdffoobarbuzz" temp_dir = self._CreateDeepS3Structure(ckpt_content=ckpt_content) - ckpt_path = self._PathJoin(temp_dir, 'model.ckpt') - with gfile.GFile(ckpt_path, 'r') as f: + ckpt_path = self._PathJoin(temp_dir, "model.ckpt") + with gfile.GFile(ckpt_path, "r") as f: f.buff_chunk_size = 4 # Test buffering by reducing chunk size ckpt_read = f.read(12) - self.assertEqual('asdfasdfasdf', ckpt_read) + self.assertEqual("asdfasdfasdf", ckpt_read) ckpt_read = f.read(6) - self.assertEqual('foobar', ckpt_read) + self.assertEqual("foobar", ckpt_read) ckpt_read = f.read(1) - self.assertEqual('b', ckpt_read) + self.assertEqual("b", ckpt_read) ckpt_read = f.read() - self.assertEqual('uzz', ckpt_read) + self.assertEqual("uzz", ckpt_read) ckpt_read = f.read(1000) - self.assertEqual('', ckpt_read) - with gfile.GFile(ckpt_path, 'rb') as f: + self.assertEqual("", ckpt_read) + with gfile.GFile(ckpt_path, "rb") as f: ckpt_read = f.read() self.assertEqual(ckpt_b_content, ckpt_read) @mock_s3 def testWrite(self): temp_dir = self._CreateDeepS3Structure() - ckpt_path = os.path.join(temp_dir, 'model2.ckpt') - ckpt_content = u'asdfasdfasdffoobarbuzz' - with gfile.GFile(ckpt_path, 'w') as f: + ckpt_path = os.path.join(temp_dir, "model2.ckpt") + ckpt_content = u"asdfasdfasdffoobarbuzz" + with gfile.GFile(ckpt_path, "w") as f: f.write(ckpt_content) - with gfile.GFile(ckpt_path, 'r') as f: + with gfile.GFile(ckpt_path, "r") as f: ckpt_read = f.read() self.assertEqual(ckpt_content, ckpt_read) @mock_s3 def testOverwrite(self): temp_dir = self._CreateDeepS3Structure() - ckpt_path = os.path.join(temp_dir, 'model2.ckpt') - ckpt_content = u'asdfasdfasdffoobarbuzz' - with gfile.GFile(ckpt_path, 'w') as f: - f.write(u'original') - with gfile.GFile(ckpt_path, 'w') as f: + ckpt_path = os.path.join(temp_dir, "model2.ckpt") + ckpt_content = u"asdfasdfasdffoobarbuzz" + with gfile.GFile(ckpt_path, "w") as f: + f.write(u"original") + with gfile.GFile(ckpt_path, "w") as f: f.write(ckpt_content) - with gfile.GFile(ckpt_path, 'r') as f: + with gfile.GFile(ckpt_path, "r") as f: ckpt_read = f.read() self.assertEqual(ckpt_content, ckpt_read) @mock_s3 def testWriteMultiple(self): temp_dir = self._CreateDeepS3Structure() - ckpt_path = os.path.join(temp_dir, 'model2.ckpt') - ckpt_content = u'asdfasdfasdffoobarbuzz' * 5 - with gfile.GFile(ckpt_path, 'w') as f: + ckpt_path = os.path.join(temp_dir, "model2.ckpt") + ckpt_content = u"asdfasdfasdffoobarbuzz" * 5 + with gfile.GFile(ckpt_path, "w") as f: for i in range(0, len(ckpt_content), 3): - f.write(ckpt_content[i:i + 3]) + f.write(ckpt_content[i : i + 3]) # Test periodic flushing of the file if i % 9 == 0: f.flush() - with gfile.GFile(ckpt_path, 'r') as f: + with gfile.GFile(ckpt_path, "r") as f: ckpt_read = f.read() self.assertEqual(ckpt_content, ckpt_read) @mock_s3 def testWriteEmpty(self): temp_dir = self._CreateDeepS3Structure() - ckpt_path = os.path.join(temp_dir, 'model2.ckpt') - ckpt_content = u'' - with gfile.GFile(ckpt_path, 'w') as f: + ckpt_path = os.path.join(temp_dir, "model2.ckpt") + ckpt_content = u"" + with gfile.GFile(ckpt_path, "w") as f: f.write(ckpt_content) - with gfile.GFile(ckpt_path, 'r') as f: + with gfile.GFile(ckpt_path, "r") as f: ckpt_read = f.read() self.assertEqual(ckpt_content, ckpt_read) @mock_s3 def testWriteBinary(self): temp_dir = self._CreateDeepS3Structure() - ckpt_path = os.path.join(temp_dir, 'model.ckpt') - ckpt_content = b'asdfasdfasdffoobarbuzz' - with gfile.GFile(ckpt_path, 'wb') as f: + ckpt_path = os.path.join(temp_dir, "model.ckpt") + ckpt_content = b"asdfasdfasdffoobarbuzz" + with gfile.GFile(ckpt_path, "wb") as f: f.write(ckpt_content) - with gfile.GFile(ckpt_path, 'rb') as f: + with gfile.GFile(ckpt_path, "rb") as f: ckpt_read = f.read() self.assertEqual(ckpt_content, ckpt_read) @mock_s3 def testWriteMultipleBinary(self): temp_dir = self._CreateDeepS3Structure() - ckpt_path = os.path.join(temp_dir, 'model2.ckpt') - ckpt_content = b'asdfasdfasdffoobarbuzz' * 5 - with gfile.GFile(ckpt_path, 'wb') as f: + ckpt_path = os.path.join(temp_dir, "model2.ckpt") + ckpt_content = b"asdfasdfasdffoobarbuzz" * 5 + with gfile.GFile(ckpt_path, "wb") as f: for i in range(0, len(ckpt_content), 3): - f.write(ckpt_content[i:i + 3]) + f.write(ckpt_content[i : i + 3]) # Test periodic flushing of the file if i % 9 == 0: f.flush() - with gfile.GFile(ckpt_path, 'rb') as f: + with gfile.GFile(ckpt_path, "rb") as f: ckpt_read = f.read() self.assertEqual(ckpt_content, ckpt_read) def _PathJoin(self, *args): - """Join directory and path with slash and not local separator""" + """Join directory and path with slash and not local separator.""" return "/".join(args) - def _CreateDeepS3Structure(self, top_directory='top_dir', ckpt_content='', - region_name='us-east-1', bucket_name='test'): + def _CreateDeepS3Structure( + self, + top_directory="top_dir", + ckpt_content="", + region_name="us-east-1", + bucket_name="test", + ): """Creates a reasonable deep structure of S3 subdirectories with files. Args: @@ -302,62 +288,62 @@ def _CreateDeepS3Structure(self, top_directory='top_dir', ckpt_content='', Returns: S3 URL of the top directory in the form 's3://bucket/path' """ - s3_top_url = 's3://{}/{}'.format(bucket_name, top_directory) + s3_top_url = "s3://{}/{}".format(bucket_name, top_directory) # Add a few subdirectories. directory_names = ( # An empty directory. - 'foo', + "foo", # A directory with an events file (and a text file). - 'bar', + "bar", # A deeper directory with events files. - 'bar/baz', + "bar/baz", # A non-empty subdir that lacks event files (should be ignored). - 'bar/quux', + "bar/quux", # This 3-level deep set of subdirectories tests logic that replaces # the full glob string with an absolute path prefix if there is # only 1 subdirectory in the final mapping. - 'quuz/garply', - 'quuz/garply/corge', - 'quuz/garply/grault', + "quuz/garply", + "quuz/garply/corge", + "quuz/garply/grault", # A directory that lacks events files, but contains a subdirectory # with events files (first level should be ignored, second level # should be included). - 'waldo', - 'waldo/fred', + "waldo", + "waldo/fred", ) - client = boto3.client('s3', region_name=region_name) + client = boto3.client("s3", region_name=region_name) client.create_bucket(Bucket=bucket_name) - client.put_object(Body='', Bucket=bucket_name, Key=top_directory) + client.put_object(Body="", Bucket=bucket_name, Key=top_directory) for directory_name in directory_names: # Add an end slash - path = top_directory + '/' + directory_name + '/' + path = top_directory + "/" + directory_name + "/" # Create an empty object so the location exists - client.put_object(Body='', Bucket=bucket_name, Key=directory_name) + client.put_object(Body="", Bucket=bucket_name, Key=directory_name) # Add a few files to the directory. file_names = ( - 'a.tfevents.1', - 'model.ckpt', - 'bar/b.tfevents.1', - 'bar/red_herring.txt', - 'bar/baz/c.tfevents.1', - 'bar/baz/d.tfevents.1', - 'bar/quux/some_flume_output.txt', - 'bar/quux/some_more_flume_output.txt', - 'quuz/e.tfevents.1', - 'quuz/garply/f.tfevents.1', - 'quuz/garply/corge/g.tfevents.1', - 'quuz/garply/grault/h.tfevents.1', - 'waldo/fred/i.tfevents.1', + "a.tfevents.1", + "model.ckpt", + "bar/b.tfevents.1", + "bar/red_herring.txt", + "bar/baz/c.tfevents.1", + "bar/baz/d.tfevents.1", + "bar/quux/some_flume_output.txt", + "bar/quux/some_more_flume_output.txt", + "quuz/e.tfevents.1", + "quuz/garply/f.tfevents.1", + "quuz/garply/corge/g.tfevents.1", + "quuz/garply/grault/h.tfevents.1", + "waldo/fred/i.tfevents.1", ) for file_name in file_names: # Add an end slash - path = top_directory + '/' + file_name - if file_name == 'model.ckpt': + path = top_directory + "/" + file_name + if file_name == "model.ckpt": content = ckpt_content else: - content = '' + content = "" client.put_object(Body=content, Bucket=bucket_name, Key=path) return s3_top_url @@ -369,14 +355,18 @@ def _CompareFilesPerSubdirectory(self, expected, gotten): gotten: The gotten iterable of 2-tuples. """ expected_directory_to_files = { - result[0]: list(result[1]) for result in expected} + result[0]: list(result[1]) for result in expected + } gotten_directory_to_files = { # Note we ignore subdirectories and just compare files - result[0]: list(result[2]) for result in gotten} + result[0]: list(result[2]) + for result in gotten + } six.assertCountEqual( self, expected_directory_to_files.keys(), - gotten_directory_to_files.keys()) + gotten_directory_to_files.keys(), + ) for subdir, expected_listing in expected_directory_to_files.items(): gotten_listing = gotten_directory_to_files[subdir] @@ -384,9 +374,10 @@ def _CompareFilesPerSubdirectory(self, expected, gotten): self, expected_listing, gotten_listing, - 'Files for subdir %r must match. Expected %r. Got %r.' % ( - subdir, expected_listing, gotten_listing)) + "Files for subdir %r must match. Expected %r. Got %r." + % (subdir, expected_listing, gotten_listing), + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tensorboard/compat/tensorflow_stub/io/gfile_test.py b/tensorboard/compat/tensorflow_stub/io/gfile_test.py index 3151a9f6f2..7a3ae7e8de 100644 --- a/tensorboard/compat/tensorflow_stub/io/gfile_test.py +++ b/tensorboard/compat/tensorflow_stub/io/gfile_test.py @@ -32,7 +32,7 @@ class GFileTest(tb_test.TestCase): def testExists(self): temp_dir = self.get_temp_dir() self._CreateDeepDirectoryStructure(temp_dir) - ckpt_path = os.path.join(temp_dir, 'model.ckpt') + ckpt_path = os.path.join(temp_dir, "model.ckpt") self.assertTrue(gfile.exists(temp_dir)) self.assertTrue(gfile.exists(ckpt_path)) @@ -40,12 +40,12 @@ def testGlob(self): temp_dir = self.get_temp_dir() self._CreateDeepDirectoryStructure(temp_dir) expected = [ - 'foo', - 'bar', - 'a.tfevents.1', - 'model.ckpt', - 'quuz', - 'waldo', + "foo", + "bar", + "a.tfevents.1", + "model.ckpt", + "quuz", + "waldo", ] expected_listing = [os.path.join(temp_dir, f) for f in expected] gotten_listing = gfile.glob(os.path.join(temp_dir, "*")) @@ -53,8 +53,9 @@ def testGlob(self): self, expected_listing, gotten_listing, - 'Files must match. Expected %r. Got %r.' % ( - expected_listing, gotten_listing)) + "Files must match. Expected %r. Got %r." + % (expected_listing, gotten_listing), + ) def testIsdir(self): temp_dir = self.get_temp_dir() @@ -64,29 +65,26 @@ def testListdir(self): temp_dir = self.get_temp_dir() self._CreateDeepDirectoryStructure(temp_dir) expected_files = ( - 'foo', - 'bar', - 'quuz', - 'a.tfevents.1', - 'model.ckpt', - 'waldo', + "foo", + "bar", + "quuz", + "a.tfevents.1", + "model.ckpt", + "waldo", ) - six.assertCountEqual( - self, - expected_files, - gfile.listdir(temp_dir)) + six.assertCountEqual(self, expected_files, gfile.listdir(temp_dir)) def testMakeDirs(self): temp_dir = self.get_temp_dir() self._CreateDeepDirectoryStructure(temp_dir) - new_dir = os.path.join(temp_dir, 'newdir', 'subdir', 'subsubdir') + new_dir = os.path.join(temp_dir, "newdir", "subdir", "subsubdir") gfile.makedirs(new_dir) self.assertTrue(gfile.isdir(new_dir)) def testMakeDirsAlreadyExists(self): temp_dir = self.get_temp_dir() self._CreateDeepDirectoryStructure(temp_dir) - new_dir = os.path.join(temp_dir, 'bar', 'baz') + new_dir = os.path.join(temp_dir, "bar", "baz") with self.assertRaises(errors.AlreadyExistsError): gfile.makedirs(new_dir) @@ -94,69 +92,53 @@ def testWalk(self): temp_dir = self.get_temp_dir() self._CreateDeepDirectoryStructure(temp_dir) expected = [ - ['', [ - 'a.tfevents.1', - 'model.ckpt', - ]], - ['foo', []], - ['bar', [ - 'b.tfevents.1', - 'red_herring.txt', - ]], - ['bar/baz', [ - 'c.tfevents.1', - 'd.tfevents.1', - ]], - ['bar/quux', [ - 'some_flume_output.txt', - 'some_more_flume_output.txt', - ]], - ['quuz', [ - 'e.tfevents.1', - ]], - ['quuz/garply', [ - 'f.tfevents.1', - ]], - ['quuz/garply/corge', [ - 'g.tfevents.1', - ]], - ['quuz/garply/grault', [ - 'h.tfevents.1', - ]], - ['waldo', []], - ['waldo/fred', [ - 'i.tfevents.1', - ]], + ["", ["a.tfevents.1", "model.ckpt",]], + ["foo", []], + ["bar", ["b.tfevents.1", "red_herring.txt",]], + ["bar/baz", ["c.tfevents.1", "d.tfevents.1",]], + [ + "bar/quux", + ["some_flume_output.txt", "some_more_flume_output.txt",], + ], + ["quuz", ["e.tfevents.1",]], + ["quuz/garply", ["f.tfevents.1",]], + ["quuz/garply/corge", ["g.tfevents.1",]], + ["quuz/garply/grault", ["h.tfevents.1",]], + ["waldo", []], + ["waldo/fred", ["i.tfevents.1",]], ] for pair in expected: # If this is not the top-level directory, prepend the high-level # directory. - pair[0] = os.path.join(temp_dir, - pair[0].replace('/', os.path.sep)) if pair[0] else temp_dir + pair[0] = ( + os.path.join(temp_dir, pair[0].replace("/", os.path.sep)) + if pair[0] + else temp_dir + ) gotten = gfile.walk(temp_dir) self._CompareFilesPerSubdirectory(expected, gotten) def testStat(self): temp_dir = self.get_temp_dir() self._CreateDeepDirectoryStructure(temp_dir) - ckpt_path = os.path.join(temp_dir, 'model.ckpt') - ckpt_content = 'asdfasdfasdffoobarbuzz' - with open(ckpt_path, 'w') as f: + ckpt_path = os.path.join(temp_dir, "model.ckpt") + ckpt_content = "asdfasdfasdffoobarbuzz" + with open(ckpt_path, "w") as f: f.write(ckpt_content) ckpt_stat = gfile.stat(ckpt_path) self.assertEqual(ckpt_stat.length, len(ckpt_content)) - bad_ckpt_path = os.path.join(temp_dir, 'bad_model.ckpt') + bad_ckpt_path = os.path.join(temp_dir, "bad_model.ckpt") with self.assertRaises(errors.NotFoundError): gfile.stat(bad_ckpt_path) def testRead(self): temp_dir = self.get_temp_dir() self._CreateDeepDirectoryStructure(temp_dir) - ckpt_path = os.path.join(temp_dir, 'model.ckpt') - ckpt_content = 'asdfasdfasdffoobarbuzz' - with open(ckpt_path, 'w') as f: + ckpt_path = os.path.join(temp_dir, "model.ckpt") + ckpt_content = "asdfasdfasdffoobarbuzz" + with open(ckpt_path, "w") as f: f.write(ckpt_content) - with gfile.GFile(ckpt_path, 'r') as f: + with gfile.GFile(ckpt_path, "r") as f: f.buff_chunk_size = 4 # Test buffering by reducing chunk size ckpt_read = f.read() self.assertEqual(ckpt_content, ckpt_read) @@ -164,24 +146,24 @@ def testRead(self): def testReadLines(self): temp_dir = self.get_temp_dir() self._CreateDeepDirectoryStructure(temp_dir) - ckpt_path = os.path.join(temp_dir, 'model.ckpt') + ckpt_path = os.path.join(temp_dir, "model.ckpt") # Note \r\n, which io.read will automatically replace with \n. # That substitution desynchronizes character offsets (omitting \r) from # the underlying byte offsets (counting \r). Multibyte characters would # similarly cause desynchronization. raw_ckpt_lines = ( - [u'\r\n'] + [u'line {}\r\n'.format(i) for i in range(10)] + [u' '] + [u"\r\n"] + [u"line {}\r\n".format(i) for i in range(10)] + [u" "] ) - expected_ckpt_lines = ( # without \r - [u'\n'] + [u'line {}\n'.format(i) for i in range(10)] + [u' '] + expected_ckpt_lines = ( # without \r + [u"\n"] + [u"line {}\n".format(i) for i in range(10)] + [u" "] ) # Write out newlines as given (i.e., \r\n) regardless of OS, so as to # test translation on read. - with io.open(ckpt_path, 'w', newline='') as f: - data = u''.join(raw_ckpt_lines) + with io.open(ckpt_path, "w", newline="") as f: + data = u"".join(raw_ckpt_lines) f.write(data) - with gfile.GFile(ckpt_path, 'r') as f: + with gfile.GFile(ckpt_path, "r") as f: f.buff_chunk_size = 4 # Test buffering by reducing chunk size read_ckpt_lines = list(f) self.assertEqual(expected_ckpt_lines, read_ckpt_lines) @@ -189,100 +171,100 @@ def testReadLines(self): def testReadWithOffset(self): temp_dir = self.get_temp_dir() self._CreateDeepDirectoryStructure(temp_dir) - ckpt_path = os.path.join(temp_dir, 'model.ckpt') - ckpt_content = 'asdfasdfasdffoobarbuzz' - ckpt_b_content = b'asdfasdfasdffoobarbuzz' - with open(ckpt_path, 'w') as f: + ckpt_path = os.path.join(temp_dir, "model.ckpt") + ckpt_content = "asdfasdfasdffoobarbuzz" + ckpt_b_content = b"asdfasdfasdffoobarbuzz" + with open(ckpt_path, "w") as f: f.write(ckpt_content) - with gfile.GFile(ckpt_path, 'r') as f: + with gfile.GFile(ckpt_path, "r") as f: f.buff_chunk_size = 4 # Test buffering by reducing chunk size ckpt_read = f.read(12) - self.assertEqual('asdfasdfasdf', ckpt_read) + self.assertEqual("asdfasdfasdf", ckpt_read) ckpt_read = f.read(6) - self.assertEqual('foobar', ckpt_read) + self.assertEqual("foobar", ckpt_read) ckpt_read = f.read(1) - self.assertEqual('b', ckpt_read) + self.assertEqual("b", ckpt_read) ckpt_read = f.read() - self.assertEqual('uzz', ckpt_read) + self.assertEqual("uzz", ckpt_read) ckpt_read = f.read(1000) - self.assertEqual('', ckpt_read) - with gfile.GFile(ckpt_path, 'rb') as f: + self.assertEqual("", ckpt_read) + with gfile.GFile(ckpt_path, "rb") as f: ckpt_read = f.read() self.assertEqual(ckpt_b_content, ckpt_read) def testWrite(self): temp_dir = self.get_temp_dir() self._CreateDeepDirectoryStructure(temp_dir) - ckpt_path = os.path.join(temp_dir, 'model2.ckpt') - ckpt_content = u'asdfasdfasdffoobarbuzz' - with gfile.GFile(ckpt_path, 'w') as f: + ckpt_path = os.path.join(temp_dir, "model2.ckpt") + ckpt_content = u"asdfasdfasdffoobarbuzz" + with gfile.GFile(ckpt_path, "w") as f: f.write(ckpt_content) - with open(ckpt_path, 'r') as f: + with open(ckpt_path, "r") as f: ckpt_read = f.read() self.assertEqual(ckpt_content, ckpt_read) def testOverwrite(self): temp_dir = self.get_temp_dir() self._CreateDeepDirectoryStructure(temp_dir) - ckpt_path = os.path.join(temp_dir, 'model2.ckpt') - ckpt_content = u'asdfasdfasdffoobarbuzz' - with gfile.GFile(ckpt_path, 'w') as f: - f.write(u'original') - with gfile.GFile(ckpt_path, 'w') as f: + ckpt_path = os.path.join(temp_dir, "model2.ckpt") + ckpt_content = u"asdfasdfasdffoobarbuzz" + with gfile.GFile(ckpt_path, "w") as f: + f.write(u"original") + with gfile.GFile(ckpt_path, "w") as f: f.write(ckpt_content) - with open(ckpt_path, 'r') as f: + with open(ckpt_path, "r") as f: ckpt_read = f.read() self.assertEqual(ckpt_content, ckpt_read) def testWriteMultiple(self): temp_dir = self.get_temp_dir() self._CreateDeepDirectoryStructure(temp_dir) - ckpt_path = os.path.join(temp_dir, 'model2.ckpt') - ckpt_content = u'asdfasdfasdffoobarbuzz' * 5 - with gfile.GFile(ckpt_path, 'w') as f: + ckpt_path = os.path.join(temp_dir, "model2.ckpt") + ckpt_content = u"asdfasdfasdffoobarbuzz" * 5 + with gfile.GFile(ckpt_path, "w") as f: for i in range(0, len(ckpt_content), 3): - f.write(ckpt_content[i:i + 3]) + f.write(ckpt_content[i : i + 3]) # Test periodic flushing of the file if i % 9 == 0: f.flush() - with open(ckpt_path, 'r') as f: + with open(ckpt_path, "r") as f: ckpt_read = f.read() self.assertEqual(ckpt_content, ckpt_read) def testWriteEmpty(self): temp_dir = self.get_temp_dir() self._CreateDeepDirectoryStructure(temp_dir) - ckpt_path = os.path.join(temp_dir, 'model2.ckpt') - ckpt_content = u'' - with gfile.GFile(ckpt_path, 'w') as f: + ckpt_path = os.path.join(temp_dir, "model2.ckpt") + ckpt_content = u"" + with gfile.GFile(ckpt_path, "w") as f: f.write(ckpt_content) - with open(ckpt_path, 'r') as f: + with open(ckpt_path, "r") as f: ckpt_read = f.read() self.assertEqual(ckpt_content, ckpt_read) def testWriteBinary(self): temp_dir = self.get_temp_dir() self._CreateDeepDirectoryStructure(temp_dir) - ckpt_path = os.path.join(temp_dir, 'model2.ckpt') - ckpt_content = b'asdfasdfasdffoobarbuzz' - with gfile.GFile(ckpt_path, 'wb') as f: + ckpt_path = os.path.join(temp_dir, "model2.ckpt") + ckpt_content = b"asdfasdfasdffoobarbuzz" + with gfile.GFile(ckpt_path, "wb") as f: f.write(ckpt_content) - with open(ckpt_path, 'rb') as f: + with open(ckpt_path, "rb") as f: ckpt_read = f.read() self.assertEqual(ckpt_content, ckpt_read) def testWriteMultipleBinary(self): temp_dir = self.get_temp_dir() self._CreateDeepDirectoryStructure(temp_dir) - ckpt_path = os.path.join(temp_dir, 'model2.ckpt') - ckpt_content = b'asdfasdfasdffoobarbuzz' * 5 - with gfile.GFile(ckpt_path, 'wb') as f: + ckpt_path = os.path.join(temp_dir, "model2.ckpt") + ckpt_content = b"asdfasdfasdffoobarbuzz" * 5 + with gfile.GFile(ckpt_path, "wb") as f: for i in range(0, len(ckpt_content), 3): - f.write(ckpt_content[i:i + 3]) + f.write(ckpt_content[i : i + 3]) # Test periodic flushing of the file if i % 9 == 0: f.flush() - with open(ckpt_path, 'rb') as f: + with open(ckpt_path, "rb") as f: ckpt_read = f.read() self.assertEqual(ckpt_content, ckpt_read) @@ -296,46 +278,46 @@ def _CreateDeepDirectoryStructure(self, top_directory): # Add a few subdirectories. directory_names = ( # An empty directory. - 'foo', + "foo", # A directory with an events file (and a text file). - 'bar', + "bar", # A deeper directory with events files. - 'bar/baz', + "bar/baz", # A non-empty subdir that lacks event files (should be ignored). - 'bar/quux', + "bar/quux", # This 3-level deep set of subdirectories tests logic that replaces # the full glob string with an absolute path prefix if there is # only 1 subdirectory in the final mapping. - 'quuz/garply', - 'quuz/garply/corge', - 'quuz/garply/grault', + "quuz/garply", + "quuz/garply/corge", + "quuz/garply/grault", # A directory that lacks events files, but contains a subdirectory # with events files (first level should be ignored, second level # should be included). - 'waldo', - 'waldo/fred', + "waldo", + "waldo/fred", ) for directory_name in directory_names: os.makedirs(os.path.join(top_directory, directory_name)) # Add a few files to the directory. file_names = ( - 'a.tfevents.1', - 'model.ckpt', - 'bar/b.tfevents.1', - 'bar/red_herring.txt', - 'bar/baz/c.tfevents.1', - 'bar/baz/d.tfevents.1', - 'bar/quux/some_flume_output.txt', - 'bar/quux/some_more_flume_output.txt', - 'quuz/e.tfevents.1', - 'quuz/garply/f.tfevents.1', - 'quuz/garply/corge/g.tfevents.1', - 'quuz/garply/grault/h.tfevents.1', - 'waldo/fred/i.tfevents.1', + "a.tfevents.1", + "model.ckpt", + "bar/b.tfevents.1", + "bar/red_herring.txt", + "bar/baz/c.tfevents.1", + "bar/baz/d.tfevents.1", + "bar/quux/some_flume_output.txt", + "bar/quux/some_more_flume_output.txt", + "quuz/e.tfevents.1", + "quuz/garply/f.tfevents.1", + "quuz/garply/corge/g.tfevents.1", + "quuz/garply/grault/h.tfevents.1", + "waldo/fred/i.tfevents.1", ) for file_name in file_names: - open(os.path.join(top_directory, file_name), 'w').close() + open(os.path.join(top_directory, file_name), "w").close() def _CompareFilesPerSubdirectory(self, expected, gotten): """Compares iterables of (subdirectory path, list of absolute paths) @@ -345,14 +327,18 @@ def _CompareFilesPerSubdirectory(self, expected, gotten): gotten: The gotten iterable of 2-tuples. """ expected_directory_to_files = { - result[0]: list(result[1]) for result in expected} + result[0]: list(result[1]) for result in expected + } gotten_directory_to_files = { # Note we ignore subdirectories and just compare files - result[0]: list(result[2]) for result in gotten} + result[0]: list(result[2]) + for result in gotten + } six.assertCountEqual( self, expected_directory_to_files.keys(), - gotten_directory_to_files.keys()) + gotten_directory_to_files.keys(), + ) for subdir, expected_listing in expected_directory_to_files.items(): gotten_listing = gotten_directory_to_files[subdir] @@ -360,9 +346,10 @@ def _CompareFilesPerSubdirectory(self, expected, gotten): self, expected_listing, gotten_listing, - 'Files for subdir %r must match. Expected %r. Got %r.' % ( - subdir, expected_listing, gotten_listing)) + "Files for subdir %r must match. Expected %r. Got %r." + % (subdir, expected_listing, gotten_listing), + ) -if __name__ == '__main__': +if __name__ == "__main__": tb_test.main() diff --git a/tensorboard/compat/tensorflow_stub/io/gfile_tf_test.py b/tensorboard/compat/tensorflow_stub/io/gfile_tf_test.py index b6ce254280..38b0b61d75 100644 --- a/tensorboard/compat/tensorflow_stub/io/gfile_tf_test.py +++ b/tensorboard/compat/tensorflow_stub/io/gfile_tf_test.py @@ -35,247 +35,274 @@ # the TensorFlow FileIO API (to the extent that it's implemented at all). # Many of the TF tests are removed because they do not apply here. -class FileIoTest(tb_test.TestCase): - def setUp(self): - self._base_dir = os.path.join(self.get_temp_dir(), "base_dir") - gfile.makedirs(self._base_dir) - - # It was a temp_dir anyway - # def tearDown(self): - # gfile.delete_recursively(self._base_dir) - - def testEmptyFilename(self): - f = gfile.GFile("", mode="r") - with self.assertRaises(errors.NotFoundError): - _ = f.read() - - def testFileDoesntExist(self): - file_path = os.path.join(self._base_dir, "temp_file") - self.assertFalse(gfile.exists(file_path)) - with self.assertRaises(errors.NotFoundError): - _ = gfile._read_file_to_string(file_path) - - def testWriteToString(self): - file_path = os.path.join(self._base_dir, "temp_file") - gfile._write_string_to_file(file_path, "testing") - self.assertTrue(gfile.exists(file_path)) - file_contents = gfile._read_file_to_string(file_path) - self.assertEqual("testing", file_contents) - - def testReadBinaryMode(self): - file_path = os.path.join(self._base_dir, "temp_file") - gfile._write_string_to_file(file_path, "testing") - with gfile.GFile(file_path, mode="rb") as f: - self.assertEqual(b"testing", f.read()) - - def testWriteBinaryMode(self): - file_path = os.path.join(self._base_dir, "temp_file") - gfile.GFile(file_path, "wb").write(compat.as_bytes("testing")) - with gfile.GFile(file_path, mode="r") as f: - self.assertEqual("testing", f.read()) - - def testMultipleFiles(self): - file_prefix = os.path.join(self._base_dir, "temp_file") - for i in range(5000): - - with gfile.GFile(file_prefix + str(i), mode="w") as f: - f.write("testing") - f.flush() - - with gfile.GFile(file_prefix + str(i), mode="r") as f: - self.assertEqual("testing", f.read()) - - def testMultipleWrites(self): - file_path = os.path.join(self._base_dir, "temp_file") - with gfile.GFile(file_path, mode="w") as f: - f.write("line1\n") - f.write("line2") - file_contents = gfile._read_file_to_string(file_path) - self.assertEqual("line1\nline2", file_contents) - - def testFileWriteBadMode(self): - file_path = os.path.join(self._base_dir, "temp_file") - with self.assertRaises(errors.PermissionDeniedError): - gfile.GFile(file_path, mode="r").write("testing") - - def testFileReadBadMode(self): - file_path = os.path.join(self._base_dir, "temp_file") - gfile.GFile(file_path, mode="w").write("testing") - self.assertTrue(gfile.exists(file_path)) - with self.assertRaises(errors.PermissionDeniedError): - gfile.GFile(file_path, mode="w").read() - - def testIsDirectory(self): - dir_path = os.path.join(self._base_dir, "test_dir") - # Failure for a non-existing dir. - self.assertFalse(gfile.isdir(dir_path)) - gfile.makedirs(dir_path) - self.assertTrue(gfile.isdir(dir_path)) - file_path = os.path.join(dir_path, "test_file") - gfile.GFile(file_path, mode="w").write("test") - # False for a file. - self.assertFalse(gfile.isdir(file_path)) - # Test that the value returned from `stat()` has `is_directory` set. - # file_statistics = gfile.stat(dir_path) - # self.assertTrue(file_statistics.is_directory) - - def testListDirectory(self): - dir_path = os.path.join(self._base_dir, "test_dir") - gfile.makedirs(dir_path) - files = ["file1.txt", "file2.txt", "file3.txt"] - for name in files: - file_path = os.path.join(dir_path, name) - gfile.GFile(file_path, mode="w").write("testing") - subdir_path = os.path.join(dir_path, "sub_dir") - gfile.makedirs(subdir_path) - subdir_file_path = os.path.join(subdir_path, "file4.txt") - gfile.GFile(subdir_file_path, mode="w").write("testing") - dir_list = gfile.listdir(dir_path) - self.assertItemsEqual(files + ["sub_dir"], dir_list) - - def testListDirectoryFailure(self): - dir_path = os.path.join(self._base_dir, "test_dir") - with self.assertRaises(errors.NotFoundError): - gfile.listdir(dir_path) - - def _setupWalkDirectories(self, dir_path): - # Creating a file structure as follows - # test_dir -> file: file1.txt; dirs: subdir1_1, subdir1_2, subdir1_3 - # subdir1_1 -> file: file3.txt - # subdir1_2 -> dir: subdir2 - gfile.makedirs(dir_path) - gfile.GFile( - os.path.join(dir_path, "file1.txt"), mode="w").write("testing") - sub_dirs1 = ["subdir1_1", "subdir1_2", "subdir1_3"] - for name in sub_dirs1: - gfile.makedirs(os.path.join(dir_path, name)) - gfile.GFile( - os.path.join(dir_path, "subdir1_1/file2.txt"), - mode="w").write("testing") - gfile.makedirs(os.path.join(dir_path, "subdir1_2/subdir2")) - - def testWalkInOrder(self): - dir_path = os.path.join(self._base_dir, "test_dir") - self._setupWalkDirectories(dir_path) - # Now test the walk (topdown = True) - all_dirs = [] - all_subdirs = [] - all_files = [] - for (w_dir, w_subdirs, w_files) in gfile.walk(dir_path, topdown=True): - all_dirs.append(w_dir) - all_subdirs.append(w_subdirs) - all_files.append(w_files) - self.assertItemsEqual(all_dirs, [dir_path] + [ - os.path.join(dir_path, item) - for item in - ["subdir1_1", "subdir1_2", "subdir1_2/subdir2", "subdir1_3"] - ]) - self.assertEqual(dir_path, all_dirs[0]) - self.assertLess( - all_dirs.index(os.path.join(dir_path, "subdir1_2")), - all_dirs.index(os.path.join(dir_path, "subdir1_2/subdir2"))) - self.assertItemsEqual(all_subdirs[1:5], [[], ["subdir2"], [], []]) - self.assertItemsEqual(all_subdirs[0], - ["subdir1_1", "subdir1_2", "subdir1_3"]) - self.assertItemsEqual(all_files, [["file1.txt"], ["file2.txt"], [], [], []]) - self.assertLess( - all_files.index(["file1.txt"]), all_files.index(["file2.txt"])) - - def testWalkPostOrder(self): - dir_path = os.path.join(self._base_dir, "test_dir") - self._setupWalkDirectories(dir_path) - # Now test the walk (topdown = False) - all_dirs = [] - all_subdirs = [] - all_files = [] - for (w_dir, w_subdirs, w_files) in gfile.walk(dir_path, topdown=False): - all_dirs.append(w_dir) - all_subdirs.append(w_subdirs) - all_files.append(w_files) - self.assertItemsEqual(all_dirs, [ - os.path.join(dir_path, item) - for item in - ["subdir1_1", "subdir1_2/subdir2", "subdir1_2", "subdir1_3"] - ] + [dir_path]) - self.assertEqual(dir_path, all_dirs[4]) - self.assertLess( - all_dirs.index(os.path.join(dir_path, "subdir1_2/subdir2")), - all_dirs.index(os.path.join(dir_path, "subdir1_2"))) - self.assertItemsEqual(all_subdirs[0:4], [[], [], ["subdir2"], []]) - self.assertItemsEqual(all_subdirs[4], - ["subdir1_1", "subdir1_2", "subdir1_3"]) - self.assertItemsEqual(all_files, [["file2.txt"], [], [], [], ["file1.txt"]]) - self.assertLess( - all_files.index(["file2.txt"]), all_files.index(["file1.txt"])) - - def testWalkFailure(self): - dir_path = os.path.join(self._base_dir, "test_dir") - # Try walking a directory that wasn't created. - all_dirs = [] - all_subdirs = [] - all_files = [] - for (w_dir, w_subdirs, w_files) in gfile.walk(dir_path, topdown=False): - all_dirs.append(w_dir) - all_subdirs.append(w_subdirs) - all_files.append(w_files) - self.assertItemsEqual(all_dirs, []) - self.assertItemsEqual(all_subdirs, []) - self.assertItemsEqual(all_files, []) - - def testStat(self): - file_path = os.path.join(self._base_dir, "temp_file") - gfile.GFile(file_path, mode="w").write("testing") - file_statistics = gfile.stat(file_path) - os_statistics = os.stat(file_path) - self.assertEqual(7, file_statistics.length) - - def testRead(self): - file_path = os.path.join(self._base_dir, "temp_file") - with gfile.GFile(file_path, mode="w") as f: - f.write("testing1\ntesting2\ntesting3\n\ntesting5") - with gfile.GFile(file_path, mode="r") as f: - self.assertEqual(36, gfile.stat(file_path).length) - self.assertEqual("testing1\n", f.read(9)) - self.assertEqual("testing2\n", f.read(9)) - self.assertEqual("t", f.read(1)) - self.assertEqual("esting3\n\ntesting5", f.read()) - - def testReadingIterator(self): - file_path = os.path.join(self._base_dir, "temp_file") - data = ["testing1\n", "testing2\n", "testing3\n", "\n", "testing5"] - with gfile.GFile(file_path, mode="w") as f: - f.write("".join(data)) - with gfile.GFile(file_path, mode="r") as f: - actual_data = [] - for line in f: - actual_data.append(line) - self.assertSequenceEqual(actual_data, data) - - def testUTF8StringPath(self): - file_path = os.path.join(self._base_dir, "UTF8测试_file") - gfile._write_string_to_file(file_path, "testing") - with gfile.GFile(file_path, mode="rb") as f: - self.assertEqual(b"testing", f.read()) - - def testEof(self): - """Test that reading past EOF does not raise an exception.""" - - file_path = os.path.join(self._base_dir, "temp_file") - - with gfile.GFile(file_path, mode="w") as f: - content = "testing" - f.write(content) - f.flush() - with gfile.GFile(file_path, mode="r") as f: - self.assertEqual(content, f.read(len(content) + 1)) - - def testUTF8StringPathExists(self): - file_path = os.path.join(self._base_dir, "UTF8测试_file_exist") - gfile._write_string_to_file(file_path, "testing") - v = gfile.exists(file_path) - self.assertEqual(v, True) +class FileIoTest(tb_test.TestCase): + def setUp(self): + self._base_dir = os.path.join(self.get_temp_dir(), "base_dir") + gfile.makedirs(self._base_dir) + + # It was a temp_dir anyway + # def tearDown(self): + # gfile.delete_recursively(self._base_dir) + + def testEmptyFilename(self): + f = gfile.GFile("", mode="r") + with self.assertRaises(errors.NotFoundError): + _ = f.read() + + def testFileDoesntExist(self): + file_path = os.path.join(self._base_dir, "temp_file") + self.assertFalse(gfile.exists(file_path)) + with self.assertRaises(errors.NotFoundError): + _ = gfile._read_file_to_string(file_path) + + def testWriteToString(self): + file_path = os.path.join(self._base_dir, "temp_file") + gfile._write_string_to_file(file_path, "testing") + self.assertTrue(gfile.exists(file_path)) + file_contents = gfile._read_file_to_string(file_path) + self.assertEqual("testing", file_contents) + + def testReadBinaryMode(self): + file_path = os.path.join(self._base_dir, "temp_file") + gfile._write_string_to_file(file_path, "testing") + with gfile.GFile(file_path, mode="rb") as f: + self.assertEqual(b"testing", f.read()) + + def testWriteBinaryMode(self): + file_path = os.path.join(self._base_dir, "temp_file") + gfile.GFile(file_path, "wb").write(compat.as_bytes("testing")) + with gfile.GFile(file_path, mode="r") as f: + self.assertEqual("testing", f.read()) + + def testMultipleFiles(self): + file_prefix = os.path.join(self._base_dir, "temp_file") + for i in range(5000): + + with gfile.GFile(file_prefix + str(i), mode="w") as f: + f.write("testing") + f.flush() + + with gfile.GFile(file_prefix + str(i), mode="r") as f: + self.assertEqual("testing", f.read()) + + def testMultipleWrites(self): + file_path = os.path.join(self._base_dir, "temp_file") + with gfile.GFile(file_path, mode="w") as f: + f.write("line1\n") + f.write("line2") + file_contents = gfile._read_file_to_string(file_path) + self.assertEqual("line1\nline2", file_contents) + + def testFileWriteBadMode(self): + file_path = os.path.join(self._base_dir, "temp_file") + with self.assertRaises(errors.PermissionDeniedError): + gfile.GFile(file_path, mode="r").write("testing") + + def testFileReadBadMode(self): + file_path = os.path.join(self._base_dir, "temp_file") + gfile.GFile(file_path, mode="w").write("testing") + self.assertTrue(gfile.exists(file_path)) + with self.assertRaises(errors.PermissionDeniedError): + gfile.GFile(file_path, mode="w").read() + + def testIsDirectory(self): + dir_path = os.path.join(self._base_dir, "test_dir") + # Failure for a non-existing dir. + self.assertFalse(gfile.isdir(dir_path)) + gfile.makedirs(dir_path) + self.assertTrue(gfile.isdir(dir_path)) + file_path = os.path.join(dir_path, "test_file") + gfile.GFile(file_path, mode="w").write("test") + # False for a file. + self.assertFalse(gfile.isdir(file_path)) + # Test that the value returned from `stat()` has `is_directory` set. + # file_statistics = gfile.stat(dir_path) + # self.assertTrue(file_statistics.is_directory) + + def testListDirectory(self): + dir_path = os.path.join(self._base_dir, "test_dir") + gfile.makedirs(dir_path) + files = ["file1.txt", "file2.txt", "file3.txt"] + for name in files: + file_path = os.path.join(dir_path, name) + gfile.GFile(file_path, mode="w").write("testing") + subdir_path = os.path.join(dir_path, "sub_dir") + gfile.makedirs(subdir_path) + subdir_file_path = os.path.join(subdir_path, "file4.txt") + gfile.GFile(subdir_file_path, mode="w").write("testing") + dir_list = gfile.listdir(dir_path) + self.assertItemsEqual(files + ["sub_dir"], dir_list) + + def testListDirectoryFailure(self): + dir_path = os.path.join(self._base_dir, "test_dir") + with self.assertRaises(errors.NotFoundError): + gfile.listdir(dir_path) + + def _setupWalkDirectories(self, dir_path): + # Creating a file structure as follows + # test_dir -> file: file1.txt; dirs: subdir1_1, subdir1_2, subdir1_3 + # subdir1_1 -> file: file3.txt + # subdir1_2 -> dir: subdir2 + gfile.makedirs(dir_path) + gfile.GFile(os.path.join(dir_path, "file1.txt"), mode="w").write( + "testing" + ) + sub_dirs1 = ["subdir1_1", "subdir1_2", "subdir1_3"] + for name in sub_dirs1: + gfile.makedirs(os.path.join(dir_path, name)) + gfile.GFile( + os.path.join(dir_path, "subdir1_1/file2.txt"), mode="w" + ).write("testing") + gfile.makedirs(os.path.join(dir_path, "subdir1_2/subdir2")) + + def testWalkInOrder(self): + dir_path = os.path.join(self._base_dir, "test_dir") + self._setupWalkDirectories(dir_path) + # Now test the walk (topdown = True) + all_dirs = [] + all_subdirs = [] + all_files = [] + for (w_dir, w_subdirs, w_files) in gfile.walk(dir_path, topdown=True): + all_dirs.append(w_dir) + all_subdirs.append(w_subdirs) + all_files.append(w_files) + self.assertItemsEqual( + all_dirs, + [dir_path] + + [ + os.path.join(dir_path, item) + for item in [ + "subdir1_1", + "subdir1_2", + "subdir1_2/subdir2", + "subdir1_3", + ] + ], + ) + self.assertEqual(dir_path, all_dirs[0]) + self.assertLess( + all_dirs.index(os.path.join(dir_path, "subdir1_2")), + all_dirs.index(os.path.join(dir_path, "subdir1_2/subdir2")), + ) + self.assertItemsEqual(all_subdirs[1:5], [[], ["subdir2"], [], []]) + self.assertItemsEqual( + all_subdirs[0], ["subdir1_1", "subdir1_2", "subdir1_3"] + ) + self.assertItemsEqual( + all_files, [["file1.txt"], ["file2.txt"], [], [], []] + ) + self.assertLess( + all_files.index(["file1.txt"]), all_files.index(["file2.txt"]) + ) + + def testWalkPostOrder(self): + dir_path = os.path.join(self._base_dir, "test_dir") + self._setupWalkDirectories(dir_path) + # Now test the walk (topdown = False) + all_dirs = [] + all_subdirs = [] + all_files = [] + for (w_dir, w_subdirs, w_files) in gfile.walk(dir_path, topdown=False): + all_dirs.append(w_dir) + all_subdirs.append(w_subdirs) + all_files.append(w_files) + self.assertItemsEqual( + all_dirs, + [ + os.path.join(dir_path, item) + for item in [ + "subdir1_1", + "subdir1_2/subdir2", + "subdir1_2", + "subdir1_3", + ] + ] + + [dir_path], + ) + self.assertEqual(dir_path, all_dirs[4]) + self.assertLess( + all_dirs.index(os.path.join(dir_path, "subdir1_2/subdir2")), + all_dirs.index(os.path.join(dir_path, "subdir1_2")), + ) + self.assertItemsEqual(all_subdirs[0:4], [[], [], ["subdir2"], []]) + self.assertItemsEqual( + all_subdirs[4], ["subdir1_1", "subdir1_2", "subdir1_3"] + ) + self.assertItemsEqual( + all_files, [["file2.txt"], [], [], [], ["file1.txt"]] + ) + self.assertLess( + all_files.index(["file2.txt"]), all_files.index(["file1.txt"]) + ) + + def testWalkFailure(self): + dir_path = os.path.join(self._base_dir, "test_dir") + # Try walking a directory that wasn't created. + all_dirs = [] + all_subdirs = [] + all_files = [] + for (w_dir, w_subdirs, w_files) in gfile.walk(dir_path, topdown=False): + all_dirs.append(w_dir) + all_subdirs.append(w_subdirs) + all_files.append(w_files) + self.assertItemsEqual(all_dirs, []) + self.assertItemsEqual(all_subdirs, []) + self.assertItemsEqual(all_files, []) + + def testStat(self): + file_path = os.path.join(self._base_dir, "temp_file") + gfile.GFile(file_path, mode="w").write("testing") + file_statistics = gfile.stat(file_path) + os_statistics = os.stat(file_path) + self.assertEqual(7, file_statistics.length) + + def testRead(self): + file_path = os.path.join(self._base_dir, "temp_file") + with gfile.GFile(file_path, mode="w") as f: + f.write("testing1\ntesting2\ntesting3\n\ntesting5") + with gfile.GFile(file_path, mode="r") as f: + self.assertEqual(36, gfile.stat(file_path).length) + self.assertEqual("testing1\n", f.read(9)) + self.assertEqual("testing2\n", f.read(9)) + self.assertEqual("t", f.read(1)) + self.assertEqual("esting3\n\ntesting5", f.read()) + + def testReadingIterator(self): + file_path = os.path.join(self._base_dir, "temp_file") + data = ["testing1\n", "testing2\n", "testing3\n", "\n", "testing5"] + with gfile.GFile(file_path, mode="w") as f: + f.write("".join(data)) + with gfile.GFile(file_path, mode="r") as f: + actual_data = [] + for line in f: + actual_data.append(line) + self.assertSequenceEqual(actual_data, data) + + def testUTF8StringPath(self): + file_path = os.path.join(self._base_dir, "UTF8测试_file") + gfile._write_string_to_file(file_path, "testing") + with gfile.GFile(file_path, mode="rb") as f: + self.assertEqual(b"testing", f.read()) + + def testEof(self): + """Test that reading past EOF does not raise an exception.""" + + file_path = os.path.join(self._base_dir, "temp_file") + + with gfile.GFile(file_path, mode="w") as f: + content = "testing" + f.write(content) + f.flush() + with gfile.GFile(file_path, mode="r") as f: + self.assertEqual(content, f.read(len(content) + 1)) + + def testUTF8StringPathExists(self): + file_path = os.path.join(self._base_dir, "UTF8测试_file_exist") + gfile._write_string_to_file(file_path, "testing") + v = gfile.exists(file_path) + self.assertEqual(v, True) if __name__ == "__main__": - tb_test.main() + tb_test.main() diff --git a/tensorboard/compat/tensorflow_stub/pywrap_tensorflow.py b/tensorboard/compat/tensorflow_stub/pywrap_tensorflow.py index a41563e698..649ac598eb 100644 --- a/tensorboard/compat/tensorflow_stub/pywrap_tensorflow.py +++ b/tensorboard/compat/tensorflow_stub/pywrap_tensorflow.py @@ -41,11 +41,11 @@ def TF_bfloat16_type(): def masked_crc32c(data): x = u32(crc32c(data)) - return u32(((x >> 15) | u32(x << 17)) + 0xa282ead8) + return u32(((x >> 15) | u32(x << 17)) + 0xA282EAD8) def u32(x): - return x & 0xffffffff + return x & 0xFFFFFFFF # fmt: off @@ -125,6 +125,7 @@ def u32(x): def crc_update(crc, data): """Update CRC-32C checksum with data. + Args: crc: 32-bit checksum to update as long. data: byte array, string or iterable over bytes. @@ -139,13 +140,14 @@ def crc_update(crc, data): crc ^= _MASK for b in buf: - table_index = (crc ^ b) & 0xff + table_index = (crc ^ b) & 0xFF crc = (CRC_TABLE[table_index] ^ (crc >> 8)) & _MASK return crc ^ _MASK def crc_finalize(crc): """Finalize CRC-32C checksum. + This function should be called as last step of crc calculation. Args: crc: 32-bit checksum as long. @@ -157,6 +159,7 @@ def crc_finalize(crc): def crc32c(data): """Compute CRC-32C checksum of the data. + Args: data: byte array, string or iterable over bytes. Returns: @@ -166,28 +169,34 @@ def crc32c(data): class PyRecordReader_New: - def __init__(self, filename=None, start_offset=0, compression_type=None, - status=None): + def __init__( + self, filename=None, start_offset=0, compression_type=None, status=None + ): if filename is None: raise errors.NotFoundError( - None, None, 'No filename provided, cannot read Events') + None, None, "No filename provided, cannot read Events" + ) if not gfile.exists(filename): raise errors.NotFoundError( - None, None, - '{} does not point to valid Events file'.format(filename)) + None, + None, + "{} does not point to valid Events file".format(filename), + ) if start_offset: raise errors.UnimplementedError( - None, None, 'start offset not supported by compat reader') + None, None, "start offset not supported by compat reader" + ) if compression_type: # TODO: Handle gzip and zlib compressed files raise errors.UnimplementedError( - None, None, 'compression not supported by compat reader') + None, None, "compression not supported by compat reader" + ) self.filename = filename self.start_offset = start_offset self.compression_type = compression_type self.status = status self.curr_event = None - self.file_handle = gfile.GFile(self.filename, 'rb') + self.file_handle = gfile.GFile(self.filename, "rb") def GetNext(self): # Read the header @@ -195,18 +204,17 @@ def GetNext(self): header_str = self.file_handle.read(8) if len(header_str) != 8: # Hit EOF so raise and exit - raise errors.OutOfRangeError(None, None, 'No more events to read') - header = struct.unpack('Q', header_str) + raise errors.OutOfRangeError(None, None, "No more events to read") + header = struct.unpack("Q", header_str) # Read the crc32, which is 4 bytes, and check it against # the crc32 of the header crc_header_str = self.file_handle.read(4) - crc_header = struct.unpack('I', crc_header_str) + crc_header = struct.unpack("I", crc_header_str) header_crc_calc = masked_crc32c(header_str) if header_crc_calc != crc_header[0]: raise errors.DataLossError( - None, None, - '{} failed header crc32 check'.format(self.filename) + None, None, "{} failed header crc32 check".format(self.filename) ) # The length of the header tells us how many bytes the Event @@ -221,11 +229,12 @@ def GetNext(self): # has no crc32, in which case we skip. crc_event_str = self.file_handle.read(4) if crc_event_str: - crc_event = struct.unpack('I', crc_event_str) + crc_event = struct.unpack("I", crc_event_str) if event_crc_calc != crc_event[0]: raise errors.DataLossError( - None, None, - '{} failed event crc32 check'.format(self.filename) + None, + None, + "{} failed event crc32 check".format(self.filename), ) # Set the current event to be read later by record() call diff --git a/tensorboard/compat/tensorflow_stub/tensor_shape.py b/tensorboard/compat/tensorflow_stub/tensor_shape.py index 1133b0d087..bdb3f7e075 100644 --- a/tensorboard/compat/tensorflow_stub/tensor_shape.py +++ b/tensorboard/compat/tensorflow_stub/tensor_shape.py @@ -49,7 +49,8 @@ def __str__(self): return "?" if value is None else str(value) def __eq__(self, other): - """Returns true if `other` has the same known value as this Dimension.""" + """Returns true if `other` has the same known value as this + Dimension.""" try: other = as_dimension(other) except (TypeError, ValueError): @@ -98,10 +99,15 @@ def is_convertible_with(self, other): True if this Dimension and `other` are convertible. """ other = as_dimension(other) - return self._value is None or other.value is None or self._value == other.value + return ( + self._value is None + or other.value is None + or self._value == other.value + ) def assert_is_convertible_with(self, other): - """Raises an exception if `other` is not convertible with this Dimension. + """Raises an exception if `other` is not convertible with this + Dimension. Args: other: Another Dimension. @@ -111,10 +117,13 @@ def assert_is_convertible_with(self, other): is_convertible_with). """ if not self.is_convertible_with(other): - raise ValueError("Dimensions %s and %s are not convertible" % (self, other)) + raise ValueError( + "Dimensions %s and %s are not convertible" % (self, other) + ) def merge_with(self, other): - """Returns a Dimension that combines the information in `self` and `other`. + """Returns a Dimension that combines the information in `self` and + `other`. Dimensions are combined as follows: @@ -433,7 +442,8 @@ def __gt__(self, other): return self._value > other.value def __ge__(self, other): - """Returns True if `self` is known to be greater than or equal to `other`. + """Returns True if `self` is known to be greater than or equal to + `other`. Dimensions are compared as follows: @@ -554,7 +564,8 @@ def __str__(self): @property def dims(self): - """Returns a list of Dimensions, or None if the shape is unspecified.""" + """Returns a list of Dimensions, or None if the shape is + unspecified.""" return self._dims @dims.setter @@ -573,9 +584,12 @@ def ndims(self): return self._ndims def __len__(self): - """Returns the rank of this shape, or raises ValueError if unspecified.""" + """Returns the rank of this shape, or raises ValueError if + unspecified.""" if self._dims is None: - raise ValueError("Cannot take the length of Shape with unknown rank.") + raise ValueError( + "Cannot take the length of Shape with unknown rank." + ) return self.ndims def __bool__(self): @@ -586,7 +600,8 @@ def __bool__(self): __nonzero__ = __bool__ def __iter__(self): - """Returns `self.dims` if the rank is known, otherwise raises ValueError.""" + """Returns `self.dims` if the rank is known, otherwise raises + ValueError.""" if self._dims is None: raise ValueError("Cannot iterate over a shape with unknown rank.") else: @@ -637,7 +652,8 @@ def __getitem__(self, key): return Dimension(None) def num_elements(self): - """Returns the total number of elements, or none for incomplete shapes.""" + """Returns the total number of elements, or none for incomplete + shapes.""" if self.is_fully_defined(): size = 1 for dim in self._dims: @@ -647,7 +663,8 @@ def num_elements(self): return None def merge_with(self, other): - """Returns a `TensorShape` combining the information in `self` and `other`. + """Returns a `TensorShape` combining the information in `self` and + `other`. The dimensions in `self` and `other` are merged elementwise, according to the rules defined for `Dimension.merge_with()`. @@ -673,7 +690,9 @@ def merge_with(self, other): new_dims.append(dim.merge_with(other[i])) return TensorShape(new_dims) except ValueError: - raise ValueError("Shapes %s and %s are not convertible" % (self, other)) + raise ValueError( + "Shapes %s and %s are not convertible" % (self, other) + ) def concatenate(self, other): """Returns the concatenation of the dimension in `self` and `other`. @@ -699,7 +718,8 @@ def concatenate(self, other): return TensorShape(self._dims + other.dims) def assert_same_rank(self, other): - """Raises an exception if `self` and `other` do not have convertible ranks. + """Raises an exception if `self` and `other` do not have convertible + ranks. Args: other: Another `TensorShape`. @@ -716,7 +736,8 @@ def assert_same_rank(self, other): ) def assert_has_rank(self, rank): - """Raises an exception if `self` is not convertible with the given `rank`. + """Raises an exception if `self` is not convertible with the given + `rank`. Args: rank: An integer. @@ -762,7 +783,9 @@ def with_rank_at_least(self, rank): `rank`. """ if self.ndims is not None and self.ndims < rank: - raise ValueError("Shape %s must have rank at least %d" % (self, rank)) + raise ValueError( + "Shape %s must have rank at least %d" % (self, rank) + ) else: return self @@ -781,7 +804,9 @@ def with_rank_at_most(self, rank): `rank`. """ if self.ndims is not None and self.ndims > rank: - raise ValueError("Shape %s must have rank at most %d" % (self, rank)) + raise ValueError( + "Shape %s must have rank at most %d" % (self, rank) + ) else: return self @@ -821,7 +846,6 @@ def is_convertible_with(self, other): Returns: True iff `self` is convertible with `other`. - """ other = as_shape(other) if self._dims is not None and other.dims is not None: @@ -833,7 +857,8 @@ def is_convertible_with(self, other): return True def assert_is_convertible_with(self, other): - """Raises exception if `self` and `other` do not represent the same shape. + """Raises exception if `self` and `other` do not represent the same + shape. This method can be used to assert that there exists a shape that both `self` and `other` represent. @@ -845,10 +870,13 @@ def assert_is_convertible_with(self, other): ValueError: If `self` and `other` do not represent the same shape. """ if not self.is_convertible_with(other): - raise ValueError("Shapes %s and %s are inconvertible" % (self, other)) + raise ValueError( + "Shapes %s and %s are inconvertible" % (self, other) + ) def most_specific_convertible_shape(self, other): - """Returns the most specific TensorShape convertible with `self` and `other`. + """Returns the most specific TensorShape convertible with `self` and + `other`. * TensorShape([None, 1]) is the most specific TensorShape convertible with both TensorShape([2, 1]) and TensorShape([5, 1]). Note that @@ -868,7 +896,11 @@ def most_specific_convertible_shape(self, other): """ other = as_shape(other) - if self._dims is None or other.dims is None or self.ndims != other.ndims: + if ( + self._dims is None + or other.dims is None + or self.ndims != other.ndims + ): return unknown_shape() dims = [(Dimension(None))] * self.ndims @@ -884,7 +916,8 @@ def is_fully_defined(self): ) def assert_is_fully_defined(self): - """Raises an exception if `self` is not fully defined in every dimension. + """Raises an exception if `self` is not fully defined in every + dimension. Raises: ValueError: If `self` does not have a known value for every dimension. @@ -902,7 +935,9 @@ def as_list(self): ValueError: If `self` is an unknown shape with an unknown rank. """ if self._dims is None: - raise ValueError("as_list() is not defined on an unknown TensorShape.") + raise ValueError( + "as_list() is not defined on an unknown TensorShape." + ) return [dim.value for dim in self._dims] def as_proto(self): @@ -934,7 +969,9 @@ def __ne__(self, other): except TypeError: return NotImplemented if self.ndims is None or other.ndims is None: - raise ValueError("The inequality of unknown TensorShapes is undefined.") + raise ValueError( + "The inequality of unknown TensorShapes is undefined." + ) if self.ndims != other.ndims: return True return self._dims != other.dims diff --git a/tensorboard/data/provider.py b/tensorboard/data/provider.py index ad2dee8d49..047c574f28 100644 --- a/tensorboard/data/provider.py +++ b/tensorboard/data/provider.py @@ -27,750 +27,769 @@ @six.add_metaclass(abc.ABCMeta) class DataProvider(object): - """Interface for reading TensorBoard scalar, tensor, and blob data. + """Interface for reading TensorBoard scalar, tensor, and blob data. - These APIs are under development and subject to change. For instance, - providers may be asked to implement more filtering mechanisms, such as - downsampling strategies or domain restriction by step or wall time. + These APIs are under development and subject to change. For instance, + providers may be asked to implement more filtering mechanisms, such as + downsampling strategies or domain restriction by step or wall time. - Unless otherwise noted, any methods on this class may raise errors - defined in `tensorboard.errors`, like `tensorboard.errors.NotFoundError`. - """ - - def data_location(self, experiment_id): - """Render a human-readable description of the data source. - - For instance, this might return a path to a directory on disk. - - The default implementation always returns the empty string. - - Args: - experiment_id: ID of enclosing experiment. - - Returns: - A string, which may be empty. + Unless otherwise noted, any methods on this class may raise errors + defined in `tensorboard.errors`, like `tensorboard.errors.NotFoundError`. """ - return "" - - @abc.abstractmethod - def list_runs(self, experiment_id): - """List all runs within an experiment. - Args: - experiment_id: ID of enclosing experiment. + def data_location(self, experiment_id): + """Render a human-readable description of the data source. + + For instance, this might return a path to a directory on disk. + + The default implementation always returns the empty string. + + Args: + experiment_id: ID of enclosing experiment. + + Returns: + A string, which may be empty. + """ + return "" + + @abc.abstractmethod + def list_runs(self, experiment_id): + """List all runs within an experiment. + + Args: + experiment_id: ID of enclosing experiment. + + Returns: + A collection of `Run` values. + + Raises: + tensorboard.errors.PublicError: See `DataProvider` class docstring. + """ + pass + + @abc.abstractmethod + def list_scalars(self, experiment_id, plugin_name, run_tag_filter=None): + """List metadata about scalar time series. + + Args: + experiment_id: ID of enclosing experiment. + plugin_name: String name of the TensorBoard plugin that created + the data to be queried. Required. + run_tag_filter: Optional `RunTagFilter` value. If omitted, all + runs and tags will be included. + + The result will only contain keys for run-tag combinations that + actually exist, which may not include all entries in the + `run_tag_filter`. + + Returns: + A nested map `d` such that `d[run][tag]` is a `ScalarTimeSeries` + value. + + Raises: + tensorboard.errors.PublicError: See `DataProvider` class docstring. + """ + pass + + @abc.abstractmethod + def read_scalars( + self, experiment_id, plugin_name, downsample=None, run_tag_filter=None + ): + """Read values from scalar time series. + + Args: + experiment_id: ID of enclosing experiment. + plugin_name: String name of the TensorBoard plugin that created + the data to be queried. Required. + downsample: Integer number of steps to which to downsample the + results (e.g., `1000`). Required. + run_tag_filter: Optional `RunTagFilter` value. If provided, a time + series will only be included in the result if its run and tag + both pass this filter. If `None`, all time series will be + included. + + The result will only contain keys for run-tag combinations that + actually exist, which may not include all entries in the + `run_tag_filter`. + + Returns: + A nested map `d` such that `d[run][tag]` is a list of + `ScalarDatum` values sorted by step. + + Raises: + tensorboard.errors.PublicError: See `DataProvider` class docstring. + """ + pass + + def list_tensors(self, experiment_id, plugin_name, run_tag_filter=None): + """List metadata about tensor time series. + + Args: + experiment_id: ID of enclosing experiment. + plugin_name: String name of the TensorBoard plugin that created + the data to be queried. Required. + run_tag_filter: Optional `RunTagFilter` value. If omitted, all + runs and tags will be included. + + The result will only contain keys for run-tag combinations that + actually exist, which may not include all entries in the + `run_tag_filter`. + + Returns: + A nested map `d` such that `d[run][tag]` is a `TensorTimeSeries` + value. + + Raises: + tensorboard.errors.PublicError: See `DataProvider` class docstring. + """ + pass + + def read_tensors( + self, experiment_id, plugin_name, downsample=None, run_tag_filter=None + ): + """Read values from tensor time series. + + Args: + experiment_id: ID of enclosing experiment. + plugin_name: String name of the TensorBoard plugin that created + the data to be queried. Required. + downsample: Integer number of steps to which to downsample the + results (e.g., `1000`). Required. + run_tag_filter: Optional `RunTagFilter` value. If provided, a time + series will only be included in the result if its run and tag + both pass this filter. If `None`, all time series will be + included. + + The result will only contain keys for run-tag combinations that + actually exist, which may not include all entries in the + `run_tag_filter`. + + Returns: + A nested map `d` such that `d[run][tag]` is a list of + `TensorDatum` values sorted by step. + + Raises: + tensorboard.errors.PublicError: See `DataProvider` class docstring. + """ + pass + + def list_blob_sequences( + self, experiment_id, plugin_name, run_tag_filter=None + ): + """List metadata about blob sequence time series. + + Args: + experiment_id: ID of enclosing experiment. + plugin_name: String name of the TensorBoard plugin that created the data + to be queried. Required. + run_tag_filter: Optional `RunTagFilter` value. If omitted, all runs and + tags will be included. The result will only contain keys for run-tag + combinations that actually exist, which may not include all entries in + the `run_tag_filter`. + + Returns: + A nested map `d` such that `d[run][tag]` is a `BlobSequenceTimeSeries` + value. + + Raises: + tensorboard.errors.PublicError: See `DataProvider` class docstring. + """ + pass + + def read_blob_sequences( + self, experiment_id, plugin_name, downsample=None, run_tag_filter=None + ): + """Read values from blob sequence time series. + + Args: + experiment_id: ID of enclosing experiment. + plugin_name: String name of the TensorBoard plugin that created the data + to be queried. Required. + downsample: Integer number of steps to which to downsample the results + (e.g., `1000`). Required. + run_tag_filter: Optional `RunTagFilter` value. If provided, a time series + will only be included in the result if its run and tag both pass this + filter. If `None`, all time series will be included. The result will + only contain keys for run-tag combinations that actually exist, which + may not include all entries in the `run_tag_filter`. + + Returns: + A nested map `d` such that `d[run][tag]` is a list of + `BlobSequenceDatum` values sorted by step. + + Raises: + tensorboard.errors.PublicError: See `DataProvider` class docstring. + """ + pass + + def read_blob(self, blob_key): + """Read data for a single blob. + + Args: + blob_key: A key identifying the desired blob, as provided by + `read_blob_sequences(...)`. + + Returns: + Raw binary data as `bytes`. + + Raises: + tensorboard.errors.PublicError: See `DataProvider` class docstring. + """ + pass - Returns: - A collection of `Run` values. - Raises: - tensorboard.errors.PublicError: See `DataProvider` class docstring. +class Run(object): + """Metadata about a run. + + Attributes: + run_id: A unique opaque string identifier for this run. + run_name: A user-facing name for this run (as a `str`). + start_time: The wall time of the earliest recorded event in this + run, as `float` seconds since epoch, or `None` if this run has no + recorded events. """ - pass - - @abc.abstractmethod - def list_scalars(self, experiment_id, plugin_name, run_tag_filter=None): - """List metadata about scalar time series. - Args: - experiment_id: ID of enclosing experiment. - plugin_name: String name of the TensorBoard plugin that created - the data to be queried. Required. - run_tag_filter: Optional `RunTagFilter` value. If omitted, all - runs and tags will be included. + __slots__ = ("_run_id", "_run_name", "_start_time") + + def __init__(self, run_id, run_name, start_time): + self._run_id = run_id + self._run_name = run_name + self._start_time = start_time + + @property + def run_id(self): + return self._run_id + + @property + def run_name(self): + return self._run_name + + @property + def start_time(self): + return self._start_time + + def __eq__(self, other): + if not isinstance(other, Run): + return False + if self._run_id != other._run_id: + return False + if self._run_name != other._run_name: + return False + if self._start_time != other._start_time: + return False + return True + + def __hash__(self): + return hash((self._run_id, self._run_name, self._start_time)) + + def __repr__(self): + return "Run(%s)" % ", ".join( + ( + "run_id=%r" % (self._run_id,), + "run_name=%r" % (self._run_name,), + "start_time=%r" % (self._start_time,), + ) + ) - The result will only contain keys for run-tag combinations that - actually exist, which may not include all entries in the - `run_tag_filter`. - Returns: - A nested map `d` such that `d[run][tag]` is a `ScalarTimeSeries` - value. +class _TimeSeries(object): + """Metadata about time series data for a particular run and tag. - Raises: - tensorboard.errors.PublicError: See `DataProvider` class docstring. + Superclass of `ScalarTimeSeries` and `BlobSequenceTimeSeries`. """ - pass - - @abc.abstractmethod - def read_scalars( - self, experiment_id, plugin_name, downsample=None, run_tag_filter=None - ): - """Read values from scalar time series. - - Args: - experiment_id: ID of enclosing experiment. - plugin_name: String name of the TensorBoard plugin that created - the data to be queried. Required. - downsample: Integer number of steps to which to downsample the - results (e.g., `1000`). Required. - run_tag_filter: Optional `RunTagFilter` value. If provided, a time - series will only be included in the result if its run and tag - both pass this filter. If `None`, all time series will be - included. - - The result will only contain keys for run-tag combinations that - actually exist, which may not include all entries in the - `run_tag_filter`. - - Returns: - A nested map `d` such that `d[run][tag]` is a list of - `ScalarDatum` values sorted by step. - - Raises: - tensorboard.errors.PublicError: See `DataProvider` class docstring. - """ - pass - def list_tensors(self, experiment_id, plugin_name, run_tag_filter=None): - """List metadata about tensor time series. + __slots__ = ( + "_max_step", + "_max_wall_time", + "_plugin_content", + "_description", + "_display_name", + ) - Args: - experiment_id: ID of enclosing experiment. - plugin_name: String name of the TensorBoard plugin that created - the data to be queried. Required. - run_tag_filter: Optional `RunTagFilter` value. If omitted, all - runs and tags will be included. + def __init__( + self, max_step, max_wall_time, plugin_content, description, display_name + ): + self._max_step = max_step + self._max_wall_time = max_wall_time + self._plugin_content = plugin_content + self._description = description + self._display_name = display_name - The result will only contain keys for run-tag combinations that - actually exist, which may not include all entries in the - `run_tag_filter`. + @property + def max_step(self): + return self._max_step - Returns: - A nested map `d` such that `d[run][tag]` is a `TensorTimeSeries` - value. + @property + def max_wall_time(self): + return self._max_wall_time - Raises: - tensorboard.errors.PublicError: See `DataProvider` class docstring. - """ - pass - - def read_tensors( - self, experiment_id, plugin_name, downsample=None, run_tag_filter=None - ): - """Read values from tensor time series. - - Args: - experiment_id: ID of enclosing experiment. - plugin_name: String name of the TensorBoard plugin that created - the data to be queried. Required. - downsample: Integer number of steps to which to downsample the - results (e.g., `1000`). Required. - run_tag_filter: Optional `RunTagFilter` value. If provided, a time - series will only be included in the result if its run and tag - both pass this filter. If `None`, all time series will be - included. - - The result will only contain keys for run-tag combinations that - actually exist, which may not include all entries in the - `run_tag_filter`. - - Returns: - A nested map `d` such that `d[run][tag]` is a list of - `TensorDatum` values sorted by step. - - Raises: - tensorboard.errors.PublicError: See `DataProvider` class docstring. - """ - pass - - def list_blob_sequences( - self, experiment_id, plugin_name, run_tag_filter=None - ): - """List metadata about blob sequence time series. - - Args: - experiment_id: ID of enclosing experiment. - plugin_name: String name of the TensorBoard plugin that created the data - to be queried. Required. - run_tag_filter: Optional `RunTagFilter` value. If omitted, all runs and - tags will be included. The result will only contain keys for run-tag - combinations that actually exist, which may not include all entries in - the `run_tag_filter`. - - Returns: - A nested map `d` such that `d[run][tag]` is a `BlobSequenceTimeSeries` - value. - - Raises: - tensorboard.errors.PublicError: See `DataProvider` class docstring. - """ - pass - - def read_blob_sequences( - self, experiment_id, plugin_name, downsample=None, run_tag_filter=None - ): - """Read values from blob sequence time series. - - Args: - experiment_id: ID of enclosing experiment. - plugin_name: String name of the TensorBoard plugin that created the data - to be queried. Required. - downsample: Integer number of steps to which to downsample the results - (e.g., `1000`). Required. - run_tag_filter: Optional `RunTagFilter` value. If provided, a time series - will only be included in the result if its run and tag both pass this - filter. If `None`, all time series will be included. The result will - only contain keys for run-tag combinations that actually exist, which - may not include all entries in the `run_tag_filter`. - - Returns: - A nested map `d` such that `d[run][tag]` is a list of - `BlobSequenceDatum` values sorted by step. - - Raises: - tensorboard.errors.PublicError: See `DataProvider` class docstring. - """ - pass + @property + def plugin_content(self): + return self._plugin_content - def read_blob(self, blob_key): - """Read data for a single blob. + @property + def description(self): + return self._description - Args: - blob_key: A key identifying the desired blob, as provided by - `read_blob_sequences(...)`. + @property + def display_name(self): + return self._display_name - Returns: - Raw binary data as `bytes`. - Raises: - tensorboard.errors.PublicError: See `DataProvider` class docstring. +class ScalarTimeSeries(_TimeSeries): + """Metadata about a scalar time series for a particular run and tag. + + Attributes: + max_step: The largest step value of any datum in this scalar time series; a + nonnegative integer. + max_wall_time: The largest wall time of any datum in this time series, as + `float` seconds since epoch. + plugin_content: A bytestring of arbitrary plugin-specific metadata for this + time series, as provided to `tf.summary.write` in the + `plugin_data.content` field of the `metadata` argument. + description: An optional long-form Markdown description, as a `str` that is + empty if no description was specified. + display_name: An optional long-form Markdown description, as a `str` that is + empty if no description was specified. Deprecated; may be removed soon. """ - pass + def __eq__(self, other): + if not isinstance(other, ScalarTimeSeries): + return False + if self._max_step != other._max_step: + return False + if self._max_wall_time != other._max_wall_time: + return False + if self._plugin_content != other._plugin_content: + return False + if self._description != other._description: + return False + if self._display_name != other._display_name: + return False + return True + + def __hash__(self): + return hash( + ( + self._max_step, + self._max_wall_time, + self._plugin_content, + self._description, + self._display_name, + ) + ) + + def __repr__(self): + return "ScalarTimeSeries(%s)" % ", ".join( + ( + "max_step=%r" % (self._max_step,), + "max_wall_time=%r" % (self._max_wall_time,), + "plugin_content=%r" % (self._plugin_content,), + "description=%r" % (self._description,), + "display_name=%r" % (self._display_name,), + ) + ) -class Run(object): - """Metadata about a run. - - Attributes: - run_id: A unique opaque string identifier for this run. - run_name: A user-facing name for this run (as a `str`). - start_time: The wall time of the earliest recorded event in this - run, as `float` seconds since epoch, or `None` if this run has no - recorded events. - """ - - __slots__ = ("_run_id", "_run_name", "_start_time") - - def __init__(self, run_id, run_name, start_time): - self._run_id = run_id - self._run_name = run_name - self._start_time = start_time - - @property - def run_id(self): - return self._run_id - - @property - def run_name(self): - return self._run_name - - @property - def start_time(self): - return self._start_time - - def __eq__(self, other): - if not isinstance(other, Run): - return False - if self._run_id != other._run_id: - return False - if self._run_name != other._run_name: - return False - if self._start_time != other._start_time: - return False - return True - - def __hash__(self): - return hash((self._run_id, self._run_name, self._start_time)) - - def __repr__(self): - return "Run(%s)" % ", ".join(( - "run_id=%r" % (self._run_id,), - "run_name=%r" % (self._run_name,), - "start_time=%r" % (self._start_time,), - )) +class ScalarDatum(object): + """A single datum in a scalar time series for a run and tag. + + Attributes: + step: The global step at which this datum occurred; an integer. This + is a unique key among data of this time series. + wall_time: The real-world time at which this datum occurred, as + `float` seconds since epoch. + value: The scalar value for this datum; a `float`. + """ -class _TimeSeries(object): - """Metadata about time series data for a particular run and tag. + __slots__ = ("_step", "_wall_time", "_value") + + def __init__(self, step, wall_time, value): + self._step = step + self._wall_time = wall_time + self._value = value + + @property + def step(self): + return self._step + + @property + def wall_time(self): + return self._wall_time + + @property + def value(self): + return self._value + + def __eq__(self, other): + if not isinstance(other, ScalarDatum): + return False + if self._step != other._step: + return False + if self._wall_time != other._wall_time: + return False + if self._value != other._value: + return False + return True + + def __hash__(self): + return hash((self._step, self._wall_time, self._value)) + + def __repr__(self): + return "ScalarDatum(%s)" % ", ".join( + ( + "step=%r" % (self._step,), + "wall_time=%r" % (self._wall_time,), + "value=%r" % (self._value,), + ) + ) - Superclass of `ScalarTimeSeries` and `BlobSequenceTimeSeries`. - """ - __slots__ = ( - "_max_step", - "_max_wall_time", - "_plugin_content", - "_description", - "_display_name", - ) +class TensorTimeSeries(_TimeSeries): + """Metadata about a tensor time series for a particular run and tag. + + Attributes: + max_step: The largest step value of any datum in this tensor time series; a + nonnegative integer. + max_wall_time: The largest wall time of any datum in this time series, as + `float` seconds since epoch. + plugin_content: A bytestring of arbitrary plugin-specific metadata for this + time series, as provided to `tf.summary.write` in the + `plugin_data.content` field of the `metadata` argument. + description: An optional long-form Markdown description, as a `str` that is + empty if no description was specified. + display_name: An optional long-form Markdown description, as a `str` that is + empty if no description was specified. Deprecated; may be removed soon. + """ - def __init__( - self, max_step, max_wall_time, plugin_content, description, display_name - ): - self._max_step = max_step - self._max_wall_time = max_wall_time - self._plugin_content = plugin_content - self._description = description - self._display_name = display_name + def __eq__(self, other): + if not isinstance(other, TensorTimeSeries): + return False + if self._max_step != other._max_step: + return False + if self._max_wall_time != other._max_wall_time: + return False + if self._plugin_content != other._plugin_content: + return False + if self._description != other._description: + return False + if self._display_name != other._display_name: + return False + return True + + def __hash__(self): + return hash( + ( + self._max_step, + self._max_wall_time, + self._plugin_content, + self._description, + self._display_name, + ) + ) + + def __repr__(self): + return "TensorTimeSeries(%s)" % ", ".join( + ( + "max_step=%r" % (self._max_step,), + "max_wall_time=%r" % (self._max_wall_time,), + "plugin_content=%r" % (self._plugin_content,), + "description=%r" % (self._description,), + "display_name=%r" % (self._display_name,), + ) + ) - @property - def max_step(self): - return self._max_step - @property - def max_wall_time(self): - return self._max_wall_time +class TensorDatum(object): + """A single datum in a tensor time series for a run and tag. + + Attributes: + step: The global step at which this datum occurred; an integer. This + is a unique key among data of this time series. + wall_time: The real-world time at which this datum occurred, as + `float` seconds since epoch. + numpy: The `numpy.ndarray` value with the tensor contents of this + datum. + """ - @property - def plugin_content(self): - return self._plugin_content + __slots__ = ("_step", "_wall_time", "_numpy") + + def __init__(self, step, wall_time, numpy): + self._step = step + self._wall_time = wall_time + self._numpy = numpy + + @property + def step(self): + return self._step + + @property + def wall_time(self): + return self._wall_time + + @property + def numpy(self): + return self._numpy + + def __eq__(self, other): + if not isinstance(other, TensorDatum): + return False + if self._step != other._step: + return False + if self._wall_time != other._wall_time: + return False + if not np.array_equal(self._numpy, other._numpy): + return False + return True + + # Unhashable type: numpy arrays are mutable. + __hash__ = None + + def __repr__(self): + return "TensorDatum(%s)" % ", ".join( + ( + "step=%r" % (self._step,), + "wall_time=%r" % (self._wall_time,), + "numpy=%r" % (self._numpy,), + ) + ) - @property - def description(self): - return self._description - @property - def display_name(self): - return self._display_name +class BlobSequenceTimeSeries(_TimeSeries): + """Metadata about a blob sequence time series for a particular run and tag. + + Attributes: + max_step: The largest step value of any datum in this scalar time series; a + nonnegative integer. + max_wall_time: The largest wall time of any datum in this time series, as + `float` seconds since epoch. + latest_max_index: The largest index in the sequence at max_step (0-based, + inclusive). + plugin_content: A bytestring of arbitrary plugin-specific metadata for this + time series, as provided to `tf.summary.write` in the + `plugin_data.content` field of the `metadata` argument. + description: An optional long-form Markdown description, as a `str` that is + empty if no description was specified. + display_name: An optional long-form Markdown description, as a `str` that is + empty if no description was specified. Deprecated; may be removed soon. + """ + __slots__ = ("_latest_max_index",) + + def __init__( + self, + max_step, + max_wall_time, + latest_max_index, + plugin_content, + description, + display_name, + ): + super(BlobSequenceTimeSeries, self).__init__( + max_step, max_wall_time, plugin_content, description, display_name + ) + self._latest_max_index = latest_max_index + + @property + def latest_max_index(self): + return self._latest_max_index + + def __eq__(self, other): + if not isinstance(other, BlobSequenceTimeSeries): + return False + if self._max_step != other._max_step: + return False + if self._max_wall_time != other._max_wall_time: + return False + if self._latest_max_index != other._latest_max_index: + return False + if self._plugin_content != other._plugin_content: + return False + if self._description != other._description: + return False + if self._display_name != other._display_name: + return False + return True + + def __hash__(self): + return hash( + ( + self._max_step, + self._max_wall_time, + self._latest_max_index, + self._plugin_content, + self._description, + self._display_name, + ) + ) + + def __repr__(self): + return "BlobSequenceTimeSeries(%s)" % ", ".join( + ( + "max_step=%r" % (self._max_step,), + "max_wall_time=%r" % (self._max_wall_time,), + "latest_max_index=%r" % (self._latest_max_index,), + "plugin_content=%r" % (self._plugin_content,), + "description=%r" % (self._description,), + "display_name=%r" % (self._display_name,), + ) + ) -class ScalarTimeSeries(_TimeSeries): - """Metadata about a scalar time series for a particular run and tag. - - Attributes: - max_step: The largest step value of any datum in this scalar time series; a - nonnegative integer. - max_wall_time: The largest wall time of any datum in this time series, as - `float` seconds since epoch. - plugin_content: A bytestring of arbitrary plugin-specific metadata for this - time series, as provided to `tf.summary.write` in the - `plugin_data.content` field of the `metadata` argument. - description: An optional long-form Markdown description, as a `str` that is - empty if no description was specified. - display_name: An optional long-form Markdown description, as a `str` that is - empty if no description was specified. Deprecated; may be removed soon. - """ - - def __eq__(self, other): - if not isinstance(other, ScalarTimeSeries): - return False - if self._max_step != other._max_step: - return False - if self._max_wall_time != other._max_wall_time: - return False - if self._plugin_content != other._plugin_content: - return False - if self._description != other._description: - return False - if self._display_name != other._display_name: - return False - return True - - def __hash__(self): - return hash(( - self._max_step, - self._max_wall_time, - self._plugin_content, - self._description, - self._display_name, - )) - - def __repr__(self): - return "ScalarTimeSeries(%s)" % ", ".join(( - "max_step=%r" % (self._max_step,), - "max_wall_time=%r" % (self._max_wall_time,), - "plugin_content=%r" % (self._plugin_content,), - "description=%r" % (self._description,), - "display_name=%r" % (self._display_name,), - )) +class BlobReference(object): + """A reference to a blob. + + Attributes: + blob_key: A string containing a key uniquely identifying a blob, which + may be dereferenced via `provider.read_blob(blob_key)`. + + These keys must be constructed such that they can be included directly in + a URL, with no further encoding. Concretely, this means that they consist + exclusively of "unreserved characters" per RFC 3986, namely + [a-zA-Z0-9._~-]. These keys are case-sensitive; it may be wise for + implementations to normalize case to reduce confusion. The empty string + is not a valid key. + + Blob keys must not contain information that should be kept secret. + Privacy-sensitive applications should use random keys (e.g. UUIDs), or + encrypt keys containing secret fields. + url: (optional) A string containing a URL from which the blob data may be + fetched directly, bypassing the data provider. URLs may be a vector + for data leaks (e.g. via browser history, web proxies, etc.), so these + URLs should not expose secret information. + """ -class ScalarDatum(object): - """A single datum in a scalar time series for a run and tag. - - Attributes: - step: The global step at which this datum occurred; an integer. This - is a unique key among data of this time series. - wall_time: The real-world time at which this datum occurred, as - `float` seconds since epoch. - value: The scalar value for this datum; a `float`. - """ - - __slots__ = ("_step", "_wall_time", "_value") - - def __init__(self, step, wall_time, value): - self._step = step - self._wall_time = wall_time - self._value = value - - @property - def step(self): - return self._step - - @property - def wall_time(self): - return self._wall_time - - @property - def value(self): - return self._value - - def __eq__(self, other): - if not isinstance(other, ScalarDatum): - return False - if self._step != other._step: - return False - if self._wall_time != other._wall_time: - return False - if self._value != other._value: - return False - return True - - def __hash__(self): - return hash((self._step, self._wall_time, self._value)) - - def __repr__(self): - return "ScalarDatum(%s)" % ", ".join(( - "step=%r" % (self._step,), - "wall_time=%r" % (self._wall_time,), - "value=%r" % (self._value,), - )) + __slots__ = ("_url", "_blob_key") + def __init__(self, blob_key, url=None): + self._blob_key = blob_key + self._url = url -class TensorTimeSeries(_TimeSeries): - """Metadata about a tensor time series for a particular run and tag. - - Attributes: - max_step: The largest step value of any datum in this tensor time series; a - nonnegative integer. - max_wall_time: The largest wall time of any datum in this time series, as - `float` seconds since epoch. - plugin_content: A bytestring of arbitrary plugin-specific metadata for this - time series, as provided to `tf.summary.write` in the - `plugin_data.content` field of the `metadata` argument. - description: An optional long-form Markdown description, as a `str` that is - empty if no description was specified. - display_name: An optional long-form Markdown description, as a `str` that is - empty if no description was specified. Deprecated; may be removed soon. - """ - - def __eq__(self, other): - if not isinstance(other, TensorTimeSeries): - return False - if self._max_step != other._max_step: - return False - if self._max_wall_time != other._max_wall_time: - return False - if self._plugin_content != other._plugin_content: - return False - if self._description != other._description: - return False - if self._display_name != other._display_name: - return False - return True - - def __hash__(self): - return hash(( - self._max_step, - self._max_wall_time, - self._plugin_content, - self._description, - self._display_name, - )) - - def __repr__(self): - return "TensorTimeSeries(%s)" % ", ".join(( - "max_step=%r" % (self._max_step,), - "max_wall_time=%r" % (self._max_wall_time,), - "plugin_content=%r" % (self._plugin_content,), - "description=%r" % (self._description,), - "display_name=%r" % (self._display_name,), - )) + @property + def blob_key(self): + """Provide a key uniquely identifying a blob. + Callers should consider these keys to be opaque-- i.e., to have + no intrinsic meaning. Some data providers may use random IDs; + but others may encode information into the key, in which case + callers must make no attempt to decode it. + """ + return self._blob_key -class TensorDatum(object): - """A single datum in a tensor time series for a run and tag. - - Attributes: - step: The global step at which this datum occurred; an integer. This - is a unique key among data of this time series. - wall_time: The real-world time at which this datum occurred, as - `float` seconds since epoch. - numpy: The `numpy.ndarray` value with the tensor contents of this - datum. - """ - - __slots__ = ("_step", "_wall_time", "_numpy") - - def __init__(self, step, wall_time, numpy): - self._step = step - self._wall_time = wall_time - self._numpy = numpy - - @property - def step(self): - return self._step - - @property - def wall_time(self): - return self._wall_time - - @property - def numpy(self): - return self._numpy - - def __eq__(self, other): - if not isinstance(other, TensorDatum): - return False - if self._step != other._step: - return False - if self._wall_time != other._wall_time: - return False - if not np.array_equal(self._numpy, other._numpy): - return False - return True - - # Unhashable type: numpy arrays are mutable. - __hash__ = None - - def __repr__(self): - return "TensorDatum(%s)" % ", ".join(( - "step=%r" % (self._step,), - "wall_time=%r" % (self._wall_time,), - "numpy=%r" % (self._numpy,), - )) + @property + def url(self): + """Provide the direct-access URL for this blob, if available. + Note that this method is *not* expected to construct a URL to + the data-loading endpoint provided by TensorBoard. If this + method returns None, then the caller should proceed to use + `blob_key()` to build the URL, as needed. + """ + return self._url -class BlobSequenceTimeSeries(_TimeSeries): - """Metadata about a blob sequence time series for a particular run and tag. - - Attributes: - max_step: The largest step value of any datum in this scalar time series; a - nonnegative integer. - max_wall_time: The largest wall time of any datum in this time series, as - `float` seconds since epoch. - latest_max_index: The largest index in the sequence at max_step (0-based, - inclusive). - plugin_content: A bytestring of arbitrary plugin-specific metadata for this - time series, as provided to `tf.summary.write` in the - `plugin_data.content` field of the `metadata` argument. - description: An optional long-form Markdown description, as a `str` that is - empty if no description was specified. - display_name: An optional long-form Markdown description, as a `str` that is - empty if no description was specified. Deprecated; may be removed soon. - """ - - __slots__ = ("_latest_max_index",) - - def __init__( - self, - max_step, - max_wall_time, - latest_max_index, - plugin_content, - description, - display_name, - ): - super(BlobSequenceTimeSeries, self).__init__( - max_step, max_wall_time, plugin_content, description, display_name - ) - self._latest_max_index = latest_max_index - - @property - def latest_max_index(self): - return self._latest_max_index - - def __eq__(self, other): - if not isinstance(other, BlobSequenceTimeSeries): - return False - if self._max_step != other._max_step: - return False - if self._max_wall_time != other._max_wall_time: - return False - if self._latest_max_index != other._latest_max_index: - return False - if self._plugin_content != other._plugin_content: - return False - if self._description != other._description: - return False - if self._display_name != other._display_name: - return False - return True - - def __hash__(self): - return hash(( - self._max_step, - self._max_wall_time, - self._latest_max_index, - self._plugin_content, - self._description, - self._display_name, - )) - - def __repr__(self): - return "BlobSequenceTimeSeries(%s)" % ", ".join(( - "max_step=%r" % (self._max_step,), - "max_wall_time=%r" % (self._max_wall_time,), - "latest_max_index=%r" % (self._latest_max_index,), - "plugin_content=%r" % (self._plugin_content,), - "description=%r" % (self._description,), - "display_name=%r" % (self._display_name,), - )) + def __eq__(self, other): + if not isinstance(other, BlobReference): + return False + if self._blob_key != other._blob_key: + return False + if self._url != other._url: + return False + return True + def __hash__(self): + return hash((self._blob_key, self._url)) -class BlobReference(object): - """A reference to a blob. - - Attributes: - blob_key: A string containing a key uniquely identifying a blob, which - may be dereferenced via `provider.read_blob(blob_key)`. - - These keys must be constructed such that they can be included directly in - a URL, with no further encoding. Concretely, this means that they consist - exclusively of "unreserved characters" per RFC 3986, namely - [a-zA-Z0-9._~-]. These keys are case-sensitive; it may be wise for - implementations to normalize case to reduce confusion. The empty string - is not a valid key. - - Blob keys must not contain information that should be kept secret. - Privacy-sensitive applications should use random keys (e.g. UUIDs), or - encrypt keys containing secret fields. - url: (optional) A string containing a URL from which the blob data may be - fetched directly, bypassing the data provider. URLs may be a vector - for data leaks (e.g. via browser history, web proxies, etc.), so these - URLs should not expose secret information. - """ - - __slots__ = ("_url", "_blob_key") - - def __init__(self, blob_key, url=None): - self._blob_key = blob_key - self._url = url - - @property - def blob_key(self): - """Provide a key uniquely identifying a blob. - - Callers should consider these keys to be opaque-- i.e., to have no intrinsic - meaning. Some data providers may use random IDs; but others may encode - information into the key, in which case callers must make no attempt to - decode it. - """ - return self._blob_key + def __repr__(self): + return "BlobReference(%s)" % ", ".join( + ("blob_key=%r" % (self._blob_key,), "url=%r" % (self._url,)) + ) - @property - def url(self): - """Provide the direct-access URL for this blob, if available. - Note that this method is *not* expected to construct a URL to the - data-loading endpoint provided by TensorBoard. If this method returns - None, then the caller should proceed to use `blob_key()` to build the URL, - as needed. +class BlobSequenceDatum(object): + """A single datum in a blob sequence time series for a run and tag. + + Attributes: + step: The global step at which this datum occurred; an integer. This is a + unique key among data of this time series. + wall_time: The real-world time at which this datum occurred, as `float` + seconds since epoch. + values: A tuple of `BlobReference` objects, providing access to elements of + this sequence. """ - return self._url - - def __eq__(self, other): - if not isinstance(other, BlobReference): - return False - if self._blob_key != other._blob_key: - return False - if self._url != other._url: - return False - return True - - def __hash__(self): - return hash((self._blob_key, self._url)) - - def __repr__(self): - return "BlobReference(%s)" % ", ".join( - ("blob_key=%r" % (self._blob_key,), "url=%r" % (self._url,)) - ) - -class BlobSequenceDatum(object): - """A single datum in a blob sequence time series for a run and tag. - - Attributes: - step: The global step at which this datum occurred; an integer. This is a - unique key among data of this time series. - wall_time: The real-world time at which this datum occurred, as `float` - seconds since epoch. - values: A tuple of `BlobReference` objects, providing access to elements of - this sequence. - """ - - __slots__ = ("_step", "_wall_time", "_values") - - def __init__(self, step, wall_time, values): - self._step = step - self._wall_time = wall_time - self._values = values - - @property - def step(self): - return self._step - - @property - def wall_time(self): - return self._wall_time - - @property - def values(self): - return self._values - - def __eq__(self, other): - if not isinstance(other, BlobSequenceDatum): - return False - if self._step != other._step: - return False - if self._wall_time != other._wall_time: - return False - if self._values != other._values: - return False - return True - - def __hash__(self): - return hash((self._step, self._wall_time, self._values)) - - def __repr__(self): - return "BlobSequenceDatum(%s)" % ", ".join(( - "step=%r" % (self._step,), - "wall_time=%r" % (self._wall_time,), - "values=%r" % (self._values,), - )) + __slots__ = ("_step", "_wall_time", "_values") + + def __init__(self, step, wall_time, values): + self._step = step + self._wall_time = wall_time + self._values = values + + @property + def step(self): + return self._step + + @property + def wall_time(self): + return self._wall_time + + @property + def values(self): + return self._values + + def __eq__(self, other): + if not isinstance(other, BlobSequenceDatum): + return False + if self._step != other._step: + return False + if self._wall_time != other._wall_time: + return False + if self._values != other._values: + return False + return True + + def __hash__(self): + return hash((self._step, self._wall_time, self._values)) + + def __repr__(self): + return "BlobSequenceDatum(%s)" % ", ".join( + ( + "step=%r" % (self._step,), + "wall_time=%r" % (self._wall_time,), + "values=%r" % (self._values,), + ) + ) class RunTagFilter(object): - """Filters data by run and tag names.""" - - def __init__(self, runs=None, tags=None): - """Construct a `RunTagFilter`. - - A time series passes this filter if both its run *and* its tag are - included in the corresponding whitelists. - - Order and multiplicity are ignored; `runs` and `tags` are treated as - sets. - - Args: - runs: Collection of run names, as strings, or `None` to admit all - runs. - tags: Collection of tag names, as strings, or `None` to admit all - tags. - """ - self._runs = None if runs is None else frozenset(runs) - self._tags = None if tags is None else frozenset(tags) - - @property - def runs(self): - return self._runs - - @property - def tags(self): - return self._tags - - def __repr__(self): - return "RunTagFilter(%s)" % ", ".join(( - "runs=%r" % (self._runs,), - "tags=%r" % (self._tags,), - )) + """Filters data by run and tag names.""" + + def __init__(self, runs=None, tags=None): + """Construct a `RunTagFilter`. + + A time series passes this filter if both its run *and* its tag are + included in the corresponding whitelists. + + Order and multiplicity are ignored; `runs` and `tags` are treated as + sets. + + Args: + runs: Collection of run names, as strings, or `None` to admit all + runs. + tags: Collection of tag names, as strings, or `None` to admit all + tags. + """ + self._runs = None if runs is None else frozenset(runs) + self._tags = None if tags is None else frozenset(tags) + + @property + def runs(self): + return self._runs + + @property + def tags(self): + return self._tags + + def __repr__(self): + return "RunTagFilter(%s)" % ", ".join( + ("runs=%r" % (self._runs,), "tags=%r" % (self._tags,),) + ) diff --git a/tensorboard/data/provider_test.py b/tensorboard/data/provider_test.py index ae2e72f36f..8a8e8fc17c 100644 --- a/tensorboard/data/provider_test.py +++ b/tensorboard/data/provider_test.py @@ -26,270 +26,300 @@ class DataProviderTest(tb_test.TestCase): - def test_abstract(self): - with six.assertRaisesRegex(self, TypeError, "abstract class"): - provider.DataProvider() + def test_abstract(self): + with six.assertRaisesRegex(self, TypeError, "abstract class"): + provider.DataProvider() class RunTest(tb_test.TestCase): - def test_eq(self): - a1 = provider.Run(run_id="a", run_name="aa", start_time=1.25) - a2 = provider.Run(run_id="a", run_name="aa", start_time=1.25) - b = provider.Run(run_id="b", run_name="bb", start_time=-1.75) - self.assertEqual(a1, a2) - self.assertNotEqual(a1, b) - self.assertNotEqual(b, object()) - - def test_repr(self): - x = provider.Run(run_id="alpha", run_name="bravo", start_time=1.25) - repr_ = repr(x) - self.assertIn(repr(x.run_id), repr_) - self.assertIn(repr(x.run_name), repr_) - self.assertIn(repr(x.start_time), repr_) + def test_eq(self): + a1 = provider.Run(run_id="a", run_name="aa", start_time=1.25) + a2 = provider.Run(run_id="a", run_name="aa", start_time=1.25) + b = provider.Run(run_id="b", run_name="bb", start_time=-1.75) + self.assertEqual(a1, a2) + self.assertNotEqual(a1, b) + self.assertNotEqual(b, object()) + + def test_repr(self): + x = provider.Run(run_id="alpha", run_name="bravo", start_time=1.25) + repr_ = repr(x) + self.assertIn(repr(x.run_id), repr_) + self.assertIn(repr(x.run_name), repr_) + self.assertIn(repr(x.start_time), repr_) class ScalarTimeSeriesTest(tb_test.TestCase): - def test_repr(self): - x = provider.ScalarTimeSeries( - max_step=77, - max_wall_time=1234.5, - plugin_content=b"AB\xCD\xEF!\x00", - description="test test", - display_name="one two", - ) - repr_ = repr(x) - self.assertIn(repr(x.max_step), repr_) - self.assertIn(repr(x.max_wall_time), repr_) - self.assertIn(repr(x.plugin_content), repr_) - self.assertIn(repr(x.description), repr_) - self.assertIn(repr(x.display_name), repr_) - - def test_eq(self): - x1 = provider.ScalarTimeSeries(77, 1234.5, b"\x12", "one", "two") - x2 = provider.ScalarTimeSeries(77, 1234.5, b"\x12", "one", "two") - x3 = provider.ScalarTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum") - self.assertEqual(x1, x2) - self.assertNotEqual(x1, x3) - self.assertNotEqual(x1, object()) - - def test_hash(self): - x1 = provider.ScalarTimeSeries(77, 1234.5, b"\x12", "one", "two") - x2 = provider.ScalarTimeSeries(77, 1234.5, b"\x12", "one", "two") - x3 = provider.ScalarTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum") - self.assertEqual(hash(x1), hash(x2)) - # The next check is technically not required by the `__hash__` - # contract, but _should_ pass; failure on this assertion would at - # least warrant some scrutiny. - self.assertNotEqual(hash(x1), hash(x3)) + def test_repr(self): + x = provider.ScalarTimeSeries( + max_step=77, + max_wall_time=1234.5, + plugin_content=b"AB\xCD\xEF!\x00", + description="test test", + display_name="one two", + ) + repr_ = repr(x) + self.assertIn(repr(x.max_step), repr_) + self.assertIn(repr(x.max_wall_time), repr_) + self.assertIn(repr(x.plugin_content), repr_) + self.assertIn(repr(x.description), repr_) + self.assertIn(repr(x.display_name), repr_) + + def test_eq(self): + x1 = provider.ScalarTimeSeries(77, 1234.5, b"\x12", "one", "two") + x2 = provider.ScalarTimeSeries(77, 1234.5, b"\x12", "one", "two") + x3 = provider.ScalarTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum") + self.assertEqual(x1, x2) + self.assertNotEqual(x1, x3) + self.assertNotEqual(x1, object()) + + def test_hash(self): + x1 = provider.ScalarTimeSeries(77, 1234.5, b"\x12", "one", "two") + x2 = provider.ScalarTimeSeries(77, 1234.5, b"\x12", "one", "two") + x3 = provider.ScalarTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum") + self.assertEqual(hash(x1), hash(x2)) + # The next check is technically not required by the `__hash__` + # contract, but _should_ pass; failure on this assertion would at + # least warrant some scrutiny. + self.assertNotEqual(hash(x1), hash(x3)) class ScalarDatumTest(tb_test.TestCase): - def test_repr(self): - x = provider.ScalarDatum(step=123, wall_time=234.5, value=-0.125) - repr_ = repr(x) - self.assertIn(repr(x.step), repr_) - self.assertIn(repr(x.wall_time), repr_) - self.assertIn(repr(x.value), repr_) - - def test_eq(self): - x1 = provider.ScalarDatum(step=12, wall_time=0.25, value=1.25) - x2 = provider.ScalarDatum(step=12, wall_time=0.25, value=1.25) - x3 = provider.ScalarDatum(step=23, wall_time=3.25, value=-0.5) - self.assertEqual(x1, x2) - self.assertNotEqual(x1, x3) - self.assertNotEqual(x1, object()) - - def test_hash(self): - x1 = provider.ScalarDatum(step=12, wall_time=0.25, value=1.25) - x2 = provider.ScalarDatum(step=12, wall_time=0.25, value=1.25) - x3 = provider.ScalarDatum(step=23, wall_time=3.25, value=-0.5) - self.assertEqual(hash(x1), hash(x2)) - # The next check is technically not required by the `__hash__` - # contract, but _should_ pass; failure on this assertion would at - # least warrant some scrutiny. - self.assertNotEqual(hash(x1), hash(x3)) + def test_repr(self): + x = provider.ScalarDatum(step=123, wall_time=234.5, value=-0.125) + repr_ = repr(x) + self.assertIn(repr(x.step), repr_) + self.assertIn(repr(x.wall_time), repr_) + self.assertIn(repr(x.value), repr_) + + def test_eq(self): + x1 = provider.ScalarDatum(step=12, wall_time=0.25, value=1.25) + x2 = provider.ScalarDatum(step=12, wall_time=0.25, value=1.25) + x3 = provider.ScalarDatum(step=23, wall_time=3.25, value=-0.5) + self.assertEqual(x1, x2) + self.assertNotEqual(x1, x3) + self.assertNotEqual(x1, object()) + + def test_hash(self): + x1 = provider.ScalarDatum(step=12, wall_time=0.25, value=1.25) + x2 = provider.ScalarDatum(step=12, wall_time=0.25, value=1.25) + x3 = provider.ScalarDatum(step=23, wall_time=3.25, value=-0.5) + self.assertEqual(hash(x1), hash(x2)) + # The next check is technically not required by the `__hash__` + # contract, but _should_ pass; failure on this assertion would at + # least warrant some scrutiny. + self.assertNotEqual(hash(x1), hash(x3)) class TensorTimeSeriesTest(tb_test.TestCase): - def test_repr(self): - x = provider.TensorTimeSeries( - max_step=77, - max_wall_time=1234.5, - plugin_content=b"AB\xCD\xEF!\x00", - description="test test", - display_name="one two", - ) - repr_ = repr(x) - self.assertIn(repr(x.max_step), repr_) - self.assertIn(repr(x.max_wall_time), repr_) - self.assertIn(repr(x.plugin_content), repr_) - self.assertIn(repr(x.description), repr_) - self.assertIn(repr(x.display_name), repr_) - - def test_eq(self): - x1 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two") - x2 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two") - x3 = provider.TensorTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum") - self.assertEqual(x1, x2) - self.assertNotEqual(x1, x3) - self.assertNotEqual(x1, object()) - - def test_hash(self): - x1 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two") - x2 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two") - x3 = provider.TensorTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum") - self.assertEqual(hash(x1), hash(x2)) - # The next check is technically not required by the `__hash__` - # contract, but _should_ pass; failure on this assertion would at - # least warrant some scrutiny. - self.assertNotEqual(hash(x1), hash(x3)) + def test_repr(self): + x = provider.TensorTimeSeries( + max_step=77, + max_wall_time=1234.5, + plugin_content=b"AB\xCD\xEF!\x00", + description="test test", + display_name="one two", + ) + repr_ = repr(x) + self.assertIn(repr(x.max_step), repr_) + self.assertIn(repr(x.max_wall_time), repr_) + self.assertIn(repr(x.plugin_content), repr_) + self.assertIn(repr(x.description), repr_) + self.assertIn(repr(x.display_name), repr_) + + def test_eq(self): + x1 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two") + x2 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two") + x3 = provider.TensorTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum") + self.assertEqual(x1, x2) + self.assertNotEqual(x1, x3) + self.assertNotEqual(x1, object()) + + def test_hash(self): + x1 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two") + x2 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two") + x3 = provider.TensorTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum") + self.assertEqual(hash(x1), hash(x2)) + # The next check is technically not required by the `__hash__` + # contract, but _should_ pass; failure on this assertion would at + # least warrant some scrutiny. + self.assertNotEqual(hash(x1), hash(x3)) class TensorDatumTest(tb_test.TestCase): - def test_repr(self): - x = provider.TensorDatum(step=123, wall_time=234.5, numpy=np.array(-0.25)) - repr_ = repr(x) - self.assertIn(repr(x.step), repr_) - self.assertIn(repr(x.wall_time), repr_) - self.assertIn(repr(x.numpy), repr_) - - def test_eq(self): - nd = np.array - x1 = provider.TensorDatum(step=12, wall_time=0.25, numpy=nd([1.0, 2.0])) - x2 = provider.TensorDatum(step=12, wall_time=0.25, numpy=nd([1.0, 2.0])) - x3 = provider.TensorDatum(step=23, wall_time=3.25, numpy=nd([-0.5, -2.5])) - self.assertEqual(x1, x2) - self.assertNotEqual(x1, x3) - self.assertNotEqual(x1, object()) - - def test_eq_with_rank0_tensor(self): - x1 = provider.TensorDatum(step=12, wall_time=0.25, numpy=np.array([1.25])) - x2 = provider.TensorDatum(step=12, wall_time=0.25, numpy=np.array([1.25])) - x3 = provider.TensorDatum(step=23, wall_time=3.25, numpy=np.array([1.25])) - self.assertEqual(x1, x2) - self.assertNotEqual(x1, x3) - self.assertNotEqual(x1, object()) - - def test_hash(self): - x = provider.TensorDatum(step=12, wall_time=0.25, numpy=np.array([1.25])) - with six.assertRaisesRegex(self, TypeError, "unhashable type"): - hash(x) + def test_repr(self): + x = provider.TensorDatum( + step=123, wall_time=234.5, numpy=np.array(-0.25) + ) + repr_ = repr(x) + self.assertIn(repr(x.step), repr_) + self.assertIn(repr(x.wall_time), repr_) + self.assertIn(repr(x.numpy), repr_) + + def test_eq(self): + nd = np.array + x1 = provider.TensorDatum(step=12, wall_time=0.25, numpy=nd([1.0, 2.0])) + x2 = provider.TensorDatum(step=12, wall_time=0.25, numpy=nd([1.0, 2.0])) + x3 = provider.TensorDatum( + step=23, wall_time=3.25, numpy=nd([-0.5, -2.5]) + ) + self.assertEqual(x1, x2) + self.assertNotEqual(x1, x3) + self.assertNotEqual(x1, object()) + + def test_eq_with_rank0_tensor(self): + x1 = provider.TensorDatum( + step=12, wall_time=0.25, numpy=np.array([1.25]) + ) + x2 = provider.TensorDatum( + step=12, wall_time=0.25, numpy=np.array([1.25]) + ) + x3 = provider.TensorDatum( + step=23, wall_time=3.25, numpy=np.array([1.25]) + ) + self.assertEqual(x1, x2) + self.assertNotEqual(x1, x3) + self.assertNotEqual(x1, object()) + + def test_hash(self): + x = provider.TensorDatum( + step=12, wall_time=0.25, numpy=np.array([1.25]) + ) + with six.assertRaisesRegex(self, TypeError, "unhashable type"): + hash(x) class BlobSequenceTimeSeriesTest(tb_test.TestCase): - - def test_repr(self): - x = provider.BlobSequenceTimeSeries( - max_step=77, - max_wall_time=1234.5, - latest_max_index=6, - plugin_content=b"AB\xCD\xEF!\x00", - description="test test", - display_name="one two", - ) - repr_ = repr(x) - self.assertIn(repr(x.max_step), repr_) - self.assertIn(repr(x.max_wall_time), repr_) - self.assertIn(repr(x.latest_max_index), repr_) - self.assertIn(repr(x.plugin_content), repr_) - self.assertIn(repr(x.description), repr_) - self.assertIn(repr(x.display_name), repr_) - - def test_eq(self): - x1 = provider.BlobSequenceTimeSeries(77, 1234.5, 6, b"\x12", "one", "two") - x2 = provider.BlobSequenceTimeSeries(77, 1234.5, 6, b"\x12", "one", "two") - x3 = provider.BlobSequenceTimeSeries(66, 4321.0, 7, b"\x7F", "hmm", "hum") - self.assertEqual(x1, x2) - self.assertNotEqual(x1, x3) - self.assertNotEqual(x1, object()) - - def test_hash(self): - x1 = provider.BlobSequenceTimeSeries(77, 1234.5, 6, b"\x12", "one", "two") - x2 = provider.BlobSequenceTimeSeries(77, 1234.5, 6, b"\x12", "one", "two") - x3 = provider.BlobSequenceTimeSeries(66, 4321.0, 7, b"\x7F", "hmm", "hum") - self.assertEqual(hash(x1), hash(x2)) - # The next check is technically not required by the `__hash__` - # contract, but _should_ pass; failure on this assertion would at - # least warrant some scrutiny. - self.assertNotEqual(hash(x1), hash(x3)) + def test_repr(self): + x = provider.BlobSequenceTimeSeries( + max_step=77, + max_wall_time=1234.5, + latest_max_index=6, + plugin_content=b"AB\xCD\xEF!\x00", + description="test test", + display_name="one two", + ) + repr_ = repr(x) + self.assertIn(repr(x.max_step), repr_) + self.assertIn(repr(x.max_wall_time), repr_) + self.assertIn(repr(x.latest_max_index), repr_) + self.assertIn(repr(x.plugin_content), repr_) + self.assertIn(repr(x.description), repr_) + self.assertIn(repr(x.display_name), repr_) + + def test_eq(self): + x1 = provider.BlobSequenceTimeSeries( + 77, 1234.5, 6, b"\x12", "one", "two" + ) + x2 = provider.BlobSequenceTimeSeries( + 77, 1234.5, 6, b"\x12", "one", "two" + ) + x3 = provider.BlobSequenceTimeSeries( + 66, 4321.0, 7, b"\x7F", "hmm", "hum" + ) + self.assertEqual(x1, x2) + self.assertNotEqual(x1, x3) + self.assertNotEqual(x1, object()) + + def test_hash(self): + x1 = provider.BlobSequenceTimeSeries( + 77, 1234.5, 6, b"\x12", "one", "two" + ) + x2 = provider.BlobSequenceTimeSeries( + 77, 1234.5, 6, b"\x12", "one", "two" + ) + x3 = provider.BlobSequenceTimeSeries( + 66, 4321.0, 7, b"\x7F", "hmm", "hum" + ) + self.assertEqual(hash(x1), hash(x2)) + # The next check is technically not required by the `__hash__` + # contract, but _should_ pass; failure on this assertion would at + # least warrant some scrutiny. + self.assertNotEqual(hash(x1), hash(x3)) class BlobReferenceTest(tb_test.TestCase): - - def test_repr(self): - x = provider.BlobReference(url="foo", blob_key="baz") - repr_ = repr(x) - self.assertIn(repr(x.url), repr_) - self.assertIn(repr(x.blob_key), repr_) - - def test_eq(self): - x1 = provider.BlobReference(url="foo", blob_key="baz") - x2 = provider.BlobReference(url="foo", blob_key="baz") - x3 = provider.BlobReference(url="foo", blob_key="qux") - self.assertEqual(x1, x2) - self.assertNotEqual(x1, x3) - self.assertNotEqual(x1, object()) - - def test_hash(self): - x1 = provider.BlobReference(url="foo", blob_key="baz") - x2 = provider.BlobReference(url="foo", blob_key="baz") - x3 = provider.BlobReference(url="foo", blob_key="qux") - self.assertEqual(hash(x1), hash(x2)) - # The next check is technically not required by the `__hash__` - # contract, but _should_ pass; failure on this assertion would at - # least warrant some scrutiny. - self.assertNotEqual(hash(x1), hash(x3)) + def test_repr(self): + x = provider.BlobReference(url="foo", blob_key="baz") + repr_ = repr(x) + self.assertIn(repr(x.url), repr_) + self.assertIn(repr(x.blob_key), repr_) + + def test_eq(self): + x1 = provider.BlobReference(url="foo", blob_key="baz") + x2 = provider.BlobReference(url="foo", blob_key="baz") + x3 = provider.BlobReference(url="foo", blob_key="qux") + self.assertEqual(x1, x2) + self.assertNotEqual(x1, x3) + self.assertNotEqual(x1, object()) + + def test_hash(self): + x1 = provider.BlobReference(url="foo", blob_key="baz") + x2 = provider.BlobReference(url="foo", blob_key="baz") + x3 = provider.BlobReference(url="foo", blob_key="qux") + self.assertEqual(hash(x1), hash(x2)) + # The next check is technically not required by the `__hash__` + # contract, but _should_ pass; failure on this assertion would at + # least warrant some scrutiny. + self.assertNotEqual(hash(x1), hash(x3)) class BlobSequenceDatumTest(tb_test.TestCase): - - def test_repr(self): - x = provider.BlobSequenceDatum( - step=123, wall_time=234.5, values=("foo", "bar", "baz")) - repr_ = repr(x) - self.assertIn(repr(x.step), repr_) - self.assertIn(repr(x.wall_time), repr_) - self.assertIn(repr(x.values), repr_) - - def test_eq(self): - x1 = provider.BlobSequenceDatum( - step=12, wall_time=0.25, values=("foo", "bar", "baz")) - x2 = provider.BlobSequenceDatum( - step=12, wall_time=0.25, values=("foo", "bar", "baz")) - x3 = provider.BlobSequenceDatum(step=23, wall_time=3.25, values=("qux",)) - self.assertEqual(x1, x2) - self.assertNotEqual(x1, x3) - self.assertNotEqual(x1, object()) - - def test_hash(self): - x1 = provider.BlobSequenceDatum( - step=12, wall_time=0.25, values=("foo", "bar", "baz")) - x2 = provider.BlobSequenceDatum( - step=12, wall_time=0.25, values=("foo", "bar", "baz")) - x3 = provider.BlobSequenceDatum(step=23, wall_time=3.25, values=("qux",)) - self.assertEqual(hash(x1), hash(x2)) - # The next check is technically not required by the `__hash__` - # contract, but _should_ pass; failure on this assertion would at - # least warrant some scrutiny. - self.assertNotEqual(hash(x1), hash(x3)) + def test_repr(self): + x = provider.BlobSequenceDatum( + step=123, wall_time=234.5, values=("foo", "bar", "baz") + ) + repr_ = repr(x) + self.assertIn(repr(x.step), repr_) + self.assertIn(repr(x.wall_time), repr_) + self.assertIn(repr(x.values), repr_) + + def test_eq(self): + x1 = provider.BlobSequenceDatum( + step=12, wall_time=0.25, values=("foo", "bar", "baz") + ) + x2 = provider.BlobSequenceDatum( + step=12, wall_time=0.25, values=("foo", "bar", "baz") + ) + x3 = provider.BlobSequenceDatum( + step=23, wall_time=3.25, values=("qux",) + ) + self.assertEqual(x1, x2) + self.assertNotEqual(x1, x3) + self.assertNotEqual(x1, object()) + + def test_hash(self): + x1 = provider.BlobSequenceDatum( + step=12, wall_time=0.25, values=("foo", "bar", "baz") + ) + x2 = provider.BlobSequenceDatum( + step=12, wall_time=0.25, values=("foo", "bar", "baz") + ) + x3 = provider.BlobSequenceDatum( + step=23, wall_time=3.25, values=("qux",) + ) + self.assertEqual(hash(x1), hash(x2)) + # The next check is technically not required by the `__hash__` + # contract, but _should_ pass; failure on this assertion would at + # least warrant some scrutiny. + self.assertNotEqual(hash(x1), hash(x3)) class RunTagFilterTest(tb_test.TestCase): - def test_defensive_copy(self): - runs = ["r1"] - tags = ["t1"] - f = provider.RunTagFilter(runs, tags) - runs.append("r2") - tags.pop() - self.assertEqual(frozenset(f.runs), frozenset(["r1"])) - self.assertEqual(frozenset(f.tags), frozenset(["t1"])) - - def test_repr(self): - x = provider.RunTagFilter(runs=["one", "two"], tags=["three", "four"]) - repr_ = repr(x) - self.assertIn(repr(x.runs), repr_) - self.assertIn(repr(x.tags), repr_) + def test_defensive_copy(self): + runs = ["r1"] + tags = ["t1"] + f = provider.RunTagFilter(runs, tags) + runs.append("r2") + tags.pop() + self.assertEqual(frozenset(f.runs), frozenset(["r1"])) + self.assertEqual(frozenset(f.tags), frozenset(["t1"])) + + def test_repr(self): + x = provider.RunTagFilter(runs=["one", "two"], tags=["three", "four"]) + repr_ = repr(x) + self.assertIn(repr(x.runs), repr_) + self.assertIn(repr(x.tags), repr_) if __name__ == "__main__": - tb_test.main() + tb_test.main() diff --git a/tensorboard/data_compat.py b/tensorboard/data_compat.py index fe50829d58..4232b62f7d 100644 --- a/tensorboard/data_compat.py +++ b/tensorboard/data_compat.py @@ -30,81 +30,89 @@ def migrate_value(value): - """Convert `value` to a new-style value, if necessary and possible. - - An "old-style" value is a value that uses any `value` field other than - the `tensor` field. A "new-style" value is a value that uses the - `tensor` field. TensorBoard continues to support old-style values on - disk; this method converts them to new-style values so that further - code need only deal with one data format. - - Arguments: - value: A `Summary.Value` object. This argument is not modified. - - Returns: - If the `value` is an old-style value for which there is a new-style - equivalent, the result is the new-style value. Otherwise---if the - value is already new-style or does not yet have a new-style - equivalent---the value will be returned unchanged. - - :type value: Summary.Value - :rtype: Summary.Value - """ - handler = { - 'histo': _migrate_histogram_value, - 'image': _migrate_image_value, - 'audio': _migrate_audio_value, - 'simple_value': _migrate_scalar_value, - }.get(value.WhichOneof('value')) - return handler(value) if handler else value + """Convert `value` to a new-style value, if necessary and possible. + + An "old-style" value is a value that uses any `value` field other than + the `tensor` field. A "new-style" value is a value that uses the + `tensor` field. TensorBoard continues to support old-style values on + disk; this method converts them to new-style values so that further + code need only deal with one data format. + + Arguments: + value: A `Summary.Value` object. This argument is not modified. + + Returns: + If the `value` is an old-style value for which there is a new-style + equivalent, the result is the new-style value. Otherwise---if the + value is already new-style or does not yet have a new-style + equivalent---the value will be returned unchanged. + + :type value: Summary.Value + :rtype: Summary.Value + """ + handler = { + "histo": _migrate_histogram_value, + "image": _migrate_image_value, + "audio": _migrate_audio_value, + "simple_value": _migrate_scalar_value, + }.get(value.WhichOneof("value")) + return handler(value) if handler else value def make_summary(tag, metadata, data): tensor_proto = tensor_util.make_tensor_proto(data) - return summary_pb2.Summary.Value(tag=tag, - metadata=metadata, - tensor=tensor_proto) + return summary_pb2.Summary.Value( + tag=tag, metadata=metadata, tensor=tensor_proto + ) def _migrate_histogram_value(value): - histogram_value = value.histo - bucket_lefts = [histogram_value.min] + histogram_value.bucket_limit[:-1] - bucket_rights = histogram_value.bucket_limit[:-1] + [histogram_value.max] - bucket_counts = histogram_value.bucket - buckets = np.array([bucket_lefts, bucket_rights, bucket_counts], dtype=np.float32).transpose() + histogram_value = value.histo + bucket_lefts = [histogram_value.min] + histogram_value.bucket_limit[:-1] + bucket_rights = histogram_value.bucket_limit[:-1] + [histogram_value.max] + bucket_counts = histogram_value.bucket + buckets = np.array( + [bucket_lefts, bucket_rights, bucket_counts], dtype=np.float32 + ).transpose() - summary_metadata = histogram_metadata.create_summary_metadata( - display_name=value.metadata.display_name or value.tag, - description=value.metadata.summary_description) + summary_metadata = histogram_metadata.create_summary_metadata( + display_name=value.metadata.display_name or value.tag, + description=value.metadata.summary_description, + ) - return make_summary(value.tag, summary_metadata, buckets) + return make_summary(value.tag, summary_metadata, buckets) def _migrate_image_value(value): - image_value = value.image - data = [tf.compat.as_bytes(str(image_value.width)), - tf.compat.as_bytes(str(image_value.height)), - tf.compat.as_bytes(image_value.encoded_image_string)] + image_value = value.image + data = [ + tf.compat.as_bytes(str(image_value.width)), + tf.compat.as_bytes(str(image_value.height)), + tf.compat.as_bytes(image_value.encoded_image_string), + ] - summary_metadata = image_metadata.create_summary_metadata( - display_name=value.metadata.display_name or value.tag, - description=value.metadata.summary_description) - return make_summary(value.tag, summary_metadata, data) + summary_metadata = image_metadata.create_summary_metadata( + display_name=value.metadata.display_name or value.tag, + description=value.metadata.summary_description, + ) + return make_summary(value.tag, summary_metadata, data) def _migrate_audio_value(value): - audio_value = value.audio - data = [[audio_value.encoded_audio_string, b'']] # empty label - summary_metadata = audio_metadata.create_summary_metadata( - display_name=value.metadata.display_name or value.tag, - description=value.metadata.summary_description, - encoding=audio_metadata.Encoding.Value('WAV')) - return make_summary(value.tag, summary_metadata, data) + audio_value = value.audio + data = [[audio_value.encoded_audio_string, b""]] # empty label + summary_metadata = audio_metadata.create_summary_metadata( + display_name=value.metadata.display_name or value.tag, + description=value.metadata.summary_description, + encoding=audio_metadata.Encoding.Value("WAV"), + ) + return make_summary(value.tag, summary_metadata, data) def _migrate_scalar_value(value): - scalar_value = value.simple_value - summary_metadata = scalar_metadata.create_summary_metadata( - display_name=value.metadata.display_name or value.tag, - description=value.metadata.summary_description) - return make_summary(value.tag, summary_metadata, scalar_value) + scalar_value = value.simple_value + summary_metadata = scalar_metadata.create_summary_metadata( + display_name=value.metadata.display_name or value.tag, + description=value.metadata.summary_description, + ) + return make_summary(value.tag, summary_metadata, scalar_value) diff --git a/tensorboard/data_compat_test.py b/tensorboard/data_compat_test.py index c5e395fe84..73ca9e86a4 100644 --- a/tensorboard/data_compat_test.py +++ b/tensorboard/data_compat_test.py @@ -32,179 +32,204 @@ from tensorboard.util import tensor_util - class MigrateValueTest(tf.test.TestCase): - """Tests for `migrate_value`. - - These tests should ensure that all first-party new-style values are - passed through unchanged, that all supported old-style values are - converted to new-style values, and that other old-style values are - passed through unchanged. - """ - - def _value_from_op(self, op): - with tf.compat.v1.Session() as sess: - summary_pbtxt = sess.run(op) - summary = summary_pb2.Summary() - summary.ParseFromString(summary_pbtxt) - # There may be multiple values (e.g., for an image summary that emits - # multiple images in one batch). That's fine; we'll choose any - # representative value, assuming that they're homogeneous. - assert summary.value - return summary.value[0] - - def _assert_noop(self, value): - original_pbtxt = value.SerializeToString() - result = data_compat.migrate_value(value) - self.assertEqual(value, result) - self.assertEqual(original_pbtxt, value.SerializeToString()) - - def test_scalar(self): - with tf.compat.v1.Graph().as_default(): - old_op = tf.compat.v1.summary.scalar('important_constants', tf.constant(0x5f3759df)) - old_value = self._value_from_op(old_op) - assert old_value.HasField('simple_value'), old_value - new_value = data_compat.migrate_value(old_value) - - self.assertEqual('important_constants', new_value.tag) - expected_metadata = scalar_metadata.create_summary_metadata( - display_name='important_constants', - description='') - self.assertEqual(expected_metadata, new_value.metadata) - self.assertTrue(new_value.HasField('tensor')) - data = tensor_util.make_ndarray(new_value.tensor) - self.assertEqual((), data.shape) - low_precision_value = np.array(0x5f3759df).astype('float32').item() - self.assertEqual(low_precision_value, data.item()) - - def test_audio(self): - with tf.compat.v1.Graph().as_default(): - audio = tf.reshape(tf.linspace(0.0, 100.0, 4 * 10 * 2), (4, 10, 2)) - old_op = tf.compat.v1.summary.audio('k488', audio, 44100) - old_value = self._value_from_op(old_op) - assert old_value.HasField('audio'), old_value - new_value = data_compat.migrate_value(old_value) - - self.assertEqual('k488/audio/0', new_value.tag) - expected_metadata = audio_metadata.create_summary_metadata( - display_name='k488/audio/0', - description='', - encoding=audio_metadata.Encoding.Value('WAV')) - self.assertEqual(expected_metadata, new_value.metadata) - self.assertTrue(new_value.HasField('tensor')) - data = tensor_util.make_ndarray(new_value.tensor) - self.assertEqual((1, 2), data.shape) - self.assertEqual(tf.compat.as_bytes(old_value.audio.encoded_audio_string), - data[0][0]) - self.assertEqual(b'', data[0][1]) # empty label - - def test_text(self): - with tf.compat.v1.Graph().as_default(): - op = tf.compat.v1.summary.text('lorem_ipsum', tf.constant('dolor sit amet')) - value = self._value_from_op(op) - assert value.HasField('tensor'), value - self._assert_noop(value) - - def test_fully_populated_tensor(self): - with tf.compat.v1.Graph().as_default(): - metadata = summary_pb2.SummaryMetadata( - plugin_data=summary_pb2.SummaryMetadata.PluginData( - plugin_name='font_of_wisdom', - content=b'adobe_garamond')) - op = tf.compat.v1.summary.tensor_summary( - name='tensorpocalypse', - tensor=tf.constant([[0.0, 2.0], [float('inf'), float('nan')]]), - display_name='TENSORPOCALYPSE', - summary_description='look on my works ye mighty and despair', - summary_metadata=metadata) - value = self._value_from_op(op) - assert value.HasField('tensor'), value - self._assert_noop(value) - - def test_image(self): - with tf.compat.v1.Graph().as_default(): - old_op = tf.compat.v1.summary.image('mona_lisa', - tf.image.convert_image_dtype( - tf.random.normal(shape=[1, 400, 200, 3]), - tf.uint8, - saturate=True)) - old_value = self._value_from_op(old_op) - assert old_value.HasField('image'), old_value - new_value = data_compat.migrate_value(old_value) - - self.assertEqual('mona_lisa/image/0', new_value.tag) - expected_metadata = image_metadata.create_summary_metadata( - display_name='mona_lisa/image/0', description='') - self.assertEqual(expected_metadata, new_value.metadata) - self.assertTrue(new_value.HasField('tensor')) - (width, height, data) = tensor_util.make_ndarray(new_value.tensor) - self.assertEqual(b'200', width) - self.assertEqual(b'400', height) - self.assertEqual( - tf.compat.as_bytes(old_value.image.encoded_image_string), data) - - def test_histogram(self): - with tf.compat.v1.Graph().as_default(): - old_op = tf.compat.v1.summary.histogram('important_data', - tf.random.normal(shape=[23, 45])) - old_value = self._value_from_op(old_op) - assert old_value.HasField('histo'), old_value - new_value = data_compat.migrate_value(old_value) - - self.assertEqual('important_data', new_value.tag) - expected_metadata = histogram_metadata.create_summary_metadata( - display_name='important_data', description='') - self.assertEqual(expected_metadata, new_value.metadata) - self.assertTrue(new_value.HasField('tensor')) - buckets = tensor_util.make_ndarray(new_value.tensor) - self.assertEqual(old_value.histo.min, buckets[0][0]) - self.assertEqual(old_value.histo.max, buckets[-1][1]) - self.assertEqual(23 * 45, buckets[:, 2].astype(int).sum()) - - def test_new_style_histogram(self): - with tf.compat.v1.Graph().as_default(): - op = histogram_summary.op('important_data', - tf.random.normal(shape=[10, 10]), - bucket_count=100, - display_name='Important data', - description='secrets of the universe') - value = self._value_from_op(op) - assert value.HasField('tensor'), value - self._assert_noop(value) - - def test_new_style_image(self): - with tf.compat.v1.Graph().as_default(): - op = image_summary.op( - 'mona_lisa', - tf.image.convert_image_dtype( - tf.random.normal(shape=[1, 400, 200, 3]), tf.uint8, saturate=True), - display_name='The Mona Lisa', - description='A renowned portrait by da Vinci.') - value = self._value_from_op(op) - assert value.HasField('tensor'), value - self._assert_noop(value) - - def test_new_style_audio(self): - with tf.compat.v1.Graph().as_default(): - audio = tf.reshape(tf.linspace(0.0, 100.0, 4 * 10 * 2), (4, 10, 2)) - op = audio_summary.op('k488', - tf.cast(audio, tf.float32), - sample_rate=44100, - display_name='Piano Concerto No.23', - description='In **A major**.') - value = self._value_from_op(op) - assert value.HasField('tensor'), value - self._assert_noop(value) - - def test_new_style_scalar(self): - with tf.compat.v1.Graph().as_default(): - op = scalar_summary.op('important_constants', tf.constant(0x5f3759df), - display_name='Important constants', - description='evil floating point bit magic') - value = self._value_from_op(op) - assert value.HasField('tensor'), value - self._assert_noop(value) - - -if __name__ == '__main__': - tf.test.main() + """Tests for `migrate_value`. + + These tests should ensure that all first-party new-style values are + passed through unchanged, that all supported old-style values are + converted to new-style values, and that other old-style values are + passed through unchanged. + """ + + def _value_from_op(self, op): + with tf.compat.v1.Session() as sess: + summary_pbtxt = sess.run(op) + summary = summary_pb2.Summary() + summary.ParseFromString(summary_pbtxt) + # There may be multiple values (e.g., for an image summary that emits + # multiple images in one batch). That's fine; we'll choose any + # representative value, assuming that they're homogeneous. + assert summary.value + return summary.value[0] + + def _assert_noop(self, value): + original_pbtxt = value.SerializeToString() + result = data_compat.migrate_value(value) + self.assertEqual(value, result) + self.assertEqual(original_pbtxt, value.SerializeToString()) + + def test_scalar(self): + with tf.compat.v1.Graph().as_default(): + old_op = tf.compat.v1.summary.scalar( + "important_constants", tf.constant(0x5F3759DF) + ) + old_value = self._value_from_op(old_op) + assert old_value.HasField("simple_value"), old_value + new_value = data_compat.migrate_value(old_value) + + self.assertEqual("important_constants", new_value.tag) + expected_metadata = scalar_metadata.create_summary_metadata( + display_name="important_constants", description="" + ) + self.assertEqual(expected_metadata, new_value.metadata) + self.assertTrue(new_value.HasField("tensor")) + data = tensor_util.make_ndarray(new_value.tensor) + self.assertEqual((), data.shape) + low_precision_value = np.array(0x5F3759DF).astype("float32").item() + self.assertEqual(low_precision_value, data.item()) + + def test_audio(self): + with tf.compat.v1.Graph().as_default(): + audio = tf.reshape(tf.linspace(0.0, 100.0, 4 * 10 * 2), (4, 10, 2)) + old_op = tf.compat.v1.summary.audio("k488", audio, 44100) + old_value = self._value_from_op(old_op) + assert old_value.HasField("audio"), old_value + new_value = data_compat.migrate_value(old_value) + + self.assertEqual("k488/audio/0", new_value.tag) + expected_metadata = audio_metadata.create_summary_metadata( + display_name="k488/audio/0", + description="", + encoding=audio_metadata.Encoding.Value("WAV"), + ) + self.assertEqual(expected_metadata, new_value.metadata) + self.assertTrue(new_value.HasField("tensor")) + data = tensor_util.make_ndarray(new_value.tensor) + self.assertEqual((1, 2), data.shape) + self.assertEqual( + tf.compat.as_bytes(old_value.audio.encoded_audio_string), data[0][0] + ) + self.assertEqual(b"", data[0][1]) # empty label + + def test_text(self): + with tf.compat.v1.Graph().as_default(): + op = tf.compat.v1.summary.text( + "lorem_ipsum", tf.constant("dolor sit amet") + ) + value = self._value_from_op(op) + assert value.HasField("tensor"), value + self._assert_noop(value) + + def test_fully_populated_tensor(self): + with tf.compat.v1.Graph().as_default(): + metadata = summary_pb2.SummaryMetadata( + plugin_data=summary_pb2.SummaryMetadata.PluginData( + plugin_name="font_of_wisdom", content=b"adobe_garamond" + ) + ) + op = tf.compat.v1.summary.tensor_summary( + name="tensorpocalypse", + tensor=tf.constant([[0.0, 2.0], [float("inf"), float("nan")]]), + display_name="TENSORPOCALYPSE", + summary_description="look on my works ye mighty and despair", + summary_metadata=metadata, + ) + value = self._value_from_op(op) + assert value.HasField("tensor"), value + self._assert_noop(value) + + def test_image(self): + with tf.compat.v1.Graph().as_default(): + old_op = tf.compat.v1.summary.image( + "mona_lisa", + tf.image.convert_image_dtype( + tf.random.normal(shape=[1, 400, 200, 3]), + tf.uint8, + saturate=True, + ), + ) + old_value = self._value_from_op(old_op) + assert old_value.HasField("image"), old_value + new_value = data_compat.migrate_value(old_value) + + self.assertEqual("mona_lisa/image/0", new_value.tag) + expected_metadata = image_metadata.create_summary_metadata( + display_name="mona_lisa/image/0", description="" + ) + self.assertEqual(expected_metadata, new_value.metadata) + self.assertTrue(new_value.HasField("tensor")) + (width, height, data) = tensor_util.make_ndarray(new_value.tensor) + self.assertEqual(b"200", width) + self.assertEqual(b"400", height) + self.assertEqual( + tf.compat.as_bytes(old_value.image.encoded_image_string), data + ) + + def test_histogram(self): + with tf.compat.v1.Graph().as_default(): + old_op = tf.compat.v1.summary.histogram( + "important_data", tf.random.normal(shape=[23, 45]) + ) + old_value = self._value_from_op(old_op) + assert old_value.HasField("histo"), old_value + new_value = data_compat.migrate_value(old_value) + + self.assertEqual("important_data", new_value.tag) + expected_metadata = histogram_metadata.create_summary_metadata( + display_name="important_data", description="" + ) + self.assertEqual(expected_metadata, new_value.metadata) + self.assertTrue(new_value.HasField("tensor")) + buckets = tensor_util.make_ndarray(new_value.tensor) + self.assertEqual(old_value.histo.min, buckets[0][0]) + self.assertEqual(old_value.histo.max, buckets[-1][1]) + self.assertEqual(23 * 45, buckets[:, 2].astype(int).sum()) + + def test_new_style_histogram(self): + with tf.compat.v1.Graph().as_default(): + op = histogram_summary.op( + "important_data", + tf.random.normal(shape=[10, 10]), + bucket_count=100, + display_name="Important data", + description="secrets of the universe", + ) + value = self._value_from_op(op) + assert value.HasField("tensor"), value + self._assert_noop(value) + + def test_new_style_image(self): + with tf.compat.v1.Graph().as_default(): + op = image_summary.op( + "mona_lisa", + tf.image.convert_image_dtype( + tf.random.normal(shape=[1, 400, 200, 3]), + tf.uint8, + saturate=True, + ), + display_name="The Mona Lisa", + description="A renowned portrait by da Vinci.", + ) + value = self._value_from_op(op) + assert value.HasField("tensor"), value + self._assert_noop(value) + + def test_new_style_audio(self): + with tf.compat.v1.Graph().as_default(): + audio = tf.reshape(tf.linspace(0.0, 100.0, 4 * 10 * 2), (4, 10, 2)) + op = audio_summary.op( + "k488", + tf.cast(audio, tf.float32), + sample_rate=44100, + display_name="Piano Concerto No.23", + description="In **A major**.", + ) + value = self._value_from_op(op) + assert value.HasField("tensor"), value + self._assert_noop(value) + + def test_new_style_scalar(self): + with tf.compat.v1.Graph().as_default(): + op = scalar_summary.op( + "important_constants", + tf.constant(0x5F3759DF), + display_name="Important constants", + description="evil floating point bit magic", + ) + value = self._value_from_op(op) + assert value.HasField("tensor"), value + self._assert_noop(value) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/default.py b/tensorboard/default.py index 9ef04d2adc..963a8fef14 100644 --- a/tensorboard/default.py +++ b/tensorboard/default.py @@ -47,7 +47,7 @@ from tensorboard.plugins.hparams import hparams_plugin from tensorboard.plugins.image import images_plugin from tensorboard.plugins.interactive_inference import ( - interactive_inference_plugin_loader + interactive_inference_plugin_loader, ) from tensorboard.plugins.pr_curve import pr_curves_plugin from tensorboard.plugins.profile import profile_plugin_loader @@ -80,39 +80,42 @@ mesh_plugin.MeshPlugin, ] + def get_plugins(): - """Returns a list specifying TensorBoard's default first-party plugins. + """Returns a list specifying TensorBoard's default first-party plugins. - Plugins are specified in this list either via a TBLoader instance to load the - plugin, or the TBPlugin class itself which will be loaded using a BasicLoader. + Plugins are specified in this list either via a TBLoader instance to load the + plugin, or the TBPlugin class itself which will be loaded using a BasicLoader. - This list can be passed to the `tensorboard.program.TensorBoard` API. + This list can be passed to the `tensorboard.program.TensorBoard` API. - Returns: - The list of default plugins. + Returns: + The list of default plugins. - :rtype: list[Type[base_plugin.TBLoader] | Type[base_plugin.TBPlugin]] - """ + :rtype: list[Type[base_plugin.TBLoader] | Type[base_plugin.TBPlugin]] + """ - return _PLUGINS[:] + return _PLUGINS[:] def get_dynamic_plugins(): - """Returns a list specifying TensorBoard's dynamically loaded plugins. + """Returns a list specifying TensorBoard's dynamically loaded plugins. - A dynamic TensorBoard plugin is specified using entry_points [1] and it is - the robust way to integrate plugins into TensorBoard. + A dynamic TensorBoard plugin is specified using entry_points [1] and it is + the robust way to integrate plugins into TensorBoard. - This list can be passed to the `tensorboard.program.TensorBoard` API. + This list can be passed to the `tensorboard.program.TensorBoard` API. - Returns: - The list of dynamic plugins. + Returns: + The list of dynamic plugins. - :rtype: list[Type[base_plugin.TBLoader] | Type[base_plugin.TBPlugin]] + :rtype: list[Type[base_plugin.TBLoader] | Type[base_plugin.TBPlugin]] - [1]: https://packaging.python.org/specifications/entry-points/ - """ - return [ - entry_point.load() - for entry_point in pkg_resources.iter_entry_points('tensorboard_plugins') - ] + [1]: https://packaging.python.org/specifications/entry-points/ + """ + return [ + entry_point.load() + for entry_point in pkg_resources.iter_entry_points( + "tensorboard_plugins" + ) + ] diff --git a/tensorboard/default_test.py b/tensorboard/default_test.py index 39665af169..17bc6c802b 100644 --- a/tensorboard/default_test.py +++ b/tensorboard/default_test.py @@ -19,10 +19,10 @@ from __future__ import print_function try: - # python version >= 3.3 - from unittest import mock + # python version >= 3.3 + from unittest import mock except ImportError: - import mock # pylint: disable=unused-import + import mock # pylint: disable=unused-import import pkg_resources @@ -32,43 +32,42 @@ class FakePlugin(base_plugin.TBPlugin): - """FakePlugin for testing.""" + """FakePlugin for testing.""" - plugin_name = 'fake' + plugin_name = "fake" class FakeEntryPoint(pkg_resources.EntryPoint): - """EntryPoint class that fake loads FakePlugin.""" + """EntryPoint class that fake loads FakePlugin.""" - @classmethod - def create(cls): - """Creates an instance of FakeEntryPoint. + @classmethod + def create(cls): + """Creates an instance of FakeEntryPoint. - Returns: - instance of FakeEntryPoint - """ - return cls('foo', 'bar') + Returns: + instance of FakeEntryPoint + """ + return cls("foo", "bar") - def load(self): - """Returns FakePlugin instead of resolving module. + def load(self): + """Returns FakePlugin instead of resolving module. - Returns: - FakePlugin - """ - return FakePlugin + Returns: + FakePlugin + """ + return FakePlugin class DefaultTest(test.TestCase): + @mock.patch.object(pkg_resources, "iter_entry_points") + def test_get_dynamic_plugin(self, mock_iter_entry_points): + mock_iter_entry_points.return_value = [FakeEntryPoint.create()] - @mock.patch.object(pkg_resources, 'iter_entry_points') - def test_get_dynamic_plugin(self, mock_iter_entry_points): - mock_iter_entry_points.return_value = [FakeEntryPoint.create()] + actual_plugins = default.get_dynamic_plugins() - actual_plugins = default.get_dynamic_plugins() - - mock_iter_entry_points.assert_called_with('tensorboard_plugins') - self.assertEqual(actual_plugins, [FakePlugin]) + mock_iter_entry_points.assert_called_with("tensorboard_plugins") + self.assertEqual(actual_plugins, [FakePlugin]) if __name__ == "__main__": - test.main() + test.main() diff --git a/tensorboard/defs/tb_proto_library_test.py b/tensorboard/defs/tb_proto_library_test.py index 425e1d06f9..3fc3bb5763 100644 --- a/tensorboard/defs/tb_proto_library_test.py +++ b/tensorboard/defs/tb_proto_library_test.py @@ -26,19 +26,19 @@ class TbProtoLibraryTest(tb_test.TestCase): - """Tests for `tb_proto_library`.""" + """Tests for `tb_proto_library`.""" - def tests_with_deps(self): - foo = test_base_pb2.Foo() - foo.foo = 1 - bar = test_downstream_pb2.Bar() - bar.foo.foo = 1 - self.assertEqual(foo, bar.foo) + def tests_with_deps(self): + foo = test_base_pb2.Foo() + foo.foo = 1 + bar = test_downstream_pb2.Bar() + bar.foo.foo = 1 + self.assertEqual(foo, bar.foo) - def test_service_deps(self): - self.assertIsInstance(test_base_pb2_grpc.FooServiceServicer, type) - self.assertIsInstance(test_downstream_pb2_grpc.BarServiceServicer, type) + def test_service_deps(self): + self.assertIsInstance(test_base_pb2_grpc.FooServiceServicer, type) + self.assertIsInstance(test_downstream_pb2_grpc.BarServiceServicer, type) if __name__ == "__main__": - tb_test.main() + tb_test.main() diff --git a/tensorboard/defs/web_test_python_stub.template.py b/tensorboard/defs/web_test_python_stub.template.py index bce71ad88d..59b40f1515 100644 --- a/tensorboard/defs/web_test_python_stub.template.py +++ b/tensorboard/defs/web_test_python_stub.template.py @@ -20,10 +20,8 @@ from tensorboard.functionaltests import wct_test_driver -Test = wct_test_driver.create_test_class( - "{BINARY_PATH}", - "{WEB_PATH}") +Test = wct_test_driver.create_test_class("{BINARY_PATH}", "{WEB_PATH}") if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/tensorboard/encode_png_benchmark.py b/tensorboard/encode_png_benchmark.py index 5550bd74cf..703482a7e3 100644 --- a/tensorboard/encode_png_benchmark.py +++ b/tensorboard/encode_png_benchmark.py @@ -64,80 +64,91 @@ def bench(image, thread_count): - """Encode `image` to PNG on `thread_count` threads in parallel. - - Returns: - A `float` representing number of seconds that it takes all threads - to finish encoding `image`. - """ - threads = [threading.Thread(target=lambda: encoder.encode_png(image)) - for _ in xrange(thread_count)] - start_time = datetime.datetime.now() - for thread in threads: - thread.start() - for thread in threads: - thread.join() - end_time = datetime.datetime.now() - delta = (end_time - start_time).total_seconds() - return delta + """Encode `image` to PNG on `thread_count` threads in parallel. + + Returns: + A `float` representing number of seconds that it takes all threads + to finish encoding `image`. + """ + threads = [ + threading.Thread(target=lambda: encoder.encode_png(image)) + for _ in xrange(thread_count) + ] + start_time = datetime.datetime.now() + for thread in threads: + thread.start() + for thread in threads: + thread.join() + end_time = datetime.datetime.now() + delta = (end_time - start_time).total_seconds() + return delta def _image_of_size(image_size): - """Generate a square RGB test image of the given side length.""" - return np.random.uniform(0, 256, [image_size, image_size, 3]).astype(np.uint8) + """Generate a square RGB test image of the given side length.""" + return np.random.uniform(0, 256, [image_size, image_size, 3]).astype( + np.uint8 + ) def _format_line(headers, fields): - """Format a line of a table. - - Arguments: - headers: A list of strings that are used as the table headers. - fields: A list of the same length as `headers` where `fields[i]` is - the entry for `headers[i]` in this row. Elements can be of - arbitrary types. Pass `headers` to print the header row. - - Returns: - A pretty string. - """ - assert len(fields) == len(headers), (fields, headers) - fields = ["%2.4f" % field if isinstance(field, float) else str(field) - for field in fields] - return ' '.join(' ' * max(0, len(header) - len(field)) + field - for (header, field) in zip(headers, fields)) + """Format a line of a table. + + Arguments: + headers: A list of strings that are used as the table headers. + fields: A list of the same length as `headers` where `fields[i]` is + the entry for `headers[i]` in this row. Elements can be of + arbitrary types. Pass `headers` to print the header row. + + Returns: + A pretty string. + """ + assert len(fields) == len(headers), (fields, headers) + fields = [ + "%2.4f" % field if isinstance(field, float) else str(field) + for field in fields + ] + return " ".join( + " " * max(0, len(header) - len(field)) + field + for (header, field) in zip(headers, fields) + ) def main(unused_argv): - logging.set_verbosity(logging.INFO) - np.random.seed(0) - - thread_counts = [1, 2, 4, 6, 8, 10, 12, 14, 16, 32] - - logger.info("Warming up...") - warmup_image = _image_of_size(256) - for thread_count in thread_counts: - bench(warmup_image, thread_count) - - logger.info("Running...") - results = {} - image = _image_of_size(4096) - headers = ('THREADS', 'TOTAL_TIME', 'UNIT_TIME', 'SPEEDUP', 'PARALLELISM') - logger.info(_format_line(headers, headers)) - for thread_count in thread_counts: - time.sleep(1.0) - total_time = min(bench(image, thread_count) - for _ in xrange(3)) # best-of-three timing - unit_time = total_time / thread_count - if total_time < 2.0: - logger.warn("This benchmark is running too quickly! This " - "may cause misleading timing data. Consider " - "increasing the image size until it takes at " - "least 2.0s to encode one image.") - results[thread_count] = unit_time - speedup = results[1] / results[thread_count] - parallelism = speedup / thread_count - fields = (thread_count, total_time, unit_time, speedup, parallelism) - logger.info(_format_line(headers, fields)) - - -if __name__ == '__main__': - app.run(main) + logging.set_verbosity(logging.INFO) + np.random.seed(0) + + thread_counts = [1, 2, 4, 6, 8, 10, 12, 14, 16, 32] + + logger.info("Warming up...") + warmup_image = _image_of_size(256) + for thread_count in thread_counts: + bench(warmup_image, thread_count) + + logger.info("Running...") + results = {} + image = _image_of_size(4096) + headers = ("THREADS", "TOTAL_TIME", "UNIT_TIME", "SPEEDUP", "PARALLELISM") + logger.info(_format_line(headers, headers)) + for thread_count in thread_counts: + time.sleep(1.0) + total_time = min( + bench(image, thread_count) for _ in xrange(3) + ) # best-of-three timing + unit_time = total_time / thread_count + if total_time < 2.0: + logger.warn( + "This benchmark is running too quickly! This " + "may cause misleading timing data. Consider " + "increasing the image size until it takes at " + "least 2.0s to encode one image." + ) + results[thread_count] = unit_time + speedup = results[1] / results[thread_count] + parallelism = speedup / thread_count + fields = (thread_count, total_time, unit_time, speedup, parallelism) + logger.info(_format_line(headers, fields)) + + +if __name__ == "__main__": + app.run(main) diff --git a/tensorboard/errors.py b/tensorboard/errors.py index f937eaa545..4ad906c1c3 100644 --- a/tensorboard/errors.py +++ b/tensorboard/errors.py @@ -30,67 +30,67 @@ class PublicError(RuntimeError): - """An error whose text does not contain sensitive information.""" + """An error whose text does not contain sensitive information.""" - http_code = 500 # default; subclasses should override + http_code = 500 # default; subclasses should override - def __init__(self, details): - super(PublicError, self).__init__(details) + def __init__(self, details): + super(PublicError, self).__init__(details) class InvalidArgumentError(PublicError): - """Client specified an invalid argument. + """Client specified an invalid argument. - The text of this error is assumed not to contain sensitive data, - and so may appear in (e.g.) the response body of a failed HTTP - request. + The text of this error is assumed not to contain sensitive data, + and so may appear in (e.g.) the response body of a failed HTTP + request. - Corresponds to HTTP 400 Bad Request or Google error code `INVALID_ARGUMENT`. - """ + Corresponds to HTTP 400 Bad Request or Google error code `INVALID_ARGUMENT`. + """ - http_code = 400 + http_code = 400 - def __init__(self, details=None): - msg = _format_message("Invalid argument", details) - super(InvalidArgumentError, self).__init__(msg) + def __init__(self, details=None): + msg = _format_message("Invalid argument", details) + super(InvalidArgumentError, self).__init__(msg) class NotFoundError(PublicError): - """Some requested entity (e.g., file or directory) was not found. + """Some requested entity (e.g., file or directory) was not found. - The text of this error is assumed not to contain sensitive data, - and so may appear in (e.g.) the response body of a failed HTTP - request. + The text of this error is assumed not to contain sensitive data, + and so may appear in (e.g.) the response body of a failed HTTP + request. - Corresponds to HTTP 404 Not Found or Google error code `NOT_FOUND`. - """ + Corresponds to HTTP 404 Not Found or Google error code `NOT_FOUND`. + """ - http_code = 404 + http_code = 404 - def __init__(self, details=None): - msg = _format_message("Not found", details) - super(NotFoundError, self).__init__(msg) + def __init__(self, details=None): + msg = _format_message("Not found", details) + super(NotFoundError, self).__init__(msg) class PermissionDeniedError(PublicError): - """The caller does not have permission to execute the specified operation. + """The caller does not have permission to execute the specified operation. - The text of this error is assumed not to contain sensitive data, - and so may appear in (e.g.) the response body of a failed HTTP - request. + The text of this error is assumed not to contain sensitive data, + and so may appear in (e.g.) the response body of a failed HTTP + request. - Corresponds to HTTP 403 Forbidden or Google error code `PERMISSION_DENIED`. - """ + Corresponds to HTTP 403 Forbidden or Google error code `PERMISSION_DENIED`. + """ - http_code = 403 + http_code = 403 - def __init__(self, details=None): - msg = _format_message("Permission denied", details) - super(PermissionDeniedError, self).__init__(msg) + def __init__(self, details=None): + msg = _format_message("Permission denied", details) + super(PermissionDeniedError, self).__init__(msg) def _format_message(code_name, details): - if details is None: - return code_name - else: - return "%s: %s" % (code_name, details) + if details is None: + return code_name + else: + return "%s: %s" % (code_name, details) diff --git a/tensorboard/errors_test.py b/tensorboard/errors_test.py index b0c6ac0156..050065a6d3 100644 --- a/tensorboard/errors_test.py +++ b/tensorboard/errors_test.py @@ -23,52 +23,49 @@ class InvalidArgumentErrorTest(tb_test.TestCase): + def test_no_details(self): + e = errors.InvalidArgumentError() + expected_msg = "Invalid argument" + self.assertEqual(str(e), expected_msg) - def test_no_details(self): - e = errors.InvalidArgumentError() - expected_msg = "Invalid argument" - self.assertEqual(str(e), expected_msg) + def test_with_details(self): + e = errors.InvalidArgumentError("expected absolute path; got './foo'") + expected_msg = "Invalid argument: expected absolute path; got './foo'" + self.assertEqual(str(e), expected_msg) - def test_with_details(self): - e = errors.InvalidArgumentError("expected absolute path; got './foo'") - expected_msg = "Invalid argument: expected absolute path; got './foo'" - self.assertEqual(str(e), expected_msg) - - def test_http_code(self): - self.assertEqual(errors.InvalidArgumentError().http_code, 400) + def test_http_code(self): + self.assertEqual(errors.InvalidArgumentError().http_code, 400) class NotFoundErrorTest(tb_test.TestCase): + def test_no_details(self): + e = errors.NotFoundError() + expected_msg = "Not found" + self.assertEqual(str(e), expected_msg) - def test_no_details(self): - e = errors.NotFoundError() - expected_msg = "Not found" - self.assertEqual(str(e), expected_msg) - - def test_with_details(self): - e = errors.NotFoundError("no scalar data for run=foo, tag=bar") - expected_msg = "Not found: no scalar data for run=foo, tag=bar" - self.assertEqual(str(e), expected_msg) + def test_with_details(self): + e = errors.NotFoundError("no scalar data for run=foo, tag=bar") + expected_msg = "Not found: no scalar data for run=foo, tag=bar" + self.assertEqual(str(e), expected_msg) - def test_http_code(self): - self.assertEqual(errors.NotFoundError().http_code, 404) + def test_http_code(self): + self.assertEqual(errors.NotFoundError().http_code, 404) class PermissionDeniedErrorTest(tb_test.TestCase): + def test_no_details(self): + e = errors.PermissionDeniedError() + expected_msg = "Permission denied" + self.assertEqual(str(e), expected_msg) - def test_no_details(self): - e = errors.PermissionDeniedError() - expected_msg = "Permission denied" - self.assertEqual(str(e), expected_msg) - - def test_with_details(self): - e = errors.PermissionDeniedError("this data is top secret") - expected_msg = "Permission denied: this data is top secret" - self.assertEqual(str(e), expected_msg) + def test_with_details(self): + e = errors.PermissionDeniedError("this data is top secret") + expected_msg = "Permission denied: this data is top secret" + self.assertEqual(str(e), expected_msg) - def test_http_code(self): - self.assertEqual(errors.PermissionDeniedError().http_code, 403) + def test_http_code(self): + self.assertEqual(errors.PermissionDeniedError().http_code, 403) if __name__ == "__main__": - tb_test.main() + tb_test.main() diff --git a/tensorboard/examples/plugins/example_basic/setup.py b/tensorboard/examples/plugins/example_basic/setup.py index e774ecbeec..54da6e0c50 100644 --- a/tensorboard/examples/plugins/example_basic/setup.py +++ b/tensorboard/examples/plugins/example_basic/setup.py @@ -25,9 +25,7 @@ version="0.1.0", description="Sample TensorBoard plugin.", packages=["tensorboard_plugin_example"], - package_data={ - "tensorboard_plugin_example": ["static/**"], - }, + package_data={"tensorboard_plugin_example": ["static/**"],}, entry_points={ "tensorboard_plugins": [ "example_basic = tensorboard_plugin_example.plugin:ExamplePlugin", diff --git a/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/demo.py b/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/demo.py index 3f7260481b..536e2c1b35 100644 --- a/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/demo.py +++ b/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/demo.py @@ -29,18 +29,17 @@ def main(unused_argv): - writer = tf.summary.create_file_writer("demo_logs") - with writer.as_default(): - summary_v2.greeting( - "guestbook", - "Alice", - step=0, - description="Sign your name!", - ) - summary_v2.greeting("guestbook", "Bob", step=1) # no need for `description` - summary_v2.greeting("guestbook", "Cheryl", step=2) - summary_v2.greeting("more_names", "David", step=4) + writer = tf.summary.create_file_writer("demo_logs") + with writer.as_default(): + summary_v2.greeting( + "guestbook", "Alice", step=0, description="Sign your name!", + ) + summary_v2.greeting( + "guestbook", "Bob", step=1 + ) # no need for `description` + summary_v2.greeting("guestbook", "Cheryl", step=2) + summary_v2.greeting("more_names", "David", step=4) if __name__ == "__main__": - app.run(main) + app.run(main) diff --git a/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/plugin.py b/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/plugin.py index acfccb3dbd..a8bacd0750 100644 --- a/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/plugin.py +++ b/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/plugin.py @@ -32,59 +32,66 @@ class ExamplePlugin(base_plugin.TBPlugin): - plugin_name = metadata.PLUGIN_NAME + plugin_name = metadata.PLUGIN_NAME - def __init__(self, context): - self._multiplexer = context.multiplexer + def __init__(self, context): + self._multiplexer = context.multiplexer - def is_active(self): - return bool(self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME)) + def is_active(self): + return bool( + self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME) + ) - def get_plugin_apps(self): - return { - "/index.js": self._serve_js, - "/tags": self._serve_tags, - "/greetings": self._serve_greetings, - } + def get_plugin_apps(self): + return { + "/index.js": self._serve_js, + "/tags": self._serve_tags, + "/greetings": self._serve_greetings, + } - def frontend_metadata(self): - return base_plugin.FrontendMetadata(es_module_path="/index.js") + def frontend_metadata(self): + return base_plugin.FrontendMetadata(es_module_path="/index.js") - @wrappers.Request.application - def _serve_js(self, request): - del request # unused - filepath = os.path.join(os.path.dirname(__file__), "static", "index.js") - with open(filepath) as infile: - contents = infile.read() - return werkzeug.Response(contents, content_type="application/javascript") + @wrappers.Request.application + def _serve_js(self, request): + del request # unused + filepath = os.path.join(os.path.dirname(__file__), "static", "index.js") + with open(filepath) as infile: + contents = infile.read() + return werkzeug.Response( + contents, content_type="application/javascript" + ) - @wrappers.Request.application - def _serve_tags(self, request): - del request # unused - mapping = self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME) - result = {run: {} for run in self._multiplexer.Runs()} - for (run, tag_to_content) in six.iteritems(mapping): - for tag in tag_to_content: - summary_metadata = self._multiplexer.SummaryMetadata(run, tag) - result[run][tag] = { - u"description": summary_metadata.summary_description, - } - contents = json.dumps(result, sort_keys=True) - return werkzeug.Response(contents, content_type="application/json") + @wrappers.Request.application + def _serve_tags(self, request): + del request # unused + mapping = self._multiplexer.PluginRunToTagToContent( + metadata.PLUGIN_NAME + ) + result = {run: {} for run in self._multiplexer.Runs()} + for (run, tag_to_content) in six.iteritems(mapping): + for tag in tag_to_content: + summary_metadata = self._multiplexer.SummaryMetadata(run, tag) + result[run][tag] = { + u"description": summary_metadata.summary_description, + } + contents = json.dumps(result, sort_keys=True) + return werkzeug.Response(contents, content_type="application/json") - @wrappers.Request.application - def _serve_greetings(self, request): - run = request.args.get("run") - tag = request.args.get("tag") - if run is None or tag is None: - raise werkzeug.exceptions.BadRequest("Must specify run and tag") - try: - data = [ - np.asscalar(tensor_util.make_ndarray(event.tensor_proto)) - .decode("utf-8") - for event in self._multiplexer.Tensors(run, tag) - ] - except KeyError: - raise werkzeug.exceptions.BadRequest("Invalid run or tag") - contents = json.dumps(data, sort_keys=True) - return werkzeug.Response(contents, content_type="application/json") + @wrappers.Request.application + def _serve_greetings(self, request): + run = request.args.get("run") + tag = request.args.get("tag") + if run is None or tag is None: + raise werkzeug.exceptions.BadRequest("Must specify run and tag") + try: + data = [ + np.asscalar( + tensor_util.make_ndarray(event.tensor_proto) + ).decode("utf-8") + for event in self._multiplexer.Tensors(run, tag) + ] + except KeyError: + raise werkzeug.exceptions.BadRequest("Invalid run or tag") + contents = json.dumps(data, sort_keys=True) + return werkzeug.Response(contents, content_type="application/json") diff --git a/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/summary_v2.py b/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/summary_v2.py index 04c80d0f58..fd8da23c0d 100644 --- a/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/summary_v2.py +++ b/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/summary_v2.py @@ -26,42 +26,42 @@ def greeting(name, guest, step=None, description=None): - """Write a "greeting" summary. + """Write a "greeting" summary. - Arguments: - name: A name for this summary. The summary tag used for TensorBoard will - be this name prefixed by any active name scopes. - guest: A rank-0 string `Tensor`. - step: Explicit `int64`-castable monotonic step value for this summary. If - omitted, this defaults to `tf.summary.experimental.get_step()`, which must - not be None. - description: Optional long-form description for this summary, as a - constant `str`. Markdown is supported. Defaults to empty. + Arguments: + name: A name for this summary. The summary tag used for TensorBoard will + be this name prefixed by any active name scopes. + guest: A rank-0 string `Tensor`. + step: Explicit `int64`-castable monotonic step value for this summary. If + omitted, this defaults to `tf.summary.experimental.get_step()`, which must + not be None. + description: Optional long-form description for this summary, as a + constant `str`. Markdown is supported. Defaults to empty. - Returns: - True on success, or false if no summary was written because no default - summary writer was available. + Returns: + True on success, or false if no summary was written because no default + summary writer was available. - Raises: - ValueError: if a default writer exists, but no step was provided and - `tf.summary.experimental.get_step()` is None. - """ - with tf.summary.experimental.summary_scope( - name, "greeting_summary", values=[guest, step], - ) as (tag, _): - return tf.summary.write( - tag=tag, - tensor=tf.strings.join(["Hello, ", guest, "!"]), - step=step, - metadata=_create_summary_metadata(description), - ) + Raises: + ValueError: if a default writer exists, but no step was provided and + `tf.summary.experimental.get_step()` is None. + """ + with tf.summary.experimental.summary_scope( + name, "greeting_summary", values=[guest, step], + ) as (tag, _): + return tf.summary.write( + tag=tag, + tensor=tf.strings.join(["Hello, ", guest, "!"]), + step=step, + metadata=_create_summary_metadata(description), + ) def _create_summary_metadata(description): - return summary_pb2.SummaryMetadata( - summary_description=description, - plugin_data=summary_pb2.SummaryMetadata.PluginData( - plugin_name=metadata.PLUGIN_NAME, - content=b"", # no need for summary-specific metadata - ), - ) + return summary_pb2.SummaryMetadata( + summary_description=description, + plugin_data=summary_pb2.SummaryMetadata.PluginData( + plugin_name=metadata.PLUGIN_NAME, + content=b"", # no need for summary-specific metadata + ), + ) diff --git a/tensorboard/functionaltests/core_test.py b/tensorboard/functionaltests/core_test.py index 0a3a6f13f0..3953c5b99f 100644 --- a/tensorboard/functionaltests/core_test.py +++ b/tensorboard/functionaltests/core_test.py @@ -37,171 +37,186 @@ class BasicTest(unittest.TestCase): - """Tests that the basic chrome is displayed when there is no data.""" - - @classmethod - def setUpClass(cls): - src_dir = os.environ["TEST_SRCDIR"] - binary = os.path.join(src_dir, - "org_tensorflow_tensorboard/tensorboard/tensorboard") - cls.logdir = tempfile.mkdtemp(prefix='core_test_%s_logdir_' % cls.__name__) - cls.setUpData() - cls.port = _BASE_PORT + _PORT_OFFSETS[cls] - cls.process = subprocess.Popen( - [binary, "--port", str(cls.port), "--logdir", cls.logdir]) - - @classmethod - def setUpData(cls): - # Overridden by DashboardsTest. - pass - - @classmethod - def tearDownClass(cls): - cls.process.kill() - cls.process.wait() - - def setUp(self): - self.driver = webtest.new_webdriver_session() - self.driver.get("http://localhost:%s" % self.port) - self.wait = wait.WebDriverWait(self.driver, 10) - - def tearDown(self): - try: - self.driver.quit() - finally: - self.driver = None - self.wait = None - - def testToolbarTitleDisplays(self): - self.wait.until( - expected_conditions.text_to_be_present_in_element(( - by.By.CLASS_NAME, "toolbar-title"), "TensorBoard")) - - def testLogdirDisplays(self): - self.wait.until( - expected_conditions.text_to_be_present_in_element(( - by.By.ID, "data_location"), self.logdir)) + """Tests that the basic chrome is displayed when there is no data.""" + + @classmethod + def setUpClass(cls): + src_dir = os.environ["TEST_SRCDIR"] + binary = os.path.join( + src_dir, "org_tensorflow_tensorboard/tensorboard/tensorboard" + ) + cls.logdir = tempfile.mkdtemp( + prefix="core_test_%s_logdir_" % cls.__name__ + ) + cls.setUpData() + cls.port = _BASE_PORT + _PORT_OFFSETS[cls] + cls.process = subprocess.Popen( + [binary, "--port", str(cls.port), "--logdir", cls.logdir] + ) + + @classmethod + def setUpData(cls): + # Overridden by DashboardsTest. + pass + + @classmethod + def tearDownClass(cls): + cls.process.kill() + cls.process.wait() + + def setUp(self): + self.driver = webtest.new_webdriver_session() + self.driver.get("http://localhost:%s" % self.port) + self.wait = wait.WebDriverWait(self.driver, 10) + + def tearDown(self): + try: + self.driver.quit() + finally: + self.driver = None + self.wait = None + + def testToolbarTitleDisplays(self): + self.wait.until( + expected_conditions.text_to_be_present_in_element( + (by.By.CLASS_NAME, "toolbar-title"), "TensorBoard" + ) + ) + + def testLogdirDisplays(self): + self.wait.until( + expected_conditions.text_to_be_present_in_element( + (by.By.ID, "data_location"), self.logdir + ) + ) + class DashboardsTest(BasicTest): - """Tests basic behavior when there is some data in TensorBoard. - - This extends `BasicTest`, so it inherits its methods to test that the - basic chrome is displayed. We also check that we can navigate around - the various dashboards. - """ - - @classmethod - def setUpData(cls): - scalars_demo.run_all(cls.logdir, verbose=False) - audio_demo.run_all(cls.logdir, verbose=False) - - def testLogdirDisplays(self): - # TensorBoard doesn't have logdir display when there is data - pass - - def testDashboardSelection(self): - """Test that we can navigate among the different dashboards.""" - selectors = { - "scalars_tab": "paper-tab[data-dashboard=scalars]", - "audio_tab": "paper-tab[data-dashboard=audio]", - "graphs_tab": "paper-tab[data-dashboard=graphs]", - "inactive_dropdown": "paper-dropdown-menu[label*=Inactive]", - "images_menu_item": "paper-item[data-dashboard=images]", - "reload_button": "paper-icon-button#reload-button", - } - elements = {} - for (name, selector) in selectors.items(): - locator = (by.By.CSS_SELECTOR, selector) - self.wait.until(expected_conditions.presence_of_element_located(locator)) - elements[name] = self.driver.find_element_by_css_selector(selector) - - # The implementation of paper-* components doesn't seem to play nice - # with Selenium's `element.is_selected()` and `element.is_enabled()` - # methods. Instead, we check the appropriate WAI-ARIA attributes. - # (Also, though the docs for `get_attribute` say that the string - # `"false"` is returned as `False`, this appears to be the case - # _sometimes_ but not _always_, so we should take special care to - # handle that.) - def is_selected(element): - attribute = element.get_attribute("aria-selected") - return attribute and attribute != "false" - - def is_enabled(element): - attribute = element.get_attribute("aria-disabled") - is_disabled = attribute and attribute != "false" - return not is_disabled - - def assert_selected_dashboard(polymer_component_name): - expected = {polymer_component_name} - actual = { - container.find_element_by_css_selector("*").tag_name # first child - for container - in self.driver.find_elements_by_css_selector(".dashboard-container") - if container.is_displayed() - } - self.assertEqual(expected, actual) - - # The scalar and audio dashboards should be active, and the scalar - # dashboard should be selected by default. The images menu item - # should not be visible, as it's within the drop-down menu. - self.assertTrue(elements["scalars_tab"].is_displayed()) - self.assertTrue(elements["audio_tab"].is_displayed()) - self.assertTrue(elements["graphs_tab"].is_displayed()) - self.assertTrue(is_selected(elements["scalars_tab"])) - self.assertFalse(is_selected(elements["audio_tab"])) - self.assertFalse(elements["images_menu_item"].is_displayed()) - self.assertFalse(is_selected(elements["images_menu_item"])) - assert_selected_dashboard("tf-scalar-dashboard") - - # While we're on the scalar dashboard, we should be allowed to - # reload the data. - self.assertTrue(is_enabled(elements["reload_button"])) - - # We should be able to activate the audio dashboard. - elements["audio_tab"].click() - self.assertFalse(is_selected(elements["scalars_tab"])) - self.assertTrue(is_selected(elements["audio_tab"])) - self.assertFalse(is_selected(elements["graphs_tab"])) - self.assertFalse(is_selected(elements["images_menu_item"])) - assert_selected_dashboard("tf-audio-dashboard") - self.assertTrue(is_enabled(elements["reload_button"])) - - # We should then be able to open the dropdown and navigate to the - # image dashboard. (We have to wait until it's visible because of the - # dropdown menu's animations.) - elements["inactive_dropdown"].click() - self.wait.until( - expected_conditions.visibility_of(elements["images_menu_item"])) - self.assertTrue(elements["images_menu_item"].is_displayed()) - elements["images_menu_item"].click() - self.assertFalse(is_selected(elements["scalars_tab"])) - self.assertFalse(is_selected(elements["audio_tab"])) - self.assertFalse(is_selected(elements["graphs_tab"])) - self.assertTrue(is_selected(elements["images_menu_item"])) - assert_selected_dashboard("tf-image-dashboard") - self.assertTrue(is_enabled(elements["reload_button"])) - - # Next, we should be able to navigate back to an active dashboard. - # If we choose the graphs dashboard, the reload feature should be - # disabled. - elements["graphs_tab"].click() - self.assertFalse(elements["images_menu_item"].is_displayed()) - self.assertFalse(is_selected(elements["scalars_tab"])) - self.assertFalse(is_selected(elements["audio_tab"])) - self.assertTrue(is_selected(elements["graphs_tab"])) - self.assertFalse(is_selected(elements["images_menu_item"])) - assert_selected_dashboard("tf-graph-dashboard") - self.assertFalse(is_enabled(elements["reload_button"])) - - # Finally, we should be able to navigate back to the scalar dashboard. - elements["scalars_tab"].click() - self.assertTrue(is_selected(elements["scalars_tab"])) - self.assertFalse(is_selected(elements["audio_tab"])) - self.assertFalse(is_selected(elements["graphs_tab"])) - self.assertFalse(is_selected(elements["images_menu_item"])) - assert_selected_dashboard("tf-scalar-dashboard") - self.assertTrue(is_enabled(elements["reload_button"])) + """Tests basic behavior when there is some data in TensorBoard. + + This extends `BasicTest`, so it inherits its methods to test that + the basic chrome is displayed. We also check that we can navigate + around the various dashboards. + """ + + @classmethod + def setUpData(cls): + scalars_demo.run_all(cls.logdir, verbose=False) + audio_demo.run_all(cls.logdir, verbose=False) + + def testLogdirDisplays(self): + # TensorBoard doesn't have logdir display when there is data + pass + + def testDashboardSelection(self): + """Test that we can navigate among the different dashboards.""" + selectors = { + "scalars_tab": "paper-tab[data-dashboard=scalars]", + "audio_tab": "paper-tab[data-dashboard=audio]", + "graphs_tab": "paper-tab[data-dashboard=graphs]", + "inactive_dropdown": "paper-dropdown-menu[label*=Inactive]", + "images_menu_item": "paper-item[data-dashboard=images]", + "reload_button": "paper-icon-button#reload-button", + } + elements = {} + for (name, selector) in selectors.items(): + locator = (by.By.CSS_SELECTOR, selector) + self.wait.until( + expected_conditions.presence_of_element_located(locator) + ) + elements[name] = self.driver.find_element_by_css_selector(selector) + + # The implementation of paper-* components doesn't seem to play nice + # with Selenium's `element.is_selected()` and `element.is_enabled()` + # methods. Instead, we check the appropriate WAI-ARIA attributes. + # (Also, though the docs for `get_attribute` say that the string + # `"false"` is returned as `False`, this appears to be the case + # _sometimes_ but not _always_, so we should take special care to + # handle that.) + def is_selected(element): + attribute = element.get_attribute("aria-selected") + return attribute and attribute != "false" + + def is_enabled(element): + attribute = element.get_attribute("aria-disabled") + is_disabled = attribute and attribute != "false" + return not is_disabled + + def assert_selected_dashboard(polymer_component_name): + expected = {polymer_component_name} + actual = { + container.find_element_by_css_selector( + "*" + ).tag_name # first child + for container in self.driver.find_elements_by_css_selector( + ".dashboard-container" + ) + if container.is_displayed() + } + self.assertEqual(expected, actual) + + # The scalar and audio dashboards should be active, and the scalar + # dashboard should be selected by default. The images menu item + # should not be visible, as it's within the drop-down menu. + self.assertTrue(elements["scalars_tab"].is_displayed()) + self.assertTrue(elements["audio_tab"].is_displayed()) + self.assertTrue(elements["graphs_tab"].is_displayed()) + self.assertTrue(is_selected(elements["scalars_tab"])) + self.assertFalse(is_selected(elements["audio_tab"])) + self.assertFalse(elements["images_menu_item"].is_displayed()) + self.assertFalse(is_selected(elements["images_menu_item"])) + assert_selected_dashboard("tf-scalar-dashboard") + + # While we're on the scalar dashboard, we should be allowed to + # reload the data. + self.assertTrue(is_enabled(elements["reload_button"])) + + # We should be able to activate the audio dashboard. + elements["audio_tab"].click() + self.assertFalse(is_selected(elements["scalars_tab"])) + self.assertTrue(is_selected(elements["audio_tab"])) + self.assertFalse(is_selected(elements["graphs_tab"])) + self.assertFalse(is_selected(elements["images_menu_item"])) + assert_selected_dashboard("tf-audio-dashboard") + self.assertTrue(is_enabled(elements["reload_button"])) + + # We should then be able to open the dropdown and navigate to the + # image dashboard. (We have to wait until it's visible because of the + # dropdown menu's animations.) + elements["inactive_dropdown"].click() + self.wait.until( + expected_conditions.visibility_of(elements["images_menu_item"]) + ) + self.assertTrue(elements["images_menu_item"].is_displayed()) + elements["images_menu_item"].click() + self.assertFalse(is_selected(elements["scalars_tab"])) + self.assertFalse(is_selected(elements["audio_tab"])) + self.assertFalse(is_selected(elements["graphs_tab"])) + self.assertTrue(is_selected(elements["images_menu_item"])) + assert_selected_dashboard("tf-image-dashboard") + self.assertTrue(is_enabled(elements["reload_button"])) + + # Next, we should be able to navigate back to an active dashboard. + # If we choose the graphs dashboard, the reload feature should be + # disabled. + elements["graphs_tab"].click() + self.assertFalse(elements["images_menu_item"].is_displayed()) + self.assertFalse(is_selected(elements["scalars_tab"])) + self.assertFalse(is_selected(elements["audio_tab"])) + self.assertTrue(is_selected(elements["graphs_tab"])) + self.assertFalse(is_selected(elements["images_menu_item"])) + assert_selected_dashboard("tf-graph-dashboard") + self.assertFalse(is_enabled(elements["reload_button"])) + + # Finally, we should be able to navigate back to the scalar dashboard. + elements["scalars_tab"].click() + self.assertTrue(is_selected(elements["scalars_tab"])) + self.assertFalse(is_selected(elements["audio_tab"])) + self.assertFalse(is_selected(elements["graphs_tab"])) + self.assertFalse(is_selected(elements["images_menu_item"])) + assert_selected_dashboard("tf-scalar-dashboard") + self.assertTrue(is_enabled(elements["reload_button"])) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/tensorboard/functionaltests/wct_test_driver.py b/tensorboard/functionaltests/wct_test_driver.py index d0e09eb9a7..9fbecd806b 100644 --- a/tensorboard/functionaltests/wct_test_driver.py +++ b/tensorboard/functionaltests/wct_test_driver.py @@ -33,114 +33,124 @@ # As emitted in the "Listening on:" line of the WebfilesServer output. # We extract only the port because the hostname can reroute through corp # DNS and force auth, which fails in tests. -_URL_RE = re.compile(br'http://[^:]*:([0-9]+)/') +_URL_RE = re.compile(br"http://[^:]*:([0-9]+)/") _SUITE_PASSED_RE = re.compile(r'.*test suite passed"$') -_SUITE_FAILED_RE = re.compile(r'.*failing test.*') +_SUITE_FAILED_RE = re.compile(r".*failing test.*") + def create_test_class(binary_path, web_path): - """Create a unittest.TestCase class to run WebComponentTester tests. - - Arguments: - binary_path: relative path to a `tf_web_library` target; - e.g.: "tensorboard/components/vz_foo/test/test_web_library" - web_path: absolute web path to the tests page in the above web - library; e.g.: "/vz-foo/test/tests.html" - - Result: - A new subclass of `unittest.TestCase`. Bind this to a variable in - the test file's main module. - """ - - class BrowserLogIndicatesResult(object): - def __init__(self): - self.passed = False - self.log = [] - - def __call__(self, driver): - # Scan through the log entries and search for a line indicating whether - # the test passed or failed. The method 'driver.get_log' also seems to - # clear the log so we aggregate it in self.log for printing later on. - new_log = driver.get_log("browser") - new_messages = [entry["message"] for entry in new_log] - self.log = self.log + new_log - if self._log_matches(new_messages, _SUITE_FAILED_RE): - self.passed = False - return True - if self._log_matches(new_messages, _SUITE_PASSED_RE): - self.passed = True - return True - # Here, we still don't know. - return False - - def _log_matches(self, messages, regexp): - for message in messages: - if regexp.match(message): - return True - return False - - class WebComponentTesterTest(unittest.TestCase): - """Tests that a family of unit tests completes successfully.""" - - def setUp(cls): - src_dir = os.environ["TEST_SRCDIR"] - binary = os.path.join( - src_dir, - "org_tensorflow_tensorboard/" + binary_path) - cls.process = subprocess.Popen( - [binary], stdin=None, stdout=None, stderr=subprocess.PIPE) - - lines = [] - hit_eof = False - while True: - line = cls.process.stderr.readline() - if line == b"": - # b"" means reached EOF; b"\n" means empty line. - hit_eof = True - break - lines.append(line) - if b"Listening on:" in line: - match = _URL_RE.search(line) - if match: - cls.port = int(match.group(1)) - break - else: - raise ValueError("Failed to parse listening-on line: %r" % line) - if len(lines) >= 1024: - # Sanity check---something is wrong. Let us fail fast rather - # than spending the 15-minute test timeout consuming - # potentially empty logs. - hit_eof = True - break - if hit_eof: - full_output = "\n".join(repr(line) for line in lines) - raise ValueError( - "Did not find listening-on line in output:\n%s" % full_output) - - def tearDown(cls): - cls.process.kill() - cls.process.wait() - - def test(self): - driver = webtest.new_webdriver_session( - capabilities={"loggingPrefs": {"browser": "ALL"}} - ) - url = "http://localhost:%s%s" % (self.port, web_path) - driver.get(url) - browser_log_indicates_result = BrowserLogIndicatesResult() - try: - wait.WebDriverWait(driver, 10).until(browser_log_indicates_result) - if not browser_log_indicates_result.passed: - self.fail() - finally: - # Print log as an aid for debugging. - log = browser_log_indicates_result.log + driver.get_log("browser") - self._print_log(log) - - def _print_log(self, entries): - print("Browser log follows:") - print("--------------------") - print(" | ".join(entries[0].keys())) - for entry in entries: - print(" | ".join(str(v) for v in entry.values())) - - return WebComponentTesterTest + """Create a unittest.TestCase class to run WebComponentTester tests. + + Arguments: + binary_path: relative path to a `tf_web_library` target; + e.g.: "tensorboard/components/vz_foo/test/test_web_library" + web_path: absolute web path to the tests page in the above web + library; e.g.: "/vz-foo/test/tests.html" + + Result: + A new subclass of `unittest.TestCase`. Bind this to a variable in + the test file's main module. + """ + + class BrowserLogIndicatesResult(object): + def __init__(self): + self.passed = False + self.log = [] + + def __call__(self, driver): + # Scan through the log entries and search for a line indicating whether + # the test passed or failed. The method 'driver.get_log' also seems to + # clear the log so we aggregate it in self.log for printing later on. + new_log = driver.get_log("browser") + new_messages = [entry["message"] for entry in new_log] + self.log = self.log + new_log + if self._log_matches(new_messages, _SUITE_FAILED_RE): + self.passed = False + return True + if self._log_matches(new_messages, _SUITE_PASSED_RE): + self.passed = True + return True + # Here, we still don't know. + return False + + def _log_matches(self, messages, regexp): + for message in messages: + if regexp.match(message): + return True + return False + + class WebComponentTesterTest(unittest.TestCase): + """Tests that a family of unit tests completes successfully.""" + + def setUp(cls): + src_dir = os.environ["TEST_SRCDIR"] + binary = os.path.join( + src_dir, "org_tensorflow_tensorboard/" + binary_path + ) + cls.process = subprocess.Popen( + [binary], stdin=None, stdout=None, stderr=subprocess.PIPE + ) + + lines = [] + hit_eof = False + while True: + line = cls.process.stderr.readline() + if line == b"": + # b"" means reached EOF; b"\n" means empty line. + hit_eof = True + break + lines.append(line) + if b"Listening on:" in line: + match = _URL_RE.search(line) + if match: + cls.port = int(match.group(1)) + break + else: + raise ValueError( + "Failed to parse listening-on line: %r" % line + ) + if len(lines) >= 1024: + # Sanity check---something is wrong. Let us fail fast rather + # than spending the 15-minute test timeout consuming + # potentially empty logs. + hit_eof = True + break + if hit_eof: + full_output = "\n".join(repr(line) for line in lines) + raise ValueError( + "Did not find listening-on line in output:\n%s" + % full_output + ) + + def tearDown(cls): + cls.process.kill() + cls.process.wait() + + def test(self): + driver = webtest.new_webdriver_session( + capabilities={"loggingPrefs": {"browser": "ALL"}} + ) + url = "http://localhost:%s%s" % (self.port, web_path) + driver.get(url) + browser_log_indicates_result = BrowserLogIndicatesResult() + try: + wait.WebDriverWait(driver, 10).until( + browser_log_indicates_result + ) + if not browser_log_indicates_result.passed: + self.fail() + finally: + # Print log as an aid for debugging. + log = browser_log_indicates_result.log + driver.get_log( + "browser" + ) + self._print_log(log) + + def _print_log(self, entries): + print("Browser log follows:") + print("--------------------") + print(" | ".join(entries[0].keys())) + for entry in entries: + print(" | ".join(str(v) for v in entry.values())) + + return WebComponentTesterTest diff --git a/tensorboard/lazy.py b/tensorboard/lazy.py index a891903965..65d51e489b 100644 --- a/tensorboard/lazy.py +++ b/tensorboard/lazy.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TensorBoard is a webapp for understanding TensorFlow runs and graphs. -""" +"""TensorBoard is a webapp for understanding TensorFlow runs and graphs.""" from __future__ import absolute_import from __future__ import division @@ -25,69 +24,80 @@ def lazy_load(name): - """Decorator to define a function that lazily loads the module 'name'. - - This can be used to defer importing troublesome dependencies - e.g. ones that - are large and infrequently used, or that cause a dependency cycle - - until they are actually used. - - Args: - name: the fully-qualified name of the module; typically the last segment - of 'name' matches the name of the decorated function - - Returns: - Decorator function that produces a lazy-loading module 'name' backed by the - underlying decorated function. - """ - def wrapper(load_fn): - # Wrap load_fn to call it exactly once and update __dict__ afterwards to - # make future lookups efficient (only failed lookups call __getattr__). - @_memoize - def load_once(self): - if load_once.loading: - raise ImportError("Circular import when resolving LazyModule %r" % name) - load_once.loading = True - try: - module = load_fn() - finally: + """Decorator to define a function that lazily loads the module 'name'. + + This can be used to defer importing troublesome dependencies - e.g. ones that + are large and infrequently used, or that cause a dependency cycle - + until they are actually used. + + Args: + name: the fully-qualified name of the module; typically the last segment + of 'name' matches the name of the decorated function + + Returns: + Decorator function that produces a lazy-loading module 'name' backed by the + underlying decorated function. + """ + + def wrapper(load_fn): + # Wrap load_fn to call it exactly once and update __dict__ afterwards to + # make future lookups efficient (only failed lookups call __getattr__). + @_memoize + def load_once(self): + if load_once.loading: + raise ImportError( + "Circular import when resolving LazyModule %r" % name + ) + load_once.loading = True + try: + module = load_fn() + finally: + load_once.loading = False + self.__dict__.update(module.__dict__) + load_once.loaded = True + return module + load_once.loading = False - self.__dict__.update(module.__dict__) - load_once.loaded = True - return module - load_once.loading = False - load_once.loaded = False + load_once.loaded = False + + # Define a module that proxies getattr() and dir() to the result of calling + # load_once() the first time it's needed. The class is nested so we can close + # over load_once() and avoid polluting the module's attrs with our own state. + class LazyModule(types.ModuleType): + def __getattr__(self, attr_name): + return getattr(load_once(self), attr_name) - # Define a module that proxies getattr() and dir() to the result of calling - # load_once() the first time it's needed. The class is nested so we can close - # over load_once() and avoid polluting the module's attrs with our own state. - class LazyModule(types.ModuleType): - def __getattr__(self, attr_name): - return getattr(load_once(self), attr_name) + def __dir__(self): + return dir(load_once(self)) - def __dir__(self): - return dir(load_once(self)) + def __repr__(self): + if load_once.loaded: + return "<%r via LazyModule (loaded)>" % load_once(self) + return ( + "" + % self.__name__ + ) - def __repr__(self): - if load_once.loaded: - return '<%r via LazyModule (loaded)>' % load_once(self) - return '' % self.__name__ + return LazyModule(name) - return LazyModule(name) - return wrapper + return wrapper def _memoize(f): - """Memoizing decorator for f, which must have exactly 1 hashable argument.""" - nothing = object() # Unique "no value" sentinel object. - cache = {} - # Use a reentrant lock so that if f references the resulting wrapper we die - # with recursion depth exceeded instead of deadlocking. - lock = threading.RLock() - @functools.wraps(f) - def wrapper(arg): - if cache.get(arg, nothing) is nothing: - with lock: + """Memoizing decorator for f, which must have exactly 1 hashable + argument.""" + nothing = object() # Unique "no value" sentinel object. + cache = {} + # Use a reentrant lock so that if f references the resulting wrapper we die + # with recursion depth exceeded instead of deadlocking. + lock = threading.RLock() + + @functools.wraps(f) + def wrapper(arg): if cache.get(arg, nothing) is nothing: - cache[arg] = f(arg) - return cache[arg] - return wrapper + with lock: + if cache.get(arg, nothing) is nothing: + cache[arg] = f(arg) + return cache[arg] + + return wrapper diff --git a/tensorboard/lazy_test.py b/tensorboard/lazy_test.py index 8ff3f7a204..9732b9df68 100644 --- a/tensorboard/lazy_test.py +++ b/tensorboard/lazy_test.py @@ -25,85 +25,95 @@ class LazyTest(unittest.TestCase): - - def test_self_composition(self): - """A lazy module should be able to load another lazy module.""" - # This test can fail if the `LazyModule` implementation stores the - # cached module as a field on the module itself rather than a - # closure value. (See pull request review comments on #1781 for - # details.) - - @lazy.lazy_load("inner") - def inner(): - import collections - return collections - - @lazy.lazy_load("outer") - def outer(): - return inner - - x1 = outer.namedtuple - x2 = inner.namedtuple - self.assertEqual(x1, x2) - - def test_lazy_cycle(self): - """A cycle among lazy modules should error, not deadlock or spin.""" - # This test can fail if `_memoize` uses a non-reentrant lock. (See - # pull request review comments on #1781 for details.) - - @lazy.lazy_load("inner") - def inner(): - return outer.foo - - @lazy.lazy_load("outer") - def outer(): - return inner - - expected_message = "Circular import when resolving LazyModule 'inner'" - with six.assertRaisesRegex(self, ImportError, expected_message): - outer.bar - - def test_repr_before_load(self): - @lazy.lazy_load("foo") - def foo(): - self.fail("Should not need to resolve this module.") - self.assertEquals(repr(foo), "") - - def test_repr_after_load(self): - import collections - @lazy.lazy_load("foo") - def foo(): - return collections - foo.namedtuple - self.assertEquals(repr(foo), "<%r via LazyModule (loaded)>" % collections) - - def test_failed_load_idempotent(self): - expected_message = "you will never stop me" - @lazy.lazy_load("bad") - def bad(): - raise ValueError(expected_message) - with six.assertRaisesRegex(self, ValueError, expected_message): - bad.day - with six.assertRaisesRegex(self, ValueError, expected_message): - bad.day - - def test_loads_only_once_even_when_result_equal_to_everything(self): - # This would fail if the implementation of `_memoize` used `==` - # rather than `is` to check for the sentinel value. - class EqualToEverything(object): - def __eq__(self, other): - return True - - count_box = [0] - @lazy.lazy_load("foo") - def foo(): - count_box[0] += 1 - return EqualToEverything() - - dir(foo) - dir(foo) - self.assertEqual(count_box[0], 1) - - -if __name__ == '__main__': - unittest.main() + def test_self_composition(self): + """A lazy module should be able to load another lazy module.""" + # This test can fail if the `LazyModule` implementation stores the + # cached module as a field on the module itself rather than a + # closure value. (See pull request review comments on #1781 for + # details.) + + @lazy.lazy_load("inner") + def inner(): + import collections + + return collections + + @lazy.lazy_load("outer") + def outer(): + return inner + + x1 = outer.namedtuple + x2 = inner.namedtuple + self.assertEqual(x1, x2) + + def test_lazy_cycle(self): + """A cycle among lazy modules should error, not deadlock or spin.""" + # This test can fail if `_memoize` uses a non-reentrant lock. (See + # pull request review comments on #1781 for details.) + + @lazy.lazy_load("inner") + def inner(): + return outer.foo + + @lazy.lazy_load("outer") + def outer(): + return inner + + expected_message = "Circular import when resolving LazyModule 'inner'" + with six.assertRaisesRegex(self, ImportError, expected_message): + outer.bar + + def test_repr_before_load(self): + @lazy.lazy_load("foo") + def foo(): + self.fail("Should not need to resolve this module.") + + self.assertEquals( + repr(foo), "" + ) + + def test_repr_after_load(self): + import collections + + @lazy.lazy_load("foo") + def foo(): + return collections + + foo.namedtuple + self.assertEquals( + repr(foo), "<%r via LazyModule (loaded)>" % collections + ) + + def test_failed_load_idempotent(self): + expected_message = "you will never stop me" + + @lazy.lazy_load("bad") + def bad(): + raise ValueError(expected_message) + + with six.assertRaisesRegex(self, ValueError, expected_message): + bad.day + with six.assertRaisesRegex(self, ValueError, expected_message): + bad.day + + def test_loads_only_once_even_when_result_equal_to_everything(self): + # This would fail if the implementation of `_memoize` used `==` + # rather than `is` to check for the sentinel value. + class EqualToEverything(object): + def __eq__(self, other): + return True + + count_box = [0] + + @lazy.lazy_load("foo") + def foo(): + count_box[0] += 1 + return EqualToEverything() + + dir(foo) + dir(foo) + self.assertEqual(count_box[0], 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tensorboard/lib_test.py b/tensorboard/lib_test.py index aab06b023e..b7d890fc16 100644 --- a/tensorboard/lib_test.py +++ b/tensorboard/lib_test.py @@ -22,22 +22,22 @@ class ReloadTensorBoardTest(unittest.TestCase): + def test_functional_after_reload(self): + self.assertNotIn("tensorboard", sys.modules) + import tensorboard as tensorboard # it makes the Google sync happy - def test_functional_after_reload(self): - self.assertNotIn("tensorboard", sys.modules) - import tensorboard as tensorboard # it makes the Google sync happy - submodules = ["notebook", "program", "summary"] - dirs_before = { - module_name: dir(getattr(tensorboard, module_name)) - for module_name in submodules - } - tensorboard = moves.reload_module(tensorboard) - dirs_after = { - module_name: dir(getattr(tensorboard, module_name)) - for module_name in submodules - } - self.assertEqual(dirs_before, dirs_after) + submodules = ["notebook", "program", "summary"] + dirs_before = { + module_name: dir(getattr(tensorboard, module_name)) + for module_name in submodules + } + tensorboard = moves.reload_module(tensorboard) + dirs_after = { + module_name: dir(getattr(tensorboard, module_name)) + for module_name in submodules + } + self.assertEqual(dirs_before, dirs_after) -if __name__ == '__main__': - unittest.main() +if __name__ == "__main__": + unittest.main() diff --git a/tensorboard/main.py b/tensorboard/main.py index 2495f335dc..0976831220 100644 --- a/tensorboard/main.py +++ b/tensorboard/main.py @@ -32,7 +32,7 @@ # pattern of reads used by TensorBoard for logdirs. See for details: # https://github.com/tensorflow/tensorboard/issues/1225 # This must be set before the first import of tensorflow. -os.environ['GCS_READ_CACHE_DISABLED'] = '1' +os.environ["GCS_READ_CACHE_DISABLED"] = "1" import sys @@ -47,33 +47,39 @@ logger = tb_logging.get_logger() + def run_main(): - """Initializes flags and calls main().""" - program.setup_environment() - - if getattr(tf, '__version__', 'stub') == 'stub': - print("TensorFlow installation not found - running with reduced feature set.", - file=sys.stderr) - - tensorboard = program.TensorBoard( - default.get_plugins() + default.get_dynamic_plugins(), - program.get_default_assets_zip_provider(), - subcommands=[uploader_main.UploaderSubcommand()]) - try: - from absl import app - # Import this to check that app.run() will accept the flags_parser argument. - from absl.flags import argparse_flags - app.run(tensorboard.main, flags_parser=tensorboard.configure) - raise AssertionError("absl.app.run() shouldn't return") - except ImportError: - pass - except base_plugin.FlagsError as e: - print("Error: %s" % e, file=sys.stderr) - sys.exit(1) - - tensorboard.configure(sys.argv) - sys.exit(tensorboard.main()) - - -if __name__ == '__main__': - run_main() + """Initializes flags and calls main().""" + program.setup_environment() + + if getattr(tf, "__version__", "stub") == "stub": + print( + "TensorFlow installation not found - running with reduced feature set.", + file=sys.stderr, + ) + + tensorboard = program.TensorBoard( + default.get_plugins() + default.get_dynamic_plugins(), + program.get_default_assets_zip_provider(), + subcommands=[uploader_main.UploaderSubcommand()], + ) + try: + from absl import app + + # Import this to check that app.run() will accept the flags_parser argument. + from absl.flags import argparse_flags + + app.run(tensorboard.main, flags_parser=tensorboard.configure) + raise AssertionError("absl.app.run() shouldn't return") + except ImportError: + pass + except base_plugin.FlagsError as e: + print("Error: %s" % e, file=sys.stderr) + sys.exit(1) + + tensorboard.configure(sys.argv) + sys.exit(tensorboard.main()) + + +if __name__ == "__main__": + run_main() diff --git a/tensorboard/manager.py b/tensorboard/manager.py index f6495bdc00..589c227fb7 100644 --- a/tensorboard/manager.py +++ b/tensorboard/manager.py @@ -41,12 +41,7 @@ # https://github.com/tensorflow/tensorboard/issues/2017. _FieldType = collections.namedtuple( "_FieldType", - ( - "serialized_type", - "runtime_type", - "serialize", - "deserialize", - ), + ("serialized_type", "runtime_type", "serialize", "deserialize",), ) _type_int = _FieldType( serialized_type=int, @@ -62,270 +57,268 @@ ) # Information about a running TensorBoard instance. -_TENSORBOARD_INFO_FIELDS = collections.OrderedDict(( - ("version", _type_str), - ("start_time", _type_int), # seconds since epoch - ("pid", _type_int), - ("port", _type_int), - ("path_prefix", _type_str), # may be empty - ("logdir", _type_str), # may be empty - ("db", _type_str), # may be empty - ("cache_key", _type_str), # opaque, as given by `cache_key` below -)) +_TENSORBOARD_INFO_FIELDS = collections.OrderedDict( + ( + ("version", _type_str), + ("start_time", _type_int), # seconds since epoch + ("pid", _type_int), + ("port", _type_int), + ("path_prefix", _type_str), # may be empty + ("logdir", _type_str), # may be empty + ("db", _type_str), # may be empty + ("cache_key", _type_str), # opaque, as given by `cache_key` below + ) +) TensorBoardInfo = collections.namedtuple( - "TensorBoardInfo", - _TENSORBOARD_INFO_FIELDS, + "TensorBoardInfo", _TENSORBOARD_INFO_FIELDS, ) def data_source_from_info(info): - """Format the data location for the given TensorBoardInfo. + """Format the data location for the given TensorBoardInfo. - Args: - info: A TensorBoardInfo value. + Args: + info: A TensorBoardInfo value. - Returns: - A human-readable string describing the logdir or database connection - used by the server: e.g., "logdir /tmp/logs". - """ - if info.db: - return "db %s" % info.db - else: - return "logdir %s" % info.logdir + Returns: + A human-readable string describing the logdir or database connection + used by the server: e.g., "logdir /tmp/logs". + """ + if info.db: + return "db %s" % info.db + else: + return "logdir %s" % info.logdir def _info_to_string(info): - """Convert a `TensorBoardInfo` to string form to be stored on disk. - - The format returned by this function is opaque and should only be - interpreted by `_info_from_string`. - - Args: - info: A valid `TensorBoardInfo` object. - - Raises: - ValueError: If any field on `info` is not of the correct type. - - Returns: - A string representation of the provided `TensorBoardInfo`. - """ - for key in _TENSORBOARD_INFO_FIELDS: - field_type = _TENSORBOARD_INFO_FIELDS[key] - if not isinstance(getattr(info, key), field_type.runtime_type): - raise ValueError( - "expected %r of type %s, but found: %r" % - (key, field_type.runtime_type, getattr(info, key)) - ) - if info.version != version.VERSION: - raise ValueError( - "expected 'version' to be %r, but found: %r" % - (version.VERSION, info.version) - ) - json_value = { - k: _TENSORBOARD_INFO_FIELDS[k].serialize(getattr(info, k)) - for k in _TENSORBOARD_INFO_FIELDS - } - return json.dumps(json_value, sort_keys=True, indent=4) + """Convert a `TensorBoardInfo` to string form to be stored on disk. + + The format returned by this function is opaque and should only be + interpreted by `_info_from_string`. + + Args: + info: A valid `TensorBoardInfo` object. + + Raises: + ValueError: If any field on `info` is not of the correct type. + + Returns: + A string representation of the provided `TensorBoardInfo`. + """ + for key in _TENSORBOARD_INFO_FIELDS: + field_type = _TENSORBOARD_INFO_FIELDS[key] + if not isinstance(getattr(info, key), field_type.runtime_type): + raise ValueError( + "expected %r of type %s, but found: %r" + % (key, field_type.runtime_type, getattr(info, key)) + ) + if info.version != version.VERSION: + raise ValueError( + "expected 'version' to be %r, but found: %r" + % (version.VERSION, info.version) + ) + json_value = { + k: _TENSORBOARD_INFO_FIELDS[k].serialize(getattr(info, k)) + for k in _TENSORBOARD_INFO_FIELDS + } + return json.dumps(json_value, sort_keys=True, indent=4) def _info_from_string(info_string): - """Parse a `TensorBoardInfo` object from its string representation. - - Args: - info_string: A string representation of a `TensorBoardInfo`, as - produced by a previous call to `_info_to_string`. - - Returns: - A `TensorBoardInfo` value. - - Raises: - ValueError: If the provided string is not valid JSON, or if it is - missing any required fields, or if any field is of incorrect type. - """ - - try: - json_value = json.loads(info_string) - except ValueError: - raise ValueError("invalid JSON: %r" % (info_string,)) - if not isinstance(json_value, dict): - raise ValueError("not a JSON object: %r" % (json_value,)) - expected_keys = frozenset(_TENSORBOARD_INFO_FIELDS) - actual_keys = frozenset(json_value) - missing_keys = expected_keys - actual_keys - if missing_keys: - raise ValueError( - "TensorBoardInfo missing keys: %r" - % (sorted(missing_keys),) - ) - # For forward compatibility, silently ignore unknown keys. + """Parse a `TensorBoardInfo` object from its string representation. - # Validate and deserialize fields. - fields = {} - for key in _TENSORBOARD_INFO_FIELDS: - field_type = _TENSORBOARD_INFO_FIELDS[key] - if not isinstance(json_value[key], field_type.serialized_type): - raise ValueError( - "expected %r of type %s, but found: %r" % - (key, field_type.serialized_type, json_value[key]) - ) - fields[key] = field_type.deserialize(json_value[key]) + Args: + info_string: A string representation of a `TensorBoardInfo`, as + produced by a previous call to `_info_to_string`. - return TensorBoardInfo(**fields) + Returns: + A `TensorBoardInfo` value. + + Raises: + ValueError: If the provided string is not valid JSON, or if it is + missing any required fields, or if any field is of incorrect type. + """ + + try: + json_value = json.loads(info_string) + except ValueError: + raise ValueError("invalid JSON: %r" % (info_string,)) + if not isinstance(json_value, dict): + raise ValueError("not a JSON object: %r" % (json_value,)) + expected_keys = frozenset(_TENSORBOARD_INFO_FIELDS) + actual_keys = frozenset(json_value) + missing_keys = expected_keys - actual_keys + if missing_keys: + raise ValueError( + "TensorBoardInfo missing keys: %r" % (sorted(missing_keys),) + ) + # For forward compatibility, silently ignore unknown keys. + + # Validate and deserialize fields. + fields = {} + for key in _TENSORBOARD_INFO_FIELDS: + field_type = _TENSORBOARD_INFO_FIELDS[key] + if not isinstance(json_value[key], field_type.serialized_type): + raise ValueError( + "expected %r of type %s, but found: %r" + % (key, field_type.serialized_type, json_value[key]) + ) + fields[key] = field_type.deserialize(json_value[key]) + + return TensorBoardInfo(**fields) def cache_key(working_directory, arguments, configure_kwargs): - """Compute a `TensorBoardInfo.cache_key` field. - - The format returned by this function is opaque. Clients may only - inspect it by comparing it for equality with other results from this - function. - - Args: - working_directory: The directory from which TensorBoard was launched - and relative to which paths like `--logdir` and `--db` are - resolved. - arguments: The command-line args to TensorBoard, as `sys.argv[1:]`. - Should be a list (or tuple), not an unparsed string. If you have a - raw shell command, use `shlex.split` before passing it to this - function. - configure_kwargs: A dictionary of additional argument values to - override the textual `arguments`, with the same semantics as in - `tensorboard.program.TensorBoard.configure`. May be an empty - dictionary. - - Returns: - A string such that if two (prospective or actual) TensorBoard - invocations have the same cache key then it is safe to use one in - place of the other. The converse is not guaranteed: it is often safe - to change the order of TensorBoard arguments, or to explicitly set - them to their default values, or to move them between `arguments` - and `configure_kwargs`, but such invocations may yield distinct - cache keys. - """ - if not isinstance(arguments, (list, tuple)): - raise TypeError( - "'arguments' should be a list of arguments, but found: %r " - "(use `shlex.split` if given a string)" - % (arguments,) + """Compute a `TensorBoardInfo.cache_key` field. + + The format returned by this function is opaque. Clients may only + inspect it by comparing it for equality with other results from this + function. + + Args: + working_directory: The directory from which TensorBoard was launched + and relative to which paths like `--logdir` and `--db` are + resolved. + arguments: The command-line args to TensorBoard, as `sys.argv[1:]`. + Should be a list (or tuple), not an unparsed string. If you have a + raw shell command, use `shlex.split` before passing it to this + function. + configure_kwargs: A dictionary of additional argument values to + override the textual `arguments`, with the same semantics as in + `tensorboard.program.TensorBoard.configure`. May be an empty + dictionary. + + Returns: + A string such that if two (prospective or actual) TensorBoard + invocations have the same cache key then it is safe to use one in + place of the other. The converse is not guaranteed: it is often safe + to change the order of TensorBoard arguments, or to explicitly set + them to their default values, or to move them between `arguments` + and `configure_kwargs`, but such invocations may yield distinct + cache keys. + """ + if not isinstance(arguments, (list, tuple)): + raise TypeError( + "'arguments' should be a list of arguments, but found: %r " + "(use `shlex.split` if given a string)" % (arguments,) + ) + datum = { + "working_directory": working_directory, + "arguments": arguments, + "configure_kwargs": configure_kwargs, + } + raw = base64.b64encode( + json.dumps(datum, sort_keys=True, separators=(",", ":")).encode("utf-8") ) - datum = { - "working_directory": working_directory, - "arguments": arguments, - "configure_kwargs": configure_kwargs, - } - raw = base64.b64encode( - json.dumps(datum, sort_keys=True, separators=(",", ":")).encode("utf-8") - ) - # `raw` is of type `bytes`, even though it only contains ASCII - # characters; we want it to be `str` in both Python 2 and 3. - return str(raw.decode("ascii")) + # `raw` is of type `bytes`, even though it only contains ASCII + # characters; we want it to be `str` in both Python 2 and 3. + return str(raw.decode("ascii")) def _get_info_dir(): - """Get path to directory in which to store info files. - - The directory returned by this function is "owned" by this module. If - the contents of the directory are modified other than via the public - functions of this module, subsequent behavior is undefined. - - The directory will be created if it does not exist. - """ - path = os.path.join(tempfile.gettempdir(), ".tensorboard-info") - try: - os.makedirs(path) - except OSError as e: - if e.errno == errno.EEXIST and os.path.isdir(path): - pass + """Get path to directory in which to store info files. + + The directory returned by this function is "owned" by this module. If + the contents of the directory are modified other than via the public + functions of this module, subsequent behavior is undefined. + + The directory will be created if it does not exist. + """ + path = os.path.join(tempfile.gettempdir(), ".tensorboard-info") + try: + os.makedirs(path) + except OSError as e: + if e.errno == errno.EEXIST and os.path.isdir(path): + pass + else: + raise else: - raise - else: - os.chmod(path, 0o777) - return path + os.chmod(path, 0o777) + return path def _get_info_file_path(): - """Get path to info file for the current process. + """Get path to info file for the current process. - As with `_get_info_dir`, the info directory will be created if it does - not exist. - """ - return os.path.join(_get_info_dir(), "pid-%d.info" % os.getpid()) + As with `_get_info_dir`, the info directory will be created if it + does not exist. + """ + return os.path.join(_get_info_dir(), "pid-%d.info" % os.getpid()) def write_info_file(tensorboard_info): - """Write TensorBoardInfo to the current process's info file. + """Write TensorBoardInfo to the current process's info file. - This should be called by `main` once the server is ready. When the - server shuts down, `remove_info_file` should be called. + This should be called by `main` once the server is ready. When the + server shuts down, `remove_info_file` should be called. - Args: - tensorboard_info: A valid `TensorBoardInfo` object. + Args: + tensorboard_info: A valid `TensorBoardInfo` object. - Raises: - ValueError: If any field on `info` is not of the correct type. - """ - payload = "%s\n" % _info_to_string(tensorboard_info) - with open(_get_info_file_path(), "w") as outfile: - outfile.write(payload) + Raises: + ValueError: If any field on `info` is not of the correct type. + """ + payload = "%s\n" % _info_to_string(tensorboard_info) + with open(_get_info_file_path(), "w") as outfile: + outfile.write(payload) def remove_info_file(): - """Remove the current process's TensorBoardInfo file, if it exists. - - If the file does not exist, no action is taken and no error is raised. - """ - try: - os.unlink(_get_info_file_path()) - except OSError as e: - if e.errno == errno.ENOENT: - # The user may have wiped their temporary directory or something. - # Not a problem: we're already in the state that we want to be in. - pass - else: - raise + """Remove the current process's TensorBoardInfo file, if it exists. + + If the file does not exist, no action is taken and no error is + raised. + """ + try: + os.unlink(_get_info_file_path()) + except OSError as e: + if e.errno == errno.ENOENT: + # The user may have wiped their temporary directory or something. + # Not a problem: we're already in the state that we want to be in. + pass + else: + raise def get_all(): - """Return TensorBoardInfo values for running TensorBoard processes. - - This function may not provide a perfect snapshot of the set of running - processes. Its result set may be incomplete if the user has cleaned - their /tmp/ directory while TensorBoard processes are running. It may - contain extraneous entries if TensorBoard processes exited uncleanly - (e.g., with SIGKILL or SIGQUIT). - - Entries in the info directory that do not represent valid - `TensorBoardInfo` values will be silently ignored. - - Returns: - A fresh list of `TensorBoardInfo` objects. - """ - info_dir = _get_info_dir() - results = [] - for filename in os.listdir(info_dir): - filepath = os.path.join(info_dir, filename) - try: - with open(filepath) as infile: - contents = infile.read() - except IOError as e: - if e.errno == errno.EACCES: - # May have been written by this module in a process whose - # `umask` includes some bits of 0o444. - continue - else: - raise - try: - info = _info_from_string(contents) - except ValueError: - # Ignore unrecognized files, logging at debug only. - tb_logging.get_logger().debug( - "invalid info file: %r", - filepath, - exc_info=True, - ) - else: - results.append(info) - return results + """Return TensorBoardInfo values for running TensorBoard processes. + + This function may not provide a perfect snapshot of the set of running + processes. Its result set may be incomplete if the user has cleaned + their /tmp/ directory while TensorBoard processes are running. It may + contain extraneous entries if TensorBoard processes exited uncleanly + (e.g., with SIGKILL or SIGQUIT). + + Entries in the info directory that do not represent valid + `TensorBoardInfo` values will be silently ignored. + + Returns: + A fresh list of `TensorBoardInfo` objects. + """ + info_dir = _get_info_dir() + results = [] + for filename in os.listdir(info_dir): + filepath = os.path.join(info_dir, filename) + try: + with open(filepath) as infile: + contents = infile.read() + except IOError as e: + if e.errno == errno.EACCES: + # May have been written by this module in a process whose + # `umask` includes some bits of 0o444. + continue + else: + raise + try: + info = _info_from_string(contents) + except ValueError: + # Ignore unrecognized files, logging at debug only. + tb_logging.get_logger().debug( + "invalid info file: %r", filepath, exc_info=True, + ) + else: + results.append(info) + return results # The following five types enumerate the possible return values of the @@ -375,102 +368,102 @@ def get_all(): def start(arguments, timeout=datetime.timedelta(seconds=60)): - """Start a new TensorBoard instance, or reuse a compatible one. - - If the cache key determined by the provided arguments and the current - working directory (see `cache_key`) matches the cache key of a running - TensorBoard process (see `get_all`), that process will be reused. - - Otherwise, a new TensorBoard process will be spawned with the provided - arguments, using the `tensorboard` binary from the system path. - - Args: - arguments: List of strings to be passed as arguments to - `tensorboard`. (If you have a raw command-line string, see - `shlex.split`.) - timeout: `datetime.timedelta` object describing how long to wait for - the subprocess to initialize a TensorBoard server and write its - `TensorBoardInfo` file. If the info file is not written within - this time period, `start` will assume that the subprocess is stuck - in a bad state, and will give up on waiting for it and return a - `StartTimedOut` result. Note that in such a case the subprocess - will not be killed. Default value is 60 seconds. - - Returns: - A `StartReused`, `StartLaunched`, `StartFailed`, or `StartTimedOut` - object. - """ - match = _find_matching_instance( - cache_key( - working_directory=os.getcwd(), - arguments=arguments, - configure_kwargs={}, - ), - ) - if match: - return StartReused(info=match) - - (stdout_fd, stdout_path) = tempfile.mkstemp(prefix=".tensorboard-stdout-") - (stderr_fd, stderr_path) = tempfile.mkstemp(prefix=".tensorboard-stderr-") - start_time_seconds = time.time() - explicit_tb = os.environ.get("TENSORBOARD_BINARY", None) - try: - p = subprocess.Popen( - ["tensorboard" if explicit_tb is None else explicit_tb] + arguments, - stdout=stdout_fd, - stderr=stderr_fd, + """Start a new TensorBoard instance, or reuse a compatible one. + + If the cache key determined by the provided arguments and the current + working directory (see `cache_key`) matches the cache key of a running + TensorBoard process (see `get_all`), that process will be reused. + + Otherwise, a new TensorBoard process will be spawned with the provided + arguments, using the `tensorboard` binary from the system path. + + Args: + arguments: List of strings to be passed as arguments to + `tensorboard`. (If you have a raw command-line string, see + `shlex.split`.) + timeout: `datetime.timedelta` object describing how long to wait for + the subprocess to initialize a TensorBoard server and write its + `TensorBoardInfo` file. If the info file is not written within + this time period, `start` will assume that the subprocess is stuck + in a bad state, and will give up on waiting for it and return a + `StartTimedOut` result. Note that in such a case the subprocess + will not be killed. Default value is 60 seconds. + + Returns: + A `StartReused`, `StartLaunched`, `StartFailed`, or `StartTimedOut` + object. + """ + match = _find_matching_instance( + cache_key( + working_directory=os.getcwd(), + arguments=arguments, + configure_kwargs={}, + ), ) - except OSError as e: - return StartExecFailed(os_error=e, explicit_binary=explicit_tb) - finally: - os.close(stdout_fd) - os.close(stderr_fd) - - poll_interval_seconds = 0.5 - end_time_seconds = start_time_seconds + timeout.total_seconds() - while time.time() < end_time_seconds: - time.sleep(poll_interval_seconds) - subprocess_result = p.poll() - if subprocess_result is not None: - return StartFailed( - exit_code=subprocess_result, - stdout=_maybe_read_file(stdout_path), - stderr=_maybe_read_file(stderr_path), - ) - for info in get_all(): - if info.pid == p.pid and info.start_time >= start_time_seconds: - return StartLaunched(info=info) - else: - return StartTimedOut(pid=p.pid) + if match: + return StartReused(info=match) + + (stdout_fd, stdout_path) = tempfile.mkstemp(prefix=".tensorboard-stdout-") + (stderr_fd, stderr_path) = tempfile.mkstemp(prefix=".tensorboard-stderr-") + start_time_seconds = time.time() + explicit_tb = os.environ.get("TENSORBOARD_BINARY", None) + try: + p = subprocess.Popen( + ["tensorboard" if explicit_tb is None else explicit_tb] + arguments, + stdout=stdout_fd, + stderr=stderr_fd, + ) + except OSError as e: + return StartExecFailed(os_error=e, explicit_binary=explicit_tb) + finally: + os.close(stdout_fd) + os.close(stderr_fd) + + poll_interval_seconds = 0.5 + end_time_seconds = start_time_seconds + timeout.total_seconds() + while time.time() < end_time_seconds: + time.sleep(poll_interval_seconds) + subprocess_result = p.poll() + if subprocess_result is not None: + return StartFailed( + exit_code=subprocess_result, + stdout=_maybe_read_file(stdout_path), + stderr=_maybe_read_file(stderr_path), + ) + for info in get_all(): + if info.pid == p.pid and info.start_time >= start_time_seconds: + return StartLaunched(info=info) + else: + return StartTimedOut(pid=p.pid) def _find_matching_instance(cache_key): - """Find a running TensorBoard instance compatible with the cache key. + """Find a running TensorBoard instance compatible with the cache key. - Returns: - A `TensorBoardInfo` object, or `None` if none matches the cache key. - """ - infos = get_all() - candidates = [info for info in infos if info.cache_key == cache_key] - for candidate in sorted(candidates, key=lambda x: x.port): - # TODO(@wchargin): Check here that the provided port is still live. - return candidate - return None + Returns: + A `TensorBoardInfo` object, or `None` if none matches the cache key. + """ + infos = get_all() + candidates = [info for info in infos if info.cache_key == cache_key] + for candidate in sorted(candidates, key=lambda x: x.port): + # TODO(@wchargin): Check here that the provided port is still live. + return candidate + return None def _maybe_read_file(filename): - """Read the given file, if it exists. - - Args: - filename: A path to a file. - - Returns: - A string containing the file contents, or `None` if the file does - not exist. - """ - try: - with open(filename) as infile: - return infile.read() - except IOError as e: - if e.errno == errno.ENOENT: - return None + """Read the given file, if it exists. + + Args: + filename: A path to a file. + + Returns: + A string containing the file contents, or `None` if the file does + not exist. + """ + try: + with open(filename) as infile: + return infile.read() + except IOError as e: + if e.errno == errno.ENOENT: + return None diff --git a/tensorboard/manager_e2e_test.py b/tensorboard/manager_e2e_test.py index 463054c482..9b8843ca4d 100644 --- a/tensorboard/manager_e2e_test.py +++ b/tensorboard/manager_e2e_test.py @@ -35,338 +35,344 @@ import tensorflow as tf try: - # python version >= 3.3 - from unittest import mock + # python version >= 3.3 + from unittest import mock except ImportError: - import mock # pylint: disable=unused-import + import mock # pylint: disable=unused-import from tensorboard import manager class ManagerEndToEndTest(tf.test.TestCase): - - def setUp(self): - super(ManagerEndToEndTest, self).setUp() - - # Spy on subprocesses spawned so that we can kill them. - self.popens = [] - class PopenSpy(subprocess.Popen): - def __init__(p, *args, **kwargs): - super(PopenSpy, p).__init__(*args, **kwargs) - self.popens.append(p) - popen_patcher = mock.patch.object(subprocess, "Popen", PopenSpy) - popen_patcher.start() - - # Make sure that temporary files (including .tensorboard-info) are - # created under our purview. - self.tmproot = os.path.join(self.get_temp_dir(), "tmproot") - os.mkdir(self.tmproot) - self._patch_environ({"TMPDIR": self.tmproot}) - tempfile.tempdir = None # force `gettempdir` to reinitialize from env - self.assertEqual(tempfile.gettempdir(), self.tmproot) - self.info_dir = manager._get_info_dir() # ensure that directory exists - - # Add our Bazel-provided `tensorboard` to the system path. (The - # //tensorboard:tensorboard target is made available in the same - # directory as //tensorboard:manager_e2e_test.) - tensorboard_binary_dir = os.path.dirname(os.environ["TEST_BINARY"]) - self._patch_environ({ - "PATH": os.pathsep.join((tensorboard_binary_dir, os.environ["PATH"])), - }) - self._ensure_tensorboard_on_path(tensorboard_binary_dir) - - def tearDown(self): - failed_kills = [] - for p in self.popens: - try: - p.kill() - except Exception as e: - if isinstance(e, OSError) and e.errno == errno.ESRCH: - # ESRCH 3 No such process: e.g., it already exited. - pass - else: - # We really want to make sure to try to kill all these - # processes. Continue killing; fail the test later. - failed_kills.append(e) - for p in self.popens: - p.wait() - self.assertEqual(failed_kills, []) - - def _patch_environ(self, partial_environ): - patcher = mock.patch.dict(os.environ, partial_environ) - patcher.start() - self.addCleanup(patcher.stop) - - def _ensure_tensorboard_on_path(self, expected_binary_dir): - """Ensure that `tensorboard(1)` refers to our own binary. - - Raises: - subprocess.CalledProcessError: If there is no `tensorboard` on the - path. - AssertionError: If the `tensorboard` on the path is not under the - provided directory. - """ - # In Python 3.3+, we could use `shutil.which` to inspect the path. - # For Python 2 compatibility, we shell out to the host platform's - # standard utility. - command = "where" if os.name == "nt" else "which" - binary = subprocess.check_output([command, "tensorboard"]) - self.assertTrue( - binary.startswith(expected_binary_dir.encode("utf-8")), - "expected %r to start with %r" % (binary, expected_binary_dir), - ) - - def _stub_tensorboard(self, name, program): - """Install a stub version of TensorBoard. - - Args: - name: A short description of the stub's behavior. This will appear - in the file path, which may appear in error messages. - program: The contents of the stub: this should probably be a - string that starts with "#!/bin/sh" and then contains a POSIX - shell script. - """ - tempdir = tempfile.mkdtemp(prefix="tensorboard-stub-%s-" % name) - # (this directory is under our test directory; no need to clean it up) - filepath = os.path.join(tempdir, "tensorboard") - with open(filepath, "w") as outfile: - outfile.write(program) - os.chmod(filepath, 0o777) - self._patch_environ({ - "PATH": os.pathsep.join((tempdir, os.environ["PATH"])), - }) - self._ensure_tensorboard_on_path(expected_binary_dir=tempdir) - - def _assert_live(self, info, expected_logdir): - url = "http://localhost:%d%s/data/logdir" % (info.port, info.path_prefix) - with contextlib.closing(urllib.request.urlopen(url)) as infile: - data = infile.read() - self.assertEqual(json.loads(data), {"logdir": expected_logdir}) - - def test_simple_start(self): - start_result = manager.start(["--logdir=./logs", "--port=0"]) - self.assertIsInstance(start_result, manager.StartLaunched) - self._assert_live(start_result.info, expected_logdir="./logs") - - def test_reuse(self): - r1 = manager.start(["--logdir=./logs", "--port=0"]) - self.assertIsInstance(r1, manager.StartLaunched) - r2 = manager.start(["--logdir=./logs", "--port=0"]) - self.assertIsInstance(r2, manager.StartReused) - self.assertEqual(r1.info, r2.info) - infos = manager.get_all() - self.assertEqual(infos, [r1.info]) - self._assert_live(r1.info, expected_logdir="./logs") - - def test_launch_new_because_incompatible(self): - r1 = manager.start(["--logdir=./logs", "--port=0"]) - self.assertIsInstance(r1, manager.StartLaunched) - r2 = manager.start(["--logdir=./adders", "--port=0"]) - self.assertIsInstance(r2, manager.StartLaunched) - self.assertNotEqual(r1.info.port, r2.info.port) - self.assertNotEqual(r1.info.pid, r2.info.pid) - infos = manager.get_all() - self.assertItemsEqual(infos, [r1.info, r2.info]) - self._assert_live(r1.info, expected_logdir="./logs") - self._assert_live(r2.info, expected_logdir="./adders") - - def test_launch_new_because_info_file_deleted(self): - r1 = manager.start(["--logdir=./logs", "--port=0"]) - self.assertIsInstance(r1, manager.StartLaunched) - - # Now suppose that someone comes and wipes /tmp/... - self.assertEqual(len(manager.get_all()), 1, manager.get_all()) - shutil.rmtree(self.tmproot) - os.mkdir(self.tmproot) - self.assertEqual(len(manager.get_all()), 0, manager.get_all()) - - # ...so that starting even the same command forces a relaunch: - r2 = manager.start(["--logdir=./logs", "--port=0"]) - self.assertIsInstance(r2, manager.StartLaunched) # (picked a new port) - self.assertEqual(r1.info.cache_key, r2.info.cache_key) - infos = manager.get_all() - self.assertItemsEqual(infos, [r2.info]) - self._assert_live(r1.info, expected_logdir="./logs") - self._assert_live(r2.info, expected_logdir="./logs") - - def test_reuse_after_kill(self): - if os.name == "nt": - self.skipTest("Can't send SIGTERM or SIGINT on Windows.") - r1 = manager.start(["--logdir=./logs", "--port=0"]) - self.assertIsInstance(r1, manager.StartLaunched) - os.kill(r1.info.pid, signal.SIGTERM) - os.waitpid(r1.info.pid, 0) - r2 = manager.start(["--logdir=./logs", "--port=0"]) - self.assertIsInstance(r2, manager.StartLaunched) - self.assertEqual(r1.info.cache_key, r2.info.cache_key) - # It's not technically guaranteed by POSIX that the following holds, - # but it will unless the OS preemptively recycles PIDs or we somehow - # cycled exactly through the whole PID space. Neither Linux nor - # macOS recycles PIDs, so we should be fine. - self.assertNotEqual(r1.info.pid, r2.info.pid) - self._assert_live(r2.info, expected_logdir="./logs") - - def test_exit_failure(self): - if os.name == "nt": - # TODO(@wchargin): This could in principle work on Windows. - self.skipTest("Requires a POSIX shell for the stub script.") - self._stub_tensorboard( - name="fail-with-77", - program=textwrap.dedent( - r""" + def setUp(self): + super(ManagerEndToEndTest, self).setUp() + + # Spy on subprocesses spawned so that we can kill them. + self.popens = [] + + class PopenSpy(subprocess.Popen): + def __init__(p, *args, **kwargs): + super(PopenSpy, p).__init__(*args, **kwargs) + self.popens.append(p) + + popen_patcher = mock.patch.object(subprocess, "Popen", PopenSpy) + popen_patcher.start() + + # Make sure that temporary files (including .tensorboard-info) are + # created under our purview. + self.tmproot = os.path.join(self.get_temp_dir(), "tmproot") + os.mkdir(self.tmproot) + self._patch_environ({"TMPDIR": self.tmproot}) + tempfile.tempdir = None # force `gettempdir` to reinitialize from env + self.assertEqual(tempfile.gettempdir(), self.tmproot) + self.info_dir = manager._get_info_dir() # ensure that directory exists + + # Add our Bazel-provided `tensorboard` to the system path. (The + # //tensorboard:tensorboard target is made available in the same + # directory as //tensorboard:manager_e2e_test.) + tensorboard_binary_dir = os.path.dirname(os.environ["TEST_BINARY"]) + self._patch_environ( + { + "PATH": os.pathsep.join( + (tensorboard_binary_dir, os.environ["PATH"]) + ), + } + ) + self._ensure_tensorboard_on_path(tensorboard_binary_dir) + + def tearDown(self): + failed_kills = [] + for p in self.popens: + try: + p.kill() + except Exception as e: + if isinstance(e, OSError) and e.errno == errno.ESRCH: + # ESRCH 3 No such process: e.g., it already exited. + pass + else: + # We really want to make sure to try to kill all these + # processes. Continue killing; fail the test later. + failed_kills.append(e) + for p in self.popens: + p.wait() + self.assertEqual(failed_kills, []) + + def _patch_environ(self, partial_environ): + patcher = mock.patch.dict(os.environ, partial_environ) + patcher.start() + self.addCleanup(patcher.stop) + + def _ensure_tensorboard_on_path(self, expected_binary_dir): + """Ensure that `tensorboard(1)` refers to our own binary. + + Raises: + subprocess.CalledProcessError: If there is no `tensorboard` on the + path. + AssertionError: If the `tensorboard` on the path is not under the + provided directory. + """ + # In Python 3.3+, we could use `shutil.which` to inspect the path. + # For Python 2 compatibility, we shell out to the host platform's + # standard utility. + command = "where" if os.name == "nt" else "which" + binary = subprocess.check_output([command, "tensorboard"]) + self.assertTrue( + binary.startswith(expected_binary_dir.encode("utf-8")), + "expected %r to start with %r" % (binary, expected_binary_dir), + ) + + def _stub_tensorboard(self, name, program): + """Install a stub version of TensorBoard. + + Args: + name: A short description of the stub's behavior. This will appear + in the file path, which may appear in error messages. + program: The contents of the stub: this should probably be a + string that starts with "#!/bin/sh" and then contains a POSIX + shell script. + """ + tempdir = tempfile.mkdtemp(prefix="tensorboard-stub-%s-" % name) + # (this directory is under our test directory; no need to clean it up) + filepath = os.path.join(tempdir, "tensorboard") + with open(filepath, "w") as outfile: + outfile.write(program) + os.chmod(filepath, 0o777) + self._patch_environ( + {"PATH": os.pathsep.join((tempdir, os.environ["PATH"])),} + ) + self._ensure_tensorboard_on_path(expected_binary_dir=tempdir) + + def _assert_live(self, info, expected_logdir): + url = "http://localhost:%d%s/data/logdir" % ( + info.port, + info.path_prefix, + ) + with contextlib.closing(urllib.request.urlopen(url)) as infile: + data = infile.read() + self.assertEqual(json.loads(data), {"logdir": expected_logdir}) + + def test_simple_start(self): + start_result = manager.start(["--logdir=./logs", "--port=0"]) + self.assertIsInstance(start_result, manager.StartLaunched) + self._assert_live(start_result.info, expected_logdir="./logs") + + def test_reuse(self): + r1 = manager.start(["--logdir=./logs", "--port=0"]) + self.assertIsInstance(r1, manager.StartLaunched) + r2 = manager.start(["--logdir=./logs", "--port=0"]) + self.assertIsInstance(r2, manager.StartReused) + self.assertEqual(r1.info, r2.info) + infos = manager.get_all() + self.assertEqual(infos, [r1.info]) + self._assert_live(r1.info, expected_logdir="./logs") + + def test_launch_new_because_incompatible(self): + r1 = manager.start(["--logdir=./logs", "--port=0"]) + self.assertIsInstance(r1, manager.StartLaunched) + r2 = manager.start(["--logdir=./adders", "--port=0"]) + self.assertIsInstance(r2, manager.StartLaunched) + self.assertNotEqual(r1.info.port, r2.info.port) + self.assertNotEqual(r1.info.pid, r2.info.pid) + infos = manager.get_all() + self.assertItemsEqual(infos, [r1.info, r2.info]) + self._assert_live(r1.info, expected_logdir="./logs") + self._assert_live(r2.info, expected_logdir="./adders") + + def test_launch_new_because_info_file_deleted(self): + r1 = manager.start(["--logdir=./logs", "--port=0"]) + self.assertIsInstance(r1, manager.StartLaunched) + + # Now suppose that someone comes and wipes /tmp/... + self.assertEqual(len(manager.get_all()), 1, manager.get_all()) + shutil.rmtree(self.tmproot) + os.mkdir(self.tmproot) + self.assertEqual(len(manager.get_all()), 0, manager.get_all()) + + # ...so that starting even the same command forces a relaunch: + r2 = manager.start(["--logdir=./logs", "--port=0"]) + self.assertIsInstance(r2, manager.StartLaunched) # (picked a new port) + self.assertEqual(r1.info.cache_key, r2.info.cache_key) + infos = manager.get_all() + self.assertItemsEqual(infos, [r2.info]) + self._assert_live(r1.info, expected_logdir="./logs") + self._assert_live(r2.info, expected_logdir="./logs") + + def test_reuse_after_kill(self): + if os.name == "nt": + self.skipTest("Can't send SIGTERM or SIGINT on Windows.") + r1 = manager.start(["--logdir=./logs", "--port=0"]) + self.assertIsInstance(r1, manager.StartLaunched) + os.kill(r1.info.pid, signal.SIGTERM) + os.waitpid(r1.info.pid, 0) + r2 = manager.start(["--logdir=./logs", "--port=0"]) + self.assertIsInstance(r2, manager.StartLaunched) + self.assertEqual(r1.info.cache_key, r2.info.cache_key) + # It's not technically guaranteed by POSIX that the following holds, + # but it will unless the OS preemptively recycles PIDs or we somehow + # cycled exactly through the whole PID space. Neither Linux nor + # macOS recycles PIDs, so we should be fine. + self.assertNotEqual(r1.info.pid, r2.info.pid) + self._assert_live(r2.info, expected_logdir="./logs") + + def test_exit_failure(self): + if os.name == "nt": + # TODO(@wchargin): This could in principle work on Windows. + self.skipTest("Requires a POSIX shell for the stub script.") + self._stub_tensorboard( + name="fail-with-77", + program=textwrap.dedent( + r""" #!/bin/sh printf >&2 'fatal: something bad happened\n' printf 'also some stdout\n' exit 77 """.lstrip(), - ), - ) - start_result = manager.start(["--logdir=./logs", "--port=0"]) - self.assertIsInstance(start_result, manager.StartFailed) - self.assertEqual( - start_result, - manager.StartFailed( - exit_code=77, - stderr="fatal: something bad happened\n", - stdout="also some stdout\n", - ), - ) - self.assertEqual(manager.get_all(), []) - - def test_exit_success(self): - # TensorBoard exiting with success but not writing the info file is - # still a failure to launch. - if os.name == "nt": - # TODO(@wchargin): This could in principle work on Windows. - self.skipTest("Requires a POSIX shell for the stub script.") - self._stub_tensorboard( - name="fail-with-0", - program=textwrap.dedent( - r""" + ), + ) + start_result = manager.start(["--logdir=./logs", "--port=0"]) + self.assertIsInstance(start_result, manager.StartFailed) + self.assertEqual( + start_result, + manager.StartFailed( + exit_code=77, + stderr="fatal: something bad happened\n", + stdout="also some stdout\n", + ), + ) + self.assertEqual(manager.get_all(), []) + + def test_exit_success(self): + # TensorBoard exiting with success but not writing the info file is + # still a failure to launch. + if os.name == "nt": + # TODO(@wchargin): This could in principle work on Windows. + self.skipTest("Requires a POSIX shell for the stub script.") + self._stub_tensorboard( + name="fail-with-0", + program=textwrap.dedent( + r""" #!/bin/sh printf >&2 'info: something good happened\n' printf 'also some standard output\n' exit 0 """.lstrip(), - ), - ) - start_result = manager.start(["--logdir=./logs", "--port=0"]) - self.assertIsInstance(start_result, manager.StartFailed) - self.assertEqual( - start_result, - manager.StartFailed( - exit_code=0, - stderr="info: something good happened\n", - stdout="also some standard output\n", - ), - ) - self.assertEqual(manager.get_all(), []) - - def test_failure_unreadable_stdio(self): - if os.name == "nt": - # TODO(@wchargin): This could in principle work on Windows. - self.skipTest("Requires a POSIX shell for the stub script.") - self._stub_tensorboard( - name="fail-and-nuke-tmp", - program=textwrap.dedent( - r""" + ), + ) + start_result = manager.start(["--logdir=./logs", "--port=0"]) + self.assertIsInstance(start_result, manager.StartFailed) + self.assertEqual( + start_result, + manager.StartFailed( + exit_code=0, + stderr="info: something good happened\n", + stdout="also some standard output\n", + ), + ) + self.assertEqual(manager.get_all(), []) + + def test_failure_unreadable_stdio(self): + if os.name == "nt": + # TODO(@wchargin): This could in principle work on Windows. + self.skipTest("Requires a POSIX shell for the stub script.") + self._stub_tensorboard( + name="fail-and-nuke-tmp", + program=textwrap.dedent( + r""" #!/bin/sh rm -r %s exit 22 - """.lstrip() % pipes.quote(self.tmproot), - ), - ) - start_result = manager.start(["--logdir=./logs", "--port=0"]) - self.assertIsInstance(start_result, manager.StartFailed) - self.assertEqual( - start_result, - manager.StartFailed( - exit_code=22, - stderr=None, - stdout=None, - ), - ) - self.assertEqual(manager.get_all(), []) - - def test_timeout(self): - if os.name == "nt": - # TODO(@wchargin): This could in principle work on Windows. - self.skipTest("Requires a POSIX shell for the stub script.") - tempdir = tempfile.mkdtemp() - pid_file = os.path.join(tempdir, "pidfile") - self._stub_tensorboard( - name="wait-a-minute", - program=textwrap.dedent( - r""" + """.lstrip() + % pipes.quote(self.tmproot), + ), + ) + start_result = manager.start(["--logdir=./logs", "--port=0"]) + self.assertIsInstance(start_result, manager.StartFailed) + self.assertEqual( + start_result, + manager.StartFailed(exit_code=22, stderr=None, stdout=None,), + ) + self.assertEqual(manager.get_all(), []) + + def test_timeout(self): + if os.name == "nt": + # TODO(@wchargin): This could in principle work on Windows. + self.skipTest("Requires a POSIX shell for the stub script.") + tempdir = tempfile.mkdtemp() + pid_file = os.path.join(tempdir, "pidfile") + self._stub_tensorboard( + name="wait-a-minute", + program=textwrap.dedent( + r""" #!/bin/sh printf >%s '%%s' "$$" printf >&2 'warn: I am tired\n' sleep 60 - """.lstrip() % pipes.quote(os.path.realpath(pid_file)), - ), - ) - start_result = manager.start( - ["--logdir=./logs", "--port=0"], - timeout=datetime.timedelta(seconds=1), - ) - self.assertIsInstance(start_result, manager.StartTimedOut) - with open(pid_file) as infile: - expected_pid = int(infile.read()) - self.assertEqual(start_result, manager.StartTimedOut(pid=expected_pid)) - self.assertEqual(manager.get_all(), []) - - def test_tensorboard_binary_environment_variable(self): - if os.name == "nt": - # TODO(@wchargin): This could in principle work on Windows. - self.skipTest("Requires a POSIX shell for the stub script.") - tempdir = tempfile.mkdtemp() - filepath = os.path.join(tempdir, "tensorbad") - program = textwrap.dedent( - r""" + """.lstrip() + % pipes.quote(os.path.realpath(pid_file)), + ), + ) + start_result = manager.start( + ["--logdir=./logs", "--port=0"], + timeout=datetime.timedelta(seconds=1), + ) + self.assertIsInstance(start_result, manager.StartTimedOut) + with open(pid_file) as infile: + expected_pid = int(infile.read()) + self.assertEqual(start_result, manager.StartTimedOut(pid=expected_pid)) + self.assertEqual(manager.get_all(), []) + + def test_tensorboard_binary_environment_variable(self): + if os.name == "nt": + # TODO(@wchargin): This could in principle work on Windows. + self.skipTest("Requires a POSIX shell for the stub script.") + tempdir = tempfile.mkdtemp() + filepath = os.path.join(tempdir, "tensorbad") + program = textwrap.dedent( + r""" #!/bin/sh printf >&2 'tensorbad: fatal: something bad happened\n' printf 'tensorbad: also some stdout\n' exit 77 """.lstrip() - ) - with open(filepath, "w") as outfile: - outfile.write(program) - os.chmod(filepath, 0o777) - self._patch_environ({"TENSORBOARD_BINARY": filepath}) - - start_result = manager.start(["--logdir=./logs", "--port=0"]) - self.assertIsInstance(start_result, manager.StartFailed) - self.assertEqual( - start_result, - manager.StartFailed( - exit_code=77, - stderr="tensorbad: fatal: something bad happened\n", - stdout="tensorbad: also some stdout\n", - ), - ) - self.assertEqual(manager.get_all(), []) - - def test_exec_failure_with_explicit_binary(self): - path = os.path.join(".", "non", "existent") - self._patch_environ({"TENSORBOARD_BINARY": path}) - - start_result = manager.start(["--logdir=./logs", "--port=0"]) - self.assertIsInstance(start_result, manager.StartExecFailed) - self.assertEqual(start_result.os_error.errno, errno.ENOENT) - self.assertEqual(start_result.explicit_binary, path) - - def test_exec_failure_with_no_explicit_binary(self): - if os.name == "nt": - # Can't use ENOENT without an absolute path (it's not treated as - # an exec failure). - self.skipTest("Not clear how to trigger this case on Windows.") - self._patch_environ({"PATH": "nope"}) - - start_result = manager.start(["--logdir=./logs", "--port=0"]) - self.assertIsInstance(start_result, manager.StartExecFailed) - self.assertEqual(start_result.os_error.errno, errno.ENOENT) - self.assertIs(start_result.explicit_binary, None) + ) + with open(filepath, "w") as outfile: + outfile.write(program) + os.chmod(filepath, 0o777) + self._patch_environ({"TENSORBOARD_BINARY": filepath}) + + start_result = manager.start(["--logdir=./logs", "--port=0"]) + self.assertIsInstance(start_result, manager.StartFailed) + self.assertEqual( + start_result, + manager.StartFailed( + exit_code=77, + stderr="tensorbad: fatal: something bad happened\n", + stdout="tensorbad: also some stdout\n", + ), + ) + self.assertEqual(manager.get_all(), []) + + def test_exec_failure_with_explicit_binary(self): + path = os.path.join(".", "non", "existent") + self._patch_environ({"TENSORBOARD_BINARY": path}) + + start_result = manager.start(["--logdir=./logs", "--port=0"]) + self.assertIsInstance(start_result, manager.StartExecFailed) + self.assertEqual(start_result.os_error.errno, errno.ENOENT) + self.assertEqual(start_result.explicit_binary, path) + + def test_exec_failure_with_no_explicit_binary(self): + if os.name == "nt": + # Can't use ENOENT without an absolute path (it's not treated as + # an exec failure). + self.skipTest("Not clear how to trigger this case on Windows.") + self._patch_environ({"PATH": "nope"}) + + start_result = manager.start(["--logdir=./logs", "--port=0"]) + self.assertIsInstance(start_result, manager.StartExecFailed) + self.assertEqual(start_result.os_error.errno, errno.ENOENT) + self.assertIs(start_result.explicit_binary, None) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/tensorboard/manager_test.py b/tensorboard/manager_test.py index 805d867e76..d76cc40751 100644 --- a/tensorboard/manager_test.py +++ b/tensorboard/manager_test.py @@ -28,10 +28,10 @@ import six try: - # python version >= 3.3 - from unittest import mock + # python version >= 3.3 + from unittest import mock except ImportError: - import mock # pylint: disable=unused-import + import mock # pylint: disable=unused-import from tensorboard import manager from tensorboard import test as tb_test @@ -40,365 +40,365 @@ def _make_info(i=0): - """Make a sample TensorBoardInfo object. - - Args: - i: Seed; vary this value to produce slightly different outputs. - - Returns: - A type-correct `TensorBoardInfo` object. - """ - return manager.TensorBoardInfo( - version=version.VERSION, - start_time=1548973541 + i, - port=6060 + i, - pid=76540 + i, - path_prefix="/foo", - logdir="~/my_data/", - db="", - cache_key="asdf", - ) + """Make a sample TensorBoardInfo object. + + Args: + i: Seed; vary this value to produce slightly different outputs. + + Returns: + A type-correct `TensorBoardInfo` object. + """ + return manager.TensorBoardInfo( + version=version.VERSION, + start_time=1548973541 + i, + port=6060 + i, + pid=76540 + i, + path_prefix="/foo", + logdir="~/my_data/", + db="", + cache_key="asdf", + ) class TensorBoardInfoTest(tb_test.TestCase): - """Unit tests for TensorBoardInfo typechecking and serialization.""" - - def test_roundtrip_serialization(self): - # This is also tested indirectly as part of `manager` integration - # tests, in `test_get_all`. - info = _make_info() - also_info = manager._info_from_string(manager._info_to_string(info)) - self.assertEqual(also_info, info) - - def test_serialization_rejects_bad_types(self): - bad_time = datetime.datetime.fromtimestamp(1549061116) # not an int - info = _make_info()._replace(start_time=bad_time) - with six.assertRaisesRegex( - self, - ValueError, - r"expected 'start_time' of type.*int.*, but found: datetime\."): - manager._info_to_string(info) - - def test_serialization_rejects_wrong_version(self): - info = _make_info()._replace(version="reversion") - with six.assertRaisesRegex( - self, - ValueError, - "expected 'version' to be '.*', but found: 'reversion'"): - manager._info_to_string(info) - - def test_deserialization_rejects_bad_json(self): - bad_input = "parse me if you dare" - with six.assertRaisesRegex( - self, - ValueError, - "invalid JSON:"): - manager._info_from_string(bad_input) - - def test_deserialization_rejects_non_object_json(self): - bad_input = "[1, 2]" - with six.assertRaisesRegex( - self, - ValueError, - re.escape("not a JSON object: [1, 2]")): - manager._info_from_string(bad_input) - - def test_deserialization_rejects_missing_version(self): - info = _make_info() - json_value = json.loads(manager._info_to_string(info)) - del json_value["version"] - bad_input = json.dumps(json_value) - with six.assertRaisesRegex( - self, - ValueError, - re.escape("missing keys: ['version']")): - manager._info_from_string(bad_input) - - def test_deserialization_accepts_future_version(self): - info = _make_info() - json_value = json.loads(manager._info_to_string(info)) - json_value["version"] = "99.99.99a20991232" - input_ = json.dumps(json_value) - result = manager._info_from_string(input_) - self.assertEqual(result.version, "99.99.99a20991232") - - def test_deserialization_ignores_extra_keys(self): - info = _make_info() - json_value = json.loads(manager._info_to_string(info)) - json_value["unlikely"] = "story" - bad_input = json.dumps(json_value) - result = manager._info_from_string(bad_input) - self.assertIsInstance(result, manager.TensorBoardInfo) - - def test_deserialization_rejects_missing_keys(self): - info = _make_info() - json_value = json.loads(manager._info_to_string(info)) - del json_value["start_time"] - bad_input = json.dumps(json_value) - with six.assertRaisesRegex( - self, - ValueError, - re.escape("missing keys: ['start_time']")): - manager._info_from_string(bad_input) - - def test_deserialization_rejects_bad_types(self): - info = _make_info() - json_value = json.loads(manager._info_to_string(info)) - json_value["start_time"] = "2001-02-03T04:05:06" - bad_input = json.dumps(json_value) - with six.assertRaisesRegex( - self, - ValueError, - "expected 'start_time' of type.*int.*, but found:.*" - "'2001-02-03T04:05:06'"): - manager._info_from_string(bad_input) - - def test_logdir_data_source_format(self): - info = _make_info()._replace(logdir="~/foo", db="") - self.assertEqual(manager.data_source_from_info(info), "logdir ~/foo") - - def test_db_data_source_format(self): - info = _make_info()._replace(logdir="", db="sqlite:~/bar") - self.assertEqual(manager.data_source_from_info(info), "db sqlite:~/bar") + """Unit tests for TensorBoardInfo typechecking and serialization.""" + + def test_roundtrip_serialization(self): + # This is also tested indirectly as part of `manager` integration + # tests, in `test_get_all`. + info = _make_info() + also_info = manager._info_from_string(manager._info_to_string(info)) + self.assertEqual(also_info, info) + + def test_serialization_rejects_bad_types(self): + bad_time = datetime.datetime.fromtimestamp(1549061116) # not an int + info = _make_info()._replace(start_time=bad_time) + with six.assertRaisesRegex( + self, + ValueError, + r"expected 'start_time' of type.*int.*, but found: datetime\.", + ): + manager._info_to_string(info) + + def test_serialization_rejects_wrong_version(self): + info = _make_info()._replace(version="reversion") + with six.assertRaisesRegex( + self, + ValueError, + "expected 'version' to be '.*', but found: 'reversion'", + ): + manager._info_to_string(info) + + def test_deserialization_rejects_bad_json(self): + bad_input = "parse me if you dare" + with six.assertRaisesRegex(self, ValueError, "invalid JSON:"): + manager._info_from_string(bad_input) + + def test_deserialization_rejects_non_object_json(self): + bad_input = "[1, 2]" + with six.assertRaisesRegex( + self, ValueError, re.escape("not a JSON object: [1, 2]") + ): + manager._info_from_string(bad_input) + + def test_deserialization_rejects_missing_version(self): + info = _make_info() + json_value = json.loads(manager._info_to_string(info)) + del json_value["version"] + bad_input = json.dumps(json_value) + with six.assertRaisesRegex( + self, ValueError, re.escape("missing keys: ['version']") + ): + manager._info_from_string(bad_input) + + def test_deserialization_accepts_future_version(self): + info = _make_info() + json_value = json.loads(manager._info_to_string(info)) + json_value["version"] = "99.99.99a20991232" + input_ = json.dumps(json_value) + result = manager._info_from_string(input_) + self.assertEqual(result.version, "99.99.99a20991232") + + def test_deserialization_ignores_extra_keys(self): + info = _make_info() + json_value = json.loads(manager._info_to_string(info)) + json_value["unlikely"] = "story" + bad_input = json.dumps(json_value) + result = manager._info_from_string(bad_input) + self.assertIsInstance(result, manager.TensorBoardInfo) + + def test_deserialization_rejects_missing_keys(self): + info = _make_info() + json_value = json.loads(manager._info_to_string(info)) + del json_value["start_time"] + bad_input = json.dumps(json_value) + with six.assertRaisesRegex( + self, ValueError, re.escape("missing keys: ['start_time']") + ): + manager._info_from_string(bad_input) + + def test_deserialization_rejects_bad_types(self): + info = _make_info() + json_value = json.loads(manager._info_to_string(info)) + json_value["start_time"] = "2001-02-03T04:05:06" + bad_input = json.dumps(json_value) + with six.assertRaisesRegex( + self, + ValueError, + "expected 'start_time' of type.*int.*, but found:.*" + "'2001-02-03T04:05:06'", + ): + manager._info_from_string(bad_input) + + def test_logdir_data_source_format(self): + info = _make_info()._replace(logdir="~/foo", db="") + self.assertEqual(manager.data_source_from_info(info), "logdir ~/foo") + + def test_db_data_source_format(self): + info = _make_info()._replace(logdir="", db="sqlite:~/bar") + self.assertEqual(manager.data_source_from_info(info), "db sqlite:~/bar") class CacheKeyTest(tb_test.TestCase): - """Unit tests for `manager.cache_key`.""" + """Unit tests for `manager.cache_key`.""" - def test_result_is_str(self): - result = manager.cache_key( - working_directory="/home/me", - arguments=["--logdir", "something"], - configure_kwargs={}, - ) - self.assertIsInstance(result, str) - - def test_depends_on_working_directory(self): - results = [ - manager.cache_key( - working_directory=d, + def test_result_is_str(self): + result = manager.cache_key( + working_directory="/home/me", arguments=["--logdir", "something"], configure_kwargs={}, ) - for d in ("/home/me", "/home/you") - ] - self.assertEqual(len(results), len(set(results))) - - def test_depends_on_arguments(self): - results = [ - manager.cache_key( + self.assertIsInstance(result, str) + + def test_depends_on_working_directory(self): + results = [ + manager.cache_key( + working_directory=d, + arguments=["--logdir", "something"], + configure_kwargs={}, + ) + for d in ("/home/me", "/home/you") + ] + self.assertEqual(len(results), len(set(results))) + + def test_depends_on_arguments(self): + results = [ + manager.cache_key( + working_directory="/home/me", + arguments=arguments, + configure_kwargs={}, + ) + for arguments in ( + ["--logdir=something"], + ["--logdir", "something"], + ["--logdir", "", "something"], + ["--logdir", "", "something", ""], + ) + ] + self.assertEqual(len(results), len(set(results))) + + def test_depends_on_configure_kwargs(self): + results = [ + manager.cache_key( + working_directory="/home/me", + arguments=[], + configure_kwargs=configure_kwargs, + ) + for configure_kwargs in ( + {"logdir": "something"}, + {"logdir": "something_else"}, + {"logdir": "something", "port": "6006"}, + ) + ] + self.assertEqual(len(results), len(set(results))) + + def test_arguments_and_configure_kwargs_independent(self): + # This test documents current behavior; its existence shouldn't be + # interpreted as mandating the behavior. In fact, it would be nice + # for `arguments` and `configure_kwargs` to be semantically merged + # in the cache key computation, but we don't currently do that. + results = [ + manager.cache_key( + working_directory="/home/me", + arguments=["--logdir", "something"], + configure_kwargs={}, + ), + manager.cache_key( + working_directory="/home/me", + arguments=[], + configure_kwargs={"logdir": "something"}, + ), + ] + self.assertEqual(len(results), len(set(results))) + + def test_arguments_list_vs_tuple_irrelevant(self): + with_list = manager.cache_key( working_directory="/home/me", - arguments=arguments, + arguments=["--logdir", "something"], configure_kwargs={}, ) - for arguments in ( - ["--logdir=something"], - ["--logdir", "something"], - ["--logdir", "", "something"], - ["--logdir", "", "something", ""], - ) - ] - self.assertEqual(len(results), len(set(results))) - - def test_depends_on_configure_kwargs(self): - results = [ - manager.cache_key( - working_directory="/home/me", - arguments=[], - configure_kwargs=configure_kwargs, - ) - for configure_kwargs in ( - {"logdir": "something"}, - {"logdir": "something_else"}, - {"logdir": "something", "port": "6006"}, - ) - ] - self.assertEqual(len(results), len(set(results))) - - def test_arguments_and_configure_kwargs_independent(self): - # This test documents current behavior; its existence shouldn't be - # interpreted as mandating the behavior. In fact, it would be nice - # for `arguments` and `configure_kwargs` to be semantically merged - # in the cache key computation, but we don't currently do that. - results = [ - manager.cache_key( + with_tuple = manager.cache_key( working_directory="/home/me", - arguments=["--logdir", "something"], + arguments=("--logdir", "something"), configure_kwargs={}, - ), - manager.cache_key( - working_directory="/home/me", - arguments=[], - configure_kwargs={"logdir": "something"}, - ), - ] - self.assertEqual(len(results), len(set(results))) - - def test_arguments_list_vs_tuple_irrelevant(self): - with_list = manager.cache_key( - working_directory="/home/me", - arguments=["--logdir", "something"], - configure_kwargs={}, - ) - with_tuple = manager.cache_key( - working_directory="/home/me", - arguments=("--logdir", "something"), - configure_kwargs={}, - ) - self.assertEqual(with_list, with_tuple) + ) + self.assertEqual(with_list, with_tuple) class TensorBoardInfoIoTest(tb_test.TestCase): - """Tests for `write_info_file`, `remove_info_file`, and `get_all`.""" - - def setUp(self): - super(TensorBoardInfoIoTest, self).setUp() - patcher = mock.patch.dict(os.environ, {"TMPDIR": self.get_temp_dir()}) - patcher.start() - self.addCleanup(patcher.stop) - tempfile.tempdir = None # force `gettempdir` to reinitialize from env - self.info_dir = manager._get_info_dir() # ensure that directory exists - - def _list_info_dir(self): - return os.listdir(self.info_dir) - - def assertMode(self, path, expected): - """Assert that the permission bits of a file are as expected. - - Args: - path: File to stat. - expected: `int`; a subset of 0o777. - - Raises: - AssertionError: If the permissions bits of `path` do not match - `expected`. - """ - stat_result = os.stat(path) - format_mode = lambda m: "0o%03o" % m - self.assertEqual( - format_mode(stat_result.st_mode & 0o777), - format_mode(expected), - ) - - def test_fails_if_info_dir_name_is_taken_by_a_regular_file(self): - os.rmdir(self.info_dir) - with open(self.info_dir, "w") as outfile: - pass - with self.assertRaises(OSError) as cm: - manager._get_info_dir() - self.assertEqual(cm.exception.errno, errno.EEXIST, cm.exception) - - @mock.patch("os.getpid", lambda: 76540) - def test_directory_world_accessible(self): - """Test that the TensorBoardInfo directory is world-accessible. + """Tests for `write_info_file`, `remove_info_file`, and `get_all`.""" + + def setUp(self): + super(TensorBoardInfoIoTest, self).setUp() + patcher = mock.patch.dict(os.environ, {"TMPDIR": self.get_temp_dir()}) + patcher.start() + self.addCleanup(patcher.stop) + tempfile.tempdir = None # force `gettempdir` to reinitialize from env + self.info_dir = manager._get_info_dir() # ensure that directory exists + + def _list_info_dir(self): + return os.listdir(self.info_dir) + + def assertMode(self, path, expected): + """Assert that the permission bits of a file are as expected. + + Args: + path: File to stat. + expected: `int`; a subset of 0o777. + + Raises: + AssertionError: If the permissions bits of `path` do not match + `expected`. + """ + stat_result = os.stat(path) + format_mode = lambda m: "0o%03o" % m + self.assertEqual( + format_mode(stat_result.st_mode & 0o777), format_mode(expected), + ) - Regression test for issue #2010: - - """ - if os.name == "nt": - self.skipTest("Windows does not use POSIX-style permissions.") - os.rmdir(self.info_dir) - # The default umask is typically 0o022, in which case this test is - # nontrivial. In the unlikely case that the umask is 0o000, we'll - # still be covered by the "restrictive umask" test case below. - manager.write_info_file(_make_info()) - self.assertMode(self.info_dir, 0o777) - self.assertEqual(self._list_info_dir(), ["pid-76540.info"]) - - @mock.patch("os.getpid", lambda: 76540) - def test_writing_file_with_restrictive_umask(self): - if os.name == "nt": - self.skipTest("Windows does not use POSIX-style permissions.") - os.rmdir(self.info_dir) - # Even if umask prevents owner-access, our I/O should still work. - old_umask = os.umask(0o777) - try: - # Sanity-check that, without special accommodation, this would - # create inaccessible directories... - sanity_dir = os.path.join(self.get_temp_dir(), "canary") - os.mkdir(sanity_dir) - self.assertMode(sanity_dir, 0o000) - - manager.write_info_file(_make_info()) - self.assertMode(self.info_dir, 0o777) - self.assertEqual(self._list_info_dir(), ["pid-76540.info"]) - finally: - self.assertEqual(oct(os.umask(old_umask)), oct(0o777)) - - @mock.patch("os.getpid", lambda: 76540) - def test_write_remove_info_file(self): - info = _make_info() - self.assertEqual(self._list_info_dir(), []) - manager.write_info_file(info) - filename = "pid-76540.info" - expected_filepath = os.path.join(self.info_dir, filename) - self.assertEqual(self._list_info_dir(), [filename]) - with open(expected_filepath) as infile: - self.assertEqual(manager._info_from_string(infile.read()), info) - manager.remove_info_file() - self.assertEqual(self._list_info_dir(), []) - - def test_write_info_file_rejects_bad_types(self): - # The particulars of validation are tested more thoroughly in - # `TensorBoardInfoTest` above. - bad_time = datetime.datetime.fromtimestamp(1549061116) - info = _make_info()._replace(start_time=bad_time) - with six.assertRaisesRegex( - self, - ValueError, - r"expected 'start_time' of type.*int.*, but found: datetime\."): - manager.write_info_file(info) - self.assertEqual(self._list_info_dir(), []) - - def test_write_info_file_rejects_wrong_version(self): - # The particulars of validation are tested more thoroughly in - # `TensorBoardInfoTest` above. - info = _make_info()._replace(version="reversion") - with six.assertRaisesRegex( - self, - ValueError, - "expected 'version' to be '.*', but found: 'reversion'"): - manager.write_info_file(info) - self.assertEqual(self._list_info_dir(), []) - - def test_remove_nonexistent(self): - # Should be a no-op, except to create the info directory if - # necessary. In particular, should not raise any exception. - manager.remove_info_file() - - def test_get_all(self): - def add_info(i): - with mock.patch("os.getpid", lambda: 76540 + i): - manager.write_info_file(_make_info(i)) - def remove_info(i): - with mock.patch("os.getpid", lambda: 76540 + i): + def test_fails_if_info_dir_name_is_taken_by_a_regular_file(self): + os.rmdir(self.info_dir) + with open(self.info_dir, "w") as outfile: + pass + with self.assertRaises(OSError) as cm: + manager._get_info_dir() + self.assertEqual(cm.exception.errno, errno.EEXIST, cm.exception) + + @mock.patch("os.getpid", lambda: 76540) + def test_directory_world_accessible(self): + """Test that the TensorBoardInfo directory is world-accessible. + + Regression test for issue #2010: + + """ + if os.name == "nt": + self.skipTest("Windows does not use POSIX-style permissions.") + os.rmdir(self.info_dir) + # The default umask is typically 0o022, in which case this test is + # nontrivial. In the unlikely case that the umask is 0o000, we'll + # still be covered by the "restrictive umask" test case below. + manager.write_info_file(_make_info()) + self.assertMode(self.info_dir, 0o777) + self.assertEqual(self._list_info_dir(), ["pid-76540.info"]) + + @mock.patch("os.getpid", lambda: 76540) + def test_writing_file_with_restrictive_umask(self): + if os.name == "nt": + self.skipTest("Windows does not use POSIX-style permissions.") + os.rmdir(self.info_dir) + # Even if umask prevents owner-access, our I/O should still work. + old_umask = os.umask(0o777) + try: + # Sanity-check that, without special accommodation, this would + # create inaccessible directories... + sanity_dir = os.path.join(self.get_temp_dir(), "canary") + os.mkdir(sanity_dir) + self.assertMode(sanity_dir, 0o000) + + manager.write_info_file(_make_info()) + self.assertMode(self.info_dir, 0o777) + self.assertEqual(self._list_info_dir(), ["pid-76540.info"]) + finally: + self.assertEqual(oct(os.umask(old_umask)), oct(0o777)) + + @mock.patch("os.getpid", lambda: 76540) + def test_write_remove_info_file(self): + info = _make_info() + self.assertEqual(self._list_info_dir(), []) + manager.write_info_file(info) + filename = "pid-76540.info" + expected_filepath = os.path.join(self.info_dir, filename) + self.assertEqual(self._list_info_dir(), [filename]) + with open(expected_filepath) as infile: + self.assertEqual(manager._info_from_string(infile.read()), info) manager.remove_info_file() - self.assertItemsEqual(manager.get_all(), []) - add_info(1) - self.assertItemsEqual(manager.get_all(), [_make_info(1)]) - add_info(2) - self.assertItemsEqual(manager.get_all(), [_make_info(1), _make_info(2)]) - remove_info(1) - self.assertItemsEqual(manager.get_all(), [_make_info(2)]) - add_info(3) - self.assertItemsEqual(manager.get_all(), [_make_info(2), _make_info(3)]) - remove_info(3) - self.assertItemsEqual(manager.get_all(), [_make_info(2)]) - remove_info(2) - self.assertItemsEqual(manager.get_all(), []) - - def test_get_all_ignores_bad_files(self): - with open(os.path.join(self.info_dir, "pid-1234.info"), "w") as outfile: - outfile.write("good luck parsing this\n") - with open(os.path.join(self.info_dir, "pid-5678.info"), "w") as outfile: - outfile.write('{"valid_json":"yes","valid_tbinfo":"no"}\n') - with open(os.path.join(self.info_dir, "pid-9012.info"), "w") as outfile: - outfile.write('if a tbinfo has st_mode==0, does it make a sound?\n') - os.chmod(os.path.join(self.info_dir, "pid-9012.info"), 0o000) - with mock.patch.object(tb_logging.get_logger(), "debug") as fn: - self.assertEqual(manager.get_all(), []) - self.assertEqual(fn.call_count, 2) # 2 invalid, 1 unreadable (silent) + self.assertEqual(self._list_info_dir(), []) + + def test_write_info_file_rejects_bad_types(self): + # The particulars of validation are tested more thoroughly in + # `TensorBoardInfoTest` above. + bad_time = datetime.datetime.fromtimestamp(1549061116) + info = _make_info()._replace(start_time=bad_time) + with six.assertRaisesRegex( + self, + ValueError, + r"expected 'start_time' of type.*int.*, but found: datetime\.", + ): + manager.write_info_file(info) + self.assertEqual(self._list_info_dir(), []) + + def test_write_info_file_rejects_wrong_version(self): + # The particulars of validation are tested more thoroughly in + # `TensorBoardInfoTest` above. + info = _make_info()._replace(version="reversion") + with six.assertRaisesRegex( + self, + ValueError, + "expected 'version' to be '.*', but found: 'reversion'", + ): + manager.write_info_file(info) + self.assertEqual(self._list_info_dir(), []) + + def test_remove_nonexistent(self): + # Should be a no-op, except to create the info directory if + # necessary. In particular, should not raise any exception. + manager.remove_info_file() + + def test_get_all(self): + def add_info(i): + with mock.patch("os.getpid", lambda: 76540 + i): + manager.write_info_file(_make_info(i)) + + def remove_info(i): + with mock.patch("os.getpid", lambda: 76540 + i): + manager.remove_info_file() + + self.assertItemsEqual(manager.get_all(), []) + add_info(1) + self.assertItemsEqual(manager.get_all(), [_make_info(1)]) + add_info(2) + self.assertItemsEqual(manager.get_all(), [_make_info(1), _make_info(2)]) + remove_info(1) + self.assertItemsEqual(manager.get_all(), [_make_info(2)]) + add_info(3) + self.assertItemsEqual(manager.get_all(), [_make_info(2), _make_info(3)]) + remove_info(3) + self.assertItemsEqual(manager.get_all(), [_make_info(2)]) + remove_info(2) + self.assertItemsEqual(manager.get_all(), []) + + def test_get_all_ignores_bad_files(self): + with open(os.path.join(self.info_dir, "pid-1234.info"), "w") as outfile: + outfile.write("good luck parsing this\n") + with open(os.path.join(self.info_dir, "pid-5678.info"), "w") as outfile: + outfile.write('{"valid_json":"yes","valid_tbinfo":"no"}\n') + with open(os.path.join(self.info_dir, "pid-9012.info"), "w") as outfile: + outfile.write("if a tbinfo has st_mode==0, does it make a sound?\n") + os.chmod(os.path.join(self.info_dir, "pid-9012.info"), 0o000) + with mock.patch.object(tb_logging.get_logger(), "debug") as fn: + self.assertEqual(manager.get_all(), []) + self.assertEqual(fn.call_count, 2) # 2 invalid, 1 unreadable (silent) if __name__ == "__main__": - tb_test.main() + tb_test.main() diff --git a/tensorboard/notebook.py b/tensorboard/notebook.py index 05c0c304e2..406aad3a76 100644 --- a/tensorboard/notebook.py +++ b/tensorboard/notebook.py @@ -30,13 +30,15 @@ import time try: - import html - html_escape = html.escape - del html + import html + + html_escape = html.escape + del html except ImportError: - import cgi - html_escape = cgi.escape - del cgi + import cgi + + html_escape = cgi.escape + del cgi from tensorboard import manager @@ -49,304 +51,299 @@ def _get_context(): - """Determine the most specific context that we're in. - - Returns: - _CONTEXT_COLAB: If in Colab with an IPython notebook context. - _CONTEXT_IPYTHON: If not in Colab, but we are in an IPython notebook - context (e.g., from running `jupyter notebook` at the command - line). - _CONTEXT_NONE: Otherwise (e.g., by running a Python script at the - command-line or using the `ipython` interactive shell). - """ - # In Colab, the `google.colab` module is available, but the shell - # returned by `IPython.get_ipython` does not have a `get_trait` - # method. - try: - import google.colab - import IPython - except ImportError: - pass - else: - if IPython.get_ipython() is not None: - # We'll assume that we're in a Colab notebook context. - return _CONTEXT_COLAB - - # In an IPython command line shell or Jupyter notebook, we can - # directly query whether we're in a notebook context. - try: - import IPython - except ImportError: - pass - else: - ipython = IPython.get_ipython() - if ipython is not None and ipython.has_trait("kernel"): - return _CONTEXT_IPYTHON - - # Otherwise, we're not in a known notebook context. - return _CONTEXT_NONE + """Determine the most specific context that we're in. + + Returns: + _CONTEXT_COLAB: If in Colab with an IPython notebook context. + _CONTEXT_IPYTHON: If not in Colab, but we are in an IPython notebook + context (e.g., from running `jupyter notebook` at the command + line). + _CONTEXT_NONE: Otherwise (e.g., by running a Python script at the + command-line or using the `ipython` interactive shell). + """ + # In Colab, the `google.colab` module is available, but the shell + # returned by `IPython.get_ipython` does not have a `get_trait` + # method. + try: + import google.colab + import IPython + except ImportError: + pass + else: + if IPython.get_ipython() is not None: + # We'll assume that we're in a Colab notebook context. + return _CONTEXT_COLAB + + # In an IPython command line shell or Jupyter notebook, we can + # directly query whether we're in a notebook context. + try: + import IPython + except ImportError: + pass + else: + ipython = IPython.get_ipython() + if ipython is not None and ipython.has_trait("kernel"): + return _CONTEXT_IPYTHON + + # Otherwise, we're not in a known notebook context. + return _CONTEXT_NONE def load_ipython_extension(ipython): - """Deprecated: use `%load_ext tensorboard` instead. + """Deprecated: use `%load_ext tensorboard` instead. Raises: RuntimeError: Always. """ - raise RuntimeError( - "Use '%load_ext tensorboard' instead of '%load_ext tensorboard.notebook'." - ) + raise RuntimeError( + "Use '%load_ext tensorboard' instead of '%load_ext tensorboard.notebook'." + ) def _load_ipython_extension(ipython): - """Load the TensorBoard notebook extension. + """Load the TensorBoard notebook extension. - Intended to be called from `%load_ext tensorboard`. Do not invoke this - directly. + Intended to be called from `%load_ext tensorboard`. Do not invoke this + directly. - Args: - ipython: An `IPython.InteractiveShell` instance. - """ - _register_magics(ipython) + Args: + ipython: An `IPython.InteractiveShell` instance. + """ + _register_magics(ipython) def _register_magics(ipython): - """Register IPython line/cell magics. + """Register IPython line/cell magics. - Args: - ipython: An `InteractiveShell` instance. - """ - ipython.register_magic_function( - _start_magic, - magic_kind="line", - magic_name="tensorboard", - ) + Args: + ipython: An `InteractiveShell` instance. + """ + ipython.register_magic_function( + _start_magic, magic_kind="line", magic_name="tensorboard", + ) def _start_magic(line): - """Implementation of the `%tensorboard` line magic.""" - return start(line) + """Implementation of the `%tensorboard` line magic.""" + return start(line) def start(args_string): - """Launch and display a TensorBoard instance as if at the command line. - - Args: - args_string: Command-line arguments to TensorBoard, to be - interpreted by `shlex.split`: e.g., "--logdir ./logs --port 0". - Shell metacharacters are not supported: e.g., "--logdir 2>&1" will - point the logdir at the literal directory named "2>&1". - """ - context = _get_context() - try: - import IPython - import IPython.display - except ImportError: - IPython = None - - if context == _CONTEXT_NONE: - handle = None - print("Launching TensorBoard...") - else: - handle = IPython.display.display( - IPython.display.Pretty("Launching TensorBoard..."), - display_id=True, - ) - - def print_or_update(message): - if handle is None: - print(message) + """Launch and display a TensorBoard instance as if at the command line. + + Args: + args_string: Command-line arguments to TensorBoard, to be + interpreted by `shlex.split`: e.g., "--logdir ./logs --port 0". + Shell metacharacters are not supported: e.g., "--logdir 2>&1" will + point the logdir at the literal directory named "2>&1". + """ + context = _get_context() + try: + import IPython + import IPython.display + except ImportError: + IPython = None + + if context == _CONTEXT_NONE: + handle = None + print("Launching TensorBoard...") else: - handle.update(IPython.display.Pretty(message)) + handle = IPython.display.display( + IPython.display.Pretty("Launching TensorBoard..."), display_id=True, + ) - parsed_args = shlex.split(args_string, comments=True, posix=True) - start_result = manager.start(parsed_args) + def print_or_update(message): + if handle is None: + print(message) + else: + handle.update(IPython.display.Pretty(message)) - if isinstance(start_result, manager.StartLaunched): - _display( - port=start_result.info.port, - print_message=False, - display_handle=handle, - ) + parsed_args = shlex.split(args_string, comments=True, posix=True) + start_result = manager.start(parsed_args) - elif isinstance(start_result, manager.StartReused): - template = ( - "Reusing TensorBoard on port {port} (pid {pid}), started {delta} ago. " - "(Use '!kill {pid}' to kill it.)" - ) - message = template.format( - port=start_result.info.port, - pid=start_result.info.pid, - delta=_time_delta_from_info(start_result.info), - ) - print_or_update(message) - _display( - port=start_result.info.port, - print_message=False, - display_handle=None, - ) + if isinstance(start_result, manager.StartLaunched): + _display( + port=start_result.info.port, + print_message=False, + display_handle=handle, + ) - elif isinstance(start_result, manager.StartFailed): - def format_stream(name, value): - if value == "": - return "" - elif value is None: - return "\n" % name - else: - return "\nContents of %s:\n%s" % (name, value.strip()) - message = ( - "ERROR: Failed to launch TensorBoard (exited with %d).%s%s" % - ( - start_result.exit_code, - format_stream("stderr", start_result.stderr), - format_stream("stdout", start_result.stdout), + elif isinstance(start_result, manager.StartReused): + template = ( + "Reusing TensorBoard on port {port} (pid {pid}), started {delta} ago. " + "(Use '!kill {pid}' to kill it.)" ) - ) - print_or_update(message) + message = template.format( + port=start_result.info.port, + pid=start_result.info.pid, + delta=_time_delta_from_info(start_result.info), + ) + print_or_update(message) + _display( + port=start_result.info.port, + print_message=False, + display_handle=None, + ) + + elif isinstance(start_result, manager.StartFailed): + + def format_stream(name, value): + if value == "": + return "" + elif value is None: + return "\n" % name + else: + return "\nContents of %s:\n%s" % (name, value.strip()) + + message = ( + "ERROR: Failed to launch TensorBoard (exited with %d).%s%s" + % ( + start_result.exit_code, + format_stream("stderr", start_result.stderr), + format_stream("stdout", start_result.stdout), + ) + ) + print_or_update(message) - elif isinstance(start_result, manager.StartExecFailed): - the_tensorboard_binary = ( - "%r (set by the `TENSORBOARD_BINARY` environment variable)" + elif isinstance(start_result, manager.StartExecFailed): + the_tensorboard_binary = ( + "%r (set by the `TENSORBOARD_BINARY` environment variable)" % (start_result.explicit_binary,) - if start_result.explicit_binary is not None - else "`tensorboard`" - ) - if start_result.os_error.errno == errno.ENOENT: - message = ( - "ERROR: Could not find %s. Please ensure that your PATH contains " - "an executable `tensorboard` program, or explicitly specify the path " - "to a TensorBoard binary by setting the `TENSORBOARD_BINARY` " - "environment variable." - % (the_tensorboard_binary,) - ) - else: - message = ( - "ERROR: Failed to start %s: %s" - % (the_tensorboard_binary, start_result.os_error) - ) - print_or_update(textwrap.fill(message)) - - elif isinstance(start_result, manager.StartTimedOut): - message = ( - "ERROR: Timed out waiting for TensorBoard to start. " - "It may still be running as pid %d." - % start_result.pid - ) - print_or_update(message) + if start_result.explicit_binary is not None + else "`tensorboard`" + ) + if start_result.os_error.errno == errno.ENOENT: + message = ( + "ERROR: Could not find %s. Please ensure that your PATH contains " + "an executable `tensorboard` program, or explicitly specify the path " + "to a TensorBoard binary by setting the `TENSORBOARD_BINARY` " + "environment variable." % (the_tensorboard_binary,) + ) + else: + message = "ERROR: Failed to start %s: %s" % ( + the_tensorboard_binary, + start_result.os_error, + ) + print_or_update(textwrap.fill(message)) + + elif isinstance(start_result, manager.StartTimedOut): + message = ( + "ERROR: Timed out waiting for TensorBoard to start. " + "It may still be running as pid %d." % start_result.pid + ) + print_or_update(message) - else: - raise TypeError( - "Unexpected result from `manager.start`: %r.\n" - "This is a TensorBoard bug; please report it." - % start_result - ) + else: + raise TypeError( + "Unexpected result from `manager.start`: %r.\n" + "This is a TensorBoard bug; please report it." % start_result + ) def _time_delta_from_info(info): - """Format the elapsed time for the given TensorBoardInfo. + """Format the elapsed time for the given TensorBoardInfo. - Args: - info: A TensorBoardInfo value. + Args: + info: A TensorBoardInfo value. - Returns: - A human-readable string describing the time since the server - described by `info` started: e.g., "2 days, 0:48:58". - """ - delta_seconds = int(time.time()) - info.start_time - return str(datetime.timedelta(seconds=delta_seconds)) + Returns: + A human-readable string describing the time since the server + described by `info` started: e.g., "2 days, 0:48:58". + """ + delta_seconds = int(time.time()) - info.start_time + return str(datetime.timedelta(seconds=delta_seconds)) def display(port=None, height=None): - """Display a TensorBoard instance already running on this machine. - - Args: - port: The port on which the TensorBoard server is listening, as an - `int`, or `None` to automatically select the most recently - launched TensorBoard. - height: The height of the frame into which to render the TensorBoard - UI, as an `int` number of pixels, or `None` to use a default value - (currently 800). - """ - _display(port=port, height=height, print_message=True, display_handle=None) + """Display a TensorBoard instance already running on this machine. + Args: + port: The port on which the TensorBoard server is listening, as an + `int`, or `None` to automatically select the most recently + launched TensorBoard. + height: The height of the frame into which to render the TensorBoard + UI, as an `int` number of pixels, or `None` to use a default value + (currently 800). + """ + _display(port=port, height=height, print_message=True, display_handle=None) -def _display(port=None, height=None, print_message=False, display_handle=None): - """Internal version of `display`. - - Args: - port: As with `display`. - height: As with `display`. - print_message: True to print which TensorBoard instance was selected - for display (if applicable), or False otherwise. - display_handle: If not None, an IPython display handle into which to - render TensorBoard. - """ - if height is None: - height = 800 - - if port is None: - infos = manager.get_all() - if not infos: - raise ValueError("Can't display TensorBoard: no known instances running.") - else: - info = max(manager.get_all(), key=lambda x: x.start_time) - port = info.port - else: - infos = [i for i in manager.get_all() if i.port == port] - info = ( - max(infos, key=lambda x: x.start_time) - if infos - else None - ) - if print_message: - if info is not None: - message = ( - "Selecting TensorBoard with {data_source} " - "(started {delta} ago; port {port}, pid {pid})." - ).format( - data_source=manager.data_source_from_info(info), - delta=_time_delta_from_info(info), - port=info.port, - pid=info.pid, - ) - print(message) +def _display(port=None, height=None, print_message=False, display_handle=None): + """Internal version of `display`. + + Args: + port: As with `display`. + height: As with `display`. + print_message: True to print which TensorBoard instance was selected + for display (if applicable), or False otherwise. + display_handle: If not None, an IPython display handle into which to + render TensorBoard. + """ + if height is None: + height = 800 + + if port is None: + infos = manager.get_all() + if not infos: + raise ValueError( + "Can't display TensorBoard: no known instances running." + ) + else: + info = max(manager.get_all(), key=lambda x: x.start_time) + port = info.port else: - # The user explicitly provided a port, and we don't have any - # additional information. There's nothing useful to say. - pass - - fn = { - _CONTEXT_COLAB: _display_colab, - _CONTEXT_IPYTHON: _display_ipython, - _CONTEXT_NONE: _display_cli, - }[_get_context()] - return fn(port=port, height=height, display_handle=display_handle) + infos = [i for i in manager.get_all() if i.port == port] + info = max(infos, key=lambda x: x.start_time) if infos else None + + if print_message: + if info is not None: + message = ( + "Selecting TensorBoard with {data_source} " + "(started {delta} ago; port {port}, pid {pid})." + ).format( + data_source=manager.data_source_from_info(info), + delta=_time_delta_from_info(info), + port=info.port, + pid=info.pid, + ) + print(message) + else: + # The user explicitly provided a port, and we don't have any + # additional information. There's nothing useful to say. + pass + + fn = { + _CONTEXT_COLAB: _display_colab, + _CONTEXT_IPYTHON: _display_ipython, + _CONTEXT_NONE: _display_cli, + }[_get_context()] + return fn(port=port, height=height, display_handle=display_handle) def _display_colab(port, height, display_handle): - """Display a TensorBoard instance in a Colab output frame. - - The Colab VM is not directly exposed to the network, so the Colab - runtime provides a service worker tunnel to proxy requests from the - end user's browser through to servers running on the Colab VM: the - output frame may issue requests to https://localhost: (HTTPS - only), which will be forwarded to the specified port on the VM. - - It does not suffice to create an `iframe` and let the service worker - redirect its traffic (` """ - replacements = [ - ("%HTML_ID%", html_escape(frame_id, quote=True)), - ("%JSON_ID%", json.dumps(frame_id)), - ("%PORT%", "%d" % port), - ("%HEIGHT%", "%d" % height), - ] - for (k, v) in replacements: - shell = shell.replace(k, v) - iframe = IPython.display.HTML(shell) - if display_handle: - display_handle.update(iframe) - else: - IPython.display.display(iframe) + replacements = [ + ("%HTML_ID%", html_escape(frame_id, quote=True)), + ("%JSON_ID%", json.dumps(frame_id)), + ("%PORT%", "%d" % port), + ("%HEIGHT%", "%d" % height), + ] + for (k, v) in replacements: + shell = shell.replace(k, v) + iframe = IPython.display.HTML(shell) + if display_handle: + display_handle.update(iframe) + else: + IPython.display.display(iframe) def _display_cli(port, height, display_handle): - del height # unused - del display_handle # unused - message = "Please visit http://localhost:%d in a web browser." % port - print(message) + del height # unused + del display_handle # unused + message = "Please visit http://localhost:%d in a web browser." % port + print(message) def list(): - """Print a listing of known running TensorBoard instances. + """Print a listing of known running TensorBoard instances. + + TensorBoard instances that were killed uncleanly (e.g., with SIGKILL + or SIGQUIT) may appear in this list even if they are no longer + running. Conversely, this list may be missing some entries if your + operating system's temporary directory has been cleared since a + still-running TensorBoard instance started. + """ + infos = manager.get_all() + if not infos: + print("No known TensorBoard instances running.") + return - TensorBoard instances that were killed uncleanly (e.g., with SIGKILL - or SIGQUIT) may appear in this list even if they are no longer - running. Conversely, this list may be missing some entries if your - operating system's temporary directory has been cleared since a - still-running TensorBoard instance started. - """ - infos = manager.get_all() - if not infos: - print("No known TensorBoard instances running.") - return - - print("Known TensorBoard instances:") - for info in infos: - template = " - port {port}: {data_source} (started {delta} ago; pid {pid})" - print(template.format( - port=info.port, - data_source=manager.data_source_from_info(info), - delta=_time_delta_from_info(info), - pid=info.pid, - )) + print("Known TensorBoard instances:") + for info in infos: + template = ( + " - port {port}: {data_source} (started {delta} ago; pid {pid})" + ) + print( + template.format( + port=info.port, + data_source=manager.data_source_from_info(info), + delta=_time_delta_from_info(info), + pid=info.pid, + ) + ) diff --git a/tensorboard/pip_package/deterministic_tar_gz.py b/tensorboard/pip_package/deterministic_tar_gz.py index 77746ec085..7ab5cbe959 100644 --- a/tensorboard/pip_package/deterministic_tar_gz.py +++ b/tensorboard/pip_package/deterministic_tar_gz.py @@ -44,56 +44,55 @@ def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "archive", - metavar="ARCHIVE", - help="name for the output `.tar.gz` archive", - ) - parser.add_argument( - "files", - metavar="files", - nargs="*", - help="files to include in the archive; basenames must be distinct", - ) - args = parser.parse_args() - archive = args.archive - files = args.files - del args - - if len(frozenset(os.path.basename(f) for f in files)) != len(files): - sys.stderr.write("Input basenames must be distinct; got: %r\n" % files) - sys.exit(1) - - # (`fd` will be closed by `fdopen` context manager below) - fd = os.open(archive, os.O_WRONLY | os.O_CREAT, 0o644) - with \ - os.fdopen(fd, "wb") as out_file, \ - gzip.GzipFile("wb", fileobj=out_file, mtime=0) as gzip_file, \ - tarfile.open(fileobj=gzip_file, mode="w:") as tar_file: - for f in files: - arcname = os.path.basename(f) - tar_file.add(f, filter=cleanse, recursive=False, arcname=arcname) + parser = argparse.ArgumentParser() + parser.add_argument( + "archive", + metavar="ARCHIVE", + help="name for the output `.tar.gz` archive", + ) + parser.add_argument( + "files", + metavar="files", + nargs="*", + help="files to include in the archive; basenames must be distinct", + ) + args = parser.parse_args() + archive = args.archive + files = args.files + del args + + if len(frozenset(os.path.basename(f) for f in files)) != len(files): + sys.stderr.write("Input basenames must be distinct; got: %r\n" % files) + sys.exit(1) + + # (`fd` will be closed by `fdopen` context manager below) + fd = os.open(archive, os.O_WRONLY | os.O_CREAT, 0o644) + with os.fdopen(fd, "wb") as out_file, gzip.GzipFile( + "wb", fileobj=out_file, mtime=0 + ) as gzip_file, tarfile.open(fileobj=gzip_file, mode="w:") as tar_file: + for f in files: + arcname = os.path.basename(f) + tar_file.add(f, filter=cleanse, recursive=False, arcname=arcname) def cleanse(tarinfo): - """Cleanse sources of nondeterminism from tar entries. + """Cleanse sources of nondeterminism from tar entries. - To be passed as the `filter` kwarg to `tarfile.TarFile.add`. + To be passed as the `filter` kwarg to `tarfile.TarFile.add`. - Args: - tarinfo: A `tarfile.TarInfo` object to be mutated. + Args: + tarinfo: A `tarfile.TarInfo` object to be mutated. - Returns: - The same `tarinfo` object, but mutated. - """ - tarinfo.uid = 0 - tarinfo.gid = 0 - tarinfo.uname = "root" - tarinfo.gname = "root" - tarinfo.mtime = 0 - return tarinfo + Returns: + The same `tarinfo` object, but mutated. + """ + tarinfo.uid = 0 + tarinfo.gid = 0 + tarinfo.uname = "root" + tarinfo.gname = "root" + tarinfo.mtime = 0 + return tarinfo if __name__ == "__main__": - main() + main() diff --git a/tensorboard/pip_package/deterministic_tar_gz_test.py b/tensorboard/pip_package/deterministic_tar_gz_test.py index aed295fe03..d62ddab823 100644 --- a/tensorboard/pip_package/deterministic_tar_gz_test.py +++ b/tensorboard/pip_package/deterministic_tar_gz_test.py @@ -27,86 +27,92 @@ class DeterministicTarGzTest(tb_test.TestCase): - - def setUp(self): - self._tool_path = os.path.join( - os.path.dirname(os.environ["TEST_BINARY"]), - "deterministic_tar_gz", - ) - - def _run_tool(self, args): - return subprocess.check_output([self._tool_path] + args) - - def _write_file(self, directory, filename, contents, utime=None): - """Write a file and set its access and modification times. - - Args: - directory: Path to parent directory for the file, as a `str`. - filename: Name of file inside directory, as a `str`. - contents: File contents, as a `str`. - utime: If not `None`, a 2-tuple of numbers (`int`s or `float`s) - representing seconds since epoch for `atime` and `mtime`, - respectively, as in the second argument to `os.utime`. Defaults - to a fixed value; the file's timestamps will always be set. - - Returns: - The new file path. - """ - filepath = os.path.join(directory, filename) - with open(filepath, "w") as outfile: - outfile.write(contents) - if utime is None: - utime = (123, 456) - os.utime(filepath, utime) - return filepath - - def test_correct_contents(self): - tempdir = self.get_temp_dir() - archive = os.path.join(tempdir, "out.tar.gz") - directory = os.path.join(tempdir, "src") - os.mkdir(directory) - self._run_tool([ - archive, - self._write_file(directory, "1.txt", "one"), - self._write_file(directory, "2.txt", "two"), - ]) - with gzip.open(archive) as gzip_file: - with tarfile.open(fileobj=gzip_file, mode="r:") as tar_file: - self.assertEqual(tar_file.getnames(), ["1.txt", "2.txt"]) # in order - self.assertEqual(tar_file.extractfile("1.txt").read(), b"one") - self.assertEqual(tar_file.extractfile("2.txt").read(), b"two") - - def test_invariant_under_mtime(self): - tempdir = self.get_temp_dir() - - archive_1 = os.path.join(tempdir, "out_1.tar.gz") - directory_1 = os.path.join(tempdir, "src_1") - os.mkdir(directory_1) - self._run_tool([ - archive_1, - self._write_file(directory_1, "1.txt", "one", utime=(1, 2)), - self._write_file(directory_1, "2.txt", "two", utime=(3, 4)), - ]) - - archive_2 = os.path.join(tempdir, "out_2.tar.gz") - directory_2 = os.path.join(tempdir, "src_2") - os.mkdir(directory_2) - self._run_tool([ - archive_2, - self._write_file(directory_2, "1.txt", "one", utime=(7, 8)), - self._write_file(directory_2, "2.txt", "two", utime=(5, 6)), - ]) - - with open(archive_1, "rb") as infile: - archive_1_contents = infile.read() - with open(archive_2, "rb") as infile: - archive_2_contents = infile.read() - - self.assertEqual(archive_1_contents, archive_2_contents) - - def test_invariant_under_owner_and_group_names(self): - self.skipTest("Can't really test this; no way to chown.") + def setUp(self): + self._tool_path = os.path.join( + os.path.dirname(os.environ["TEST_BINARY"]), "deterministic_tar_gz", + ) + + def _run_tool(self, args): + return subprocess.check_output([self._tool_path] + args) + + def _write_file(self, directory, filename, contents, utime=None): + """Write a file and set its access and modification times. + + Args: + directory: Path to parent directory for the file, as a `str`. + filename: Name of file inside directory, as a `str`. + contents: File contents, as a `str`. + utime: If not `None`, a 2-tuple of numbers (`int`s or `float`s) + representing seconds since epoch for `atime` and `mtime`, + respectively, as in the second argument to `os.utime`. Defaults + to a fixed value; the file's timestamps will always be set. + + Returns: + The new file path. + """ + filepath = os.path.join(directory, filename) + with open(filepath, "w") as outfile: + outfile.write(contents) + if utime is None: + utime = (123, 456) + os.utime(filepath, utime) + return filepath + + def test_correct_contents(self): + tempdir = self.get_temp_dir() + archive = os.path.join(tempdir, "out.tar.gz") + directory = os.path.join(tempdir, "src") + os.mkdir(directory) + self._run_tool( + [ + archive, + self._write_file(directory, "1.txt", "one"), + self._write_file(directory, "2.txt", "two"), + ] + ) + with gzip.open(archive) as gzip_file: + with tarfile.open(fileobj=gzip_file, mode="r:") as tar_file: + self.assertEqual( + tar_file.getnames(), ["1.txt", "2.txt"] + ) # in order + self.assertEqual(tar_file.extractfile("1.txt").read(), b"one") + self.assertEqual(tar_file.extractfile("2.txt").read(), b"two") + + def test_invariant_under_mtime(self): + tempdir = self.get_temp_dir() + + archive_1 = os.path.join(tempdir, "out_1.tar.gz") + directory_1 = os.path.join(tempdir, "src_1") + os.mkdir(directory_1) + self._run_tool( + [ + archive_1, + self._write_file(directory_1, "1.txt", "one", utime=(1, 2)), + self._write_file(directory_1, "2.txt", "two", utime=(3, 4)), + ] + ) + + archive_2 = os.path.join(tempdir, "out_2.tar.gz") + directory_2 = os.path.join(tempdir, "src_2") + os.mkdir(directory_2) + self._run_tool( + [ + archive_2, + self._write_file(directory_2, "1.txt", "one", utime=(7, 8)), + self._write_file(directory_2, "2.txt", "two", utime=(5, 6)), + ] + ) + + with open(archive_1, "rb") as infile: + archive_1_contents = infile.read() + with open(archive_2, "rb") as infile: + archive_2_contents = infile.read() + + self.assertEqual(archive_1_contents, archive_2_contents) + + def test_invariant_under_owner_and_group_names(self): + self.skipTest("Can't really test this; no way to chown.") if __name__ == "__main__": - tb_test.main() + tb_test.main() diff --git a/tensorboard/pip_package/setup.py b/tensorboard/pip_package/setup.py index 52c4c911f7..a54e48896a 100644 --- a/tensorboard/pip_package/setup.py +++ b/tensorboard/pip_package/setup.py @@ -23,83 +23,81 @@ REQUIRED_PACKAGES = [ - 'absl-py >= 0.4', + "absl-py >= 0.4", # futures is a backport of the python 3.2+ concurrent.futures module 'futures >= 3.1.1; python_version < "3"', - 'grpcio >= 1.24.3', - 'google-auth >= 1.6.3, < 2', - 'google-auth-oauthlib >= 0.4.1, < 0.5', - 'markdown >= 2.6.8', - 'numpy >= 1.12.0', - 'protobuf >= 3.6.0', - 'requests >= 2.21.0, < 3', - 'setuptools >= 41.0.0', - 'six >= 1.10.0', - 'werkzeug >= 0.11.15', + "grpcio >= 1.24.3", + "google-auth >= 1.6.3, < 2", + "google-auth-oauthlib >= 0.4.1, < 0.5", + "markdown >= 2.6.8", + "numpy >= 1.12.0", + "protobuf >= 3.6.0", + "requests >= 2.21.0, < 3", + "setuptools >= 41.0.0", + "six >= 1.10.0", + "werkzeug >= 0.11.15", # python3 specifically requires wheel 0.26 'wheel; python_version < "3"', 'wheel >= 0.26; python_version >= "3"', ] CONSOLE_SCRIPTS = [ - 'tensorboard = tensorboard.main:run_main', + "tensorboard = tensorboard.main:run_main", ] + def get_readme(): - with open('README.rst') as f: - return f.read() + with open("README.rst") as f: + return f.read() + setup( - name='tensorboard', - version=tensorboard.version.VERSION.replace('-', ''), - description='TensorBoard lets you watch Tensors Flow', + name="tensorboard", + version=tensorboard.version.VERSION.replace("-", ""), + description="TensorBoard lets you watch Tensors Flow", long_description=get_readme(), - url='https://github.com/tensorflow/tensorboard', - author='Google Inc.', - author_email='packages@tensorflow.org', + url="https://github.com/tensorflow/tensorboard", + author="Google Inc.", + author_email="packages@tensorflow.org", # Contained modules and scripts. packages=find_packages(), entry_points={ - 'console_scripts': CONSOLE_SCRIPTS, - 'tensorboard_plugins': [ - 'projector = tensorboard.plugins.projector.projector_plugin:ProjectorPlugin', + "console_scripts": CONSOLE_SCRIPTS, + "tensorboard_plugins": [ + "projector = tensorboard.plugins.projector.projector_plugin:ProjectorPlugin", ], }, package_data={ - 'tensorboard': [ - 'webfiles.zip', - ], - 'tensorboard.plugins.beholder': [ - 'resources/*', - ], + "tensorboard": ["webfiles.zip",], + "tensorboard.plugins.beholder": ["resources/*",], # Must keep this in sync with tf_projector_plugin:projector_assets - 'tensorboard.plugins.projector': [ - 'tf_projector_plugin/index.js', - 'tf_projector_plugin/projector_binary.html', - 'tf_projector_plugin/projector_binary.js', + "tensorboard.plugins.projector": [ + "tf_projector_plugin/index.js", + "tf_projector_plugin/projector_binary.html", + "tf_projector_plugin/projector_binary.js", ], }, # Disallow python 3.0 and 3.1 which lack a 'futures' module (see above). - python_requires='>= 2.7, != 3.0.*, != 3.1.*', + python_requires=">= 2.7, != 3.0.*, != 3.1.*", install_requires=REQUIRED_PACKAGES, tests_require=REQUIRED_PACKAGES, # PyPI package information. classifiers=[ - 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', - 'Intended Audience :: Education', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: Apache Software License', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Topic :: Scientific/Engineering :: Mathematics', - 'Topic :: Software Development :: Libraries :: Python Modules', - 'Topic :: Software Development :: Libraries', + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: Software Development :: Libraries", ], - license='Apache 2.0', - keywords='tensorflow tensorboard tensor machine learning visualizer', + license="Apache 2.0", + keywords="tensorflow tensorboard tensor machine learning visualizer", ) diff --git a/tensorboard/plugin_util.py b/tensorboard/plugin_util.py index c259f001d2..eb2132148a 100644 --- a/tensorboard/plugin_util.py +++ b/tensorboard/plugin_util.py @@ -19,6 +19,7 @@ from __future__ import print_function import bleach + # pylint: disable=g-bad-import-order # Google-only: import markdown_freewisdom import markdown @@ -28,80 +29,84 @@ _ALLOWED_ATTRIBUTES = { - 'a': ['href', 'title'], - 'img': ['src', 'title', 'alt'], + "a": ["href", "title"], + "img": ["src", "title", "alt"], } _ALLOWED_TAGS = [ - 'ul', - 'ol', - 'li', - 'p', - 'pre', - 'code', - 'blockquote', - 'h1', - 'h2', - 'h3', - 'h4', - 'h5', - 'h6', - 'hr', - 'br', - 'strong', - 'em', - 'a', - 'img', - 'table', - 'thead', - 'tbody', - 'td', - 'tr', - 'th', + "ul", + "ol", + "li", + "p", + "pre", + "code", + "blockquote", + "h1", + "h2", + "h3", + "h4", + "h5", + "h6", + "hr", + "br", + "strong", + "em", + "a", + "img", + "table", + "thead", + "tbody", + "td", + "tr", + "th", ] def markdown_to_safe_html(markdown_string): - """Convert Markdown to HTML that's safe to splice into the DOM. - - Arguments: - markdown_string: A Unicode string or UTF-8--encoded bytestring - containing Markdown source. Markdown tables are supported. - - Returns: - A string containing safe HTML. - """ - warning = '' - # Convert to utf-8 whenever we have a binary input. - if isinstance(markdown_string, six.binary_type): - markdown_string_decoded = markdown_string.decode('utf-8') - # Remove null bytes and warn if there were any, since it probably means - # we were given a bad encoding. - markdown_string = markdown_string_decoded.replace(u'\x00', u'') - num_null_bytes = len(markdown_string_decoded) - len(markdown_string) - if num_null_bytes: - warning = ('\n') % num_null_bytes - - string_html = markdown.markdown( - markdown_string, extensions=['markdown.extensions.tables']) - string_sanitized = bleach.clean( - string_html, tags=_ALLOWED_TAGS, attributes=_ALLOWED_ATTRIBUTES) - return warning + string_sanitized + """Convert Markdown to HTML that's safe to splice into the DOM. + + Arguments: + markdown_string: A Unicode string or UTF-8--encoded bytestring + containing Markdown source. Markdown tables are supported. + + Returns: + A string containing safe HTML. + """ + warning = "" + # Convert to utf-8 whenever we have a binary input. + if isinstance(markdown_string, six.binary_type): + markdown_string_decoded = markdown_string.decode("utf-8") + # Remove null bytes and warn if there were any, since it probably means + # we were given a bad encoding. + markdown_string = markdown_string_decoded.replace(u"\x00", u"") + num_null_bytes = len(markdown_string_decoded) - len(markdown_string) + if num_null_bytes: + warning = ( + "\n" + ) % num_null_bytes + + string_html = markdown.markdown( + markdown_string, extensions=["markdown.extensions.tables"] + ) + string_sanitized = bleach.clean( + string_html, tags=_ALLOWED_TAGS, attributes=_ALLOWED_ATTRIBUTES + ) + return warning + string_sanitized def experiment_id(environ): - """Determine the experiment ID associated with a WSGI request. + """Determine the experiment ID associated with a WSGI request. - Each request to TensorBoard has an associated experiment ID, which is - always a string and may be empty. This experiment ID should be passed - to data providers. + Each request to TensorBoard has an associated experiment ID, which is + always a string and may be empty. This experiment ID should be passed + to data providers. - Args: - environ: A WSGI environment `dict`. For a Werkzeug request, this is - `request.environ`. + Args: + environ: A WSGI environment `dict`. For a Werkzeug request, this is + `request.environ`. - Returns: - A experiment ID, as a possibly-empty `str`. - """ - return environ.get(_experiment_id.WSGI_ENVIRON_KEY, "") + Returns: + A experiment ID, as a possibly-empty `str`. + """ + return environ.get(_experiment_id.WSGI_ENVIRON_KEY, "") diff --git a/tensorboard/plugin_util_test.py b/tensorboard/plugin_util_test.py index ae029ccb67..23d6948bdd 100644 --- a/tensorboard/plugin_util_test.py +++ b/tensorboard/plugin_util_test.py @@ -26,24 +26,25 @@ class MarkdownToSafeHTMLTest(tb_test.TestCase): - - def _test(self, markdown_string, expected): - actual = plugin_util.markdown_to_safe_html(markdown_string) - self.assertEqual(expected, actual) - - def test_empty_input(self): - self._test(u'', u'') - - def test_basic_formatting(self): - self._test(u'# _Hello_, **world!**\n\n' - 'Check out [my website](http://example.com)!', - u'

Hello, world!

\n' - '

Check out my website!

') - - def test_table_formatting(self): - self._test( - textwrap.dedent( - u"""\ + def _test(self, markdown_string, expected): + actual = plugin_util.markdown_to_safe_html(markdown_string) + self.assertEqual(expected, actual) + + def test_empty_input(self): + self._test(u"", u"") + + def test_basic_formatting(self): + self._test( + u"# _Hello_, **world!**\n\n" + "Check out [my website](http://example.com)!", + u"

Hello, world!

\n" + '

Check out my website!

', + ) + + def test_table_formatting(self): + self._test( + textwrap.dedent( + u"""\ Here is some data: TensorBoard usage | Happiness @@ -52,9 +53,10 @@ def test_table_formatting(self): 0.5 | 0.5 1.0 | 1.0 - Wouldn't you agree?"""), - textwrap.dedent( - u"""\ + Wouldn't you agree?""" + ), + textwrap.dedent( + u"""\

Here is some data:

@@ -78,65 +80,77 @@ def test_table_formatting(self):
-

Wouldn't you agree?

""")) - - def test_whitelisted_tags_and_attributes_allowed(self): - s = (u'Check out ' - 'my website!') - self._test(s, u'

%s

' % s) - - def test_arbitrary_tags_and_attributes_removed(self): - self._test(u'We should bring back the blink tag; ' - '' - 'sign the petition!', - u'

We should bring back the ' - '<blink>blink tag</blink>; ' - 'sign the petition!

') - - def test_javascript_hrefs_sanitized(self): - self._test(u'A sketchy link for you', - u'

A sketchy link for you

') - - def test_byte_strings_interpreted_as_utf8(self): - s = u'> Look\u2014some UTF-8!'.encode('utf-8') - assert isinstance(s, six.binary_type), (type(s), six.binary_type) - self._test(s, - u'
\n

Look\u2014some UTF-8!

\n
') - - def test_unicode_strings_passed_through(self): - s = u'> Look\u2014some UTF-8!' - assert not isinstance(s, six.binary_type), (type(s), six.binary_type) - self._test(s, - u'
\n

Look\u2014some UTF-8!

\n
') - - def test_null_bytes_stripped_before_markdown_processing(self): - # If this function is mistakenly called with UTF-16 or UTF-32 encoded text, - # there will probably be a bunch of null bytes. These would be stripped by - # the sanitizer no matter what, but make sure we remove them before markdown - # interpretation to avoid affecting output (e.g. middle-word underscores - # would generate erroneous tags like "underscore") and add an - # HTML comment with a warning. - s = u'un_der_score'.encode('utf-32-le') - # UTF-32 encoding of ASCII will have 3 null bytes per char. 36 = 3 * 12. - self._test(s, - u'\n' - '

un_der_score

') +

Wouldn't you agree?

""" + ), + ) + + def test_whitelisted_tags_and_attributes_allowed(self): + s = ( + u'Check out ' + "my website!" + ) + self._test(s, u"

%s

" % s) + + def test_arbitrary_tags_and_attributes_removed(self): + self._test( + u"We should bring back the blink tag; " + '' + "sign the petition!", + u"

We should bring back the " + "<blink>blink tag</blink>; " + 'sign the petition!

', + ) + + def test_javascript_hrefs_sanitized(self): + self._test( + u'A sketchy link for you', + u"

A sketchy link for you

", + ) + + def test_byte_strings_interpreted_as_utf8(self): + s = u"> Look\u2014some UTF-8!".encode("utf-8") + assert isinstance(s, six.binary_type), (type(s), six.binary_type) + self._test( + s, u"
\n

Look\u2014some UTF-8!

\n
" + ) + + def test_unicode_strings_passed_through(self): + s = u"> Look\u2014some UTF-8!" + assert not isinstance(s, six.binary_type), (type(s), six.binary_type) + self._test( + s, u"
\n

Look\u2014some UTF-8!

\n
" + ) + + def test_null_bytes_stripped_before_markdown_processing(self): + # If this function is mistakenly called with UTF-16 or UTF-32 encoded text, + # there will probably be a bunch of null bytes. These would be stripped by + # the sanitizer no matter what, but make sure we remove them before markdown + # interpretation to avoid affecting output (e.g. middle-word underscores + # would generate erroneous tags like "underscore") and add an + # HTML comment with a warning. + s = u"un_der_score".encode("utf-32-le") + # UTF-32 encoding of ASCII will have 3 null bytes per char. 36 = 3 * 12. + self._test( + s, + u"\n" + "

un_der_score

", + ) class ExperimentIdTest(tb_test.TestCase): - """Tests for `plugin_util.experiment_id`.""" + """Tests for `plugin_util.experiment_id`.""" - def test_default(self): - # This shouldn't happen; the `ExperimentIdMiddleware` always set an - # experiment ID. In case something goes wrong, degrade gracefully. - environ = {} - self.assertEqual(plugin_util.experiment_id(environ), "") + def test_default(self): + # This shouldn't happen; the `ExperimentIdMiddleware` always set an + # experiment ID. In case something goes wrong, degrade gracefully. + environ = {} + self.assertEqual(plugin_util.experiment_id(environ), "") - def test_present(self): - environ = {experiment_id.WSGI_ENVIRON_KEY: "123"} - self.assertEqual(plugin_util.experiment_id(environ), "123") + def test_present(self): + environ = {experiment_id.WSGI_ENVIRON_KEY: "123"} + self.assertEqual(plugin_util.experiment_id(environ), "123") -if __name__ == '__main__': - tb_test.main() +if __name__ == "__main__": + tb_test.main() diff --git a/tensorboard/plugins/audio/audio_demo.py b/tensorboard/plugins/audio/audio_demo.py index 347ba77523..ec7df2c88a 100644 --- a/tensorboard/plugins/audio/audio_demo.py +++ b/tensorboard/plugins/audio/audio_demo.py @@ -30,157 +30,178 @@ FLAGS = flags.FLAGS -flags.DEFINE_string('logdir', '/tmp/audio_demo', - 'Directory into which to write TensorBoard data.') +flags.DEFINE_string( + "logdir", + "/tmp/audio_demo", + "Directory into which to write TensorBoard data.", +) -flags.DEFINE_integer('steps', 50, - 'Number of frequencies of each waveform to generate.') +flags.DEFINE_integer( + "steps", 50, "Number of frequencies of each waveform to generate." +) # Parameters for the audio output. -flags.DEFINE_integer('sample_rate', 44100, 'Sample rate, in Hz.') -flags.DEFINE_float('duration', 2.0, 'Duration of each waveform, in s.') +flags.DEFINE_integer("sample_rate", 44100, "Sample rate, in Hz.") +flags.DEFINE_float("duration", 2.0, "Duration of each waveform, in s.") def _samples(): - """Compute how many samples should be included in each waveform.""" - return int(FLAGS.sample_rate * FLAGS.duration) + """Compute how many samples should be included in each waveform.""" + return int(FLAGS.sample_rate * FLAGS.duration) def run(logdir, run_name, wave_name, wave_constructor): - """Generate wave data of the given form. - - The provided function `wave_constructor` should accept a scalar tensor - of type float32, representing the frequency (in Hz) at which to - construct a wave, and return a tensor of shape [1, _samples(), `n`] - representing audio data (for some number of channels `n`). - - Waves will be generated at frequencies ranging from A4 to A5. - - Arguments: - logdir: the top-level directory into which to write summary data - run_name: the name of this run; will be created as a subdirectory - under logdir - wave_name: the name of the wave being generated - wave_constructor: see above - """ - tf.compat.v1.reset_default_graph() - tf.compat.v1.set_random_seed(0) - - # On each step `i`, we'll set this placeholder to `i`. This allows us - # to know "what time it is" at each step. - step_placeholder = tf.compat.v1.placeholder(tf.float32, shape=[]) - - # We want to linearly interpolate a frequency between A4 (440 Hz) and - # A5 (880 Hz). - with tf.name_scope('compute_frequency'): - f_min = 440.0 - f_max = 880.0 - t = step_placeholder / (FLAGS.steps - 1) - frequency = f_min * (1.0 - t) + f_max * t - - # Let's log this frequency, just so that we can make sure that it's as - # expected. - tf.compat.v1.summary.scalar('frequency', frequency) - - # Now, we pass this to the wave constructor to get our waveform. Doing - # so within a name scope means that any summaries that the wave - # constructor produces will be namespaced. - with tf.name_scope(wave_name): - waveform = wave_constructor(frequency) - - # We also have the opportunity to annotate each audio clip with a - # label. This is a good place to include the frequency, because it'll - # be visible immediately next to the audio clip. - with tf.name_scope('compute_labels'): - samples = tf.shape(input=waveform)[0] - wave_types = tf.tile(["*Wave type:* `%s`." % wave_name], [samples]) - frequencies = tf.strings.join([ - "*Frequency:* ", - tf.tile([tf.as_string(frequency, precision=2)], [samples]), - " Hz.", - ]) - samples = tf.strings.join([ - "*Sample:* ", tf.as_string(tf.range(samples) + 1), - " of ", tf.as_string(samples), ".", - ]) - labels = tf.strings.join([wave_types, frequencies, samples], separator=" ") - - # We can place a description next to the summary in TensorBoard. This - # is a good place to explain what the summary represents, methodology - # for creating it, etc. Let's include the source code of the function - # that generated the wave. - source = '\n'.join(' %s' % line.rstrip() - for line in inspect.getsourcelines(wave_constructor)[0]) - description = ("A wave of type `%r`, generated via:\n\n%s" - % (wave_name, source)) - - # Here's the crucial piece: we interpret this result as audio. - summary.op('waveform', waveform, FLAGS.sample_rate, - labels=labels, - display_name=wave_name, - description=description) - - # Now, we can collect up all the summaries and begin the run. - summ = tf.compat.v1.summary.merge_all() - - sess = tf.compat.v1.Session() - writer = tf.summary.FileWriter(os.path.join(logdir, run_name)) - writer.add_graph(sess.graph) - sess.run(tf.compat.v1.global_variables_initializer()) - for step in xrange(FLAGS.steps): - s = sess.run(summ, feed_dict={step_placeholder: float(step)}) - writer.add_summary(s, global_step=step) - writer.close() + """Generate wave data of the given form. + + The provided function `wave_constructor` should accept a scalar tensor + of type float32, representing the frequency (in Hz) at which to + construct a wave, and return a tensor of shape [1, _samples(), `n`] + representing audio data (for some number of channels `n`). + + Waves will be generated at frequencies ranging from A4 to A5. + + Arguments: + logdir: the top-level directory into which to write summary data + run_name: the name of this run; will be created as a subdirectory + under logdir + wave_name: the name of the wave being generated + wave_constructor: see above + """ + tf.compat.v1.reset_default_graph() + tf.compat.v1.set_random_seed(0) + + # On each step `i`, we'll set this placeholder to `i`. This allows us + # to know "what time it is" at each step. + step_placeholder = tf.compat.v1.placeholder(tf.float32, shape=[]) + + # We want to linearly interpolate a frequency between A4 (440 Hz) and + # A5 (880 Hz). + with tf.name_scope("compute_frequency"): + f_min = 440.0 + f_max = 880.0 + t = step_placeholder / (FLAGS.steps - 1) + frequency = f_min * (1.0 - t) + f_max * t + + # Let's log this frequency, just so that we can make sure that it's as + # expected. + tf.compat.v1.summary.scalar("frequency", frequency) + + # Now, we pass this to the wave constructor to get our waveform. Doing + # so within a name scope means that any summaries that the wave + # constructor produces will be namespaced. + with tf.name_scope(wave_name): + waveform = wave_constructor(frequency) + + # We also have the opportunity to annotate each audio clip with a + # label. This is a good place to include the frequency, because it'll + # be visible immediately next to the audio clip. + with tf.name_scope("compute_labels"): + samples = tf.shape(input=waveform)[0] + wave_types = tf.tile(["*Wave type:* `%s`." % wave_name], [samples]) + frequencies = tf.strings.join( + [ + "*Frequency:* ", + tf.tile([tf.as_string(frequency, precision=2)], [samples]), + " Hz.", + ] + ) + samples = tf.strings.join( + [ + "*Sample:* ", + tf.as_string(tf.range(samples) + 1), + " of ", + tf.as_string(samples), + ".", + ] + ) + labels = tf.strings.join( + [wave_types, frequencies, samples], separator=" " + ) + + # We can place a description next to the summary in TensorBoard. This + # is a good place to explain what the summary represents, methodology + # for creating it, etc. Let's include the source code of the function + # that generated the wave. + source = "\n".join( + " %s" % line.rstrip() + for line in inspect.getsourcelines(wave_constructor)[0] + ) + description = "A wave of type `%r`, generated via:\n\n%s" % ( + wave_name, + source, + ) + + # Here's the crucial piece: we interpret this result as audio. + summary.op( + "waveform", + waveform, + FLAGS.sample_rate, + labels=labels, + display_name=wave_name, + description=description, + ) + + # Now, we can collect up all the summaries and begin the run. + summ = tf.compat.v1.summary.merge_all() + + sess = tf.compat.v1.Session() + writer = tf.summary.FileWriter(os.path.join(logdir, run_name)) + writer.add_graph(sess.graph) + sess.run(tf.compat.v1.global_variables_initializer()) + for step in xrange(FLAGS.steps): + s = sess.run(summ, feed_dict={step_placeholder: float(step)}) + writer.add_summary(s, global_step=step) + writer.close() # Now, let's take a look at the kinds of waves that we can generate. def sine_wave(frequency): - """Emit a sine wave at the given frequency.""" - xs = tf.reshape(tf.range(_samples(), dtype=tf.float32), [1, _samples(), 1]) - ts = xs / FLAGS.sample_rate - return tf.sin(2 * math.pi * frequency * ts) + """Emit a sine wave at the given frequency.""" + xs = tf.reshape(tf.range(_samples(), dtype=tf.float32), [1, _samples(), 1]) + ts = xs / FLAGS.sample_rate + return tf.sin(2 * math.pi * frequency * ts) def square_wave(frequency): - """Emit a square wave at the given frequency.""" - # The square is just the sign of the sine! - return tf.sign(sine_wave(frequency)) + """Emit a square wave at the given frequency.""" + # The square is just the sign of the sine! + return tf.sign(sine_wave(frequency)) def triangle_wave(frequency): - """Emit a triangle wave at the given frequency.""" - xs = tf.reshape(tf.range(_samples(), dtype=tf.float32), [1, _samples(), 1]) - ts = xs / FLAGS.sample_rate - # - # A triangle wave looks like this: - # - # /\ /\ - # / \ / \ - # \ / \ / - # \/ \/ - # - # If we look at just half a period (the first four slashes in the - # diagram above), we can see that it looks like a transformed absolute - # value function. - # - # Let's start by computing the times relative to the start of each - # half-wave pulse (each individual "mountain" or "valley", of which - # there are four in the above diagram). - half_pulse_index = ts * (frequency * 2) - half_pulse_angle = half_pulse_index % 1.0 # in [0, 1] - # - # Now, we can see that each positive half-pulse ("mountain") has - # amplitude given by A(z) = 0.5 - abs(z - 0.5), and then normalized: - absolute_amplitude = (0.5 - tf.abs(half_pulse_angle - 0.5)) / 0.5 - # - # But every other half-pulse is negative, so we should invert these. - half_pulse_parity = tf.sign(1 - (half_pulse_index % 2.0)) - amplitude = half_pulse_parity * absolute_amplitude - # - # This is precisely the desired result, so we're done! - return amplitude + """Emit a triangle wave at the given frequency.""" + xs = tf.reshape(tf.range(_samples(), dtype=tf.float32), [1, _samples(), 1]) + ts = xs / FLAGS.sample_rate + # + # A triangle wave looks like this: + # + # /\ /\ + # / \ / \ + # \ / \ / + # \/ \/ + # + # If we look at just half a period (the first four slashes in the + # diagram above), we can see that it looks like a transformed absolute + # value function. + # + # Let's start by computing the times relative to the start of each + # half-wave pulse (each individual "mountain" or "valley", of which + # there are four in the above diagram). + half_pulse_index = ts * (frequency * 2) + half_pulse_angle = half_pulse_index % 1.0 # in [0, 1] + # + # Now, we can see that each positive half-pulse ("mountain") has + # amplitude given by A(z) = 0.5 - abs(z - 0.5), and then normalized: + absolute_amplitude = (0.5 - tf.abs(half_pulse_angle - 0.5)) / 0.5 + # + # But every other half-pulse is negative, so we should invert these. + half_pulse_parity = tf.sign(1 - (half_pulse_index % 2.0)) + amplitude = half_pulse_parity * absolute_amplitude + # + # This is precisely the desired result, so we're done! + return amplitude # If we want to get fancy, we can use our above waves as primitives to @@ -188,74 +209,79 @@ def triangle_wave(frequency): def bisine_wave(frequency): - """Emit two sine waves, in stereo at different octaves.""" - # - # We can first our existing sine generator to generate two different - # waves. - f_hi = frequency - f_lo = frequency / 2.0 - with tf.name_scope('hi'): - sine_hi = sine_wave(f_hi) - with tf.name_scope('lo'): - sine_lo = sine_wave(f_lo) - # - # Now, we have two tensors of shape [1, _samples(), 1]. By concatenating - # them along axis 2, we get a tensor of shape [1, _samples(), 2]---a - # stereo waveform. - return tf.concat([sine_lo, sine_hi], axis=2) + """Emit two sine waves, in stereo at different octaves.""" + # + # We can first our existing sine generator to generate two different + # waves. + f_hi = frequency + f_lo = frequency / 2.0 + with tf.name_scope("hi"): + sine_hi = sine_wave(f_hi) + with tf.name_scope("lo"): + sine_lo = sine_wave(f_lo) + # + # Now, we have two tensors of shape [1, _samples(), 1]. By concatenating + # them along axis 2, we get a tensor of shape [1, _samples(), 2]---a + # stereo waveform. + return tf.concat([sine_lo, sine_hi], axis=2) def bisine_wahwah_wave(frequency): - """Emit two sine waves with balance oscillating left and right.""" - # - # This is clearly intended to build on the bisine wave defined above, - # so we can start by generating that. - waves_a = bisine_wave(frequency) - # - # Then, by reversing axis 2, we swap the stereo channels. By mixing - # this with `waves_a`, we'll be able to create the desired effect. - waves_b = tf.reverse(waves_a, axis=[2]) - # - # Let's have the balance oscillate from left to right four times. - iterations = 4 - # - # Now, we compute the balance for each sample: `ts` has values - # in [0, 1] that indicate how much we should use `waves_a`. - xs = tf.reshape(tf.range(_samples(), dtype=tf.float32), [1, _samples(), 1]) - thetas = xs / _samples() * iterations - ts = (tf.sin(math.pi * 2 * thetas) + 1) / 2 - # - # Finally, we can mix the two together, and we're done. - wave = ts * waves_a + (1.0 - ts) * waves_b - # - # Alternately, we can make the effect more pronounced by exaggerating - # the sample data. Let's emit both variations. - exaggerated_wave = wave ** 3.0 - return tf.concat([wave, exaggerated_wave], axis=0) + """Emit two sine waves with balance oscillating left and right.""" + # + # This is clearly intended to build on the bisine wave defined above, + # so we can start by generating that. + waves_a = bisine_wave(frequency) + # + # Then, by reversing axis 2, we swap the stereo channels. By mixing + # this with `waves_a`, we'll be able to create the desired effect. + waves_b = tf.reverse(waves_a, axis=[2]) + # + # Let's have the balance oscillate from left to right four times. + iterations = 4 + # + # Now, we compute the balance for each sample: `ts` has values + # in [0, 1] that indicate how much we should use `waves_a`. + xs = tf.reshape(tf.range(_samples(), dtype=tf.float32), [1, _samples(), 1]) + thetas = xs / _samples() * iterations + ts = (tf.sin(math.pi * 2 * thetas) + 1) / 2 + # + # Finally, we can mix the two together, and we're done. + wave = ts * waves_a + (1.0 - ts) * waves_b + # + # Alternately, we can make the effect more pronounced by exaggerating + # the sample data. Let's emit both variations. + exaggerated_wave = wave ** 3.0 + return tf.concat([wave, exaggerated_wave], axis=0) def run_all(logdir, verbose=False): - """Generate waves of the shapes defined above. - - Arguments: - logdir: the directory into which to store all the runs' data - verbose: if true, print out each run's name as it begins - """ - waves = [sine_wave, square_wave, triangle_wave, - bisine_wave, bisine_wahwah_wave] - for (i, wave_constructor) in enumerate(waves): - wave_name = wave_constructor.__name__ - run_name = 'wave:%02d,%s' % (i + 1, wave_name) - if verbose: - print('--- Running: %s' % run_name) - run(logdir, run_name, wave_name, wave_constructor) + """Generate waves of the shapes defined above. + + Arguments: + logdir: the directory into which to store all the runs' data + verbose: if true, print out each run's name as it begins + """ + waves = [ + sine_wave, + square_wave, + triangle_wave, + bisine_wave, + bisine_wahwah_wave, + ] + for (i, wave_constructor) in enumerate(waves): + wave_name = wave_constructor.__name__ + run_name = "wave:%02d,%s" % (i + 1, wave_name) + if verbose: + print("--- Running: %s" % run_name) + run(logdir, run_name, wave_name, wave_constructor) def main(unused_argv): - print('Saving output to %s.' % FLAGS.logdir) - run_all(FLAGS.logdir, verbose=True) - print('Done. Output saved to %s.' % FLAGS.logdir) + print("Saving output to %s." % FLAGS.logdir) + run_all(FLAGS.logdir, verbose=True) + print("Done. Output saved to %s." % FLAGS.logdir) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/tensorboard/plugins/audio/audio_plugin.py b/tensorboard/plugins/audio/audio_plugin.py index 542ab6d312..906e12aeb0 100644 --- a/tensorboard/plugins/audio/audio_plugin.py +++ b/tensorboard/plugins/audio/audio_plugin.py @@ -30,204 +30,230 @@ from tensorboard.util import tensor_util -_DEFAULT_MIME_TYPE = 'application/octet-stream' +_DEFAULT_MIME_TYPE = "application/octet-stream" _MIME_TYPES = { - metadata.Encoding.Value('WAV'): 'audio/wav', + metadata.Encoding.Value("WAV"): "audio/wav", } class AudioPlugin(base_plugin.TBPlugin): - """Audio Plugin for TensorBoard.""" - - plugin_name = metadata.PLUGIN_NAME - - def __init__(self, context): - """Instantiates AudioPlugin via TensorBoard core. - - Args: - context: A base_plugin.TBContext instance. - """ - self._multiplexer = context.multiplexer - - def get_plugin_apps(self): - return { - '/audio': self._serve_audio_metadata, - '/individualAudio': self._serve_individual_audio, - '/tags': self._serve_tags, - } - - def is_active(self): - """The audio plugin is active iff any run has at least one relevant tag.""" - if not self._multiplexer: - return False - return bool(self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME)) - - def frontend_metadata(self): - return base_plugin.FrontendMetadata(element_name='tf-audio-dashboard') - - def _index_impl(self): - """Return information about the tags in each run. - - Result is a dictionary of the form - - { - "runName1": { - "tagName1": { - "displayName": "The first tag", - "description": "

Long ago there was just one tag...

", - "samples": 3 - }, - "tagName2": ..., - ... - }, - "runName2": ..., - ... + """Audio Plugin for TensorBoard.""" + + plugin_name = metadata.PLUGIN_NAME + + def __init__(self, context): + """Instantiates AudioPlugin via TensorBoard core. + + Args: + context: A base_plugin.TBContext instance. + """ + self._multiplexer = context.multiplexer + + def get_plugin_apps(self): + return { + "/audio": self._serve_audio_metadata, + "/individualAudio": self._serve_individual_audio, + "/tags": self._serve_tags, } - For each tag, `samples` is the greatest number of audio clips that - appear at any particular step. (It's not related to "samples of a - waveform.") For example, if for tag `minibatch_input` there are - five audio clips at step 0 and ten audio clips at step 1, then the - dictionary for `"minibatch_input"` will contain `"samples": 10`. - """ - runs = self._multiplexer.Runs() - result = {run: {} for run in runs} - - mapping = self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME) - for (run, tag_to_content) in six.iteritems(mapping): - for tag in tag_to_content: - summary_metadata = self._multiplexer.SummaryMetadata(run, tag) - tensor_events = self._multiplexer.Tensors(run, tag) - samples = max([self._number_of_samples(event.tensor_proto) - for event in tensor_events] + [0]) - result[run][tag] = {'displayName': summary_metadata.display_name, - 'description': plugin_util.markdown_to_safe_html( - summary_metadata.summary_description), - 'samples': samples} - - return result - - def _number_of_samples(self, tensor_proto): - """Count the number of samples of an audio TensorProto.""" - # We directly inspect the `tensor_shape` of the proto instead of - # using the preferred `tensor_util.make_ndarray(...).shape`, because - # these protos can contain a large amount of encoded audio data, - # and we don't want to have to convert them all to numpy arrays - # just to look at their shape. - return tensor_proto.tensor_shape.dim[0].size - - def _filter_by_sample(self, tensor_events, sample): - return [tensor_event for tensor_event in tensor_events - if self._number_of_samples(tensor_event.tensor_proto) > sample] - - @wrappers.Request.application - def _serve_audio_metadata(self, request): - """Given a tag and list of runs, serve a list of metadata for audio. - - Note that the actual audio data are not sent; instead, we respond - with URLs to the audio. The frontend should treat these URLs as - opaque and should not try to parse information about them or - generate them itself, as the format may change. - - Args: - request: A werkzeug.wrappers.Request object. - - Returns: - A werkzeug.Response application. - """ - tag = request.args.get('tag') - run = request.args.get('run') - sample = int(request.args.get('sample', 0)) - - events = self._multiplexer.Tensors(run, tag) - try: - response = self._audio_response_for_run(events, run, tag, sample) - except KeyError: - return http_util.Respond( - request, 'Invalid run or tag', 'text/plain', code=400 - ) - return http_util.Respond(request, response, 'application/json') - - def _audio_response_for_run(self, tensor_events, run, tag, sample): - """Builds a JSON-serializable object with information about audio. - - Args: - tensor_events: A list of image event_accumulator.TensorEvent objects. - run: The name of the run. - tag: The name of the tag the audio entries all belong to. - sample: The zero-indexed sample of the audio sample for which to - retrieve information. For instance, setting `sample` to `2` will - fetch information about only the third audio clip of each batch, - and steps with fewer than three audio clips will be omitted from - the results. - - Returns: - A list of dictionaries containing the wall time, step, URL, width, and - height for each audio entry. - """ - response = [] - index = 0 - filtered_events = self._filter_by_sample(tensor_events, sample) - content_type = self._get_mime_type(run, tag) - for (index, tensor_event) in enumerate(filtered_events): - data = tensor_util.make_ndarray(tensor_event.tensor_proto) - label = data[sample, 1] - response.append({ - 'wall_time': tensor_event.wall_time, - 'step': tensor_event.step, - 'label': plugin_util.markdown_to_safe_html(label), - 'contentType': content_type, - 'query': self._query_for_individual_audio(run, tag, sample, index) - }) - return response - - def _query_for_individual_audio(self, run, tag, sample, index): - """Builds a URL for accessing the specified audio. - - This should be kept in sync with _serve_audio_metadata. Note that the URL is - *not* guaranteed to always return the same audio, since audio may be - unloaded from the reservoir as new audio entries come in. - - Args: - run: The name of the run. - tag: The tag. - index: The index of the audio entry. Negative values are OK. - - Returns: - A string representation of a URL that will load the index-th sampled audio - in the given run with the given tag. - """ - query_string = urllib.parse.urlencode({ - 'run': run, - 'tag': tag, - 'sample': sample, - 'index': index, - }) - return query_string - - def _get_mime_type(self, run, tag): - content = self._multiplexer.SummaryMetadata(run, tag).plugin_data.content - parsed = metadata.parse_plugin_metadata(content) - return _MIME_TYPES.get(parsed.encoding, _DEFAULT_MIME_TYPE) - - @wrappers.Request.application - def _serve_individual_audio(self, request): - """Serve encoded audio data.""" - tag = request.args.get('tag') - run = request.args.get('run') - index = int(request.args.get('index', '0')) - sample = int(request.args.get('sample', '0')) - try: - events = self._filter_by_sample(self._multiplexer.Tensors(run, tag), sample) - data = tensor_util.make_ndarray(events[index].tensor_proto)[sample, 0] - except (KeyError, IndexError): - return http_util.Respond( - request, 'Invalid run, tag, index, or sample', 'text/plain', code=400 - ) - mime_type = self._get_mime_type(run, tag) - return http_util.Respond(request, data, mime_type) - - @wrappers.Request.application - def _serve_tags(self, request): - index = self._index_impl() - return http_util.Respond(request, index, 'application/json') + def is_active(self): + """The audio plugin is active iff any run has at least one relevant + tag.""" + if not self._multiplexer: + return False + return bool( + self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME) + ) + + def frontend_metadata(self): + return base_plugin.FrontendMetadata(element_name="tf-audio-dashboard") + + def _index_impl(self): + """Return information about the tags in each run. + + Result is a dictionary of the form + + { + "runName1": { + "tagName1": { + "displayName": "The first tag", + "description": "

Long ago there was just one tag...

", + "samples": 3 + }, + "tagName2": ..., + ... + }, + "runName2": ..., + ... + } + + For each tag, `samples` is the greatest number of audio clips that + appear at any particular step. (It's not related to "samples of a + waveform.") For example, if for tag `minibatch_input` there are + five audio clips at step 0 and ten audio clips at step 1, then the + dictionary for `"minibatch_input"` will contain `"samples": 10`. + """ + runs = self._multiplexer.Runs() + result = {run: {} for run in runs} + + mapping = self._multiplexer.PluginRunToTagToContent( + metadata.PLUGIN_NAME + ) + for (run, tag_to_content) in six.iteritems(mapping): + for tag in tag_to_content: + summary_metadata = self._multiplexer.SummaryMetadata(run, tag) + tensor_events = self._multiplexer.Tensors(run, tag) + samples = max( + [ + self._number_of_samples(event.tensor_proto) + for event in tensor_events + ] + + [0] + ) + result[run][tag] = { + "displayName": summary_metadata.display_name, + "description": plugin_util.markdown_to_safe_html( + summary_metadata.summary_description + ), + "samples": samples, + } + + return result + + def _number_of_samples(self, tensor_proto): + """Count the number of samples of an audio TensorProto.""" + # We directly inspect the `tensor_shape` of the proto instead of + # using the preferred `tensor_util.make_ndarray(...).shape`, because + # these protos can contain a large amount of encoded audio data, + # and we don't want to have to convert them all to numpy arrays + # just to look at their shape. + return tensor_proto.tensor_shape.dim[0].size + + def _filter_by_sample(self, tensor_events, sample): + return [ + tensor_event + for tensor_event in tensor_events + if self._number_of_samples(tensor_event.tensor_proto) > sample + ] + + @wrappers.Request.application + def _serve_audio_metadata(self, request): + """Given a tag and list of runs, serve a list of metadata for audio. + + Note that the actual audio data are not sent; instead, we respond + with URLs to the audio. The frontend should treat these URLs as + opaque and should not try to parse information about them or + generate them itself, as the format may change. + + Args: + request: A werkzeug.wrappers.Request object. + + Returns: + A werkzeug.Response application. + """ + tag = request.args.get("tag") + run = request.args.get("run") + sample = int(request.args.get("sample", 0)) + + events = self._multiplexer.Tensors(run, tag) + try: + response = self._audio_response_for_run(events, run, tag, sample) + except KeyError: + return http_util.Respond( + request, "Invalid run or tag", "text/plain", code=400 + ) + return http_util.Respond(request, response, "application/json") + + def _audio_response_for_run(self, tensor_events, run, tag, sample): + """Builds a JSON-serializable object with information about audio. + + Args: + tensor_events: A list of image event_accumulator.TensorEvent objects. + run: The name of the run. + tag: The name of the tag the audio entries all belong to. + sample: The zero-indexed sample of the audio sample for which to + retrieve information. For instance, setting `sample` to `2` will + fetch information about only the third audio clip of each batch, + and steps with fewer than three audio clips will be omitted from + the results. + + Returns: + A list of dictionaries containing the wall time, step, URL, width, and + height for each audio entry. + """ + response = [] + index = 0 + filtered_events = self._filter_by_sample(tensor_events, sample) + content_type = self._get_mime_type(run, tag) + for (index, tensor_event) in enumerate(filtered_events): + data = tensor_util.make_ndarray(tensor_event.tensor_proto) + label = data[sample, 1] + response.append( + { + "wall_time": tensor_event.wall_time, + "step": tensor_event.step, + "label": plugin_util.markdown_to_safe_html(label), + "contentType": content_type, + "query": self._query_for_individual_audio( + run, tag, sample, index + ), + } + ) + return response + + def _query_for_individual_audio(self, run, tag, sample, index): + """Builds a URL for accessing the specified audio. + + This should be kept in sync with _serve_audio_metadata. Note that the URL is + *not* guaranteed to always return the same audio, since audio may be + unloaded from the reservoir as new audio entries come in. + + Args: + run: The name of the run. + tag: The tag. + index: The index of the audio entry. Negative values are OK. + + Returns: + A string representation of a URL that will load the index-th sampled audio + in the given run with the given tag. + """ + query_string = urllib.parse.urlencode( + {"run": run, "tag": tag, "sample": sample, "index": index,} + ) + return query_string + + def _get_mime_type(self, run, tag): + content = self._multiplexer.SummaryMetadata( + run, tag + ).plugin_data.content + parsed = metadata.parse_plugin_metadata(content) + return _MIME_TYPES.get(parsed.encoding, _DEFAULT_MIME_TYPE) + + @wrappers.Request.application + def _serve_individual_audio(self, request): + """Serve encoded audio data.""" + tag = request.args.get("tag") + run = request.args.get("run") + index = int(request.args.get("index", "0")) + sample = int(request.args.get("sample", "0")) + try: + events = self._filter_by_sample( + self._multiplexer.Tensors(run, tag), sample + ) + data = tensor_util.make_ndarray(events[index].tensor_proto)[ + sample, 0 + ] + except (KeyError, IndexError): + return http_util.Respond( + request, + "Invalid run, tag, index, or sample", + "text/plain", + code=400, + ) + mime_type = self._get_mime_type(run, tag) + return http_util.Respond(request, data, mime_type) + + @wrappers.Request.application + def _serve_tags(self, request): + index = self._index_impl() + return http_util.Respond(request, index, "application/json") diff --git a/tensorboard/plugins/audio/audio_plugin_test.py b/tensorboard/plugins/audio/audio_plugin_test.py index 0672c7bffd..088ef0acca 100644 --- a/tensorboard/plugins/audio/audio_plugin_test.py +++ b/tensorboard/plugins/audio/audio_plugin_test.py @@ -32,7 +32,9 @@ from werkzeug import wrappers from tensorboard.backend import application -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from tensorboard.backend.event_processing import ( + plugin_event_multiplexer as event_multiplexer, +) from tensorboard.plugins import base_plugin from tensorboard.plugins.audio import audio_plugin from tensorboard.plugins.audio import summary @@ -40,201 +42,237 @@ class AudioPluginTest(tf.test.TestCase): + def setUp(self): + self.log_dir = tempfile.mkdtemp() - def setUp(self): - self.log_dir = tempfile.mkdtemp() - - # We use numpy.random to generate audio. We seed to avoid non-determinism - # in this test. - numpy.random.seed(42) - - # Create old-style audio summaries for run "foo". - tf.compat.v1.reset_default_graph() - with tf.compat.v1.Graph().as_default(): - sess = tf.compat.v1.Session() - placeholder = tf.compat.v1.placeholder(tf.float32) - tf.compat.v1.summary.audio(name="baz", tensor=placeholder, sample_rate=44100) - merged_summary_op = tf.compat.v1.summary.merge_all() - foo_directory = os.path.join(self.log_dir, "foo") - with test_util.FileWriterCache.get(foo_directory) as writer: - writer.add_graph(sess.graph) - for step in xrange(2): - # The floats (sample data) range from -1 to 1. - writer.add_summary(sess.run(merged_summary_op, feed_dict={ - placeholder: numpy.random.rand(42, 22050) * 2 - 1 - }), global_step=step) - - # Create new-style audio summaries for run "bar". - tf.compat.v1.reset_default_graph() - with tf.compat.v1.Graph().as_default(): - sess = tf.compat.v1.Session() - audio_placeholder = tf.compat.v1.placeholder(tf.float32) - labels_placeholder = tf.compat.v1.placeholder(tf.string) - summary.op("quux", audio_placeholder, sample_rate=44100, - labels=labels_placeholder, - description="how do you pronounce that, anyway?") - merged_summary_op = tf.compat.v1.summary.merge_all() - bar_directory = os.path.join(self.log_dir, "bar") - with test_util.FileWriterCache.get(bar_directory) as writer: - writer.add_graph(sess.graph) - for step in xrange(2): - # The floats (sample data) range from -1 to 1. - writer.add_summary(sess.run(merged_summary_op, feed_dict={ - audio_placeholder: numpy.random.rand(42, 11025, 1) * 2 - 1, - labels_placeholder: [ - tf.compat.as_bytes('step **%s**, sample %s' % (step, sample)) - for sample in xrange(42) - ], - }), global_step=step) - - # Start a server with the plugin. - multiplexer = event_multiplexer.EventMultiplexer({ - "foo": foo_directory, - "bar": bar_directory, - }) - multiplexer.Reload() - context = base_plugin.TBContext( - logdir=self.log_dir, multiplexer=multiplexer) - self.plugin = audio_plugin.AudioPlugin(context) - wsgi_app = application.TensorBoardWSGI([self.plugin]) - self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse) - - def tearDown(self): - shutil.rmtree(self.log_dir, ignore_errors=True) - - def _DeserializeResponse(self, byte_content): - """Deserializes byte content that is a JSON encoding. - - Args: - byte_content: The byte content of a response. - - Returns: - The deserialized python object decoded from JSON. - """ - return json.loads(byte_content.decode("utf-8")) - - def testRoutesProvided(self): - """Tests that the plugin offers the correct routes.""" - routes = self.plugin.get_plugin_apps() - self.assertIsInstance(routes["/audio"], collections.Callable) - self.assertIsInstance(routes["/individualAudio"], collections.Callable) - self.assertIsInstance(routes["/tags"], collections.Callable) - - def testOldStyleAudioRoute(self): - """Tests that the /audio routes returns correct old-style data.""" - response = self.server.get( - "/data/plugin/audio/audio?run=foo&tag=baz/audio/0&sample=0") - self.assertEqual(200, response.status_code) - - # Verify that the correct entries are returned. - entries = self._DeserializeResponse(response.get_data()) - self.assertEqual(2, len(entries)) - - # Verify that the 1st entry is correct. - entry = entries[0] - self.assertEqual("audio/wav", entry["contentType"]) - self.assertEqual("", entry["label"]) - self.assertEqual(0, entry["step"]) - parsed_query = urllib.parse.parse_qs(entry["query"]) - self.assertListEqual(["foo"], parsed_query["run"]) - self.assertListEqual(["baz/audio/0"], parsed_query["tag"]) - self.assertListEqual(["0"], parsed_query["sample"]) - self.assertListEqual(["0"], parsed_query["index"]) - - # Verify that the 2nd entry is correct. - entry = entries[1] - self.assertEqual("audio/wav", entry["contentType"]) - self.assertEqual("", entry["label"]) - self.assertEqual(1, entry["step"]) - parsed_query = urllib.parse.parse_qs(entry["query"]) - self.assertListEqual(["foo"], parsed_query["run"]) - self.assertListEqual(["baz/audio/0"], parsed_query["tag"]) - self.assertListEqual(["0"], parsed_query["sample"]) - self.assertListEqual(["1"], parsed_query["index"]) - - def testNewStyleAudioRoute(self): - """Tests that the /audio routes returns correct new-style data.""" - response = self.server.get( - "/data/plugin/audio/audio?run=bar&tag=quux/audio_summary&sample=0") - self.assertEqual(200, response.status_code) - - # Verify that the correct entries are returned. - entries = self._DeserializeResponse(response.get_data()) - self.assertEqual(2, len(entries)) - - # Verify that the 1st entry is correct. - entry = entries[0] - self.assertEqual("audio/wav", entry["contentType"]) - self.assertEqual( - "

step %s, sample 0

" % entry["step"], - entry["label"]) - self.assertEqual(0, entry["step"]) - parsed_query = urllib.parse.parse_qs(entry["query"]) - self.assertListEqual(["bar"], parsed_query["run"]) - self.assertListEqual(["quux/audio_summary"], parsed_query["tag"]) - self.assertListEqual(["0"], parsed_query["sample"]) - self.assertListEqual(["0"], parsed_query["index"]) - - # Verify that the 2nd entry is correct. - entry = entries[1] - self.assertEqual("audio/wav", entry["contentType"]) - self.assertEqual( - "

step %s, sample 0

" % entry["step"], - entry["label"]) - self.assertEqual(1, entry["step"]) - parsed_query = urllib.parse.parse_qs(entry["query"]) - self.assertListEqual(["bar"], parsed_query["run"]) - self.assertListEqual(["quux/audio_summary"], parsed_query["tag"]) - self.assertListEqual(["0"], parsed_query["sample"]) - self.assertListEqual(["1"], parsed_query["index"]) - - def testOldStyleIndividualAudioRoute(self): - """Tests fetching an individual audio clip from an old-style summary.""" - response = self.server.get( - "/data/plugin/audio/individualAudio" - "?run=foo&tag=baz/audio/0&sample=0&index=0") - self.assertEqual(200, response.status_code) - self.assertEqual("audio/wav", response.headers.get("content-type")) - - def testNewStyleIndividualAudioRoute(self): - """Tests fetching an individual audio clip from an old-style summary.""" - response = self.server.get( - "/data/plugin/audio/individualAudio" - "?run=bar&tag=quux/audio_summary&sample=0&index=0") - self.assertEqual(200, response.status_code) - self.assertEqual("audio/wav", response.headers.get("content-type")) - - def testTagsRoute(self): - """Tests that the /tags route offers the correct run to tag mapping.""" - response = self.server.get("/data/plugin/audio/tags") - self.assertEqual(200, response.status_code) - self.assertDictEqual({ - "foo": { - "baz/audio/0": { - "displayName": "baz/audio/0", - "description": "", - "samples": 1, - }, - "baz/audio/1": { - "displayName": "baz/audio/1", - "description": "", - "samples": 1, - }, - "baz/audio/2": { - "displayName": "baz/audio/2", - "description": "", - "samples": 1, - }, - }, - "bar": { - "quux/audio_summary": { - "displayName": "quux", - "description": "

how do you pronounce that, anyway?

", - "samples": 3, # 42 inputs, but max_outputs=3 + # We use numpy.random to generate audio. We seed to avoid non-determinism + # in this test. + numpy.random.seed(42) + + # Create old-style audio summaries for run "foo". + tf.compat.v1.reset_default_graph() + with tf.compat.v1.Graph().as_default(): + sess = tf.compat.v1.Session() + placeholder = tf.compat.v1.placeholder(tf.float32) + tf.compat.v1.summary.audio( + name="baz", tensor=placeholder, sample_rate=44100 + ) + merged_summary_op = tf.compat.v1.summary.merge_all() + foo_directory = os.path.join(self.log_dir, "foo") + with test_util.FileWriterCache.get(foo_directory) as writer: + writer.add_graph(sess.graph) + for step in xrange(2): + # The floats (sample data) range from -1 to 1. + writer.add_summary( + sess.run( + merged_summary_op, + feed_dict={ + placeholder: numpy.random.rand(42, 22050) * 2 + - 1 + }, + ), + global_step=step, + ) + + # Create new-style audio summaries for run "bar". + tf.compat.v1.reset_default_graph() + with tf.compat.v1.Graph().as_default(): + sess = tf.compat.v1.Session() + audio_placeholder = tf.compat.v1.placeholder(tf.float32) + labels_placeholder = tf.compat.v1.placeholder(tf.string) + summary.op( + "quux", + audio_placeholder, + sample_rate=44100, + labels=labels_placeholder, + description="how do you pronounce that, anyway?", + ) + merged_summary_op = tf.compat.v1.summary.merge_all() + bar_directory = os.path.join(self.log_dir, "bar") + with test_util.FileWriterCache.get(bar_directory) as writer: + writer.add_graph(sess.graph) + for step in xrange(2): + # The floats (sample data) range from -1 to 1. + writer.add_summary( + sess.run( + merged_summary_op, + feed_dict={ + audio_placeholder: numpy.random.rand( + 42, 11025, 1 + ) + * 2 + - 1, + labels_placeholder: [ + tf.compat.as_bytes( + "step **%s**, sample %s" + % (step, sample) + ) + for sample in xrange(42) + ], + }, + ), + global_step=step, + ) + + # Start a server with the plugin. + multiplexer = event_multiplexer.EventMultiplexer( + {"foo": foo_directory, "bar": bar_directory,} + ) + multiplexer.Reload() + context = base_plugin.TBContext( + logdir=self.log_dir, multiplexer=multiplexer + ) + self.plugin = audio_plugin.AudioPlugin(context) + wsgi_app = application.TensorBoardWSGI([self.plugin]) + self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse) + + def tearDown(self): + shutil.rmtree(self.log_dir, ignore_errors=True) + + def _DeserializeResponse(self, byte_content): + """Deserializes byte content that is a JSON encoding. + + Args: + byte_content: The byte content of a response. + + Returns: + The deserialized python object decoded from JSON. + """ + return json.loads(byte_content.decode("utf-8")) + + def testRoutesProvided(self): + """Tests that the plugin offers the correct routes.""" + routes = self.plugin.get_plugin_apps() + self.assertIsInstance(routes["/audio"], collections.Callable) + self.assertIsInstance(routes["/individualAudio"], collections.Callable) + self.assertIsInstance(routes["/tags"], collections.Callable) + + def testOldStyleAudioRoute(self): + """Tests that the /audio routes returns correct old-style data.""" + response = self.server.get( + "/data/plugin/audio/audio?run=foo&tag=baz/audio/0&sample=0" + ) + self.assertEqual(200, response.status_code) + + # Verify that the correct entries are returned. + entries = self._DeserializeResponse(response.get_data()) + self.assertEqual(2, len(entries)) + + # Verify that the 1st entry is correct. + entry = entries[0] + self.assertEqual("audio/wav", entry["contentType"]) + self.assertEqual("", entry["label"]) + self.assertEqual(0, entry["step"]) + parsed_query = urllib.parse.parse_qs(entry["query"]) + self.assertListEqual(["foo"], parsed_query["run"]) + self.assertListEqual(["baz/audio/0"], parsed_query["tag"]) + self.assertListEqual(["0"], parsed_query["sample"]) + self.assertListEqual(["0"], parsed_query["index"]) + + # Verify that the 2nd entry is correct. + entry = entries[1] + self.assertEqual("audio/wav", entry["contentType"]) + self.assertEqual("", entry["label"]) + self.assertEqual(1, entry["step"]) + parsed_query = urllib.parse.parse_qs(entry["query"]) + self.assertListEqual(["foo"], parsed_query["run"]) + self.assertListEqual(["baz/audio/0"], parsed_query["tag"]) + self.assertListEqual(["0"], parsed_query["sample"]) + self.assertListEqual(["1"], parsed_query["index"]) + + def testNewStyleAudioRoute(self): + """Tests that the /audio routes returns correct new-style data.""" + response = self.server.get( + "/data/plugin/audio/audio?run=bar&tag=quux/audio_summary&sample=0" + ) + self.assertEqual(200, response.status_code) + + # Verify that the correct entries are returned. + entries = self._DeserializeResponse(response.get_data()) + self.assertEqual(2, len(entries)) + + # Verify that the 1st entry is correct. + entry = entries[0] + self.assertEqual("audio/wav", entry["contentType"]) + self.assertEqual( + "

step %s, sample 0

" % entry["step"], + entry["label"], + ) + self.assertEqual(0, entry["step"]) + parsed_query = urllib.parse.parse_qs(entry["query"]) + self.assertListEqual(["bar"], parsed_query["run"]) + self.assertListEqual(["quux/audio_summary"], parsed_query["tag"]) + self.assertListEqual(["0"], parsed_query["sample"]) + self.assertListEqual(["0"], parsed_query["index"]) + + # Verify that the 2nd entry is correct. + entry = entries[1] + self.assertEqual("audio/wav", entry["contentType"]) + self.assertEqual( + "

step %s, sample 0

" % entry["step"], + entry["label"], + ) + self.assertEqual(1, entry["step"]) + parsed_query = urllib.parse.parse_qs(entry["query"]) + self.assertListEqual(["bar"], parsed_query["run"]) + self.assertListEqual(["quux/audio_summary"], parsed_query["tag"]) + self.assertListEqual(["0"], parsed_query["sample"]) + self.assertListEqual(["1"], parsed_query["index"]) + + def testOldStyleIndividualAudioRoute(self): + """Tests fetching an individual audio clip from an old-style + summary.""" + response = self.server.get( + "/data/plugin/audio/individualAudio" + "?run=foo&tag=baz/audio/0&sample=0&index=0" + ) + self.assertEqual(200, response.status_code) + self.assertEqual("audio/wav", response.headers.get("content-type")) + + def testNewStyleIndividualAudioRoute(self): + """Tests fetching an individual audio clip from an old-style + summary.""" + response = self.server.get( + "/data/plugin/audio/individualAudio" + "?run=bar&tag=quux/audio_summary&sample=0&index=0" + ) + self.assertEqual(200, response.status_code) + self.assertEqual("audio/wav", response.headers.get("content-type")) + + def testTagsRoute(self): + """Tests that the /tags route offers the correct run to tag mapping.""" + response = self.server.get("/data/plugin/audio/tags") + self.assertEqual(200, response.status_code) + self.assertDictEqual( + { + "foo": { + "baz/audio/0": { + "displayName": "baz/audio/0", + "description": "", + "samples": 1, + }, + "baz/audio/1": { + "displayName": "baz/audio/1", + "description": "", + "samples": 1, + }, + "baz/audio/2": { + "displayName": "baz/audio/2", + "description": "", + "samples": 1, + }, + }, + "bar": { + "quux/audio_summary": { + "displayName": "quux", + "description": "

how do you pronounce that, anyway?

", + "samples": 3, # 42 inputs, but max_outputs=3 + }, + }, }, - }, - }, self._DeserializeResponse(response.get_data())) + self._DeserializeResponse(response.get_data()), + ) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/tensorboard/plugins/audio/metadata.py b/tensorboard/plugins/audio/metadata.py index 2b4baaedef..d6d5cdc6db 100644 --- a/tensorboard/plugins/audio/metadata.py +++ b/tensorboard/plugins/audio/metadata.py @@ -24,7 +24,7 @@ logger = tb_logging.get_logger() -PLUGIN_NAME = 'audio' +PLUGIN_NAME = "audio" # The most recent value for the `version` field of the `AudioPluginData` # proto. @@ -35,40 +35,45 @@ def create_summary_metadata(display_name, description, encoding): - """Create a `SummaryMetadata` proto for audio plugin data. + """Create a `SummaryMetadata` proto for audio plugin data. - Returns: - A `SummaryMetadata` protobuf object. - """ - content = plugin_data_pb2.AudioPluginData( - version=PROTO_VERSION, encoding=encoding) - metadata = summary_pb2.SummaryMetadata( - display_name=display_name, - summary_description=description, - plugin_data=summary_pb2.SummaryMetadata.PluginData( - plugin_name=PLUGIN_NAME, - content=content.SerializeToString())) - return metadata + Returns: + A `SummaryMetadata` protobuf object. + """ + content = plugin_data_pb2.AudioPluginData( + version=PROTO_VERSION, encoding=encoding + ) + metadata = summary_pb2.SummaryMetadata( + display_name=display_name, + summary_description=description, + plugin_data=summary_pb2.SummaryMetadata.PluginData( + plugin_name=PLUGIN_NAME, content=content.SerializeToString() + ), + ) + return metadata def parse_plugin_metadata(content): - """Parse summary metadata to a Python object. + """Parse summary metadata to a Python object. - Arguments: - content: The `content` field of a `SummaryMetadata` proto - corresponding to the audio plugin. + Arguments: + content: The `content` field of a `SummaryMetadata` proto + corresponding to the audio plugin. - Returns: - An `AudioPluginData` protobuf object. - """ - if not isinstance(content, bytes): - raise TypeError('Content type must be bytes') - result = plugin_data_pb2.AudioPluginData.FromString(content) - if result.version == 0: - return result - else: - logger.warn( - 'Unknown metadata version: %s. The latest version known to ' - 'this build of TensorBoard is %s; perhaps a newer build is ' - 'available?', result.version, PROTO_VERSION) - return result + Returns: + An `AudioPluginData` protobuf object. + """ + if not isinstance(content, bytes): + raise TypeError("Content type must be bytes") + result = plugin_data_pb2.AudioPluginData.FromString(content) + if result.version == 0: + return result + else: + logger.warn( + "Unknown metadata version: %s. The latest version known to " + "this build of TensorBoard is %s; perhaps a newer build is " + "available?", + result.version, + PROTO_VERSION, + ) + return result diff --git a/tensorboard/plugins/audio/summary.py b/tensorboard/plugins/audio/summary.py index 6498623dde..506a07e6c5 100644 --- a/tensorboard/plugins/audio/summary.py +++ b/tensorboard/plugins/audio/summary.py @@ -41,176 +41,190 @@ audio = summary_v2.audio -def op(name, - audio, - sample_rate, - labels=None, - max_outputs=3, - encoding=None, - display_name=None, - description=None, - collections=None): - """Create a legacy audio summary op for use in a TensorFlow graph. - - Arguments: - name: A unique name for the generated summary node. - audio: A `Tensor` representing audio data with shape `[k, t, c]`, - where `k` is the number of audio clips, `t` is the number of - frames, and `c` is the number of channels. Elements should be - floating-point values in `[-1.0, 1.0]`. Any of the dimensions may - be statically unknown (i.e., `None`). - sample_rate: An `int` or rank-0 `int32` `Tensor` that represents the - sample rate, in Hz. Must be positive. - labels: Optional `string` `Tensor`, a vector whose length is the - first dimension of `audio`, where `labels[i]` contains arbitrary - textual information about `audio[i]`. (For instance, this could be - some text that a TTS system was supposed to produce.) Markdown is - supported. Contents should be UTF-8. - max_outputs: Optional `int` or rank-0 integer `Tensor`. At most this - many audio clips will be emitted at each step. When more than - `max_outputs` many clips are provided, the first `max_outputs` - many clips will be used and the rest silently discarded. - encoding: A constant `str` (not string tensor) indicating the - desired encoding. You can choose any format you like, as long as - it's "wav". Please see the "API compatibility note" below. - display_name: Optional name for this summary in TensorBoard, as a - constant `str`. Defaults to `name`. - description: Optional long-form description for this summary, as a - constant `str`. Markdown is supported. Defaults to empty. - collections: Optional list of graph collections keys. The new - summary op is added to these collections. Defaults to - `[Graph Keys.SUMMARIES]`. - - Returns: - A TensorFlow summary op. - - API compatibility note: The default value of the `encoding` - argument is _not_ guaranteed to remain unchanged across TensorBoard - versions. In the future, we will by default encode as FLAC instead of - as WAV. If the specific format is important to you, please provide a - file format explicitly. - """ - # TODO(nickfelt): remove on-demand imports once dep situation is fixed. - import tensorflow.compat.v1 as tf - - if display_name is None: - display_name = name - if encoding is None: - encoding = 'wav' - - if encoding == 'wav': - encoding = metadata.Encoding.Value('WAV') - encoder = functools.partial(tf.audio.encode_wav, sample_rate=sample_rate) - else: - raise ValueError('Unknown encoding: %r' % encoding) - - with tf.name_scope(name), \ - tf.control_dependencies([tf.assert_rank(audio, 3)]): +def op( + name, + audio, + sample_rate, + labels=None, + max_outputs=3, + encoding=None, + display_name=None, + description=None, + collections=None, +): + """Create a legacy audio summary op for use in a TensorFlow graph. + + Arguments: + name: A unique name for the generated summary node. + audio: A `Tensor` representing audio data with shape `[k, t, c]`, + where `k` is the number of audio clips, `t` is the number of + frames, and `c` is the number of channels. Elements should be + floating-point values in `[-1.0, 1.0]`. Any of the dimensions may + be statically unknown (i.e., `None`). + sample_rate: An `int` or rank-0 `int32` `Tensor` that represents the + sample rate, in Hz. Must be positive. + labels: Optional `string` `Tensor`, a vector whose length is the + first dimension of `audio`, where `labels[i]` contains arbitrary + textual information about `audio[i]`. (For instance, this could be + some text that a TTS system was supposed to produce.) Markdown is + supported. Contents should be UTF-8. + max_outputs: Optional `int` or rank-0 integer `Tensor`. At most this + many audio clips will be emitted at each step. When more than + `max_outputs` many clips are provided, the first `max_outputs` + many clips will be used and the rest silently discarded. + encoding: A constant `str` (not string tensor) indicating the + desired encoding. You can choose any format you like, as long as + it's "wav". Please see the "API compatibility note" below. + display_name: Optional name for this summary in TensorBoard, as a + constant `str`. Defaults to `name`. + description: Optional long-form description for this summary, as a + constant `str`. Markdown is supported. Defaults to empty. + collections: Optional list of graph collections keys. The new + summary op is added to these collections. Defaults to + `[Graph Keys.SUMMARIES]`. + + Returns: + A TensorFlow summary op. + + API compatibility note: The default value of the `encoding` + argument is _not_ guaranteed to remain unchanged across TensorBoard + versions. In the future, we will by default encode as FLAC instead of + as WAV. If the specific format is important to you, please provide a + file format explicitly. + """ + # TODO(nickfelt): remove on-demand imports once dep situation is fixed. + import tensorflow.compat.v1 as tf + + if display_name is None: + display_name = name + if encoding is None: + encoding = "wav" + + if encoding == "wav": + encoding = metadata.Encoding.Value("WAV") + encoder = functools.partial( + tf.audio.encode_wav, sample_rate=sample_rate + ) + else: + raise ValueError("Unknown encoding: %r" % encoding) + + with tf.name_scope(name), tf.control_dependencies( + [tf.assert_rank(audio, 3)] + ): + limited_audio = audio[:max_outputs] + encoded_audio = tf.map_fn( + encoder, limited_audio, dtype=tf.string, name="encode_each_audio" + ) + if labels is None: + limited_labels = tf.tile([""], tf.shape(input=limited_audio)[:1]) + else: + limited_labels = labels[:max_outputs] + tensor = tf.transpose(a=tf.stack([encoded_audio, limited_labels])) + summary_metadata = metadata.create_summary_metadata( + display_name=display_name, + description=description, + encoding=encoding, + ) + return tf.summary.tensor_summary( + name="audio_summary", + tensor=tensor, + collections=collections, + summary_metadata=summary_metadata, + ) + + +def pb( + name, + audio, + sample_rate, + labels=None, + max_outputs=3, + encoding=None, + display_name=None, + description=None, +): + """Create a legacy audio summary protobuf. + + This behaves as if you were to create an `op` with the same arguments + (wrapped with constant tensors where appropriate) and then execute + that summary op in a TensorFlow session. + + Arguments: + name: A unique name for the generated summary node. + audio: An `np.array` representing audio data with shape `[k, t, c]`, + where `k` is the number of audio clips, `t` is the number of + frames, and `c` is the number of channels. Elements should be + floating-point values in `[-1.0, 1.0]`. + sample_rate: An `int` that represents the sample rate, in Hz. + Must be positive. + labels: Optional list (or rank-1 `np.array`) of textstrings or UTF-8 + bytestrings whose length is the first dimension of `audio`, where + `labels[i]` contains arbitrary textual information about + `audio[i]`. (For instance, this could be some text that a TTS + system was supposed to produce.) Markdown is supported. + max_outputs: Optional `int`. At most this many audio clips will be + emitted. When more than `max_outputs` many clips are provided, the + first `max_outputs` many clips will be used and the rest silently + discarded. + encoding: A constant `str` indicating the desired encoding. You + can choose any format you like, as long as it's "wav". Please see + the "API compatibility note" below. + display_name: Optional name for this summary in TensorBoard, as a + `str`. Defaults to `name`. + description: Optional long-form description for this summary, as a + `str`. Markdown is supported. Defaults to empty. + + Returns: + A `tf.Summary` protobuf object. + + API compatibility note: The default value of the `encoding` + argument is _not_ guaranteed to remain unchanged across TensorBoard + versions. In the future, we will by default encode as FLAC instead of + as WAV. If the specific format is important to you, please provide a + file format explicitly. + """ + # TODO(nickfelt): remove on-demand imports once dep situation is fixed. + import tensorflow.compat.v1 as tf + + audio = np.array(audio) + if audio.ndim != 3: + raise ValueError("Shape %r must have rank 3" % (audio.shape,)) + if encoding is None: + encoding = "wav" + + if encoding == "wav": + encoding = metadata.Encoding.Value("WAV") + encoder = functools.partial( + encoder_util.encode_wav, samples_per_second=sample_rate + ) + else: + raise ValueError("Unknown encoding: %r" % encoding) + limited_audio = audio[:max_outputs] - encoded_audio = tf.map_fn(encoder, limited_audio, - dtype=tf.string, - name='encode_each_audio') if labels is None: - limited_labels = tf.tile([''], tf.shape(input=limited_audio)[:1]) + limited_labels = [b""] * len(limited_audio) else: - limited_labels = labels[:max_outputs] - tensor = tf.transpose(a=tf.stack([encoded_audio, limited_labels])) + limited_labels = [ + tf.compat.as_bytes(label) for label in labels[:max_outputs] + ] + + encoded_audio = [encoder(a) for a in limited_audio] + content = np.array([encoded_audio, limited_labels]).transpose() + tensor = tf.make_tensor_proto(content, dtype=tf.string) + + if display_name is None: + display_name = name summary_metadata = metadata.create_summary_metadata( - display_name=display_name, - description=description, - encoding=encoding) - return tf.summary.tensor_summary(name='audio_summary', - tensor=tensor, - collections=collections, - summary_metadata=summary_metadata) - - -def pb(name, - audio, - sample_rate, - labels=None, - max_outputs=3, - encoding=None, - display_name=None, - description=None): - """Create a legacy audio summary protobuf. - - This behaves as if you were to create an `op` with the same arguments - (wrapped with constant tensors where appropriate) and then execute - that summary op in a TensorFlow session. - - Arguments: - name: A unique name for the generated summary node. - audio: An `np.array` representing audio data with shape `[k, t, c]`, - where `k` is the number of audio clips, `t` is the number of - frames, and `c` is the number of channels. Elements should be - floating-point values in `[-1.0, 1.0]`. - sample_rate: An `int` that represents the sample rate, in Hz. - Must be positive. - labels: Optional list (or rank-1 `np.array`) of textstrings or UTF-8 - bytestrings whose length is the first dimension of `audio`, where - `labels[i]` contains arbitrary textual information about - `audio[i]`. (For instance, this could be some text that a TTS - system was supposed to produce.) Markdown is supported. - max_outputs: Optional `int`. At most this many audio clips will be - emitted. When more than `max_outputs` many clips are provided, the - first `max_outputs` many clips will be used and the rest silently - discarded. - encoding: A constant `str` indicating the desired encoding. You - can choose any format you like, as long as it's "wav". Please see - the "API compatibility note" below. - display_name: Optional name for this summary in TensorBoard, as a - `str`. Defaults to `name`. - description: Optional long-form description for this summary, as a - `str`. Markdown is supported. Defaults to empty. - - Returns: - A `tf.Summary` protobuf object. - - API compatibility note: The default value of the `encoding` - argument is _not_ guaranteed to remain unchanged across TensorBoard - versions. In the future, we will by default encode as FLAC instead of - as WAV. If the specific format is important to you, please provide a - file format explicitly. - """ - # TODO(nickfelt): remove on-demand imports once dep situation is fixed. - import tensorflow.compat.v1 as tf - - audio = np.array(audio) - if audio.ndim != 3: - raise ValueError('Shape %r must have rank 3' % (audio.shape,)) - if encoding is None: - encoding = 'wav' - - if encoding == 'wav': - encoding = metadata.Encoding.Value('WAV') - encoder = functools.partial(encoder_util.encode_wav, - samples_per_second=sample_rate) - else: - raise ValueError('Unknown encoding: %r' % encoding) - - limited_audio = audio[:max_outputs] - if labels is None: - limited_labels = [b''] * len(limited_audio) - else: - limited_labels = [tf.compat.as_bytes(label) - for label in labels[:max_outputs]] - - encoded_audio = [encoder(a) for a in limited_audio] - content = np.array([encoded_audio, limited_labels]).transpose() - tensor = tf.make_tensor_proto(content, dtype=tf.string) - - if display_name is None: - display_name = name - summary_metadata = metadata.create_summary_metadata( - display_name=display_name, - description=description, - encoding=encoding) - tf_summary_metadata = tf.SummaryMetadata.FromString( - summary_metadata.SerializeToString()) - - summary = tf.Summary() - summary.value.add(tag='%s/audio_summary' % name, - metadata=tf_summary_metadata, - tensor=tensor) - return summary + display_name=display_name, description=description, encoding=encoding + ) + tf_summary_metadata = tf.SummaryMetadata.FromString( + summary_metadata.SerializeToString() + ) + + summary = tf.Summary() + summary.value.add( + tag="%s/audio_summary" % name, + metadata=tf_summary_metadata, + tensor=tensor, + ) + return summary diff --git a/tensorboard/plugins/audio/summary_test.py b/tensorboard/plugins/audio/summary_test.py index 25fabdd54e..e4ba03bea1 100644 --- a/tensorboard/plugins/audio/summary_test.py +++ b/tensorboard/plugins/audio/summary_test.py @@ -33,212 +33,223 @@ try: - tf2.__version__ # Force lazy import to resolve + tf2.__version__ # Force lazy import to resolve except ImportError: - tf2 = None + tf2 = None try: - tf.compat.v1.enable_eager_execution() + tf.compat.v1.enable_eager_execution() except AttributeError: - # TF 2.0 doesn't have this symbol because eager is the default. - pass + # TF 2.0 doesn't have this symbol because eager is the default. + pass -audio_ops = getattr(tf, 'audio', None) +audio_ops = getattr(tf, "audio", None) if audio_ops is None: - # Fallback for older versions of TF without tf.audio. - from tensorflow.python.ops import gen_audio_ops as audio_ops + # Fallback for older versions of TF without tf.audio. + from tensorflow.python.ops import gen_audio_ops as audio_ops class SummaryBaseTest(object): - - def setUp(self): - super(SummaryBaseTest, self).setUp() - self.samples_rate = 44100 - self.audio_count = 1 - self.num_samples = 20 - self.num_channels = 2 - - def _generate_audio(self, **kwargs): - size = [ - kwargs.get('k', self.audio_count), - kwargs.get('n', self.num_samples), - kwargs.get('c', self.num_channels), - ] - return np.sin(np.linspace(0.0, 100.0, np.prod(size), - dtype=np.float32)).reshape(size) - - def audio(self, *args, **kwargs): - raise NotImplementedError() - - def test_metadata(self): - data = np.array(1, np.float32, ndmin=3) - description = 'Piano Concerto No. 23 (K488), in **A major.**' - pb = self.audio('k488', data, 44100, description=description) - self.assertEqual(len(pb.value), 1) - summary_metadata = pb.value[0].metadata - self.assertEqual(summary_metadata.summary_description, description) - plugin_data = summary_metadata.plugin_data - self.assertEqual(plugin_data.plugin_name, metadata.PLUGIN_NAME) - content = summary_metadata.plugin_data.content - parsed = metadata.parse_plugin_metadata(content) - self.assertEqual(parsed.encoding, metadata.Encoding.Value('WAV')) - - def test_wav_format_roundtrip(self): - audio = self._generate_audio(c=1) - pb = self.audio('k888', audio, 44100) - encoded = tensor_util.make_ndarray(pb.value[0].tensor) - decoded, sample_rate = audio_ops.decode_wav(encoded.flat[0]) - # WAV roundtrip goes from float32 to int16 and back, so expect some - # precision loss, but not more than 2 applications of rounding error from - # mapping the range [-1.0, 1.0] to 2^16. - epsilon = 2 * 2.0 / (2**16) - self.assertAllClose(audio[0], decoded, atol=epsilon) - self.assertEqual(44100, sample_rate.numpy()) - - def _test_dimensions(self, audio): - pb = self.audio('k888', audio, 44100) - self.assertEqual(1, len(pb.value)) - results = tensor_util.make_ndarray(pb.value[0].tensor) - for i, (encoded, _) in enumerate(results): - decoded, _ = audio_ops.decode_wav(encoded) - self.assertEqual(audio[i].shape, decoded.shape) - - def test_dimensions(self): - # Check mono and stereo. - self._test_dimensions(self._generate_audio(c=1)) - self._test_dimensions(self._generate_audio(c=2)) - - def test_audio_count_zero(self): - shape = (0, self.num_samples, 2) - audio = np.array([]).reshape(shape).astype(np.float32) - pb = self.audio('k488', audio, 44100, max_outputs=3) - self.assertEqual(1, len(pb.value)) - results = tensor_util.make_ndarray(pb.value[0].tensor) - self.assertEqual(results.shape, (0, 2)) - - def test_audio_count_less_than_max_outputs(self): - max_outputs = 3 - data = self._generate_audio(k=(max_outputs - 1)) - pb = self.audio('k488', data, 44100, max_outputs=max_outputs) - self.assertEqual(1, len(pb.value)) - results = tensor_util.make_ndarray(pb.value[0].tensor) - self.assertEqual(results.shape, (len(data), 2)) - - def test_audio_count_when_more_than_max(self): - max_outputs = 3 - data = self._generate_audio(k=(max_outputs + 1)) - pb = self.audio('k488', data, 44100, max_outputs=max_outputs) - self.assertEqual(1, len(pb.value)) - results = tensor_util.make_ndarray(pb.value[0].tensor) - self.assertEqual(results.shape, (max_outputs, 2)) - - def test_requires_nonnegative_max_outputs(self): - data = np.array(1, np.float32, ndmin=3) - with six.assertRaisesRegex( - self, (ValueError, tf.errors.InvalidArgumentError), '>= 0'): - self.audio('k488', data, 44100, max_outputs=-1) - - def test_requires_rank_3(self): - with six.assertRaisesRegex(self, ValueError, 'must have rank 3'): - self.audio('k488', np.array([[1]]), 44100) - - def test_requires_wav(self): - data = np.array(1, np.float32, ndmin=3) - with six.assertRaisesRegex(self, ValueError, 'Unknown encoding'): - self.audio('k488', data, 44100, encoding='pptx') + def setUp(self): + super(SummaryBaseTest, self).setUp() + self.samples_rate = 44100 + self.audio_count = 1 + self.num_samples = 20 + self.num_channels = 2 + + def _generate_audio(self, **kwargs): + size = [ + kwargs.get("k", self.audio_count), + kwargs.get("n", self.num_samples), + kwargs.get("c", self.num_channels), + ] + return np.sin( + np.linspace(0.0, 100.0, np.prod(size), dtype=np.float32) + ).reshape(size) + + def audio(self, *args, **kwargs): + raise NotImplementedError() + + def test_metadata(self): + data = np.array(1, np.float32, ndmin=3) + description = "Piano Concerto No. 23 (K488), in **A major.**" + pb = self.audio("k488", data, 44100, description=description) + self.assertEqual(len(pb.value), 1) + summary_metadata = pb.value[0].metadata + self.assertEqual(summary_metadata.summary_description, description) + plugin_data = summary_metadata.plugin_data + self.assertEqual(plugin_data.plugin_name, metadata.PLUGIN_NAME) + content = summary_metadata.plugin_data.content + parsed = metadata.parse_plugin_metadata(content) + self.assertEqual(parsed.encoding, metadata.Encoding.Value("WAV")) + + def test_wav_format_roundtrip(self): + audio = self._generate_audio(c=1) + pb = self.audio("k888", audio, 44100) + encoded = tensor_util.make_ndarray(pb.value[0].tensor) + decoded, sample_rate = audio_ops.decode_wav(encoded.flat[0]) + # WAV roundtrip goes from float32 to int16 and back, so expect some + # precision loss, but not more than 2 applications of rounding error from + # mapping the range [-1.0, 1.0] to 2^16. + epsilon = 2 * 2.0 / (2 ** 16) + self.assertAllClose(audio[0], decoded, atol=epsilon) + self.assertEqual(44100, sample_rate.numpy()) + + def _test_dimensions(self, audio): + pb = self.audio("k888", audio, 44100) + self.assertEqual(1, len(pb.value)) + results = tensor_util.make_ndarray(pb.value[0].tensor) + for i, (encoded, _) in enumerate(results): + decoded, _ = audio_ops.decode_wav(encoded) + self.assertEqual(audio[i].shape, decoded.shape) + + def test_dimensions(self): + # Check mono and stereo. + self._test_dimensions(self._generate_audio(c=1)) + self._test_dimensions(self._generate_audio(c=2)) + + def test_audio_count_zero(self): + shape = (0, self.num_samples, 2) + audio = np.array([]).reshape(shape).astype(np.float32) + pb = self.audio("k488", audio, 44100, max_outputs=3) + self.assertEqual(1, len(pb.value)) + results = tensor_util.make_ndarray(pb.value[0].tensor) + self.assertEqual(results.shape, (0, 2)) + + def test_audio_count_less_than_max_outputs(self): + max_outputs = 3 + data = self._generate_audio(k=(max_outputs - 1)) + pb = self.audio("k488", data, 44100, max_outputs=max_outputs) + self.assertEqual(1, len(pb.value)) + results = tensor_util.make_ndarray(pb.value[0].tensor) + self.assertEqual(results.shape, (len(data), 2)) + + def test_audio_count_when_more_than_max(self): + max_outputs = 3 + data = self._generate_audio(k=(max_outputs + 1)) + pb = self.audio("k488", data, 44100, max_outputs=max_outputs) + self.assertEqual(1, len(pb.value)) + results = tensor_util.make_ndarray(pb.value[0].tensor) + self.assertEqual(results.shape, (max_outputs, 2)) + + def test_requires_nonnegative_max_outputs(self): + data = np.array(1, np.float32, ndmin=3) + with six.assertRaisesRegex( + self, (ValueError, tf.errors.InvalidArgumentError), ">= 0" + ): + self.audio("k488", data, 44100, max_outputs=-1) + + def test_requires_rank_3(self): + with six.assertRaisesRegex(self, ValueError, "must have rank 3"): + self.audio("k488", np.array([[1]]), 44100) + + def test_requires_wav(self): + data = np.array(1, np.float32, ndmin=3) + with six.assertRaisesRegex(self, ValueError, "Unknown encoding"): + self.audio("k488", data, 44100, encoding="pptx") class SummaryV1PbTest(SummaryBaseTest, tf.test.TestCase): - def setUp(self): - super(SummaryV1PbTest, self).setUp() + def setUp(self): + super(SummaryV1PbTest, self).setUp() - def audio(self, *args, **kwargs): - return summary.pb(*args, **kwargs) + def audio(self, *args, **kwargs): + return summary.pb(*args, **kwargs) - def test_tag(self): - data = np.array(1, np.float32, ndmin=3) - self.assertEqual('a/audio_summary', - self.audio('a', data, 44100).value[0].tag) - self.assertEqual('a/b/audio_summary', - self.audio('a/b', data, 44100).value[0].tag) + def test_tag(self): + data = np.array(1, np.float32, ndmin=3) + self.assertEqual( + "a/audio_summary", self.audio("a", data, 44100).value[0].tag + ) + self.assertEqual( + "a/b/audio_summary", self.audio("a/b", data, 44100).value[0].tag + ) - def test_requires_nonnegative_max_outputs(self): - self.skipTest('summary V1 pb does not actually enforce this') + def test_requires_nonnegative_max_outputs(self): + self.skipTest("summary V1 pb does not actually enforce this") class SummaryV1OpTest(SummaryBaseTest, tf.test.TestCase): - def setUp(self): - super(SummaryV1OpTest, self).setUp() - - def audio(self, *args, **kwargs): - return tf.compat.v1.Summary.FromString(summary.op(*args, **kwargs).numpy()) - - def test_tag(self): - data = np.array(1, np.float32, ndmin=3) - self.assertEqual('a/audio_summary', - self.audio('a', data, 44100).value[0].tag) - self.assertEqual('a/b/audio_summary', - self.audio('a/b', data, 44100).value[0].tag) - - def test_scoped_tag(self): - data = np.array(1, np.float32, ndmin=3) - with tf.name_scope('scope'): - self.assertEqual('scope/a/audio_summary', - self.audio('a', data, 44100).value[0].tag) - - def test_audio_count_zero(self): - self.skipTest('fails under eager because map_fn() returns float dtype') - - def test_requires_nonnegative_max_outputs(self): - self.skipTest('summary V1 op does not actually enforce this') + def setUp(self): + super(SummaryV1OpTest, self).setUp() + + def audio(self, *args, **kwargs): + return tf.compat.v1.Summary.FromString( + summary.op(*args, **kwargs).numpy() + ) + + def test_tag(self): + data = np.array(1, np.float32, ndmin=3) + self.assertEqual( + "a/audio_summary", self.audio("a", data, 44100).value[0].tag + ) + self.assertEqual( + "a/b/audio_summary", self.audio("a/b", data, 44100).value[0].tag + ) + + def test_scoped_tag(self): + data = np.array(1, np.float32, ndmin=3) + with tf.name_scope("scope"): + self.assertEqual( + "scope/a/audio_summary", + self.audio("a", data, 44100).value[0].tag, + ) + + def test_audio_count_zero(self): + self.skipTest("fails under eager because map_fn() returns float dtype") + + def test_requires_nonnegative_max_outputs(self): + self.skipTest("summary V1 op does not actually enforce this") class SummaryV2OpTest(SummaryBaseTest, tf.test.TestCase): - def setUp(self): - super(SummaryV2OpTest, self).setUp() - if tf2 is None: - self.skipTest('TF v2 summary API not available') - - def audio(self, *args, **kwargs): - return self.audio_event(*args, **kwargs).summary - - def audio_event(self, *args, **kwargs): - kwargs.setdefault('step', 1) - writer = tf2.summary.create_file_writer(self.get_temp_dir()) - with writer.as_default(): - summary.audio(*args, **kwargs) - writer.close() - event_files = sorted(glob.glob(os.path.join(self.get_temp_dir(), '*'))) - self.assertEqual(len(event_files), 1) - events = list(tf.compat.v1.train.summary_iterator(event_files[0])) - # Expect a boilerplate event for the file_version, then the summary one. - self.assertEqual(len(events), 2) - # Delete the event file to reset to an empty directory for later calls. - # TODO(nickfelt): use a unique subdirectory per writer instead. - os.remove(event_files[0]) - return events[1] - - def test_scoped_tag(self): - data = np.array(1, np.float32, ndmin=3) - with tf.name_scope('scope'): - self.assertEqual('scope/a', self.audio('a', data, 44100).value[0].tag) - - def test_step(self): - data = np.array(1, np.float32, ndmin=3) - event = self.audio_event('a', data, 44100, step=333) - self.assertEqual(333, event.step) - - def test_default_step(self): - data = np.array(1, np.float32, ndmin=3) - try: - tf2.summary.experimental.set_step(333) - # TODO(nickfelt): change test logic so we can just omit `step` entirely. - event = self.audio_event('a', data, 44100, step=None) - self.assertEqual(333, event.step) - finally: - # Reset to default state for other tests. - tf2.summary.experimental.set_step(None) - - -if __name__ == '__main__': - tf.test.main() + def setUp(self): + super(SummaryV2OpTest, self).setUp() + if tf2 is None: + self.skipTest("TF v2 summary API not available") + + def audio(self, *args, **kwargs): + return self.audio_event(*args, **kwargs).summary + + def audio_event(self, *args, **kwargs): + kwargs.setdefault("step", 1) + writer = tf2.summary.create_file_writer(self.get_temp_dir()) + with writer.as_default(): + summary.audio(*args, **kwargs) + writer.close() + event_files = sorted(glob.glob(os.path.join(self.get_temp_dir(), "*"))) + self.assertEqual(len(event_files), 1) + events = list(tf.compat.v1.train.summary_iterator(event_files[0])) + # Expect a boilerplate event for the file_version, then the summary one. + self.assertEqual(len(events), 2) + # Delete the event file to reset to an empty directory for later calls. + # TODO(nickfelt): use a unique subdirectory per writer instead. + os.remove(event_files[0]) + return events[1] + + def test_scoped_tag(self): + data = np.array(1, np.float32, ndmin=3) + with tf.name_scope("scope"): + self.assertEqual( + "scope/a", self.audio("a", data, 44100).value[0].tag + ) + + def test_step(self): + data = np.array(1, np.float32, ndmin=3) + event = self.audio_event("a", data, 44100, step=333) + self.assertEqual(333, event.step) + + def test_default_step(self): + data = np.array(1, np.float32, ndmin=3) + try: + tf2.summary.experimental.set_step(333) + # TODO(nickfelt): change test logic so we can just omit `step` entirely. + event = self.audio_event("a", data, 44100, step=None) + self.assertEqual(333, event.step) + finally: + # Reset to default state for other tests. + tf2.summary.experimental.set_step(None) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/audio/summary_v2.py b/tensorboard/plugins/audio/summary_v2.py index 157c7db537..a4d078a82f 100644 --- a/tensorboard/plugins/audio/summary_v2.py +++ b/tensorboard/plugins/audio/summary_v2.py @@ -32,87 +32,97 @@ from tensorboard.util import lazy_tensor_creator -def audio(name, - data, - sample_rate, - step=None, - max_outputs=3, - encoding=None, - description=None): - """Write an audio summary. +def audio( + name, + data, + sample_rate, + step=None, + max_outputs=3, + encoding=None, + description=None, +): + """Write an audio summary. - Arguments: - name: A name for this summary. The summary tag used for TensorBoard will - be this name prefixed by any active name scopes. - data: A `Tensor` representing audio data with shape `[k, t, c]`, - where `k` is the number of audio clips, `t` is the number of - frames, and `c` is the number of channels. Elements should be - floating-point values in `[-1.0, 1.0]`. Any of the dimensions may - be statically unknown (i.e., `None`). - sample_rate: An `int` or rank-0 `int32` `Tensor` that represents the - sample rate, in Hz. Must be positive. - step: Explicit `int64`-castable monotonic step value for this summary. If - omitted, this defaults to `tf.summary.experimental.get_step()`, which must - not be None. - max_outputs: Optional `int` or rank-0 integer `Tensor`. At most this - many audio clips will be emitted at each step. When more than - `max_outputs` many clips are provided, the first `max_outputs` - many clips will be used and the rest silently discarded. - encoding: Optional constant `str` for the desired encoding. Only "wav" - is currently supported, but this is not guaranteed to remain the - default, so if you want "wav" in particular, set this explicitly. - description: Optional long-form description for this summary, as a - constant `str`. Markdown is supported. Defaults to empty. + Arguments: + name: A name for this summary. The summary tag used for TensorBoard will + be this name prefixed by any active name scopes. + data: A `Tensor` representing audio data with shape `[k, t, c]`, + where `k` is the number of audio clips, `t` is the number of + frames, and `c` is the number of channels. Elements should be + floating-point values in `[-1.0, 1.0]`. Any of the dimensions may + be statically unknown (i.e., `None`). + sample_rate: An `int` or rank-0 `int32` `Tensor` that represents the + sample rate, in Hz. Must be positive. + step: Explicit `int64`-castable monotonic step value for this summary. If + omitted, this defaults to `tf.summary.experimental.get_step()`, which must + not be None. + max_outputs: Optional `int` or rank-0 integer `Tensor`. At most this + many audio clips will be emitted at each step. When more than + `max_outputs` many clips are provided, the first `max_outputs` + many clips will be used and the rest silently discarded. + encoding: Optional constant `str` for the desired encoding. Only "wav" + is currently supported, but this is not guaranteed to remain the + default, so if you want "wav" in particular, set this explicitly. + description: Optional long-form description for this summary, as a + constant `str`. Markdown is supported. Defaults to empty. - Returns: - True on success, or false if no summary was emitted because no default - summary writer was available. + Returns: + True on success, or false if no summary was emitted because no default + summary writer was available. - Raises: - ValueError: if a default writer exists, but no step was provided and - `tf.summary.experimental.get_step()` is None. - """ - audio_ops = getattr(tf, 'audio', None) - if audio_ops is None: - # Fallback for older versions of TF without tf.audio. - from tensorflow.python.ops import gen_audio_ops as audio_ops + Raises: + ValueError: if a default writer exists, but no step was provided and + `tf.summary.experimental.get_step()` is None. + """ + audio_ops = getattr(tf, "audio", None) + if audio_ops is None: + # Fallback for older versions of TF without tf.audio. + from tensorflow.python.ops import gen_audio_ops as audio_ops - if encoding is None: - encoding = 'wav' - if encoding != 'wav': - raise ValueError('Unknown encoding: %r' % encoding) - summary_metadata = metadata.create_summary_metadata( - display_name=None, - description=description, - encoding=metadata.Encoding.Value('WAV')) - inputs = [data, sample_rate, max_outputs, step] - # TODO(https://github.com/tensorflow/tensorboard/issues/2109): remove fallback - summary_scope = ( - getattr(tf.summary.experimental, 'summary_scope', None) or - tf.summary.summary_scope) - with summary_scope( - name, 'audio_summary', values=inputs) as (tag, _): - # Defer audio encoding preprocessing by passing it as a callable to write(), - # wrapped in a LazyTensorCreator for backwards compatibility, so that we - # only do this work when summaries are actually written. - @lazy_tensor_creator.LazyTensorCreator - def lazy_tensor(): - tf.debugging.assert_rank(data, 3) - tf.debugging.assert_non_negative(max_outputs) - limited_audio = data[:max_outputs] - encode_fn = functools.partial(audio_ops.encode_wav, - sample_rate=sample_rate) - encoded_audio = tf.map_fn(encode_fn, limited_audio, - dtype=tf.string, - name='encode_each_audio') - # Workaround for map_fn returning float dtype for an empty elems input. - encoded_audio = tf.cond( - tf.shape(input=encoded_audio)[0] > 0, - lambda: encoded_audio, lambda: tf.constant([], tf.string)) - limited_labels = tf.tile([''], tf.shape(input=limited_audio)[:1]) - return tf.transpose(a=tf.stack([encoded_audio, limited_labels])) + if encoding is None: + encoding = "wav" + if encoding != "wav": + raise ValueError("Unknown encoding: %r" % encoding) + summary_metadata = metadata.create_summary_metadata( + display_name=None, + description=description, + encoding=metadata.Encoding.Value("WAV"), + ) + inputs = [data, sample_rate, max_outputs, step] + # TODO(https://github.com/tensorflow/tensorboard/issues/2109): remove fallback + summary_scope = ( + getattr(tf.summary.experimental, "summary_scope", None) + or tf.summary.summary_scope + ) + with summary_scope(name, "audio_summary", values=inputs) as (tag, _): + # Defer audio encoding preprocessing by passing it as a callable to write(), + # wrapped in a LazyTensorCreator for backwards compatibility, so that we + # only do this work when summaries are actually written. + @lazy_tensor_creator.LazyTensorCreator + def lazy_tensor(): + tf.debugging.assert_rank(data, 3) + tf.debugging.assert_non_negative(max_outputs) + limited_audio = data[:max_outputs] + encode_fn = functools.partial( + audio_ops.encode_wav, sample_rate=sample_rate + ) + encoded_audio = tf.map_fn( + encode_fn, + limited_audio, + dtype=tf.string, + name="encode_each_audio", + ) + # Workaround for map_fn returning float dtype for an empty elems input. + encoded_audio = tf.cond( + tf.shape(input=encoded_audio)[0] > 0, + lambda: encoded_audio, + lambda: tf.constant([], tf.string), + ) + limited_labels = tf.tile([""], tf.shape(input=limited_audio)[:1]) + return tf.transpose(a=tf.stack([encoded_audio, limited_labels])) - # To ensure that audio encoding logic is only executed when summaries - # are written, we pass callable to `tensor` parameter. - return tf.summary.write( - tag=tag, tensor=lazy_tensor, step=step, metadata=summary_metadata) + # To ensure that audio encoding logic is only executed when summaries + # are written, we pass callable to `tensor` parameter. + return tf.summary.write( + tag=tag, tensor=lazy_tensor, step=step, metadata=summary_metadata + ) diff --git a/tensorboard/plugins/base_plugin.py b/tensorboard/plugins/base_plugin.py index 2b87567e80..04f93dd00e 100644 --- a/tensorboard/plugins/base_plugin.py +++ b/tensorboard/plugins/base_plugin.py @@ -14,8 +14,8 @@ # ============================================================================== """TensorBoard Plugin abstract base class. -Every plugin in TensorBoard must extend and implement the abstract methods of -this base class. +Every plugin in TensorBoard must extend and implement the abstract +methods of this base class. """ from __future__ import absolute_import @@ -31,331 +31,339 @@ @six.add_metaclass(ABCMeta) class TBPlugin(object): - """TensorBoard plugin interface. + """TensorBoard plugin interface. + + Every plugin must extend from this class. + + Subclasses should have a trivial constructor that takes a TBContext + argument. Any operation that might throw an exception should either be + done lazily or made safe with a TBLoader subclass, so the plugin won't + negatively impact the rest of TensorBoard. + + Fields: + plugin_name: The plugin_name will also be a prefix in the http + handlers, e.g. `data/plugins/$PLUGIN_NAME/$HANDLER` The plugin + name must be unique for each registered plugin, or a ValueError + will be thrown when the application is constructed. The plugin + name must only contain characters among [A-Za-z0-9_.-], and must + be nonempty, or a ValueError will similarly be thrown. + """ - Every plugin must extend from this class. + plugin_name = None - Subclasses should have a trivial constructor that takes a TBContext - argument. Any operation that might throw an exception should either be - done lazily or made safe with a TBLoader subclass, so the plugin won't - negatively impact the rest of TensorBoard. + def __init__(self, context): + """Initializes this plugin. - Fields: - plugin_name: The plugin_name will also be a prefix in the http - handlers, e.g. `data/plugins/$PLUGIN_NAME/$HANDLER` The plugin - name must be unique for each registered plugin, or a ValueError - will be thrown when the application is constructed. The plugin - name must only contain characters among [A-Za-z0-9_.-], and must - be nonempty, or a ValueError will similarly be thrown. - """ + The default implementation does nothing. Subclasses are encouraged + to override this and save any necessary fields from the `context`. - plugin_name = None + Args: + context: A `base_plugin.TBContext` object. + """ + pass - def __init__(self, context): - """Initializes this plugin. + @abstractmethod + def get_plugin_apps(self): + """Returns a set of WSGI applications that the plugin implements. - The default implementation does nothing. Subclasses are encouraged - to override this and save any necessary fields from the `context`. + Each application gets registered with the tensorboard app and is served + under a prefix path that includes the name of the plugin. - Args: - context: A `base_plugin.TBContext` object. - """ - pass + Returns: + A dict mapping route paths to WSGI applications. Each route path + should include a leading slash. + """ + raise NotImplementedError() - @abstractmethod - def get_plugin_apps(self): - """Returns a set of WSGI applications that the plugin implements. + @abstractmethod + def is_active(self): + """Determines whether this plugin is active. - Each application gets registered with the tensorboard app and is served - under a prefix path that includes the name of the plugin. + A plugin may not be active for instance if it lacks relevant data. If a + plugin is inactive, the frontend may avoid issuing requests to its routes. - Returns: - A dict mapping route paths to WSGI applications. Each route path - should include a leading slash. - """ - raise NotImplementedError() + Returns: + A boolean value. Whether this plugin is active. + """ + raise NotImplementedError() - @abstractmethod - def is_active(self): - """Determines whether this plugin is active. + def frontend_metadata(self): + """Defines how the plugin will be displayed on the frontend. - A plugin may not be active for instance if it lacks relevant data. If a - plugin is inactive, the frontend may avoid issuing requests to its routes. + The base implementation returns a default value. Subclasses + should override this and specify either an `es_module_path` or + (for legacy plugins) an `element_name`, and are encouraged to + set any other relevant attributes. + """ + return FrontendMetadata() - Returns: - A boolean value. Whether this plugin is active. - """ - raise NotImplementedError() - def frontend_metadata(self): - """Defines how the plugin will be displayed on the frontend. +class FrontendMetadata(object): + """Metadata required to render a plugin on the frontend. - The base implementation returns a default value. Subclasses should - override this and specify either an `es_module_path` or (for legacy - plugins) an `element_name`, and are encouraged to set any other - relevant attributes. + Each argument to the constructor is publicly accessible under a + field of the same name. See constructor docs for further details. """ - return FrontendMetadata() - -class FrontendMetadata(object): - """Metadata required to render a plugin on the frontend. - - Each argument to the constructor is publicly accessible under a field - of the same name. See constructor docs for further details. - """ - - def __init__( - self, - disable_reload=None, - element_name=None, - es_module_path=None, - remove_dom=None, - tab_name=None, - is_ng_component=False, - ): - """Creates a `FrontendMetadata` value. - - The argument list is sorted and may be extended in the future; - therefore, callers must pass only named arguments to this - constructor. - - Args: - disable_reload: Whether to disable the reload button and - auto-reload timer. A `bool`; defaults to `False`. - element_name: For legacy plugins, name of the custom element - defining the plugin frontend: e.g., `"tf-scalar-dashboard"`. - A `str` or `None` (for iframed plugins). Mutually exclusive - with `es_module_path`. - es_module_path: ES module to use as an entry point to this plugin. - A `str` that is a key in the result of `get_plugin_apps()`, or - `None` for legacy plugins bundled with TensorBoard as part of - `webfiles.zip`. Mutually exclusive with legacy `element_name` - remove_dom: Whether to remove the plugin DOM when switching to a - different plugin, to trigger the Polymer 'detached' event. - A `bool`; defaults to `False`. - tab_name: Name to show in the menu item for this dashboard within - the navigation bar. May differ from the plugin name: for - instance, the tab name should not use underscores to separate - words. Should be a `str` or `None` (the default; indicates to - use the plugin name as the tab name). - is_ng_component: Set to `True` only for built-in Agnular plugins. - In this case, the `plugin_name` property of the Plugin, which is - mapped to the `id` property in JavaScript's `UiPluginMetadata` type, - is used to select the Angular component. A `True` value is mutually - exclusive with `element_name` and `es_module_path`. - """ - self._disable_reload = False if disable_reload is None else disable_reload - self._element_name = element_name - self._es_module_path = es_module_path - self._remove_dom = False if remove_dom is None else remove_dom - self._tab_name = tab_name - self._is_ng_component = is_ng_component - - @property - def disable_reload(self): - return self._disable_reload - - @property - def element_name(self): - return self._element_name - - @property - def is_ng_component(self): - return self._is_ng_component - - @property - def es_module_path(self): - return self._es_module_path - - @property - def remove_dom(self): - return self._remove_dom - - @property - def tab_name(self): - return self._tab_name - - def __eq__(self, other): - if not isinstance(other, FrontendMetadata): - return False - if self._disable_reload != other._disable_reload: - return False - if self._disable_reload != other._disable_reload: - return False - if self._element_name != other._element_name: - return False - if self._es_module_path != other._es_module_path: - return False - if self._remove_dom != other._remove_dom: - return False - if self._tab_name != other._tab_name: - return False - return True - - def __hash__(self): - return hash(( - self._disable_reload, - self._element_name, - self._es_module_path, - self._remove_dom, - self._tab_name, - self._is_ng_component, - )) - - def __repr__(self): - return "FrontendMetadata(%s)" % ", ".join(( - "disable_reload=%r" % self._disable_reload, - "element_name=%r" % self._element_name, - "es_module_path=%r" % self._es_module_path, - "remove_dom=%r" % self._remove_dom, - "tab_name=%r" % self._tab_name, - "is_ng_component=%r" % self._is_ng_component, - )) + def __init__( + self, + disable_reload=None, + element_name=None, + es_module_path=None, + remove_dom=None, + tab_name=None, + is_ng_component=False, + ): + """Creates a `FrontendMetadata` value. + + The argument list is sorted and may be extended in the future; + therefore, callers must pass only named arguments to this + constructor. + + Args: + disable_reload: Whether to disable the reload button and + auto-reload timer. A `bool`; defaults to `False`. + element_name: For legacy plugins, name of the custom element + defining the plugin frontend: e.g., `"tf-scalar-dashboard"`. + A `str` or `None` (for iframed plugins). Mutually exclusive + with `es_module_path`. + es_module_path: ES module to use as an entry point to this plugin. + A `str` that is a key in the result of `get_plugin_apps()`, or + `None` for legacy plugins bundled with TensorBoard as part of + `webfiles.zip`. Mutually exclusive with legacy `element_name` + remove_dom: Whether to remove the plugin DOM when switching to a + different plugin, to trigger the Polymer 'detached' event. + A `bool`; defaults to `False`. + tab_name: Name to show in the menu item for this dashboard within + the navigation bar. May differ from the plugin name: for + instance, the tab name should not use underscores to separate + words. Should be a `str` or `None` (the default; indicates to + use the plugin name as the tab name). + is_ng_component: Set to `True` only for built-in Agnular plugins. + In this case, the `plugin_name` property of the Plugin, which is + mapped to the `id` property in JavaScript's `UiPluginMetadata` type, + is used to select the Angular component. A `True` value is mutually + exclusive with `element_name` and `es_module_path`. + """ + self._disable_reload = ( + False if disable_reload is None else disable_reload + ) + self._element_name = element_name + self._es_module_path = es_module_path + self._remove_dom = False if remove_dom is None else remove_dom + self._tab_name = tab_name + self._is_ng_component = is_ng_component + + @property + def disable_reload(self): + return self._disable_reload + + @property + def element_name(self): + return self._element_name + + @property + def is_ng_component(self): + return self._is_ng_component + + @property + def es_module_path(self): + return self._es_module_path + + @property + def remove_dom(self): + return self._remove_dom + + @property + def tab_name(self): + return self._tab_name + + def __eq__(self, other): + if not isinstance(other, FrontendMetadata): + return False + if self._disable_reload != other._disable_reload: + return False + if self._disable_reload != other._disable_reload: + return False + if self._element_name != other._element_name: + return False + if self._es_module_path != other._es_module_path: + return False + if self._remove_dom != other._remove_dom: + return False + if self._tab_name != other._tab_name: + return False + return True + + def __hash__(self): + return hash( + ( + self._disable_reload, + self._element_name, + self._es_module_path, + self._remove_dom, + self._tab_name, + self._is_ng_component, + ) + ) + + def __repr__(self): + return "FrontendMetadata(%s)" % ", ".join( + ( + "disable_reload=%r" % self._disable_reload, + "element_name=%r" % self._element_name, + "es_module_path=%r" % self._es_module_path, + "remove_dom=%r" % self._remove_dom, + "tab_name=%r" % self._tab_name, + "is_ng_component=%r" % self._is_ng_component, + ) + ) class TBContext(object): - """Magic container of information passed from TensorBoard core to plugins. - - A TBContext instance is passed to the constructor of a TBPlugin class. Plugins - are strongly encouraged to assume that any of these fields can be None. In - cases when a field is considered mandatory by a plugin, it can either crash - with ValueError, or silently choose to disable itself by returning False from - its is_active method. - - All fields in this object are thread safe. - """ - - def __init__( - self, - assets_zip_provider=None, - data_provider=None, - db_connection_provider=None, - db_uri=None, - flags=None, - logdir=None, - multiplexer=None, - plugin_name_to_instance=None, - window_title=None): - """Instantiates magic container. - - The argument list is sorted and may be extended in the future; therefore, - callers must pass only named arguments to this constructor. - - Args: - assets_zip_provider: A function that returns a newly opened file handle - for a zip file containing all static assets. The file names inside the - zip file are considered absolute paths on the web server. The file - handle this function returns must be closed. It is assumed that you - will pass this file handle to zipfile.ZipFile. This zip file should - also have been created by the tensorboard_zip_file build rule. - data_provider: Instance of `tensorboard.data.provider.DataProvider`. May - be `None` if `flags.generic_data` is set to `"false"`. - db_connection_provider: Function taking no arguments that returns a - PEP-249 database Connection object, or None if multiplexer should be - used instead. The returned value must be closed, and is safe to use in - a `with` statement. It is also safe to assume that calling this - function is cheap. The returned connection must only be used by a - single thread. Things like connection pooling are considered - implementation details of the provider. - db_uri: The string db URI TensorBoard was started with. If this is set, - the logdir should be None. - flags: An object of the runtime flags provided to TensorBoard to their - values. - logdir: The string logging directory TensorBoard was started with. If this - is set, the db_uri should be None. - multiplexer: An EventMultiplexer with underlying TB data. Plugins should - copy this data over to the database when the db fields are set. - plugin_name_to_instance: A mapping between plugin name to instance. - Plugins may use this property to access other plugins. The context - object is passed to plugins during their construction, so a given - plugin may be absent from this mapping until it is registered. Plugin - logic should handle cases in which a plugin is absent from this - mapping, lest a KeyError is raised. - window_title: A string specifying the window title. + """Magic container of information passed from TensorBoard core to plugins. + + A TBContext instance is passed to the constructor of a TBPlugin class. Plugins + are strongly encouraged to assume that any of these fields can be None. In + cases when a field is considered mandatory by a plugin, it can either crash + with ValueError, or silently choose to disable itself by returning False from + its is_active method. + + All fields in this object are thread safe. """ - self.assets_zip_provider = assets_zip_provider - self.data_provider = data_provider - self.db_connection_provider = db_connection_provider - self.db_uri = db_uri - self.flags = flags - self.logdir = logdir - self.multiplexer = multiplexer - self.plugin_name_to_instance = plugin_name_to_instance - self.window_title = window_title + + def __init__( + self, + assets_zip_provider=None, + data_provider=None, + db_connection_provider=None, + db_uri=None, + flags=None, + logdir=None, + multiplexer=None, + plugin_name_to_instance=None, + window_title=None, + ): + """Instantiates magic container. + + The argument list is sorted and may be extended in the future; therefore, + callers must pass only named arguments to this constructor. + + Args: + assets_zip_provider: A function that returns a newly opened file handle + for a zip file containing all static assets. The file names inside the + zip file are considered absolute paths on the web server. The file + handle this function returns must be closed. It is assumed that you + will pass this file handle to zipfile.ZipFile. This zip file should + also have been created by the tensorboard_zip_file build rule. + data_provider: Instance of `tensorboard.data.provider.DataProvider`. May + be `None` if `flags.generic_data` is set to `"false"`. + db_connection_provider: Function taking no arguments that returns a + PEP-249 database Connection object, or None if multiplexer should be + used instead. The returned value must be closed, and is safe to use in + a `with` statement. It is also safe to assume that calling this + function is cheap. The returned connection must only be used by a + single thread. Things like connection pooling are considered + implementation details of the provider. + db_uri: The string db URI TensorBoard was started with. If this is set, + the logdir should be None. + flags: An object of the runtime flags provided to TensorBoard to their + values. + logdir: The string logging directory TensorBoard was started with. If this + is set, the db_uri should be None. + multiplexer: An EventMultiplexer with underlying TB data. Plugins should + copy this data over to the database when the db fields are set. + plugin_name_to_instance: A mapping between plugin name to instance. + Plugins may use this property to access other plugins. The context + object is passed to plugins during their construction, so a given + plugin may be absent from this mapping until it is registered. Plugin + logic should handle cases in which a plugin is absent from this + mapping, lest a KeyError is raised. + window_title: A string specifying the window title. + """ + self.assets_zip_provider = assets_zip_provider + self.data_provider = data_provider + self.db_connection_provider = db_connection_provider + self.db_uri = db_uri + self.flags = flags + self.logdir = logdir + self.multiplexer = multiplexer + self.plugin_name_to_instance = plugin_name_to_instance + self.window_title = window_title class TBLoader(object): - """TBPlugin factory base class. + """TBPlugin factory base class. - Plugins can override this class to customize how a plugin is loaded at - startup. This might entail adding command-line arguments, checking if - optional dependencies are installed, and potentially also specializing - the plugin class at runtime. + Plugins can override this class to customize how a plugin is loaded at + startup. This might entail adding command-line arguments, checking if + optional dependencies are installed, and potentially also specializing + the plugin class at runtime. - When plugins use optional dependencies, the loader needs to be - specified in its own module. That way it's guaranteed to be - importable, even if the `TBPlugin` itself can't be imported. + When plugins use optional dependencies, the loader needs to be + specified in its own module. That way it's guaranteed to be + importable, even if the `TBPlugin` itself can't be imported. - Subclasses must have trivial constructors. - """ + Subclasses must have trivial constructors. + """ - def define_flags(self, parser): - """Adds plugin-specific CLI flags to parser. + def define_flags(self, parser): + """Adds plugin-specific CLI flags to parser. - The default behavior is to do nothing. + The default behavior is to do nothing. - When overriding this method, it's recommended that plugins call the - `parser.add_argument_group(plugin_name)` method for readability. No - flags should be specified that would cause `parse_args([])` to fail. + When overriding this method, it's recommended that plugins call the + `parser.add_argument_group(plugin_name)` method for readability. No + flags should be specified that would cause `parse_args([])` to fail. - Args: - parser: The argument parsing object, which may be mutated. - """ - pass + Args: + parser: The argument parsing object, which may be mutated. + """ + pass - def fix_flags(self, flags): - """Allows flag values to be corrected or validated after parsing. + def fix_flags(self, flags): + """Allows flag values to be corrected or validated after parsing. - Args: - flags: The parsed argparse.Namespace object. + Args: + flags: The parsed argparse.Namespace object. - Raises: - base_plugin.FlagsError: If a flag is invalid or a required - flag is not passed. - """ - pass + Raises: + base_plugin.FlagsError: If a flag is invalid or a required + flag is not passed. + """ + pass - def load(self, context): - """Loads a TBPlugin instance during the setup phase. + def load(self, context): + """Loads a TBPlugin instance during the setup phase. - Args: - context: The TBContext instance. + Args: + context: The TBContext instance. - Returns: - A plugin instance or None if it could not be loaded. Loaders that return - None are skipped. + Returns: + A plugin instance or None if it could not be loaded. Loaders that return + None are skipped. - :type context: TBContext - :rtype: TBPlugin | None - """ - return None + :type context: TBContext + :rtype: TBPlugin | None + """ + return None class BasicLoader(TBLoader): - """Simple TBLoader that's sufficient for most plugins.""" + """Simple TBLoader that's sufficient for most plugins.""" - def __init__(self, plugin_class): - """Creates simple plugin instance maker. + def __init__(self, plugin_class): + """Creates simple plugin instance maker. - :param plugin_class: :class:`TBPlugin` - """ - self.plugin_class = plugin_class + :param plugin_class: :class:`TBPlugin` + """ + self.plugin_class = plugin_class - def load(self, context): - return self.plugin_class(context) + def load(self, context): + return self.plugin_class(context) class FlagsError(ValueError): - """Raised when a command line flag is not specified or is invalid.""" - pass + """Raised when a command line flag is not specified or is invalid.""" + + pass diff --git a/tensorboard/plugins/base_plugin_test.py b/tensorboard/plugins/base_plugin_test.py index 772fff2e05..dbb67fa037 100644 --- a/tensorboard/plugins/base_plugin_test.py +++ b/tensorboard/plugins/base_plugin_test.py @@ -23,50 +23,49 @@ class FrontendMetadataTest(tb_test.TestCase): + def _create_metadata(self): + return base_plugin.FrontendMetadata( + disable_reload="my disable_reload", + element_name="my element_name", + es_module_path="my es_module_path", + remove_dom="my remove_dom", + tab_name="my tab_name", + ) - def _create_metadata(self): - return base_plugin.FrontendMetadata( - disable_reload="my disable_reload", - element_name="my element_name", - es_module_path="my es_module_path", - remove_dom="my remove_dom", - tab_name="my tab_name", - ) + def test_basics(self): + md = self._create_metadata() + self.assertEqual(md.disable_reload, "my disable_reload") + self.assertEqual(md.element_name, "my element_name") + self.assertEqual(md.es_module_path, "my es_module_path") + self.assertEqual(md.remove_dom, "my remove_dom") + self.assertEqual(md.tab_name, "my tab_name") - def test_basics(self): - md = self._create_metadata() - self.assertEqual(md.disable_reload, "my disable_reload") - self.assertEqual(md.element_name, "my element_name") - self.assertEqual(md.es_module_path, "my es_module_path") - self.assertEqual(md.remove_dom, "my remove_dom") - self.assertEqual(md.tab_name, "my tab_name") + def test_repr(self): + repr_ = repr(self._create_metadata()) + self.assertIn(repr("my disable_reload"), repr_) + self.assertIn(repr("my element_name"), repr_) + self.assertIn(repr("my es_module_path"), repr_) + self.assertIn(repr("my remove_dom"), repr_) + self.assertIn(repr("my tab_name"), repr_) - def test_repr(self): - repr_ = repr(self._create_metadata()) - self.assertIn(repr("my disable_reload"), repr_) - self.assertIn(repr("my element_name"), repr_) - self.assertIn(repr("my es_module_path"), repr_) - self.assertIn(repr("my remove_dom"), repr_) - self.assertIn(repr("my tab_name"), repr_) + def test_eq(self): + md1 = base_plugin.FrontendMetadata(element_name="foo") + md2 = base_plugin.FrontendMetadata(element_name="foo") + md3 = base_plugin.FrontendMetadata(element_name="bar") + self.assertEqual(md1, md2) + self.assertNotEqual(md1, md3) + self.assertNotEqual(md1, "hmm") - def test_eq(self): - md1 = base_plugin.FrontendMetadata(element_name="foo") - md2 = base_plugin.FrontendMetadata(element_name="foo") - md3 = base_plugin.FrontendMetadata(element_name="bar") - self.assertEqual(md1, md2) - self.assertNotEqual(md1, md3) - self.assertNotEqual(md1, "hmm") + def test_hash(self): + md1 = base_plugin.FrontendMetadata(element_name="foo") + md2 = base_plugin.FrontendMetadata(element_name="foo") + md3 = base_plugin.FrontendMetadata(element_name="bar") + self.assertEqual(hash(md1), hash(md2)) + # The next check is technically not required by the `__hash__` + # contract, but _should_ pass; failure on this assertion would at + # least warrant some scrutiny. + self.assertNotEqual(hash(md1), hash(md3)) - def test_hash(self): - md1 = base_plugin.FrontendMetadata(element_name="foo") - md2 = base_plugin.FrontendMetadata(element_name="foo") - md3 = base_plugin.FrontendMetadata(element_name="bar") - self.assertEqual(hash(md1), hash(md2)) - # The next check is technically not required by the `__hash__` - # contract, but _should_ pass; failure on this assertion would at - # least warrant some scrutiny. - self.assertNotEqual(hash(md1), hash(md3)) - -if __name__ == '__main__': - tb_test.main() +if __name__ == "__main__": + tb_test.main() diff --git a/tensorboard/plugins/beholder/__init__.py b/tensorboard/plugins/beholder/__init__.py index d5d71a2d9d..2c6234da13 100644 --- a/tensorboard/plugins/beholder/__init__.py +++ b/tensorboard/plugins/beholder/__init__.py @@ -14,10 +14,10 @@ # Only import Beholder API when tensorflow is available. try: - # pylint: disable=unused-import - import tensorflow + # pylint: disable=unused-import + import tensorflow except ImportError: - pass + pass else: - from tensorboard.plugins.beholder.beholder import Beholder - from tensorboard.plugins.beholder.beholder import BeholderHook + from tensorboard.plugins.beholder.beholder import Beholder + from tensorboard.plugins.beholder.beholder import BeholderHook diff --git a/tensorboard/plugins/beholder/beholder.py b/tensorboard/plugins/beholder/beholder.py index 10ea9ddc8f..30f445f4d7 100644 --- a/tensorboard/plugins/beholder/beholder.py +++ b/tensorboard/plugins/beholder/beholder.py @@ -23,10 +23,19 @@ import tensorflow as tf from tensorboard.plugins.beholder import im_util -from tensorboard.plugins.beholder.file_system_tools import read_pickle,\ - write_pickle, write_file -from tensorboard.plugins.beholder.shared_config import PLUGIN_NAME, TAG_NAME,\ - SUMMARY_FILENAME, DEFAULT_CONFIG, CONFIG_FILENAME, SUMMARY_COLLECTION_KEY_NAME +from tensorboard.plugins.beholder.file_system_tools import ( + read_pickle, + write_pickle, + write_file, +) +from tensorboard.plugins.beholder.shared_config import ( + PLUGIN_NAME, + TAG_NAME, + SUMMARY_FILENAME, + DEFAULT_CONFIG, + CONFIG_FILENAME, + SUMMARY_COLLECTION_KEY_NAME, +) from tensorboard.plugins.beholder import video_writing from tensorboard.plugins.beholder.visualizer import Visualizer from tensorboard.util import tb_logging @@ -36,187 +45,192 @@ class Beholder(object): - - def __init__(self, logdir): - self.PLUGIN_LOGDIR = logdir + '/plugins/' + PLUGIN_NAME - - self.is_recording = False - self.video_writer = video_writing.VideoWriter( - self.PLUGIN_LOGDIR, - outputs=[ - video_writing.FFmpegVideoOutput, - video_writing.PNGVideoOutput]) - - self.frame_placeholder = tf.compat.v1.placeholder(tf.uint8, [None, None, None]) - self.summary_op = tf.compat.v1.summary.tensor_summary(TAG_NAME, - self.frame_placeholder, - collections=[ - SUMMARY_COLLECTION_KEY_NAME - ]) - - self.last_image_shape = [] - self.last_update_time = time.time() - self.config_last_modified_time = -1 - self.previous_config = dict(DEFAULT_CONFIG) - - if not tf.io.gfile.exists(self.PLUGIN_LOGDIR + '/config.pkl'): - tf.io.gfile.makedirs(self.PLUGIN_LOGDIR) - write_pickle(DEFAULT_CONFIG, '{}/{}'.format(self.PLUGIN_LOGDIR, - CONFIG_FILENAME)) - - self.visualizer = Visualizer(self.PLUGIN_LOGDIR) - - - def _get_config(self): - '''Reads the config file from disk or creates a new one.''' - filename = '{}/{}'.format(self.PLUGIN_LOGDIR, CONFIG_FILENAME) - modified_time = os.path.getmtime(filename) - - if modified_time != self.config_last_modified_time: - config = read_pickle(filename, default=self.previous_config) - self.previous_config = config - else: - config = self.previous_config - - self.config_last_modified_time = modified_time - return config - - - def _write_summary(self, session, frame): - '''Writes the frame to disk as a tensor summary.''' - summary = session.run(self.summary_op, feed_dict={ - self.frame_placeholder: frame - }) - path = '{}/{}'.format(self.PLUGIN_LOGDIR, SUMMARY_FILENAME) - write_file(summary, path) - - - def _get_final_image(self, session, config, arrays=None, frame=None): - if config['values'] == 'frames': - if frame is None: - final_image = im_util.get_image_relative_to_script('frame-missing.png') - else: - frame = frame() if callable(frame) else frame - final_image = im_util.scale_image_for_display(frame) - - elif config['values'] == 'arrays': - if arrays is None: - final_image = im_util.get_image_relative_to_script('arrays-missing.png') - # TODO: hack to clear the info. Should be cleaner. - self.visualizer._save_section_info([], []) - else: - final_image = self.visualizer.build_frame(arrays) - - elif config['values'] == 'trainable_variables': - arrays = [session.run(x) for x in tf.compat.v1.trainable_variables()] - final_image = self.visualizer.build_frame(arrays) - - if len(final_image.shape) == 2: - # Map grayscale images to 3D tensors. - final_image = np.expand_dims(final_image, -1) - - return final_image - - - def _enough_time_has_passed(self, FPS): - '''For limiting how often frames are computed.''' - if FPS == 0: - return False - else: - earliest_time = self.last_update_time + (1.0 / FPS) - return time.time() >= earliest_time - - - def _update_frame(self, session, arrays, frame, config): - final_image = self._get_final_image(session, config, arrays, frame) - self._write_summary(session, final_image) - self.last_image_shape = final_image.shape - - return final_image - - - def _update_recording(self, frame, config): - '''Adds a frame to the current video output.''' - # pylint: disable=redefined-variable-type - should_record = config['is_recording'] - - if should_record: - if not self.is_recording: - self.is_recording = True - logger.info( - 'Starting recording using %s', - self.video_writer.current_output().name()) - self.video_writer.write_frame(frame) - elif self.is_recording: - self.is_recording = False - self.video_writer.finish() - logger.info('Finished recording') - - - # TODO: blanket try and except for production? I don't someone's script to die - # after weeks of running because of a visualization. - def update(self, session, arrays=None, frame=None): - '''Creates a frame and writes it to disk. - - Args: - arrays: a list of np arrays. Use the "custom" option in the client. - frame: a 2D np array. This way the plugin can be used for video of any - kind, not just the visualization that comes with the plugin. - - frame can also be a function, which only is evaluated when the - "frame" option is selected by the client. - ''' - new_config = self._get_config() - - if self._enough_time_has_passed(self.previous_config['FPS']): - self.visualizer.update(new_config) - self.last_update_time = time.time() - final_image = self._update_frame(session, arrays, frame, new_config) - self._update_recording(final_image, new_config) - - - ############################################################################## - - @staticmethod - def gradient_helper(optimizer, loss, var_list=None): - '''A helper to get the gradients out at each step. - - Args: - optimizer: the optimizer op. - loss: the op that computes your loss value. - - Returns: the gradient tensors and the train_step op. - ''' - if var_list is None: - var_list = tf.compat.v1.trainable_variables() - - grads_and_vars = optimizer.compute_gradients(loss, var_list=var_list) - grads = [pair[0] for pair in grads_and_vars] - - return grads, optimizer.apply_gradients(grads_and_vars) + def __init__(self, logdir): + self.PLUGIN_LOGDIR = logdir + "/plugins/" + PLUGIN_NAME + + self.is_recording = False + self.video_writer = video_writing.VideoWriter( + self.PLUGIN_LOGDIR, + outputs=[ + video_writing.FFmpegVideoOutput, + video_writing.PNGVideoOutput, + ], + ) + + self.frame_placeholder = tf.compat.v1.placeholder( + tf.uint8, [None, None, None] + ) + self.summary_op = tf.compat.v1.summary.tensor_summary( + TAG_NAME, + self.frame_placeholder, + collections=[SUMMARY_COLLECTION_KEY_NAME], + ) + + self.last_image_shape = [] + self.last_update_time = time.time() + self.config_last_modified_time = -1 + self.previous_config = dict(DEFAULT_CONFIG) + + if not tf.io.gfile.exists(self.PLUGIN_LOGDIR + "/config.pkl"): + tf.io.gfile.makedirs(self.PLUGIN_LOGDIR) + write_pickle( + DEFAULT_CONFIG, + "{}/{}".format(self.PLUGIN_LOGDIR, CONFIG_FILENAME), + ) + + self.visualizer = Visualizer(self.PLUGIN_LOGDIR) + + def _get_config(self): + """Reads the config file from disk or creates a new one.""" + filename = "{}/{}".format(self.PLUGIN_LOGDIR, CONFIG_FILENAME) + modified_time = os.path.getmtime(filename) + + if modified_time != self.config_last_modified_time: + config = read_pickle(filename, default=self.previous_config) + self.previous_config = config + else: + config = self.previous_config + + self.config_last_modified_time = modified_time + return config + + def _write_summary(self, session, frame): + """Writes the frame to disk as a tensor summary.""" + summary = session.run( + self.summary_op, feed_dict={self.frame_placeholder: frame} + ) + path = "{}/{}".format(self.PLUGIN_LOGDIR, SUMMARY_FILENAME) + write_file(summary, path) + + def _get_final_image(self, session, config, arrays=None, frame=None): + if config["values"] == "frames": + if frame is None: + final_image = im_util.get_image_relative_to_script( + "frame-missing.png" + ) + else: + frame = frame() if callable(frame) else frame + final_image = im_util.scale_image_for_display(frame) + + elif config["values"] == "arrays": + if arrays is None: + final_image = im_util.get_image_relative_to_script( + "arrays-missing.png" + ) + # TODO: hack to clear the info. Should be cleaner. + self.visualizer._save_section_info([], []) + else: + final_image = self.visualizer.build_frame(arrays) + + elif config["values"] == "trainable_variables": + arrays = [ + session.run(x) for x in tf.compat.v1.trainable_variables() + ] + final_image = self.visualizer.build_frame(arrays) + + if len(final_image.shape) == 2: + # Map grayscale images to 3D tensors. + final_image = np.expand_dims(final_image, -1) + + return final_image + + def _enough_time_has_passed(self, FPS): + """For limiting how often frames are computed.""" + if FPS == 0: + return False + else: + earliest_time = self.last_update_time + (1.0 / FPS) + return time.time() >= earliest_time + + def _update_frame(self, session, arrays, frame, config): + final_image = self._get_final_image(session, config, arrays, frame) + self._write_summary(session, final_image) + self.last_image_shape = final_image.shape + + return final_image + + def _update_recording(self, frame, config): + """Adds a frame to the current video output.""" + # pylint: disable=redefined-variable-type + should_record = config["is_recording"] + + if should_record: + if not self.is_recording: + self.is_recording = True + logger.info( + "Starting recording using %s", + self.video_writer.current_output().name(), + ) + self.video_writer.write_frame(frame) + elif self.is_recording: + self.is_recording = False + self.video_writer.finish() + logger.info("Finished recording") + + # TODO: blanket try and except for production? I don't someone's script to die + # after weeks of running because of a visualization. + def update(self, session, arrays=None, frame=None): + """Creates a frame and writes it to disk. + + Args: + arrays: a list of np arrays. Use the "custom" option in the client. + frame: a 2D np array. This way the plugin can be used for video of any + kind, not just the visualization that comes with the plugin. + + frame can also be a function, which only is evaluated when the + "frame" option is selected by the client. + """ + new_config = self._get_config() + + if self._enough_time_has_passed(self.previous_config["FPS"]): + self.visualizer.update(new_config) + self.last_update_time = time.time() + final_image = self._update_frame(session, arrays, frame, new_config) + self._update_recording(final_image, new_config) + + ############################################################################## + + @staticmethod + def gradient_helper(optimizer, loss, var_list=None): + """A helper to get the gradients out at each step. + + Args: + optimizer: the optimizer op. + loss: the op that computes your loss value. + + Returns: the gradient tensors and the train_step op. + """ + if var_list is None: + var_list = tf.compat.v1.trainable_variables() + + grads_and_vars = optimizer.compute_gradients(loss, var_list=var_list) + grads = [pair[0] for pair in grads_and_vars] + + return grads, optimizer.apply_gradients(grads_and_vars) class BeholderHook(tf.estimator.SessionRunHook): - """SessionRunHook implementation that runs Beholder every step. - - Convenient when using tf.train.MonitoredSession: - ```python - beholder_hook = BeholderHook(LOG_DIRECTORY) - with MonitoredSession(..., hooks=[beholder_hook]) as sess: - sess.run(train_op) - ``` - """ - def __init__(self, logdir): - """Creates new Hook instance - - Args: - logdir: Directory where Beholder should write data. + """SessionRunHook implementation that runs Beholder every step. + + Convenient when using tf.train.MonitoredSession: + ```python + beholder_hook = BeholderHook(LOG_DIRECTORY) + with MonitoredSession(..., hooks=[beholder_hook]) as sess: + sess.run(train_op) + ``` """ - self._logdir = logdir - self.beholder = None - def begin(self): - self.beholder = Beholder(self._logdir) + def __init__(self, logdir): + """Creates new Hook instance. + + Args: + logdir: Directory where Beholder should write data. + """ + self._logdir = logdir + self.beholder = None + + def begin(self): + self.beholder = Beholder(self._logdir) - def after_run(self, run_context, unused_run_values): - self.beholder.update(run_context.session) + def after_run(self, run_context, unused_run_values): + self.beholder.update(run_context.session) diff --git a/tensorboard/plugins/beholder/beholder_demo.py b/tensorboard/plugins/beholder/beholder_demo.py index 583e991415..ccec17bca8 100644 --- a/tensorboard/plugins/beholder/beholder_demo.py +++ b/tensorboard/plugins/beholder/beholder_demo.py @@ -32,186 +32,229 @@ FLAGS = None -LOG_DIRECTORY = '/tmp/beholder-demo' +LOG_DIRECTORY = "/tmp/beholder-demo" + def train(): - mnist_data = mnist.input_data.read_data_sets( - FLAGS.data_dir, one_hot=True, fake_data=FLAGS.fake_data) - - sess = tf.compat.v1.InteractiveSession() - - with tf.name_scope('input'): - x = tf.compat.v1.placeholder(tf.float32, [None, 784], name='x-input') - y_ = tf.compat.v1.placeholder(tf.float32, [None, 10], name='y-input') - - with tf.name_scope('input_reshape'): - image_shaped_input = tf.reshape(x, [-1, 28, 28, 1]) - tf.compat.v1.summary.image('input', image_shaped_input, 10) - - def weight_variable(shape): - """Create a weight variable with appropriate initialization.""" - initial = tf.random.truncated_normal(shape, stddev=0.01) - return tf.Variable(initial) - - def bias_variable(shape): - """Create a bias variable with appropriate initialization.""" - initial = tf.constant(0.1, shape=shape) - return tf.Variable(initial) - - def variable_summaries(var): - """Attach a lot of summaries to a Tensor (for TensorBoard visualization).""" - with tf.name_scope('summaries'): - mean = tf.reduce_mean(input_tensor=var) - tf.compat.v1.summary.scalar('mean', mean) - with tf.name_scope('stddev'): - stddev = tf.sqrt(tf.reduce_mean(input_tensor=tf.square(var - mean))) - tf.compat.v1.summary.scalar('stddev', stddev) - tf.compat.v1.summary.scalar('max', tf.reduce_max(input_tensor=var)) - tf.compat.v1.summary.scalar('min', tf.reduce_min(input_tensor=var)) - tf.compat.v1.summary.histogram('histogram', var) - - def nn_layer(input_tensor, input_dim, output_dim, layer_name, act=tf.nn.relu): - """Reusable code for making a simple neural net layer. - - It does a matrix multiply, bias add, and then uses ReLU to nonlinearize. - It also sets up name scoping so that the resultant graph is easy to read, - and adds a number of summary ops. - """ - # Adding a name scope ensures logical grouping of the layers in the graph. - with tf.name_scope(layer_name): - # This Variable will hold the state of the weights for the layer - with tf.name_scope('weights'): - weights = weight_variable([input_dim, output_dim]) - variable_summaries(weights) - with tf.name_scope('biases'): - biases = bias_variable([output_dim]) - variable_summaries(biases) - with tf.name_scope('Wx_plus_b'): - preactivate = tf.matmul(input_tensor, weights) + biases - tf.compat.v1.summary.histogram('pre_activations', preactivate) - activations = act(preactivate, name='activation') - tf.compat.v1.summary.histogram('activations', activations) - return activations - - #conv1 - kernel = tf.Variable(tf.random.truncated_normal([5, 5, 1, 10], - dtype=tf.float32, - stddev=1e-1), - name='conv-weights') - conv = tf.nn.conv2d(image_shaped_input, kernel, [1, 1, 1, 1], padding='VALID') - biases_init = tf.constant( - 0.0, shape=[kernel.get_shape().as_list()[-1]], dtype=tf.float32) - biases = tf.Variable(biases_init, trainable=True, name='biases') - out = tf.nn.bias_add(conv, biases) - conv1 = tf.nn.relu(out, name='relu') - - #conv2 - kernel2_init = tf.random.truncated_normal( - [3, 3, 10, 20], dtype=tf.float32, stddev=1e-1) - kernel2 = tf.Variable(kernel2_init, name='conv-weights2') - conv2 = tf.nn.conv2d(conv1, kernel2, [1, 1, 1, 1], padding='VALID') - biases2_init = tf.constant( - 0.0, shape=[kernel2.get_shape().as_list()[-1]], dtype=tf.float32) - biases2 = tf.Variable(biases2_init, trainable=True, name='biases') - out2 = tf.nn.bias_add(conv2, biases2) - conv2 = tf.nn.relu(out2, name='relu') - - flattened = tf.contrib.layers.flatten(conv2) - hidden1 = nn_layer( - flattened, flattened.get_shape().as_list()[1], 10, 'layer1') - - with tf.name_scope('dropout'): - keep_prob = tf.compat.v1.placeholder(tf.float32) - tf.compat.v1.summary.scalar('dropout_keep_probability', keep_prob) - dropped = tf.nn.dropout(hidden1, 1 - keep_prob) - - y = nn_layer(dropped, 10, 10, 'layer2', act=tf.identity) - - with tf.name_scope('cross_entropy'): - diff = tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y) - with tf.name_scope('total'): - cross_entropy = tf.reduce_mean(input_tensor=diff) - tf.compat.v1.summary.scalar('cross_entropy', cross_entropy) - - with tf.name_scope('train'): - optimizer = tf.compat.v1.train.AdamOptimizer(FLAGS.learning_rate) - gradients, train_step = beholder_lib.Beholder.gradient_helper( - optimizer, cross_entropy) - - with tf.name_scope('accuracy'): - with tf.name_scope('correct_prediction'): - correct_prediction = tf.equal(tf.argmax(input=y, axis=1), tf.argmax(input=y_, axis=1)) - with tf.name_scope('accuracy'): - accuracy = tf.reduce_mean(input_tensor=tf.cast(correct_prediction, tf.float32)) - tf.compat.v1.summary.scalar('accuracy', accuracy) - - merged = tf.compat.v1.summary.merge_all() - train_writer = tf.summary.FileWriter(LOG_DIRECTORY + '/train', sess.graph) - test_writer = tf.summary.FileWriter(LOG_DIRECTORY + '/test') - tf.compat.v1.global_variables_initializer().run() - - beholder = beholder_lib.Beholder(logdir=LOG_DIRECTORY) - - - def feed_dict(is_train): - if is_train or FLAGS.fake_data: - xs, ys = mnist_data.train.next_batch(100, fake_data=FLAGS.fake_data) - k = FLAGS.dropout - else: - xs, ys = mnist_data.test.images, mnist_data.test.labels - k = 1.0 - return {x: xs, y_: ys, keep_prob: k} - - for i in range(FLAGS.max_steps): - summary, acc = sess.run([merged, accuracy], feed_dict=feed_dict(False)) - test_writer.add_summary(summary, i) - print('Accuracy at step %s: %s' % (i, acc)) - print('i', i) - feed_dictionary = feed_dict(True) - summary, gradient_arrays, activations, _ = sess.run( - [ - merged, - gradients, - [image_shaped_input, conv1, conv2, hidden1, y], - train_step - ], - feed_dict=feed_dictionary) - first_of_batch = sess.run(x, feed_dict=feed_dictionary)[0].reshape(28, 28) - beholder.update( - session=sess, - arrays=activations + [first_of_batch] + gradient_arrays, - frame=first_of_batch, + mnist_data = mnist.input_data.read_data_sets( + FLAGS.data_dir, one_hot=True, fake_data=FLAGS.fake_data + ) + + sess = tf.compat.v1.InteractiveSession() + + with tf.name_scope("input"): + x = tf.compat.v1.placeholder(tf.float32, [None, 784], name="x-input") + y_ = tf.compat.v1.placeholder(tf.float32, [None, 10], name="y-input") + + with tf.name_scope("input_reshape"): + image_shaped_input = tf.reshape(x, [-1, 28, 28, 1]) + tf.compat.v1.summary.image("input", image_shaped_input, 10) + + def weight_variable(shape): + """Create a weight variable with appropriate initialization.""" + initial = tf.random.truncated_normal(shape, stddev=0.01) + return tf.Variable(initial) + + def bias_variable(shape): + """Create a bias variable with appropriate initialization.""" + initial = tf.constant(0.1, shape=shape) + return tf.Variable(initial) + + def variable_summaries(var): + """Attach a lot of summaries to a Tensor (for TensorBoard + visualization).""" + with tf.name_scope("summaries"): + mean = tf.reduce_mean(input_tensor=var) + tf.compat.v1.summary.scalar("mean", mean) + with tf.name_scope("stddev"): + stddev = tf.sqrt( + tf.reduce_mean(input_tensor=tf.square(var - mean)) + ) + tf.compat.v1.summary.scalar("stddev", stddev) + tf.compat.v1.summary.scalar("max", tf.reduce_max(input_tensor=var)) + tf.compat.v1.summary.scalar("min", tf.reduce_min(input_tensor=var)) + tf.compat.v1.summary.histogram("histogram", var) + + def nn_layer( + input_tensor, input_dim, output_dim, layer_name, act=tf.nn.relu + ): + """Reusable code for making a simple neural net layer. + + It does a matrix multiply, bias add, and then uses ReLU to + nonlinearize. It also sets up name scoping so that the resultant + graph is easy to read, and adds a number of summary ops. + """ + # Adding a name scope ensures logical grouping of the layers in the graph. + with tf.name_scope(layer_name): + # This Variable will hold the state of the weights for the layer + with tf.name_scope("weights"): + weights = weight_variable([input_dim, output_dim]) + variable_summaries(weights) + with tf.name_scope("biases"): + biases = bias_variable([output_dim]) + variable_summaries(biases) + with tf.name_scope("Wx_plus_b"): + preactivate = tf.matmul(input_tensor, weights) + biases + tf.compat.v1.summary.histogram("pre_activations", preactivate) + activations = act(preactivate, name="activation") + tf.compat.v1.summary.histogram("activations", activations) + return activations + + # conv1 + kernel = tf.Variable( + tf.random.truncated_normal( + [5, 5, 1, 10], dtype=tf.float32, stddev=1e-1 + ), + name="conv-weights", + ) + conv = tf.nn.conv2d( + image_shaped_input, kernel, [1, 1, 1, 1], padding="VALID" + ) + biases_init = tf.constant( + 0.0, shape=[kernel.get_shape().as_list()[-1]], dtype=tf.float32 ) - train_writer.add_summary(summary, i) + biases = tf.Variable(biases_init, trainable=True, name="biases") + out = tf.nn.bias_add(conv, biases) + conv1 = tf.nn.relu(out, name="relu") + + # conv2 + kernel2_init = tf.random.truncated_normal( + [3, 3, 10, 20], dtype=tf.float32, stddev=1e-1 + ) + kernel2 = tf.Variable(kernel2_init, name="conv-weights2") + conv2 = tf.nn.conv2d(conv1, kernel2, [1, 1, 1, 1], padding="VALID") + biases2_init = tf.constant( + 0.0, shape=[kernel2.get_shape().as_list()[-1]], dtype=tf.float32 + ) + biases2 = tf.Variable(biases2_init, trainable=True, name="biases") + out2 = tf.nn.bias_add(conv2, biases2) + conv2 = tf.nn.relu(out2, name="relu") + + flattened = tf.contrib.layers.flatten(conv2) + hidden1 = nn_layer( + flattened, flattened.get_shape().as_list()[1], 10, "layer1" + ) + + with tf.name_scope("dropout"): + keep_prob = tf.compat.v1.placeholder(tf.float32) + tf.compat.v1.summary.scalar("dropout_keep_probability", keep_prob) + dropped = tf.nn.dropout(hidden1, 1 - keep_prob) + + y = nn_layer(dropped, 10, 10, "layer2", act=tf.identity) + + with tf.name_scope("cross_entropy"): + diff = tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y) + with tf.name_scope("total"): + cross_entropy = tf.reduce_mean(input_tensor=diff) + tf.compat.v1.summary.scalar("cross_entropy", cross_entropy) + + with tf.name_scope("train"): + optimizer = tf.compat.v1.train.AdamOptimizer(FLAGS.learning_rate) + gradients, train_step = beholder_lib.Beholder.gradient_helper( + optimizer, cross_entropy + ) + + with tf.name_scope("accuracy"): + with tf.name_scope("correct_prediction"): + correct_prediction = tf.equal( + tf.argmax(input=y, axis=1), tf.argmax(input=y_, axis=1) + ) + with tf.name_scope("accuracy"): + accuracy = tf.reduce_mean( + input_tensor=tf.cast(correct_prediction, tf.float32) + ) + tf.compat.v1.summary.scalar("accuracy", accuracy) + + merged = tf.compat.v1.summary.merge_all() + train_writer = tf.summary.FileWriter(LOG_DIRECTORY + "/train", sess.graph) + test_writer = tf.summary.FileWriter(LOG_DIRECTORY + "/test") + tf.compat.v1.global_variables_initializer().run() + + beholder = beholder_lib.Beholder(logdir=LOG_DIRECTORY) + + def feed_dict(is_train): + if is_train or FLAGS.fake_data: + xs, ys = mnist_data.train.next_batch(100, fake_data=FLAGS.fake_data) + k = FLAGS.dropout + else: + xs, ys = mnist_data.test.images, mnist_data.test.labels + k = 1.0 + return {x: xs, y_: ys, keep_prob: k} + + for i in range(FLAGS.max_steps): + summary, acc = sess.run([merged, accuracy], feed_dict=feed_dict(False)) + test_writer.add_summary(summary, i) + print("Accuracy at step %s: %s" % (i, acc)) + print("i", i) + feed_dictionary = feed_dict(True) + summary, gradient_arrays, activations, _ = sess.run( + [ + merged, + gradients, + [image_shaped_input, conv1, conv2, hidden1, y], + train_step, + ], + feed_dict=feed_dictionary, + ) + first_of_batch = sess.run(x, feed_dict=feed_dictionary)[0].reshape( + 28, 28 + ) + beholder.update( + session=sess, + arrays=activations + [first_of_batch] + gradient_arrays, + frame=first_of_batch, + ) + train_writer.add_summary(summary, i) + + train_writer.close() + test_writer.close() - train_writer.close() - test_writer.close() def main(_): - if not tf.io.gfile.exists(LOG_DIRECTORY): - tf.io.gfile.makedirs(LOG_DIRECTORY) - train() - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--fake_data', nargs='?', const=True, type=bool, - default=False, - help='If true, uses fake data for unit testing.') - parser.add_argument('--max_steps', type=int, default=1000000, - help='Number of steps to run trainer.') - parser.add_argument('--learning_rate', type=float, default=0.001, - help='Initial learning rate') - parser.add_argument('--dropout', type=float, default=0.9, - help='Keep probability for training dropout.') - parser.add_argument( - '--data_dir', - type=str, - default='/tmp/tensorflow/mnist/input_data', - help='Directory for storing input data') - parser.add_argument( - '--log_dir', - type=str, - default='/tmp/tensorflow/mnist/logs/mnist_with_summaries', - help='Summaries log directory') - FLAGS, unparsed = parser.parse_known_args() - app.run(main=main, argv=[sys.argv[0]] + unparsed) + if not tf.io.gfile.exists(LOG_DIRECTORY): + tf.io.gfile.makedirs(LOG_DIRECTORY) + train() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--fake_data", + nargs="?", + const=True, + type=bool, + default=False, + help="If true, uses fake data for unit testing.", + ) + parser.add_argument( + "--max_steps", + type=int, + default=1000000, + help="Number of steps to run trainer.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=0.001, + help="Initial learning rate", + ) + parser.add_argument( + "--dropout", + type=float, + default=0.9, + help="Keep probability for training dropout.", + ) + parser.add_argument( + "--data_dir", + type=str, + default="/tmp/tensorflow/mnist/input_data", + help="Directory for storing input data", + ) + parser.add_argument( + "--log_dir", + type=str, + default="/tmp/tensorflow/mnist/logs/mnist_with_summaries", + help="Summaries log directory", + ) + FLAGS, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorboard/plugins/beholder/beholder_plugin.py b/tensorboard/plugins/beholder/beholder_plugin.py index b01a192de8..87a382688b 100644 --- a/tensorboard/plugins/beholder/beholder_plugin.py +++ b/tensorboard/plugins/beholder/beholder_plugin.py @@ -35,162 +35,180 @@ logger = tb_logging.get_logger() -DEFAULT_INFO = [{ - 'name': 'Waiting for data...', -}] +DEFAULT_INFO = [{"name": "Waiting for data...",}] class BeholderPlugin(base_plugin.TBPlugin): - """ - TensorBoard plugin for viewing model data as a live video during training. - """ - - plugin_name = shared_config.PLUGIN_NAME - - def __init__(self, context): - self._lock = threading.Lock() - self._MULTIPLEXER = context.multiplexer - self.PLUGIN_LOGDIR = pau.PluginDirectory( - context.logdir, shared_config.PLUGIN_NAME) - self.FPS = 10 - self._config_file_lock = threading.Lock() - self.most_recent_frame = None - self.most_recent_info = DEFAULT_INFO - - def get_plugin_apps(self): - return { - '/change-config': self._serve_change_config, - '/beholder-frame': self._serve_beholder_frame, - '/section-info': self._serve_section_info, - '/ping': self._serve_ping, - '/is-active': self._serve_is_active, - } - - def is_active(self): - summary_filename = '{}/{}'.format( - self.PLUGIN_LOGDIR, shared_config.SUMMARY_FILENAME) - info_filename = '{}/{}'.format( - self.PLUGIN_LOGDIR, shared_config.SECTION_INFO_FILENAME) - return tf.io.gfile.exists(summary_filename) and\ - tf.io.gfile.exists(info_filename) - - def frontend_metadata(self): - # TODO(#2338): Keep this in sync with the `registerDashboard` call - # on the frontend until that call is removed. - return base_plugin.FrontendMetadata( - element_name='tf-beholder-dashboard', - remove_dom=True, - ) - - def is_config_writable(self): - try: - if not tf.io.gfile.exists(self.PLUGIN_LOGDIR): - tf.io.gfile.makedirs(self.PLUGIN_LOGDIR) - config_filename = '{}/{}'.format( - self.PLUGIN_LOGDIR, shared_config.CONFIG_FILENAME) - with self._config_file_lock: - file_system_tools.write_pickle( - file_system_tools.read_pickle( - config_filename, shared_config.DEFAULT_CONFIG), - config_filename) - return True - except tf.errors.PermissionDeniedError as e: - logger.warn( - 'Unable to write Beholder config, controls will be disabled: %s', e) - return False - - @wrappers.Request.application - def _serve_is_active(self, request): - is_active = self.is_active() - # If the plugin isn't active, don't check if the configuration is writable - # since that will leave traces on disk; instead return True (the default). - is_config_writable = self.is_config_writable() if is_active else True - response = { - 'is_active': is_active, - 'is_config_writable': is_config_writable, - } - return http_util.Respond(request, response, 'application/json') - - def _fetch_current_frame(self): - path = '{}/{}'.format(self.PLUGIN_LOGDIR, shared_config.SUMMARY_FILENAME) - with self._lock: - try: - frame = file_system_tools.read_tensor_summary(path).astype(np.uint8) - self.most_recent_frame = frame - return frame - except (message.DecodeError, IOError, tf.errors.NotFoundError): - if self.most_recent_frame is None: - self.most_recent_frame = im_util.get_image_relative_to_script( - 'no-data.png') - return self.most_recent_frame - - @wrappers.Request.application - def _serve_change_config(self, request): - config = {} - - for key, value in request.form.items(): - try: - config[key] = int(value) - except ValueError: - if value == 'false': - config[key] = False - elif value == 'true': - config[key] = True - else: - config[key] = value - - self.FPS = config['FPS'] - - with self._config_file_lock: - file_system_tools.write_pickle( - config, - '{}/{}'.format(self.PLUGIN_LOGDIR, shared_config.CONFIG_FILENAME)) - return http_util.Respond(request, {'config': config}, 'application/json') - - @wrappers.Request.application - def _serve_section_info(self, request): - path = '{}/{}'.format( - self.PLUGIN_LOGDIR, shared_config.SECTION_INFO_FILENAME) - with self._lock: - default = self.most_recent_info - info = file_system_tools.read_pickle(path, default=default) - if info is not default: - with self._lock: - self.most_recent_info = info - return http_util.Respond(request, info, 'application/json') - - def _frame_generator(self): - while True: - last_duration = 0 - - if self.FPS == 0: - continue - else: - time.sleep(max(0, 1/(self.FPS) - last_duration)) - - start_time = time.time() - array = self._fetch_current_frame() - image_bytes = encoder.encode_png(array) - - frame_text = b'--frame\r\n' - content_type = b'Content-Type: image/png\r\n\r\n' - - response_content = frame_text + content_type + image_bytes + b'\r\n\r\n' - - last_duration = time.time() - start_time - yield response_content - - - @wrappers.Request.application - def _serve_beholder_frame(self, request): # pylint: disable=unused-argument - # Thanks to Miguel Grinberg for this technique: - # https://blog.miguelgrinberg.com/post/video-streaming-with-flask - mimetype = 'multipart/x-mixed-replace; boundary=frame' - return http_util.Respond(request, - self._frame_generator(), - mimetype, - code=200) - - @wrappers.Request.application - def _serve_ping(self, request): # pylint: disable=unused-argument - return http_util.Respond(request, {'status': 'alive'}, 'application/json') + """TensorBoard plugin for viewing model data as a live video during + training.""" + + plugin_name = shared_config.PLUGIN_NAME + + def __init__(self, context): + self._lock = threading.Lock() + self._MULTIPLEXER = context.multiplexer + self.PLUGIN_LOGDIR = pau.PluginDirectory( + context.logdir, shared_config.PLUGIN_NAME + ) + self.FPS = 10 + self._config_file_lock = threading.Lock() + self.most_recent_frame = None + self.most_recent_info = DEFAULT_INFO + + def get_plugin_apps(self): + return { + "/change-config": self._serve_change_config, + "/beholder-frame": self._serve_beholder_frame, + "/section-info": self._serve_section_info, + "/ping": self._serve_ping, + "/is-active": self._serve_is_active, + } + + def is_active(self): + summary_filename = "{}/{}".format( + self.PLUGIN_LOGDIR, shared_config.SUMMARY_FILENAME + ) + info_filename = "{}/{}".format( + self.PLUGIN_LOGDIR, shared_config.SECTION_INFO_FILENAME + ) + return tf.io.gfile.exists(summary_filename) and tf.io.gfile.exists( + info_filename + ) + + def frontend_metadata(self): + # TODO(#2338): Keep this in sync with the `registerDashboard` call + # on the frontend until that call is removed. + return base_plugin.FrontendMetadata( + element_name="tf-beholder-dashboard", remove_dom=True, + ) + + def is_config_writable(self): + try: + if not tf.io.gfile.exists(self.PLUGIN_LOGDIR): + tf.io.gfile.makedirs(self.PLUGIN_LOGDIR) + config_filename = "{}/{}".format( + self.PLUGIN_LOGDIR, shared_config.CONFIG_FILENAME + ) + with self._config_file_lock: + file_system_tools.write_pickle( + file_system_tools.read_pickle( + config_filename, shared_config.DEFAULT_CONFIG + ), + config_filename, + ) + return True + except tf.errors.PermissionDeniedError as e: + logger.warn( + "Unable to write Beholder config, controls will be disabled: %s", + e, + ) + return False + + @wrappers.Request.application + def _serve_is_active(self, request): + is_active = self.is_active() + # If the plugin isn't active, don't check if the configuration is writable + # since that will leave traces on disk; instead return True (the default). + is_config_writable = self.is_config_writable() if is_active else True + response = { + "is_active": is_active, + "is_config_writable": is_config_writable, + } + return http_util.Respond(request, response, "application/json") + + def _fetch_current_frame(self): + path = "{}/{}".format( + self.PLUGIN_LOGDIR, shared_config.SUMMARY_FILENAME + ) + with self._lock: + try: + frame = file_system_tools.read_tensor_summary(path).astype( + np.uint8 + ) + self.most_recent_frame = frame + return frame + except (message.DecodeError, IOError, tf.errors.NotFoundError): + if self.most_recent_frame is None: + self.most_recent_frame = im_util.get_image_relative_to_script( + "no-data.png" + ) + return self.most_recent_frame + + @wrappers.Request.application + def _serve_change_config(self, request): + config = {} + + for key, value in request.form.items(): + try: + config[key] = int(value) + except ValueError: + if value == "false": + config[key] = False + elif value == "true": + config[key] = True + else: + config[key] = value + + self.FPS = config["FPS"] + + with self._config_file_lock: + file_system_tools.write_pickle( + config, + "{}/{}".format( + self.PLUGIN_LOGDIR, shared_config.CONFIG_FILENAME + ), + ) + return http_util.Respond( + request, {"config": config}, "application/json" + ) + + @wrappers.Request.application + def _serve_section_info(self, request): + path = "{}/{}".format( + self.PLUGIN_LOGDIR, shared_config.SECTION_INFO_FILENAME + ) + with self._lock: + default = self.most_recent_info + info = file_system_tools.read_pickle(path, default=default) + if info is not default: + with self._lock: + self.most_recent_info = info + return http_util.Respond(request, info, "application/json") + + def _frame_generator(self): + while True: + last_duration = 0 + + if self.FPS == 0: + continue + else: + time.sleep(max(0, 1 / (self.FPS) - last_duration)) + + start_time = time.time() + array = self._fetch_current_frame() + image_bytes = encoder.encode_png(array) + + frame_text = b"--frame\r\n" + content_type = b"Content-Type: image/png\r\n\r\n" + + response_content = ( + frame_text + content_type + image_bytes + b"\r\n\r\n" + ) + + last_duration = time.time() - start_time + yield response_content + + @wrappers.Request.application + def _serve_beholder_frame(self, request): # pylint: disable=unused-argument + # Thanks to Miguel Grinberg for this technique: + # https://blog.miguelgrinberg.com/post/video-streaming-with-flask + mimetype = "multipart/x-mixed-replace; boundary=frame" + return http_util.Respond( + request, self._frame_generator(), mimetype, code=200 + ) + + @wrappers.Request.application + def _serve_ping(self, request): # pylint: disable=unused-argument + return http_util.Respond( + request, {"status": "alive"}, "application/json" + ) diff --git a/tensorboard/plugins/beholder/beholder_plugin_loader.py b/tensorboard/plugins/beholder/beholder_plugin_loader.py index 66dd00fca4..c7762a34a4 100644 --- a/tensorboard/plugins/beholder/beholder_plugin_loader.py +++ b/tensorboard/plugins/beholder/beholder_plugin_loader.py @@ -22,25 +22,26 @@ class BeholderPluginLoader(base_plugin.TBLoader): - """BeholderPlugin factory. + """BeholderPlugin factory. - This class checks for `tensorflow` install and dependency. - """ + This class checks for `tensorflow` install and dependency. + """ - def load(self, context): - """Returns the plugin, if possible. + def load(self, context): + """Returns the plugin, if possible. - Args: - context: The TBContext flags. + Args: + context: The TBContext flags. - Returns: - A BeholderPlugin instance or None if it couldn't be loaded. - """ - try: - # pylint: disable=unused-import - import tensorflow - except ImportError: - return - - from tensorboard.plugins.beholder.beholder_plugin import BeholderPlugin - return BeholderPlugin(context) + Returns: + A BeholderPlugin instance or None if it couldn't be loaded. + """ + try: + # pylint: disable=unused-import + import tensorflow + except ImportError: + return + + from tensorboard.plugins.beholder.beholder_plugin import BeholderPlugin + + return BeholderPlugin(context) diff --git a/tensorboard/plugins/beholder/beholder_test.py b/tensorboard/plugins/beholder/beholder_test.py index efd0195526..ba3608ac30 100644 --- a/tensorboard/plugins/beholder/beholder_test.py +++ b/tensorboard/plugins/beholder/beholder_test.py @@ -24,30 +24,29 @@ class BeholderTest(tf.test.TestCase): - - def setUp(self): - self._current_time_seconds = 1554232353 - - def advance_time(self, delta_seconds): - self._current_time_seconds += delta_seconds - - def get_time(self): - return self._current_time_seconds - - @test_util.run_v1_only("Requires sessions") - def test_update(self): - with tf.test.mock.patch("time.time", self.get_time): - b = beholder.Beholder(self.get_temp_dir()) - array = np.array([[0, 1], [1, 0]]) - with tf.Session() as sess: - v = tf.Variable([0, 0], trainable=True) - sess.run(tf.global_variables_initializer()) - # Beholder only updates if at least one frame has passed. The - # default FPS value is 10, but in any case 100 seconds ought to - # do it. - self.advance_time(delta_seconds=100) - b.update(session=sess, arrays=[array]) + def setUp(self): + self._current_time_seconds = 1554232353 + + def advance_time(self, delta_seconds): + self._current_time_seconds += delta_seconds + + def get_time(self): + return self._current_time_seconds + + @test_util.run_v1_only("Requires sessions") + def test_update(self): + with tf.test.mock.patch("time.time", self.get_time): + b = beholder.Beholder(self.get_temp_dir()) + array = np.array([[0, 1], [1, 0]]) + with tf.Session() as sess: + v = tf.Variable([0, 0], trainable=True) + sess.run(tf.global_variables_initializer()) + # Beholder only updates if at least one frame has passed. The + # default FPS value is 10, but in any case 100 seconds ought to + # do it. + self.advance_time(delta_seconds=100) + b.update(session=sess, arrays=[array]) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/tensorboard/plugins/beholder/colormaps.py b/tensorboard/plugins/beholder/colormaps.py index 9dca3b97ed..099c308d77 100644 --- a/tensorboard/plugins/beholder/colormaps.py +++ b/tensorboard/plugins/beholder/colormaps.py @@ -47,1045 +47,1057 @@ def _convert(colormap_data): - colormap = (np.array(colormap_data) * 255).astype(np.uint8) - colormap.setflags(write=False) - return colormap + colormap = (np.array(colormap_data) * 255).astype(np.uint8) + colormap.setflags(write=False) + return colormap magma = _convert( - ((0.001462, 0.000466, 0.013866), - (0.002258, 0.001295, 0.018331), - (0.003279, 0.002305, 0.023708), - (0.004512, 0.003490, 0.029965), - (0.005950, 0.004843, 0.037130), - (0.007588, 0.006356, 0.044973), - (0.009426, 0.008022, 0.052844), - (0.011465, 0.009828, 0.060750), - (0.013708, 0.011771, 0.068667), - (0.016156, 0.013840, 0.076603), - (0.018815, 0.016026, 0.084584), - (0.021692, 0.018320, 0.092610), - (0.024792, 0.020715, 0.100676), - (0.028123, 0.023201, 0.108787), - (0.031696, 0.025765, 0.116965), - (0.035520, 0.028397, 0.125209), - (0.039608, 0.031090, 0.133515), - (0.043830, 0.033830, 0.141886), - (0.048062, 0.036607, 0.150327), - (0.052320, 0.039407, 0.158841), - (0.056615, 0.042160, 0.167446), - (0.060949, 0.044794, 0.176129), - (0.065330, 0.047318, 0.184892), - (0.069764, 0.049726, 0.193735), - (0.074257, 0.052017, 0.202660), - (0.078815, 0.054184, 0.211667), - (0.083446, 0.056225, 0.220755), - (0.088155, 0.058133, 0.229922), - (0.092949, 0.059904, 0.239164), - (0.097833, 0.061531, 0.248477), - (0.102815, 0.063010, 0.257854), - (0.107899, 0.064335, 0.267289), - (0.113094, 0.065492, 0.276784), - (0.118405, 0.066479, 0.286321), - (0.123833, 0.067295, 0.295879), - (0.129380, 0.067935, 0.305443), - (0.135053, 0.068391, 0.315000), - (0.140858, 0.068654, 0.324538), - (0.146785, 0.068738, 0.334011), - (0.152839, 0.068637, 0.343404), - (0.159018, 0.068354, 0.352688), - (0.165308, 0.067911, 0.361816), - (0.171713, 0.067305, 0.370771), - (0.178212, 0.066576, 0.379497), - (0.184801, 0.065732, 0.387973), - (0.191460, 0.064818, 0.396152), - (0.198177, 0.063862, 0.404009), - (0.204935, 0.062907, 0.411514), - (0.211718, 0.061992, 0.418647), - (0.218512, 0.061158, 0.425392), - (0.225302, 0.060445, 0.431742), - (0.232077, 0.059889, 0.437695), - (0.238826, 0.059517, 0.443256), - (0.245543, 0.059352, 0.448436), - (0.252220, 0.059415, 0.453248), - (0.258857, 0.059706, 0.457710), - (0.265447, 0.060237, 0.461840), - (0.271994, 0.060994, 0.465660), - (0.278493, 0.061978, 0.469190), - (0.284951, 0.063168, 0.472451), - (0.291366, 0.064553, 0.475462), - (0.297740, 0.066117, 0.478243), - (0.304081, 0.067835, 0.480812), - (0.310382, 0.069702, 0.483186), - (0.316654, 0.071690, 0.485380), - (0.322899, 0.073782, 0.487408), - (0.329114, 0.075972, 0.489287), - (0.335308, 0.078236, 0.491024), - (0.341482, 0.080564, 0.492631), - (0.347636, 0.082946, 0.494121), - (0.353773, 0.085373, 0.495501), - (0.359898, 0.087831, 0.496778), - (0.366012, 0.090314, 0.497960), - (0.372116, 0.092816, 0.499053), - (0.378211, 0.095332, 0.500067), - (0.384299, 0.097855, 0.501002), - (0.390384, 0.100379, 0.501864), - (0.396467, 0.102902, 0.502658), - (0.402548, 0.105420, 0.503386), - (0.408629, 0.107930, 0.504052), - (0.414709, 0.110431, 0.504662), - (0.420791, 0.112920, 0.505215), - (0.426877, 0.115395, 0.505714), - (0.432967, 0.117855, 0.506160), - (0.439062, 0.120298, 0.506555), - (0.445163, 0.122724, 0.506901), - (0.451271, 0.125132, 0.507198), - (0.457386, 0.127522, 0.507448), - (0.463508, 0.129893, 0.507652), - (0.469640, 0.132245, 0.507809), - (0.475780, 0.134577, 0.507921), - (0.481929, 0.136891, 0.507989), - (0.488088, 0.139186, 0.508011), - (0.494258, 0.141462, 0.507988), - (0.500438, 0.143719, 0.507920), - (0.506629, 0.145958, 0.507806), - (0.512831, 0.148179, 0.507648), - (0.519045, 0.150383, 0.507443), - (0.525270, 0.152569, 0.507192), - (0.531507, 0.154739, 0.506895), - (0.537755, 0.156894, 0.506551), - (0.544015, 0.159033, 0.506159), - (0.550287, 0.161158, 0.505719), - (0.556571, 0.163269, 0.505230), - (0.562866, 0.165368, 0.504692), - (0.569172, 0.167454, 0.504105), - (0.575490, 0.169530, 0.503466), - (0.581819, 0.171596, 0.502777), - (0.588158, 0.173652, 0.502035), - (0.594508, 0.175701, 0.501241), - (0.600868, 0.177743, 0.500394), - (0.607238, 0.179779, 0.499492), - (0.613617, 0.181811, 0.498536), - (0.620005, 0.183840, 0.497524), - (0.626401, 0.185867, 0.496456), - (0.632805, 0.187893, 0.495332), - (0.639216, 0.189921, 0.494150), - (0.645633, 0.191952, 0.492910), - (0.652056, 0.193986, 0.491611), - (0.658483, 0.196027, 0.490253), - (0.664915, 0.198075, 0.488836), - (0.671349, 0.200133, 0.487358), - (0.677786, 0.202203, 0.485819), - (0.684224, 0.204286, 0.484219), - (0.690661, 0.206384, 0.482558), - (0.697098, 0.208501, 0.480835), - (0.703532, 0.210638, 0.479049), - (0.709962, 0.212797, 0.477201), - (0.716387, 0.214982, 0.475290), - (0.722805, 0.217194, 0.473316), - (0.729216, 0.219437, 0.471279), - (0.735616, 0.221713, 0.469180), - (0.742004, 0.224025, 0.467018), - (0.748378, 0.226377, 0.464794), - (0.754737, 0.228772, 0.462509), - (0.761077, 0.231214, 0.460162), - (0.767398, 0.233705, 0.457755), - (0.773695, 0.236249, 0.455289), - (0.779968, 0.238851, 0.452765), - (0.786212, 0.241514, 0.450184), - (0.792427, 0.244242, 0.447543), - (0.798608, 0.247040, 0.444848), - (0.804752, 0.249911, 0.442102), - (0.810855, 0.252861, 0.439305), - (0.816914, 0.255895, 0.436461), - (0.822926, 0.259016, 0.433573), - (0.828886, 0.262229, 0.430644), - (0.834791, 0.265540, 0.427671), - (0.840636, 0.268953, 0.424666), - (0.846416, 0.272473, 0.421631), - (0.852126, 0.276106, 0.418573), - (0.857763, 0.279857, 0.415496), - (0.863320, 0.283729, 0.412403), - (0.868793, 0.287728, 0.409303), - (0.874176, 0.291859, 0.406205), - (0.879464, 0.296125, 0.403118), - (0.884651, 0.300530, 0.400047), - (0.889731, 0.305079, 0.397002), - (0.894700, 0.309773, 0.393995), - (0.899552, 0.314616, 0.391037), - (0.904281, 0.319610, 0.388137), - (0.908884, 0.324755, 0.385308), - (0.913354, 0.330052, 0.382563), - (0.917689, 0.335500, 0.379915), - (0.921884, 0.341098, 0.377376), - (0.925937, 0.346844, 0.374959), - (0.929845, 0.352734, 0.372677), - (0.933606, 0.358764, 0.370541), - (0.937221, 0.364929, 0.368567), - (0.940687, 0.371224, 0.366762), - (0.944006, 0.377643, 0.365136), - (0.947180, 0.384178, 0.363701), - (0.950210, 0.390820, 0.362468), - (0.953099, 0.397563, 0.361438), - (0.955849, 0.404400, 0.360619), - (0.958464, 0.411324, 0.360014), - (0.960949, 0.418323, 0.359630), - (0.963310, 0.425390, 0.359469), - (0.965549, 0.432519, 0.359529), - (0.967671, 0.439703, 0.359810), - (0.969680, 0.446936, 0.360311), - (0.971582, 0.454210, 0.361030), - (0.973381, 0.461520, 0.361965), - (0.975082, 0.468861, 0.363111), - (0.976690, 0.476226, 0.364466), - (0.978210, 0.483612, 0.366025), - (0.979645, 0.491014, 0.367783), - (0.981000, 0.498428, 0.369734), - (0.982279, 0.505851, 0.371874), - (0.983485, 0.513280, 0.374198), - (0.984622, 0.520713, 0.376698), - (0.985693, 0.528148, 0.379371), - (0.986700, 0.535582, 0.382210), - (0.987646, 0.543015, 0.385210), - (0.988533, 0.550446, 0.388365), - (0.989363, 0.557873, 0.391671), - (0.990138, 0.565296, 0.395122), - (0.990871, 0.572706, 0.398714), - (0.991558, 0.580107, 0.402441), - (0.992196, 0.587502, 0.406299), - (0.992785, 0.594891, 0.410283), - (0.993326, 0.602275, 0.414390), - (0.993834, 0.609644, 0.418613), - (0.994309, 0.616999, 0.422950), - (0.994738, 0.624350, 0.427397), - (0.995122, 0.631696, 0.431951), - (0.995480, 0.639027, 0.436607), - (0.995810, 0.646344, 0.441361), - (0.996096, 0.653659, 0.446213), - (0.996341, 0.660969, 0.451160), - (0.996580, 0.668256, 0.456192), - (0.996775, 0.675541, 0.461314), - (0.996925, 0.682828, 0.466526), - (0.997077, 0.690088, 0.471811), - (0.997186, 0.697349, 0.477182), - (0.997254, 0.704611, 0.482635), - (0.997325, 0.711848, 0.488154), - (0.997351, 0.719089, 0.493755), - (0.997351, 0.726324, 0.499428), - (0.997341, 0.733545, 0.505167), - (0.997285, 0.740772, 0.510983), - (0.997228, 0.747981, 0.516859), - (0.997138, 0.755190, 0.522806), - (0.997019, 0.762398, 0.528821), - (0.996898, 0.769591, 0.534892), - (0.996727, 0.776795, 0.541039), - (0.996571, 0.783977, 0.547233), - (0.996369, 0.791167, 0.553499), - (0.996162, 0.798348, 0.559820), - (0.995932, 0.805527, 0.566202), - (0.995680, 0.812706, 0.572645), - (0.995424, 0.819875, 0.579140), - (0.995131, 0.827052, 0.585701), - (0.994851, 0.834213, 0.592307), - (0.994524, 0.841387, 0.598983), - (0.994222, 0.848540, 0.605696), - (0.993866, 0.855711, 0.612482), - (0.993545, 0.862859, 0.619299), - (0.993170, 0.870024, 0.626189), - (0.992831, 0.877168, 0.633109), - (0.992440, 0.884330, 0.640099), - (0.992089, 0.891470, 0.647116), - (0.991688, 0.898627, 0.654202), - (0.991332, 0.905763, 0.661309), - (0.990930, 0.912915, 0.668481), - (0.990570, 0.920049, 0.675675), - (0.990175, 0.927196, 0.682926), - (0.989815, 0.934329, 0.690198), - (0.989434, 0.941470, 0.697519), - (0.989077, 0.948604, 0.704863), - (0.988717, 0.955742, 0.712242), - (0.988367, 0.962878, 0.719649), - (0.988033, 0.970012, 0.727077), - (0.987691, 0.977154, 0.734536), - (0.987387, 0.984288, 0.742002), - (0.987053, 0.991438, 0.749504))) + ( + (0.001462, 0.000466, 0.013866), + (0.002258, 0.001295, 0.018331), + (0.003279, 0.002305, 0.023708), + (0.004512, 0.003490, 0.029965), + (0.005950, 0.004843, 0.037130), + (0.007588, 0.006356, 0.044973), + (0.009426, 0.008022, 0.052844), + (0.011465, 0.009828, 0.060750), + (0.013708, 0.011771, 0.068667), + (0.016156, 0.013840, 0.076603), + (0.018815, 0.016026, 0.084584), + (0.021692, 0.018320, 0.092610), + (0.024792, 0.020715, 0.100676), + (0.028123, 0.023201, 0.108787), + (0.031696, 0.025765, 0.116965), + (0.035520, 0.028397, 0.125209), + (0.039608, 0.031090, 0.133515), + (0.043830, 0.033830, 0.141886), + (0.048062, 0.036607, 0.150327), + (0.052320, 0.039407, 0.158841), + (0.056615, 0.042160, 0.167446), + (0.060949, 0.044794, 0.176129), + (0.065330, 0.047318, 0.184892), + (0.069764, 0.049726, 0.193735), + (0.074257, 0.052017, 0.202660), + (0.078815, 0.054184, 0.211667), + (0.083446, 0.056225, 0.220755), + (0.088155, 0.058133, 0.229922), + (0.092949, 0.059904, 0.239164), + (0.097833, 0.061531, 0.248477), + (0.102815, 0.063010, 0.257854), + (0.107899, 0.064335, 0.267289), + (0.113094, 0.065492, 0.276784), + (0.118405, 0.066479, 0.286321), + (0.123833, 0.067295, 0.295879), + (0.129380, 0.067935, 0.305443), + (0.135053, 0.068391, 0.315000), + (0.140858, 0.068654, 0.324538), + (0.146785, 0.068738, 0.334011), + (0.152839, 0.068637, 0.343404), + (0.159018, 0.068354, 0.352688), + (0.165308, 0.067911, 0.361816), + (0.171713, 0.067305, 0.370771), + (0.178212, 0.066576, 0.379497), + (0.184801, 0.065732, 0.387973), + (0.191460, 0.064818, 0.396152), + (0.198177, 0.063862, 0.404009), + (0.204935, 0.062907, 0.411514), + (0.211718, 0.061992, 0.418647), + (0.218512, 0.061158, 0.425392), + (0.225302, 0.060445, 0.431742), + (0.232077, 0.059889, 0.437695), + (0.238826, 0.059517, 0.443256), + (0.245543, 0.059352, 0.448436), + (0.252220, 0.059415, 0.453248), + (0.258857, 0.059706, 0.457710), + (0.265447, 0.060237, 0.461840), + (0.271994, 0.060994, 0.465660), + (0.278493, 0.061978, 0.469190), + (0.284951, 0.063168, 0.472451), + (0.291366, 0.064553, 0.475462), + (0.297740, 0.066117, 0.478243), + (0.304081, 0.067835, 0.480812), + (0.310382, 0.069702, 0.483186), + (0.316654, 0.071690, 0.485380), + (0.322899, 0.073782, 0.487408), + (0.329114, 0.075972, 0.489287), + (0.335308, 0.078236, 0.491024), + (0.341482, 0.080564, 0.492631), + (0.347636, 0.082946, 0.494121), + (0.353773, 0.085373, 0.495501), + (0.359898, 0.087831, 0.496778), + (0.366012, 0.090314, 0.497960), + (0.372116, 0.092816, 0.499053), + (0.378211, 0.095332, 0.500067), + (0.384299, 0.097855, 0.501002), + (0.390384, 0.100379, 0.501864), + (0.396467, 0.102902, 0.502658), + (0.402548, 0.105420, 0.503386), + (0.408629, 0.107930, 0.504052), + (0.414709, 0.110431, 0.504662), + (0.420791, 0.112920, 0.505215), + (0.426877, 0.115395, 0.505714), + (0.432967, 0.117855, 0.506160), + (0.439062, 0.120298, 0.506555), + (0.445163, 0.122724, 0.506901), + (0.451271, 0.125132, 0.507198), + (0.457386, 0.127522, 0.507448), + (0.463508, 0.129893, 0.507652), + (0.469640, 0.132245, 0.507809), + (0.475780, 0.134577, 0.507921), + (0.481929, 0.136891, 0.507989), + (0.488088, 0.139186, 0.508011), + (0.494258, 0.141462, 0.507988), + (0.500438, 0.143719, 0.507920), + (0.506629, 0.145958, 0.507806), + (0.512831, 0.148179, 0.507648), + (0.519045, 0.150383, 0.507443), + (0.525270, 0.152569, 0.507192), + (0.531507, 0.154739, 0.506895), + (0.537755, 0.156894, 0.506551), + (0.544015, 0.159033, 0.506159), + (0.550287, 0.161158, 0.505719), + (0.556571, 0.163269, 0.505230), + (0.562866, 0.165368, 0.504692), + (0.569172, 0.167454, 0.504105), + (0.575490, 0.169530, 0.503466), + (0.581819, 0.171596, 0.502777), + (0.588158, 0.173652, 0.502035), + (0.594508, 0.175701, 0.501241), + (0.600868, 0.177743, 0.500394), + (0.607238, 0.179779, 0.499492), + (0.613617, 0.181811, 0.498536), + (0.620005, 0.183840, 0.497524), + (0.626401, 0.185867, 0.496456), + (0.632805, 0.187893, 0.495332), + (0.639216, 0.189921, 0.494150), + (0.645633, 0.191952, 0.492910), + (0.652056, 0.193986, 0.491611), + (0.658483, 0.196027, 0.490253), + (0.664915, 0.198075, 0.488836), + (0.671349, 0.200133, 0.487358), + (0.677786, 0.202203, 0.485819), + (0.684224, 0.204286, 0.484219), + (0.690661, 0.206384, 0.482558), + (0.697098, 0.208501, 0.480835), + (0.703532, 0.210638, 0.479049), + (0.709962, 0.212797, 0.477201), + (0.716387, 0.214982, 0.475290), + (0.722805, 0.217194, 0.473316), + (0.729216, 0.219437, 0.471279), + (0.735616, 0.221713, 0.469180), + (0.742004, 0.224025, 0.467018), + (0.748378, 0.226377, 0.464794), + (0.754737, 0.228772, 0.462509), + (0.761077, 0.231214, 0.460162), + (0.767398, 0.233705, 0.457755), + (0.773695, 0.236249, 0.455289), + (0.779968, 0.238851, 0.452765), + (0.786212, 0.241514, 0.450184), + (0.792427, 0.244242, 0.447543), + (0.798608, 0.247040, 0.444848), + (0.804752, 0.249911, 0.442102), + (0.810855, 0.252861, 0.439305), + (0.816914, 0.255895, 0.436461), + (0.822926, 0.259016, 0.433573), + (0.828886, 0.262229, 0.430644), + (0.834791, 0.265540, 0.427671), + (0.840636, 0.268953, 0.424666), + (0.846416, 0.272473, 0.421631), + (0.852126, 0.276106, 0.418573), + (0.857763, 0.279857, 0.415496), + (0.863320, 0.283729, 0.412403), + (0.868793, 0.287728, 0.409303), + (0.874176, 0.291859, 0.406205), + (0.879464, 0.296125, 0.403118), + (0.884651, 0.300530, 0.400047), + (0.889731, 0.305079, 0.397002), + (0.894700, 0.309773, 0.393995), + (0.899552, 0.314616, 0.391037), + (0.904281, 0.319610, 0.388137), + (0.908884, 0.324755, 0.385308), + (0.913354, 0.330052, 0.382563), + (0.917689, 0.335500, 0.379915), + (0.921884, 0.341098, 0.377376), + (0.925937, 0.346844, 0.374959), + (0.929845, 0.352734, 0.372677), + (0.933606, 0.358764, 0.370541), + (0.937221, 0.364929, 0.368567), + (0.940687, 0.371224, 0.366762), + (0.944006, 0.377643, 0.365136), + (0.947180, 0.384178, 0.363701), + (0.950210, 0.390820, 0.362468), + (0.953099, 0.397563, 0.361438), + (0.955849, 0.404400, 0.360619), + (0.958464, 0.411324, 0.360014), + (0.960949, 0.418323, 0.359630), + (0.963310, 0.425390, 0.359469), + (0.965549, 0.432519, 0.359529), + (0.967671, 0.439703, 0.359810), + (0.969680, 0.446936, 0.360311), + (0.971582, 0.454210, 0.361030), + (0.973381, 0.461520, 0.361965), + (0.975082, 0.468861, 0.363111), + (0.976690, 0.476226, 0.364466), + (0.978210, 0.483612, 0.366025), + (0.979645, 0.491014, 0.367783), + (0.981000, 0.498428, 0.369734), + (0.982279, 0.505851, 0.371874), + (0.983485, 0.513280, 0.374198), + (0.984622, 0.520713, 0.376698), + (0.985693, 0.528148, 0.379371), + (0.986700, 0.535582, 0.382210), + (0.987646, 0.543015, 0.385210), + (0.988533, 0.550446, 0.388365), + (0.989363, 0.557873, 0.391671), + (0.990138, 0.565296, 0.395122), + (0.990871, 0.572706, 0.398714), + (0.991558, 0.580107, 0.402441), + (0.992196, 0.587502, 0.406299), + (0.992785, 0.594891, 0.410283), + (0.993326, 0.602275, 0.414390), + (0.993834, 0.609644, 0.418613), + (0.994309, 0.616999, 0.422950), + (0.994738, 0.624350, 0.427397), + (0.995122, 0.631696, 0.431951), + (0.995480, 0.639027, 0.436607), + (0.995810, 0.646344, 0.441361), + (0.996096, 0.653659, 0.446213), + (0.996341, 0.660969, 0.451160), + (0.996580, 0.668256, 0.456192), + (0.996775, 0.675541, 0.461314), + (0.996925, 0.682828, 0.466526), + (0.997077, 0.690088, 0.471811), + (0.997186, 0.697349, 0.477182), + (0.997254, 0.704611, 0.482635), + (0.997325, 0.711848, 0.488154), + (0.997351, 0.719089, 0.493755), + (0.997351, 0.726324, 0.499428), + (0.997341, 0.733545, 0.505167), + (0.997285, 0.740772, 0.510983), + (0.997228, 0.747981, 0.516859), + (0.997138, 0.755190, 0.522806), + (0.997019, 0.762398, 0.528821), + (0.996898, 0.769591, 0.534892), + (0.996727, 0.776795, 0.541039), + (0.996571, 0.783977, 0.547233), + (0.996369, 0.791167, 0.553499), + (0.996162, 0.798348, 0.559820), + (0.995932, 0.805527, 0.566202), + (0.995680, 0.812706, 0.572645), + (0.995424, 0.819875, 0.579140), + (0.995131, 0.827052, 0.585701), + (0.994851, 0.834213, 0.592307), + (0.994524, 0.841387, 0.598983), + (0.994222, 0.848540, 0.605696), + (0.993866, 0.855711, 0.612482), + (0.993545, 0.862859, 0.619299), + (0.993170, 0.870024, 0.626189), + (0.992831, 0.877168, 0.633109), + (0.992440, 0.884330, 0.640099), + (0.992089, 0.891470, 0.647116), + (0.991688, 0.898627, 0.654202), + (0.991332, 0.905763, 0.661309), + (0.990930, 0.912915, 0.668481), + (0.990570, 0.920049, 0.675675), + (0.990175, 0.927196, 0.682926), + (0.989815, 0.934329, 0.690198), + (0.989434, 0.941470, 0.697519), + (0.989077, 0.948604, 0.704863), + (0.988717, 0.955742, 0.712242), + (0.988367, 0.962878, 0.719649), + (0.988033, 0.970012, 0.727077), + (0.987691, 0.977154, 0.734536), + (0.987387, 0.984288, 0.742002), + (0.987053, 0.991438, 0.749504), + ) +) inferno = _convert( - ((0.001462, 0.000466, 0.013866), - (0.002267, 0.001270, 0.018570), - (0.003299, 0.002249, 0.024239), - (0.004547, 0.003392, 0.030909), - (0.006006, 0.004692, 0.038558), - (0.007676, 0.006136, 0.046836), - (0.009561, 0.007713, 0.055143), - (0.011663, 0.009417, 0.063460), - (0.013995, 0.011225, 0.071862), - (0.016561, 0.013136, 0.080282), - (0.019373, 0.015133, 0.088767), - (0.022447, 0.017199, 0.097327), - (0.025793, 0.019331, 0.105930), - (0.029432, 0.021503, 0.114621), - (0.033385, 0.023702, 0.123397), - (0.037668, 0.025921, 0.132232), - (0.042253, 0.028139, 0.141141), - (0.046915, 0.030324, 0.150164), - (0.051644, 0.032474, 0.159254), - (0.056449, 0.034569, 0.168414), - (0.061340, 0.036590, 0.177642), - (0.066331, 0.038504, 0.186962), - (0.071429, 0.040294, 0.196354), - (0.076637, 0.041905, 0.205799), - (0.081962, 0.043328, 0.215289), - (0.087411, 0.044556, 0.224813), - (0.092990, 0.045583, 0.234358), - (0.098702, 0.046402, 0.243904), - (0.104551, 0.047008, 0.253430), - (0.110536, 0.047399, 0.262912), - (0.116656, 0.047574, 0.272321), - (0.122908, 0.047536, 0.281624), - (0.129285, 0.047293, 0.290788), - (0.135778, 0.046856, 0.299776), - (0.142378, 0.046242, 0.308553), - (0.149073, 0.045468, 0.317085), - (0.155850, 0.044559, 0.325338), - (0.162689, 0.043554, 0.333277), - (0.169575, 0.042489, 0.340874), - (0.176493, 0.041402, 0.348111), - (0.183429, 0.040329, 0.354971), - (0.190367, 0.039309, 0.361447), - (0.197297, 0.038400, 0.367535), - (0.204209, 0.037632, 0.373238), - (0.211095, 0.037030, 0.378563), - (0.217949, 0.036615, 0.383522), - (0.224763, 0.036405, 0.388129), - (0.231538, 0.036405, 0.392400), - (0.238273, 0.036621, 0.396353), - (0.244967, 0.037055, 0.400007), - (0.251620, 0.037705, 0.403378), - (0.258234, 0.038571, 0.406485), - (0.264810, 0.039647, 0.409345), - (0.271347, 0.040922, 0.411976), - (0.277850, 0.042353, 0.414392), - (0.284321, 0.043933, 0.416608), - (0.290763, 0.045644, 0.418637), - (0.297178, 0.047470, 0.420491), - (0.303568, 0.049396, 0.422182), - (0.309935, 0.051407, 0.423721), - (0.316282, 0.053490, 0.425116), - (0.322610, 0.055634, 0.426377), - (0.328921, 0.057827, 0.427511), - (0.335217, 0.060060, 0.428524), - (0.341500, 0.062325, 0.429425), - (0.347771, 0.064616, 0.430217), - (0.354032, 0.066925, 0.430906), - (0.360284, 0.069247, 0.431497), - (0.366529, 0.071579, 0.431994), - (0.372768, 0.073915, 0.432400), - (0.379001, 0.076253, 0.432719), - (0.385228, 0.078591, 0.432955), - (0.391453, 0.080927, 0.433109), - (0.397674, 0.083257, 0.433183), - (0.403894, 0.085580, 0.433179), - (0.410113, 0.087896, 0.433098), - (0.416331, 0.090203, 0.432943), - (0.422549, 0.092501, 0.432714), - (0.428768, 0.094790, 0.432412), - (0.434987, 0.097069, 0.432039), - (0.441207, 0.099338, 0.431594), - (0.447428, 0.101597, 0.431080), - (0.453651, 0.103848, 0.430498), - (0.459875, 0.106089, 0.429846), - (0.466100, 0.108322, 0.429125), - (0.472328, 0.110547, 0.428334), - (0.478558, 0.112764, 0.427475), - (0.484789, 0.114974, 0.426548), - (0.491022, 0.117179, 0.425552), - (0.497257, 0.119379, 0.424488), - (0.503493, 0.121575, 0.423356), - (0.509730, 0.123769, 0.422156), - (0.515967, 0.125960, 0.420887), - (0.522206, 0.128150, 0.419549), - (0.528444, 0.130341, 0.418142), - (0.534683, 0.132534, 0.416667), - (0.540920, 0.134729, 0.415123), - (0.547157, 0.136929, 0.413511), - (0.553392, 0.139134, 0.411829), - (0.559624, 0.141346, 0.410078), - (0.565854, 0.143567, 0.408258), - (0.572081, 0.145797, 0.406369), - (0.578304, 0.148039, 0.404411), - (0.584521, 0.150294, 0.402385), - (0.590734, 0.152563, 0.400290), - (0.596940, 0.154848, 0.398125), - (0.603139, 0.157151, 0.395891), - (0.609330, 0.159474, 0.393589), - (0.615513, 0.161817, 0.391219), - (0.621685, 0.164184, 0.388781), - (0.627847, 0.166575, 0.386276), - (0.633998, 0.168992, 0.383704), - (0.640135, 0.171438, 0.381065), - (0.646260, 0.173914, 0.378359), - (0.652369, 0.176421, 0.375586), - (0.658463, 0.178962, 0.372748), - (0.664540, 0.181539, 0.369846), - (0.670599, 0.184153, 0.366879), - (0.676638, 0.186807, 0.363849), - (0.682656, 0.189501, 0.360757), - (0.688653, 0.192239, 0.357603), - (0.694627, 0.195021, 0.354388), - (0.700576, 0.197851, 0.351113), - (0.706500, 0.200728, 0.347777), - (0.712396, 0.203656, 0.344383), - (0.718264, 0.206636, 0.340931), - (0.724103, 0.209670, 0.337424), - (0.729909, 0.212759, 0.333861), - (0.735683, 0.215906, 0.330245), - (0.741423, 0.219112, 0.326576), - (0.747127, 0.222378, 0.322856), - (0.752794, 0.225706, 0.319085), - (0.758422, 0.229097, 0.315266), - (0.764010, 0.232554, 0.311399), - (0.769556, 0.236077, 0.307485), - (0.775059, 0.239667, 0.303526), - (0.780517, 0.243327, 0.299523), - (0.785929, 0.247056, 0.295477), - (0.791293, 0.250856, 0.291390), - (0.796607, 0.254728, 0.287264), - (0.801871, 0.258674, 0.283099), - (0.807082, 0.262692, 0.278898), - (0.812239, 0.266786, 0.274661), - (0.817341, 0.270954, 0.270390), - (0.822386, 0.275197, 0.266085), - (0.827372, 0.279517, 0.261750), - (0.832299, 0.283913, 0.257383), - (0.837165, 0.288385, 0.252988), - (0.841969, 0.292933, 0.248564), - (0.846709, 0.297559, 0.244113), - (0.851384, 0.302260, 0.239636), - (0.855992, 0.307038, 0.235133), - (0.860533, 0.311892, 0.230606), - (0.865006, 0.316822, 0.226055), - (0.869409, 0.321827, 0.221482), - (0.873741, 0.326906, 0.216886), - (0.878001, 0.332060, 0.212268), - (0.882188, 0.337287, 0.207628), - (0.886302, 0.342586, 0.202968), - (0.890341, 0.347957, 0.198286), - (0.894305, 0.353399, 0.193584), - (0.898192, 0.358911, 0.188860), - (0.902003, 0.364492, 0.184116), - (0.905735, 0.370140, 0.179350), - (0.909390, 0.375856, 0.174563), - (0.912966, 0.381636, 0.169755), - (0.916462, 0.387481, 0.164924), - (0.919879, 0.393389, 0.160070), - (0.923215, 0.399359, 0.155193), - (0.926470, 0.405389, 0.150292), - (0.929644, 0.411479, 0.145367), - (0.932737, 0.417627, 0.140417), - (0.935747, 0.423831, 0.135440), - (0.938675, 0.430091, 0.130438), - (0.941521, 0.436405, 0.125409), - (0.944285, 0.442772, 0.120354), - (0.946965, 0.449191, 0.115272), - (0.949562, 0.455660, 0.110164), - (0.952075, 0.462178, 0.105031), - (0.954506, 0.468744, 0.099874), - (0.956852, 0.475356, 0.094695), - (0.959114, 0.482014, 0.089499), - (0.961293, 0.488716, 0.084289), - (0.963387, 0.495462, 0.079073), - (0.965397, 0.502249, 0.073859), - (0.967322, 0.509078, 0.068659), - (0.969163, 0.515946, 0.063488), - (0.970919, 0.522853, 0.058367), - (0.972590, 0.529798, 0.053324), - (0.974176, 0.536780, 0.048392), - (0.975677, 0.543798, 0.043618), - (0.977092, 0.550850, 0.039050), - (0.978422, 0.557937, 0.034931), - (0.979666, 0.565057, 0.031409), - (0.980824, 0.572209, 0.028508), - (0.981895, 0.579392, 0.026250), - (0.982881, 0.586606, 0.024661), - (0.983779, 0.593849, 0.023770), - (0.984591, 0.601122, 0.023606), - (0.985315, 0.608422, 0.024202), - (0.985952, 0.615750, 0.025592), - (0.986502, 0.623105, 0.027814), - (0.986964, 0.630485, 0.030908), - (0.987337, 0.637890, 0.034916), - (0.987622, 0.645320, 0.039886), - (0.987819, 0.652773, 0.045581), - (0.987926, 0.660250, 0.051750), - (0.987945, 0.667748, 0.058329), - (0.987874, 0.675267, 0.065257), - (0.987714, 0.682807, 0.072489), - (0.987464, 0.690366, 0.079990), - (0.987124, 0.697944, 0.087731), - (0.986694, 0.705540, 0.095694), - (0.986175, 0.713153, 0.103863), - (0.985566, 0.720782, 0.112229), - (0.984865, 0.728427, 0.120785), - (0.984075, 0.736087, 0.129527), - (0.983196, 0.743758, 0.138453), - (0.982228, 0.751442, 0.147565), - (0.981173, 0.759135, 0.156863), - (0.980032, 0.766837, 0.166353), - (0.978806, 0.774545, 0.176037), - (0.977497, 0.782258, 0.185923), - (0.976108, 0.789974, 0.196018), - (0.974638, 0.797692, 0.206332), - (0.973088, 0.805409, 0.216877), - (0.971468, 0.813122, 0.227658), - (0.969783, 0.820825, 0.238686), - (0.968041, 0.828515, 0.249972), - (0.966243, 0.836191, 0.261534), - (0.964394, 0.843848, 0.273391), - (0.962517, 0.851476, 0.285546), - (0.960626, 0.859069, 0.298010), - (0.958720, 0.866624, 0.310820), - (0.956834, 0.874129, 0.323974), - (0.954997, 0.881569, 0.337475), - (0.953215, 0.888942, 0.351369), - (0.951546, 0.896226, 0.365627), - (0.950018, 0.903409, 0.380271), - (0.948683, 0.910473, 0.395289), - (0.947594, 0.917399, 0.410665), - (0.946809, 0.924168, 0.426373), - (0.946392, 0.930761, 0.442367), - (0.946403, 0.937159, 0.458592), - (0.946903, 0.943348, 0.474970), - (0.947937, 0.949318, 0.491426), - (0.949545, 0.955063, 0.507860), - (0.951740, 0.960587, 0.524203), - (0.954529, 0.965896, 0.540361), - (0.957896, 0.971003, 0.556275), - (0.961812, 0.975924, 0.571925), - (0.966249, 0.980678, 0.587206), - (0.971162, 0.985282, 0.602154), - (0.976511, 0.989753, 0.616760), - (0.982257, 0.994109, 0.631017), - (0.988362, 0.998364, 0.644924))) + ( + (0.001462, 0.000466, 0.013866), + (0.002267, 0.001270, 0.018570), + (0.003299, 0.002249, 0.024239), + (0.004547, 0.003392, 0.030909), + (0.006006, 0.004692, 0.038558), + (0.007676, 0.006136, 0.046836), + (0.009561, 0.007713, 0.055143), + (0.011663, 0.009417, 0.063460), + (0.013995, 0.011225, 0.071862), + (0.016561, 0.013136, 0.080282), + (0.019373, 0.015133, 0.088767), + (0.022447, 0.017199, 0.097327), + (0.025793, 0.019331, 0.105930), + (0.029432, 0.021503, 0.114621), + (0.033385, 0.023702, 0.123397), + (0.037668, 0.025921, 0.132232), + (0.042253, 0.028139, 0.141141), + (0.046915, 0.030324, 0.150164), + (0.051644, 0.032474, 0.159254), + (0.056449, 0.034569, 0.168414), + (0.061340, 0.036590, 0.177642), + (0.066331, 0.038504, 0.186962), + (0.071429, 0.040294, 0.196354), + (0.076637, 0.041905, 0.205799), + (0.081962, 0.043328, 0.215289), + (0.087411, 0.044556, 0.224813), + (0.092990, 0.045583, 0.234358), + (0.098702, 0.046402, 0.243904), + (0.104551, 0.047008, 0.253430), + (0.110536, 0.047399, 0.262912), + (0.116656, 0.047574, 0.272321), + (0.122908, 0.047536, 0.281624), + (0.129285, 0.047293, 0.290788), + (0.135778, 0.046856, 0.299776), + (0.142378, 0.046242, 0.308553), + (0.149073, 0.045468, 0.317085), + (0.155850, 0.044559, 0.325338), + (0.162689, 0.043554, 0.333277), + (0.169575, 0.042489, 0.340874), + (0.176493, 0.041402, 0.348111), + (0.183429, 0.040329, 0.354971), + (0.190367, 0.039309, 0.361447), + (0.197297, 0.038400, 0.367535), + (0.204209, 0.037632, 0.373238), + (0.211095, 0.037030, 0.378563), + (0.217949, 0.036615, 0.383522), + (0.224763, 0.036405, 0.388129), + (0.231538, 0.036405, 0.392400), + (0.238273, 0.036621, 0.396353), + (0.244967, 0.037055, 0.400007), + (0.251620, 0.037705, 0.403378), + (0.258234, 0.038571, 0.406485), + (0.264810, 0.039647, 0.409345), + (0.271347, 0.040922, 0.411976), + (0.277850, 0.042353, 0.414392), + (0.284321, 0.043933, 0.416608), + (0.290763, 0.045644, 0.418637), + (0.297178, 0.047470, 0.420491), + (0.303568, 0.049396, 0.422182), + (0.309935, 0.051407, 0.423721), + (0.316282, 0.053490, 0.425116), + (0.322610, 0.055634, 0.426377), + (0.328921, 0.057827, 0.427511), + (0.335217, 0.060060, 0.428524), + (0.341500, 0.062325, 0.429425), + (0.347771, 0.064616, 0.430217), + (0.354032, 0.066925, 0.430906), + (0.360284, 0.069247, 0.431497), + (0.366529, 0.071579, 0.431994), + (0.372768, 0.073915, 0.432400), + (0.379001, 0.076253, 0.432719), + (0.385228, 0.078591, 0.432955), + (0.391453, 0.080927, 0.433109), + (0.397674, 0.083257, 0.433183), + (0.403894, 0.085580, 0.433179), + (0.410113, 0.087896, 0.433098), + (0.416331, 0.090203, 0.432943), + (0.422549, 0.092501, 0.432714), + (0.428768, 0.094790, 0.432412), + (0.434987, 0.097069, 0.432039), + (0.441207, 0.099338, 0.431594), + (0.447428, 0.101597, 0.431080), + (0.453651, 0.103848, 0.430498), + (0.459875, 0.106089, 0.429846), + (0.466100, 0.108322, 0.429125), + (0.472328, 0.110547, 0.428334), + (0.478558, 0.112764, 0.427475), + (0.484789, 0.114974, 0.426548), + (0.491022, 0.117179, 0.425552), + (0.497257, 0.119379, 0.424488), + (0.503493, 0.121575, 0.423356), + (0.509730, 0.123769, 0.422156), + (0.515967, 0.125960, 0.420887), + (0.522206, 0.128150, 0.419549), + (0.528444, 0.130341, 0.418142), + (0.534683, 0.132534, 0.416667), + (0.540920, 0.134729, 0.415123), + (0.547157, 0.136929, 0.413511), + (0.553392, 0.139134, 0.411829), + (0.559624, 0.141346, 0.410078), + (0.565854, 0.143567, 0.408258), + (0.572081, 0.145797, 0.406369), + (0.578304, 0.148039, 0.404411), + (0.584521, 0.150294, 0.402385), + (0.590734, 0.152563, 0.400290), + (0.596940, 0.154848, 0.398125), + (0.603139, 0.157151, 0.395891), + (0.609330, 0.159474, 0.393589), + (0.615513, 0.161817, 0.391219), + (0.621685, 0.164184, 0.388781), + (0.627847, 0.166575, 0.386276), + (0.633998, 0.168992, 0.383704), + (0.640135, 0.171438, 0.381065), + (0.646260, 0.173914, 0.378359), + (0.652369, 0.176421, 0.375586), + (0.658463, 0.178962, 0.372748), + (0.664540, 0.181539, 0.369846), + (0.670599, 0.184153, 0.366879), + (0.676638, 0.186807, 0.363849), + (0.682656, 0.189501, 0.360757), + (0.688653, 0.192239, 0.357603), + (0.694627, 0.195021, 0.354388), + (0.700576, 0.197851, 0.351113), + (0.706500, 0.200728, 0.347777), + (0.712396, 0.203656, 0.344383), + (0.718264, 0.206636, 0.340931), + (0.724103, 0.209670, 0.337424), + (0.729909, 0.212759, 0.333861), + (0.735683, 0.215906, 0.330245), + (0.741423, 0.219112, 0.326576), + (0.747127, 0.222378, 0.322856), + (0.752794, 0.225706, 0.319085), + (0.758422, 0.229097, 0.315266), + (0.764010, 0.232554, 0.311399), + (0.769556, 0.236077, 0.307485), + (0.775059, 0.239667, 0.303526), + (0.780517, 0.243327, 0.299523), + (0.785929, 0.247056, 0.295477), + (0.791293, 0.250856, 0.291390), + (0.796607, 0.254728, 0.287264), + (0.801871, 0.258674, 0.283099), + (0.807082, 0.262692, 0.278898), + (0.812239, 0.266786, 0.274661), + (0.817341, 0.270954, 0.270390), + (0.822386, 0.275197, 0.266085), + (0.827372, 0.279517, 0.261750), + (0.832299, 0.283913, 0.257383), + (0.837165, 0.288385, 0.252988), + (0.841969, 0.292933, 0.248564), + (0.846709, 0.297559, 0.244113), + (0.851384, 0.302260, 0.239636), + (0.855992, 0.307038, 0.235133), + (0.860533, 0.311892, 0.230606), + (0.865006, 0.316822, 0.226055), + (0.869409, 0.321827, 0.221482), + (0.873741, 0.326906, 0.216886), + (0.878001, 0.332060, 0.212268), + (0.882188, 0.337287, 0.207628), + (0.886302, 0.342586, 0.202968), + (0.890341, 0.347957, 0.198286), + (0.894305, 0.353399, 0.193584), + (0.898192, 0.358911, 0.188860), + (0.902003, 0.364492, 0.184116), + (0.905735, 0.370140, 0.179350), + (0.909390, 0.375856, 0.174563), + (0.912966, 0.381636, 0.169755), + (0.916462, 0.387481, 0.164924), + (0.919879, 0.393389, 0.160070), + (0.923215, 0.399359, 0.155193), + (0.926470, 0.405389, 0.150292), + (0.929644, 0.411479, 0.145367), + (0.932737, 0.417627, 0.140417), + (0.935747, 0.423831, 0.135440), + (0.938675, 0.430091, 0.130438), + (0.941521, 0.436405, 0.125409), + (0.944285, 0.442772, 0.120354), + (0.946965, 0.449191, 0.115272), + (0.949562, 0.455660, 0.110164), + (0.952075, 0.462178, 0.105031), + (0.954506, 0.468744, 0.099874), + (0.956852, 0.475356, 0.094695), + (0.959114, 0.482014, 0.089499), + (0.961293, 0.488716, 0.084289), + (0.963387, 0.495462, 0.079073), + (0.965397, 0.502249, 0.073859), + (0.967322, 0.509078, 0.068659), + (0.969163, 0.515946, 0.063488), + (0.970919, 0.522853, 0.058367), + (0.972590, 0.529798, 0.053324), + (0.974176, 0.536780, 0.048392), + (0.975677, 0.543798, 0.043618), + (0.977092, 0.550850, 0.039050), + (0.978422, 0.557937, 0.034931), + (0.979666, 0.565057, 0.031409), + (0.980824, 0.572209, 0.028508), + (0.981895, 0.579392, 0.026250), + (0.982881, 0.586606, 0.024661), + (0.983779, 0.593849, 0.023770), + (0.984591, 0.601122, 0.023606), + (0.985315, 0.608422, 0.024202), + (0.985952, 0.615750, 0.025592), + (0.986502, 0.623105, 0.027814), + (0.986964, 0.630485, 0.030908), + (0.987337, 0.637890, 0.034916), + (0.987622, 0.645320, 0.039886), + (0.987819, 0.652773, 0.045581), + (0.987926, 0.660250, 0.051750), + (0.987945, 0.667748, 0.058329), + (0.987874, 0.675267, 0.065257), + (0.987714, 0.682807, 0.072489), + (0.987464, 0.690366, 0.079990), + (0.987124, 0.697944, 0.087731), + (0.986694, 0.705540, 0.095694), + (0.986175, 0.713153, 0.103863), + (0.985566, 0.720782, 0.112229), + (0.984865, 0.728427, 0.120785), + (0.984075, 0.736087, 0.129527), + (0.983196, 0.743758, 0.138453), + (0.982228, 0.751442, 0.147565), + (0.981173, 0.759135, 0.156863), + (0.980032, 0.766837, 0.166353), + (0.978806, 0.774545, 0.176037), + (0.977497, 0.782258, 0.185923), + (0.976108, 0.789974, 0.196018), + (0.974638, 0.797692, 0.206332), + (0.973088, 0.805409, 0.216877), + (0.971468, 0.813122, 0.227658), + (0.969783, 0.820825, 0.238686), + (0.968041, 0.828515, 0.249972), + (0.966243, 0.836191, 0.261534), + (0.964394, 0.843848, 0.273391), + (0.962517, 0.851476, 0.285546), + (0.960626, 0.859069, 0.298010), + (0.958720, 0.866624, 0.310820), + (0.956834, 0.874129, 0.323974), + (0.954997, 0.881569, 0.337475), + (0.953215, 0.888942, 0.351369), + (0.951546, 0.896226, 0.365627), + (0.950018, 0.903409, 0.380271), + (0.948683, 0.910473, 0.395289), + (0.947594, 0.917399, 0.410665), + (0.946809, 0.924168, 0.426373), + (0.946392, 0.930761, 0.442367), + (0.946403, 0.937159, 0.458592), + (0.946903, 0.943348, 0.474970), + (0.947937, 0.949318, 0.491426), + (0.949545, 0.955063, 0.507860), + (0.951740, 0.960587, 0.524203), + (0.954529, 0.965896, 0.540361), + (0.957896, 0.971003, 0.556275), + (0.961812, 0.975924, 0.571925), + (0.966249, 0.980678, 0.587206), + (0.971162, 0.985282, 0.602154), + (0.976511, 0.989753, 0.616760), + (0.982257, 0.994109, 0.631017), + (0.988362, 0.998364, 0.644924), + ) +) plasma = _convert( - ((0.050383, 0.029803, 0.527975), - (0.063536, 0.028426, 0.533124), - (0.075353, 0.027206, 0.538007), - (0.086222, 0.026125, 0.542658), - (0.096379, 0.025165, 0.547103), - (0.105980, 0.024309, 0.551368), - (0.115124, 0.023556, 0.555468), - (0.123903, 0.022878, 0.559423), - (0.132381, 0.022258, 0.563250), - (0.140603, 0.021687, 0.566959), - (0.148607, 0.021154, 0.570562), - (0.156421, 0.020651, 0.574065), - (0.164070, 0.020171, 0.577478), - (0.171574, 0.019706, 0.580806), - (0.178950, 0.019252, 0.584054), - (0.186213, 0.018803, 0.587228), - (0.193374, 0.018354, 0.590330), - (0.200445, 0.017902, 0.593364), - (0.207435, 0.017442, 0.596333), - (0.214350, 0.016973, 0.599239), - (0.221197, 0.016497, 0.602083), - (0.227983, 0.016007, 0.604867), - (0.234715, 0.015502, 0.607592), - (0.241396, 0.014979, 0.610259), - (0.248032, 0.014439, 0.612868), - (0.254627, 0.013882, 0.615419), - (0.261183, 0.013308, 0.617911), - (0.267703, 0.012716, 0.620346), - (0.274191, 0.012109, 0.622722), - (0.280648, 0.011488, 0.625038), - (0.287076, 0.010855, 0.627295), - (0.293478, 0.010213, 0.629490), - (0.299855, 0.009561, 0.631624), - (0.306210, 0.008902, 0.633694), - (0.312543, 0.008239, 0.635700), - (0.318856, 0.007576, 0.637640), - (0.325150, 0.006915, 0.639512), - (0.331426, 0.006261, 0.641316), - (0.337683, 0.005618, 0.643049), - (0.343925, 0.004991, 0.644710), - (0.350150, 0.004382, 0.646298), - (0.356359, 0.003798, 0.647810), - (0.362553, 0.003243, 0.649245), - (0.368733, 0.002724, 0.650601), - (0.374897, 0.002245, 0.651876), - (0.381047, 0.001814, 0.653068), - (0.387183, 0.001434, 0.654177), - (0.393304, 0.001114, 0.655199), - (0.399411, 0.000859, 0.656133), - (0.405503, 0.000678, 0.656977), - (0.411580, 0.000577, 0.657730), - (0.417642, 0.000564, 0.658390), - (0.423689, 0.000646, 0.658956), - (0.429719, 0.000831, 0.659425), - (0.435734, 0.001127, 0.659797), - (0.441732, 0.001540, 0.660069), - (0.447714, 0.002080, 0.660240), - (0.453677, 0.002755, 0.660310), - (0.459623, 0.003574, 0.660277), - (0.465550, 0.004545, 0.660139), - (0.471457, 0.005678, 0.659897), - (0.477344, 0.006980, 0.659549), - (0.483210, 0.008460, 0.659095), - (0.489055, 0.010127, 0.658534), - (0.494877, 0.011990, 0.657865), - (0.500678, 0.014055, 0.657088), - (0.506454, 0.016333, 0.656202), - (0.512206, 0.018833, 0.655209), - (0.517933, 0.021563, 0.654109), - (0.523633, 0.024532, 0.652901), - (0.529306, 0.027747, 0.651586), - (0.534952, 0.031217, 0.650165), - (0.540570, 0.034950, 0.648640), - (0.546157, 0.038954, 0.647010), - (0.551715, 0.043136, 0.645277), - (0.557243, 0.047331, 0.643443), - (0.562738, 0.051545, 0.641509), - (0.568201, 0.055778, 0.639477), - (0.573632, 0.060028, 0.637349), - (0.579029, 0.064296, 0.635126), - (0.584391, 0.068579, 0.632812), - (0.589719, 0.072878, 0.630408), - (0.595011, 0.077190, 0.627917), - (0.600266, 0.081516, 0.625342), - (0.605485, 0.085854, 0.622686), - (0.610667, 0.090204, 0.619951), - (0.615812, 0.094564, 0.617140), - (0.620919, 0.098934, 0.614257), - (0.625987, 0.103312, 0.611305), - (0.631017, 0.107699, 0.608287), - (0.636008, 0.112092, 0.605205), - (0.640959, 0.116492, 0.602065), - (0.645872, 0.120898, 0.598867), - (0.650746, 0.125309, 0.595617), - (0.655580, 0.129725, 0.592317), - (0.660374, 0.134144, 0.588971), - (0.665129, 0.138566, 0.585582), - (0.669845, 0.142992, 0.582154), - (0.674522, 0.147419, 0.578688), - (0.679160, 0.151848, 0.575189), - (0.683758, 0.156278, 0.571660), - (0.688318, 0.160709, 0.568103), - (0.692840, 0.165141, 0.564522), - (0.697324, 0.169573, 0.560919), - (0.701769, 0.174005, 0.557296), - (0.706178, 0.178437, 0.553657), - (0.710549, 0.182868, 0.550004), - (0.714883, 0.187299, 0.546338), - (0.719181, 0.191729, 0.542663), - (0.723444, 0.196158, 0.538981), - (0.727670, 0.200586, 0.535293), - (0.731862, 0.205013, 0.531601), - (0.736019, 0.209439, 0.527908), - (0.740143, 0.213864, 0.524216), - (0.744232, 0.218288, 0.520524), - (0.748289, 0.222711, 0.516834), - (0.752312, 0.227133, 0.513149), - (0.756304, 0.231555, 0.509468), - (0.760264, 0.235976, 0.505794), - (0.764193, 0.240396, 0.502126), - (0.768090, 0.244817, 0.498465), - (0.771958, 0.249237, 0.494813), - (0.775796, 0.253658, 0.491171), - (0.779604, 0.258078, 0.487539), - (0.783383, 0.262500, 0.483918), - (0.787133, 0.266922, 0.480307), - (0.790855, 0.271345, 0.476706), - (0.794549, 0.275770, 0.473117), - (0.798216, 0.280197, 0.469538), - (0.801855, 0.284626, 0.465971), - (0.805467, 0.289057, 0.462415), - (0.809052, 0.293491, 0.458870), - (0.812612, 0.297928, 0.455338), - (0.816144, 0.302368, 0.451816), - (0.819651, 0.306812, 0.448306), - (0.823132, 0.311261, 0.444806), - (0.826588, 0.315714, 0.441316), - (0.830018, 0.320172, 0.437836), - (0.833422, 0.324635, 0.434366), - (0.836801, 0.329105, 0.430905), - (0.840155, 0.333580, 0.427455), - (0.843484, 0.338062, 0.424013), - (0.846788, 0.342551, 0.420579), - (0.850066, 0.347048, 0.417153), - (0.853319, 0.351553, 0.413734), - (0.856547, 0.356066, 0.410322), - (0.859750, 0.360588, 0.406917), - (0.862927, 0.365119, 0.403519), - (0.866078, 0.369660, 0.400126), - (0.869203, 0.374212, 0.396738), - (0.872303, 0.378774, 0.393355), - (0.875376, 0.383347, 0.389976), - (0.878423, 0.387932, 0.386600), - (0.881443, 0.392529, 0.383229), - (0.884436, 0.397139, 0.379860), - (0.887402, 0.401762, 0.376494), - (0.890340, 0.406398, 0.373130), - (0.893250, 0.411048, 0.369768), - (0.896131, 0.415712, 0.366407), - (0.898984, 0.420392, 0.363047), - (0.901807, 0.425087, 0.359688), - (0.904601, 0.429797, 0.356329), - (0.907365, 0.434524, 0.352970), - (0.910098, 0.439268, 0.349610), - (0.912800, 0.444029, 0.346251), - (0.915471, 0.448807, 0.342890), - (0.918109, 0.453603, 0.339529), - (0.920714, 0.458417, 0.336166), - (0.923287, 0.463251, 0.332801), - (0.925825, 0.468103, 0.329435), - (0.928329, 0.472975, 0.326067), - (0.930798, 0.477867, 0.322697), - (0.933232, 0.482780, 0.319325), - (0.935630, 0.487712, 0.315952), - (0.937990, 0.492667, 0.312575), - (0.940313, 0.497642, 0.309197), - (0.942598, 0.502639, 0.305816), - (0.944844, 0.507658, 0.302433), - (0.947051, 0.512699, 0.299049), - (0.949217, 0.517763, 0.295662), - (0.951344, 0.522850, 0.292275), - (0.953428, 0.527960, 0.288883), - (0.955470, 0.533093, 0.285490), - (0.957469, 0.538250, 0.282096), - (0.959424, 0.543431, 0.278701), - (0.961336, 0.548636, 0.275305), - (0.963203, 0.553865, 0.271909), - (0.965024, 0.559118, 0.268513), - (0.966798, 0.564396, 0.265118), - (0.968526, 0.569700, 0.261721), - (0.970205, 0.575028, 0.258325), - (0.971835, 0.580382, 0.254931), - (0.973416, 0.585761, 0.251540), - (0.974947, 0.591165, 0.248151), - (0.976428, 0.596595, 0.244767), - (0.977856, 0.602051, 0.241387), - (0.979233, 0.607532, 0.238013), - (0.980556, 0.613039, 0.234646), - (0.981826, 0.618572, 0.231287), - (0.983041, 0.624131, 0.227937), - (0.984199, 0.629718, 0.224595), - (0.985301, 0.635330, 0.221265), - (0.986345, 0.640969, 0.217948), - (0.987332, 0.646633, 0.214648), - (0.988260, 0.652325, 0.211364), - (0.989128, 0.658043, 0.208100), - (0.989935, 0.663787, 0.204859), - (0.990681, 0.669558, 0.201642), - (0.991365, 0.675355, 0.198453), - (0.991985, 0.681179, 0.195295), - (0.992541, 0.687030, 0.192170), - (0.993032, 0.692907, 0.189084), - (0.993456, 0.698810, 0.186041), - (0.993814, 0.704741, 0.183043), - (0.994103, 0.710698, 0.180097), - (0.994324, 0.716681, 0.177208), - (0.994474, 0.722691, 0.174381), - (0.994553, 0.728728, 0.171622), - (0.994561, 0.734791, 0.168938), - (0.994495, 0.740880, 0.166335), - (0.994355, 0.746995, 0.163821), - (0.994141, 0.753137, 0.161404), - (0.993851, 0.759304, 0.159092), - (0.993482, 0.765499, 0.156891), - (0.993033, 0.771720, 0.154808), - (0.992505, 0.777967, 0.152855), - (0.991897, 0.784239, 0.151042), - (0.991209, 0.790537, 0.149377), - (0.990439, 0.796859, 0.147870), - (0.989587, 0.803205, 0.146529), - (0.988648, 0.809579, 0.145357), - (0.987621, 0.815978, 0.144363), - (0.986509, 0.822401, 0.143557), - (0.985314, 0.828846, 0.142945), - (0.984031, 0.835315, 0.142528), - (0.982653, 0.841812, 0.142303), - (0.981190, 0.848329, 0.142279), - (0.979644, 0.854866, 0.142453), - (0.977995, 0.861432, 0.142808), - (0.976265, 0.868016, 0.143351), - (0.974443, 0.874622, 0.144061), - (0.972530, 0.881250, 0.144923), - (0.970533, 0.887896, 0.145919), - (0.968443, 0.894564, 0.147014), - (0.966271, 0.901249, 0.148180), - (0.964021, 0.907950, 0.149370), - (0.961681, 0.914672, 0.150520), - (0.959276, 0.921407, 0.151566), - (0.956808, 0.928152, 0.152409), - (0.954287, 0.934908, 0.152921), - (0.951726, 0.941671, 0.152925), - (0.949151, 0.948435, 0.152178), - (0.946602, 0.955190, 0.150328), - (0.944152, 0.961916, 0.146861), - (0.941896, 0.968590, 0.140956), - (0.940015, 0.975158, 0.131326))) + ( + (0.050383, 0.029803, 0.527975), + (0.063536, 0.028426, 0.533124), + (0.075353, 0.027206, 0.538007), + (0.086222, 0.026125, 0.542658), + (0.096379, 0.025165, 0.547103), + (0.105980, 0.024309, 0.551368), + (0.115124, 0.023556, 0.555468), + (0.123903, 0.022878, 0.559423), + (0.132381, 0.022258, 0.563250), + (0.140603, 0.021687, 0.566959), + (0.148607, 0.021154, 0.570562), + (0.156421, 0.020651, 0.574065), + (0.164070, 0.020171, 0.577478), + (0.171574, 0.019706, 0.580806), + (0.178950, 0.019252, 0.584054), + (0.186213, 0.018803, 0.587228), + (0.193374, 0.018354, 0.590330), + (0.200445, 0.017902, 0.593364), + (0.207435, 0.017442, 0.596333), + (0.214350, 0.016973, 0.599239), + (0.221197, 0.016497, 0.602083), + (0.227983, 0.016007, 0.604867), + (0.234715, 0.015502, 0.607592), + (0.241396, 0.014979, 0.610259), + (0.248032, 0.014439, 0.612868), + (0.254627, 0.013882, 0.615419), + (0.261183, 0.013308, 0.617911), + (0.267703, 0.012716, 0.620346), + (0.274191, 0.012109, 0.622722), + (0.280648, 0.011488, 0.625038), + (0.287076, 0.010855, 0.627295), + (0.293478, 0.010213, 0.629490), + (0.299855, 0.009561, 0.631624), + (0.306210, 0.008902, 0.633694), + (0.312543, 0.008239, 0.635700), + (0.318856, 0.007576, 0.637640), + (0.325150, 0.006915, 0.639512), + (0.331426, 0.006261, 0.641316), + (0.337683, 0.005618, 0.643049), + (0.343925, 0.004991, 0.644710), + (0.350150, 0.004382, 0.646298), + (0.356359, 0.003798, 0.647810), + (0.362553, 0.003243, 0.649245), + (0.368733, 0.002724, 0.650601), + (0.374897, 0.002245, 0.651876), + (0.381047, 0.001814, 0.653068), + (0.387183, 0.001434, 0.654177), + (0.393304, 0.001114, 0.655199), + (0.399411, 0.000859, 0.656133), + (0.405503, 0.000678, 0.656977), + (0.411580, 0.000577, 0.657730), + (0.417642, 0.000564, 0.658390), + (0.423689, 0.000646, 0.658956), + (0.429719, 0.000831, 0.659425), + (0.435734, 0.001127, 0.659797), + (0.441732, 0.001540, 0.660069), + (0.447714, 0.002080, 0.660240), + (0.453677, 0.002755, 0.660310), + (0.459623, 0.003574, 0.660277), + (0.465550, 0.004545, 0.660139), + (0.471457, 0.005678, 0.659897), + (0.477344, 0.006980, 0.659549), + (0.483210, 0.008460, 0.659095), + (0.489055, 0.010127, 0.658534), + (0.494877, 0.011990, 0.657865), + (0.500678, 0.014055, 0.657088), + (0.506454, 0.016333, 0.656202), + (0.512206, 0.018833, 0.655209), + (0.517933, 0.021563, 0.654109), + (0.523633, 0.024532, 0.652901), + (0.529306, 0.027747, 0.651586), + (0.534952, 0.031217, 0.650165), + (0.540570, 0.034950, 0.648640), + (0.546157, 0.038954, 0.647010), + (0.551715, 0.043136, 0.645277), + (0.557243, 0.047331, 0.643443), + (0.562738, 0.051545, 0.641509), + (0.568201, 0.055778, 0.639477), + (0.573632, 0.060028, 0.637349), + (0.579029, 0.064296, 0.635126), + (0.584391, 0.068579, 0.632812), + (0.589719, 0.072878, 0.630408), + (0.595011, 0.077190, 0.627917), + (0.600266, 0.081516, 0.625342), + (0.605485, 0.085854, 0.622686), + (0.610667, 0.090204, 0.619951), + (0.615812, 0.094564, 0.617140), + (0.620919, 0.098934, 0.614257), + (0.625987, 0.103312, 0.611305), + (0.631017, 0.107699, 0.608287), + (0.636008, 0.112092, 0.605205), + (0.640959, 0.116492, 0.602065), + (0.645872, 0.120898, 0.598867), + (0.650746, 0.125309, 0.595617), + (0.655580, 0.129725, 0.592317), + (0.660374, 0.134144, 0.588971), + (0.665129, 0.138566, 0.585582), + (0.669845, 0.142992, 0.582154), + (0.674522, 0.147419, 0.578688), + (0.679160, 0.151848, 0.575189), + (0.683758, 0.156278, 0.571660), + (0.688318, 0.160709, 0.568103), + (0.692840, 0.165141, 0.564522), + (0.697324, 0.169573, 0.560919), + (0.701769, 0.174005, 0.557296), + (0.706178, 0.178437, 0.553657), + (0.710549, 0.182868, 0.550004), + (0.714883, 0.187299, 0.546338), + (0.719181, 0.191729, 0.542663), + (0.723444, 0.196158, 0.538981), + (0.727670, 0.200586, 0.535293), + (0.731862, 0.205013, 0.531601), + (0.736019, 0.209439, 0.527908), + (0.740143, 0.213864, 0.524216), + (0.744232, 0.218288, 0.520524), + (0.748289, 0.222711, 0.516834), + (0.752312, 0.227133, 0.513149), + (0.756304, 0.231555, 0.509468), + (0.760264, 0.235976, 0.505794), + (0.764193, 0.240396, 0.502126), + (0.768090, 0.244817, 0.498465), + (0.771958, 0.249237, 0.494813), + (0.775796, 0.253658, 0.491171), + (0.779604, 0.258078, 0.487539), + (0.783383, 0.262500, 0.483918), + (0.787133, 0.266922, 0.480307), + (0.790855, 0.271345, 0.476706), + (0.794549, 0.275770, 0.473117), + (0.798216, 0.280197, 0.469538), + (0.801855, 0.284626, 0.465971), + (0.805467, 0.289057, 0.462415), + (0.809052, 0.293491, 0.458870), + (0.812612, 0.297928, 0.455338), + (0.816144, 0.302368, 0.451816), + (0.819651, 0.306812, 0.448306), + (0.823132, 0.311261, 0.444806), + (0.826588, 0.315714, 0.441316), + (0.830018, 0.320172, 0.437836), + (0.833422, 0.324635, 0.434366), + (0.836801, 0.329105, 0.430905), + (0.840155, 0.333580, 0.427455), + (0.843484, 0.338062, 0.424013), + (0.846788, 0.342551, 0.420579), + (0.850066, 0.347048, 0.417153), + (0.853319, 0.351553, 0.413734), + (0.856547, 0.356066, 0.410322), + (0.859750, 0.360588, 0.406917), + (0.862927, 0.365119, 0.403519), + (0.866078, 0.369660, 0.400126), + (0.869203, 0.374212, 0.396738), + (0.872303, 0.378774, 0.393355), + (0.875376, 0.383347, 0.389976), + (0.878423, 0.387932, 0.386600), + (0.881443, 0.392529, 0.383229), + (0.884436, 0.397139, 0.379860), + (0.887402, 0.401762, 0.376494), + (0.890340, 0.406398, 0.373130), + (0.893250, 0.411048, 0.369768), + (0.896131, 0.415712, 0.366407), + (0.898984, 0.420392, 0.363047), + (0.901807, 0.425087, 0.359688), + (0.904601, 0.429797, 0.356329), + (0.907365, 0.434524, 0.352970), + (0.910098, 0.439268, 0.349610), + (0.912800, 0.444029, 0.346251), + (0.915471, 0.448807, 0.342890), + (0.918109, 0.453603, 0.339529), + (0.920714, 0.458417, 0.336166), + (0.923287, 0.463251, 0.332801), + (0.925825, 0.468103, 0.329435), + (0.928329, 0.472975, 0.326067), + (0.930798, 0.477867, 0.322697), + (0.933232, 0.482780, 0.319325), + (0.935630, 0.487712, 0.315952), + (0.937990, 0.492667, 0.312575), + (0.940313, 0.497642, 0.309197), + (0.942598, 0.502639, 0.305816), + (0.944844, 0.507658, 0.302433), + (0.947051, 0.512699, 0.299049), + (0.949217, 0.517763, 0.295662), + (0.951344, 0.522850, 0.292275), + (0.953428, 0.527960, 0.288883), + (0.955470, 0.533093, 0.285490), + (0.957469, 0.538250, 0.282096), + (0.959424, 0.543431, 0.278701), + (0.961336, 0.548636, 0.275305), + (0.963203, 0.553865, 0.271909), + (0.965024, 0.559118, 0.268513), + (0.966798, 0.564396, 0.265118), + (0.968526, 0.569700, 0.261721), + (0.970205, 0.575028, 0.258325), + (0.971835, 0.580382, 0.254931), + (0.973416, 0.585761, 0.251540), + (0.974947, 0.591165, 0.248151), + (0.976428, 0.596595, 0.244767), + (0.977856, 0.602051, 0.241387), + (0.979233, 0.607532, 0.238013), + (0.980556, 0.613039, 0.234646), + (0.981826, 0.618572, 0.231287), + (0.983041, 0.624131, 0.227937), + (0.984199, 0.629718, 0.224595), + (0.985301, 0.635330, 0.221265), + (0.986345, 0.640969, 0.217948), + (0.987332, 0.646633, 0.214648), + (0.988260, 0.652325, 0.211364), + (0.989128, 0.658043, 0.208100), + (0.989935, 0.663787, 0.204859), + (0.990681, 0.669558, 0.201642), + (0.991365, 0.675355, 0.198453), + (0.991985, 0.681179, 0.195295), + (0.992541, 0.687030, 0.192170), + (0.993032, 0.692907, 0.189084), + (0.993456, 0.698810, 0.186041), + (0.993814, 0.704741, 0.183043), + (0.994103, 0.710698, 0.180097), + (0.994324, 0.716681, 0.177208), + (0.994474, 0.722691, 0.174381), + (0.994553, 0.728728, 0.171622), + (0.994561, 0.734791, 0.168938), + (0.994495, 0.740880, 0.166335), + (0.994355, 0.746995, 0.163821), + (0.994141, 0.753137, 0.161404), + (0.993851, 0.759304, 0.159092), + (0.993482, 0.765499, 0.156891), + (0.993033, 0.771720, 0.154808), + (0.992505, 0.777967, 0.152855), + (0.991897, 0.784239, 0.151042), + (0.991209, 0.790537, 0.149377), + (0.990439, 0.796859, 0.147870), + (0.989587, 0.803205, 0.146529), + (0.988648, 0.809579, 0.145357), + (0.987621, 0.815978, 0.144363), + (0.986509, 0.822401, 0.143557), + (0.985314, 0.828846, 0.142945), + (0.984031, 0.835315, 0.142528), + (0.982653, 0.841812, 0.142303), + (0.981190, 0.848329, 0.142279), + (0.979644, 0.854866, 0.142453), + (0.977995, 0.861432, 0.142808), + (0.976265, 0.868016, 0.143351), + (0.974443, 0.874622, 0.144061), + (0.972530, 0.881250, 0.144923), + (0.970533, 0.887896, 0.145919), + (0.968443, 0.894564, 0.147014), + (0.966271, 0.901249, 0.148180), + (0.964021, 0.907950, 0.149370), + (0.961681, 0.914672, 0.150520), + (0.959276, 0.921407, 0.151566), + (0.956808, 0.928152, 0.152409), + (0.954287, 0.934908, 0.152921), + (0.951726, 0.941671, 0.152925), + (0.949151, 0.948435, 0.152178), + (0.946602, 0.955190, 0.150328), + (0.944152, 0.961916, 0.146861), + (0.941896, 0.968590, 0.140956), + (0.940015, 0.975158, 0.131326), + ) +) viridis = _convert( - ((0.267004, 0.004874, 0.329415), - (0.268510, 0.009605, 0.335427), - (0.269944, 0.014625, 0.341379), - (0.271305, 0.019942, 0.347269), - (0.272594, 0.025563, 0.353093), - (0.273809, 0.031497, 0.358853), - (0.274952, 0.037752, 0.364543), - (0.276022, 0.044167, 0.370164), - (0.277018, 0.050344, 0.375715), - (0.277941, 0.056324, 0.381191), - (0.278791, 0.062145, 0.386592), - (0.279566, 0.067836, 0.391917), - (0.280267, 0.073417, 0.397163), - (0.280894, 0.078907, 0.402329), - (0.281446, 0.084320, 0.407414), - (0.281924, 0.089666, 0.412415), - (0.282327, 0.094955, 0.417331), - (0.282656, 0.100196, 0.422160), - (0.282910, 0.105393, 0.426902), - (0.283091, 0.110553, 0.431554), - (0.283197, 0.115680, 0.436115), - (0.283229, 0.120777, 0.440584), - (0.283187, 0.125848, 0.444960), - (0.283072, 0.130895, 0.449241), - (0.282884, 0.135920, 0.453427), - (0.282623, 0.140926, 0.457517), - (0.282290, 0.145912, 0.461510), - (0.281887, 0.150881, 0.465405), - (0.281412, 0.155834, 0.469201), - (0.280868, 0.160771, 0.472899), - (0.280255, 0.165693, 0.476498), - (0.279574, 0.170599, 0.479997), - (0.278826, 0.175490, 0.483397), - (0.278012, 0.180367, 0.486697), - (0.277134, 0.185228, 0.489898), - (0.276194, 0.190074, 0.493001), - (0.275191, 0.194905, 0.496005), - (0.274128, 0.199721, 0.498911), - (0.273006, 0.204520, 0.501721), - (0.271828, 0.209303, 0.504434), - (0.270595, 0.214069, 0.507052), - (0.269308, 0.218818, 0.509577), - (0.267968, 0.223549, 0.512008), - (0.266580, 0.228262, 0.514349), - (0.265145, 0.232956, 0.516599), - (0.263663, 0.237631, 0.518762), - (0.262138, 0.242286, 0.520837), - (0.260571, 0.246922, 0.522828), - (0.258965, 0.251537, 0.524736), - (0.257322, 0.256130, 0.526563), - (0.255645, 0.260703, 0.528312), - (0.253935, 0.265254, 0.529983), - (0.252194, 0.269783, 0.531579), - (0.250425, 0.274290, 0.533103), - (0.248629, 0.278775, 0.534556), - (0.246811, 0.283237, 0.535941), - (0.244972, 0.287675, 0.537260), - (0.243113, 0.292092, 0.538516), - (0.241237, 0.296485, 0.539709), - (0.239346, 0.300855, 0.540844), - (0.237441, 0.305202, 0.541921), - (0.235526, 0.309527, 0.542944), - (0.233603, 0.313828, 0.543914), - (0.231674, 0.318106, 0.544834), - (0.229739, 0.322361, 0.545706), - (0.227802, 0.326594, 0.546532), - (0.225863, 0.330805, 0.547314), - (0.223925, 0.334994, 0.548053), - (0.221989, 0.339161, 0.548752), - (0.220057, 0.343307, 0.549413), - (0.218130, 0.347432, 0.550038), - (0.216210, 0.351535, 0.550627), - (0.214298, 0.355619, 0.551184), - (0.212395, 0.359683, 0.551710), - (0.210503, 0.363727, 0.552206), - (0.208623, 0.367752, 0.552675), - (0.206756, 0.371758, 0.553117), - (0.204903, 0.375746, 0.553533), - (0.203063, 0.379716, 0.553925), - (0.201239, 0.383670, 0.554294), - (0.199430, 0.387607, 0.554642), - (0.197636, 0.391528, 0.554969), - (0.195860, 0.395433, 0.555276), - (0.194100, 0.399323, 0.555565), - (0.192357, 0.403199, 0.555836), - (0.190631, 0.407061, 0.556089), - (0.188923, 0.410910, 0.556326), - (0.187231, 0.414746, 0.556547), - (0.185556, 0.418570, 0.556753), - (0.183898, 0.422383, 0.556944), - (0.182256, 0.426184, 0.557120), - (0.180629, 0.429975, 0.557282), - (0.179019, 0.433756, 0.557430), - (0.177423, 0.437527, 0.557565), - (0.175841, 0.441290, 0.557685), - (0.174274, 0.445044, 0.557792), - (0.172719, 0.448791, 0.557885), - (0.171176, 0.452530, 0.557965), - (0.169646, 0.456262, 0.558030), - (0.168126, 0.459988, 0.558082), - (0.166617, 0.463708, 0.558119), - (0.165117, 0.467423, 0.558141), - (0.163625, 0.471133, 0.558148), - (0.162142, 0.474838, 0.558140), - (0.160665, 0.478540, 0.558115), - (0.159194, 0.482237, 0.558073), - (0.157729, 0.485932, 0.558013), - (0.156270, 0.489624, 0.557936), - (0.154815, 0.493313, 0.557840), - (0.153364, 0.497000, 0.557724), - (0.151918, 0.500685, 0.557587), - (0.150476, 0.504369, 0.557430), - (0.149039, 0.508051, 0.557250), - (0.147607, 0.511733, 0.557049), - (0.146180, 0.515413, 0.556823), - (0.144759, 0.519093, 0.556572), - (0.143343, 0.522773, 0.556295), - (0.141935, 0.526453, 0.555991), - (0.140536, 0.530132, 0.555659), - (0.139147, 0.533812, 0.555298), - (0.137770, 0.537492, 0.554906), - (0.136408, 0.541173, 0.554483), - (0.135066, 0.544853, 0.554029), - (0.133743, 0.548535, 0.553541), - (0.132444, 0.552216, 0.553018), - (0.131172, 0.555899, 0.552459), - (0.129933, 0.559582, 0.551864), - (0.128729, 0.563265, 0.551229), - (0.127568, 0.566949, 0.550556), - (0.126453, 0.570633, 0.549841), - (0.125394, 0.574318, 0.549086), - (0.124395, 0.578002, 0.548287), - (0.123463, 0.581687, 0.547445), - (0.122606, 0.585371, 0.546557), - (0.121831, 0.589055, 0.545623), - (0.121148, 0.592739, 0.544641), - (0.120565, 0.596422, 0.543611), - (0.120092, 0.600104, 0.542530), - (0.119738, 0.603785, 0.541400), - (0.119512, 0.607464, 0.540218), - (0.119423, 0.611141, 0.538982), - (0.119483, 0.614817, 0.537692), - (0.119699, 0.618490, 0.536347), - (0.120081, 0.622161, 0.534946), - (0.120638, 0.625828, 0.533488), - (0.121380, 0.629492, 0.531973), - (0.122312, 0.633153, 0.530398), - (0.123444, 0.636809, 0.528763), - (0.124780, 0.640461, 0.527068), - (0.126326, 0.644107, 0.525311), - (0.128087, 0.647749, 0.523491), - (0.130067, 0.651384, 0.521608), - (0.132268, 0.655014, 0.519661), - (0.134692, 0.658636, 0.517649), - (0.137339, 0.662252, 0.515571), - (0.140210, 0.665859, 0.513427), - (0.143303, 0.669459, 0.511215), - (0.146616, 0.673050, 0.508936), - (0.150148, 0.676631, 0.506589), - (0.153894, 0.680203, 0.504172), - (0.157851, 0.683765, 0.501686), - (0.162016, 0.687316, 0.499129), - (0.166383, 0.690856, 0.496502), - (0.170948, 0.694384, 0.493803), - (0.175707, 0.697900, 0.491033), - (0.180653, 0.701402, 0.488189), - (0.185783, 0.704891, 0.485273), - (0.191090, 0.708366, 0.482284), - (0.196571, 0.711827, 0.479221), - (0.202219, 0.715272, 0.476084), - (0.208030, 0.718701, 0.472873), - (0.214000, 0.722114, 0.469588), - (0.220124, 0.725509, 0.466226), - (0.226397, 0.728888, 0.462789), - (0.232815, 0.732247, 0.459277), - (0.239374, 0.735588, 0.455688), - (0.246070, 0.738910, 0.452024), - (0.252899, 0.742211, 0.448284), - (0.259857, 0.745492, 0.444467), - (0.266941, 0.748751, 0.440573), - (0.274149, 0.751988, 0.436601), - (0.281477, 0.755203, 0.432552), - (0.288921, 0.758394, 0.428426), - (0.296479, 0.761561, 0.424223), - (0.304148, 0.764704, 0.419943), - (0.311925, 0.767822, 0.415586), - (0.319809, 0.770914, 0.411152), - (0.327796, 0.773980, 0.406640), - (0.335885, 0.777018, 0.402049), - (0.344074, 0.780029, 0.397381), - (0.352360, 0.783011, 0.392636), - (0.360741, 0.785964, 0.387814), - (0.369214, 0.788888, 0.382914), - (0.377779, 0.791781, 0.377939), - (0.386433, 0.794644, 0.372886), - (0.395174, 0.797475, 0.367757), - (0.404001, 0.800275, 0.362552), - (0.412913, 0.803041, 0.357269), - (0.421908, 0.805774, 0.351910), - (0.430983, 0.808473, 0.346476), - (0.440137, 0.811138, 0.340967), - (0.449368, 0.813768, 0.335384), - (0.458674, 0.816363, 0.329727), - (0.468053, 0.818921, 0.323998), - (0.477504, 0.821444, 0.318195), - (0.487026, 0.823929, 0.312321), - (0.496615, 0.826376, 0.306377), - (0.506271, 0.828786, 0.300362), - (0.515992, 0.831158, 0.294279), - (0.525776, 0.833491, 0.288127), - (0.535621, 0.835785, 0.281908), - (0.545524, 0.838039, 0.275626), - (0.555484, 0.840254, 0.269281), - (0.565498, 0.842430, 0.262877), - (0.575563, 0.844566, 0.256415), - (0.585678, 0.846661, 0.249897), - (0.595839, 0.848717, 0.243329), - (0.606045, 0.850733, 0.236712), - (0.616293, 0.852709, 0.230052), - (0.626579, 0.854645, 0.223353), - (0.636902, 0.856542, 0.216620), - (0.647257, 0.858400, 0.209861), - (0.657642, 0.860219, 0.203082), - (0.668054, 0.861999, 0.196293), - (0.678489, 0.863742, 0.189503), - (0.688944, 0.865448, 0.182725), - (0.699415, 0.867117, 0.175971), - (0.709898, 0.868751, 0.169257), - (0.720391, 0.870350, 0.162603), - (0.730889, 0.871916, 0.156029), - (0.741388, 0.873449, 0.149561), - (0.751884, 0.874951, 0.143228), - (0.762373, 0.876424, 0.137064), - (0.772852, 0.877868, 0.131109), - (0.783315, 0.879285, 0.125405), - (0.793760, 0.880678, 0.120005), - (0.804182, 0.882046, 0.114965), - (0.814576, 0.883393, 0.110347), - (0.824940, 0.884720, 0.106217), - (0.835270, 0.886029, 0.102646), - (0.845561, 0.887322, 0.099702), - (0.855810, 0.888601, 0.097452), - (0.866013, 0.889868, 0.095953), - (0.876168, 0.891125, 0.095250), - (0.886271, 0.892374, 0.095374), - (0.896320, 0.893616, 0.096335), - (0.906311, 0.894855, 0.098125), - (0.916242, 0.896091, 0.100717), - (0.926106, 0.897330, 0.104071), - (0.935904, 0.898570, 0.108131), - (0.945636, 0.899815, 0.112838), - (0.955300, 0.901065, 0.118128), - (0.964894, 0.902323, 0.123941), - (0.974417, 0.903590, 0.130215), - (0.983868, 0.904867, 0.136897), - (0.993248, 0.906157, 0.143936))) + ( + (0.267004, 0.004874, 0.329415), + (0.268510, 0.009605, 0.335427), + (0.269944, 0.014625, 0.341379), + (0.271305, 0.019942, 0.347269), + (0.272594, 0.025563, 0.353093), + (0.273809, 0.031497, 0.358853), + (0.274952, 0.037752, 0.364543), + (0.276022, 0.044167, 0.370164), + (0.277018, 0.050344, 0.375715), + (0.277941, 0.056324, 0.381191), + (0.278791, 0.062145, 0.386592), + (0.279566, 0.067836, 0.391917), + (0.280267, 0.073417, 0.397163), + (0.280894, 0.078907, 0.402329), + (0.281446, 0.084320, 0.407414), + (0.281924, 0.089666, 0.412415), + (0.282327, 0.094955, 0.417331), + (0.282656, 0.100196, 0.422160), + (0.282910, 0.105393, 0.426902), + (0.283091, 0.110553, 0.431554), + (0.283197, 0.115680, 0.436115), + (0.283229, 0.120777, 0.440584), + (0.283187, 0.125848, 0.444960), + (0.283072, 0.130895, 0.449241), + (0.282884, 0.135920, 0.453427), + (0.282623, 0.140926, 0.457517), + (0.282290, 0.145912, 0.461510), + (0.281887, 0.150881, 0.465405), + (0.281412, 0.155834, 0.469201), + (0.280868, 0.160771, 0.472899), + (0.280255, 0.165693, 0.476498), + (0.279574, 0.170599, 0.479997), + (0.278826, 0.175490, 0.483397), + (0.278012, 0.180367, 0.486697), + (0.277134, 0.185228, 0.489898), + (0.276194, 0.190074, 0.493001), + (0.275191, 0.194905, 0.496005), + (0.274128, 0.199721, 0.498911), + (0.273006, 0.204520, 0.501721), + (0.271828, 0.209303, 0.504434), + (0.270595, 0.214069, 0.507052), + (0.269308, 0.218818, 0.509577), + (0.267968, 0.223549, 0.512008), + (0.266580, 0.228262, 0.514349), + (0.265145, 0.232956, 0.516599), + (0.263663, 0.237631, 0.518762), + (0.262138, 0.242286, 0.520837), + (0.260571, 0.246922, 0.522828), + (0.258965, 0.251537, 0.524736), + (0.257322, 0.256130, 0.526563), + (0.255645, 0.260703, 0.528312), + (0.253935, 0.265254, 0.529983), + (0.252194, 0.269783, 0.531579), + (0.250425, 0.274290, 0.533103), + (0.248629, 0.278775, 0.534556), + (0.246811, 0.283237, 0.535941), + (0.244972, 0.287675, 0.537260), + (0.243113, 0.292092, 0.538516), + (0.241237, 0.296485, 0.539709), + (0.239346, 0.300855, 0.540844), + (0.237441, 0.305202, 0.541921), + (0.235526, 0.309527, 0.542944), + (0.233603, 0.313828, 0.543914), + (0.231674, 0.318106, 0.544834), + (0.229739, 0.322361, 0.545706), + (0.227802, 0.326594, 0.546532), + (0.225863, 0.330805, 0.547314), + (0.223925, 0.334994, 0.548053), + (0.221989, 0.339161, 0.548752), + (0.220057, 0.343307, 0.549413), + (0.218130, 0.347432, 0.550038), + (0.216210, 0.351535, 0.550627), + (0.214298, 0.355619, 0.551184), + (0.212395, 0.359683, 0.551710), + (0.210503, 0.363727, 0.552206), + (0.208623, 0.367752, 0.552675), + (0.206756, 0.371758, 0.553117), + (0.204903, 0.375746, 0.553533), + (0.203063, 0.379716, 0.553925), + (0.201239, 0.383670, 0.554294), + (0.199430, 0.387607, 0.554642), + (0.197636, 0.391528, 0.554969), + (0.195860, 0.395433, 0.555276), + (0.194100, 0.399323, 0.555565), + (0.192357, 0.403199, 0.555836), + (0.190631, 0.407061, 0.556089), + (0.188923, 0.410910, 0.556326), + (0.187231, 0.414746, 0.556547), + (0.185556, 0.418570, 0.556753), + (0.183898, 0.422383, 0.556944), + (0.182256, 0.426184, 0.557120), + (0.180629, 0.429975, 0.557282), + (0.179019, 0.433756, 0.557430), + (0.177423, 0.437527, 0.557565), + (0.175841, 0.441290, 0.557685), + (0.174274, 0.445044, 0.557792), + (0.172719, 0.448791, 0.557885), + (0.171176, 0.452530, 0.557965), + (0.169646, 0.456262, 0.558030), + (0.168126, 0.459988, 0.558082), + (0.166617, 0.463708, 0.558119), + (0.165117, 0.467423, 0.558141), + (0.163625, 0.471133, 0.558148), + (0.162142, 0.474838, 0.558140), + (0.160665, 0.478540, 0.558115), + (0.159194, 0.482237, 0.558073), + (0.157729, 0.485932, 0.558013), + (0.156270, 0.489624, 0.557936), + (0.154815, 0.493313, 0.557840), + (0.153364, 0.497000, 0.557724), + (0.151918, 0.500685, 0.557587), + (0.150476, 0.504369, 0.557430), + (0.149039, 0.508051, 0.557250), + (0.147607, 0.511733, 0.557049), + (0.146180, 0.515413, 0.556823), + (0.144759, 0.519093, 0.556572), + (0.143343, 0.522773, 0.556295), + (0.141935, 0.526453, 0.555991), + (0.140536, 0.530132, 0.555659), + (0.139147, 0.533812, 0.555298), + (0.137770, 0.537492, 0.554906), + (0.136408, 0.541173, 0.554483), + (0.135066, 0.544853, 0.554029), + (0.133743, 0.548535, 0.553541), + (0.132444, 0.552216, 0.553018), + (0.131172, 0.555899, 0.552459), + (0.129933, 0.559582, 0.551864), + (0.128729, 0.563265, 0.551229), + (0.127568, 0.566949, 0.550556), + (0.126453, 0.570633, 0.549841), + (0.125394, 0.574318, 0.549086), + (0.124395, 0.578002, 0.548287), + (0.123463, 0.581687, 0.547445), + (0.122606, 0.585371, 0.546557), + (0.121831, 0.589055, 0.545623), + (0.121148, 0.592739, 0.544641), + (0.120565, 0.596422, 0.543611), + (0.120092, 0.600104, 0.542530), + (0.119738, 0.603785, 0.541400), + (0.119512, 0.607464, 0.540218), + (0.119423, 0.611141, 0.538982), + (0.119483, 0.614817, 0.537692), + (0.119699, 0.618490, 0.536347), + (0.120081, 0.622161, 0.534946), + (0.120638, 0.625828, 0.533488), + (0.121380, 0.629492, 0.531973), + (0.122312, 0.633153, 0.530398), + (0.123444, 0.636809, 0.528763), + (0.124780, 0.640461, 0.527068), + (0.126326, 0.644107, 0.525311), + (0.128087, 0.647749, 0.523491), + (0.130067, 0.651384, 0.521608), + (0.132268, 0.655014, 0.519661), + (0.134692, 0.658636, 0.517649), + (0.137339, 0.662252, 0.515571), + (0.140210, 0.665859, 0.513427), + (0.143303, 0.669459, 0.511215), + (0.146616, 0.673050, 0.508936), + (0.150148, 0.676631, 0.506589), + (0.153894, 0.680203, 0.504172), + (0.157851, 0.683765, 0.501686), + (0.162016, 0.687316, 0.499129), + (0.166383, 0.690856, 0.496502), + (0.170948, 0.694384, 0.493803), + (0.175707, 0.697900, 0.491033), + (0.180653, 0.701402, 0.488189), + (0.185783, 0.704891, 0.485273), + (0.191090, 0.708366, 0.482284), + (0.196571, 0.711827, 0.479221), + (0.202219, 0.715272, 0.476084), + (0.208030, 0.718701, 0.472873), + (0.214000, 0.722114, 0.469588), + (0.220124, 0.725509, 0.466226), + (0.226397, 0.728888, 0.462789), + (0.232815, 0.732247, 0.459277), + (0.239374, 0.735588, 0.455688), + (0.246070, 0.738910, 0.452024), + (0.252899, 0.742211, 0.448284), + (0.259857, 0.745492, 0.444467), + (0.266941, 0.748751, 0.440573), + (0.274149, 0.751988, 0.436601), + (0.281477, 0.755203, 0.432552), + (0.288921, 0.758394, 0.428426), + (0.296479, 0.761561, 0.424223), + (0.304148, 0.764704, 0.419943), + (0.311925, 0.767822, 0.415586), + (0.319809, 0.770914, 0.411152), + (0.327796, 0.773980, 0.406640), + (0.335885, 0.777018, 0.402049), + (0.344074, 0.780029, 0.397381), + (0.352360, 0.783011, 0.392636), + (0.360741, 0.785964, 0.387814), + (0.369214, 0.788888, 0.382914), + (0.377779, 0.791781, 0.377939), + (0.386433, 0.794644, 0.372886), + (0.395174, 0.797475, 0.367757), + (0.404001, 0.800275, 0.362552), + (0.412913, 0.803041, 0.357269), + (0.421908, 0.805774, 0.351910), + (0.430983, 0.808473, 0.346476), + (0.440137, 0.811138, 0.340967), + (0.449368, 0.813768, 0.335384), + (0.458674, 0.816363, 0.329727), + (0.468053, 0.818921, 0.323998), + (0.477504, 0.821444, 0.318195), + (0.487026, 0.823929, 0.312321), + (0.496615, 0.826376, 0.306377), + (0.506271, 0.828786, 0.300362), + (0.515992, 0.831158, 0.294279), + (0.525776, 0.833491, 0.288127), + (0.535621, 0.835785, 0.281908), + (0.545524, 0.838039, 0.275626), + (0.555484, 0.840254, 0.269281), + (0.565498, 0.842430, 0.262877), + (0.575563, 0.844566, 0.256415), + (0.585678, 0.846661, 0.249897), + (0.595839, 0.848717, 0.243329), + (0.606045, 0.850733, 0.236712), + (0.616293, 0.852709, 0.230052), + (0.626579, 0.854645, 0.223353), + (0.636902, 0.856542, 0.216620), + (0.647257, 0.858400, 0.209861), + (0.657642, 0.860219, 0.203082), + (0.668054, 0.861999, 0.196293), + (0.678489, 0.863742, 0.189503), + (0.688944, 0.865448, 0.182725), + (0.699415, 0.867117, 0.175971), + (0.709898, 0.868751, 0.169257), + (0.720391, 0.870350, 0.162603), + (0.730889, 0.871916, 0.156029), + (0.741388, 0.873449, 0.149561), + (0.751884, 0.874951, 0.143228), + (0.762373, 0.876424, 0.137064), + (0.772852, 0.877868, 0.131109), + (0.783315, 0.879285, 0.125405), + (0.793760, 0.880678, 0.120005), + (0.804182, 0.882046, 0.114965), + (0.814576, 0.883393, 0.110347), + (0.824940, 0.884720, 0.106217), + (0.835270, 0.886029, 0.102646), + (0.845561, 0.887322, 0.099702), + (0.855810, 0.888601, 0.097452), + (0.866013, 0.889868, 0.095953), + (0.876168, 0.891125, 0.095250), + (0.886271, 0.892374, 0.095374), + (0.896320, 0.893616, 0.096335), + (0.906311, 0.894855, 0.098125), + (0.916242, 0.896091, 0.100717), + (0.926106, 0.897330, 0.104071), + (0.935904, 0.898570, 0.108131), + (0.945636, 0.899815, 0.112838), + (0.955300, 0.901065, 0.118128), + (0.964894, 0.902323, 0.123941), + (0.974417, 0.903590, 0.130215), + (0.983868, 0.904867, 0.136897), + (0.993248, 0.906157, 0.143936), + ) +) -__all__ = ['magma', 'inferno', 'plasma', 'viridis'] +__all__ = ["magma", "inferno", "plasma", "viridis"] diff --git a/tensorboard/plugins/beholder/file_system_tools.py b/tensorboard/plugins/beholder/file_system_tools.py index f533182e10..296edcf219 100644 --- a/tensorboard/plugins/beholder/file_system_tools.py +++ b/tensorboard/plugins/beholder/file_system_tools.py @@ -27,42 +27,42 @@ logger = tb_logging.get_logger() -def write_file(contents, path, mode='wb'): - with tf.io.gfile.GFile(path, mode) as new_file: - new_file.write(contents) +def write_file(contents, path, mode="wb"): + with tf.io.gfile.GFile(path, mode) as new_file: + new_file.write(contents) def read_tensor_summary(path): - with tf.io.gfile.GFile(path, 'rb') as summary_file: - summary_string = summary_file.read() + with tf.io.gfile.GFile(path, "rb") as summary_file: + summary_string = summary_file.read() - if not summary_string: - raise message.DecodeError('Empty summary.') + if not summary_string: + raise message.DecodeError("Empty summary.") - summary_proto = summary_pb2.Summary() - summary_proto.ParseFromString(summary_string) - tensor_proto = summary_proto.value[0].tensor - array = tensor_util.make_ndarray(tensor_proto) + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_string) + tensor_proto = summary_proto.value[0].tensor + array = tensor_util.make_ndarray(tensor_proto) - return array + return array def write_pickle(obj, path): - with tf.io.gfile.GFile(path, 'wb') as new_file: - pickle.dump(obj, new_file) + with tf.io.gfile.GFile(path, "wb") as new_file: + pickle.dump(obj, new_file) def read_pickle(path, default=None): - try: - with tf.io.gfile.GFile(path, 'rb') as pickle_file: - result = pickle.load(pickle_file) - - except (IOError, EOFError, ValueError, tf.errors.NotFoundError) as e: - if not isinstance(e, tf.errors.NotFoundError): - logger.error('Error reading pickle value: %s', e) - if default is not None: - result = default - else: - raise - - return result + try: + with tf.io.gfile.GFile(path, "rb") as pickle_file: + result = pickle.load(pickle_file) + + except (IOError, EOFError, ValueError, tf.errors.NotFoundError) as e: + if not isinstance(e, tf.errors.NotFoundError): + logger.error("Error reading pickle value: %s", e) + if default is not None: + result = default + else: + raise + + return result diff --git a/tensorboard/plugins/beholder/im_util.py b/tensorboard/plugins/beholder/im_util.py index fd80ff9b30..db3e6e9ef6 100644 --- a/tensorboard/plugins/beholder/im_util.py +++ b/tensorboard/plugins/beholder/im_util.py @@ -28,130 +28,128 @@ # pylint: disable=not-context-manager + def global_extrema(arrays): - return min([x.min() for x in arrays]), max([x.max() for x in arrays]) + return min([x.min() for x in arrays]), max([x.max() for x in arrays]) def scale_sections(sections, scaling_scope): - ''' + """ input: unscaled sections. returns: sections scaled to [0, 255] - ''' - new_sections = [] + """ + new_sections = [] - if scaling_scope == 'layer': - for section in sections: - new_sections.append(scale_image_for_display(section)) + if scaling_scope == "layer": + for section in sections: + new_sections.append(scale_image_for_display(section)) - elif scaling_scope == 'network': - global_min, global_max = global_extrema(sections) + elif scaling_scope == "network": + global_min, global_max = global_extrema(sections) - for section in sections: - new_sections.append(scale_image_for_display(section, - global_min, - global_max)) - return new_sections + for section in sections: + new_sections.append( + scale_image_for_display(section, global_min, global_max) + ) + return new_sections def scale_image_for_display(image, minimum=None, maximum=None): - image = image.astype(float) + image = image.astype(float) - minimum = image.min() if minimum is None else minimum - image -= minimum + minimum = image.min() if minimum is None else minimum + image -= minimum - maximum = image.max() if maximum is None else maximum + maximum = image.max() if maximum is None else maximum - if maximum == 0: - return image - else: - image *= 255 / maximum - return image.astype(np.uint8) + if maximum == 0: + return image + else: + image *= 255 / maximum + return image.astype(np.uint8) def pad_to_shape(array, shape, constant=245): - padding = [] + padding = [] - for actual_dim, target_dim in zip(array.shape, shape): - start_padding = 0 - end_padding = target_dim - actual_dim + for actual_dim, target_dim in zip(array.shape, shape): + start_padding = 0 + end_padding = target_dim - actual_dim - padding.append((start_padding, end_padding)) + padding.append((start_padding, end_padding)) - return np.pad(array, padding, mode='constant', constant_values=constant) + return np.pad(array, padding, mode="constant", constant_values=constant) -def apply_colormap(image, colormap='magma'): - if colormap == 'grayscale': - return image - cm = getattr(colormaps, colormap) - return image if cm is None else cm[image] +def apply_colormap(image, colormap="magma"): + if colormap == "grayscale": + return image + cm = getattr(colormaps, colormap) + return image if cm is None else cm[image] class PNGDecoder(op_evaluator.PersistentOpEvaluator): + def __init__(self): + super(PNGDecoder, self).__init__() + self._image_placeholder = None + self._decode_op = None - def __init__(self): - super(PNGDecoder, self).__init__() - self._image_placeholder = None - self._decode_op = None - - - def initialize_graph(self): - self._image_placeholder = tf.compat.v1.placeholder(dtype=tf.string) - self._decode_op = tf.image.decode_png(self._image_placeholder) + def initialize_graph(self): + self._image_placeholder = tf.compat.v1.placeholder(dtype=tf.string) + self._decode_op = tf.image.decode_png(self._image_placeholder) - - # pylint: disable=arguments-differ - def run(self, image): - return self._decode_op.eval(feed_dict={ - self._image_placeholder: image, - }) + # pylint: disable=arguments-differ + def run(self, image): + return self._decode_op.eval(feed_dict={self._image_placeholder: image,}) class Resizer(op_evaluator.PersistentOpEvaluator): - - def __init__(self): - super(Resizer, self).__init__() - self._image_placeholder = None - self._size_placeholder = None - self._resize_op = None - - - def initialize_graph(self): - self._image_placeholder = tf.compat.v1.placeholder(dtype=tf.float32) - self._size_placeholder = tf.compat.v1.placeholder(dtype=tf.int32) - self._resize_op = tf.compat.v1.image.resize_nearest_neighbor( - self._image_placeholder, - self._size_placeholder, - ) - - # pylint: disable=arguments-differ - def run(self, image, height, width): - if len(image.shape) == 2: - image = image.reshape([image.shape[0], image.shape[1], 1]) - - resized = np.squeeze(self._resize_op.eval(feed_dict={ - self._image_placeholder: [image], - self._size_placeholder: [height, width] - })) - - return resized + def __init__(self): + super(Resizer, self).__init__() + self._image_placeholder = None + self._size_placeholder = None + self._resize_op = None + + def initialize_graph(self): + self._image_placeholder = tf.compat.v1.placeholder(dtype=tf.float32) + self._size_placeholder = tf.compat.v1.placeholder(dtype=tf.int32) + self._resize_op = tf.compat.v1.image.resize_nearest_neighbor( + self._image_placeholder, self._size_placeholder, + ) + + # pylint: disable=arguments-differ + def run(self, image, height, width): + if len(image.shape) == 2: + image = image.reshape([image.shape[0], image.shape[1], 1]) + + resized = np.squeeze( + self._resize_op.eval( + feed_dict={ + self._image_placeholder: [image], + self._size_placeholder: [height, width], + } + ) + ) + + return resized decode_png = PNGDecoder() resize = Resizer() + def read_image(filename): - with tf.io.gfile.GFile(filename, 'rb') as image_file: - return np.array(decode_png(image_file.read())) + with tf.io.gfile.GFile(filename, "rb") as image_file: + return np.array(decode_png(image_file.read())) def write_image(array, filename): - with tf.io.gfile.GFile(filename, 'w') as image_file: - image_file.write(encoder.encode_png(array)) + with tf.io.gfile.GFile(filename, "w") as image_file: + image_file.write(encoder.encode_png(array)) def get_image_relative_to_script(filename): - script_directory = os.path.dirname(__file__) - filename = os.path.join(script_directory, 'resources', filename) + script_directory = os.path.dirname(__file__) + filename = os.path.join(script_directory, "resources", filename) - return read_image(filename) + return read_image(filename) diff --git a/tensorboard/plugins/beholder/shared_config.py b/tensorboard/plugins/beholder/shared_config.py index b987a50cfb..0660ddfa5d 100644 --- a/tensorboard/plugins/beholder/shared_config.py +++ b/tensorboard/plugins/beholder/shared_config.py @@ -16,22 +16,22 @@ from __future__ import division from __future__ import print_function -PLUGIN_NAME = 'beholder' -TAG_NAME = 'beholder-frame' -SUMMARY_FILENAME = 'frame.summary' -CONFIG_FILENAME = 'config.pkl' -SECTION_INFO_FILENAME = 'section-info.pkl' -SUMMARY_COLLECTION_KEY_NAME = 'summaries_beholder' +PLUGIN_NAME = "beholder" +TAG_NAME = "beholder-frame" +SUMMARY_FILENAME = "frame.summary" +CONFIG_FILENAME = "config.pkl" +SECTION_INFO_FILENAME = "section-info.pkl" +SUMMARY_COLLECTION_KEY_NAME = "summaries_beholder" DEFAULT_CONFIG = { - 'values': 'trainable_variables', - 'mode': 'variance', - 'scaling': 'layer', - 'window_size': 15, - 'FPS': 10, - 'is_recording': False, - 'show_all': False, - 'colormap': 'magma' + "values": "trainable_variables", + "mode": "variance", + "scaling": "layer", + "window_size": 15, + "FPS": 10, + "is_recording": False, + "show_all": False, + "colormap": "magma", } SECTION_HEIGHT = 128 diff --git a/tensorboard/plugins/beholder/video_writing.py b/tensorboard/plugins/beholder/video_writing.py index b2b2d2598a..d02af7c338 100644 --- a/tensorboard/plugins/beholder/video_writing.py +++ b/tensorboard/plugins/beholder/video_writing.py @@ -31,172 +31,194 @@ class VideoWriter(object): - """Video file writer that can use different output types. - - Each VideoWriter instance writes video files to a specified directory, using - the first available VideoOutput from the provided list. - """ - - def __init__(self, directory, outputs): - self.directory = directory - # Filter to the available outputs - self.outputs = [out for out in outputs if out.available()] - if not self.outputs: - raise IOError('No available video outputs') - self.output_index = 0 - self.output = None - self.frame_shape = None - - def current_output(self): - return self.outputs[self.output_index] - - def write_frame(self, np_array): - # Reset whenever we encounter a new frame shape. - if self.frame_shape != np_array.shape: - if self.output: - self.output.close() - self.output = None - self.frame_shape = np_array.shape - logger.info('Starting video with frame shape: %s', self.frame_shape) - # Write the frame, advancing across output types as necessary. - original_output_index = self.output_index - for self.output_index in range(original_output_index, len(self.outputs)): - try: - if not self.output: - new_output = self.outputs[self.output_index] - if self.output_index > original_output_index: - logger.warn( - 'Falling back to video output %s', new_output.name()) - self.output = new_output(self.directory, self.frame_shape) - self.output.emit_frame(np_array) - return - except (IOError, OSError) as e: - logger.warn( - 'Video output type %s not available: %s', - self.current_output().name(), str(e)) + """Video file writer that can use different output types. + + Each VideoWriter instance writes video files to a specified + directory, using the first available VideoOutput from the provided + list. + """ + + def __init__(self, directory, outputs): + self.directory = directory + # Filter to the available outputs + self.outputs = [out for out in outputs if out.available()] + if not self.outputs: + raise IOError("No available video outputs") + self.output_index = 0 + self.output = None + self.frame_shape = None + + def current_output(self): + return self.outputs[self.output_index] + + def write_frame(self, np_array): + # Reset whenever we encounter a new frame shape. + if self.frame_shape != np_array.shape: + if self.output: + self.output.close() + self.output = None + self.frame_shape = np_array.shape + logger.info("Starting video with frame shape: %s", self.frame_shape) + # Write the frame, advancing across output types as necessary. + original_output_index = self.output_index + for self.output_index in range( + original_output_index, len(self.outputs) + ): + try: + if not self.output: + new_output = self.outputs[self.output_index] + if self.output_index > original_output_index: + logger.warn( + "Falling back to video output %s", new_output.name() + ) + self.output = new_output(self.directory, self.frame_shape) + self.output.emit_frame(np_array) + return + except (IOError, OSError) as e: + logger.warn( + "Video output type %s not available: %s", + self.current_output().name(), + str(e), + ) + if self.output: + self.output.close() + self.output = None + raise IOError("Exhausted available video outputs") + + def finish(self): if self.output: - self.output.close() + self.output.close() self.output = None - raise IOError('Exhausted available video outputs') - - def finish(self): - if self.output: - self.output.close() - self.output = None - self.frame_shape = None - # Reconsider failed outputs when video is manually restarted. - self.output_index = 0 + self.frame_shape = None + # Reconsider failed outputs when video is manually restarted. + self.output_index = 0 @six.add_metaclass(abc.ABCMeta) class VideoOutput(object): - """Base class for video outputs supported by VideoWriter.""" + """Base class for video outputs supported by VideoWriter.""" - # Would add @abc.abstractmethod in python 3.3+ - @classmethod - def available(cls): - raise NotImplementedError() + # Would add @abc.abstractmethod in python 3.3+ + @classmethod + def available(cls): + raise NotImplementedError() - @classmethod - def name(cls): - return cls.__name__ + @classmethod + def name(cls): + return cls.__name__ - @abc.abstractmethod - def emit_frame(self, np_array): - raise NotImplementedError() + @abc.abstractmethod + def emit_frame(self, np_array): + raise NotImplementedError() - @abc.abstractmethod - def close(self): - raise NotImplementedError() + @abc.abstractmethod + def close(self): + raise NotImplementedError() class PNGVideoOutput(VideoOutput): - """Video output implemented by writing individual PNGs to disk.""" + """Video output implemented by writing individual PNGs to disk.""" - @classmethod - def available(cls): - return True + @classmethod + def available(cls): + return True - def __init__(self, directory, frame_shape): - del frame_shape # unused - self.directory = directory + '/video-frames-{}'.format(time.time()) - self.frame_num = 0 - tf.io.gfile.makedirs(self.directory) + def __init__(self, directory, frame_shape): + del frame_shape # unused + self.directory = directory + "/video-frames-{}".format(time.time()) + self.frame_num = 0 + tf.io.gfile.makedirs(self.directory) - def emit_frame(self, np_array): - filename = self.directory + '/{:05}.png'.format(self.frame_num) - im_util.write_image(np_array.astype(np.uint8), filename) - self.frame_num += 1 + def emit_frame(self, np_array): + filename = self.directory + "/{:05}.png".format(self.frame_num) + im_util.write_image(np_array.astype(np.uint8), filename) + self.frame_num += 1 - def close(self): - pass + def close(self): + pass class FFmpegVideoOutput(VideoOutput): - """Video output implemented by streaming to FFmpeg with .mp4 output.""" - - @classmethod - def available(cls): - # Silently check if ffmpeg is available. - try: - with open(os.devnull, 'wb') as devnull: - subprocess.check_call( - ['ffmpeg', '-version'], stdout=devnull, stderr=devnull) - return True - except (OSError, subprocess.CalledProcessError): - return False - - def __init__(self, directory, frame_shape): - self.filename = directory + '/video-{}.webm'.format(time.time()) - if len(frame_shape) != 3: - raise ValueError( - 'Expected rank-3 array for frame, got %s' % str(frame_shape)) - # Set input pixel format based on channel count. - if frame_shape[2] == 1: - pix_fmt = 'gray' - elif frame_shape[2] == 3: - pix_fmt = 'rgb24' - else: - raise ValueError('Unsupported channel count %d' % frame_shape[2]) - - command = [ - 'ffmpeg', - '-y', # Overwite output - # Input options - raw video file format and codec. - '-f', 'rawvideo', - '-vcodec', 'rawvideo', - '-s', '%dx%d' % (frame_shape[1], frame_shape[0]), # Width x height. - '-pix_fmt', pix_fmt, - '-r', '15', # Frame rate: arbitrarily use 15 frames per second. - '-i', '-', # Use stdin. - '-an', # No audio. - # Output options - use lossless VP9 codec inside .webm. - '-vcodec', 'libvpx-vp9', - '-lossless', '1', - # Using YUV is most compatible, though conversion from RGB skews colors. - '-pix_fmt', 'yuv420p', - self.filename - ] - PIPE = subprocess.PIPE - self.ffmpeg = subprocess.Popen( - command, stdin=PIPE, stdout=PIPE, stderr=PIPE) - - def _handle_error(self): - _, stderr = self.ffmpeg.communicate() - bar = '=' * 40 - logger.error( - 'Error writing to FFmpeg:\n%s\n%s\n%s', bar, stderr.rstrip('\n'), bar) - - def emit_frame(self, np_array): - try: - self.ffmpeg.stdin.write(np_array.tobytes()) - self.ffmpeg.stdin.flush() - except IOError: - self._handle_error() - raise IOError('Failure invoking FFmpeg') - - def close(self): - if self.ffmpeg.poll() is None: - # Close stdin and consume and discard stderr/stdout. - self.ffmpeg.communicate() - self.ffmpeg = None + """Video output implemented by streaming to FFmpeg with .mp4 output.""" + + @classmethod + def available(cls): + # Silently check if ffmpeg is available. + try: + with open(os.devnull, "wb") as devnull: + subprocess.check_call( + ["ffmpeg", "-version"], stdout=devnull, stderr=devnull + ) + return True + except (OSError, subprocess.CalledProcessError): + return False + + def __init__(self, directory, frame_shape): + self.filename = directory + "/video-{}.webm".format(time.time()) + if len(frame_shape) != 3: + raise ValueError( + "Expected rank-3 array for frame, got %s" % str(frame_shape) + ) + # Set input pixel format based on channel count. + if frame_shape[2] == 1: + pix_fmt = "gray" + elif frame_shape[2] == 3: + pix_fmt = "rgb24" + else: + raise ValueError("Unsupported channel count %d" % frame_shape[2]) + + command = [ + "ffmpeg", + "-y", # Overwite output + # Input options - raw video file format and codec. + "-f", + "rawvideo", + "-vcodec", + "rawvideo", + "-s", + "%dx%d" % (frame_shape[1], frame_shape[0]), # Width x height. + "-pix_fmt", + pix_fmt, + "-r", + "15", # Frame rate: arbitrarily use 15 frames per second. + "-i", + "-", # Use stdin. + "-an", # No audio. + # Output options - use lossless VP9 codec inside .webm. + "-vcodec", + "libvpx-vp9", + "-lossless", + "1", + # Using YUV is most compatible, though conversion from RGB skews colors. + "-pix_fmt", + "yuv420p", + self.filename, + ] + PIPE = subprocess.PIPE + self.ffmpeg = subprocess.Popen( + command, stdin=PIPE, stdout=PIPE, stderr=PIPE + ) + + def _handle_error(self): + _, stderr = self.ffmpeg.communicate() + bar = "=" * 40 + logger.error( + "Error writing to FFmpeg:\n%s\n%s\n%s", + bar, + stderr.rstrip("\n"), + bar, + ) + + def emit_frame(self, np_array): + try: + self.ffmpeg.stdin.write(np_array.tobytes()) + self.ffmpeg.stdin.flush() + except IOError: + self._handle_error() + raise IOError("Failure invoking FFmpeg") + + def close(self): + if self.ffmpeg.poll() is None: + # Close stdin and consume and discard stderr/stdout. + self.ffmpeg.communicate() + self.ffmpeg = None diff --git a/tensorboard/plugins/beholder/visualizer.py b/tensorboard/plugins/beholder/visualizer.py index b8ea95c6ec..46fbca7a19 100644 --- a/tensorboard/plugins/beholder/visualizer.py +++ b/tensorboard/plugins/beholder/visualizer.py @@ -23,286 +23,300 @@ import tensorflow as tf from tensorboard.plugins.beholder import im_util -from tensorboard.plugins.beholder.shared_config import SECTION_HEIGHT,\ - IMAGE_WIDTH, DEFAULT_CONFIG, SECTION_INFO_FILENAME +from tensorboard.plugins.beholder.shared_config import ( + SECTION_HEIGHT, + IMAGE_WIDTH, + DEFAULT_CONFIG, + SECTION_INFO_FILENAME, +) from tensorboard.plugins.beholder.file_system_tools import write_pickle MIN_SQUARE_SIZE = 3 class Visualizer(object): - - def __init__(self, logdir): - self.logdir = logdir - self.sections_over_time = deque([], DEFAULT_CONFIG['window_size']) - self.config = dict(DEFAULT_CONFIG) - self.old_config = dict(DEFAULT_CONFIG) - - - def _reshape_conv_array(self, array, section_height, image_width): - '''Reshape a rank 4 array to be rank 2, where each column of block_width is - a filter, and each row of block height is an input channel. For example: - - [[[[ 11, 21, 31, 41], - [ 51, 61, 71, 81], - [ 91, 101, 111, 121]], - [[ 12, 22, 32, 42], - [ 52, 62, 72, 82], - [ 92, 102, 112, 122]], - [[ 13, 23, 33, 43], - [ 53, 63, 73, 83], - [ 93, 103, 113, 123]]], - [[[ 14, 24, 34, 44], - [ 54, 64, 74, 84], - [ 94, 104, 114, 124]], - [[ 15, 25, 35, 45], - [ 55, 65, 75, 85], - [ 95, 105, 115, 125]], - [[ 16, 26, 36, 46], - [ 56, 66, 76, 86], - [ 96, 106, 116, 126]]], - [[[ 17, 27, 37, 47], - [ 57, 67, 77, 87], - [ 97, 107, 117, 127]], - [[ 18, 28, 38, 48], - [ 58, 68, 78, 88], - [ 98, 108, 118, 128]], - [[ 19, 29, 39, 49], - [ 59, 69, 79, 89], - [ 99, 109, 119, 129]]]] - - should be reshaped to: - - [[ 11, 12, 13, 21, 22, 23, 31, 32, 33, 41, 42, 43], - [ 14, 15, 16, 24, 25, 26, 34, 35, 36, 44, 45, 46], - [ 17, 18, 19, 27, 28, 29, 37, 38, 39, 47, 48, 49], - [ 51, 52, 53, 61, 62, 63, 71, 72, 73, 81, 82, 83], - [ 54, 55, 56, 64, 65, 66, 74, 75, 76, 84, 85, 86], - [ 57, 58, 59, 67, 68, 69, 77, 78, 79, 87, 88, 89], - [ 91, 92, 93, 101, 102, 103, 111, 112, 113, 121, 122, 123], - [ 94, 95, 96, 104, 105, 106, 114, 115, 116, 124, 125, 126], - [ 97, 98, 99, 107, 108, 109, 117, 118, 119, 127, 128, 129]] - ''' - - # E.g. [100, 24, 24, 10]: this shouldn't be reshaped like normal. - if array.shape[1] == array.shape[2] and array.shape[0] != array.shape[1]: - array = np.rollaxis(np.rollaxis(array, 2), 2) - - block_height, block_width, in_channels = array.shape[:3] - rows = [] - - max_element_count = section_height * int(image_width / MIN_SQUARE_SIZE) - element_count = 0 - - for i in range(in_channels): - rows.append(array[:, :, i, :].reshape(block_height, -1, order='F')) - - # This line should be left in this position. Gives it one extra row. - if element_count >= max_element_count and not self.config['show_all']: - break - - element_count += block_height * in_channels * block_width - - return np.vstack(rows) - - - def _reshape_irregular_array(self, array, section_height, image_width): - '''Reshapes arrays of ranks not in {1, 2, 4} - ''' - section_area = section_height * image_width - flattened_array = np.ravel(array) - - if not self.config['show_all']: - flattened_array = flattened_array[:int(section_area/MIN_SQUARE_SIZE)] - - cell_count = np.prod(flattened_array.shape) - cell_area = section_area / cell_count - - cell_side_length = max(1, floor(sqrt(cell_area))) - row_count = max(1, int(section_height / cell_side_length)) - col_count = int(cell_count / row_count) - - # Reshape the truncated array so that it has the same aspect ratio as - # the section. - - # Truncate whatever remaining values there are that don't fit. Hopefully - # it doesn't matter that the last few (< section count) aren't there. - section = np.reshape(flattened_array[:row_count * col_count], - (row_count, col_count)) - - return section - - - def _determine_image_width(self, arrays, show_all): - final_width = IMAGE_WIDTH - - if show_all: - for array in arrays: + def __init__(self, logdir): + self.logdir = logdir + self.sections_over_time = deque([], DEFAULT_CONFIG["window_size"]) + self.config = dict(DEFAULT_CONFIG) + self.old_config = dict(DEFAULT_CONFIG) + + def _reshape_conv_array(self, array, section_height, image_width): + """Reshape a rank 4 array to be rank 2, where each column of + block_width is a filter, and each row of block height is an input + channel. For example: + + [[[[ 11, 21, 31, 41], + [ 51, 61, 71, 81], + [ 91, 101, 111, 121]], + [[ 12, 22, 32, 42], + [ 52, 62, 72, 82], + [ 92, 102, 112, 122]], + [[ 13, 23, 33, 43], + [ 53, 63, 73, 83], + [ 93, 103, 113, 123]]], + [[[ 14, 24, 34, 44], + [ 54, 64, 74, 84], + [ 94, 104, 114, 124]], + [[ 15, 25, 35, 45], + [ 55, 65, 75, 85], + [ 95, 105, 115, 125]], + [[ 16, 26, 36, 46], + [ 56, 66, 76, 86], + [ 96, 106, 116, 126]]], + [[[ 17, 27, 37, 47], + [ 57, 67, 77, 87], + [ 97, 107, 117, 127]], + [[ 18, 28, 38, 48], + [ 58, 68, 78, 88], + [ 98, 108, 118, 128]], + [[ 19, 29, 39, 49], + [ 59, 69, 79, 89], + [ 99, 109, 119, 129]]]] + + should be reshaped to: + + [[ 11, 12, 13, 21, 22, 23, 31, 32, 33, 41, 42, 43], + [ 14, 15, 16, 24, 25, 26, 34, 35, 36, 44, 45, 46], + [ 17, 18, 19, 27, 28, 29, 37, 38, 39, 47, 48, 49], + [ 51, 52, 53, 61, 62, 63, 71, 72, 73, 81, 82, 83], + [ 54, 55, 56, 64, 65, 66, 74, 75, 76, 84, 85, 86], + [ 57, 58, 59, 67, 68, 69, 77, 78, 79, 87, 88, 89], + [ 91, 92, 93, 101, 102, 103, 111, 112, 113, 121, 122, 123], + [ 94, 95, 96, 104, 105, 106, 114, 115, 116, 124, 125, 126], + [ 97, 98, 99, 107, 108, 109, 117, 118, 119, 127, 128, 129]] + """ + + # E.g. [100, 24, 24, 10]: this shouldn't be reshaped like normal. + if ( + array.shape[1] == array.shape[2] + and array.shape[0] != array.shape[1] + ): + array = np.rollaxis(np.rollaxis(array, 2), 2) + + block_height, block_width, in_channels = array.shape[:3] + rows = [] + + max_element_count = section_height * int(image_width / MIN_SQUARE_SIZE) + element_count = 0 + + for i in range(in_channels): + rows.append(array[:, :, i, :].reshape(block_height, -1, order="F")) + + # This line should be left in this position. Gives it one extra row. + if ( + element_count >= max_element_count + and not self.config["show_all"] + ): + break + + element_count += block_height * in_channels * block_width + + return np.vstack(rows) + + def _reshape_irregular_array(self, array, section_height, image_width): + """Reshapes arrays of ranks not in {1, 2, 4}""" + section_area = section_height * image_width + flattened_array = np.ravel(array) + + if not self.config["show_all"]: + flattened_array = flattened_array[ + : int(section_area / MIN_SQUARE_SIZE) + ] + + cell_count = np.prod(flattened_array.shape) + cell_area = section_area / cell_count + + cell_side_length = max(1, floor(sqrt(cell_area))) + row_count = max(1, int(section_height / cell_side_length)) + col_count = int(cell_count / row_count) + + # Reshape the truncated array so that it has the same aspect ratio as + # the section. + + # Truncate whatever remaining values there are that don't fit. Hopefully + # it doesn't matter that the last few (< section count) aren't there. + section = np.reshape( + flattened_array[: row_count * col_count], (row_count, col_count) + ) + + return section + + def _determine_image_width(self, arrays, show_all): + final_width = IMAGE_WIDTH + + if show_all: + for array in arrays: + rank = len(array.shape) + + if rank == 1: + width = len(array) + elif rank == 2: + width = array.shape[1] + elif rank == 4: + width = array.shape[1] * array.shape[3] + else: + width = IMAGE_WIDTH + + if width > final_width: + final_width = width + + return final_width + + def _determine_section_height(self, array, show_all): rank = len(array.shape) - - if rank == 1: - width = len(array) - elif rank == 2: - width = array.shape[1] - elif rank == 4: - width = array.shape[1] * array.shape[3] - else: - width = IMAGE_WIDTH - - if width > final_width: - final_width = width - - return final_width - - - def _determine_section_height(self, array, show_all): - rank = len(array.shape) - height = SECTION_HEIGHT - - if show_all: - if rank == 1: height = SECTION_HEIGHT - if rank == 2: - height = max(SECTION_HEIGHT, array.shape[0]) - elif rank == 4: - height = max(SECTION_HEIGHT, array.shape[0] * array.shape[2]) - else: - height = max(SECTION_HEIGHT, np.prod(array.shape) // IMAGE_WIDTH) - - return height - - def _arrays_to_sections(self, arrays): - ''' + if show_all: + if rank == 1: + height = SECTION_HEIGHT + if rank == 2: + height = max(SECTION_HEIGHT, array.shape[0]) + elif rank == 4: + height = max(SECTION_HEIGHT, array.shape[0] * array.shape[2]) + else: + height = max( + SECTION_HEIGHT, np.prod(array.shape) // IMAGE_WIDTH + ) + + return height + + def _arrays_to_sections(self, arrays): + """ input: unprocessed numpy arrays. returns: columns of the size that they will appear in the image, not scaled for display. That needs to wait until after variance is computed. - ''' - sections = [] - sections_to_resize_later = {} - show_all = self.config['show_all'] - image_width = self._determine_image_width(arrays, show_all) - - for array_number, array in enumerate(arrays): - rank = len(array.shape) - section_height = self._determine_section_height(array, show_all) - - if rank == 1: - section = np.atleast_2d(array) - elif rank == 2: - section = array - elif rank == 4: - section = self._reshape_conv_array(array, section_height, image_width) - else: - section = self._reshape_irregular_array(array, - section_height, - image_width) - # Only calculate variance for what we have to. In some cases (biases), - # the section is larger than the array, so we don't want to calculate - # variance for the same value over and over - better to resize later. - # About a 6-7x speedup for a big network with a big variance window. - section_size = section_height * image_width - array_size = np.prod(array.shape) - - if section_size > array_size: - sections.append(section) - sections_to_resize_later[array_number] = section_height - else: - sections.append(im_util.resize(section, section_height, image_width)) - - self.sections_over_time.append(sections) - - if self.config['mode'] == 'variance': - sections = self._sections_to_variance_sections(self.sections_over_time) - - for array_number, height in sections_to_resize_later.items(): - sections[array_number] = im_util.resize(sections[array_number], - height, - image_width) - return sections - - - def _sections_to_variance_sections(self, sections_over_time): - '''Computes the variance of corresponding sections over time. - - Returns: - a list of np arrays. - ''' - variance_sections = [] - - for i in range(len(sections_over_time[0])): - time_sections = [sections[i] for sections in sections_over_time] - variance = np.var(time_sections, axis=0) - variance_sections.append(variance) - - return variance_sections - - - def _sections_to_image(self, sections): - padding_size = 5 - - sections = im_util.scale_sections(sections, self.config['scaling']) - - final_stack = [sections[0]] - padding = np.zeros((padding_size, sections[0].shape[1])) - - for section in sections[1:]: - final_stack.append(padding) - final_stack.append(section) - - return np.vstack(final_stack).astype(np.uint8) - - - def _maybe_clear_deque(self): - '''Clears the deque if certain parts of the config have changed.''' - - for config_item in ['values', 'mode', 'show_all']: - if self.config[config_item] != self.old_config[config_item]: - self.sections_over_time.clear() - break - - self.old_config = self.config - - window_size = self.config['window_size'] - if window_size != self.sections_over_time.maxlen: - self.sections_over_time = deque(self.sections_over_time, window_size) - - - def _save_section_info(self, arrays, sections): - infos = [] - - if self.config['values'] == 'trainable_variables': - names = [x.name for x in tf.compat.v1.trainable_variables()] - else: - names = range(len(arrays)) - - for array, section, name in zip(arrays, sections, names): - info = {} + """ + sections = [] + sections_to_resize_later = {} + show_all = self.config["show_all"] + image_width = self._determine_image_width(arrays, show_all) + + for array_number, array in enumerate(arrays): + rank = len(array.shape) + section_height = self._determine_section_height(array, show_all) + + if rank == 1: + section = np.atleast_2d(array) + elif rank == 2: + section = array + elif rank == 4: + section = self._reshape_conv_array( + array, section_height, image_width + ) + else: + section = self._reshape_irregular_array( + array, section_height, image_width + ) + # Only calculate variance for what we have to. In some cases (biases), + # the section is larger than the array, so we don't want to calculate + # variance for the same value over and over - better to resize later. + # About a 6-7x speedup for a big network with a big variance window. + section_size = section_height * image_width + array_size = np.prod(array.shape) + + if section_size > array_size: + sections.append(section) + sections_to_resize_later[array_number] = section_height + else: + sections.append( + im_util.resize(section, section_height, image_width) + ) + + self.sections_over_time.append(sections) + + if self.config["mode"] == "variance": + sections = self._sections_to_variance_sections( + self.sections_over_time + ) + + for array_number, height in sections_to_resize_later.items(): + sections[array_number] = im_util.resize( + sections[array_number], height, image_width + ) + return sections + + def _sections_to_variance_sections(self, sections_over_time): + """Computes the variance of corresponding sections over time. + + Returns: + a list of np arrays. + """ + variance_sections = [] + + for i in range(len(sections_over_time[0])): + time_sections = [sections[i] for sections in sections_over_time] + variance = np.var(time_sections, axis=0) + variance_sections.append(variance) + + return variance_sections + + def _sections_to_image(self, sections): + padding_size = 5 + + sections = im_util.scale_sections(sections, self.config["scaling"]) + + final_stack = [sections[0]] + padding = np.zeros((padding_size, sections[0].shape[1])) + + for section in sections[1:]: + final_stack.append(padding) + final_stack.append(section) + + return np.vstack(final_stack).astype(np.uint8) + + def _maybe_clear_deque(self): + """Clears the deque if certain parts of the config have changed.""" + + for config_item in ["values", "mode", "show_all"]: + if self.config[config_item] != self.old_config[config_item]: + self.sections_over_time.clear() + break + + self.old_config = self.config + + window_size = self.config["window_size"] + if window_size != self.sections_over_time.maxlen: + self.sections_over_time = deque( + self.sections_over_time, window_size + ) + + def _save_section_info(self, arrays, sections): + infos = [] + + if self.config["values"] == "trainable_variables": + names = [x.name for x in tf.compat.v1.trainable_variables()] + else: + names = range(len(arrays)) - info['name'] = name - info['shape'] = str(array.shape) - info['min'] = '{:.3e}'.format(section.min()) - info['mean'] = '{:.3e}'.format(section.mean()) - info['max'] = '{:.3e}'.format(section.max()) - info['range'] = '{:.3e}'.format(section.max() - section.min()) - info['height'] = section.shape[0] + for array, section, name in zip(arrays, sections, names): + info = {} - infos.append(info) + info["name"] = name + info["shape"] = str(array.shape) + info["min"] = "{:.3e}".format(section.min()) + info["mean"] = "{:.3e}".format(section.mean()) + info["max"] = "{:.3e}".format(section.max()) + info["range"] = "{:.3e}".format(section.max() - section.min()) + info["height"] = section.shape[0] - write_pickle(infos, '{}/{}'.format(self.logdir, SECTION_INFO_FILENAME)) + infos.append(info) + write_pickle(infos, "{}/{}".format(self.logdir, SECTION_INFO_FILENAME)) - def build_frame(self, arrays): - self._maybe_clear_deque() + def build_frame(self, arrays): + self._maybe_clear_deque() - arrays = arrays if isinstance(arrays, list) else [arrays] + arrays = arrays if isinstance(arrays, list) else [arrays] - sections = self._arrays_to_sections(arrays) - self._save_section_info(arrays, sections) - final_image = self._sections_to_image(sections) - final_image = im_util.apply_colormap(final_image, self.config['colormap']) + sections = self._arrays_to_sections(arrays) + self._save_section_info(arrays, sections) + final_image = self._sections_to_image(sections) + final_image = im_util.apply_colormap( + final_image, self.config["colormap"] + ) - return final_image + return final_image - def update(self, config): - self.config = config + def update(self, config): + self.config = config diff --git a/tensorboard/plugins/core/core_plugin.py b/tensorboard/plugins/core/core_plugin.py index d09c2e5e23..4dfe678f1e 100644 --- a/tensorboard/plugins/core/core_plugin.py +++ b/tensorboard/plugins/core/core_plugin.py @@ -44,179 +44,199 @@ class CorePlugin(base_plugin.TBPlugin): - """Core plugin for TensorBoard. + """Core plugin for TensorBoard. - This plugin serves runs, configuration data, and static assets. This plugin - should always be present in a TensorBoard WSGI application. - """ - - plugin_name = 'core' - - def __init__(self, context): - """Instantiates CorePlugin. - - Args: - context: A base_plugin.TBContext instance. - """ - logdir_spec = context.flags.logdir_spec if context.flags else '' - self._logdir = context.logdir or logdir_spec - self._db_uri = context.db_uri - self._window_title = context.window_title - self._multiplexer = context.multiplexer - self._db_connection_provider = context.db_connection_provider - self._assets_zip_provider = context.assets_zip_provider - if context.flags and context.flags.generic_data == 'true': - self._data_provider = context.data_provider - else: - self._data_provider = None - - def is_active(self): - return True - - def get_plugin_apps(self): - apps = { - '/___rPc_sWiTcH___': self._send_404_without_logging, - '/audio': self._redirect_to_index, - '/data/environment': self._serve_environment, - '/data/logdir': self._serve_logdir, - '/data/runs': self._serve_runs, - '/data/experiments': self._serve_experiments, - '/data/experiment_runs': self._serve_experiment_runs, - '/data/window_properties': self._serve_window_properties, - '/events': self._redirect_to_index, - '/favicon.ico': self._send_404_without_logging, - '/graphs': self._redirect_to_index, - '/histograms': self._redirect_to_index, - '/images': self._redirect_to_index, - } - apps.update(self.get_resource_apps()) - return apps - - def get_resource_apps(self): - apps = {} - if not self._assets_zip_provider: - return apps - - with self._assets_zip_provider() as fp: - with zipfile.ZipFile(fp) as zip_: - for path in zip_.namelist(): - gzipped_asset_bytes = _gzip(zip_.read(path)) - wsgi_app = functools.partial( - self._serve_asset, path, gzipped_asset_bytes) - apps['/' + path] = wsgi_app - apps['/'] = apps['/index.html'] - return apps - - @wrappers.Request.application - def _send_404_without_logging(self, request): - return http_util.Respond(request, 'Not found', 'text/plain', code=404) - - @wrappers.Request.application - def _redirect_to_index(self, unused_request): - return utils.redirect('/') - - @wrappers.Request.application - def _serve_asset(self, path, gzipped_asset_bytes, request): - """Serves a pre-gzipped static asset from the zip file.""" - mimetype = mimetypes.guess_type(path)[0] or 'application/octet-stream' - return http_util.Respond( - request, gzipped_asset_bytes, mimetype, content_encoding='gzip') - - @wrappers.Request.application - def _serve_environment(self, request): - """Serve a JSON object containing some base properties used by the frontend. - - * data_location is either a path to a directory or an address to a - database (depending on which mode TensorBoard is running in). - * window_title is the title of the TensorBoard web page. + This plugin serves runs, configuration data, and static assets. This + plugin should always be present in a TensorBoard WSGI application. """ - if self._data_provider: - experiment = plugin_util.experiment_id(request.environ) - data_location = self._data_provider.data_location(experiment) - else: - data_location = self._logdir or self._db_uri - return http_util.Respond( - request, - { - 'data_location': data_location, - 'window_title': self._window_title, - }, - 'application/json') - - @wrappers.Request.application - def _serve_logdir(self, request): - """Respond with a JSON object containing this TensorBoard's logdir.""" - # TODO(chihuahua): Remove this method once the frontend instead uses the - # /data/environment route (and no deps throughout Google use the - # /data/logdir route). - return http_util.Respond( - request, {'logdir': self._logdir}, 'application/json') - - @wrappers.Request.application - def _serve_window_properties(self, request): - """Serve a JSON object containing this TensorBoard's window properties.""" - # TODO(chihuahua): Remove this method once the frontend instead uses the - # /data/environment route. - return http_util.Respond( - request, {'window_title': self._window_title}, 'application/json') - - @wrappers.Request.application - def _serve_runs(self, request): - """Serve a JSON array of run names, ordered by run started time. - - Sort order is by started time (aka first event time) with empty times sorted - last, and then ties are broken by sorting on the run name. - """ - if self._data_provider: - experiment = plugin_util.experiment_id(request.environ) - runs = sorted( - self._data_provider.list_runs(experiment_id=experiment), - key=lambda run: ( - run.start_time if run.start_time is not None else float('inf'), - run.run_name, - ) - ) - run_names = [run.run_name for run in runs] - elif self._db_connection_provider: - db = self._db_connection_provider() - cursor = db.execute(''' + + plugin_name = "core" + + def __init__(self, context): + """Instantiates CorePlugin. + + Args: + context: A base_plugin.TBContext instance. + """ + logdir_spec = context.flags.logdir_spec if context.flags else "" + self._logdir = context.logdir or logdir_spec + self._db_uri = context.db_uri + self._window_title = context.window_title + self._multiplexer = context.multiplexer + self._db_connection_provider = context.db_connection_provider + self._assets_zip_provider = context.assets_zip_provider + if context.flags and context.flags.generic_data == "true": + self._data_provider = context.data_provider + else: + self._data_provider = None + + def is_active(self): + return True + + def get_plugin_apps(self): + apps = { + "/___rPc_sWiTcH___": self._send_404_without_logging, + "/audio": self._redirect_to_index, + "/data/environment": self._serve_environment, + "/data/logdir": self._serve_logdir, + "/data/runs": self._serve_runs, + "/data/experiments": self._serve_experiments, + "/data/experiment_runs": self._serve_experiment_runs, + "/data/window_properties": self._serve_window_properties, + "/events": self._redirect_to_index, + "/favicon.ico": self._send_404_without_logging, + "/graphs": self._redirect_to_index, + "/histograms": self._redirect_to_index, + "/images": self._redirect_to_index, + } + apps.update(self.get_resource_apps()) + return apps + + def get_resource_apps(self): + apps = {} + if not self._assets_zip_provider: + return apps + + with self._assets_zip_provider() as fp: + with zipfile.ZipFile(fp) as zip_: + for path in zip_.namelist(): + gzipped_asset_bytes = _gzip(zip_.read(path)) + wsgi_app = functools.partial( + self._serve_asset, path, gzipped_asset_bytes + ) + apps["/" + path] = wsgi_app + apps["/"] = apps["/index.html"] + return apps + + @wrappers.Request.application + def _send_404_without_logging(self, request): + return http_util.Respond(request, "Not found", "text/plain", code=404) + + @wrappers.Request.application + def _redirect_to_index(self, unused_request): + return utils.redirect("/") + + @wrappers.Request.application + def _serve_asset(self, path, gzipped_asset_bytes, request): + """Serves a pre-gzipped static asset from the zip file.""" + mimetype = mimetypes.guess_type(path)[0] or "application/octet-stream" + return http_util.Respond( + request, gzipped_asset_bytes, mimetype, content_encoding="gzip" + ) + + @wrappers.Request.application + def _serve_environment(self, request): + """Serve a JSON object containing some base properties used by the + frontend. + + * data_location is either a path to a directory or an address to a + database (depending on which mode TensorBoard is running in). + * window_title is the title of the TensorBoard web page. + """ + if self._data_provider: + experiment = plugin_util.experiment_id(request.environ) + data_location = self._data_provider.data_location(experiment) + else: + data_location = self._logdir or self._db_uri + return http_util.Respond( + request, + { + "data_location": data_location, + "window_title": self._window_title, + }, + "application/json", + ) + + @wrappers.Request.application + def _serve_logdir(self, request): + """Respond with a JSON object containing this TensorBoard's logdir.""" + # TODO(chihuahua): Remove this method once the frontend instead uses the + # /data/environment route (and no deps throughout Google use the + # /data/logdir route). + return http_util.Respond( + request, {"logdir": self._logdir}, "application/json" + ) + + @wrappers.Request.application + def _serve_window_properties(self, request): + """Serve a JSON object containing this TensorBoard's window + properties.""" + # TODO(chihuahua): Remove this method once the frontend instead uses the + # /data/environment route. + return http_util.Respond( + request, {"window_title": self._window_title}, "application/json" + ) + + @wrappers.Request.application + def _serve_runs(self, request): + """Serve a JSON array of run names, ordered by run started time. + + Sort order is by started time (aka first event time) with empty + times sorted last, and then ties are broken by sorting on the + run name. + """ + if self._data_provider: + experiment = plugin_util.experiment_id(request.environ) + runs = sorted( + self._data_provider.list_runs(experiment_id=experiment), + key=lambda run: ( + run.start_time + if run.start_time is not None + else float("inf"), + run.run_name, + ), + ) + run_names = [run.run_name for run in runs] + elif self._db_connection_provider: + db = self._db_connection_provider() + cursor = db.execute( + """ SELECT run_name, started_time IS NULL as started_time_nulls_last, started_time FROM Runs ORDER BY started_time_nulls_last, started_time, run_name - ''') - run_names = [row[0] for row in cursor] - else: - # Python's list.sort is stable, so to order by started time and - # then by name, we can just do the sorts in the reverse order. - run_names = sorted(self._multiplexer.Runs()) - def get_first_event_timestamp(run_name): - try: - return self._multiplexer.FirstEventTimestamp(run_name) - except ValueError as e: - logger.warn( - 'Unable to get first event timestamp for run %s: %s', run_name, e) - # Put runs without a timestamp at the end. - return float('inf') - run_names.sort(key=get_first_event_timestamp) - return http_util.Respond(request, run_names, 'application/json') - - @wrappers.Request.application - def _serve_experiments(self, request): - """Serve a JSON array of experiments. Experiments are ordered by experiment - started time (aka first event time) with empty times sorted last, and then - ties are broken by sorting on the experiment name. - """ - results = self.list_experiments_impl() - return http_util.Respond(request, results, 'application/json') - - def list_experiments_impl(self): - results = [] - if self._db_connection_provider: - db = self._db_connection_provider() - cursor = db.execute(''' + """ + ) + run_names = [row[0] for row in cursor] + else: + # Python's list.sort is stable, so to order by started time and + # then by name, we can just do the sorts in the reverse order. + run_names = sorted(self._multiplexer.Runs()) + + def get_first_event_timestamp(run_name): + try: + return self._multiplexer.FirstEventTimestamp(run_name) + except ValueError as e: + logger.warn( + "Unable to get first event timestamp for run %s: %s", + run_name, + e, + ) + # Put runs without a timestamp at the end. + return float("inf") + + run_names.sort(key=get_first_event_timestamp) + return http_util.Respond(request, run_names, "application/json") + + @wrappers.Request.application + def _serve_experiments(self, request): + """Serve a JSON array of experiments. + + Experiments are ordered by experiment started time (aka first + event time) with empty times sorted last, and then ties are + broken by sorting on the experiment name. + """ + results = self.list_experiments_impl() + return http_util.Respond(request, results, "application/json") + + def list_experiments_impl(self): + results = [] + if self._db_connection_provider: + db = self._db_connection_provider() + cursor = db.execute( + """ SELECT experiment_id, experiment_name, @@ -225,30 +245,33 @@ def list_experiments_impl(self): FROM Experiments ORDER BY started_time_nulls_last, started_time, experiment_name, experiment_id - ''') - results = [{ - "id": row[0], - "name": row[1], - "startTime": row[2], - } for row in cursor] - - return results - - @wrappers.Request.application - def _serve_experiment_runs(self, request): - """Serve a JSON runs of an experiment, specified with query param - `experiment`, with their nested data, tag, populated. Runs returned are - ordered by started time (aka first event time) with empty times sorted last, - and then ties are broken by sorting on the run name. Tags are sorted by - its name, displayName, and lastly, inserted time. - """ - results = [] - if self._db_connection_provider: - exp_id = plugin_util.experiment_id(request.environ) - runs_dict = collections.OrderedDict() - - db = self._db_connection_provider() - cursor = db.execute(''' + """ + ) + results = [ + {"id": row[0], "name": row[1], "startTime": row[2],} + for row in cursor + ] + + return results + + @wrappers.Request.application + def _serve_experiment_runs(self, request): + """Serve a JSON runs of an experiment, specified with query param + `experiment`, with their nested data, tag, populated. + + Runs returned are ordered by started time (aka first event time) + with empty times sorted last, and then ties are broken by + sorting on the run name. Tags are sorted by its name, + displayName, and lastly, inserted time. + """ + results = [] + if self._db_connection_provider: + exp_id = plugin_util.experiment_id(request.environ) + runs_dict = collections.OrderedDict() + + db = self._db_connection_provider() + cursor = db.execute( + """ SELECT Runs.run_id, Runs.run_name, @@ -270,52 +293,58 @@ def _serve_experiment_runs(self, request): Tags.tag_name, Tags.display_name, Tags.inserted_time; - ''', (exp_id,)) - for row in cursor: - run_id = row[0] - if not run_id in runs_dict: - runs_dict[run_id] = { - "id": run_id, - "name": row[1], - "startTime": math.floor(row[2]), - "tags": [], - } - # tag can be missing. - if row[4]: - runs_dict[run_id].get("tags").append({ - "id": row[4], - "displayName": row[6], - "name": row[5], - "pluginName": row[7], - }) - results = list(runs_dict.values()) - return http_util.Respond(request, results, 'application/json') + """, + (exp_id,), + ) + for row in cursor: + run_id = row[0] + if not run_id in runs_dict: + runs_dict[run_id] = { + "id": run_id, + "name": row[1], + "startTime": math.floor(row[2]), + "tags": [], + } + # tag can be missing. + if row[4]: + runs_dict[run_id].get("tags").append( + { + "id": row[4], + "displayName": row[6], + "name": row[5], + "pluginName": row[7], + } + ) + results = list(runs_dict.values()) + return http_util.Respond(request, results, "application/json") + class CorePluginLoader(base_plugin.TBLoader): - """CorePlugin factory.""" - - def define_flags(self, parser): - """Adds standard TensorBoard CLI flags to parser.""" - parser.add_argument( - '--logdir', - metavar='PATH', - type=str, - default='', - help='''\ + """CorePlugin factory.""" + + def define_flags(self, parser): + """Adds standard TensorBoard CLI flags to parser.""" + parser.add_argument( + "--logdir", + metavar="PATH", + type=str, + default="", + help="""\ Directory where TensorBoard will look to find TensorFlow event files that it can display. TensorBoard will recursively walk the directory structure rooted at logdir, looking for .*tfevents.* files. A leading tilde will be expanded with the semantics of Python's os.expanduser function. -''') - - parser.add_argument( - '--logdir_spec', - metavar='PATH_SPEC', - type=str, - default='', - help='''\ +""", + ) + + parser.add_argument( + "--logdir_spec", + metavar="PATH_SPEC", + type=str, + default="", + help="""\ Like `--logdir`, but with special interpretation for commas and colons: commas separate multiple runs, where a colon specifies a new name for a run. For example: @@ -325,77 +354,84 @@ def define_flags(self, parser): log directories recursively; for finer-grained control, prefer using a symlink tree. Some features may not work when using `--logdir_spec` instead of `--logdir`. -''') - - parser.add_argument( - '--host', - metavar='ADDR', - type=str, - default=None, # like localhost, but prints a note about `--bind_all` - help='''\ +""", + ) + + parser.add_argument( + "--host", + metavar="ADDR", + type=str, + default=None, # like localhost, but prints a note about `--bind_all` + help="""\ What host to listen to (default: localhost). To serve to the entire local network on both IPv4 and IPv6, see `--bind_all`, with which this option is mutually exclusive. -''') +""", + ) - parser.add_argument( - '--bind_all', - action='store_true', - help='''\ + parser.add_argument( + "--bind_all", + action="store_true", + help="""\ Serve on all public interfaces. This will expose your TensorBoard instance to the network on both IPv4 and IPv6 (where available). Mutually exclusive with `--host`. -''') - - - parser.add_argument( - '--port', - metavar='PORT', - type=lambda s: (None if s == "default" else int(s)), - default="default", - help='''\ +""", + ) + + parser.add_argument( + "--port", + metavar="PORT", + type=lambda s: (None if s == "default" else int(s)), + default="default", + help="""\ Port to serve TensorBoard on. Pass 0 to request an unused port selected by the operating system, or pass "default" to try to bind to the default port (%s) but search for a nearby free port if the default port is unavailable. (default: "default").\ -''' % DEFAULT_PORT) - - parser.add_argument( - '--purge_orphaned_data', - metavar='BOOL', - # Custom str-to-bool converter since regular bool() doesn't work. - type=lambda v: {'true': True, 'false': False}.get(v.lower(), v), - choices=[True, False], - default=True, - help='''\ +""" + % DEFAULT_PORT, + ) + + parser.add_argument( + "--purge_orphaned_data", + metavar="BOOL", + # Custom str-to-bool converter since regular bool() doesn't work. + type=lambda v: {"true": True, "false": False}.get(v.lower(), v), + choices=[True, False], + default=True, + help="""\ Whether to purge data that may have been orphaned due to TensorBoard restarts. Setting --purge_orphaned_data=False can be used to debug data disappearance. (default: %(default)s)\ -''') - - parser.add_argument( - '--db', - metavar='URI', - type=str, - default='', - help='''\ +""", + ) + + parser.add_argument( + "--db", + metavar="URI", + type=str, + default="", + help="""\ [experimental] sets SQL database URI and enables DB backend mode, which is read-only unless --db_import is also passed.\ -''') +""", + ) - parser.add_argument( - '--db_import', - action='store_true', - help='''\ + parser.add_argument( + "--db_import", + action="store_true", + help="""\ [experimental] enables DB read-and-import mode, which in combination with --logdir imports event files into a DB backend on the fly. The backing DB is temporary unless --db is also passed to specify a DB path to use.\ -''') +""", + ) - parser.add_argument( - '--inspect', - action='store_true', - help='''\ + parser.add_argument( + "--inspect", + action="store_true", + help="""\ Prints digests of event files to command line. This is useful when no data is shown on TensorBoard, or the data shown @@ -407,39 +443,43 @@ def define_flags(self, parser): `tensorboard --inspect --logdir mylogdir --tag loss` See tensorboard/backend/event_processing/event_file_inspector.py for more info.\ -''') - - # This flag has a "_tb" suffix to avoid conflicting with an internal flag - # named --version. Note that due to argparse auto-expansion of unambiguous - # flag prefixes, you can still invoke this as `tensorboard --version`. - parser.add_argument( - '--version_tb', - action='store_true', - help='Prints the version of Tensorboard') - - parser.add_argument( - '--tag', - metavar='TAG', - type=str, - default='', - help='tag to query for; used with --inspect') - - parser.add_argument( - '--event_file', - metavar='PATH', - type=str, - default='', - help='''\ +""", + ) + + # This flag has a "_tb" suffix to avoid conflicting with an internal flag + # named --version. Note that due to argparse auto-expansion of unambiguous + # flag prefixes, you can still invoke this as `tensorboard --version`. + parser.add_argument( + "--version_tb", + action="store_true", + help="Prints the version of Tensorboard", + ) + + parser.add_argument( + "--tag", + metavar="TAG", + type=str, + default="", + help="tag to query for; used with --inspect", + ) + + parser.add_argument( + "--event_file", + metavar="PATH", + type=str, + default="", + help="""\ The particular event file to query for. Only used if --inspect is present and --logdir is not specified.\ -''') - - parser.add_argument( - '--path_prefix', - metavar='PATH', - type=str, - default='', - help='''\ +""", + ) + + parser.add_argument( + "--path_prefix", + metavar="PATH", + type=str, + default="", + help="""\ An optional, relative prefix to the path, e.g. "/path/to/tensorboard". resulting in the new base url being located at localhost:6006/path/to/tensorboard under default settings. A leading @@ -447,59 +487,64 @@ def define_flags(self, parser): optional and has no effect. The path_prefix can be leveraged for path based routing of an ELB when the website base_url is not available e.g. "example.site.com/path/to/tensorboard/".\ -''') - - parser.add_argument( - '--window_title', - metavar='TEXT', - type=str, - default='', - help='changes title of browser window') - - parser.add_argument( - '--max_reload_threads', - metavar='COUNT', - type=int, - default=1, - help='''\ +""", + ) + + parser.add_argument( + "--window_title", + metavar="TEXT", + type=str, + default="", + help="changes title of browser window", + ) + + parser.add_argument( + "--max_reload_threads", + metavar="COUNT", + type=int, + default=1, + help="""\ The max number of threads that TensorBoard can use to reload runs. Not relevant for db read-only mode. Each thread reloads one run at a time. (default: %(default)s)\ -''') - - parser.add_argument( - '--reload_interval', - metavar='SECONDS', - type=float, - default=5.0, - help='''\ +""", + ) + + parser.add_argument( + "--reload_interval", + metavar="SECONDS", + type=float, + default=5.0, + help="""\ How often the backend should load more data, in seconds. Set to 0 to load just once at startup and a negative number to never reload at all. Not relevant for DB read-only mode. (default: %(default)s)\ -''') - - parser.add_argument( - '--reload_task', - metavar='TYPE', - type=str, - default='auto', - choices=['auto', 'thread', 'process', 'blocking'], - help='''\ +""", + ) + + parser.add_argument( + "--reload_task", + metavar="TYPE", + type=str, + default="auto", + choices=["auto", "thread", "process", "blocking"], + help="""\ [experimental] The mechanism to use for the background data reload task. The default "auto" option will conditionally use threads for legacy reloading and a child process for DB import reloading. The "process" option is only useful with DB import mode. The "blocking" option will block startup until reload finishes, and requires --load_interval=0. (default: %(default)s)\ -''') - - parser.add_argument( - '--reload_multifile', - metavar='BOOL', - # Custom str-to-bool converter since regular bool() doesn't work. - type=lambda v: {'true': True, 'false': False}.get(v.lower(), v), - choices=[True, False], - default=None, - help='''\ +""", + ) + + parser.add_argument( + "--reload_multifile", + metavar="BOOL", + # Custom str-to-bool converter since regular bool() doesn't work. + type=lambda v: {"true": True, "false": False}.get(v.lower(), v), + choices=[True, False], + default=None, + help="""\ [experimental] If true, this enables experimental support for continuously polling multiple event files in each run directory for newly appended data (rather than only polling the last event file). Event files will only be @@ -507,14 +552,15 @@ def define_flags(self, parser): defined by --reload_multifile_inactive_secs, to limit resource usage. Beware of running out of memory if the logdir contains many active event files. (default: false)\ -''') - - parser.add_argument( - '--reload_multifile_inactive_secs', - metavar='SECONDS', - type=int, - default=4000, - help='''\ +""", + ) + + parser.add_argument( + "--reload_multifile_inactive_secs", + metavar="SECONDS", + type=int, + default=4000, + help="""\ [experimental] Configures the age threshold in seconds at which an event file that has no event wall time more recent than that will be considered an inactive file and no longer polled (to limit resource usage). If set to -1, @@ -523,25 +569,27 @@ def define_flags(self, parser): last-file-only polling strategy (akin to --reload_multifile=false). (default: %(default)s - intended to ensure an event file remains active if it receives new data at least once per hour)\ -''') - - parser.add_argument( - '--generic_data', - metavar='TYPE', - type=str, - default='auto', - choices=['false', 'auto', 'true'], - help='''\ +""", + ) + + parser.add_argument( + "--generic_data", + metavar="TYPE", + type=str, + default="auto", + choices=["false", "auto", "true"], + help="""\ [experimental] Whether to use generic data provider infrastructure. The "auto" option enables this only for dashboards that are considered stable under the new codepaths. (default: %(default)s)\ -''') - - parser.add_argument( - '--samples_per_plugin', - type=str, - default='', - help='''\ +""", + ) + + parser.add_argument( + "--samples_per_plugin", + type=str, + default="", + help="""\ An optional comma separated list of plugin_name=num_samples pairs to explicitly specify how many samples to keep per tag for that plugin. For unspecified plugins, TensorBoard randomly downsamples logged summaries @@ -550,45 +598,54 @@ def define_flags(self, parser): means keep all samples of that type. For instance "scalars=500,images=0" keeps 500 scalars and all images. Most users should not need to set this flag.\ -''') - - def fix_flags(self, flags): - """Fixes standard TensorBoard CLI flags to parser.""" - FlagsError = base_plugin.FlagsError - if flags.version_tb: - pass - elif flags.inspect: - if flags.logdir_spec: - raise FlagsError('--logdir_spec is not supported with --inspect.') - if flags.logdir and flags.event_file: - raise FlagsError( - 'Must specify either --logdir or --event_file, but not both.') - if not (flags.logdir or flags.event_file): - raise FlagsError('Must specify either --logdir or --event_file.') - elif flags.logdir and flags.logdir_spec: - raise FlagsError( - 'May not specify both --logdir and --logdir_spec') - elif not flags.db and not flags.logdir and not flags.logdir_spec: - raise FlagsError('A logdir or db must be specified. ' - 'For example `tensorboard --logdir mylogdir` ' - 'or `tensorboard --db sqlite:~/.tensorboard.db`. ' - 'Run `tensorboard --helpfull` for details and examples.') - elif flags.host is not None and flags.bind_all: - raise FlagsError('Must not specify both --host and --bind_all.') - - flags.path_prefix = flags.path_prefix.rstrip('/') - if flags.path_prefix and not flags.path_prefix.startswith('/'): - raise FlagsError( - 'Path prefix must start with slash, but got: %r.' % flags.path_prefix) - - def load(self, context): - """Creates CorePlugin instance.""" - return CorePlugin(context) +""", + ) + + def fix_flags(self, flags): + """Fixes standard TensorBoard CLI flags to parser.""" + FlagsError = base_plugin.FlagsError + if flags.version_tb: + pass + elif flags.inspect: + if flags.logdir_spec: + raise FlagsError( + "--logdir_spec is not supported with --inspect." + ) + if flags.logdir and flags.event_file: + raise FlagsError( + "Must specify either --logdir or --event_file, but not both." + ) + if not (flags.logdir or flags.event_file): + raise FlagsError( + "Must specify either --logdir or --event_file." + ) + elif flags.logdir and flags.logdir_spec: + raise FlagsError("May not specify both --logdir and --logdir_spec") + elif not flags.db and not flags.logdir and not flags.logdir_spec: + raise FlagsError( + "A logdir or db must be specified. " + "For example `tensorboard --logdir mylogdir` " + "or `tensorboard --db sqlite:~/.tensorboard.db`. " + "Run `tensorboard --helpfull` for details and examples." + ) + elif flags.host is not None and flags.bind_all: + raise FlagsError("Must not specify both --host and --bind_all.") + + flags.path_prefix = flags.path_prefix.rstrip("/") + if flags.path_prefix and not flags.path_prefix.startswith("/"): + raise FlagsError( + "Path prefix must start with slash, but got: %r." + % flags.path_prefix + ) + + def load(self, context): + """Creates CorePlugin instance.""" + return CorePlugin(context) def _gzip(bytestring): - out = six.BytesIO() - # Set mtime to zero for deterministic results across TensorBoard launches. - with gzip.GzipFile(fileobj=out, mode='wb', compresslevel=3, mtime=0) as f: - f.write(bytestring) - return out.getvalue() + out = six.BytesIO() + # Set mtime to zero for deterministic results across TensorBoard launches. + with gzip.GzipFile(fileobj=out, mode="wb", compresslevel=3, mtime=0) as f: + f.write(bytestring) + return out.getvalue() diff --git a/tensorboard/plugins/core/core_plugin_test.py b/tensorboard/plugins/core/core_plugin_test.py index 8b928cd5dc..f7877b78cc 100644 --- a/tensorboard/plugins/core/core_plugin_test.py +++ b/tensorboard/plugins/core/core_plugin_test.py @@ -33,372 +33,436 @@ from werkzeug import wrappers from tensorboard.backend import application -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from tensorboard.backend.event_processing import ( + plugin_event_multiplexer as event_multiplexer, +) from tensorboard.compat.proto import graph_pb2 from tensorboard.compat.proto import meta_graph_pb2 from tensorboard.plugins import base_plugin from tensorboard.plugins.core import core_plugin from tensorboard.util import test_util -FAKE_INDEX_HTML = b'fake-index' +FAKE_INDEX_HTML = b"fake-index" class FakeFlags(object): - def __init__( - self, - bind_all=False, - host=None, - inspect=False, - version_tb=False, - logdir='', - logdir_spec='', - event_file='', - db='', - path_prefix=''): - self.bind_all = bind_all - self.host = host - self.inspect = inspect - self.version_tb = version_tb - self.logdir = logdir - self.logdir_spec = logdir_spec - self.event_file = event_file - self.db = db - self.path_prefix = path_prefix + def __init__( + self, + bind_all=False, + host=None, + inspect=False, + version_tb=False, + logdir="", + logdir_spec="", + event_file="", + db="", + path_prefix="", + ): + self.bind_all = bind_all + self.host = host + self.inspect = inspect + self.version_tb = version_tb + self.logdir = logdir + self.logdir_spec = logdir_spec + self.event_file = event_file + self.db = db + self.path_prefix = path_prefix class CorePluginTest(tf.test.TestCase): - _only_use_meta_graph = False # Server data contains only a GraphDef - - def setUp(self): - super(CorePluginTest, self).setUp() - self.temp_dir = self.get_temp_dir() - self.addCleanup(shutil.rmtree, self.temp_dir) - self.db_path = os.path.join(self.temp_dir, 'db.db') - self.db = sqlite3.connect(self.db_path) - self.db_uri = 'sqlite:' + self.db_path - self._start_logdir_based_server(self.temp_dir) - self._start_db_based_server() - - def testRoutesProvided(self): - """Tests that the plugin offers the correct routes.""" - routes = self.logdir_based_plugin.get_plugin_apps() - self.assertIsInstance(routes['/data/logdir'], collections.Callable) - self.assertIsInstance(routes['/data/runs'], collections.Callable) - - def testFlag(self): - loader = core_plugin.CorePluginLoader() - loader.fix_flags(FakeFlags(version_tb=True)) - loader.fix_flags(FakeFlags(inspect=True, logdir='/tmp')) - loader.fix_flags(FakeFlags(inspect=True, event_file='/tmp/event.out')) - loader.fix_flags(FakeFlags(inspect=False, logdir='/tmp')) - loader.fix_flags(FakeFlags(inspect=False, db='sqlite:foo')) - # User can pass both, although the behavior is not clearly defined. - loader.fix_flags(FakeFlags(inspect=False, logdir='/tmp', db="sqlite:foo")) - - logdir_or_db_req = r'A logdir or db must be specified' - one_of_event_or_logdir_req = r'Must specify either --logdir.*but not both.$' - event_or_logdir_req = r'Must specify either --logdir or --event_file.$' - - with six.assertRaisesRegex(self, ValueError, event_or_logdir_req): - loader.fix_flags(FakeFlags(inspect=True)) - with six.assertRaisesRegex(self, ValueError, event_or_logdir_req): - loader.fix_flags(FakeFlags(inspect=True, db='sqlite:~/db.sqlite')) - with six.assertRaisesRegex(self, ValueError, one_of_event_or_logdir_req): - loader.fix_flags(FakeFlags(inspect=True, logdir='/tmp', - event_file='/tmp/event.out')) - with six.assertRaisesRegex(self, ValueError, logdir_or_db_req): - loader.fix_flags(FakeFlags(inspect=False)) - with six.assertRaisesRegex(self, ValueError, logdir_or_db_req): - loader.fix_flags(FakeFlags(inspect=False, event_file='/tmp/event.out')) - - def testPathPrefix_stripsTrailingSlashes(self): - loader = core_plugin.CorePluginLoader() - for path_prefix in ('/hello', '/hello/', '/hello//', '/hello///'): - flag = FakeFlags(inspect=False, logdir='/tmp', path_prefix=path_prefix) - loader.fix_flags(flag) - self.assertEqual( - flag.path_prefix, - '/hello', - 'got %r (input %r)' % (flag.path_prefix, path_prefix), - ) - - def testPathPrefix_mustStartWithSlash(self): - loader = core_plugin.CorePluginLoader() - flag = FakeFlags(inspect=False, logdir='/tmp', path_prefix='noslash') - with self.assertRaises(base_plugin.FlagsError) as cm: - loader.fix_flags(flag) - msg = str(cm.exception) - self.assertIn('must start with slash', msg) - self.assertIn(repr('noslash'), msg) - - def testIndex_returnsActualHtml(self): - """Test the format of the /data/runs endpoint.""" - response = self.logdir_based_server.get('/') - self.assertEqual(200, response.status_code) - self.assertStartsWith(response.headers.get('Content-Type'), 'text/html') - html = response.get_data() - self.assertEqual(html, FAKE_INDEX_HTML) - - def testDataPaths_disableAllCaching(self): - """Test the format of the /data/runs endpoint.""" - for path in ('/data/runs', '/data/logdir'): - response = self.logdir_based_server.get(path) - self.assertEqual(200, response.status_code, msg=path) - self.assertEqual('0', response.headers.get('Expires'), msg=path) - - def testEnvironmentForDbUri(self): - """Test that the environment route correctly returns the database URI.""" - parsed_object = self._get_json(self.db_based_server, '/data/environment') - self.assertEqual(parsed_object['data_location'], self.db_uri) - - def testEnvironmentForLogdir(self): - """Test that the environment route correctly returns the logdir.""" - parsed_object = self._get_json( - self.logdir_based_server, '/data/environment') - self.assertEqual(parsed_object['data_location'], self.logdir) - - def testEnvironmentForWindowTitle(self): - """Test that the environment route correctly returns the window title.""" - parsed_object_db = self._get_json( - self.db_based_server, '/data/environment') - parsed_object_logdir = self._get_json( - self.logdir_based_server, '/data/environment') - self.assertEqual( - parsed_object_db['window_title'], parsed_object_logdir['window_title']) - self.assertEqual(parsed_object_db['window_title'], 'title foo') - - def testLogdir(self): - """Test the format of the data/logdir endpoint.""" - parsed_object = self._get_json(self.logdir_based_server, '/data/logdir') - self.assertEqual(parsed_object, {'logdir': self.logdir}) - - @test_util.run_v1_only('Uses tf.contrib when adding runs.') - def testRuns(self): - """Test the format of the /data/runs endpoint.""" - self._add_run('run1') - run_json = self._get_json(self.db_based_server, '/data/runs') - self.assertEqual(run_json, ['run1']) - run_json = self._get_json(self.logdir_based_server, '/data/runs') - self.assertEqual(run_json, ['run1']) - - @test_util.run_v1_only('Uses tf.contrib when adding runs.') - def testExperiments(self): - """Test the format of the /data/experiments endpoint.""" - self._add_run('run1', experiment_name = 'exp1') - self._add_run('run2', experiment_name = 'exp1') - self._add_run('run3', experiment_name = 'exp2') - - [exp1, exp2] = self._get_json(self.db_based_server, '/data/experiments') - self.assertEqual(exp1.get('name'), 'exp1') - self.assertEqual(exp2.get('name'), 'exp2') - - exp_json = self._get_json(self.logdir_based_server, '/data/experiments') - self.assertEqual(exp_json, []) - - @test_util.run_v1_only('Uses tf.contrib when adding runs.') - def testExperimentRuns(self): - """Test the format of the /data/experiment_runs endpoint.""" - self._add_run('run1', experiment_name = 'exp1') - self._add_run('run2', experiment_name = 'exp1') - self._add_run('run3', experiment_name = 'exp2') - - [exp1, exp2] = self._get_json(self.db_based_server, '/data/experiments') - - exp1_runs = self._get_json(self.db_based_server, - '/experiment/%s/data/experiment_runs' % exp1.get('id')) - self.assertEqual(len(exp1_runs), 2); - self.assertEqual(exp1_runs[0].get('name'), 'run1'); - self.assertEqual(exp1_runs[1].get('name'), 'run2'); - self.assertEqual(len(exp1_runs[0].get('tags')), 1); - self.assertEqual(exp1_runs[0].get('tags')[0].get('name'), 'mytag'); - self.assertEqual(len(exp1_runs[1].get('tags')), 1); - self.assertEqual(exp1_runs[1].get('tags')[0].get('name'), 'mytag'); - - exp2_runs = self._get_json(self.db_based_server, - '/experiment/%s/data/experiment_runs' % exp2.get('id')) - self.assertEqual(len(exp2_runs), 1); - self.assertEqual(exp2_runs[0].get('name'), 'run3'); - - # TODO(stephanwlee): Write test on runs that do not have any tag. - - exp_json = self._get_json(self.logdir_based_server, '/data/experiments') - self.assertEqual(exp_json, []) - - @test_util.run_v1_only('Uses tf.contrib when adding runs.') - def testRunsAppendOnly(self): - """Test that new runs appear after old ones in /data/runs.""" - fake_wall_times = { - 'run1': 1234.0, - 'avocado': 2345.0, - 'zebra': 3456.0, - 'ox': 4567.0, - 'mysterious': None, - 'enigmatic': None, - } - - stubs = tf.compat.v1.test.StubOutForTesting() - def FirstEventTimestamp_stub(multiplexer_self, run_name): - del multiplexer_self - matches = [candidate_name - for candidate_name in fake_wall_times - if run_name.endswith(candidate_name)] - self.assertEqual(len(matches), 1, '%s (%s)' % (matches, run_name)) - wall_time = fake_wall_times[matches[0]] - if wall_time is None: - raise ValueError('No event timestamp could be found') - else: - return wall_time - - stubs.SmartSet(self.multiplexer, - 'FirstEventTimestamp', - FirstEventTimestamp_stub) - - # Start with a single run. - self._add_run('run1') - - # Add one run: it should come last. - self._add_run('avocado') - self.assertEqual(self._get_json(self.db_based_server, '/data/runs'), - ['run1', 'avocado']) - self.assertEqual(self._get_json(self.logdir_based_server, '/data/runs'), - ['run1', 'avocado']) - - # Add another run: it should come last, too. - self._add_run('zebra') - self.assertEqual(self._get_json(self.db_based_server, '/data/runs'), - ['run1', 'avocado', 'zebra']) - self.assertEqual(self._get_json(self.logdir_based_server, '/data/runs'), - ['run1', 'avocado', 'zebra']) - - # And maybe there's a run for which we somehow have no timestamp. - self._add_run('mysterious') - with self.db: - self.db.execute('UPDATE Runs SET started_time=NULL WHERE run_name=?', - ['mysterious']) - self.assertEqual(self._get_json(self.db_based_server, '/data/runs'), - ['run1', 'avocado', 'zebra', 'mysterious']) - self.assertEqual(self._get_json(self.logdir_based_server, '/data/runs'), - ['run1', 'avocado', 'zebra', 'mysterious']) - - # Add another timestamped run: it should come before the timestamp-less one. - self._add_run('ox') - self.assertEqual(self._get_json(self.db_based_server, '/data/runs'), - ['run1', 'avocado', 'zebra', 'ox', 'mysterious']) - self.assertEqual(self._get_json(self.logdir_based_server, '/data/runs'), - ['run1', 'avocado', 'zebra', 'ox', 'mysterious']) - - # Add another timestamp-less run, lexicographically before the other one: - # it should come after all timestamped runs but first among timestamp-less. - self._add_run('enigmatic') - with self.db: - self.db.execute('UPDATE Runs SET started_time=NULL WHERE run_name=?', - ['enigmatic']) - self.assertEqual( - self._get_json(self.db_based_server, '/data/runs'), - ['run1', 'avocado', 'zebra', 'ox', 'enigmatic', 'mysterious']) - self.assertEqual( - self._get_json(self.logdir_based_server, '/data/runs'), - ['run1', 'avocado', 'zebra', 'ox', 'enigmatic', 'mysterious']) - - stubs.CleanUp() - - def _start_logdir_based_server(self, temp_dir): - self.logdir = temp_dir - self.multiplexer = event_multiplexer.EventMultiplexer( - size_guidance=application.DEFAULT_SIZE_GUIDANCE, - purge_orphaned_data=True) - context = base_plugin.TBContext( - assets_zip_provider=get_test_assets_zip_provider(), - logdir=self.logdir, - multiplexer=self.multiplexer, - window_title='title foo') - self.logdir_based_plugin = core_plugin.CorePlugin(context) - app = application.TensorBoardWSGI([self.logdir_based_plugin]) - self.logdir_based_server = werkzeug_test.Client(app, wrappers.BaseResponse) - - def _start_db_based_server(self): - db_connection_provider = application.create_sqlite_connection_provider( - self.db_uri) - context = base_plugin.TBContext( - assets_zip_provider=get_test_assets_zip_provider(), - db_connection_provider=db_connection_provider, - db_uri=self.db_uri, - window_title='title foo') - self.db_based_plugin = core_plugin.CorePlugin(context) - app = application.TensorBoardWSGI([self.db_based_plugin]) - self.db_based_server = werkzeug_test.Client(app, wrappers.BaseResponse) - - def _add_run(self, run_name, experiment_name='experiment'): - self._generate_test_data(run_name, experiment_name) - self.multiplexer.AddRunsFromDirectory(self.logdir) - self.multiplexer.Reload() - - def _get_json(self, server, path): - response = server.get(path) - self.assertEqual(200, response.status_code) - return self._get_json_payload(response) - - def _get_json_payload(self, response): - self.assertStartsWith(response.headers.get('Content-Type'), - 'application/json') - return json.loads(response.get_data().decode('utf-8')) - - def _generate_test_data(self, run_name, experiment_name): - """Generates the test data directory. - - The test data has a single run of the given name, containing: - - a graph definition and metagraph definition - - Arguments: - run_name: The directory under self.logdir into which to write - events. - """ - run_path = os.path.join(self.logdir, run_name) - with test_util.FileWriterCache.get(run_path) as writer: - - # Add a simple graph event. - graph_def = graph_pb2.GraphDef() - node1 = graph_def.node.add() - node1.name = 'a' - node2 = graph_def.node.add() - node2.name = 'b' - node2.attr['very_large_attr'].s = b'a' * 2048 # 2 KB attribute - - meta_graph_def = meta_graph_pb2.MetaGraphDef(graph_def=graph_def) - - if self._only_use_meta_graph: - writer.add_meta_graph(meta_graph_def) - else: - writer.add_graph(graph=None, graph_def=graph_def) - - # Write data for the run to the database. - # TODO(nickfelt): Figure out why reseting the graph is necessary. - tf.compat.v1.reset_default_graph() - db_writer = tf.contrib.summary.create_db_writer( - db_uri=self.db_path, - experiment_name=experiment_name, - run_name=run_name, - user_name='user') - with db_writer.as_default(), tf.contrib.summary.always_record_summaries(): - tf.contrib.summary.scalar('mytag', 1) - - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - sess.run(tf.contrib.summary.summary_writer_initializer_op()) - sess.run(tf.contrib.summary.all_summary_ops()) + _only_use_meta_graph = False # Server data contains only a GraphDef + + def setUp(self): + super(CorePluginTest, self).setUp() + self.temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, self.temp_dir) + self.db_path = os.path.join(self.temp_dir, "db.db") + self.db = sqlite3.connect(self.db_path) + self.db_uri = "sqlite:" + self.db_path + self._start_logdir_based_server(self.temp_dir) + self._start_db_based_server() + + def testRoutesProvided(self): + """Tests that the plugin offers the correct routes.""" + routes = self.logdir_based_plugin.get_plugin_apps() + self.assertIsInstance(routes["/data/logdir"], collections.Callable) + self.assertIsInstance(routes["/data/runs"], collections.Callable) + + def testFlag(self): + loader = core_plugin.CorePluginLoader() + loader.fix_flags(FakeFlags(version_tb=True)) + loader.fix_flags(FakeFlags(inspect=True, logdir="/tmp")) + loader.fix_flags(FakeFlags(inspect=True, event_file="/tmp/event.out")) + loader.fix_flags(FakeFlags(inspect=False, logdir="/tmp")) + loader.fix_flags(FakeFlags(inspect=False, db="sqlite:foo")) + # User can pass both, although the behavior is not clearly defined. + loader.fix_flags( + FakeFlags(inspect=False, logdir="/tmp", db="sqlite:foo") + ) + + logdir_or_db_req = r"A logdir or db must be specified" + one_of_event_or_logdir_req = ( + r"Must specify either --logdir.*but not both.$" + ) + event_or_logdir_req = r"Must specify either --logdir or --event_file.$" + + with six.assertRaisesRegex(self, ValueError, event_or_logdir_req): + loader.fix_flags(FakeFlags(inspect=True)) + with six.assertRaisesRegex(self, ValueError, event_or_logdir_req): + loader.fix_flags(FakeFlags(inspect=True, db="sqlite:~/db.sqlite")) + with six.assertRaisesRegex( + self, ValueError, one_of_event_or_logdir_req + ): + loader.fix_flags( + FakeFlags( + inspect=True, logdir="/tmp", event_file="/tmp/event.out" + ) + ) + with six.assertRaisesRegex(self, ValueError, logdir_or_db_req): + loader.fix_flags(FakeFlags(inspect=False)) + with six.assertRaisesRegex(self, ValueError, logdir_or_db_req): + loader.fix_flags( + FakeFlags(inspect=False, event_file="/tmp/event.out") + ) + + def testPathPrefix_stripsTrailingSlashes(self): + loader = core_plugin.CorePluginLoader() + for path_prefix in ("/hello", "/hello/", "/hello//", "/hello///"): + flag = FakeFlags( + inspect=False, logdir="/tmp", path_prefix=path_prefix + ) + loader.fix_flags(flag) + self.assertEqual( + flag.path_prefix, + "/hello", + "got %r (input %r)" % (flag.path_prefix, path_prefix), + ) + + def testPathPrefix_mustStartWithSlash(self): + loader = core_plugin.CorePluginLoader() + flag = FakeFlags(inspect=False, logdir="/tmp", path_prefix="noslash") + with self.assertRaises(base_plugin.FlagsError) as cm: + loader.fix_flags(flag) + msg = str(cm.exception) + self.assertIn("must start with slash", msg) + self.assertIn(repr("noslash"), msg) + + def testIndex_returnsActualHtml(self): + """Test the format of the /data/runs endpoint.""" + response = self.logdir_based_server.get("/") + self.assertEqual(200, response.status_code) + self.assertStartsWith(response.headers.get("Content-Type"), "text/html") + html = response.get_data() + self.assertEqual(html, FAKE_INDEX_HTML) + + def testDataPaths_disableAllCaching(self): + """Test the format of the /data/runs endpoint.""" + for path in ("/data/runs", "/data/logdir"): + response = self.logdir_based_server.get(path) + self.assertEqual(200, response.status_code, msg=path) + self.assertEqual("0", response.headers.get("Expires"), msg=path) + + def testEnvironmentForDbUri(self): + """Test that the environment route correctly returns the database + URI.""" + parsed_object = self._get_json( + self.db_based_server, "/data/environment" + ) + self.assertEqual(parsed_object["data_location"], self.db_uri) + + def testEnvironmentForLogdir(self): + """Test that the environment route correctly returns the logdir.""" + parsed_object = self._get_json( + self.logdir_based_server, "/data/environment" + ) + self.assertEqual(parsed_object["data_location"], self.logdir) + + def testEnvironmentForWindowTitle(self): + """Test that the environment route correctly returns the window + title.""" + parsed_object_db = self._get_json( + self.db_based_server, "/data/environment" + ) + parsed_object_logdir = self._get_json( + self.logdir_based_server, "/data/environment" + ) + self.assertEqual( + parsed_object_db["window_title"], + parsed_object_logdir["window_title"], + ) + self.assertEqual(parsed_object_db["window_title"], "title foo") + + def testLogdir(self): + """Test the format of the data/logdir endpoint.""" + parsed_object = self._get_json(self.logdir_based_server, "/data/logdir") + self.assertEqual(parsed_object, {"logdir": self.logdir}) + + @test_util.run_v1_only("Uses tf.contrib when adding runs.") + def testRuns(self): + """Test the format of the /data/runs endpoint.""" + self._add_run("run1") + run_json = self._get_json(self.db_based_server, "/data/runs") + self.assertEqual(run_json, ["run1"]) + run_json = self._get_json(self.logdir_based_server, "/data/runs") + self.assertEqual(run_json, ["run1"]) + + @test_util.run_v1_only("Uses tf.contrib when adding runs.") + def testExperiments(self): + """Test the format of the /data/experiments endpoint.""" + self._add_run("run1", experiment_name="exp1") + self._add_run("run2", experiment_name="exp1") + self._add_run("run3", experiment_name="exp2") + + [exp1, exp2] = self._get_json(self.db_based_server, "/data/experiments") + self.assertEqual(exp1.get("name"), "exp1") + self.assertEqual(exp2.get("name"), "exp2") + + exp_json = self._get_json(self.logdir_based_server, "/data/experiments") + self.assertEqual(exp_json, []) + + @test_util.run_v1_only("Uses tf.contrib when adding runs.") + def testExperimentRuns(self): + """Test the format of the /data/experiment_runs endpoint.""" + self._add_run("run1", experiment_name="exp1") + self._add_run("run2", experiment_name="exp1") + self._add_run("run3", experiment_name="exp2") + + [exp1, exp2] = self._get_json(self.db_based_server, "/data/experiments") + + exp1_runs = self._get_json( + self.db_based_server, + "/experiment/%s/data/experiment_runs" % exp1.get("id"), + ) + self.assertEqual(len(exp1_runs), 2) + self.assertEqual(exp1_runs[0].get("name"), "run1") + self.assertEqual(exp1_runs[1].get("name"), "run2") + self.assertEqual(len(exp1_runs[0].get("tags")), 1) + self.assertEqual(exp1_runs[0].get("tags")[0].get("name"), "mytag") + self.assertEqual(len(exp1_runs[1].get("tags")), 1) + self.assertEqual(exp1_runs[1].get("tags")[0].get("name"), "mytag") + + exp2_runs = self._get_json( + self.db_based_server, + "/experiment/%s/data/experiment_runs" % exp2.get("id"), + ) + self.assertEqual(len(exp2_runs), 1) + self.assertEqual(exp2_runs[0].get("name"), "run3") + + # TODO(stephanwlee): Write test on runs that do not have any tag. + + exp_json = self._get_json(self.logdir_based_server, "/data/experiments") + self.assertEqual(exp_json, []) + + @test_util.run_v1_only("Uses tf.contrib when adding runs.") + def testRunsAppendOnly(self): + """Test that new runs appear after old ones in /data/runs.""" + fake_wall_times = { + "run1": 1234.0, + "avocado": 2345.0, + "zebra": 3456.0, + "ox": 4567.0, + "mysterious": None, + "enigmatic": None, + } + + stubs = tf.compat.v1.test.StubOutForTesting() + + def FirstEventTimestamp_stub(multiplexer_self, run_name): + del multiplexer_self + matches = [ + candidate_name + for candidate_name in fake_wall_times + if run_name.endswith(candidate_name) + ] + self.assertEqual(len(matches), 1, "%s (%s)" % (matches, run_name)) + wall_time = fake_wall_times[matches[0]] + if wall_time is None: + raise ValueError("No event timestamp could be found") + else: + return wall_time + + stubs.SmartSet( + self.multiplexer, "FirstEventTimestamp", FirstEventTimestamp_stub + ) + + # Start with a single run. + self._add_run("run1") + + # Add one run: it should come last. + self._add_run("avocado") + self.assertEqual( + self._get_json(self.db_based_server, "/data/runs"), + ["run1", "avocado"], + ) + self.assertEqual( + self._get_json(self.logdir_based_server, "/data/runs"), + ["run1", "avocado"], + ) + + # Add another run: it should come last, too. + self._add_run("zebra") + self.assertEqual( + self._get_json(self.db_based_server, "/data/runs"), + ["run1", "avocado", "zebra"], + ) + self.assertEqual( + self._get_json(self.logdir_based_server, "/data/runs"), + ["run1", "avocado", "zebra"], + ) + + # And maybe there's a run for which we somehow have no timestamp. + self._add_run("mysterious") + with self.db: + self.db.execute( + "UPDATE Runs SET started_time=NULL WHERE run_name=?", + ["mysterious"], + ) + self.assertEqual( + self._get_json(self.db_based_server, "/data/runs"), + ["run1", "avocado", "zebra", "mysterious"], + ) + self.assertEqual( + self._get_json(self.logdir_based_server, "/data/runs"), + ["run1", "avocado", "zebra", "mysterious"], + ) + + # Add another timestamped run: it should come before the timestamp-less one. + self._add_run("ox") + self.assertEqual( + self._get_json(self.db_based_server, "/data/runs"), + ["run1", "avocado", "zebra", "ox", "mysterious"], + ) + self.assertEqual( + self._get_json(self.logdir_based_server, "/data/runs"), + ["run1", "avocado", "zebra", "ox", "mysterious"], + ) + + # Add another timestamp-less run, lexicographically before the other one: + # it should come after all timestamped runs but first among timestamp-less. + self._add_run("enigmatic") + with self.db: + self.db.execute( + "UPDATE Runs SET started_time=NULL WHERE run_name=?", + ["enigmatic"], + ) + self.assertEqual( + self._get_json(self.db_based_server, "/data/runs"), + ["run1", "avocado", "zebra", "ox", "enigmatic", "mysterious"], + ) + self.assertEqual( + self._get_json(self.logdir_based_server, "/data/runs"), + ["run1", "avocado", "zebra", "ox", "enigmatic", "mysterious"], + ) + + stubs.CleanUp() + + def _start_logdir_based_server(self, temp_dir): + self.logdir = temp_dir + self.multiplexer = event_multiplexer.EventMultiplexer( + size_guidance=application.DEFAULT_SIZE_GUIDANCE, + purge_orphaned_data=True, + ) + context = base_plugin.TBContext( + assets_zip_provider=get_test_assets_zip_provider(), + logdir=self.logdir, + multiplexer=self.multiplexer, + window_title="title foo", + ) + self.logdir_based_plugin = core_plugin.CorePlugin(context) + app = application.TensorBoardWSGI([self.logdir_based_plugin]) + self.logdir_based_server = werkzeug_test.Client( + app, wrappers.BaseResponse + ) + + def _start_db_based_server(self): + db_connection_provider = application.create_sqlite_connection_provider( + self.db_uri + ) + context = base_plugin.TBContext( + assets_zip_provider=get_test_assets_zip_provider(), + db_connection_provider=db_connection_provider, + db_uri=self.db_uri, + window_title="title foo", + ) + self.db_based_plugin = core_plugin.CorePlugin(context) + app = application.TensorBoardWSGI([self.db_based_plugin]) + self.db_based_server = werkzeug_test.Client(app, wrappers.BaseResponse) + + def _add_run(self, run_name, experiment_name="experiment"): + self._generate_test_data(run_name, experiment_name) + self.multiplexer.AddRunsFromDirectory(self.logdir) + self.multiplexer.Reload() + + def _get_json(self, server, path): + response = server.get(path) + self.assertEqual(200, response.status_code) + return self._get_json_payload(response) + + def _get_json_payload(self, response): + self.assertStartsWith( + response.headers.get("Content-Type"), "application/json" + ) + return json.loads(response.get_data().decode("utf-8")) + + def _generate_test_data(self, run_name, experiment_name): + """Generates the test data directory. + + The test data has a single run of the given name, containing: + - a graph definition and metagraph definition + + Arguments: + run_name: The directory under self.logdir into which to write + events. + """ + run_path = os.path.join(self.logdir, run_name) + with test_util.FileWriterCache.get(run_path) as writer: + + # Add a simple graph event. + graph_def = graph_pb2.GraphDef() + node1 = graph_def.node.add() + node1.name = "a" + node2 = graph_def.node.add() + node2.name = "b" + node2.attr["very_large_attr"].s = b"a" * 2048 # 2 KB attribute + + meta_graph_def = meta_graph_pb2.MetaGraphDef(graph_def=graph_def) + + if self._only_use_meta_graph: + writer.add_meta_graph(meta_graph_def) + else: + writer.add_graph(graph=None, graph_def=graph_def) + + # Write data for the run to the database. + # TODO(nickfelt): Figure out why reseting the graph is necessary. + tf.compat.v1.reset_default_graph() + db_writer = tf.contrib.summary.create_db_writer( + db_uri=self.db_path, + experiment_name=experiment_name, + run_name=run_name, + user_name="user", + ) + with db_writer.as_default(), tf.contrib.summary.always_record_summaries(): + tf.contrib.summary.scalar("mytag", 1) + + with tf.compat.v1.Session() as sess: + sess.run(tf.compat.v1.global_variables_initializer()) + sess.run(tf.contrib.summary.summary_writer_initializer_op()) + sess.run(tf.contrib.summary.all_summary_ops()) class CorePluginUsingMetagraphOnlyTest(CorePluginTest): - # Tests new ability to use only the MetaGraphDef - _only_use_meta_graph = True # Server data contains only a MetaGraphDef + # Tests new ability to use only the MetaGraphDef + _only_use_meta_graph = True # Server data contains only a MetaGraphDef def get_test_assets_zip_provider(): - memfile = six.BytesIO() - with zipfile.ZipFile(memfile, mode='w', compression=zipfile.ZIP_DEFLATED) as zf: - zf.writestr('index.html', FAKE_INDEX_HTML) - return lambda: contextlib.closing(six.BytesIO(memfile.getvalue())) + memfile = six.BytesIO() + with zipfile.ZipFile( + memfile, mode="w", compression=zipfile.ZIP_DEFLATED + ) as zf: + zf.writestr("index.html", FAKE_INDEX_HTML) + return lambda: contextlib.closing(six.BytesIO(memfile.getvalue())) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/custom_scalar/custom_scalar_demo.py b/tensorboard/plugins/custom_scalar/custom_scalar_demo.py index 8063e187fe..8e2166cb01 100644 --- a/tensorboard/plugins/custom_scalar/custom_scalar_demo.py +++ b/tensorboard/plugins/custom_scalar/custom_scalar_demo.py @@ -14,7 +14,8 @@ # ============================================================================== """Create sample PR curve summary data. -The logic below logs scalar data and then lays out the custom scalars dashboard. +The logic below logs scalar data and then lays out the custom scalars +dashboard. """ from __future__ import absolute_import @@ -29,85 +30,103 @@ from tensorboard.plugins.custom_scalar import layout_pb2 -LOGDIR = '/tmp/custom_scalar_demo' +LOGDIR = "/tmp/custom_scalar_demo" def run(): - """Run custom scalar demo and generate event files.""" - step = tf.compat.v1.placeholder(tf.float32, shape=[]) - - with tf.name_scope('loss'): - # Specify 2 different loss values, each tagged differently. - summary_lib.scalar('foo', tf.pow(0.9, step)) - summary_lib.scalar('bar', tf.pow(0.85, step + 2)) - - # Log metric baz as well as upper and lower bounds for a margin chart. - middle_baz_value = step + 4 * tf.random.uniform([]) - 2 - summary_lib.scalar('baz', middle_baz_value) - summary_lib.scalar('baz_lower', - middle_baz_value - 6.42 - tf.random.uniform([])) - summary_lib.scalar('baz_upper', - middle_baz_value + 6.42 + tf.random.uniform([])) - - with tf.name_scope('trigFunctions'): - summary_lib.scalar('cosine', tf.cos(step)) - summary_lib.scalar('sine', tf.sin(step)) - summary_lib.scalar('tangent', tf.tan(step)) - - merged_summary = tf.compat.v1.summary.merge_all() - - with tf.compat.v1.Session() as sess, tf.summary.FileWriter(LOGDIR) as writer: - # We only need to specify the layout once (instead of per step). - layout_summary = summary_lib.custom_scalar_pb( - layout_pb2.Layout(category=[ - layout_pb2.Category( - title='losses', - chart=[ - layout_pb2.Chart( - title='losses', - multiline=layout_pb2.MultilineChartContent( - tag=[r'loss(?!.*margin.*)'],)), - layout_pb2.Chart( - title='baz', - margin=layout_pb2.MarginChartContent( - series=[ - layout_pb2.MarginChartContent.Series( - value='loss/baz/scalar_summary', - lower='loss/baz_lower/scalar_summary', - upper='loss/baz_upper/scalar_summary' + """Run custom scalar demo and generate event files.""" + step = tf.compat.v1.placeholder(tf.float32, shape=[]) + + with tf.name_scope("loss"): + # Specify 2 different loss values, each tagged differently. + summary_lib.scalar("foo", tf.pow(0.9, step)) + summary_lib.scalar("bar", tf.pow(0.85, step + 2)) + + # Log metric baz as well as upper and lower bounds for a margin chart. + middle_baz_value = step + 4 * tf.random.uniform([]) - 2 + summary_lib.scalar("baz", middle_baz_value) + summary_lib.scalar( + "baz_lower", middle_baz_value - 6.42 - tf.random.uniform([]) + ) + summary_lib.scalar( + "baz_upper", middle_baz_value + 6.42 + tf.random.uniform([]) + ) + + with tf.name_scope("trigFunctions"): + summary_lib.scalar("cosine", tf.cos(step)) + summary_lib.scalar("sine", tf.sin(step)) + summary_lib.scalar("tangent", tf.tan(step)) + + merged_summary = tf.compat.v1.summary.merge_all() + + with tf.compat.v1.Session() as sess, tf.summary.FileWriter( + LOGDIR + ) as writer: + # We only need to specify the layout once (instead of per step). + layout_summary = summary_lib.custom_scalar_pb( + layout_pb2.Layout( + category=[ + layout_pb2.Category( + title="losses", + chart=[ + layout_pb2.Chart( + title="losses", + multiline=layout_pb2.MultilineChartContent( + tag=[r"loss(?!.*margin.*)"], ), - ],)), - ]), - layout_pb2.Category( - title='trig functions', - chart=[ - layout_pb2.Chart( - title='wave trig functions', - multiline=layout_pb2.MultilineChartContent( - tag=[ - r'trigFunctions/cosine', r'trigFunctions/sine' - ],)), - # The range of tangent is different. Give it its own chart. - layout_pb2.Chart( - title='tan', - multiline=layout_pb2.MultilineChartContent( - tag=[r'trigFunctions/tangent'],)), - ], - # This category we care less about. Make it initially closed. - closed=True), - ])) - writer.add_summary(layout_summary) - - for i in xrange(42): - summary = sess.run(merged_summary, feed_dict={step: i}) - writer.add_summary(summary, global_step=i) + ), + layout_pb2.Chart( + title="baz", + margin=layout_pb2.MarginChartContent( + series=[ + layout_pb2.MarginChartContent.Series( + value="loss/baz/scalar_summary", + lower="loss/baz_lower/scalar_summary", + upper="loss/baz_upper/scalar_summary", + ), + ], + ), + ), + ], + ), + layout_pb2.Category( + title="trig functions", + chart=[ + layout_pb2.Chart( + title="wave trig functions", + multiline=layout_pb2.MultilineChartContent( + tag=[ + r"trigFunctions/cosine", + r"trigFunctions/sine", + ], + ), + ), + # The range of tangent is different. Give it its own chart. + layout_pb2.Chart( + title="tan", + multiline=layout_pb2.MultilineChartContent( + tag=[r"trigFunctions/tangent"], + ), + ), + ], + # This category we care less about. Make it initially closed. + closed=True, + ), + ] + ) + ) + writer.add_summary(layout_summary) + + for i in xrange(42): + summary = sess.run(merged_summary, feed_dict={step: i}) + writer.add_summary(summary, global_step=i) def main(unused_argv): - print('Saving output to %s.' % LOGDIR) - run() - print('Done. Output saved to %s.' % LOGDIR) + print("Saving output to %s." % LOGDIR) + run() + print("Done. Output saved to %s." % LOGDIR) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/tensorboard/plugins/custom_scalar/custom_scalars_plugin.py b/tensorboard/plugins/custom_scalar/custom_scalars_plugin.py index a1f421537f..9212482583 100644 --- a/tensorboard/plugins/custom_scalar/custom_scalars_plugin.py +++ b/tensorboard/plugins/custom_scalar/custom_scalars_plugin.py @@ -43,11 +43,11 @@ # The name of the property in the response for whether the regex is valid. -_REGEX_VALID_PROPERTY = 'regex_valid' +_REGEX_VALID_PROPERTY = "regex_valid" # The name of the property in the response for the payload (tag to ScalarEvents # mapping). -_TAG_TO_EVENTS_PROPERTY = 'tag_to_events' +_TAG_TO_EVENTS_PROPERTY = "tag_to_events" # The number of seconds to wait in between checks for the config file specifying # layout. @@ -55,199 +55,216 @@ class CustomScalarsPlugin(base_plugin.TBPlugin): - """CustomScalars Plugin for TensorBoard.""" - - plugin_name = metadata.PLUGIN_NAME - - def __init__(self, context): - """Instantiates ScalarsPlugin via TensorBoard core. - - Args: - context: A base_plugin.TBContext instance. - """ - self._logdir = context.logdir - self._multiplexer = context.multiplexer - self._plugin_name_to_instance = context.plugin_name_to_instance - - def _get_scalars_plugin(self): - """Tries to get the scalars plugin. - - Returns: - The scalars plugin. Or None if it is not yet registered. - """ - if scalars_metadata.PLUGIN_NAME in self._plugin_name_to_instance: - # The plugin is registered. - return self._plugin_name_to_instance[scalars_metadata.PLUGIN_NAME] - # The plugin is not yet registered. - return None - - def get_plugin_apps(self): - return { - '/download_data': self.download_data_route, - '/layout': self.layout_route, - '/scalars': self.scalars_route, - } - - def is_active(self): - """This plugin is active if 2 conditions hold. - - 1. The scalars plugin is registered and active. - 2. There is a custom layout for the dashboard. - - Returns: A boolean. Whether the plugin is active. - """ - if not self._multiplexer: - return False - - scalars_plugin_instance = self._get_scalars_plugin() - if not (scalars_plugin_instance and - scalars_plugin_instance.is_active()): - return False - - # This plugin is active if any run has a layout. - return bool(self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME)) - - def frontend_metadata(self): - return base_plugin.FrontendMetadata( - element_name='tf-custom-scalar-dashboard', - tab_name='Custom Scalars', - ) - - @wrappers.Request.application - def download_data_route(self, request): - run = request.args.get('run') - tag = request.args.get('tag') - response_format = request.args.get('format') - try: - body, mime_type = self.download_data_impl(run, tag, response_format) - except ValueError as e: - return http_util.Respond( - request=request, - content=str(e), - content_type='text/plain', - code=400) - return http_util.Respond(request, body, mime_type) - - def download_data_impl(self, run, tag, response_format): - """Provides a response for downloading scalars data for a data series. - - Args: - run: The run. - tag: The specific tag. - response_format: A string. One of the values of the OutputFormat enum of - the scalar plugin. - - Raises: - ValueError: If the scalars plugin is not registered. - - Returns: - 2 entities: - - A JSON object response body. - - A mime type (string) for the response. - """ - scalars_plugin_instance = self._get_scalars_plugin() - if not scalars_plugin_instance: - raise ValueError(('Failed to respond to request for /download_data. ' - 'The scalars plugin is oddly not registered.')) - - body, mime_type = scalars_plugin_instance.scalars_impl( - tag, run, None, response_format) - return body, mime_type - - @wrappers.Request.application - def scalars_route(self, request): - """Given a tag regex and single run, return ScalarEvents. - - This route takes 2 GET params: - run: A run string to find tags for. - tag: A string that is a regex used to find matching tags. - The response is a JSON object: - { - // Whether the regular expression is valid. Also false if empty. - regexValid: boolean, - - // An object mapping tag name to a list of ScalarEvents. - payload: Object, - } - """ - tag_regex_string = request.args.get('tag') - run = request.args.get('run') - mime_type = 'application/json' - - try: - body = self.scalars_impl(run, tag_regex_string) - except ValueError as e: - return http_util.Respond( - request=request, - content=str(e), - content_type='text/plain', - code=400) - - # Produce the response. - return http_util.Respond(request, body, mime_type) - - def scalars_impl(self, run, tag_regex_string): - """Given a tag regex and single run, return ScalarEvents. - - Args: - run: A run string. - tag_regex_string: A regular expression that captures portions of tags. - - Raises: - ValueError: if the scalars plugin is not registered. - - Returns: - A dictionary that is the JSON-able response. - """ - if not tag_regex_string: - # The user provided no regex. - return { - _REGEX_VALID_PROPERTY: False, - _TAG_TO_EVENTS_PROPERTY: {}, - } - - # Construct the regex. - try: - regex = re.compile(tag_regex_string) - except re.error: - return { - _REGEX_VALID_PROPERTY: False, - _TAG_TO_EVENTS_PROPERTY: {}, - } - - # Fetch the tags for the run. Filter for tags that match the regex. - run_to_data = self._multiplexer.PluginRunToTagToContent( - scalars_metadata.PLUGIN_NAME) - - tag_to_data = None - try: - tag_to_data = run_to_data[run] - except KeyError: - # The run could not be found. Perhaps a configuration specified a run that - # TensorBoard has not read from disk yet. - payload = {} - - if tag_to_data: - scalars_plugin_instance = self._get_scalars_plugin() - if not scalars_plugin_instance: - raise ValueError(('Failed to respond to request for /scalars. ' - 'The scalars plugin is oddly not registered.')) - - form = scalars_plugin.OutputFormat.JSON - payload = { - tag: scalars_plugin_instance.scalars_impl(tag, run, None, form)[0] - for tag in tag_to_data.keys() - if regex.match(tag) - } - - return { - _REGEX_VALID_PROPERTY: True, - _TAG_TO_EVENTS_PROPERTY: payload, - } - - @wrappers.Request.application - def layout_route(self, request): - r"""Fetches the custom layout specified by the config file in the logdir. + """CustomScalars Plugin for TensorBoard.""" + + plugin_name = metadata.PLUGIN_NAME + + def __init__(self, context): + """Instantiates ScalarsPlugin via TensorBoard core. + + Args: + context: A base_plugin.TBContext instance. + """ + self._logdir = context.logdir + self._multiplexer = context.multiplexer + self._plugin_name_to_instance = context.plugin_name_to_instance + + def _get_scalars_plugin(self): + """Tries to get the scalars plugin. + + Returns: + The scalars plugin. Or None if it is not yet registered. + """ + if scalars_metadata.PLUGIN_NAME in self._plugin_name_to_instance: + # The plugin is registered. + return self._plugin_name_to_instance[scalars_metadata.PLUGIN_NAME] + # The plugin is not yet registered. + return None + + def get_plugin_apps(self): + return { + "/download_data": self.download_data_route, + "/layout": self.layout_route, + "/scalars": self.scalars_route, + } + + def is_active(self): + """This plugin is active if 2 conditions hold. + + 1. The scalars plugin is registered and active. + 2. There is a custom layout for the dashboard. + + Returns: A boolean. Whether the plugin is active. + """ + if not self._multiplexer: + return False + + scalars_plugin_instance = self._get_scalars_plugin() + if not ( + scalars_plugin_instance and scalars_plugin_instance.is_active() + ): + return False + + # This plugin is active if any run has a layout. + return bool( + self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME) + ) + + def frontend_metadata(self): + return base_plugin.FrontendMetadata( + element_name="tf-custom-scalar-dashboard", + tab_name="Custom Scalars", + ) + + @wrappers.Request.application + def download_data_route(self, request): + run = request.args.get("run") + tag = request.args.get("tag") + response_format = request.args.get("format") + try: + body, mime_type = self.download_data_impl(run, tag, response_format) + except ValueError as e: + return http_util.Respond( + request=request, + content=str(e), + content_type="text/plain", + code=400, + ) + return http_util.Respond(request, body, mime_type) + + def download_data_impl(self, run, tag, response_format): + """Provides a response for downloading scalars data for a data series. + + Args: + run: The run. + tag: The specific tag. + response_format: A string. One of the values of the OutputFormat enum of + the scalar plugin. + + Raises: + ValueError: If the scalars plugin is not registered. + + Returns: + 2 entities: + - A JSON object response body. + - A mime type (string) for the response. + """ + scalars_plugin_instance = self._get_scalars_plugin() + if not scalars_plugin_instance: + raise ValueError( + ( + "Failed to respond to request for /download_data. " + "The scalars plugin is oddly not registered." + ) + ) + + body, mime_type = scalars_plugin_instance.scalars_impl( + tag, run, None, response_format + ) + return body, mime_type + + @wrappers.Request.application + def scalars_route(self, request): + """Given a tag regex and single run, return ScalarEvents. + + This route takes 2 GET params: + run: A run string to find tags for. + tag: A string that is a regex used to find matching tags. + The response is a JSON object: + { + // Whether the regular expression is valid. Also false if empty. + regexValid: boolean, + + // An object mapping tag name to a list of ScalarEvents. + payload: Object, + } + """ + tag_regex_string = request.args.get("tag") + run = request.args.get("run") + mime_type = "application/json" + + try: + body = self.scalars_impl(run, tag_regex_string) + except ValueError as e: + return http_util.Respond( + request=request, + content=str(e), + content_type="text/plain", + code=400, + ) + + # Produce the response. + return http_util.Respond(request, body, mime_type) + + def scalars_impl(self, run, tag_regex_string): + """Given a tag regex and single run, return ScalarEvents. + + Args: + run: A run string. + tag_regex_string: A regular expression that captures portions of tags. + + Raises: + ValueError: if the scalars plugin is not registered. + + Returns: + A dictionary that is the JSON-able response. + """ + if not tag_regex_string: + # The user provided no regex. + return { + _REGEX_VALID_PROPERTY: False, + _TAG_TO_EVENTS_PROPERTY: {}, + } + + # Construct the regex. + try: + regex = re.compile(tag_regex_string) + except re.error: + return { + _REGEX_VALID_PROPERTY: False, + _TAG_TO_EVENTS_PROPERTY: {}, + } + + # Fetch the tags for the run. Filter for tags that match the regex. + run_to_data = self._multiplexer.PluginRunToTagToContent( + scalars_metadata.PLUGIN_NAME + ) + + tag_to_data = None + try: + tag_to_data = run_to_data[run] + except KeyError: + # The run could not be found. Perhaps a configuration specified a run that + # TensorBoard has not read from disk yet. + payload = {} + + if tag_to_data: + scalars_plugin_instance = self._get_scalars_plugin() + if not scalars_plugin_instance: + raise ValueError( + ( + "Failed to respond to request for /scalars. " + "The scalars plugin is oddly not registered." + ) + ) + + form = scalars_plugin.OutputFormat.JSON + payload = { + tag: scalars_plugin_instance.scalars_impl(tag, run, None, form)[ + 0 + ] + for tag in tag_to_data.keys() + if regex.match(tag) + } + + return { + _REGEX_VALID_PROPERTY: True, + _TAG_TO_EVENTS_PROPERTY: payload, + } + + @wrappers.Request.application + def layout_route(self, request): + r"""Fetches the custom layout specified by the config file in the logdir. If more than 1 run contains a layout, this method merges the layouts by merging charts within individual categories. If 2 categories with the same @@ -259,50 +276,60 @@ def layout_route(self, request): The response is an empty object if no layout could be found. """ - body = self.layout_impl() - return http_util.Respond(request, body, 'application/json') - - def layout_impl(self): - # Keep a mapping between and category so we do not create duplicate - # categories. - title_to_category = {} - - merged_layout = None - runs = list(self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME)) - runs.sort() - for run in runs: - tensor_events = self._multiplexer.Tensors( - run, metadata.CONFIG_SUMMARY_TAG) - - # This run has a layout. Merge it with the ones currently found. - string_array = tensor_util.make_ndarray(tensor_events[0].tensor_proto) - content = np.asscalar(string_array) - layout_proto = layout_pb2.Layout() - layout_proto.ParseFromString(tf.compat.as_bytes(content)) - - if merged_layout: - # Append the categories within this layout to the merged layout. - for category in layout_proto.category: - if category.title in title_to_category: - # A category with this name has been seen before. Do not create a - # new one. Merge their charts, skipping any duplicates. - title_to_category[category.title].chart.extend([ - c for c in category.chart - if c not in title_to_category[category.title].chart - ]) - else: - # This category has not been seen before. - merged_layout.category.add().MergeFrom(category) - title_to_category[category.title] = category - else: - # This is the first layout encountered. - merged_layout = layout_proto - for category in layout_proto.category: - title_to_category[category.title] = category - - if merged_layout: - return json_format.MessageToJson( - merged_layout, including_default_value_fields=True) - else: - # No layout was found. - return {} + body = self.layout_impl() + return http_util.Respond(request, body, "application/json") + + def layout_impl(self): + # Keep a mapping between and category so we do not create duplicate + # categories. + title_to_category = {} + + merged_layout = None + runs = list( + self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME) + ) + runs.sort() + for run in runs: + tensor_events = self._multiplexer.Tensors( + run, metadata.CONFIG_SUMMARY_TAG + ) + + # This run has a layout. Merge it with the ones currently found. + string_array = tensor_util.make_ndarray( + tensor_events[0].tensor_proto + ) + content = np.asscalar(string_array) + layout_proto = layout_pb2.Layout() + layout_proto.ParseFromString(tf.compat.as_bytes(content)) + + if merged_layout: + # Append the categories within this layout to the merged layout. + for category in layout_proto.category: + if category.title in title_to_category: + # A category with this name has been seen before. Do not create a + # new one. Merge their charts, skipping any duplicates. + title_to_category[category.title].chart.extend( + [ + c + for c in category.chart + if c + not in title_to_category[category.title].chart + ] + ) + else: + # This category has not been seen before. + merged_layout.category.add().MergeFrom(category) + title_to_category[category.title] = category + else: + # This is the first layout encountered. + merged_layout = layout_proto + for category in layout_proto.category: + title_to_category[category.title] = category + + if merged_layout: + return json_format.MessageToJson( + merged_layout, including_default_value_fields=True + ) + else: + # No layout was found. + return {} diff --git a/tensorboard/plugins/custom_scalar/custom_scalars_plugin_test.py b/tensorboard/plugins/custom_scalar/custom_scalars_plugin_test.py index ca2e59664a..0f4dec9e7d 100644 --- a/tensorboard/plugins/custom_scalar/custom_scalars_plugin_test.py +++ b/tensorboard/plugins/custom_scalar/custom_scalars_plugin_test.py @@ -25,7 +25,9 @@ import tensorflow as tf from google.protobuf import json_format -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from tensorboard.backend.event_processing import ( + plugin_event_multiplexer as event_multiplexer, +) from tensorboard.plugins import base_plugin from tensorboard.plugins.custom_scalar import custom_scalars_plugin from tensorboard.plugins.custom_scalar import layout_pb2 @@ -38,209 +40,244 @@ class CustomScalarsPluginTest(tf.test.TestCase): + def __init__(self, *args, **kwargs): + super(CustomScalarsPluginTest, self).__init__(*args, **kwargs) + self.logdir = os.path.join(self.get_temp_dir(), "logdir") + os.makedirs(self.logdir) - def __init__(self, *args, **kwargs): - super(CustomScalarsPluginTest, self).__init__(*args, **kwargs) - self.logdir = os.path.join(self.get_temp_dir(), 'logdir') - os.makedirs(self.logdir) - - self.logdir_layout = layout_pb2.Layout( - category=[ - layout_pb2.Category( - title='cross entropy', - chart=[ - layout_pb2.Chart( - title='cross entropy', - multiline=layout_pb2.MultilineChartContent( - tag=[r'cross entropy'], - )), - ], - closed=True) + self.logdir_layout = layout_pb2.Layout( + category=[ + layout_pb2.Category( + title="cross entropy", + chart=[ + layout_pb2.Chart( + title="cross entropy", + multiline=layout_pb2.MultilineChartContent( + tag=[r"cross entropy"], + ), + ), + ], + closed=True, + ) ] - ) - self.foo_layout = layout_pb2.Layout( - category=[ - layout_pb2.Category( - title='mean biases', - chart=[ - layout_pb2.Chart( - title='mean layer biases', - multiline=layout_pb2.MultilineChartContent( - tag=[r'mean/layer0/biases', r'mean/layer1/biases'], - )), - ] - ), - layout_pb2.Category( - title='std weights', - chart=[ - layout_pb2.Chart( - title='stddev layer weights', - multiline=layout_pb2.MultilineChartContent( - tag=[r'stddev/layer\d+/weights'], - )), - ] - ), - # A category with this name is also present in a layout for a - # different run (the logdir run) and also contains a duplicate chart - layout_pb2.Category( - title='cross entropy', - chart=[ - layout_pb2.Chart( - title='cross entropy margin chart', - margin=layout_pb2.MarginChartContent( - series=[ - layout_pb2.MarginChartContent.Series( - value='cross entropy', - lower='cross entropy lower', - upper='cross entropy upper'), - ], - )), - layout_pb2.Chart( - title='cross entropy', - multiline=layout_pb2.MultilineChartContent( - tag=[r'cross entropy'], - )), - ] - ), - ] - ) - - # Generate test data. - with test_util.FileWriterCache.get(os.path.join(self.logdir, 'foo')) as writer: - writer.add_summary(summary.pb(self.foo_layout)) - for step in range(4): - writer.add_summary(scalar_summary.pb('squares', step * step), step) - - with test_util.FileWriterCache.get(os.path.join(self.logdir, 'bar')) as writer: - for step in range(3): - writer.add_summary(scalar_summary.pb('increments', step + 1), step) - - # The '.' run lacks scalar data but has a layout. - with test_util.FileWriterCache.get(self.logdir) as writer: - writer.add_summary(summary.pb(self.logdir_layout)) - - self.plugin = self.createPlugin(self.logdir) - - def createPlugin(self, logdir): - multiplexer = event_multiplexer.EventMultiplexer() - multiplexer.AddRunsFromDirectory(logdir) - multiplexer.Reload() - plugin_name_to_instance = {} - context = base_plugin.TBContext( - logdir=logdir, - multiplexer=multiplexer, - plugin_name_to_instance=plugin_name_to_instance) - scalars_plugin_instance = scalars_plugin.ScalarsPlugin(context) - custom_scalars_plugin_instance = custom_scalars_plugin.CustomScalarsPlugin( - context) - plugin_instances = [scalars_plugin_instance, custom_scalars_plugin_instance] - for plugin_instance in plugin_instances: - plugin_name_to_instance[plugin_instance.plugin_name] = plugin_instance - return custom_scalars_plugin_instance - - def testDownloadData(self): - body, mime_type = self.plugin.download_data_impl( - 'foo', 'squares/scalar_summary', 'json') - self.assertEqual('application/json', mime_type) - self.assertEqual(4, len(body)) - for step, entry in enumerate(body): - # The time stamp should be reasonable. - self.assertGreater(entry[0], 0) - self.assertEqual(step, entry[1]) - np.testing.assert_allclose(step * step, entry[2]) - - def testScalars(self): - body = self.plugin.scalars_impl('bar', 'increments') - self.assertTrue(body['regex_valid']) - self.assertItemsEqual( - ['increments/scalar_summary'], list(body['tag_to_events'].keys())) - data = body['tag_to_events']['increments/scalar_summary'] - for step, entry in enumerate(data): - # The time stamp should be reasonable. - self.assertGreater(entry[0], 0) - self.assertEqual(step, entry[1]) - np.testing.assert_allclose(step + 1, entry[2]) - - def testMergedLayout(self): - parsed_layout = layout_pb2.Layout() - json_format.Parse(self.plugin.layout_impl(), parsed_layout) - correct_layout = layout_pb2.Layout( - category=[ - # A category with this name is also present in a layout for a - # different run (the logdir run) - layout_pb2.Category( - title='cross entropy', - chart=[ - layout_pb2.Chart( - title='cross entropy', - multiline=layout_pb2.MultilineChartContent( - tag=[r'cross entropy'], - )), - layout_pb2.Chart( - title='cross entropy margin chart', - margin=layout_pb2.MarginChartContent( - series=[ - layout_pb2.MarginChartContent.Series( - value='cross entropy', - lower='cross entropy lower', - upper='cross entropy upper'), - ], - )), - ], - closed=True, - ), - layout_pb2.Category( - title='mean biases', - chart=[ - layout_pb2.Chart( - title='mean layer biases', - multiline=layout_pb2.MultilineChartContent( - tag=[r'mean/layer0/biases', r'mean/layer1/biases'], - )), - ] - ), - layout_pb2.Category( - title='std weights', - chart=[ - layout_pb2.Chart( - title='stddev layer weights', - multiline=layout_pb2.MultilineChartContent( - tag=[r'stddev/layer\d+/weights'], - )), - ] - ), + ) + self.foo_layout = layout_pb2.Layout( + category=[ + layout_pb2.Category( + title="mean biases", + chart=[ + layout_pb2.Chart( + title="mean layer biases", + multiline=layout_pb2.MultilineChartContent( + tag=[ + r"mean/layer0/biases", + r"mean/layer1/biases", + ], + ), + ), + ], + ), + layout_pb2.Category( + title="std weights", + chart=[ + layout_pb2.Chart( + title="stddev layer weights", + multiline=layout_pb2.MultilineChartContent( + tag=[r"stddev/layer\d+/weights"], + ), + ), + ], + ), + # A category with this name is also present in a layout for a + # different run (the logdir run) and also contains a duplicate chart + layout_pb2.Category( + title="cross entropy", + chart=[ + layout_pb2.Chart( + title="cross entropy margin chart", + margin=layout_pb2.MarginChartContent( + series=[ + layout_pb2.MarginChartContent.Series( + value="cross entropy", + lower="cross entropy lower", + upper="cross entropy upper", + ), + ], + ), + ), + layout_pb2.Chart( + title="cross entropy", + multiline=layout_pb2.MultilineChartContent( + tag=[r"cross entropy"], + ), + ), + ], + ), + ] + ) + + # Generate test data. + with test_util.FileWriterCache.get( + os.path.join(self.logdir, "foo") + ) as writer: + writer.add_summary(summary.pb(self.foo_layout)) + for step in range(4): + writer.add_summary( + scalar_summary.pb("squares", step * step), step + ) + + with test_util.FileWriterCache.get( + os.path.join(self.logdir, "bar") + ) as writer: + for step in range(3): + writer.add_summary( + scalar_summary.pb("increments", step + 1), step + ) + + # The '.' run lacks scalar data but has a layout. + with test_util.FileWriterCache.get(self.logdir) as writer: + writer.add_summary(summary.pb(self.logdir_layout)) + + self.plugin = self.createPlugin(self.logdir) + + def createPlugin(self, logdir): + multiplexer = event_multiplexer.EventMultiplexer() + multiplexer.AddRunsFromDirectory(logdir) + multiplexer.Reload() + plugin_name_to_instance = {} + context = base_plugin.TBContext( + logdir=logdir, + multiplexer=multiplexer, + plugin_name_to_instance=plugin_name_to_instance, + ) + scalars_plugin_instance = scalars_plugin.ScalarsPlugin(context) + custom_scalars_plugin_instance = custom_scalars_plugin.CustomScalarsPlugin( + context + ) + plugin_instances = [ + scalars_plugin_instance, + custom_scalars_plugin_instance, ] - ) - self.assertProtoEquals(correct_layout, parsed_layout) - - def testLayoutFromSingleRun(self): - # The foo directory contains 1 single layout. - local_plugin = self.createPlugin(os.path.join(self.logdir, 'foo')) - parsed_layout = layout_pb2.Layout() - json_format.Parse(local_plugin.layout_impl(), parsed_layout) - self.assertProtoEquals(self.foo_layout, parsed_layout) - - def testNoLayoutFound(self): - # The bar directory contains no layout. - local_plugin = self.createPlugin(os.path.join(self.logdir, 'bar')) - self.assertDictEqual({}, local_plugin.layout_impl()) - - def testIsActive(self): - self.assertTrue(self.plugin.is_active()) - - def testIsNotActiveDueToNoLayout(self): - # The bar directory contains scalar data but no layout. - local_plugin = self.createPlugin(os.path.join(self.logdir, 'bar')) - self.assertFalse(local_plugin.is_active()) - - def testIsNotActiveDueToNoScalarsData(self): - # Generate a directory with a layout but no scalars data. - directory = os.path.join(self.logdir, 'no_scalars') - with test_util.FileWriterCache.get(directory) as writer: - writer.add_summary(summary.pb(self.logdir_layout)) - - local_plugin = self.createPlugin(directory) - self.assertFalse(local_plugin.is_active()) + for plugin_instance in plugin_instances: + plugin_name_to_instance[ + plugin_instance.plugin_name + ] = plugin_instance + return custom_scalars_plugin_instance + + def testDownloadData(self): + body, mime_type = self.plugin.download_data_impl( + "foo", "squares/scalar_summary", "json" + ) + self.assertEqual("application/json", mime_type) + self.assertEqual(4, len(body)) + for step, entry in enumerate(body): + # The time stamp should be reasonable. + self.assertGreater(entry[0], 0) + self.assertEqual(step, entry[1]) + np.testing.assert_allclose(step * step, entry[2]) + + def testScalars(self): + body = self.plugin.scalars_impl("bar", "increments") + self.assertTrue(body["regex_valid"]) + self.assertItemsEqual( + ["increments/scalar_summary"], list(body["tag_to_events"].keys()) + ) + data = body["tag_to_events"]["increments/scalar_summary"] + for step, entry in enumerate(data): + # The time stamp should be reasonable. + self.assertGreater(entry[0], 0) + self.assertEqual(step, entry[1]) + np.testing.assert_allclose(step + 1, entry[2]) + + def testMergedLayout(self): + parsed_layout = layout_pb2.Layout() + json_format.Parse(self.plugin.layout_impl(), parsed_layout) + correct_layout = layout_pb2.Layout( + category=[ + # A category with this name is also present in a layout for a + # different run (the logdir run) + layout_pb2.Category( + title="cross entropy", + chart=[ + layout_pb2.Chart( + title="cross entropy", + multiline=layout_pb2.MultilineChartContent( + tag=[r"cross entropy"], + ), + ), + layout_pb2.Chart( + title="cross entropy margin chart", + margin=layout_pb2.MarginChartContent( + series=[ + layout_pb2.MarginChartContent.Series( + value="cross entropy", + lower="cross entropy lower", + upper="cross entropy upper", + ), + ], + ), + ), + ], + closed=True, + ), + layout_pb2.Category( + title="mean biases", + chart=[ + layout_pb2.Chart( + title="mean layer biases", + multiline=layout_pb2.MultilineChartContent( + tag=[ + r"mean/layer0/biases", + r"mean/layer1/biases", + ], + ), + ), + ], + ), + layout_pb2.Category( + title="std weights", + chart=[ + layout_pb2.Chart( + title="stddev layer weights", + multiline=layout_pb2.MultilineChartContent( + tag=[r"stddev/layer\d+/weights"], + ), + ), + ], + ), + ] + ) + self.assertProtoEquals(correct_layout, parsed_layout) + + def testLayoutFromSingleRun(self): + # The foo directory contains 1 single layout. + local_plugin = self.createPlugin(os.path.join(self.logdir, "foo")) + parsed_layout = layout_pb2.Layout() + json_format.Parse(local_plugin.layout_impl(), parsed_layout) + self.assertProtoEquals(self.foo_layout, parsed_layout) + + def testNoLayoutFound(self): + # The bar directory contains no layout. + local_plugin = self.createPlugin(os.path.join(self.logdir, "bar")) + self.assertDictEqual({}, local_plugin.layout_impl()) + + def testIsActive(self): + self.assertTrue(self.plugin.is_active()) + + def testIsNotActiveDueToNoLayout(self): + # The bar directory contains scalar data but no layout. + local_plugin = self.createPlugin(os.path.join(self.logdir, "bar")) + self.assertFalse(local_plugin.is_active()) + + def testIsNotActiveDueToNoScalarsData(self): + # Generate a directory with a layout but no scalars data. + directory = os.path.join(self.logdir, "no_scalars") + with test_util.FileWriterCache.get(directory) as writer: + writer.add_summary(summary.pb(self.logdir_layout)) + + local_plugin = self.createPlugin(directory) + self.assertFalse(local_plugin.is_active()) + if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/tensorboard/plugins/custom_scalar/metadata.py b/tensorboard/plugins/custom_scalar/metadata.py index 6fd6399dea..c71980537f 100644 --- a/tensorboard/plugins/custom_scalar/metadata.py +++ b/tensorboard/plugins/custom_scalar/metadata.py @@ -21,16 +21,19 @@ from tensorboard.compat.proto import summary_pb2 # A special tag named used for the summary that stores the layout. -CONFIG_SUMMARY_TAG = 'custom_scalars__config__' +CONFIG_SUMMARY_TAG = "custom_scalars__config__" + +PLUGIN_NAME = "custom_scalars" -PLUGIN_NAME = 'custom_scalars' def create_summary_metadata(): - """Create a `SummaryMetadata` proto for custom scalar plugin data. + """Create a `SummaryMetadata` proto for custom scalar plugin data. - Returns: - A `summary_pb2.SummaryMetadata` protobuf object. - """ - return summary_pb2.SummaryMetadata( - plugin_data=summary_pb2.SummaryMetadata.PluginData( - plugin_name=PLUGIN_NAME)) + Returns: + A `summary_pb2.SummaryMetadata` protobuf object. + """ + return summary_pb2.SummaryMetadata( + plugin_data=summary_pb2.SummaryMetadata.PluginData( + plugin_name=PLUGIN_NAME + ) + ) diff --git a/tensorboard/plugins/custom_scalar/summary.py b/tensorboard/plugins/custom_scalar/summary.py index 7dea66b26a..230da905ab 100644 --- a/tensorboard/plugins/custom_scalar/summary.py +++ b/tensorboard/plugins/custom_scalar/summary.py @@ -1,4 +1,3 @@ - # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Contains summaries related to laying out the custom scalars dashboard. -""" +"""Contains summaries related to laying out the custom scalars dashboard.""" from __future__ import absolute_import from __future__ import division @@ -25,57 +23,61 @@ def op(scalars_layout, collections=None): - """Creates a summary that contains a layout. + """Creates a summary that contains a layout. - When users navigate to the custom scalars dashboard, they will see a layout - based on the proto provided to this function. + When users navigate to the custom scalars dashboard, they will see a layout + based on the proto provided to this function. - Args: - scalars_layout: The scalars_layout_pb2.Layout proto that specifies the - layout. - collections: Optional list of graph collections keys. The new - summary op is added to these collections. Defaults to - `[Graph Keys.SUMMARIES]`. + Args: + scalars_layout: The scalars_layout_pb2.Layout proto that specifies the + layout. + collections: Optional list of graph collections keys. The new + summary op is added to these collections. Defaults to + `[Graph Keys.SUMMARIES]`. - Returns: - A tensor summary op that writes the layout to disk. - """ - # TODO(nickfelt): remove on-demand imports once dep situation is fixed. - import tensorflow.compat.v1 as tf + Returns: + A tensor summary op that writes the layout to disk. + """ + # TODO(nickfelt): remove on-demand imports once dep situation is fixed. + import tensorflow.compat.v1 as tf - assert isinstance(scalars_layout, layout_pb2.Layout) - summary_metadata = metadata.create_summary_metadata() - return tf.summary.tensor_summary(name=metadata.CONFIG_SUMMARY_TAG, - tensor=tf.constant( - scalars_layout.SerializeToString(), - dtype=tf.string), - collections=collections, - summary_metadata=summary_metadata) + assert isinstance(scalars_layout, layout_pb2.Layout) + summary_metadata = metadata.create_summary_metadata() + return tf.summary.tensor_summary( + name=metadata.CONFIG_SUMMARY_TAG, + tensor=tf.constant(scalars_layout.SerializeToString(), dtype=tf.string), + collections=collections, + summary_metadata=summary_metadata, + ) def pb(scalars_layout): - """Creates a summary that contains a layout. + """Creates a summary that contains a layout. - When users navigate to the custom scalars dashboard, they will see a layout - based on the proto provided to this function. + When users navigate to the custom scalars dashboard, they will see a layout + based on the proto provided to this function. - Args: - scalars_layout: The scalars_layout_pb2.Layout proto that specifies the - layout. + Args: + scalars_layout: The scalars_layout_pb2.Layout proto that specifies the + layout. - Returns: - A summary proto containing the layout. - """ - # TODO(nickfelt): remove on-demand imports once dep situation is fixed. - import tensorflow.compat.v1 as tf + Returns: + A summary proto containing the layout. + """ + # TODO(nickfelt): remove on-demand imports once dep situation is fixed. + import tensorflow.compat.v1 as tf - assert isinstance(scalars_layout, layout_pb2.Layout) - tensor = tf.make_tensor_proto( - scalars_layout.SerializeToString(), dtype=tf.string) - tf_summary_metadata = tf.SummaryMetadata.FromString( - metadata.create_summary_metadata().SerializeToString()) - summary = tf.Summary() - summary.value.add(tag=metadata.CONFIG_SUMMARY_TAG, - metadata=tf_summary_metadata, - tensor=tensor) - return summary + assert isinstance(scalars_layout, layout_pb2.Layout) + tensor = tf.make_tensor_proto( + scalars_layout.SerializeToString(), dtype=tf.string + ) + tf_summary_metadata = tf.SummaryMetadata.FromString( + metadata.create_summary_metadata().SerializeToString() + ) + summary = tf.Summary() + summary.value.add( + tag=metadata.CONFIG_SUMMARY_TAG, + metadata=tf_summary_metadata, + tensor=tensor, + ) + return summary diff --git a/tensorboard/plugins/custom_scalar/summary_test.py b/tensorboard/plugins/custom_scalar/summary_test.py index dbd725bb9e..5bd7990c2e 100644 --- a/tensorboard/plugins/custom_scalar/summary_test.py +++ b/tensorboard/plugins/custom_scalar/summary_test.py @@ -1,4 +1,3 @@ - # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -22,7 +21,9 @@ import numpy as np import tensorflow as tf -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from tensorboard.backend.event_processing import ( + plugin_event_multiplexer as event_multiplexer, +) from tensorboard.plugins.custom_scalar import layout_pb2 from tensorboard.plugins.custom_scalar import metadata from tensorboard.plugins.custom_scalar import summary @@ -31,79 +32,90 @@ class LayoutTest(tf.test.TestCase): + def setUp(self): + super(LayoutTest, self).setUp() + self.logdir = self.get_temp_dir() - def setUp(self): - super(LayoutTest, self).setUp() - self.logdir = self.get_temp_dir() - - def testSetLayout(self): - layout_proto_to_write = layout_pb2.Layout( - category=[ - layout_pb2.Category( - title='mean biases', - chart=[ - layout_pb2.Chart( - title='mean layer biases', - multiline=layout_pb2.MultilineChartContent( - tag=[r'mean/layer\d+/biases'], - )), - ]), - layout_pb2.Category( - title='std weights', - chart=[ - layout_pb2.Chart( - title='stddev layer weights', - multiline=layout_pb2.MultilineChartContent( - tag=[r'stddev/layer\d+/weights'], - )), - ]), - layout_pb2.Category( - title='cross entropy ... and maybe some other values', - chart=[ - layout_pb2.Chart( - title='cross entropy', - multiline=layout_pb2.MultilineChartContent( - tag=[r'cross entropy'], - )), - layout_pb2.Chart( - title='accuracy', - margin=layout_pb2.MarginChartContent( - series=[ - layout_pb2.MarginChartContent.Series( - value='accuracy', - lower='accuracy_lower_margin', - upper='accuracy_upper_margin') - ] - )), - layout_pb2.Chart( - title='max layer weights', - multiline=layout_pb2.MultilineChartContent( - tag=[r'max/layer1/.*', r'max/layer2/.*'], - )), - ], - closed=True) - ]) + def testSetLayout(self): + layout_proto_to_write = layout_pb2.Layout( + category=[ + layout_pb2.Category( + title="mean biases", + chart=[ + layout_pb2.Chart( + title="mean layer biases", + multiline=layout_pb2.MultilineChartContent( + tag=[r"mean/layer\d+/biases"], + ), + ), + ], + ), + layout_pb2.Category( + title="std weights", + chart=[ + layout_pb2.Chart( + title="stddev layer weights", + multiline=layout_pb2.MultilineChartContent( + tag=[r"stddev/layer\d+/weights"], + ), + ), + ], + ), + layout_pb2.Category( + title="cross entropy ... and maybe some other values", + chart=[ + layout_pb2.Chart( + title="cross entropy", + multiline=layout_pb2.MultilineChartContent( + tag=[r"cross entropy"], + ), + ), + layout_pb2.Chart( + title="accuracy", + margin=layout_pb2.MarginChartContent( + series=[ + layout_pb2.MarginChartContent.Series( + value="accuracy", + lower="accuracy_lower_margin", + upper="accuracy_upper_margin", + ) + ] + ), + ), + layout_pb2.Chart( + title="max layer weights", + multiline=layout_pb2.MultilineChartContent( + tag=[r"max/layer1/.*", r"max/layer2/.*"], + ), + ), + ], + closed=True, + ), + ] + ) - # Write the data as a summary for the '.' run. - with tf.compat.v1.Session() as s, test_util.FileWriterCache.get(self.logdir) as writer: - writer.add_summary(s.run(summary.op(layout_proto_to_write))) + # Write the data as a summary for the '.' run. + with tf.compat.v1.Session() as s, test_util.FileWriterCache.get( + self.logdir + ) as writer: + writer.add_summary(s.run(summary.op(layout_proto_to_write))) - # Read the data from disk. - multiplexer = event_multiplexer.EventMultiplexer() - multiplexer.AddRunsFromDirectory(self.logdir) - multiplexer.Reload() - tensor_events = multiplexer.Tensors('.', metadata.CONFIG_SUMMARY_TAG) - self.assertEqual(1, len(tensor_events)) + # Read the data from disk. + multiplexer = event_multiplexer.EventMultiplexer() + multiplexer.AddRunsFromDirectory(self.logdir) + multiplexer.Reload() + tensor_events = multiplexer.Tensors(".", metadata.CONFIG_SUMMARY_TAG) + self.assertEqual(1, len(tensor_events)) - # Parse the data. - string_array = tensor_util.make_ndarray(tensor_events[0].tensor_proto) - content = np.asscalar(string_array) - layout_proto_from_disk = layout_pb2.Layout() - layout_proto_from_disk.ParseFromString(tf.compat.as_bytes(content)) + # Parse the data. + string_array = tensor_util.make_ndarray(tensor_events[0].tensor_proto) + content = np.asscalar(string_array) + layout_proto_from_disk = layout_pb2.Layout() + layout_proto_from_disk.ParseFromString(tf.compat.as_bytes(content)) - # Verify the content. - self.assertProtoEquals(layout_proto_to_write, layout_proto_from_disk) + # Verify the content. + self.assertProtoEquals(layout_proto_to_write, layout_proto_from_disk) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/tensorboard/plugins/debugger/comm_channel.py b/tensorboard/plugins/debugger/comm_channel.py index d27aecbede..8def0a1552 100644 --- a/tensorboard/plugins/debugger/comm_channel.py +++ b/tensorboard/plugins/debugger/comm_channel.py @@ -24,80 +24,81 @@ class CommChannel(object): - """A class that handles the queueing of outgoing and incoming messages. - - CommChannel is a multi-consumer interface that serves the following purposes: - - 1) Keeps track of all the messages that it has received from the caller of - put_outgoing(). In the case of TDP, these are messages about the start of - Session.runs() and the pausing events at tensor breakpoints. These messages - are kept in the order they are received. These messages are organized in - memory by a serial index starting from 1. Since the messages are maintained - in the memory indefinitely, they ought to be small in size. - # TODO(cais): If the need arises, persist the messages. - 2) Allows the callers of get_outgoing() to retrieve any message by a serial - index (also referred to as "position") at anytime. Notice that we want to - support multiple callers because more than once browser sessions may need - to connect to the backend simultaneously. If a caller of get_outgoing() - requests a serial that has not been received from put_going() yet, the - get_ougoing() call will block until a message is received at that position. - """ - - def __init__(self): - self._outgoing = [] - self._outgoing_counter = 0 - self._outgoing_lock = threading.Lock() - self._outgoing_pending_queues = dict() - - def put(self, message): - """Put a message into the outgoing message stack. - - Outgoing message will be stored indefinitely to support multi-users. + """A class that handles the queueing of outgoing and incoming messages. + + CommChannel is a multi-consumer interface that serves the following purposes: + + 1) Keeps track of all the messages that it has received from the caller of + put_outgoing(). In the case of TDP, these are messages about the start of + Session.runs() and the pausing events at tensor breakpoints. These messages + are kept in the order they are received. These messages are organized in + memory by a serial index starting from 1. Since the messages are maintained + in the memory indefinitely, they ought to be small in size. + # TODO(cais): If the need arises, persist the messages. + 2) Allows the callers of get_outgoing() to retrieve any message by a serial + index (also referred to as "position") at anytime. Notice that we want to + support multiple callers because more than once browser sessions may need + to connect to the backend simultaneously. If a caller of get_outgoing() + requests a serial that has not been received from put_going() yet, the + get_ougoing() call will block until a message is received at that position. """ - with self._outgoing_lock: - self._outgoing.append(message) - self._outgoing_counter += 1 - - # Check to see if there are pending queues waiting for the item. - if self._outgoing_counter in self._outgoing_pending_queues: - for q in self._outgoing_pending_queues[self._outgoing_counter]: - q.put(message) - del self._outgoing_pending_queues[self._outgoing_counter] - - def get(self, pos): - """Get message(s) from the outgoing message stack. - - Blocks until an item at stack position pos becomes available. - This method is thread safe. - - Args: - pos: An int specifying the top position of the message stack to access. - For example, if the stack counter is at 3 and pos == 2, then the 2nd - item on the stack will be returned, together with an int that indicates - the current stack heigh (3 in this case). - - Returns: - 1. The item at stack position pos. - 2. The height of the stack when the retun values are generated. - - Raises: - ValueError: If input `pos` is zero or negative. - """ - if pos <= 0: - raise ValueError('Invalid pos %d: pos must be > 0' % pos) - with self._outgoing_lock: - if self._outgoing_counter >= pos: - # If the stack already has the requested position, return the value - # immediately. - return self._outgoing[pos - 1], self._outgoing_counter - else: - # If the stack has not reached the requested position yet, create a - # queue and block on get(). - if pos not in self._outgoing_pending_queues: - self._outgoing_pending_queues[pos] = [] - q = queue.Queue(maxsize=1) - self._outgoing_pending_queues[pos].append(q) - - value = q.get() - with self._outgoing_lock: - return value, self._outgoing_counter + + def __init__(self): + self._outgoing = [] + self._outgoing_counter = 0 + self._outgoing_lock = threading.Lock() + self._outgoing_pending_queues = dict() + + def put(self, message): + """Put a message into the outgoing message stack. + + Outgoing message will be stored indefinitely to support multi- + users. + """ + with self._outgoing_lock: + self._outgoing.append(message) + self._outgoing_counter += 1 + + # Check to see if there are pending queues waiting for the item. + if self._outgoing_counter in self._outgoing_pending_queues: + for q in self._outgoing_pending_queues[self._outgoing_counter]: + q.put(message) + del self._outgoing_pending_queues[self._outgoing_counter] + + def get(self, pos): + """Get message(s) from the outgoing message stack. + + Blocks until an item at stack position pos becomes available. + This method is thread safe. + + Args: + pos: An int specifying the top position of the message stack to access. + For example, if the stack counter is at 3 and pos == 2, then the 2nd + item on the stack will be returned, together with an int that indicates + the current stack heigh (3 in this case). + + Returns: + 1. The item at stack position pos. + 2. The height of the stack when the retun values are generated. + + Raises: + ValueError: If input `pos` is zero or negative. + """ + if pos <= 0: + raise ValueError("Invalid pos %d: pos must be > 0" % pos) + with self._outgoing_lock: + if self._outgoing_counter >= pos: + # If the stack already has the requested position, return the value + # immediately. + return self._outgoing[pos - 1], self._outgoing_counter + else: + # If the stack has not reached the requested position yet, create a + # queue and block on get(). + if pos not in self._outgoing_pending_queues: + self._outgoing_pending_queues[pos] = [] + q = queue.Queue(maxsize=1) + self._outgoing_pending_queues[pos].append(q) + + value = q.get() + with self._outgoing_lock: + return value, self._outgoing_counter diff --git a/tensorboard/plugins/debugger/comm_channel_test.py b/tensorboard/plugins/debugger/comm_channel_test.py index d647b95e92..2ff8fb5641 100644 --- a/tensorboard/plugins/debugger/comm_channel_test.py +++ b/tensorboard/plugins/debugger/comm_channel_test.py @@ -26,74 +26,76 @@ class CommChannelTest(tf.test.TestCase): - - def testGetOutgoingWithInvalidPosLeadsToAssertionError(self): - channel = comm_channel.CommChannel() - with self.assertRaises(ValueError): - channel.get(0) - with self.assertRaises(ValueError): - channel.get(-1) - - def testOutgoingSerialPutOneAndGetOne(self): - channel = comm_channel.CommChannel() - channel.put('A') - self.assertEqual(('A', 1), channel.get(1)) - - def testOutgoingSerialPutTwoGetOne(self): - channel = comm_channel.CommChannel() - channel.put('A') - channel.put('B') - channel.put('C') - self.assertEqual(('A', 3), channel.get(1)) - self.assertEqual(('B', 3), channel.get(2)) - self.assertEqual(('C', 3), channel.get(3)) - - def testOutgoingConcurrentPutAndOneGetter(self): - channel = comm_channel.CommChannel() - - result = {'outgoing': []} - def get_two(): - result['outgoing'].append(channel.get(1)) - result['outgoing'].append(channel.get(2)) - - t = threading.Thread(target=get_two) - t.start() - channel.put('A') - channel.put('B') - t.join() - self.assertEqual('A', result['outgoing'][0][0]) - self.assertIn(result['outgoing'][0][1], [1, 2]) - self.assertEqual(('B', 2), result['outgoing'][1]) - - def testOutgoingConcurrentPutAndTwoGetters(self): - channel = comm_channel.CommChannel() - - result1 = {'outgoing': []} - result2 = {'outgoing': []} - def getter1(): - result1['outgoing'].append(channel.get(1)) - result1['outgoing'].append(channel.get(2)) - def getter2(): - result2['outgoing'].append(channel.get(1)) - result2['outgoing'].append(channel.get(2)) - - t1 = threading.Thread(target=getter1) - t1.start() - t2 = threading.Thread(target=getter2) - t2.start() - - channel.put('A') - channel.put('B') - t1.join() - t2.join() - - self.assertEqual('A', result1['outgoing'][0][0]) - self.assertIn(result1['outgoing'][0][1], [1, 2]) - self.assertEqual(('B', 2), result1['outgoing'][1]) - self.assertEqual('A', result2['outgoing'][0][0]) - self.assertIn(result2['outgoing'][0][1], [1, 2]) - self.assertEqual(('B', 2), result2['outgoing'][1]) - - -if __name__ == '__main__': - tf.test.main() + def testGetOutgoingWithInvalidPosLeadsToAssertionError(self): + channel = comm_channel.CommChannel() + with self.assertRaises(ValueError): + channel.get(0) + with self.assertRaises(ValueError): + channel.get(-1) + + def testOutgoingSerialPutOneAndGetOne(self): + channel = comm_channel.CommChannel() + channel.put("A") + self.assertEqual(("A", 1), channel.get(1)) + + def testOutgoingSerialPutTwoGetOne(self): + channel = comm_channel.CommChannel() + channel.put("A") + channel.put("B") + channel.put("C") + self.assertEqual(("A", 3), channel.get(1)) + self.assertEqual(("B", 3), channel.get(2)) + self.assertEqual(("C", 3), channel.get(3)) + + def testOutgoingConcurrentPutAndOneGetter(self): + channel = comm_channel.CommChannel() + + result = {"outgoing": []} + + def get_two(): + result["outgoing"].append(channel.get(1)) + result["outgoing"].append(channel.get(2)) + + t = threading.Thread(target=get_two) + t.start() + channel.put("A") + channel.put("B") + t.join() + self.assertEqual("A", result["outgoing"][0][0]) + self.assertIn(result["outgoing"][0][1], [1, 2]) + self.assertEqual(("B", 2), result["outgoing"][1]) + + def testOutgoingConcurrentPutAndTwoGetters(self): + channel = comm_channel.CommChannel() + + result1 = {"outgoing": []} + result2 = {"outgoing": []} + + def getter1(): + result1["outgoing"].append(channel.get(1)) + result1["outgoing"].append(channel.get(2)) + + def getter2(): + result2["outgoing"].append(channel.get(1)) + result2["outgoing"].append(channel.get(2)) + + t1 = threading.Thread(target=getter1) + t1.start() + t2 = threading.Thread(target=getter2) + t2.start() + + channel.put("A") + channel.put("B") + t1.join() + t2.join() + + self.assertEqual("A", result1["outgoing"][0][0]) + self.assertIn(result1["outgoing"][0][1], [1, 2]) + self.assertEqual(("B", 2), result1["outgoing"][1]) + self.assertEqual("A", result2["outgoing"][0][0]) + self.assertIn(result2["outgoing"][0][1], [1, 2]) + self.assertEqual(("B", 2), result2["outgoing"][1]) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/debugger/debug_graphs_helper.py b/tensorboard/plugins/debugger/debug_graphs_helper.py index 728aa4a347..b9ef51058c 100644 --- a/tensorboard/plugins/debugger/debug_graphs_helper.py +++ b/tensorboard/plugins/debugger/debug_graphs_helper.py @@ -24,88 +24,110 @@ class DebugGraphWrapper(object): - """A wrapper for potentially debugger-decorated GraphDef.""" - - def __init__(self, graph_def): - self._graph_def = graph_def - # A map from debug op to list of debug-op-attached tensors. - self._grpc_gated_tensors = dict() - self._grpc_gated_lock = threading.Lock() - self._maybe_base_expanded_node_names = None - self._node_name_lock = threading.Lock() - - def get_gated_grpc_tensors(self, matching_debug_op=None): - """Extract all nodes with gated-gRPC debug ops attached. - - Uses cached values if available. - This method is thread-safe. - - Args: - graph_def: A tf.GraphDef proto. - matching_debug_op: Return tensors and nodes with only matching the - specified debug op name (optional). If `None`, will extract only - `DebugIdentity` debug ops. - - Returns: - A list of (node_name, op_type, output_slot, debug_op) tuples. - """ - with self._grpc_gated_lock: - matching_debug_op = matching_debug_op or 'DebugIdentity' - if matching_debug_op not in self._grpc_gated_tensors: - # First, construct a map from node name to op type. - node_name_to_op_type = dict( - (node.name, node.op) for node in self._graph_def.node) - - # Second, populate the output list. - gated = [] - for node in self._graph_def.node: - if node.op == matching_debug_op: - for attr_key in node.attr: - if attr_key == 'gated_grpc' and node.attr[attr_key].b: - node_name, output_slot, _, debug_op = ( - debug_graphs.parse_debug_node_name(node.name)) - gated.append( - (node_name, node_name_to_op_type[node_name], output_slot, - debug_op)) - break - self._grpc_gated_tensors[matching_debug_op] = gated - - return self._grpc_gated_tensors[matching_debug_op] - - def maybe_base_expanded_node_name(self, node_name): - """Expand the base name if there are node names nested under the node. - - For example, if there are two nodes in the graph, "a" and "a/read", then - calling this function on "a" will give "a/(a)", a form that points at - a leaf node in the nested TensorBoard graph. Calling this function on - "a/read" will just return "a/read", because there is no node nested under - it. - - This method is thread-safe. - - Args: - node_name: Name of the node. - graph_def: The `GraphDef` that the node is a part of. - - Returns: - Possibly base-expanded node name. - """ - with self._node_name_lock: - # Lazily populate the map from original node name to base-expanded ones. - if self._maybe_base_expanded_node_names is None: - self._maybe_base_expanded_node_names = dict() - # Sort all the node names. - sorted_names = sorted(node.name for node in self._graph_def.node) - for i, name in enumerate(sorted_names): - j = i + 1 - while j < len(sorted_names) and sorted_names[j].startswith(name): - if sorted_names[j].startswith(name + '/'): - self._maybe_base_expanded_node_names[name] = ( - name + '/(' + name.split('/')[-1] + ')') - break - j += 1 - return self._maybe_base_expanded_node_names.get(node_name, node_name) - - @property - def graph_def(self): - return self._graph_def + """A wrapper for potentially debugger-decorated GraphDef.""" + + def __init__(self, graph_def): + self._graph_def = graph_def + # A map from debug op to list of debug-op-attached tensors. + self._grpc_gated_tensors = dict() + self._grpc_gated_lock = threading.Lock() + self._maybe_base_expanded_node_names = None + self._node_name_lock = threading.Lock() + + def get_gated_grpc_tensors(self, matching_debug_op=None): + """Extract all nodes with gated-gRPC debug ops attached. + + Uses cached values if available. + This method is thread-safe. + + Args: + graph_def: A tf.GraphDef proto. + matching_debug_op: Return tensors and nodes with only matching the + specified debug op name (optional). If `None`, will extract only + `DebugIdentity` debug ops. + + Returns: + A list of (node_name, op_type, output_slot, debug_op) tuples. + """ + with self._grpc_gated_lock: + matching_debug_op = matching_debug_op or "DebugIdentity" + if matching_debug_op not in self._grpc_gated_tensors: + # First, construct a map from node name to op type. + node_name_to_op_type = dict( + (node.name, node.op) for node in self._graph_def.node + ) + + # Second, populate the output list. + gated = [] + for node in self._graph_def.node: + if node.op == matching_debug_op: + for attr_key in node.attr: + if ( + attr_key == "gated_grpc" + and node.attr[attr_key].b + ): + ( + node_name, + output_slot, + _, + debug_op, + ) = debug_graphs.parse_debug_node_name( + node.name + ) + gated.append( + ( + node_name, + node_name_to_op_type[node_name], + output_slot, + debug_op, + ) + ) + break + self._grpc_gated_tensors[matching_debug_op] = gated + + return self._grpc_gated_tensors[matching_debug_op] + + def maybe_base_expanded_node_name(self, node_name): + """Expand the base name if there are node names nested under the node. + + For example, if there are two nodes in the graph, "a" and "a/read", then + calling this function on "a" will give "a/(a)", a form that points at + a leaf node in the nested TensorBoard graph. Calling this function on + "a/read" will just return "a/read", because there is no node nested under + it. + + This method is thread-safe. + + Args: + node_name: Name of the node. + graph_def: The `GraphDef` that the node is a part of. + + Returns: + Possibly base-expanded node name. + """ + with self._node_name_lock: + # Lazily populate the map from original node name to base-expanded ones. + if self._maybe_base_expanded_node_names is None: + self._maybe_base_expanded_node_names = dict() + # Sort all the node names. + sorted_names = sorted( + node.name for node in self._graph_def.node + ) + for i, name in enumerate(sorted_names): + j = i + 1 + while j < len(sorted_names) and sorted_names[j].startswith( + name + ): + if sorted_names[j].startswith(name + "/"): + self._maybe_base_expanded_node_names[name] = ( + name + "/(" + name.split("/")[-1] + ")" + ) + break + j += 1 + return self._maybe_base_expanded_node_names.get( + node_name, node_name + ) + + @property + def graph_def(self): + return self._graph_def diff --git a/tensorboard/plugins/debugger/debug_graphs_helper_test.py b/tensorboard/plugins/debugger/debug_graphs_helper_test.py index 0c3c1f6b43..0db679498c 100644 --- a/tensorboard/plugins/debugger/debug_graphs_helper_test.py +++ b/tensorboard/plugins/debugger/debug_graphs_helper_test.py @@ -37,6 +37,7 @@ import tensorflow as tf from tensorflow.python import debug as tf_debug + # See discussion on issue #1996 for private module import justification. from tensorflow.python import tf2 as tensorflow_python_tf2 from tensorflow.python.debug.lib import grpc_debug_test_server @@ -50,133 +51,166 @@ class ExtractGatedGrpcDebugOpsTest(tf.test.TestCase): - - @classmethod - def setUpClass(cls): - (cls.debug_server_port, cls.debug_server_url, _, cls.debug_server_thread, - cls.debug_server - ) = grpc_debug_test_server.start_server_on_separate_thread( - dump_to_filesystem=False) - logger.info('debug server url: %s', cls.debug_server_url) - - @classmethod - def tearDownClass(cls): - cls.debug_server.stop_server().wait() - cls.debug_server_thread.join() - - def tearDown(self): - tf.compat.v1.reset_default_graph() - self.debug_server.clear_data() - - def _createTestGraphAndRunOptions(self, sess, gated_grpc=True): - a = tf.Variable([1.0], name='a') - b = tf.Variable([2.0], name='b') - c = tf.Variable([3.0], name='c') - d = tf.Variable([4.0], name='d') - x = tf.add(a, b, name='x') - y = tf.add(c, d, name='y') - z = tf.add(x, y, name='z') - - run_options = tf.compat.v1.RunOptions(output_partition_graphs=True) - debug_op = 'DebugIdentity' - if gated_grpc: - debug_op += '(gated_grpc=True)' - tf_debug.watch_graph(run_options, - sess.graph, - debug_ops=debug_op, - debug_urls=self.debug_server_url) - return z, run_options - - def testExtractGatedGrpcTensorsFoundGatedGrpcOps(self): - with tf.compat.v1.Session() as sess: - z, run_options = self._createTestGraphAndRunOptions(sess, gated_grpc=True) - - sess.run(tf.compat.v1.global_variables_initializer()) - run_metadata = config_pb2.RunMetadata() - self.assertAllClose( - [10.0], sess.run(z, options=run_options, run_metadata=run_metadata)) - - graph_wrapper = debug_graphs_helper.DebugGraphWrapper( - run_metadata.partition_graphs[0]) - gated_debug_ops = graph_wrapper.get_gated_grpc_tensors() - - # Verify that the op types are available. - for item in gated_debug_ops: - self.assertTrue(item[1]) - - # Strip out the op types before further checks, because op type names can - # change in the future (e.g., 'VariableV2' --> 'VariableV3'). - gated_debug_ops = [ - (item[0], item[2], item[3]) for item in gated_debug_ops] - - self.assertIn(('a', 0, 'DebugIdentity'), gated_debug_ops) - self.assertIn(('b', 0, 'DebugIdentity'), gated_debug_ops) - self.assertIn(('c', 0, 'DebugIdentity'), gated_debug_ops) - self.assertIn(('d', 0, 'DebugIdentity'), gated_debug_ops) - - self.assertIn(('x', 0, 'DebugIdentity'), gated_debug_ops) - self.assertIn(('y', 0, 'DebugIdentity'), gated_debug_ops) - self.assertIn(('z', 0, 'DebugIdentity'), gated_debug_ops) - - def testGraphDefProperty(self): - with tf.compat.v1.Session() as sess: - z, run_options = self._createTestGraphAndRunOptions(sess, gated_grpc=True) - - sess.run(tf.compat.v1.global_variables_initializer()) - run_metadata = config_pb2.RunMetadata() - self.assertAllClose( - [10.0], sess.run(z, options=run_options, run_metadata=run_metadata)) - - graph_wrapper = debug_graphs_helper.DebugGraphWrapper( - run_metadata.partition_graphs[0]) - self.assertProtoEquals( - run_metadata.partition_graphs[0], graph_wrapper.graph_def) - - def testExtractGatedGrpcTensorsFoundNoGatedGrpcOps(self): - with tf.compat.v1.Session() as sess: - z, run_options = self._createTestGraphAndRunOptions(sess, - gated_grpc=False) - - sess.run(tf.compat.v1.global_variables_initializer()) - run_metadata = config_pb2.RunMetadata() - self.assertAllClose( - [10.0], sess.run(z, options=run_options, run_metadata=run_metadata)) - - graph_wrapper = debug_graphs_helper.DebugGraphWrapper( - run_metadata.partition_graphs[0]) - gated_debug_ops = graph_wrapper.get_gated_grpc_tensors() - self.assertEqual([], gated_debug_ops) + @classmethod + def setUpClass(cls): + ( + cls.debug_server_port, + cls.debug_server_url, + _, + cls.debug_server_thread, + cls.debug_server, + ) = grpc_debug_test_server.start_server_on_separate_thread( + dump_to_filesystem=False + ) + logger.info("debug server url: %s", cls.debug_server_url) + + @classmethod + def tearDownClass(cls): + cls.debug_server.stop_server().wait() + cls.debug_server_thread.join() + + def tearDown(self): + tf.compat.v1.reset_default_graph() + self.debug_server.clear_data() + + def _createTestGraphAndRunOptions(self, sess, gated_grpc=True): + a = tf.Variable([1.0], name="a") + b = tf.Variable([2.0], name="b") + c = tf.Variable([3.0], name="c") + d = tf.Variable([4.0], name="d") + x = tf.add(a, b, name="x") + y = tf.add(c, d, name="y") + z = tf.add(x, y, name="z") + + run_options = tf.compat.v1.RunOptions(output_partition_graphs=True) + debug_op = "DebugIdentity" + if gated_grpc: + debug_op += "(gated_grpc=True)" + tf_debug.watch_graph( + run_options, + sess.graph, + debug_ops=debug_op, + debug_urls=self.debug_server_url, + ) + return z, run_options + + def testExtractGatedGrpcTensorsFoundGatedGrpcOps(self): + with tf.compat.v1.Session() as sess: + z, run_options = self._createTestGraphAndRunOptions( + sess, gated_grpc=True + ) + + sess.run(tf.compat.v1.global_variables_initializer()) + run_metadata = config_pb2.RunMetadata() + self.assertAllClose( + [10.0], + sess.run(z, options=run_options, run_metadata=run_metadata), + ) + + graph_wrapper = debug_graphs_helper.DebugGraphWrapper( + run_metadata.partition_graphs[0] + ) + gated_debug_ops = graph_wrapper.get_gated_grpc_tensors() + + # Verify that the op types are available. + for item in gated_debug_ops: + self.assertTrue(item[1]) + + # Strip out the op types before further checks, because op type names can + # change in the future (e.g., 'VariableV2' --> 'VariableV3'). + gated_debug_ops = [ + (item[0], item[2], item[3]) for item in gated_debug_ops + ] + + self.assertIn(("a", 0, "DebugIdentity"), gated_debug_ops) + self.assertIn(("b", 0, "DebugIdentity"), gated_debug_ops) + self.assertIn(("c", 0, "DebugIdentity"), gated_debug_ops) + self.assertIn(("d", 0, "DebugIdentity"), gated_debug_ops) + + self.assertIn(("x", 0, "DebugIdentity"), gated_debug_ops) + self.assertIn(("y", 0, "DebugIdentity"), gated_debug_ops) + self.assertIn(("z", 0, "DebugIdentity"), gated_debug_ops) + + def testGraphDefProperty(self): + with tf.compat.v1.Session() as sess: + z, run_options = self._createTestGraphAndRunOptions( + sess, gated_grpc=True + ) + + sess.run(tf.compat.v1.global_variables_initializer()) + run_metadata = config_pb2.RunMetadata() + self.assertAllClose( + [10.0], + sess.run(z, options=run_options, run_metadata=run_metadata), + ) + + graph_wrapper = debug_graphs_helper.DebugGraphWrapper( + run_metadata.partition_graphs[0] + ) + self.assertProtoEquals( + run_metadata.partition_graphs[0], graph_wrapper.graph_def + ) + + def testExtractGatedGrpcTensorsFoundNoGatedGrpcOps(self): + with tf.compat.v1.Session() as sess: + z, run_options = self._createTestGraphAndRunOptions( + sess, gated_grpc=False + ) + + sess.run(tf.compat.v1.global_variables_initializer()) + run_metadata = config_pb2.RunMetadata() + self.assertAllClose( + [10.0], + sess.run(z, options=run_options, run_metadata=run_metadata), + ) + + graph_wrapper = debug_graphs_helper.DebugGraphWrapper( + run_metadata.partition_graphs[0] + ) + gated_debug_ops = graph_wrapper.get_gated_grpc_tensors() + self.assertEqual([], gated_debug_ops) class BaseExpandedNodeNameTest(tf.test.TestCase): - - def testMaybeBaseExpandedNodeName(self): - with tf.compat.v1.Session() as sess: - a = tf.Variable([1.0], name='foo/a') - b = tf.Variable([2.0], name='bar/b') - _ = tf.add(a, b, name='baz/c') - - graph_wrapper = debug_graphs_helper.DebugGraphWrapper(sess.graph_def) - - self.assertEqual( - 'foo/a/(a)', graph_wrapper.maybe_base_expanded_node_name('foo/a')) - self.assertEqual( - 'bar/b/(b)', graph_wrapper.maybe_base_expanded_node_name('bar/b')) - self.assertEqual( - 'foo/a/read', - graph_wrapper.maybe_base_expanded_node_name('foo/a/read')) - self.assertEqual( - 'bar/b/read', - graph_wrapper.maybe_base_expanded_node_name('bar/b/read')) - - if tensorflow_python_tf2.enabled(): - # NOTE(#1705): TF 2.0 tf.add creates nested nodes. - self.assertEqual( - 'baz/c/(c)', graph_wrapper.maybe_base_expanded_node_name('baz/c')) - else: - self.assertEqual( - 'baz/c', graph_wrapper.maybe_base_expanded_node_name('baz/c')) - - -if __name__ == '__main__': - tf.test.main() + def testMaybeBaseExpandedNodeName(self): + with tf.compat.v1.Session() as sess: + a = tf.Variable([1.0], name="foo/a") + b = tf.Variable([2.0], name="bar/b") + _ = tf.add(a, b, name="baz/c") + + graph_wrapper = debug_graphs_helper.DebugGraphWrapper( + sess.graph_def + ) + + self.assertEqual( + "foo/a/(a)", + graph_wrapper.maybe_base_expanded_node_name("foo/a"), + ) + self.assertEqual( + "bar/b/(b)", + graph_wrapper.maybe_base_expanded_node_name("bar/b"), + ) + self.assertEqual( + "foo/a/read", + graph_wrapper.maybe_base_expanded_node_name("foo/a/read"), + ) + self.assertEqual( + "bar/b/read", + graph_wrapper.maybe_base_expanded_node_name("bar/b/read"), + ) + + if tensorflow_python_tf2.enabled(): + # NOTE(#1705): TF 2.0 tf.add creates nested nodes. + self.assertEqual( + "baz/c/(c)", + graph_wrapper.maybe_base_expanded_node_name("baz/c"), + ) + else: + self.assertEqual( + "baz/c", + graph_wrapper.maybe_base_expanded_node_name("baz/c"), + ) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/debugger/debugger_plugin.py b/tensorboard/plugins/debugger/debugger_plugin.py index c9f65b5440..117b6532a1 100644 --- a/tensorboard/plugins/debugger/debugger_plugin.py +++ b/tensorboard/plugins/debugger/debugger_plugin.py @@ -40,493 +40,562 @@ logger = tb_logging.get_logger() # HTTP routes. -_HEALTH_PILLS_ROUTE = '/health_pills' -_NUMERICS_ALERT_REPORT_ROUTE = '/numerics_alert_report' +_HEALTH_PILLS_ROUTE = "/health_pills" +_NUMERICS_ALERT_REPORT_ROUTE = "/numerics_alert_report" # The POST key of HEALTH_PILLS_ROUTE for a JSON list of node names. -_NODE_NAMES_POST_KEY = 'node_names' +_NODE_NAMES_POST_KEY = "node_names" # The POST key of HEALTH_PILLS_ROUTE for the run to retrieve health pills for. -_RUN_POST_KEY = 'run' +_RUN_POST_KEY = "run" # The default run to retrieve health pills for. -_DEFAULT_RUN = '.' +_DEFAULT_RUN = "." # The POST key of HEALTH_PILLS_ROUTE for the specific step to retrieve health # pills for. -_STEP_POST_KEY = 'step' +_STEP_POST_KEY = "step" # A glob pattern for files containing debugger-related events. -_DEBUGGER_EVENTS_GLOB_PATTERN = 'events.debugger*' +_DEBUGGER_EVENTS_GLOB_PATTERN = "events.debugger*" # Encapsulates data for a single health pill. -HealthPillEvent = collections.namedtuple('HealthPillEvent', [ - 'wall_time', 'step', 'device_name', 'output_slot', 'node_name', 'dtype', - 'shape', 'value' -]) +HealthPillEvent = collections.namedtuple( + "HealthPillEvent", + [ + "wall_time", + "step", + "device_name", + "output_slot", + "node_name", + "dtype", + "shape", + "value", + ], +) class DebuggerPlugin(base_plugin.TBPlugin): - """TensorFlow Debugger plugin. Receives requests for debugger-related data. + """TensorFlow Debugger plugin. Receives requests for debugger-related data. - That data could include health pills, which unveil the status of tensor - values. - """ - - # This string field is used by TensorBoard to generate the paths for routes - # provided by this plugin. It must thus be URL-friendly. This field is also - # used to uniquely identify this plugin throughout TensorBoard. See BasePlugin - # for details. - plugin_name = constants.DEBUGGER_PLUGIN_NAME - - def __init__(self, context): - """Constructs a debugger plugin for TensorBoard. - - This plugin adds handlers for retrieving debugger-related data. The plugin - also starts a debugger data server once the log directory is passed to the - plugin via the call to get_plugin_apps. - - Args: - context: A base_plugin.TBContext instance. - """ - self._event_multiplexer = context.multiplexer - self._logdir = context.logdir - self._debugger_data_server = None - self._grpc_port = None - - def listen(self, grpc_port): - """Start listening on the given gRPC port. - - This method of an instance of DebuggerPlugin can be invoked at most once. - This method is not thread safe. - - Args: - grpc_port: port number to listen at. - - Raises: - ValueError: If this instance is already listening at a gRPC port. - """ - if self._grpc_port: - raise ValueError( - "This DebuggerPlugin instance is already listening at gRPC port %d" % - self._grpc_port) - self._grpc_port = grpc_port - - sys.stderr.write('Creating DebuggerDataServer at port %d and logdir %s\n' % - (self._grpc_port, self._logdir)) - sys.stderr.flush() - self._debugger_data_server = debugger_server_lib.DebuggerDataServer( - self._grpc_port, self._logdir) - - threading.Thread(target=self._debugger_data_server. - start_the_debugger_data_receiving_server).start() - - def get_plugin_apps(self): - """Obtains a mapping between routes and handlers. - - This function also starts a debugger data server on separate thread if the - plugin has not started one yet. - - Returns: - A mapping between routes and handlers (functions that respond to - requests). + That data could include health pills, which unveil the status of + tensor values. """ - return { - _HEALTH_PILLS_ROUTE: self._serve_health_pills_handler, - _NUMERICS_ALERT_REPORT_ROUTE: self._serve_numerics_alert_report_handler, - } - - def is_active(self): - """Determines whether this plugin is active. - - This plugin is active if any health pills information is present for any - run. - Returns: - A boolean. Whether this plugin is active. - """ - return bool( - self._grpc_port is not None and - self._event_multiplexer and - self._event_multiplexer.PluginRunToTagToContent( - constants.DEBUGGER_PLUGIN_NAME)) - - def frontend_metadata(self): - return base_plugin.FrontendMetadata(element_name='tf-debugger-dashboard') - - @wrappers.Request.application - def _serve_health_pills_handler(self, request): - """A (wrapped) werkzeug handler for serving health pills. - - Accepts POST requests and responds with health pills. The request accepts - several POST parameters: - - node_names: (required string) A JSON-ified list of node names for which - the client would like to request health pills. - run: (optional string) The run to retrieve health pills for. Defaults to - '.'. This data is sent via POST (not GET) since URL length is limited. - step: (optional integer): The session run step for which to - retrieve health pills. If provided, the handler reads the health pills - of that step from disk (which is slow) and produces a response with - only health pills at that step. If not provided, the handler returns a - response with health pills at all steps sampled by the event - multiplexer (the fast path). The motivation here is that, sometimes, - one desires to examine health pills at a specific step (to say find - the first step that causes a model to blow up with NaNs). - get_plugin_apps must be called before this slower feature is used - because that method passes the logdir (directory path) to this plugin. - - This handler responds with a JSON-ified object mapping from node names to a - list (of size 1) of health pill event objects, each of which has these - properties. - - { - 'wall_time': float, - 'step': int, - 'node_name': string, - 'output_slot': int, - # A list of 12 floats that summarizes the elements of the tensor. - 'value': float[], - } - - Node names for which there are no health pills to be found are excluded from - the mapping. - - Args: - request: The request issued by the client for health pills. - - Returns: - A werkzeug BaseResponse object. - """ - if request.method != 'POST': - return http_util.Respond( - request, - '%s requests are forbidden by the debugger plugin.' % request.method, - 'text/plain', - code=405 - ) - - if _NODE_NAMES_POST_KEY not in request.form: - return http_util.Respond(request, ( - 'The %r POST key was not found in the request for health pills.' % - _NODE_NAMES_POST_KEY), 'text/plain', code=400) - - jsonified_node_names = request.form[_NODE_NAMES_POST_KEY] - try: - node_names = json.loads(tf.compat.as_text(jsonified_node_names)) - except Exception as e: # pylint: disable=broad-except - # Different JSON libs raise different exceptions, so we just do a - # catch-all here. This problem is complicated by how Tensorboard might be - # run in many different environments, as it is open-source. - # TODO(@caisq, @chihuahua): Create platform-dependent adapter to catch - # specific types of exceptions, instead of the broad catching here. - response = ( - 'Could not decode node name JSON string %r: %s' - ) % (jsonified_node_names, e) - return http_util.Respond(request, response, 'text/plain', code=400) - - if not isinstance(node_names, list): - response = ( - '%r is not a JSON list of node names:' - ) % (jsonified_node_names) - return http_util.Respond(request, response, 'text/plain', code=400) - - run = request.form.get(_RUN_POST_KEY, _DEFAULT_RUN) - step_string = request.form.get(_STEP_POST_KEY, None) - if step_string is None: - # Use all steps sampled by the event multiplexer (Relatively fast). - mapping = self._obtain_sampled_health_pills(run, node_names) - else: - # Read disk to obtain the health pills for that step (Relatively slow). - # Make sure that the directory for the run exists. - # Determine the directory of events file to read. - events_directory = self._logdir - if run != _DEFAULT_RUN: - # Use the directory for the specific run. - events_directory = os.path.join(events_directory, run) - - step = int(step_string) - try: - mapping = self._obtain_health_pills_at_step( - events_directory, node_names, step) - except IOError as error: - response = 'Error retrieving health pills for step %d: %s' % (step, error) - return http_util.Respond(request, response, 'text/plain', code=404) - - # Convert event_accumulator.HealthPillEvents to JSON-able dicts. - jsonable_mapping = {} - for node_name, events in mapping.items(): - jsonable_mapping[node_name] = [e._asdict() for e in events] - return http_util.Respond(request, jsonable_mapping, 'application/json') - - def _obtain_sampled_health_pills(self, run, node_names): - """Obtains the health pills for a run sampled by the event multiplexer. - - This is much faster than the alternative path of reading health pills from - disk. - - Args: - run: The run to fetch health pills for. - node_names: A list of node names for which to retrieve health pills. - - Returns: - A dictionary mapping from node name to a list of - event_accumulator.HealthPillEvents. - """ - runs_to_tags_to_content = self._event_multiplexer.PluginRunToTagToContent( - constants.DEBUGGER_PLUGIN_NAME) - - if run not in runs_to_tags_to_content: - # The run lacks health pills. - return {} - - # This is also a mapping between node name and plugin content because this - # plugin tags by node name. - tags_to_content = runs_to_tags_to_content[run] - - mapping = {} - for node_name in node_names: - if node_name not in tags_to_content: - # This node lacks health pill data. - continue - - health_pills = [] - for tensor_event in self._event_multiplexer.Tensors(run, node_name): - json_string = tags_to_content[node_name] + # This string field is used by TensorBoard to generate the paths for routes + # provided by this plugin. It must thus be URL-friendly. This field is also + # used to uniquely identify this plugin throughout TensorBoard. See BasePlugin + # for details. + plugin_name = constants.DEBUGGER_PLUGIN_NAME + + def __init__(self, context): + """Constructs a debugger plugin for TensorBoard. + + This plugin adds handlers for retrieving debugger-related data. The plugin + also starts a debugger data server once the log directory is passed to the + plugin via the call to get_plugin_apps. + + Args: + context: A base_plugin.TBContext instance. + """ + self._event_multiplexer = context.multiplexer + self._logdir = context.logdir + self._debugger_data_server = None + self._grpc_port = None + + def listen(self, grpc_port): + """Start listening on the given gRPC port. + + This method of an instance of DebuggerPlugin can be invoked at most once. + This method is not thread safe. + + Args: + grpc_port: port number to listen at. + + Raises: + ValueError: If this instance is already listening at a gRPC port. + """ + if self._grpc_port: + raise ValueError( + "This DebuggerPlugin instance is already listening at gRPC port %d" + % self._grpc_port + ) + self._grpc_port = grpc_port + + sys.stderr.write( + "Creating DebuggerDataServer at port %d and logdir %s\n" + % (self._grpc_port, self._logdir) + ) + sys.stderr.flush() + self._debugger_data_server = debugger_server_lib.DebuggerDataServer( + self._grpc_port, self._logdir + ) + + threading.Thread( + target=self._debugger_data_server.start_the_debugger_data_receiving_server + ).start() + + def get_plugin_apps(self): + """Obtains a mapping between routes and handlers. + + This function also starts a debugger data server on separate thread if the + plugin has not started one yet. + + Returns: + A mapping between routes and handlers (functions that respond to + requests). + """ + return { + _HEALTH_PILLS_ROUTE: self._serve_health_pills_handler, + _NUMERICS_ALERT_REPORT_ROUTE: self._serve_numerics_alert_report_handler, + } + + def is_active(self): + """Determines whether this plugin is active. + + This plugin is active if any health pills information is present for any + run. + + Returns: + A boolean. Whether this plugin is active. + """ + return bool( + self._grpc_port is not None + and self._event_multiplexer + and self._event_multiplexer.PluginRunToTagToContent( + constants.DEBUGGER_PLUGIN_NAME + ) + ) + + def frontend_metadata(self): + return base_plugin.FrontendMetadata( + element_name="tf-debugger-dashboard" + ) + + @wrappers.Request.application + def _serve_health_pills_handler(self, request): + """A (wrapped) werkzeug handler for serving health pills. + + Accepts POST requests and responds with health pills. The request accepts + several POST parameters: + + node_names: (required string) A JSON-ified list of node names for which + the client would like to request health pills. + run: (optional string) The run to retrieve health pills for. Defaults to + '.'. This data is sent via POST (not GET) since URL length is limited. + step: (optional integer): The session run step for which to + retrieve health pills. If provided, the handler reads the health pills + of that step from disk (which is slow) and produces a response with + only health pills at that step. If not provided, the handler returns a + response with health pills at all steps sampled by the event + multiplexer (the fast path). The motivation here is that, sometimes, + one desires to examine health pills at a specific step (to say find + the first step that causes a model to blow up with NaNs). + get_plugin_apps must be called before this slower feature is used + because that method passes the logdir (directory path) to this plugin. + + This handler responds with a JSON-ified object mapping from node names to a + list (of size 1) of health pill event objects, each of which has these + properties. + + { + 'wall_time': float, + 'step': int, + 'node_name': string, + 'output_slot': int, + # A list of 12 floats that summarizes the elements of the tensor. + 'value': float[], + } + + Node names for which there are no health pills to be found are excluded from + the mapping. + + Args: + request: The request issued by the client for health pills. + + Returns: + A werkzeug BaseResponse object. + """ + if request.method != "POST": + return http_util.Respond( + request, + "%s requests are forbidden by the debugger plugin." + % request.method, + "text/plain", + code=405, + ) + + if _NODE_NAMES_POST_KEY not in request.form: + return http_util.Respond( + request, + ( + "The %r POST key was not found in the request for health pills." + % _NODE_NAMES_POST_KEY + ), + "text/plain", + code=400, + ) + + jsonified_node_names = request.form[_NODE_NAMES_POST_KEY] try: - content_object = json.loads(tf.compat.as_text(json_string)) - device_name = content_object['device'] - output_slot = content_object['outputSlot'] - health_pills.append( - self._tensor_proto_to_health_pill(tensor_event, node_name, - device_name, output_slot)) - except (KeyError, ValueError) as e: - logger.error('Could not determine device from JSON string ' - '%r: %r', json_string, e) - - mapping[node_name] = health_pills - - return mapping - - def _tensor_proto_to_health_pill(self, tensor_event, node_name, device, - output_slot): - """Converts an event_accumulator.TensorEvent to a HealthPillEvent. - - Args: - tensor_event: The event_accumulator.TensorEvent to convert. - node_name: The name of the node (without the output slot). - device: The device. - output_slot: The integer output slot this health pill is relevant to. - - Returns: - A HealthPillEvent. - """ - return self._process_health_pill_value( - wall_time=tensor_event.wall_time, - step=tensor_event.step, - device_name=device, - output_slot=output_slot, - node_name=node_name, - tensor_proto=tensor_event.tensor_proto) - - def _obtain_health_pills_at_step(self, events_directory, node_names, step): - """Reads disk to obtain the health pills for a run at a specific step. - - This could be much slower than the alternative path of just returning all - health pills sampled by the event multiplexer. It could take tens of minutes - to complete this call for large graphs for big step values (in the - thousands). - - Args: - events_directory: The directory containing events for the desired run. - node_names: A list of node names for which to retrieve health pills. - step: The step to obtain health pills for. - - Returns: - A dictionary mapping from node name to a list of health pill objects (see - docs for _serve_health_pills_handler for properties of those objects). - - Raises: - IOError: If no files with health pill events could be found. - """ - # Obtain all files with debugger-related events. - pattern = os.path.join(events_directory, _DEBUGGER_EVENTS_GLOB_PATTERN) - file_paths = glob.glob(pattern) - - if not file_paths: - raise IOError( - 'No events files found that matches the pattern %r.' % pattern) - - # Sort by name (and thus by timestamp). - file_paths.sort() - - mapping = collections.defaultdict(list) - node_name_set = frozenset(node_names) - - for file_path in file_paths: - should_stop = self._process_health_pill_event( - node_name_set, mapping, step, file_path) - if should_stop: - break - - return mapping - - def _process_health_pill_event(self, node_name_set, mapping, target_step, - file_path): - """Creates health pills out of data in an event. - - Creates health pills out of the event and adds them to the mapping. - - Args: - node_name_set: A set of node names that are relevant. - mapping: The mapping from node name to HealthPillEvents. - This object may be destructively modified. - target_step: The target step at which to obtain health pills. - file_path: The path to the file with health pill events. - - Returns: - Whether we should stop reading events because future events are no longer - relevant. - """ - events_loader = event_file_loader.EventFileLoader(file_path) - for event in events_loader.Load(): - if not event.HasField('summary'): - logger.warn( - 'An event in a debugger events file lacks a summary.') - continue - - if event.step < target_step: - # This event is not of the relevant step. We perform this check - # first because the majority of events will be eliminated from - # consideration by this check. - continue - - if event.step > target_step: - # We have passed the relevant step. No need to read more events. - return True - - for value in event.summary.value: - # Obtain the device name from the metadata. - summary_metadata = value.metadata - plugin_data = summary_metadata.plugin_data - if plugin_data.plugin_name == constants.DEBUGGER_PLUGIN_NAME: - try: - content = json.loads( - tf.compat.as_text(summary_metadata.plugin_data.content)) - except ValueError as err: - logger.warn( - 'Could not parse the JSON string containing data for ' - 'the debugger plugin: %r, %r', content, err) - continue - device_name = content['device'] - output_slot = content['outputSlot'] + node_names = json.loads(tf.compat.as_text(jsonified_node_names)) + except Exception as e: # pylint: disable=broad-except + # Different JSON libs raise different exceptions, so we just do a + # catch-all here. This problem is complicated by how Tensorboard might be + # run in many different environments, as it is open-source. + # TODO(@caisq, @chihuahua): Create platform-dependent adapter to catch + # specific types of exceptions, instead of the broad catching here. + response = ("Could not decode node name JSON string %r: %s") % ( + jsonified_node_names, + e, + ) + return http_util.Respond(request, response, "text/plain", code=400) + + if not isinstance(node_names, list): + response = ("%r is not a JSON list of node names:") % ( + jsonified_node_names + ) + return http_util.Respond(request, response, "text/plain", code=400) + + run = request.form.get(_RUN_POST_KEY, _DEFAULT_RUN) + step_string = request.form.get(_STEP_POST_KEY, None) + if step_string is None: + # Use all steps sampled by the event multiplexer (Relatively fast). + mapping = self._obtain_sampled_health_pills(run, node_names) else: - logger.error( - 'No debugger plugin data found for event with tag %s and node ' - 'name %s.', value.tag, value.node_name) - continue - - if not value.HasField('tensor'): - logger.warn( - 'An event in a debugger events file lacks a tensor value.') - continue - - match = re.match(r'^(.*):(\d+):DebugNumericSummary$', value.node_name) - if not match: - logger.warn( - ('A event with a health pill has an invalid watch, (i.e., an ' - 'unexpected debug op): %r'), value.node_name) - return None - - health_pill = self._process_health_pill_value( - wall_time=event.wall_time, - step=event.step, + # Read disk to obtain the health pills for that step (Relatively slow). + # Make sure that the directory for the run exists. + # Determine the directory of events file to read. + events_directory = self._logdir + if run != _DEFAULT_RUN: + # Use the directory for the specific run. + events_directory = os.path.join(events_directory, run) + + step = int(step_string) + try: + mapping = self._obtain_health_pills_at_step( + events_directory, node_names, step + ) + except IOError as error: + response = "Error retrieving health pills for step %d: %s" % ( + step, + error, + ) + return http_util.Respond( + request, response, "text/plain", code=404 + ) + + # Convert event_accumulator.HealthPillEvents to JSON-able dicts. + jsonable_mapping = {} + for node_name, events in mapping.items(): + jsonable_mapping[node_name] = [e._asdict() for e in events] + return http_util.Respond(request, jsonable_mapping, "application/json") + + def _obtain_sampled_health_pills(self, run, node_names): + """Obtains the health pills for a run sampled by the event multiplexer. + + This is much faster than the alternative path of reading health pills from + disk. + + Args: + run: The run to fetch health pills for. + node_names: A list of node names for which to retrieve health pills. + + Returns: + A dictionary mapping from node name to a list of + event_accumulator.HealthPillEvents. + """ + runs_to_tags_to_content = self._event_multiplexer.PluginRunToTagToContent( + constants.DEBUGGER_PLUGIN_NAME + ) + + if run not in runs_to_tags_to_content: + # The run lacks health pills. + return {} + + # This is also a mapping between node name and plugin content because this + # plugin tags by node name. + tags_to_content = runs_to_tags_to_content[run] + + mapping = {} + for node_name in node_names: + if node_name not in tags_to_content: + # This node lacks health pill data. + continue + + health_pills = [] + for tensor_event in self._event_multiplexer.Tensors(run, node_name): + json_string = tags_to_content[node_name] + try: + content_object = json.loads(tf.compat.as_text(json_string)) + device_name = content_object["device"] + output_slot = content_object["outputSlot"] + health_pills.append( + self._tensor_proto_to_health_pill( + tensor_event, node_name, device_name, output_slot + ) + ) + except (KeyError, ValueError) as e: + logger.error( + "Could not determine device from JSON string " "%r: %r", + json_string, + e, + ) + + mapping[node_name] = health_pills + + return mapping + + def _tensor_proto_to_health_pill( + self, tensor_event, node_name, device, output_slot + ): + """Converts an event_accumulator.TensorEvent to a HealthPillEvent. + + Args: + tensor_event: The event_accumulator.TensorEvent to convert. + node_name: The name of the node (without the output slot). + device: The device. + output_slot: The integer output slot this health pill is relevant to. + + Returns: + A HealthPillEvent. + """ + return self._process_health_pill_value( + wall_time=tensor_event.wall_time, + step=tensor_event.step, + device_name=device, + output_slot=output_slot, + node_name=node_name, + tensor_proto=tensor_event.tensor_proto, + ) + + def _obtain_health_pills_at_step(self, events_directory, node_names, step): + """Reads disk to obtain the health pills for a run at a specific step. + + This could be much slower than the alternative path of just returning all + health pills sampled by the event multiplexer. It could take tens of minutes + to complete this call for large graphs for big step values (in the + thousands). + + Args: + events_directory: The directory containing events for the desired run. + node_names: A list of node names for which to retrieve health pills. + step: The step to obtain health pills for. + + Returns: + A dictionary mapping from node name to a list of health pill objects (see + docs for _serve_health_pills_handler for properties of those objects). + + Raises: + IOError: If no files with health pill events could be found. + """ + # Obtain all files with debugger-related events. + pattern = os.path.join(events_directory, _DEBUGGER_EVENTS_GLOB_PATTERN) + file_paths = glob.glob(pattern) + + if not file_paths: + raise IOError( + "No events files found that matches the pattern %r." % pattern + ) + + # Sort by name (and thus by timestamp). + file_paths.sort() + + mapping = collections.defaultdict(list) + node_name_set = frozenset(node_names) + + for file_path in file_paths: + should_stop = self._process_health_pill_event( + node_name_set, mapping, step, file_path + ) + if should_stop: + break + + return mapping + + def _process_health_pill_event( + self, node_name_set, mapping, target_step, file_path + ): + """Creates health pills out of data in an event. + + Creates health pills out of the event and adds them to the mapping. + + Args: + node_name_set: A set of node names that are relevant. + mapping: The mapping from node name to HealthPillEvents. + This object may be destructively modified. + target_step: The target step at which to obtain health pills. + file_path: The path to the file with health pill events. + + Returns: + Whether we should stop reading events because future events are no longer + relevant. + """ + events_loader = event_file_loader.EventFileLoader(file_path) + for event in events_loader.Load(): + if not event.HasField("summary"): + logger.warn( + "An event in a debugger events file lacks a summary." + ) + continue + + if event.step < target_step: + # This event is not of the relevant step. We perform this check + # first because the majority of events will be eliminated from + # consideration by this check. + continue + + if event.step > target_step: + # We have passed the relevant step. No need to read more events. + return True + + for value in event.summary.value: + # Obtain the device name from the metadata. + summary_metadata = value.metadata + plugin_data = summary_metadata.plugin_data + if plugin_data.plugin_name == constants.DEBUGGER_PLUGIN_NAME: + try: + content = json.loads( + tf.compat.as_text( + summary_metadata.plugin_data.content + ) + ) + except ValueError as err: + logger.warn( + "Could not parse the JSON string containing data for " + "the debugger plugin: %r, %r", + content, + err, + ) + continue + device_name = content["device"] + output_slot = content["outputSlot"] + else: + logger.error( + "No debugger plugin data found for event with tag %s and node " + "name %s.", + value.tag, + value.node_name, + ) + continue + + if not value.HasField("tensor"): + logger.warn( + "An event in a debugger events file lacks a tensor value." + ) + continue + + match = re.match( + r"^(.*):(\d+):DebugNumericSummary$", value.node_name + ) + if not match: + logger.warn( + ( + "A event with a health pill has an invalid watch, (i.e., an " + "unexpected debug op): %r" + ), + value.node_name, + ) + return None + + health_pill = self._process_health_pill_value( + wall_time=event.wall_time, + step=event.step, + device_name=device_name, + output_slot=output_slot, + node_name=match.group(1), + tensor_proto=value.tensor, + node_name_set=node_name_set, + ) + if not health_pill: + continue + mapping[health_pill.node_name].append(health_pill) + + # Keep reading events. + return False + + def _process_health_pill_value( + self, + wall_time, + step, + device_name, + output_slot, + node_name, + tensor_proto, + node_name_set=None, + ): + """Creates a HealthPillEvent containing various properties of a health + pill. + + Args: + wall_time: The wall time in seconds. + step: The session run step of the event. + device_name: The name of the node's device. + output_slot: The numeric output slot. + node_name: The name of the node (without the output slot). + tensor_proto: A tensor proto of data. + node_name_set: An optional set of node names that are relevant. If not + provided, no filtering by relevance occurs. + + Returns: + An event_accumulator.HealthPillEvent. Or None if one could not be created. + """ + if node_name_set and node_name not in node_name_set: + # This event is not relevant. + return None + + # Since we seek health pills for a specific step, this function + # returns 1 health pill per node per step. The wall time is the + # seconds since the epoch. + elements = list(tensor_util.make_ndarray(tensor_proto)) + return HealthPillEvent( + wall_time=wall_time, + step=step, device_name=device_name, output_slot=output_slot, - node_name=match.group(1), - tensor_proto=value.tensor, - node_name_set=node_name_set) - if not health_pill: - continue - mapping[health_pill.node_name].append(health_pill) - - # Keep reading events. - return False - - def _process_health_pill_value(self, - wall_time, - step, - device_name, - output_slot, - node_name, - tensor_proto, - node_name_set=None): - """Creates a HealthPillEvent containing various properties of a health pill. - - Args: - wall_time: The wall time in seconds. - step: The session run step of the event. - device_name: The name of the node's device. - output_slot: The numeric output slot. - node_name: The name of the node (without the output slot). - tensor_proto: A tensor proto of data. - node_name_set: An optional set of node names that are relevant. If not - provided, no filtering by relevance occurs. - - Returns: - An event_accumulator.HealthPillEvent. Or None if one could not be created. - """ - if node_name_set and node_name not in node_name_set: - # This event is not relevant. - return None - - # Since we seek health pills for a specific step, this function - # returns 1 health pill per node per step. The wall time is the - # seconds since the epoch. - elements = list(tensor_util.make_ndarray(tensor_proto)) - return HealthPillEvent( - wall_time=wall_time, - step=step, - device_name=device_name, - output_slot=output_slot, - node_name=node_name, - dtype=repr(tf.as_dtype(elements[12])), - shape=elements[14:], - value=elements) - - @wrappers.Request.application - def _serve_numerics_alert_report_handler(self, request): - """A (wrapped) werkzeug handler for serving numerics alert report. - - Accepts GET requests and responds with an array of JSON-ified - NumericsAlertReportRow. - - Each JSON-ified NumericsAlertReportRow object has the following format: - { - 'device_name': string, - 'tensor_name': string, - 'first_timestamp': float, - 'nan_event_count': int, - 'neg_inf_event_count': int, - 'pos_inf_event_count': int - } - - These objects are sorted by ascending order of first_timestamp in the - response array. - - Args: - request: The request, currently assumed to be empty. - - Returns: - A werkzeug BaseResponse object. - """ - if request.method != 'GET': - response = ( - '%s requests are forbidden by the debugger plugin.' % request.method) - return http_util.Respond(request, response, 'text/plain', code=405) - - report = self._debugger_data_server.numerics_alert_report() - - # Convert the named tuples to dictionaries so we JSON them into objects. - response = [r._asdict() for r in report] # pylint: disable=protected-access - return http_util.Respond(request, response, 'application/json') + node_name=node_name, + dtype=repr(tf.as_dtype(elements[12])), + shape=elements[14:], + value=elements, + ) + + @wrappers.Request.application + def _serve_numerics_alert_report_handler(self, request): + """A (wrapped) werkzeug handler for serving numerics alert report. + + Accepts GET requests and responds with an array of JSON-ified + NumericsAlertReportRow. + + Each JSON-ified NumericsAlertReportRow object has the following format: + { + 'device_name': string, + 'tensor_name': string, + 'first_timestamp': float, + 'nan_event_count': int, + 'neg_inf_event_count': int, + 'pos_inf_event_count': int + } + + These objects are sorted by ascending order of first_timestamp in the + response array. + + Args: + request: The request, currently assumed to be empty. + + Returns: + A werkzeug BaseResponse object. + """ + if request.method != "GET": + response = ( + "%s requests are forbidden by the debugger plugin." + % request.method + ) + return http_util.Respond(request, response, "text/plain", code=405) + + report = self._debugger_data_server.numerics_alert_report() + + # Convert the named tuples to dictionaries so we JSON them into objects. + response = [ + r._asdict() for r in report + ] # pylint: disable=protected-access + return http_util.Respond(request, response, "application/json") diff --git a/tensorboard/plugins/debugger/debugger_plugin_loader.py b/tensorboard/plugins/debugger/debugger_plugin_loader.py index fd1ea05377..3dc5c5170f 100644 --- a/tensorboard/plugins/debugger/debugger_plugin_loader.py +++ b/tensorboard/plugins/debugger/debugger_plugin_loader.py @@ -32,51 +32,51 @@ class InactiveDebuggerPlugin(base_plugin.TBPlugin): - """A placeholder debugger plugin used when no grpc port is specified.""" + """A placeholder debugger plugin used when no grpc port is specified.""" - plugin_name = constants.DEBUGGER_PLUGIN_NAME + plugin_name = constants.DEBUGGER_PLUGIN_NAME - def __init__(self): - pass + def __init__(self): + pass - def is_active(self): - return False + def is_active(self): + return False - def frontend_metadata(self): - return base_plugin.FrontendMetadata(element_name='tf-debugger-dashboard') + def frontend_metadata(self): + return base_plugin.FrontendMetadata( + element_name="tf-debugger-dashboard" + ) - def get_plugin_apps(self): - return { - '/debugger_grpc_host_port': self._serve_debugger_grpc_host_port, - } - - @wrappers.Request.application - def _serve_debugger_grpc_host_port(self, request): - # Respond with a -1 port number to indicate the debugger plugin is - # inactive. - return http_util.Respond( - request, - {'host': None, 'port': -1}, - 'application/json') + def get_plugin_apps(self): + return { + "/debugger_grpc_host_port": self._serve_debugger_grpc_host_port, + } + @wrappers.Request.application + def _serve_debugger_grpc_host_port(self, request): + # Respond with a -1 port number to indicate the debugger plugin is + # inactive. + return http_util.Respond( + request, {"host": None, "port": -1}, "application/json" + ) class DebuggerPluginLoader(base_plugin.TBLoader): - """DebuggerPlugin factory factory. - - This class determines which debugger plugin to load, based on custom - flags. It also checks for the `grpcio` PyPi dependency. - """ - - def define_flags(self, parser): - """Adds DebuggerPlugin CLI flags to parser.""" - group = parser.add_argument_group('debugger plugin') - group.add_argument( - '--debugger_data_server_grpc_port', - metavar='PORT', - type=int, - default=-1, - help='''\ + """DebuggerPlugin factory factory. + + This class determines which debugger plugin to load, based on custom + flags. It also checks for the `grpcio` PyPi dependency. + """ + + def define_flags(self, parser): + """Adds DebuggerPlugin CLI flags to parser.""" + group = parser.add_argument_group("debugger plugin") + group.add_argument( + "--debugger_data_server_grpc_port", + metavar="PORT", + type=int, + default=-1, + help="""\ The port at which the non-interactive debugger data server should receive debugging data via gRPC from one or more debugger-enabled TensorFlow runtimes. No debugger plugin or debugger data server will be @@ -84,13 +84,14 @@ def define_flags(self, parser): `--debugger_port` flag in that it starts a non-interactive mode. It is for use with the "health pills" feature of the Graph Dashboard. This flag is mutually exclusive with `--debugger_port`.\ -''') - group.add_argument( - '--debugger_port', - metavar='PORT', - type=int, - default=-1, - help='''\ +""", + ) + group.add_argument( + "--debugger_port", + metavar="PORT", + type=int, + default=-1, + help="""\ The port at which the interactive debugger data server (to be started by the debugger plugin) should receive debugging data via gRPC from one or more debugger-enabled TensorFlow runtimes. No debugger plugin or @@ -100,63 +101,76 @@ def define_flags(self, parser): inside a TensorFlow Graph or between Session.runs. It is for use with the interactive Debugger Dashboard. This flag is mutually exclusive with `--debugger_data_server_grpc_port`.\ -''') - - def fix_flags(self, flags): - """Fixes Debugger related flags. - - Raises: - ValueError: If both the `debugger_data_server_grpc_port` and - `debugger_port` flags are specified as >= 0. - """ - # Check that not both grpc port flags are specified. - if flags.debugger_data_server_grpc_port > 0 and flags.debugger_port > 0: - raise base_plugin.FlagsError( - '--debugger_data_server_grpc_port and --debugger_port are mutually ' - 'exclusive. Do not use both of them at the same time.') - - def load(self, context): - """Returns the debugger plugin, if possible. - - Args: - context: The TBContext flags including `add_arguments`. - - Returns: - A DebuggerPlugin instance or None if it couldn't be loaded. - """ - flags = context.flags - if flags.debugger_data_server_grpc_port > 0 or flags.debugger_port > 0: - # Verify that the required Python packages are installed. - try: - # pylint: disable=unused-import - import tensorflow - except ImportError: - raise ImportError( - 'To use the debugger plugin, you need to have TensorFlow installed:\n' - ' pip install tensorflow') - - if flags.debugger_data_server_grpc_port > 0: - from tensorboard.plugins.debugger import debugger_plugin as debugger_plugin_lib - - # debugger_data_server_grpc opens the non-interactive Debugger Plugin, - # which appears as health pills in the Graph Plugin. - noninteractive_plugin = debugger_plugin_lib.DebuggerPlugin(context) - logger.info('Starting Non-interactive Debugger Plugin at gRPC port %d', - flags.debugger_data_server_grpc_port) - noninteractive_plugin.listen(flags.debugger_data_server_grpc_port) - return noninteractive_plugin - elif flags.debugger_port > 0: - from tensorboard.plugins.debugger import interactive_debugger_plugin as interactive_debugger_plugin_lib - interactive_plugin = ( - interactive_debugger_plugin_lib.InteractiveDebuggerPlugin(context)) - logger.info('Starting Interactive Debugger Plugin at gRPC port %d', - flags.debugger_data_server_grpc_port) - interactive_plugin.listen(flags.debugger_port) - return interactive_plugin - else: - # If neither the debugger_data_server_grpc_port flag or the grpc_port - # flag is specified, we instantiate a dummy plugin as a placeholder for - # the frontend. The dummy plugin will display a message indicating that - # the plugin is not active. It'll also display a command snippet to - # illustrate how to activate the interactive Debugger Plugin. - return InactiveDebuggerPlugin() +""", + ) + + def fix_flags(self, flags): + """Fixes Debugger related flags. + + Raises: + ValueError: If both the `debugger_data_server_grpc_port` and + `debugger_port` flags are specified as >= 0. + """ + # Check that not both grpc port flags are specified. + if flags.debugger_data_server_grpc_port > 0 and flags.debugger_port > 0: + raise base_plugin.FlagsError( + "--debugger_data_server_grpc_port and --debugger_port are mutually " + "exclusive. Do not use both of them at the same time." + ) + + def load(self, context): + """Returns the debugger plugin, if possible. + + Args: + context: The TBContext flags including `add_arguments`. + + Returns: + A DebuggerPlugin instance or None if it couldn't be loaded. + """ + flags = context.flags + if flags.debugger_data_server_grpc_port > 0 or flags.debugger_port > 0: + # Verify that the required Python packages are installed. + try: + # pylint: disable=unused-import + import tensorflow + except ImportError: + raise ImportError( + "To use the debugger plugin, you need to have TensorFlow installed:\n" + " pip install tensorflow" + ) + + if flags.debugger_data_server_grpc_port > 0: + from tensorboard.plugins.debugger import ( + debugger_plugin as debugger_plugin_lib, + ) + + # debugger_data_server_grpc opens the non-interactive Debugger Plugin, + # which appears as health pills in the Graph Plugin. + noninteractive_plugin = debugger_plugin_lib.DebuggerPlugin(context) + logger.info( + "Starting Non-interactive Debugger Plugin at gRPC port %d", + flags.debugger_data_server_grpc_port, + ) + noninteractive_plugin.listen(flags.debugger_data_server_grpc_port) + return noninteractive_plugin + elif flags.debugger_port > 0: + from tensorboard.plugins.debugger import ( + interactive_debugger_plugin as interactive_debugger_plugin_lib, + ) + + interactive_plugin = interactive_debugger_plugin_lib.InteractiveDebuggerPlugin( + context + ) + logger.info( + "Starting Interactive Debugger Plugin at gRPC port %d", + flags.debugger_data_server_grpc_port, + ) + interactive_plugin.listen(flags.debugger_port) + return interactive_plugin + else: + # If neither the debugger_data_server_grpc_port flag or the grpc_port + # flag is specified, we instantiate a dummy plugin as a placeholder for + # the frontend. The dummy plugin will display a message indicating that + # the plugin is not active. It'll also display a command snippet to + # illustrate how to activate the interactive Debugger Plugin. + return InactiveDebuggerPlugin() diff --git a/tensorboard/plugins/debugger/debugger_plugin_test.py b/tensorboard/plugins/debugger/debugger_plugin_test.py index 834d769ec6..8dfeeb4c1f 100644 --- a/tensorboard/plugins/debugger/debugger_plugin_test.py +++ b/tensorboard/plugins/debugger/debugger_plugin_test.py @@ -23,241 +23,295 @@ import tensorflow as tf -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from tensorboard.backend.event_processing import ( + plugin_event_multiplexer as event_multiplexer, +) from tensorboard.plugins import base_plugin from tensorboard.plugins.debugger import debugger_plugin_testlib from tensorboard.plugins.debugger import numerics_alert class DebuggerPluginTest(debugger_plugin_testlib.DebuggerPluginTestBase): + def testHealthPillsRouteProvided(self): + """Tests that the plugin offers the route for requesting health + pills.""" + apps = self.plugin.get_plugin_apps() + self.assertIn("/health_pills", apps) + self.assertIsInstance(apps["/health_pills"], collections.Callable) - def testHealthPillsRouteProvided(self): - """Tests that the plugin offers the route for requesting health pills.""" - apps = self.plugin.get_plugin_apps() - self.assertIn('/health_pills', apps) - self.assertIsInstance(apps['/health_pills'], collections.Callable) + def testHealthPillsPluginIsActive(self): + # The multiplexer has sampled health pills. + self.assertTrue(self.plugin.is_active()) - def testHealthPillsPluginIsActive(self): - # The multiplexer has sampled health pills. - self.assertTrue(self.plugin.is_active()) + def testHealthPillsPluginIsInactive(self): + plugin = self.debugger_plugin_module.DebuggerPlugin( + base_plugin.TBContext( + logdir=self.log_dir, + multiplexer=event_multiplexer.EventMultiplexer({}), + ) + ) + plugin.listen(self.debugger_data_server_grpc_port) - def testHealthPillsPluginIsInactive(self): - plugin = self.debugger_plugin_module.DebuggerPlugin( - base_plugin.TBContext( - logdir=self.log_dir, - multiplexer=event_multiplexer.EventMultiplexer({}))) - plugin.listen(self.debugger_data_server_grpc_port) + # The multiplexer lacks sampled health pills. + self.assertFalse(plugin.is_active()) - # The multiplexer lacks sampled health pills. - self.assertFalse(plugin.is_active()) - - def testRequestHealthPillsForRunFoo(self): - """Tests that the plugin produces health pills for a specified run.""" - response = self.server.post( - '/data/plugin/debugger/health_pills', - data={ - 'node_names': json.dumps(['layers/Variable', 'unavailable_node']), - 'run': 'run_foo', - }) - self.assertEqual(200, response.status_code) - self.assertDictEqual({ - 'layers/Variable': [{ - 'wall_time': 4242, - 'step': 42, - 'device_name': '/job:localhost/replica:0/task:0/cpu:0', - 'node_name': 'layers/Variable', - 'output_slot': 0, - 'dtype': 'tf.int16', - 'shape': [8.0], - 'value': list(range(12)) + [tf.int16.as_datatype_enum, 1.0, 8.0], - }], - }, self._DeserializeResponse(response.get_data())) - - def testRequestHealthPillsForDefaultRun(self): - """Tests that the plugin produces health pills for the default '.' run.""" - # Do not provide a 'run' parameter in POST data. - response = self.server.post( - '/data/plugin/debugger/health_pills', - data={ - 'node_names': json.dumps(['logits/Add', 'unavailable_node']), - }) - self.assertEqual(200, response.status_code) - # The health pills for 'layers/Matmul' should not be included since the - # request excluded that node name. - self.assertDictEqual({ - 'logits/Add': [ + def testRequestHealthPillsForRunFoo(self): + """Tests that the plugin produces health pills for a specified run.""" + response = self.server.post( + "/data/plugin/debugger/health_pills", + data={ + "node_names": json.dumps( + ["layers/Variable", "unavailable_node"] + ), + "run": "run_foo", + }, + ) + self.assertEqual(200, response.status_code) + self.assertDictEqual( { - 'wall_time': 1337, - 'step': 7, - 'device_name': '/job:localhost/replica:0/task:0/cpu:0', - 'node_name': 'logits/Add', - 'output_slot': 0, - 'dtype': 'tf.int32', - 'shape': [3.0, 3.0], - 'value': (list(range(12)) + - [float(tf.int32.as_datatype_enum), 2.0, 3.0, 3.0]), + "layers/Variable": [ + { + "wall_time": 4242, + "step": 42, + "device_name": "/job:localhost/replica:0/task:0/cpu:0", + "node_name": "layers/Variable", + "output_slot": 0, + "dtype": "tf.int16", + "shape": [8.0], + "value": list(range(12)) + + [tf.int16.as_datatype_enum, 1.0, 8.0], + } + ], }, + self._DeserializeResponse(response.get_data()), + ) + + def testRequestHealthPillsForDefaultRun(self): + """Tests that the plugin produces health pills for the default '.' + run.""" + # Do not provide a 'run' parameter in POST data. + response = self.server.post( + "/data/plugin/debugger/health_pills", + data={ + "node_names": json.dumps(["logits/Add", "unavailable_node"]), + }, + ) + self.assertEqual(200, response.status_code) + # The health pills for 'layers/Matmul' should not be included since the + # request excluded that node name. + self.assertDictEqual( { - 'wall_time': 1338, - 'step': 8, - 'device_name': '/job:localhost/replica:0/task:0/cpu:0', - 'node_name': 'logits/Add', - 'output_slot': 0, - 'dtype': 'tf.int16', - 'shape': [], - 'value': (list(range(12)) + - [float(tf.int16.as_datatype_enum), 0.0]), + "logits/Add": [ + { + "wall_time": 1337, + "step": 7, + "device_name": "/job:localhost/replica:0/task:0/cpu:0", + "node_name": "logits/Add", + "output_slot": 0, + "dtype": "tf.int32", + "shape": [3.0, 3.0], + "value": ( + list(range(12)) + + [float(tf.int32.as_datatype_enum), 2.0, 3.0, 3.0] + ), + }, + { + "wall_time": 1338, + "step": 8, + "device_name": "/job:localhost/replica:0/task:0/cpu:0", + "node_name": "logits/Add", + "output_slot": 0, + "dtype": "tf.int16", + "shape": [], + "value": ( + list(range(12)) + + [float(tf.int16.as_datatype_enum), 0.0] + ), + }, + ], }, - ], - }, self._DeserializeResponse(response.get_data())) + self._DeserializeResponse(response.get_data()), + ) - def testRequestHealthPillsForEmptyRun(self): - """Tests that the plugin responds with an empty dictionary.""" - response = self.server.post( - '/data/plugin/debugger/health_pills', - data={ - 'node_names': json.dumps(['layers/Variable']), - 'run': 'run_with_no_health_pills', - }) - self.assertEqual(200, response.status_code) - self.assertDictEqual({}, self._DeserializeResponse(response.get_data())) + def testRequestHealthPillsForEmptyRun(self): + """Tests that the plugin responds with an empty dictionary.""" + response = self.server.post( + "/data/plugin/debugger/health_pills", + data={ + "node_names": json.dumps(["layers/Variable"]), + "run": "run_with_no_health_pills", + }, + ) + self.assertEqual(200, response.status_code) + self.assertDictEqual({}, self._DeserializeResponse(response.get_data())) - def testGetRequestsUnsupported(self): - """Tests that GET requests are unsupported.""" - response = self.server.get('/data/plugin/debugger/health_pills') - self.assertEqual(405, response.status_code) + def testGetRequestsUnsupported(self): + """Tests that GET requests are unsupported.""" + response = self.server.get("/data/plugin/debugger/health_pills") + self.assertEqual(405, response.status_code) - def testRequestsWithoutProperPostKeyUnsupported(self): - """Tests that requests lacking the node_names POST key are unsupported.""" - response = self.server.post('/data/plugin/debugger/health_pills') - self.assertEqual(400, response.status_code) + def testRequestsWithoutProperPostKeyUnsupported(self): + """Tests that requests lacking the node_names POST key are + unsupported.""" + response = self.server.post("/data/plugin/debugger/health_pills") + self.assertEqual(400, response.status_code) - def testRequestsWithBadJsonUnsupported(self): - """Tests that requests with undecodable JSON are unsupported.""" - response = self.server.post( - '/data/plugin/debugger/health_pills', - data={ - 'node_names': 'some obviously non JSON text', - }) - self.assertEqual(400, response.status_code) + def testRequestsWithBadJsonUnsupported(self): + """Tests that requests with undecodable JSON are unsupported.""" + response = self.server.post( + "/data/plugin/debugger/health_pills", + data={"node_names": "some obviously non JSON text",}, + ) + self.assertEqual(400, response.status_code) - def testRequestsWithNonListPostDataUnsupported(self): - """Tests that requests with loads lacking lists of ops are unsupported.""" - response = self.server.post( - '/data/plugin/debugger/health_pills', - data={ - 'node_names': json.dumps({ - 'this is a dict': 'and not a list.' - }), - }) - self.assertEqual(400, response.status_code) + def testRequestsWithNonListPostDataUnsupported(self): + """Tests that requests with loads lacking lists of ops are + unsupported.""" + response = self.server.post( + "/data/plugin/debugger/health_pills", + data={ + "node_names": json.dumps({"this is a dict": "and not a list."}), + }, + ) + self.assertEqual(400, response.status_code) - def testFetchHealthPillsForSpecificStep(self): - """Tests that requesting health pills at a specific steps works. + def testFetchHealthPillsForSpecificStep(self): + """Tests that requesting health pills at a specific steps works. - This path may be slow in real life because it reads from disk. - """ - # Request health pills for these nodes at step 7 specifically. - response = self.server.post( - '/data/plugin/debugger/health_pills', - data={ - 'node_names': json.dumps(['logits/Add', 'layers/Matmul']), - 'step': 7 - }) - self.assertEqual(200, response.status_code) - # The response should only include health pills at step 7. - self.assertDictEqual({ - 'logits/Add': [ - { - 'wall_time': 1337, - 'step': 7, - 'device_name': '/job:localhost/replica:0/task:0/cpu:0', - 'node_name': 'logits/Add', - 'output_slot': 0, - 'dtype': 'tf.int32', - 'shape': [3.0, 3.0], - 'value': (list(range(12)) + - [float(tf.int32.as_datatype_enum), 2.0, 3.0, 3.0]), + This path may be slow in real life because it reads from disk. + """ + # Request health pills for these nodes at step 7 specifically. + response = self.server.post( + "/data/plugin/debugger/health_pills", + data={ + "node_names": json.dumps(["logits/Add", "layers/Matmul"]), + "step": 7, }, - ], - 'layers/Matmul': [ + ) + self.assertEqual(200, response.status_code) + # The response should only include health pills at step 7. + self.assertDictEqual( { - 'wall_time': 43, - 'step': 7, - 'device_name': '/job:localhost/replica:0/task:0/cpu:0', - 'node_name': 'layers/Matmul', - 'output_slot': 1, - 'dtype': 'tf.float64', - 'shape': [3.0, 3.0], - 'value': (list(range(12)) + - [float(tf.float64.as_datatype_enum), 2.0, 3.0, 3.0]), + "logits/Add": [ + { + "wall_time": 1337, + "step": 7, + "device_name": "/job:localhost/replica:0/task:0/cpu:0", + "node_name": "logits/Add", + "output_slot": 0, + "dtype": "tf.int32", + "shape": [3.0, 3.0], + "value": ( + list(range(12)) + + [float(tf.int32.as_datatype_enum), 2.0, 3.0, 3.0] + ), + }, + ], + "layers/Matmul": [ + { + "wall_time": 43, + "step": 7, + "device_name": "/job:localhost/replica:0/task:0/cpu:0", + "node_name": "layers/Matmul", + "output_slot": 1, + "dtype": "tf.float64", + "shape": [3.0, 3.0], + "value": ( + list(range(12)) + + [ + float(tf.float64.as_datatype_enum), + 2.0, + 3.0, + 3.0, + ] + ), + }, + ], }, - ], - }, self._DeserializeResponse(response.get_data())) + self._DeserializeResponse(response.get_data()), + ) - def testNoHealthPillsForSpecificStep(self): - """Tests that an empty mapping is returned for no health pills at a step.""" - response = self.server.post( - '/data/plugin/debugger/health_pills', - data={ - 'node_names': json.dumps(['some/clearly/non-existent/op']), - 'step': 7 - }) - self.assertEqual(200, response.status_code) - self.assertDictEqual({}, self._DeserializeResponse(response.get_data())) + def testNoHealthPillsForSpecificStep(self): + """Tests that an empty mapping is returned for no health pills at a + step.""" + response = self.server.post( + "/data/plugin/debugger/health_pills", + data={ + "node_names": json.dumps(["some/clearly/non-existent/op"]), + "step": 7, + }, + ) + self.assertEqual(200, response.status_code) + self.assertDictEqual({}, self._DeserializeResponse(response.get_data())) - def testNoHealthPillsForOutOfRangeStep(self): - """Tests that an empty mapping is returned for an out of range step.""" - response = self.server.post( - '/data/plugin/debugger/health_pills', - data={ - 'node_names': json.dumps(['logits/Add', 'layers/Matmul']), - # This step higher than that of any event written to disk. - 'step': 42424242 - }) - self.assertEqual(200, response.status_code) - self.assertDictEqual({}, self._DeserializeResponse(response.get_data())) + def testNoHealthPillsForOutOfRangeStep(self): + """Tests that an empty mapping is returned for an out of range step.""" + response = self.server.post( + "/data/plugin/debugger/health_pills", + data={ + "node_names": json.dumps(["logits/Add", "layers/Matmul"]), + # This step higher than that of any event written to disk. + "step": 42424242, + }, + ) + self.assertEqual(200, response.status_code) + self.assertDictEqual({}, self._DeserializeResponse(response.get_data())) - def testNumericsAlertReportResponse(self): - """Tests that reports of bad values are returned.""" - alerts = [ - numerics_alert.NumericsAlertReportRow('cpu0', 'MatMul', 123, 2, 3, 4), - numerics_alert.NumericsAlertReportRow('cpu1', 'Add', 124, 5, 6, 7), - ] - self.mock_debugger_data_server.numerics_alert_report.return_value = alerts - response = self.server.get('/data/plugin/debugger/numerics_alert_report') - self.assertEqual(200, response.status_code) + def testNumericsAlertReportResponse(self): + """Tests that reports of bad values are returned.""" + alerts = [ + numerics_alert.NumericsAlertReportRow( + "cpu0", "MatMul", 123, 2, 3, 4 + ), + numerics_alert.NumericsAlertReportRow("cpu1", "Add", 124, 5, 6, 7), + ] + self.mock_debugger_data_server.numerics_alert_report.return_value = ( + alerts + ) + response = self.server.get( + "/data/plugin/debugger/numerics_alert_report" + ) + self.assertEqual(200, response.status_code) - retrieved_alerts = self._DeserializeResponse(response.get_data()) - self.assertEqual(2, len(retrieved_alerts)) - self.assertDictEqual({ - 'device_name': 'cpu0', - 'tensor_name': 'MatMul', - 'first_timestamp': 123, - 'nan_event_count': 2, - 'neg_inf_event_count': 3, - 'pos_inf_event_count': 4, - }, retrieved_alerts[0]) - self.assertDictEqual({ - 'device_name': 'cpu1', - 'tensor_name': 'Add', - 'first_timestamp': 124, - 'nan_event_count': 5, - 'neg_inf_event_count': 6, - 'pos_inf_event_count': 7, - }, retrieved_alerts[1]) + retrieved_alerts = self._DeserializeResponse(response.get_data()) + self.assertEqual(2, len(retrieved_alerts)) + self.assertDictEqual( + { + "device_name": "cpu0", + "tensor_name": "MatMul", + "first_timestamp": 123, + "nan_event_count": 2, + "neg_inf_event_count": 3, + "pos_inf_event_count": 4, + }, + retrieved_alerts[0], + ) + self.assertDictEqual( + { + "device_name": "cpu1", + "tensor_name": "Add", + "first_timestamp": 124, + "nan_event_count": 5, + "neg_inf_event_count": 6, + "pos_inf_event_count": 7, + }, + retrieved_alerts[1], + ) - def testDebuggerDataServerNotStartedWhenListenIsNotCalled(self): - """Tests that the plugin starts no debugger data server if port is None.""" - self.mock_debugger_data_server_class.reset_mock() + def testDebuggerDataServerNotStartedWhenListenIsNotCalled(self): + """Tests that the plugin starts no debugger data server if port is + None.""" + self.mock_debugger_data_server_class.reset_mock() - # Initialize a debugger plugin with no GRPC port provided. - self.debugger_plugin_module.DebuggerPlugin(self.context).get_plugin_apps() + # Initialize a debugger plugin with no GRPC port provided. + self.debugger_plugin_module.DebuggerPlugin( + self.context + ).get_plugin_apps() - # No debugger data server should have been started. - # assert_not_called is not available in Python 3.4. - self.assertFalse(self.mock_debugger_data_server_class.called) + # No debugger data server should have been started. + # assert_not_called is not available in Python 3.4. + self.assertFalse(self.mock_debugger_data_server_class.called) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/debugger/debugger_plugin_testlib.py b/tensorboard/plugins/debugger/debugger_plugin_testlib.py index e8a3de85b5..c0036399a5 100644 --- a/tensorboard/plugins/debugger/debugger_plugin_testlib.py +++ b/tensorboard/plugins/debugger/debugger_plugin_testlib.py @@ -28,183 +28,220 @@ # To keep compatibility with both 1.x and 2.x try: - from tensorflow.python import _pywrap_events_writer as tf_events_writer + from tensorflow.python import _pywrap_events_writer as tf_events_writer except ImportError: - from tensorflow.python import pywrap_tensorflow as tf_events_writer + from tensorflow.python import pywrap_tensorflow as tf_events_writer from werkzeug import wrappers from werkzeug import test as werkzeug_test from google.protobuf import json_format from tensorboard.backend import application -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from tensorboard.backend.event_processing import ( + plugin_event_multiplexer as event_multiplexer, +) from tensorboard.plugins import base_plugin from tensorboard.plugins.debugger import constants from tensorflow.core.debug import debugger_event_metadata_pb2 + # pylint: enable=ungrouped-imports, wrong-import-order class DebuggerPluginTestBase(tf.test.TestCase): - - def __init__(self, *args, **kwargs): - super(DebuggerPluginTestBase, self).__init__(*args, **kwargs) - self.debugger_plugin_module = None - - def setUp(self): - super(DebuggerPluginTestBase, self).setUp() - # Importing the debugger_plugin can sometimes unfortunately produce errors. - try: - - from tensorboard.plugins.debugger import debugger_plugin - from tensorboard.plugins.debugger import debugger_server_lib - - except Exception as e: # pylint: disable=broad-except - raise self.skipTest( - 'Skipping test because importing some modules failed: %r' % e) - self.debugger_plugin_module = debugger_plugin - - # Populate the log directory with debugger event for run '.'. - self.log_dir = self.get_temp_dir() - file_prefix = tf.compat.as_bytes( - os.path.join(self.log_dir, 'events.debugger')) - writer = tf_events_writer.EventsWriter(file_prefix) - device_name = '/job:localhost/replica:0/task:0/cpu:0' - writer.WriteEvent( - self._CreateEventWithDebugNumericSummary( - device_name=device_name, - op_name='layers/Matmul', - output_slot=0, - wall_time=42, - step=2, - list_of_values=(list(range(12)) + - [float(tf.float32.as_datatype_enum), 1.0, 3.0]))) - writer.WriteEvent( - self._CreateEventWithDebugNumericSummary( - device_name=device_name, - op_name='layers/Matmul', - output_slot=1, - wall_time=43, - step=7, - list_of_values=( - list(range(12)) + - [float(tf.float64.as_datatype_enum), 2.0, 3.0, 3.0]))) - writer.WriteEvent( - self._CreateEventWithDebugNumericSummary( - device_name=device_name, - op_name='logits/Add', - output_slot=0, - wall_time=1337, - step=7, - list_of_values=(list(range(12)) + - [float(tf.int32.as_datatype_enum), 2.0, 3.0, 3.0]))) - writer.WriteEvent( - self._CreateEventWithDebugNumericSummary( - device_name=device_name, - op_name='logits/Add', - output_slot=0, - wall_time=1338, - step=8, - list_of_values=(list(range(12)) + - [float(tf.int16.as_datatype_enum), 0.0]))) - writer.Close() - - # Populate the log directory with debugger event for run 'run_foo'. - run_foo_directory = os.path.join(self.log_dir, 'run_foo') - os.mkdir(run_foo_directory) - file_prefix = tf.compat.as_bytes( - os.path.join(run_foo_directory, 'events.debugger')) - writer = tf_events_writer.EventsWriter(file_prefix) - writer.WriteEvent( - self._CreateEventWithDebugNumericSummary( - device_name=device_name, - op_name='layers/Variable', - output_slot=0, - wall_time=4242, - step=42, - list_of_values=(list(range(12)) + - [float(tf.int16.as_datatype_enum), 1.0, 8.0]))) - writer.Close() - - # Start a server that will receive requests and respond with health pills. - multiplexer = event_multiplexer.EventMultiplexer({ - '.': self.log_dir, - 'run_foo': run_foo_directory, - }) - multiplexer.Reload() - self.debugger_data_server_grpc_port = portpicker.pick_unused_port() - - # Fake threading behavior so that threads are synchronous. - tf.compat.v1.test.mock.patch('threading.Thread.start', threading.Thread.run).start() - - self.mock_debugger_data_server = tf.compat.v1.test.mock.Mock( - debugger_server_lib.DebuggerDataServer) - self.mock_debugger_data_server_class = tf.compat.v1.test.mock.Mock( - debugger_server_lib.DebuggerDataServer, - return_value=self.mock_debugger_data_server) - - tf.compat.v1.test.mock.patch.object( - debugger_server_lib, - 'DebuggerDataServer', - self.mock_debugger_data_server_class).start() - - self.context = base_plugin.TBContext( - logdir=self.log_dir, multiplexer=multiplexer) - self.plugin = debugger_plugin.DebuggerPlugin(self.context) - self.plugin.listen(self.debugger_data_server_grpc_port) - wsgi_app = application.TensorBoardWSGI([self.plugin]) - self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse) - - # The debugger data server should be started at the correct port. - self.mock_debugger_data_server_class.assert_called_once_with( - self.debugger_data_server_grpc_port, self.log_dir) - - mock_debugger_data_server = self.mock_debugger_data_server - start = mock_debugger_data_server.start_the_debugger_data_receiving_server - self.assertEqual(1, start.call_count) - - def tearDown(self): - # Remove the directory with debugger-related events files. - tf.compat.v1.test.mock.patch.stopall() - - def _CreateEventWithDebugNumericSummary( - self, device_name, op_name, output_slot, wall_time, step, list_of_values): - """Creates event with a health pill summary. - - Note the debugger plugin only works with TensorFlow and, thus, uses TF - protos and TF EventsWriter. - - Args: - device_name: The name of the op's device. - op_name: The name of the op to which a DebugNumericSummary was attached. - output_slot: The numeric output slot for the tensor. - wall_time: The numeric wall time of the event. - step: The step of the event. - list_of_values: A python list of values within the tensor. - - Returns: - A `tf.Event` with a health pill summary. - """ - event = tf.compat.v1.Event(step=step, wall_time=wall_time) - tensor = tf.compat.v1.make_tensor_proto( - list_of_values, dtype=tf.float64, shape=[len(list_of_values)]) - value = event.summary.value.add( - tag=op_name, - node_name='%s:%d:DebugNumericSummary' % (op_name, output_slot), - tensor=tensor) - content_proto = debugger_event_metadata_pb2.DebuggerEventMetadata( - device=device_name, output_slot=output_slot) - value.metadata.plugin_data.plugin_name = constants.DEBUGGER_PLUGIN_NAME - value.metadata.plugin_data.content = tf.compat.as_bytes( - json_format.MessageToJson( - content_proto, including_default_value_fields=True)) - return event - - def _DeserializeResponse(self, byte_content): - """Deserializes byte content that is a JSON encoding. - - Args: - byte_content: The byte content of a JSON response. - - Returns: - The deserialized python object decoded from JSON. - """ - return json.loads(byte_content.decode('utf-8')) + def __init__(self, *args, **kwargs): + super(DebuggerPluginTestBase, self).__init__(*args, **kwargs) + self.debugger_plugin_module = None + + def setUp(self): + super(DebuggerPluginTestBase, self).setUp() + # Importing the debugger_plugin can sometimes unfortunately produce errors. + try: + + from tensorboard.plugins.debugger import debugger_plugin + from tensorboard.plugins.debugger import debugger_server_lib + + except Exception as e: # pylint: disable=broad-except + raise self.skipTest( + "Skipping test because importing some modules failed: %r" % e + ) + self.debugger_plugin_module = debugger_plugin + + # Populate the log directory with debugger event for run '.'. + self.log_dir = self.get_temp_dir() + file_prefix = tf.compat.as_bytes( + os.path.join(self.log_dir, "events.debugger") + ) + writer = tf_events_writer.EventsWriter(file_prefix) + device_name = "/job:localhost/replica:0/task:0/cpu:0" + writer.WriteEvent( + self._CreateEventWithDebugNumericSummary( + device_name=device_name, + op_name="layers/Matmul", + output_slot=0, + wall_time=42, + step=2, + list_of_values=( + list(range(12)) + + [float(tf.float32.as_datatype_enum), 1.0, 3.0] + ), + ) + ) + writer.WriteEvent( + self._CreateEventWithDebugNumericSummary( + device_name=device_name, + op_name="layers/Matmul", + output_slot=1, + wall_time=43, + step=7, + list_of_values=( + list(range(12)) + + [float(tf.float64.as_datatype_enum), 2.0, 3.0, 3.0] + ), + ) + ) + writer.WriteEvent( + self._CreateEventWithDebugNumericSummary( + device_name=device_name, + op_name="logits/Add", + output_slot=0, + wall_time=1337, + step=7, + list_of_values=( + list(range(12)) + + [float(tf.int32.as_datatype_enum), 2.0, 3.0, 3.0] + ), + ) + ) + writer.WriteEvent( + self._CreateEventWithDebugNumericSummary( + device_name=device_name, + op_name="logits/Add", + output_slot=0, + wall_time=1338, + step=8, + list_of_values=( + list(range(12)) + [float(tf.int16.as_datatype_enum), 0.0] + ), + ) + ) + writer.Close() + + # Populate the log directory with debugger event for run 'run_foo'. + run_foo_directory = os.path.join(self.log_dir, "run_foo") + os.mkdir(run_foo_directory) + file_prefix = tf.compat.as_bytes( + os.path.join(run_foo_directory, "events.debugger") + ) + writer = tf_events_writer.EventsWriter(file_prefix) + writer.WriteEvent( + self._CreateEventWithDebugNumericSummary( + device_name=device_name, + op_name="layers/Variable", + output_slot=0, + wall_time=4242, + step=42, + list_of_values=( + list(range(12)) + + [float(tf.int16.as_datatype_enum), 1.0, 8.0] + ), + ) + ) + writer.Close() + + # Start a server that will receive requests and respond with health pills. + multiplexer = event_multiplexer.EventMultiplexer( + {".": self.log_dir, "run_foo": run_foo_directory,} + ) + multiplexer.Reload() + self.debugger_data_server_grpc_port = portpicker.pick_unused_port() + + # Fake threading behavior so that threads are synchronous. + tf.compat.v1.test.mock.patch( + "threading.Thread.start", threading.Thread.run + ).start() + + self.mock_debugger_data_server = tf.compat.v1.test.mock.Mock( + debugger_server_lib.DebuggerDataServer + ) + self.mock_debugger_data_server_class = tf.compat.v1.test.mock.Mock( + debugger_server_lib.DebuggerDataServer, + return_value=self.mock_debugger_data_server, + ) + + tf.compat.v1.test.mock.patch.object( + debugger_server_lib, + "DebuggerDataServer", + self.mock_debugger_data_server_class, + ).start() + + self.context = base_plugin.TBContext( + logdir=self.log_dir, multiplexer=multiplexer + ) + self.plugin = debugger_plugin.DebuggerPlugin(self.context) + self.plugin.listen(self.debugger_data_server_grpc_port) + wsgi_app = application.TensorBoardWSGI([self.plugin]) + self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse) + + # The debugger data server should be started at the correct port. + self.mock_debugger_data_server_class.assert_called_once_with( + self.debugger_data_server_grpc_port, self.log_dir + ) + + mock_debugger_data_server = self.mock_debugger_data_server + start = ( + mock_debugger_data_server.start_the_debugger_data_receiving_server + ) + self.assertEqual(1, start.call_count) + + def tearDown(self): + # Remove the directory with debugger-related events files. + tf.compat.v1.test.mock.patch.stopall() + + def _CreateEventWithDebugNumericSummary( + self, device_name, op_name, output_slot, wall_time, step, list_of_values + ): + """Creates event with a health pill summary. + + Note the debugger plugin only works with TensorFlow and, thus, uses TF + protos and TF EventsWriter. + + Args: + device_name: The name of the op's device. + op_name: The name of the op to which a DebugNumericSummary was attached. + output_slot: The numeric output slot for the tensor. + wall_time: The numeric wall time of the event. + step: The step of the event. + list_of_values: A python list of values within the tensor. + + Returns: + A `tf.Event` with a health pill summary. + """ + event = tf.compat.v1.Event(step=step, wall_time=wall_time) + tensor = tf.compat.v1.make_tensor_proto( + list_of_values, dtype=tf.float64, shape=[len(list_of_values)] + ) + value = event.summary.value.add( + tag=op_name, + node_name="%s:%d:DebugNumericSummary" % (op_name, output_slot), + tensor=tensor, + ) + content_proto = debugger_event_metadata_pb2.DebuggerEventMetadata( + device=device_name, output_slot=output_slot + ) + value.metadata.plugin_data.plugin_name = constants.DEBUGGER_PLUGIN_NAME + value.metadata.plugin_data.content = tf.compat.as_bytes( + json_format.MessageToJson( + content_proto, including_default_value_fields=True + ) + ) + return event + + def _DeserializeResponse(self, byte_content): + """Deserializes byte content that is a JSON encoding. + + Args: + byte_content: The byte content of a JSON response. + + Returns: + The deserialized python object decoded from JSON. + """ + return json.loads(byte_content.decode("utf-8")) diff --git a/tensorboard/plugins/debugger/debugger_server_lib.py b/tensorboard/plugins/debugger/debugger_server_lib.py index 1176e40a49..18cda5fa5a 100644 --- a/tensorboard/plugins/debugger/debugger_server_lib.py +++ b/tensorboard/plugins/debugger/debugger_server_lib.py @@ -14,8 +14,8 @@ # ============================================================================== """Receives data from a TensorFlow debugger. Writes event summaries. -This listener server writes debugging-related events into a logdir directory, -from which a TensorBoard instance can read. +This listener server writes debugging-related events into a logdir +directory, from which a TensorBoard instance can read. """ from __future__ import absolute_import @@ -33,7 +33,9 @@ from tensorflow.python.debug.lib import grpc_debug_server from tensorboard.plugins.debugger import constants -from tensorboard.plugins.debugger import events_writer_manager as events_writer_manager_lib +from tensorboard.plugins.debugger import ( + events_writer_manager as events_writer_manager_lib, +) from tensorboard.plugins.debugger import numerics_alert from tensorboard.util import tb_logging from tensorboard.util import tensor_util @@ -42,276 +44,306 @@ class DebuggerDataStreamHandler( - grpc_debug_server.EventListenerBaseStreamHandler): - """Implementation of stream handler for debugger data. - - Each instance of this class is created by a DebuggerDataServer upon a - gRPC stream established between the debugged Session::Run() invocation in - TensorFlow core runtime and the DebuggerDataServer instance. - - Each instance of this class does the following: - 1) receives a core metadata Event proto during its constructor call. - 2) receives GraphDef Event proto(s) through its on_graph_event method. - 3) receives tensor value Event proto(s) through its on_value_event method. - """ - - def __init__(self, - events_writer_manager, - numerics_alert_callback=None): - """Constructor of DebuggerDataStreamHandler. - - Args: - events_writer_manager: Manages writing events to disk. - numerics_alert_callback: An optional callback run every time a health pill - event with bad values (Nan, -Inf, or +Inf) is received. The callback - takes the event as a parameter. + grpc_debug_server.EventListenerBaseStreamHandler +): + """Implementation of stream handler for debugger data. + + Each instance of this class is created by a DebuggerDataServer upon a + gRPC stream established between the debugged Session::Run() invocation in + TensorFlow core runtime and the DebuggerDataServer instance. + + Each instance of this class does the following: + 1) receives a core metadata Event proto during its constructor call. + 2) receives GraphDef Event proto(s) through its on_graph_event method. + 3) receives tensor value Event proto(s) through its on_value_event method. """ - super(DebuggerDataStreamHandler, self).__init__() - self._events_writer_manager = events_writer_manager - self._numerics_alert_callback = numerics_alert_callback - - # We use session_run_index as the "step" value for debugger events because - # it is unique across all runs. It is not specific to a set of feeds and - # fetches. - self._session_run_index = -1 - - def on_core_metadata_event(self, event): - """Implementation of the core metadata-carrying Event proto callback. - - Args: - event: An Event proto that contains core metadata about the debugged - Session::Run() in its log_message.message field, as a JSON string. - See the doc string of debug_data.DebugDumpDir.core_metadata for details. - """ - self._session_run_index = self._parse_session_run_index(event) - - def on_graph_def(self, graph_def, device_name, wall_time): - """Implementation of the GraphDef-carrying Event proto callback. - - Args: - graph_def: A GraphDef proto. N.B.: The GraphDef is from - the core runtime of a debugged Session::Run() call, after graph - partition. Therefore it may differ from the GraphDef available to - the general TensorBoard. For example, the GraphDef in general - TensorBoard may get partitioned for multiple devices (CPUs and GPUs), - each of which will generate a GraphDef event proto sent to this - method. - device_name: Name of the device on which the graph was created. - wall_time: An epoch timestamp (in microseconds) for the graph. - """ - # For now, we do nothing with the graph def. However, we must define this - # method to satisfy the handler's interface. Furthermore, we may use the - # graph in the future (for instance to provide a graph if there is no graph - # provided otherwise). - del device_name - del wall_time - del graph_def - - def on_value_event(self, event): - """Records the summary values based on an updated message from the debugger. - Logs an error message if writing the event to disk fails. + def __init__(self, events_writer_manager, numerics_alert_callback=None): + """Constructor of DebuggerDataStreamHandler. + + Args: + events_writer_manager: Manages writing events to disk. + numerics_alert_callback: An optional callback run every time a health pill + event with bad values (Nan, -Inf, or +Inf) is received. The callback + takes the event as a parameter. + """ + super(DebuggerDataStreamHandler, self).__init__() + self._events_writer_manager = events_writer_manager + self._numerics_alert_callback = numerics_alert_callback + + # We use session_run_index as the "step" value for debugger events because + # it is unique across all runs. It is not specific to a set of feeds and + # fetches. + self._session_run_index = -1 + + def on_core_metadata_event(self, event): + """Implementation of the core metadata-carrying Event proto callback. + + Args: + event: An Event proto that contains core metadata about the debugged + Session::Run() in its log_message.message field, as a JSON string. + See the doc string of debug_data.DebugDumpDir.core_metadata for details. + """ + self._session_run_index = self._parse_session_run_index(event) + + def on_graph_def(self, graph_def, device_name, wall_time): + """Implementation of the GraphDef-carrying Event proto callback. + + Args: + graph_def: A GraphDef proto. N.B.: The GraphDef is from + the core runtime of a debugged Session::Run() call, after graph + partition. Therefore it may differ from the GraphDef available to + the general TensorBoard. For example, the GraphDef in general + TensorBoard may get partitioned for multiple devices (CPUs and GPUs), + each of which will generate a GraphDef event proto sent to this + method. + device_name: Name of the device on which the graph was created. + wall_time: An epoch timestamp (in microseconds) for the graph. + """ + # For now, we do nothing with the graph def. However, we must define this + # method to satisfy the handler's interface. Furthermore, we may use the + # graph in the future (for instance to provide a graph if there is no graph + # provided otherwise). + del device_name + del wall_time + del graph_def + + def on_value_event(self, event): + """Records the summary values based on an updated message from the + debugger. + + Logs an error message if writing the event to disk fails. + + Args: + event: The Event proto to be processed. + """ + if not event.summary.value: + logger.warn("The summary of the event lacks a value.") + return + + # The node name property is actually a watch key, which is a concatenation + # of several pieces of data. + watch_key = event.summary.value[0].node_name + if not watch_key.endswith(constants.DEBUG_NUMERIC_SUMMARY_SUFFIX): + # Ignore events that lack a DebugNumericSummary. + # NOTE(@chihuahua): We may later handle other types of debug ops. + return + + # We remove the constants.DEBUG_NUMERIC_SUMMARY_SUFFIX from the end of the + # watch name because it is not distinguishing: every health pill entry ends + # with it. + node_name_and_output_slot = watch_key[ + : -len(constants.DEBUG_NUMERIC_SUMMARY_SUFFIX) + ] + + shape = tensor_util.make_ndarray(event.summary.value[0].tensor).shape + if ( + len(shape) != 1 + or shape[0] < constants.MIN_DEBUG_NUMERIC_SUMMARY_TENSOR_LENGTH + ): + logger.warn( + "Health-pill tensor either lacks a dimension or is " + "shaped incorrectly: %s" % shape + ) + return + + match = re.match(r"^(.*):(\d+)$", node_name_and_output_slot) + if not match: + logger.warn( + ( + "A event with a health pill has an invalid node name and output " + "slot combination, (i.e., an unexpected debug op): %r" + ), + node_name_and_output_slot, + ) + return + + if self._session_run_index >= 0: + event.step = self._session_run_index + else: + # Data from parameter servers (or any graphs without a master) do not + # contain core metadata. So the session run count is missing. Set its + # value to a microsecond epoch timestamp. + event.step = int(time.time() * 1e6) + + # Write this event to the events file designated for data from the + # debugger. + self._events_writer_manager.write_event(event) + + alert = numerics_alert.extract_numerics_alert(event) + if self._numerics_alert_callback and alert: + self._numerics_alert_callback(alert) + + def _parse_session_run_index(self, event): + """Parses the session_run_index value from the event proto. + + Args: + event: The event with metadata that contains the session_run_index. + + Returns: + The int session_run_index value. Or + constants.SENTINEL_FOR_UNDETERMINED_STEP if it could not be determined. + """ + metadata_string = event.log_message.message + try: + metadata = json.loads(metadata_string) + except ValueError as e: + logger.error( + "Could not decode metadata string '%s' for step value: %s", + metadata_string, + e, + ) + return constants.SENTINEL_FOR_UNDETERMINED_STEP - Args: - event: The Event proto to be processed. - """ - if not event.summary.value: - logger.warn("The summary of the event lacks a value.") - return - - # The node name property is actually a watch key, which is a concatenation - # of several pieces of data. - watch_key = event.summary.value[0].node_name - if not watch_key.endswith(constants.DEBUG_NUMERIC_SUMMARY_SUFFIX): - # Ignore events that lack a DebugNumericSummary. - # NOTE(@chihuahua): We may later handle other types of debug ops. - return - - # We remove the constants.DEBUG_NUMERIC_SUMMARY_SUFFIX from the end of the - # watch name because it is not distinguishing: every health pill entry ends - # with it. - node_name_and_output_slot = watch_key[ - :-len(constants.DEBUG_NUMERIC_SUMMARY_SUFFIX)] - - shape = tensor_util.make_ndarray(event.summary.value[0].tensor).shape - if (len(shape) != 1 or - shape[0] < constants.MIN_DEBUG_NUMERIC_SUMMARY_TENSOR_LENGTH): - logger.warn("Health-pill tensor either lacks a dimension or is " - "shaped incorrectly: %s" % shape) - return - - match = re.match(r"^(.*):(\d+)$", node_name_and_output_slot) - if not match: - logger.warn( - ("A event with a health pill has an invalid node name and output " - "slot combination, (i.e., an unexpected debug op): %r"), - node_name_and_output_slot) - return - - if self._session_run_index >= 0: - event.step = self._session_run_index - else: - # Data from parameter servers (or any graphs without a master) do not - # contain core metadata. So the session run count is missing. Set its - # value to a microsecond epoch timestamp. - event.step = int(time.time() * 1e6) - - # Write this event to the events file designated for data from the - # debugger. - self._events_writer_manager.write_event(event) - - alert = numerics_alert.extract_numerics_alert(event) - if self._numerics_alert_callback and alert: - self._numerics_alert_callback(alert) - - def _parse_session_run_index(self, event): - """Parses the session_run_index value from the event proto. - - Args: - event: The event with metadata that contains the session_run_index. - - Returns: - The int session_run_index value. Or - constants.SENTINEL_FOR_UNDETERMINED_STEP if it could not be determined. - """ - metadata_string = event.log_message.message - try: - metadata = json.loads(metadata_string) - except ValueError as e: - logger.error( - "Could not decode metadata string '%s' for step value: %s", - metadata_string, e) - return constants.SENTINEL_FOR_UNDETERMINED_STEP - - try: - return metadata["session_run_index"] - except KeyError: - logger.error( - "The session_run_index is missing from the metadata: %s", - metadata_string) - return constants.SENTINEL_FOR_UNDETERMINED_STEP + try: + return metadata["session_run_index"] + except KeyError: + logger.error( + "The session_run_index is missing from the metadata: %s", + metadata_string, + ) + return constants.SENTINEL_FOR_UNDETERMINED_STEP class DebuggerDataServer(grpc_debug_server.EventListenerBaseServicer): - """A service that receives and writes debugger data such as health pills. - """ - - def __init__(self, - receive_port, - logdir, - always_flush=False): - """Receives health pills from a debugger and writes them to disk. - - Args: - receive_port: The port at which to receive health pills from the - TensorFlow debugger. - logdir: The directory in which to write events files that TensorBoard will - read. - always_flush: A boolean indicating whether the EventsWriter will be - flushed after every write. Can be used for testing. - """ - # We create a special directory within logdir to store debugger-related - # events (if that directory does not already exist). This is necessary - # because for each directory within logdir, TensorBoard only reads through - # each events file once. There may be other non-debugger events files being - # written to at the same time. Without this special directory, TensorBoard - # may stop surfacing health pills after some arbitrary step value. - debugger_directory = os.path.join( - os.path.expanduser(logdir), constants.DEBUGGER_DATA_DIRECTORY_NAME) - - if not tf.io.gfile.exists(debugger_directory): - try: - tf.io.gfile.makedirs(debugger_directory) - logger.info("Created directory for debugger data: %s", - debugger_directory) - except tf.errors.OpError as e: - logger.fatal( - "Could not make directory for debugger data: %s. Error: %s", - debugger_directory, e) - - self._events_writer_manager = events_writer_manager_lib.EventsWriterManager( - events_directory=debugger_directory, - always_flush=always_flush) - - # Write an event with a file version as the first event within the events - # file. If the event version is 2, TensorBoard uses a path for purging - # events that does not depend on step. This is important because debugger - # events use a notion of step that differs from that of the rest of - # TensorBoard. - try: - self._events_writer_manager.write_event( - tf.compat.v1.Event( - wall_time=0, step=0, file_version=constants.EVENTS_VERSION)) - except IOError as e: - logger.error( - "Writing to %s failed: %s", - self._events_writer_manager.get_current_file_name(), e) - - # See if a backup file exists. If so, use it to initialize the registry. - self._registry_backup_file_path = os.path.join( - debugger_directory, constants.ALERT_REGISTRY_BACKUP_FILE_NAME) - initial_data = None - - if tf.io.gfile.exists(self._registry_backup_file_path): - # A backup file exists. Read its contents to use for initialization. - with tf.io.gfile.GFile(self._registry_backup_file_path, "r") as backup_file: + """A service that receives and writes debugger data such as health + pills.""" + + def __init__(self, receive_port, logdir, always_flush=False): + """Receives health pills from a debugger and writes them to disk. + + Args: + receive_port: The port at which to receive health pills from the + TensorFlow debugger. + logdir: The directory in which to write events files that TensorBoard will + read. + always_flush: A boolean indicating whether the EventsWriter will be + flushed after every write. Can be used for testing. + """ + # We create a special directory within logdir to store debugger-related + # events (if that directory does not already exist). This is necessary + # because for each directory within logdir, TensorBoard only reads through + # each events file once. There may be other non-debugger events files being + # written to at the same time. Without this special directory, TensorBoard + # may stop surfacing health pills after some arbitrary step value. + debugger_directory = os.path.join( + os.path.expanduser(logdir), constants.DEBUGGER_DATA_DIRECTORY_NAME + ) + + if not tf.io.gfile.exists(debugger_directory): + try: + tf.io.gfile.makedirs(debugger_directory) + logger.info( + "Created directory for debugger data: %s", + debugger_directory, + ) + except tf.errors.OpError as e: + logger.fatal( + "Could not make directory for debugger data: %s. Error: %s", + debugger_directory, + e, + ) + + self._events_writer_manager = events_writer_manager_lib.EventsWriterManager( + events_directory=debugger_directory, always_flush=always_flush + ) + + # Write an event with a file version as the first event within the events + # file. If the event version is 2, TensorBoard uses a path for purging + # events that does not depend on step. This is important because debugger + # events use a notion of step that differs from that of the rest of + # TensorBoard. try: - # Use the data to initialize the registry. - initial_data = json.load(backup_file) - except ValueError as err: - # Could not parse the data. No backup data obtained. - logger.error( - "Could not parse contents of %s: %s", - self._registry_backup_file_path, err) - - self._numerics_alert_registry = numerics_alert.NumericsAlertRegistry( - initialization_list=initial_data) - - self._numerics_alert_lock = threading.Lock() - curried_handler_constructor = functools.partial( - DebuggerDataStreamHandler, - self._events_writer_manager, - self._numerics_alert_callback) - grpc_debug_server.EventListenerBaseServicer.__init__( - self, receive_port, curried_handler_constructor) - - def start_the_debugger_data_receiving_server(self): - """Starts the HTTP server for receiving health pills at `receive_port`. - - After this method is called, health pills issued to host:receive_port - will be stored by this object. Calling this method also creates a file - within the log directory for storing health pill summary events. - """ - self.run_server() - - def get_events_file_name(self): - """Gets the name of the debugger events file currently being written to. - - Returns: - The string name of the debugger events file currently being written to. - This is just the name of that file, not the full path to that file. - """ - return self._events_writer_manager.get_current_file_name() - - def _numerics_alert_callback(self, alert): - """Handles the case in which we receive a bad value (NaN, -/+ Inf). - - Args: - alert: The alert to be registered. - """ - with self._numerics_alert_lock: - self._numerics_alert_registry.register(alert) - - def numerics_alert_report(self): - """Get a report of the numerics alerts that have occurred. - - Returns: - A list of `numerics_alert.NumericsAlertReportRow`, sorted in ascending - order of first_timestamp. - """ - with self._numerics_alert_lock: - return self._numerics_alert_registry.report() - - def dispose(self): - """Disposes of this object. Call only after this is done being used.""" - self._events_writer_manager.dispose() + self._events_writer_manager.write_event( + tf.compat.v1.Event( + wall_time=0, step=0, file_version=constants.EVENTS_VERSION + ) + ) + except IOError as e: + logger.error( + "Writing to %s failed: %s", + self._events_writer_manager.get_current_file_name(), + e, + ) + + # See if a backup file exists. If so, use it to initialize the registry. + self._registry_backup_file_path = os.path.join( + debugger_directory, constants.ALERT_REGISTRY_BACKUP_FILE_NAME + ) + initial_data = None + + if tf.io.gfile.exists(self._registry_backup_file_path): + # A backup file exists. Read its contents to use for initialization. + with tf.io.gfile.GFile( + self._registry_backup_file_path, "r" + ) as backup_file: + try: + # Use the data to initialize the registry. + initial_data = json.load(backup_file) + except ValueError as err: + # Could not parse the data. No backup data obtained. + logger.error( + "Could not parse contents of %s: %s", + self._registry_backup_file_path, + err, + ) + + self._numerics_alert_registry = numerics_alert.NumericsAlertRegistry( + initialization_list=initial_data + ) + + self._numerics_alert_lock = threading.Lock() + curried_handler_constructor = functools.partial( + DebuggerDataStreamHandler, + self._events_writer_manager, + self._numerics_alert_callback, + ) + grpc_debug_server.EventListenerBaseServicer.__init__( + self, receive_port, curried_handler_constructor + ) + + def start_the_debugger_data_receiving_server(self): + """Starts the HTTP server for receiving health pills at `receive_port`. + + After this method is called, health pills issued to + host:receive_port will be stored by this object. Calling this + method also creates a file within the log directory for storing + health pill summary events. + """ + self.run_server() + + def get_events_file_name(self): + """Gets the name of the debugger events file currently being written + to. + + Returns: + The string name of the debugger events file currently being written to. + This is just the name of that file, not the full path to that file. + """ + return self._events_writer_manager.get_current_file_name() + + def _numerics_alert_callback(self, alert): + """Handles the case in which we receive a bad value (NaN, -/+ Inf). + + Args: + alert: The alert to be registered. + """ + with self._numerics_alert_lock: + self._numerics_alert_registry.register(alert) + + def numerics_alert_report(self): + """Get a report of the numerics alerts that have occurred. + + Returns: + A list of `numerics_alert.NumericsAlertReportRow`, sorted in ascending + order of first_timestamp. + """ + with self._numerics_alert_lock: + return self._numerics_alert_registry.report() + + def dispose(self): + """Disposes of this object. + + Call only after this is done being used. + """ + self._events_writer_manager.dispose() diff --git a/tensorboard/plugins/debugger/debugger_server_test.py b/tensorboard/plugins/debugger/debugger_server_test.py index b95f90ff2d..03a817af61 100644 --- a/tensorboard/plugins/debugger/debugger_server_test.py +++ b/tensorboard/plugins/debugger/debugger_server_test.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests the debugger data server, which receives and writes debugger events.""" +"""Tests the debugger data server, which receives and writes debugger +events.""" from __future__ import absolute_import from __future__ import division @@ -31,228 +32,293 @@ from tensorboard.plugins.debugger import debugger_server_lib from tensorboard.plugins.debugger import numerics_alert from tensorboard.util import tensor_util + # pylint: enable=ungrouped-imports, wrong-import-order class FakeEventsWriterManager(object): - """An events writer manager that tracks events that would be written. + """An events writer manager that tracks events that would be written. - During normal usage, the debugger data server would write events to disk. - Unfortunately, this test cannot depend on TensorFlow's record reader due to - GRPC library conflicts (b/35006065). Hence, we use a fake EventsWriter that - keeps track of events that would be written to disk. - """ + During normal usage, the debugger data server would write events to + disk. Unfortunately, this test cannot depend on TensorFlow's record + reader due to GRPC library conflicts (b/35006065). Hence, we use a + fake EventsWriter that keeps track of events that would be written + to disk. + """ - def __init__(self, events_output_list): - """Constructs a fake events writer, which appends events to a list. + def __init__(self, events_output_list): + """Constructs a fake events writer, which appends events to a list. - Args: - events_output_list: The list to append events that would be written to - disk. - """ - self.events_written = events_output_list + Args: + events_output_list: The list to append events that would be written to + disk. + """ + self.events_written = events_output_list - def dispose(self): - """Does nothing. This implementation creates no file.""" + def dispose(self): + """Does nothing. - def write_event(self, event): - """Pretends to write an event to disk. + This implementation creates no file. + """ - Args: - event: The event proto. - """ - self.events_written.append(event) + def write_event(self, event): + """Pretends to write an event to disk. + Args: + event: The event proto. + """ + self.events_written.append(event) -class DebuggerDataServerTest(tf.test.TestCase): - def setUp(self): - self.events_written = [] - - events_writer_manager = FakeEventsWriterManager(self.events_written) - self.stream_handler = debugger_server_lib.DebuggerDataStreamHandler( - events_writer_manager=events_writer_manager) - self.stream_handler.on_core_metadata_event(event_pb2.Event()) - - def tearDown(self): - tf.compat.v1.test.mock.patch.stopall() - - def _create_event_with_float_tensor(self, node_name, output_slot, debug_op, - list_of_values): - """Creates event with float64 (double) tensors. - - Args: - node_name: The string name of the op. This lacks both the output slot as - well as the name of the debug op. - output_slot: The number that is the output slot. - debug_op: The name of the debug op to use. - list_of_values: A python list of values within the tensor. - Returns: - A `Event` with a summary containing that node name and a float64 - tensor with those values. - """ - event = event_pb2.Event() - value = event.summary.value.add( - tag=node_name, - node_name="%s:%d:%s" % (node_name, output_slot, debug_op), - tensor=tensor_util.make_tensor_proto( - list_of_values, dtype=tf.float64, shape=[len(list_of_values)])) - plugin_content = debugger_event_metadata_pb2.DebuggerEventMetadata( - device="/job:localhost/replica:0/task:0/cpu:0", output_slot=output_slot) - value.metadata.plugin_data.plugin_name = constants.DEBUGGER_PLUGIN_NAME - value.metadata.plugin_data.content = tf.compat.as_bytes( - json_format.MessageToJson( - plugin_content, including_default_value_fields=True)) - return event - - def _verify_event_lists_have_same_tensor_values(self, expected, gotten): - """Checks that two lists of events have the same tensor values. - - Args: - expected: The expected list of events. - gotten: The list of events we actually got. - """ - self.assertEqual(len(expected), len(gotten)) - - # Compare the events one at a time. - for expected_event, gotten_event in zip(expected, gotten): - self.assertEqual(expected_event.summary.value[0].node_name, - gotten_event.summary.value[0].node_name) - self.assertAllClose( - tensor_util.make_ndarray(expected_event.summary.value[0].tensor), - tensor_util.make_ndarray(gotten_event.summary.value[0].tensor)) - self.assertEqual(expected_event.summary.value[0].tag, - gotten_event.summary.value[0].tag) - - def testOnValueEventWritesHealthPill(self): - """Tests that the stream handler writes health pills in order.""" - # The debugger stream handler receives 2 health pill events. - received_events = [ - self._create_event_with_float_tensor( - "MatMul", 0, "DebugNumericSummary", list(range(1, 15))), - self._create_event_with_float_tensor( - "add", 0, "DebugNumericSummary", [x * x for x in range(1, 15)]), - self._create_event_with_float_tensor( - "MatMul", 0, "DebugNumericSummary", [x + 42 for x in range(1, 15)]), - ] - - for event in received_events: - self.stream_handler.on_value_event(event) - - # Verify that the stream handler wrote them to disk in order. - self._verify_event_lists_have_same_tensor_values(received_events, - self.events_written) - - def testOnValueEventIgnoresIrrelevantOps(self): - """Tests that non-DebugNumericSummary ops are ignored.""" - # Receive a DebugNumericSummary event. - numeric_summary_event = self._create_event_with_float_tensor( - "MatMul", 42, "DebugNumericSummary", list(range(1, 15))) - self.stream_handler.on_value_event(numeric_summary_event) - - # Receive a non-DebugNumericSummary event. - self.stream_handler.on_value_event( - self._create_event_with_float_tensor("add", 0, "DebugIdentity", - list(range(1, 15)))) - - # The stream handler should have only written the DebugNumericSummary event - # to disk. - self._verify_event_lists_have_same_tensor_values([numeric_summary_event], - self.events_written) - - def testCorrectStepIsWritten(self): - events_written = [] - metadata_event = event_pb2.Event() - metadata_event.log_message.message = json.dumps({"session_run_index": 42}) - stream_handler = debugger_server_lib.DebuggerDataStreamHandler( - events_writer_manager=FakeEventsWriterManager(events_written)) - stream_handler.on_core_metadata_event(metadata_event) - - # The server receives 2 events. It should assign both the correct step. - stream_handler.on_value_event( - self._create_event_with_float_tensor("MatMul", 0, "DebugNumericSummary", - list(range(1, 15)))) - stream_handler.on_value_event( - self._create_event_with_float_tensor("Add", 0, "DebugNumericSummary", - list(range(2, 16)))) - self.assertEqual(42, events_written[0].step) - self.assertEqual(42, events_written[1].step) - - def testSentinelStepValueAssignedWhenExecutorStepCountKeyIsMissing(self): - events_written = [] - metadata_event = event_pb2.Event() - metadata_event.log_message.message = json.dumps({}) - stream_handler = debugger_server_lib.DebuggerDataStreamHandler( - events_writer_manager=FakeEventsWriterManager(events_written)) - stream_handler.on_core_metadata_event(metadata_event) - health_pill_event = self._create_event_with_float_tensor( - "MatMul", 0, "DebugNumericSummary", list(range(1, 15))) - stream_handler.on_value_event(health_pill_event) - self.assertGreater(events_written[0].step, 0) - - def testSentinelStepValueAssignedWhenMetadataJsonIsInvalid(self): - events_written = [] - metadata_event = event_pb2.Event() - metadata_event.log_message.message = "some invalid JSON string" - stream_handler = debugger_server_lib.DebuggerDataStreamHandler( - events_writer_manager=FakeEventsWriterManager(events_written)) - stream_handler.on_core_metadata_event(metadata_event) - health_pill_event = self._create_event_with_float_tensor( - "MatMul", 0, "DebugNumericSummary", list(range(1, 15))) - stream_handler.on_value_event(health_pill_event) - self.assertGreater(events_written[0].step, 0) - - def testAlertingEventCallback(self): - numerics_alert_callback = tf.compat.v1.test.mock.Mock() - stream_handler = debugger_server_lib.DebuggerDataStreamHandler( - events_writer_manager=FakeEventsWriterManager( - self.events_written), - numerics_alert_callback=numerics_alert_callback) - stream_handler.on_core_metadata_event(event_pb2.Event()) - - # The stream handler receives 1 good event and 1 with an NaN value. - stream_handler.on_value_event( - self._create_event_with_float_tensor("Add", 0, "DebugNumericSummary", - [0] * 14)) - stream_handler.on_value_event( - self._create_event_with_float_tensor("Add", 0, "DebugNumericSummary", [ - 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 - ])) - - # The second event should have triggered the callback. - numerics_alert_callback.assert_called_once_with( - numerics_alert.NumericsAlert("/job:localhost/replica:0/task:0/cpu:0", - "Add:0", 0, 1, 0, 0)) - - # The stream handler receives an event with a -Inf value. - numerics_alert_callback.reset_mock() - stream_handler.on_value_event( - self._create_event_with_float_tensor("Add", 0, "DebugNumericSummary", [ - 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 - ])) - numerics_alert_callback.assert_called_once_with( - numerics_alert.NumericsAlert("/job:localhost/replica:0/task:0/cpu:0", - "Add:0", 0, 0, 1, 0)) - - # The stream handler receives an event with a +Inf value. - numerics_alert_callback.reset_mock() - stream_handler.on_value_event( - self._create_event_with_float_tensor("Add", 0, "DebugNumericSummary", [ - 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0 - ])) - numerics_alert_callback.assert_called_once_with( - numerics_alert.NumericsAlert("/job:localhost/replica:0/task:0/cpu:0", - "Add:0", 0, 0, 0, 1)) - - # The stream handler receives an event without any pathetic values. - numerics_alert_callback.reset_mock() - stream_handler.on_value_event( - self._create_event_with_float_tensor("Add", 0, "DebugNumericSummary", [ - 0, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0 - ])) - # assert_not_called is not available in Python 3.4. - self.assertFalse(numerics_alert_callback.called) +class DebuggerDataServerTest(tf.test.TestCase): + def setUp(self): + self.events_written = [] + + events_writer_manager = FakeEventsWriterManager(self.events_written) + self.stream_handler = debugger_server_lib.DebuggerDataStreamHandler( + events_writer_manager=events_writer_manager + ) + self.stream_handler.on_core_metadata_event(event_pb2.Event()) + + def tearDown(self): + tf.compat.v1.test.mock.patch.stopall() + + def _create_event_with_float_tensor( + self, node_name, output_slot, debug_op, list_of_values + ): + """Creates event with float64 (double) tensors. + + Args: + node_name: The string name of the op. This lacks both the output slot as + well as the name of the debug op. + output_slot: The number that is the output slot. + debug_op: The name of the debug op to use. + list_of_values: A python list of values within the tensor. + Returns: + A `Event` with a summary containing that node name and a float64 + tensor with those values. + """ + event = event_pb2.Event() + value = event.summary.value.add( + tag=node_name, + node_name="%s:%d:%s" % (node_name, output_slot, debug_op), + tensor=tensor_util.make_tensor_proto( + list_of_values, dtype=tf.float64, shape=[len(list_of_values)] + ), + ) + plugin_content = debugger_event_metadata_pb2.DebuggerEventMetadata( + device="/job:localhost/replica:0/task:0/cpu:0", + output_slot=output_slot, + ) + value.metadata.plugin_data.plugin_name = constants.DEBUGGER_PLUGIN_NAME + value.metadata.plugin_data.content = tf.compat.as_bytes( + json_format.MessageToJson( + plugin_content, including_default_value_fields=True + ) + ) + return event + + def _verify_event_lists_have_same_tensor_values(self, expected, gotten): + """Checks that two lists of events have the same tensor values. + + Args: + expected: The expected list of events. + gotten: The list of events we actually got. + """ + self.assertEqual(len(expected), len(gotten)) + + # Compare the events one at a time. + for expected_event, gotten_event in zip(expected, gotten): + self.assertEqual( + expected_event.summary.value[0].node_name, + gotten_event.summary.value[0].node_name, + ) + self.assertAllClose( + tensor_util.make_ndarray( + expected_event.summary.value[0].tensor + ), + tensor_util.make_ndarray(gotten_event.summary.value[0].tensor), + ) + self.assertEqual( + expected_event.summary.value[0].tag, + gotten_event.summary.value[0].tag, + ) + + def testOnValueEventWritesHealthPill(self): + """Tests that the stream handler writes health pills in order.""" + # The debugger stream handler receives 2 health pill events. + received_events = [ + self._create_event_with_float_tensor( + "MatMul", 0, "DebugNumericSummary", list(range(1, 15)) + ), + self._create_event_with_float_tensor( + "add", 0, "DebugNumericSummary", [x * x for x in range(1, 15)] + ), + self._create_event_with_float_tensor( + "MatMul", + 0, + "DebugNumericSummary", + [x + 42 for x in range(1, 15)], + ), + ] + + for event in received_events: + self.stream_handler.on_value_event(event) + + # Verify that the stream handler wrote them to disk in order. + self._verify_event_lists_have_same_tensor_values( + received_events, self.events_written + ) + + def testOnValueEventIgnoresIrrelevantOps(self): + """Tests that non-DebugNumericSummary ops are ignored.""" + # Receive a DebugNumericSummary event. + numeric_summary_event = self._create_event_with_float_tensor( + "MatMul", 42, "DebugNumericSummary", list(range(1, 15)) + ) + self.stream_handler.on_value_event(numeric_summary_event) + + # Receive a non-DebugNumericSummary event. + self.stream_handler.on_value_event( + self._create_event_with_float_tensor( + "add", 0, "DebugIdentity", list(range(1, 15)) + ) + ) + + # The stream handler should have only written the DebugNumericSummary event + # to disk. + self._verify_event_lists_have_same_tensor_values( + [numeric_summary_event], self.events_written + ) + + def testCorrectStepIsWritten(self): + events_written = [] + metadata_event = event_pb2.Event() + metadata_event.log_message.message = json.dumps( + {"session_run_index": 42} + ) + stream_handler = debugger_server_lib.DebuggerDataStreamHandler( + events_writer_manager=FakeEventsWriterManager(events_written) + ) + stream_handler.on_core_metadata_event(metadata_event) + + # The server receives 2 events. It should assign both the correct step. + stream_handler.on_value_event( + self._create_event_with_float_tensor( + "MatMul", 0, "DebugNumericSummary", list(range(1, 15)) + ) + ) + stream_handler.on_value_event( + self._create_event_with_float_tensor( + "Add", 0, "DebugNumericSummary", list(range(2, 16)) + ) + ) + self.assertEqual(42, events_written[0].step) + self.assertEqual(42, events_written[1].step) + + def testSentinelStepValueAssignedWhenExecutorStepCountKeyIsMissing(self): + events_written = [] + metadata_event = event_pb2.Event() + metadata_event.log_message.message = json.dumps({}) + stream_handler = debugger_server_lib.DebuggerDataStreamHandler( + events_writer_manager=FakeEventsWriterManager(events_written) + ) + stream_handler.on_core_metadata_event(metadata_event) + health_pill_event = self._create_event_with_float_tensor( + "MatMul", 0, "DebugNumericSummary", list(range(1, 15)) + ) + stream_handler.on_value_event(health_pill_event) + self.assertGreater(events_written[0].step, 0) + + def testSentinelStepValueAssignedWhenMetadataJsonIsInvalid(self): + events_written = [] + metadata_event = event_pb2.Event() + metadata_event.log_message.message = "some invalid JSON string" + stream_handler = debugger_server_lib.DebuggerDataStreamHandler( + events_writer_manager=FakeEventsWriterManager(events_written) + ) + stream_handler.on_core_metadata_event(metadata_event) + health_pill_event = self._create_event_with_float_tensor( + "MatMul", 0, "DebugNumericSummary", list(range(1, 15)) + ) + stream_handler.on_value_event(health_pill_event) + self.assertGreater(events_written[0].step, 0) + + def testAlertingEventCallback(self): + numerics_alert_callback = tf.compat.v1.test.mock.Mock() + stream_handler = debugger_server_lib.DebuggerDataStreamHandler( + events_writer_manager=FakeEventsWriterManager(self.events_written), + numerics_alert_callback=numerics_alert_callback, + ) + stream_handler.on_core_metadata_event(event_pb2.Event()) + + # The stream handler receives 1 good event and 1 with an NaN value. + stream_handler.on_value_event( + self._create_event_with_float_tensor( + "Add", 0, "DebugNumericSummary", [0] * 14 + ) + ) + stream_handler.on_value_event( + self._create_event_with_float_tensor( + "Add", + 0, + "DebugNumericSummary", + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ) + ) + + # The second event should have triggered the callback. + numerics_alert_callback.assert_called_once_with( + numerics_alert.NumericsAlert( + "/job:localhost/replica:0/task:0/cpu:0", "Add:0", 0, 1, 0, 0 + ) + ) + + # The stream handler receives an event with a -Inf value. + numerics_alert_callback.reset_mock() + stream_handler.on_value_event( + self._create_event_with_float_tensor( + "Add", + 0, + "DebugNumericSummary", + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ) + ) + numerics_alert_callback.assert_called_once_with( + numerics_alert.NumericsAlert( + "/job:localhost/replica:0/task:0/cpu:0", "Add:0", 0, 0, 1, 0 + ) + ) + + # The stream handler receives an event with a +Inf value. + numerics_alert_callback.reset_mock() + stream_handler.on_value_event( + self._create_event_with_float_tensor( + "Add", + 0, + "DebugNumericSummary", + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + ) + ) + numerics_alert_callback.assert_called_once_with( + numerics_alert.NumericsAlert( + "/job:localhost/replica:0/task:0/cpu:0", "Add:0", 0, 0, 0, 1 + ) + ) + + # The stream handler receives an event without any pathetic values. + numerics_alert_callback.reset_mock() + stream_handler.on_value_event( + self._create_event_with_float_tensor( + "Add", + 0, + "DebugNumericSummary", + [0, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0], + ) + ) + # assert_not_called is not available in Python 3.4. + self.assertFalse(numerics_alert_callback.called) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/tensorboard/plugins/debugger/events_writer_manager.py b/tensorboard/plugins/debugger/events_writer_manager.py index ca9707e5b3..472c26fa35 100644 --- a/tensorboard/plugins/debugger/events_writer_manager.py +++ b/tensorboard/plugins/debugger/events_writer_manager.py @@ -30,9 +30,9 @@ # To keep compatibility with both 1.x and 2.x try: - from tensorflow.python import _pywrap_events_writer as tf_events_writer + from tensorflow.python import _pywrap_events_writer as tf_events_writer except ImportError: - from tensorflow.python import pywrap_tensorflow as tf_events_writer + from tensorflow.python import pywrap_tensorflow as tf_events_writer from tensorboard.util import tb_logging logger = tb_logging.get_logger() @@ -42,7 +42,8 @@ # A regex used to match the names of debugger-related events files. _DEBUGGER_EVENTS_FILE_NAME_REGEX = re.compile( - r"^" + re.escape(DEBUGGER_EVENTS_FILE_STARTING_TEXT) + r"\.\d+\.(\d+)") + r"^" + re.escape(DEBUGGER_EVENTS_FILE_STARTING_TEXT) + r"\.\d+\.(\d+)" +) # The default events file size cap (in bytes). _DEFAULT_EVENTS_FILE_SIZE_CAP_BYTES = 500 * 1024 * 1024 @@ -57,191 +58,208 @@ class EventsWriterManager(object): - """Manages writing debugger-related events to disk. - - Creates new events writers if files get too big. - """ - - def __init__(self, - events_directory, - single_file_size_cap_bytes=_DEFAULT_EVENTS_FILE_SIZE_CAP_BYTES, - check_this_often=_DEFAULT_CHECK_EVENT_FILES_SIZE_CAP_EVERY, - total_file_size_cap_bytes=_DEFAULT_TOTAL_SIZE_CAP_BYTES, - always_flush=False): - """Constructs an EventsWriterManager. - - Args: - events_directory: (`string`) The log directory in which debugger events - reside. - single_file_size_cap_bytes: (`int`) A number of bytes. During a check, if - the manager determines that the events file being written to exceeds - this size, it creates a new events file to write to. Note that events - file may still exceed this size - the events writer manager just creates - a new events file if it finds that the current file exceeds this size. - check_this_often: (`int`) The manager performs a file size check every - this many events. We want to avoid checking upon every event for - performance reasons. If provided, must be greater than 1. - total_file_size_cap_bytes: A cap on the total number of bytes occupied by - all events. When a new events writer is created, the least recently - created events file will be deleted if the total size occupied by - debugger-related events on disk exceeds this cap. Note that the total - size could now and then be larger than this cap because the events - writer manager only checks when it creates a new events file. - always_flush: (`bool`) Whether to flush to disk after every write. Useful - for testing. - """ - self._events_directory = events_directory - self._single_file_size_cap_bytes = single_file_size_cap_bytes - self.total_file_size_cap_bytes = total_file_size_cap_bytes - self._check_this_often = check_this_often - self._always_flush = always_flush - - # Each events file gets a unique file count within its file name. This value - # increments every time a new events file is created. - self._events_file_count = 0 - - # If there are existing event files, assign the events file count to be - # greater than the last existing one. - events_file_names = self._fetch_events_files_on_disk() - if events_file_names: - self._events_file_count = self._obtain_file_index( - events_file_names[-1]) + 1 - - self._event_count = 0 - self._lock = threading.Lock() - self._events_writer = self._create_events_writer(events_directory) - - def write_event(self, event): - """Writes an event proto to disk. - - This method is threadsafe with respect to invocations of itself. - - Args: - event: The event proto. + """Manages writing debugger-related events to disk. - Raises: - IOError: If writing the event proto to disk fails. + Creates new events writers if files get too big. """ - self._lock.acquire() - try: - self._events_writer.WriteEvent(event) - self._event_count += 1 - if self._always_flush: - # We flush on every event within the integration test. - self._events_writer.Flush() - - if self._event_count == self._check_this_often: - # Every so often, we check whether the size of the file is too big. - self._event_count = 0 - # Flush to get an accurate size check. - self._events_writer.Flush() - - file_path = os.path.join(self._events_directory, - self.get_current_file_name()) - if not tf.io.gfile.exists(file_path): - # The events file does not exist. Perhaps the user had manually - # deleted it after training began. Create a new one. - self._events_writer.Close() - self._events_writer = self._create_events_writer( - self._events_directory) - elif tf.io.gfile.stat(file_path).length > self._single_file_size_cap_bytes: - # The current events file has gotten too big. Close the previous - # events writer. Make a new one. - self._events_writer.Close() - self._events_writer = self._create_events_writer( - self._events_directory) - except IOError as err: - logger.error( - "Writing to %s failed: %s", self.get_current_file_name(), err) - self._lock.release() - - def get_current_file_name(self): - """Gets the name of the events file currently being written to. - - Returns: - The name of the events file being written to. - """ - return tf.compat.as_text(self._events_writer.FileName()) + def __init__( + self, + events_directory, + single_file_size_cap_bytes=_DEFAULT_EVENTS_FILE_SIZE_CAP_BYTES, + check_this_often=_DEFAULT_CHECK_EVENT_FILES_SIZE_CAP_EVERY, + total_file_size_cap_bytes=_DEFAULT_TOTAL_SIZE_CAP_BYTES, + always_flush=False, + ): + """Constructs an EventsWriterManager. + + Args: + events_directory: (`string`) The log directory in which debugger events + reside. + single_file_size_cap_bytes: (`int`) A number of bytes. During a check, if + the manager determines that the events file being written to exceeds + this size, it creates a new events file to write to. Note that events + file may still exceed this size - the events writer manager just creates + a new events file if it finds that the current file exceeds this size. + check_this_often: (`int`) The manager performs a file size check every + this many events. We want to avoid checking upon every event for + performance reasons. If provided, must be greater than 1. + total_file_size_cap_bytes: A cap on the total number of bytes occupied by + all events. When a new events writer is created, the least recently + created events file will be deleted if the total size occupied by + debugger-related events on disk exceeds this cap. Note that the total + size could now and then be larger than this cap because the events + writer manager only checks when it creates a new events file. + always_flush: (`bool`) Whether to flush to disk after every write. Useful + for testing. + """ + self._events_directory = events_directory + self._single_file_size_cap_bytes = single_file_size_cap_bytes + self.total_file_size_cap_bytes = total_file_size_cap_bytes + self._check_this_often = check_this_often + self._always_flush = always_flush + + # Each events file gets a unique file count within its file name. This value + # increments every time a new events file is created. + self._events_file_count = 0 + + # If there are existing event files, assign the events file count to be + # greater than the last existing one. + events_file_names = self._fetch_events_files_on_disk() + if events_file_names: + self._events_file_count = ( + self._obtain_file_index(events_file_names[-1]) + 1 + ) - def dispose(self): - """Disposes of this events writer manager, making it no longer usable. + self._event_count = 0 + self._lock = threading.Lock() + self._events_writer = self._create_events_writer(events_directory) - Call this method when this object is done being used in order to clean up - resources and handlers. This method should ever only be called once. - """ - self._lock.acquire() - self._events_writer.Close() - self._events_writer = None - self._lock.release() + def write_event(self, event): + """Writes an event proto to disk. - def _create_events_writer(self, directory): - """Creates a new events writer. + This method is threadsafe with respect to invocations of itself. - Args: - directory: The directory in which to write files containing events. + Args: + event: The event proto. - Returns: - A new events writer, which corresponds to a new events file. - """ - total_size = 0 - events_files = self._fetch_events_files_on_disk() - for file_name in events_files: - file_path = os.path.join(self._events_directory, file_name) - total_size += tf.io.gfile.stat(file_path).length - - if total_size >= self.total_file_size_cap_bytes: - # The total size written to disk is too big. Delete events files until - # the size is below the cap. - for file_name in events_files: - if total_size < self.total_file_size_cap_bytes: - break - - file_path = os.path.join(self._events_directory, file_name) - file_size = tf.io.gfile.stat(file_path).length + Raises: + IOError: If writing the event proto to disk fails. + """ + self._lock.acquire() try: - tf.io.gfile.remove(file_path) - total_size -= file_size - logger.info( - "Deleted %s because events files take up over %d bytes", - file_path, self.total_file_size_cap_bytes) + self._events_writer.WriteEvent(event) + self._event_count += 1 + if self._always_flush: + # We flush on every event within the integration test. + self._events_writer.Flush() + + if self._event_count == self._check_this_often: + # Every so often, we check whether the size of the file is too big. + self._event_count = 0 + + # Flush to get an accurate size check. + self._events_writer.Flush() + + file_path = os.path.join( + self._events_directory, self.get_current_file_name() + ) + if not tf.io.gfile.exists(file_path): + # The events file does not exist. Perhaps the user had manually + # deleted it after training began. Create a new one. + self._events_writer.Close() + self._events_writer = self._create_events_writer( + self._events_directory + ) + elif ( + tf.io.gfile.stat(file_path).length + > self._single_file_size_cap_bytes + ): + # The current events file has gotten too big. Close the previous + # events writer. Make a new one. + self._events_writer.Close() + self._events_writer = self._create_events_writer( + self._events_directory + ) except IOError as err: - logger.error("Deleting %s failed: %s", file_path, err) - - # We increment this index because each events writer must differ in prefix. - self._events_file_count += 1 - file_path = "%s.%d.%d" % ( - os.path.join(directory, DEBUGGER_EVENTS_FILE_STARTING_TEXT), - time.time(), self._events_file_count) - logger.info("Creating events file %s", file_path) - return tf_events_writer.EventsWriter(tf.compat.as_bytes(file_path)) - - def _fetch_events_files_on_disk(self): - """Obtains the names of debugger-related events files within the directory. - - Returns: - The names of the debugger-related events files written to disk. The names - are sorted in increasing events file index. - """ - all_files = tf.io.gfile.listdir(self._events_directory) - relevant_files = [ - file_name for file_name in all_files - if _DEBUGGER_EVENTS_FILE_NAME_REGEX.match(file_name) - ] - return sorted(relevant_files, key=self._obtain_file_index) - - def _obtain_file_index(self, file_name): - """Obtains the file index associated with an events file. - - The index is stored within a file name and is incremented every time a new - events file is created. Assumes that the file name is a valid debugger - events file name. - - Args: - file_name: The name of the debugger-related events file. The file index is - stored within the file name. - - Returns: - The integer events file index. - """ - return int(_DEBUGGER_EVENTS_FILE_NAME_REGEX.match(file_name).group(1)) + logger.error( + "Writing to %s failed: %s", self.get_current_file_name(), err + ) + self._lock.release() + + def get_current_file_name(self): + """Gets the name of the events file currently being written to. + + Returns: + The name of the events file being written to. + """ + return tf.compat.as_text(self._events_writer.FileName()) + + def dispose(self): + """Disposes of this events writer manager, making it no longer usable. + + Call this method when this object is done being used in order to + clean up resources and handlers. This method should ever only be + called once. + """ + self._lock.acquire() + self._events_writer.Close() + self._events_writer = None + self._lock.release() + + def _create_events_writer(self, directory): + """Creates a new events writer. + + Args: + directory: The directory in which to write files containing events. + + Returns: + A new events writer, which corresponds to a new events file. + """ + total_size = 0 + events_files = self._fetch_events_files_on_disk() + for file_name in events_files: + file_path = os.path.join(self._events_directory, file_name) + total_size += tf.io.gfile.stat(file_path).length + + if total_size >= self.total_file_size_cap_bytes: + # The total size written to disk is too big. Delete events files until + # the size is below the cap. + for file_name in events_files: + if total_size < self.total_file_size_cap_bytes: + break + + file_path = os.path.join(self._events_directory, file_name) + file_size = tf.io.gfile.stat(file_path).length + try: + tf.io.gfile.remove(file_path) + total_size -= file_size + logger.info( + "Deleted %s because events files take up over %d bytes", + file_path, + self.total_file_size_cap_bytes, + ) + except IOError as err: + logger.error("Deleting %s failed: %s", file_path, err) + + # We increment this index because each events writer must differ in prefix. + self._events_file_count += 1 + file_path = "%s.%d.%d" % ( + os.path.join(directory, DEBUGGER_EVENTS_FILE_STARTING_TEXT), + time.time(), + self._events_file_count, + ) + logger.info("Creating events file %s", file_path) + return tf_events_writer.EventsWriter(tf.compat.as_bytes(file_path)) + + def _fetch_events_files_on_disk(self): + """Obtains the names of debugger-related events files within the + directory. + + Returns: + The names of the debugger-related events files written to disk. The names + are sorted in increasing events file index. + """ + all_files = tf.io.gfile.listdir(self._events_directory) + relevant_files = [ + file_name + for file_name in all_files + if _DEBUGGER_EVENTS_FILE_NAME_REGEX.match(file_name) + ] + return sorted(relevant_files, key=self._obtain_file_index) + + def _obtain_file_index(self, file_name): + """Obtains the file index associated with an events file. + + The index is stored within a file name and is incremented every time a new + events file is created. Assumes that the file name is a valid debugger + events file name. + + Args: + file_name: The name of the debugger-related events file. The file index is + stored within the file name. + + Returns: + The integer events file index. + """ + return int(_DEBUGGER_EVENTS_FILE_NAME_REGEX.match(file_name).group(1)) diff --git a/tensorboard/plugins/debugger/events_writer_manager_test.py b/tensorboard/plugins/debugger/events_writer_manager_test.py index d7b7530cdf..f39c086667 100644 --- a/tensorboard/plugins/debugger/events_writer_manager_test.py +++ b/tensorboard/plugins/debugger/events_writer_manager_test.py @@ -23,241 +23,295 @@ import tensorflow as tf -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from tensorboard.backend.event_processing import ( + plugin_event_multiplexer as event_multiplexer, +) from tensorboard.plugins import base_plugin from tensorboard.plugins.debugger import debugger_plugin_testlib from tensorboard.plugins.debugger import numerics_alert class DebuggerPluginTest(debugger_plugin_testlib.DebuggerPluginTestBase): + def testHealthPillsRouteProvided(self): + """Tests that the plugin offers the route for requesting health + pills.""" + apps = self.plugin.get_plugin_apps() + self.assertIn("/health_pills", apps) + self.assertIsInstance(apps["/health_pills"], collections.Callable) - def testHealthPillsRouteProvided(self): - """Tests that the plugin offers the route for requesting health pills.""" - apps = self.plugin.get_plugin_apps() - self.assertIn('/health_pills', apps) - self.assertIsInstance(apps['/health_pills'], collections.Callable) + def testHealthPillsPluginIsActive(self): + # The multiplexer has sampled health pills. + self.assertTrue(self.plugin.is_active()) - def testHealthPillsPluginIsActive(self): - # The multiplexer has sampled health pills. - self.assertTrue(self.plugin.is_active()) + def testHealthPillsPluginIsInactive(self): + plugin = self.debugger_plugin_module.DebuggerPlugin( + base_plugin.TBContext( + logdir=self.log_dir, + multiplexer=event_multiplexer.EventMultiplexer({}), + ) + ) + plugin.listen(self.debugger_data_server_grpc_port) - def testHealthPillsPluginIsInactive(self): - plugin = self.debugger_plugin_module.DebuggerPlugin( - base_plugin.TBContext( - logdir=self.log_dir, - multiplexer=event_multiplexer.EventMultiplexer({}))) - plugin.listen(self.debugger_data_server_grpc_port) + # The multiplexer lacks sampled health pills. + self.assertFalse(plugin.is_active()) - # The multiplexer lacks sampled health pills. - self.assertFalse(plugin.is_active()) - - def testRequestHealthPillsForRunFoo(self): - """Tests that the plugin produces health pills for a specified run.""" - response = self.server.post( - '/data/plugin/debugger/health_pills', - data={ - 'node_names': json.dumps(['layers/Variable', 'unavailable_node']), - 'run': 'run_foo', - }) - self.assertEqual(200, response.status_code) - self.assertDictEqual({ - 'layers/Variable': [{ - 'wall_time': 4242, - 'step': 42, - 'device_name': '/job:localhost/replica:0/task:0/cpu:0', - 'node_name': 'layers/Variable', - 'output_slot': 0, - 'dtype': 'tf.int16', - 'shape': [8.0], - 'value': list(range(12)) + [tf.int16.as_datatype_enum, 1.0, 8.0], - }], - }, self._DeserializeResponse(response.get_data())) - - def testRequestHealthPillsForDefaultRun(self): - """Tests that the plugin produces health pills for the default '.' run.""" - # Do not provide a 'run' parameter in POST data. - response = self.server.post( - '/data/plugin/debugger/health_pills', - data={ - 'node_names': json.dumps(['logits/Add', 'unavailable_node']), - }) - self.assertEqual(200, response.status_code) - # The health pills for 'layers/Matmul' should not be included since the - # request excluded that node name. - self.assertDictEqual({ - 'logits/Add': [ + def testRequestHealthPillsForRunFoo(self): + """Tests that the plugin produces health pills for a specified run.""" + response = self.server.post( + "/data/plugin/debugger/health_pills", + data={ + "node_names": json.dumps( + ["layers/Variable", "unavailable_node"] + ), + "run": "run_foo", + }, + ) + self.assertEqual(200, response.status_code) + self.assertDictEqual( { - 'wall_time': 1337, - 'step': 7, - 'device_name': '/job:localhost/replica:0/task:0/cpu:0', - 'node_name': 'logits/Add', - 'output_slot': 0, - 'dtype': 'tf.int32', - 'shape': [3.0, 3.0], - 'value': (list(range(12)) + - [float(tf.int32.as_datatype_enum), 2.0, 3.0, 3.0]), + "layers/Variable": [ + { + "wall_time": 4242, + "step": 42, + "device_name": "/job:localhost/replica:0/task:0/cpu:0", + "node_name": "layers/Variable", + "output_slot": 0, + "dtype": "tf.int16", + "shape": [8.0], + "value": list(range(12)) + + [tf.int16.as_datatype_enum, 1.0, 8.0], + } + ], }, + self._DeserializeResponse(response.get_data()), + ) + + def testRequestHealthPillsForDefaultRun(self): + """Tests that the plugin produces health pills for the default '.' + run.""" + # Do not provide a 'run' parameter in POST data. + response = self.server.post( + "/data/plugin/debugger/health_pills", + data={ + "node_names": json.dumps(["logits/Add", "unavailable_node"]), + }, + ) + self.assertEqual(200, response.status_code) + # The health pills for 'layers/Matmul' should not be included since the + # request excluded that node name. + self.assertDictEqual( { - 'wall_time': 1338, - 'step': 8, - 'device_name': '/job:localhost/replica:0/task:0/cpu:0', - 'node_name': 'logits/Add', - 'output_slot': 0, - 'dtype': 'tf.int16', - 'shape': [], - 'value': (list(range(12)) + - [float(tf.int16.as_datatype_enum), 0.0]), + "logits/Add": [ + { + "wall_time": 1337, + "step": 7, + "device_name": "/job:localhost/replica:0/task:0/cpu:0", + "node_name": "logits/Add", + "output_slot": 0, + "dtype": "tf.int32", + "shape": [3.0, 3.0], + "value": ( + list(range(12)) + + [float(tf.int32.as_datatype_enum), 2.0, 3.0, 3.0] + ), + }, + { + "wall_time": 1338, + "step": 8, + "device_name": "/job:localhost/replica:0/task:0/cpu:0", + "node_name": "logits/Add", + "output_slot": 0, + "dtype": "tf.int16", + "shape": [], + "value": ( + list(range(12)) + + [float(tf.int16.as_datatype_enum), 0.0] + ), + }, + ], }, - ], - }, self._DeserializeResponse(response.get_data())) + self._DeserializeResponse(response.get_data()), + ) - def testRequestHealthPillsForEmptyRun(self): - """Tests that the plugin responds with an empty dictionary.""" - response = self.server.post( - '/data/plugin/debugger/health_pills', - data={ - 'node_names': json.dumps(['layers/Variable']), - 'run': 'run_with_no_health_pills', - }) - self.assertEqual(200, response.status_code) - self.assertDictEqual({}, self._DeserializeResponse(response.get_data())) + def testRequestHealthPillsForEmptyRun(self): + """Tests that the plugin responds with an empty dictionary.""" + response = self.server.post( + "/data/plugin/debugger/health_pills", + data={ + "node_names": json.dumps(["layers/Variable"]), + "run": "run_with_no_health_pills", + }, + ) + self.assertEqual(200, response.status_code) + self.assertDictEqual({}, self._DeserializeResponse(response.get_data())) - def testGetRequestsUnsupported(self): - """Tests that GET requests are unsupported.""" - response = self.server.get('/data/plugin/debugger/health_pills') - self.assertEqual(405, response.status_code) + def testGetRequestsUnsupported(self): + """Tests that GET requests are unsupported.""" + response = self.server.get("/data/plugin/debugger/health_pills") + self.assertEqual(405, response.status_code) - def testRequestsWithoutProperPostKeyUnsupported(self): - """Tests that requests lacking the node_names POST key are unsupported.""" - response = self.server.post('/data/plugin/debugger/health_pills') - self.assertEqual(400, response.status_code) + def testRequestsWithoutProperPostKeyUnsupported(self): + """Tests that requests lacking the node_names POST key are + unsupported.""" + response = self.server.post("/data/plugin/debugger/health_pills") + self.assertEqual(400, response.status_code) - def testRequestsWithBadJsonUnsupported(self): - """Tests that requests with undecodable JSON are unsupported.""" - response = self.server.post( - '/data/plugin/debugger/health_pills', - data={ - 'node_names': 'some obviously non JSON text', - }) - self.assertEqual(400, response.status_code) + def testRequestsWithBadJsonUnsupported(self): + """Tests that requests with undecodable JSON are unsupported.""" + response = self.server.post( + "/data/plugin/debugger/health_pills", + data={"node_names": "some obviously non JSON text",}, + ) + self.assertEqual(400, response.status_code) - def testRequestsWithNonListPostDataUnsupported(self): - """Tests that requests with loads lacking lists of ops are unsupported.""" - response = self.server.post( - '/data/plugin/debugger/health_pills', - data={ - 'node_names': json.dumps({ - 'this is a dict': 'and not a list.' - }), - }) - self.assertEqual(400, response.status_code) + def testRequestsWithNonListPostDataUnsupported(self): + """Tests that requests with loads lacking lists of ops are + unsupported.""" + response = self.server.post( + "/data/plugin/debugger/health_pills", + data={ + "node_names": json.dumps({"this is a dict": "and not a list."}), + }, + ) + self.assertEqual(400, response.status_code) - def testFetchHealthPillsForSpecificStep(self): - """Tests that requesting health pills at a specific steps works. + def testFetchHealthPillsForSpecificStep(self): + """Tests that requesting health pills at a specific steps works. - This path may be slow in real life because it reads from disk. - """ - # Request health pills for these nodes at step 7 specifically. - response = self.server.post( - '/data/plugin/debugger/health_pills', - data={ - 'node_names': json.dumps(['logits/Add', 'layers/Matmul']), - 'step': 7 - }) - self.assertEqual(200, response.status_code) - # The response should only include health pills at step 7. - self.assertDictEqual({ - 'logits/Add': [ - { - 'wall_time': 1337, - 'step': 7, - 'device_name': '/job:localhost/replica:0/task:0/cpu:0', - 'node_name': 'logits/Add', - 'output_slot': 0, - 'dtype': 'tf.int32', - 'shape': [3.0, 3.0], - 'value': (list(range(12)) + - [float(tf.int32.as_datatype_enum), 2.0, 3.0, 3.0]), + This path may be slow in real life because it reads from disk. + """ + # Request health pills for these nodes at step 7 specifically. + response = self.server.post( + "/data/plugin/debugger/health_pills", + data={ + "node_names": json.dumps(["logits/Add", "layers/Matmul"]), + "step": 7, }, - ], - 'layers/Matmul': [ + ) + self.assertEqual(200, response.status_code) + # The response should only include health pills at step 7. + self.assertDictEqual( { - 'wall_time': 43, - 'step': 7, - 'device_name': '/job:localhost/replica:0/task:0/cpu:0', - 'node_name': 'layers/Matmul', - 'output_slot': 1, - 'dtype': 'tf.float64', - 'shape': [3.0, 3.0], - 'value': (list(range(12)) + - [float(tf.float64.as_datatype_enum), 2.0, 3.0, 3.0]), + "logits/Add": [ + { + "wall_time": 1337, + "step": 7, + "device_name": "/job:localhost/replica:0/task:0/cpu:0", + "node_name": "logits/Add", + "output_slot": 0, + "dtype": "tf.int32", + "shape": [3.0, 3.0], + "value": ( + list(range(12)) + + [float(tf.int32.as_datatype_enum), 2.0, 3.0, 3.0] + ), + }, + ], + "layers/Matmul": [ + { + "wall_time": 43, + "step": 7, + "device_name": "/job:localhost/replica:0/task:0/cpu:0", + "node_name": "layers/Matmul", + "output_slot": 1, + "dtype": "tf.float64", + "shape": [3.0, 3.0], + "value": ( + list(range(12)) + + [ + float(tf.float64.as_datatype_enum), + 2.0, + 3.0, + 3.0, + ] + ), + }, + ], }, - ], - }, self._DeserializeResponse(response.get_data())) + self._DeserializeResponse(response.get_data()), + ) - def testNoHealthPillsForSpecificStep(self): - """Tests that an empty mapping is returned for no health pills at a step.""" - response = self.server.post( - '/data/plugin/debugger/health_pills', - data={ - 'node_names': json.dumps(['some/clearly/non-existent/op']), - 'step': 7 - }) - self.assertEqual(200, response.status_code) - self.assertDictEqual({}, self._DeserializeResponse(response.get_data())) + def testNoHealthPillsForSpecificStep(self): + """Tests that an empty mapping is returned for no health pills at a + step.""" + response = self.server.post( + "/data/plugin/debugger/health_pills", + data={ + "node_names": json.dumps(["some/clearly/non-existent/op"]), + "step": 7, + }, + ) + self.assertEqual(200, response.status_code) + self.assertDictEqual({}, self._DeserializeResponse(response.get_data())) - def testNoHealthPillsForOutOfRangeStep(self): - """Tests that an empty mapping is returned for an out of range step.""" - response = self.server.post( - '/data/plugin/debugger/health_pills', - data={ - 'node_names': json.dumps(['logits/Add', 'layers/Matmul']), - # This step higher than that of any event written to disk. - 'step': 42424242 - }) - self.assertEqual(200, response.status_code) - self.assertDictEqual({}, self._DeserializeResponse(response.get_data())) + def testNoHealthPillsForOutOfRangeStep(self): + """Tests that an empty mapping is returned for an out of range step.""" + response = self.server.post( + "/data/plugin/debugger/health_pills", + data={ + "node_names": json.dumps(["logits/Add", "layers/Matmul"]), + # This step higher than that of any event written to disk. + "step": 42424242, + }, + ) + self.assertEqual(200, response.status_code) + self.assertDictEqual({}, self._DeserializeResponse(response.get_data())) - def testNumericsAlertReportResponse(self): - """Tests that reports of bad values are returned.""" - alerts = [ - numerics_alert.NumericsAlertReportRow('cpu0', 'MatMul', 123, 2, 3, 4), - numerics_alert.NumericsAlertReportRow('cpu1', 'Add', 124, 5, 6, 7), - ] - self.mock_debugger_data_server.numerics_alert_report.return_value = alerts - response = self.server.get('/data/plugin/debugger/numerics_alert_report') - self.assertEqual(200, response.status_code) + def testNumericsAlertReportResponse(self): + """Tests that reports of bad values are returned.""" + alerts = [ + numerics_alert.NumericsAlertReportRow( + "cpu0", "MatMul", 123, 2, 3, 4 + ), + numerics_alert.NumericsAlertReportRow("cpu1", "Add", 124, 5, 6, 7), + ] + self.mock_debugger_data_server.numerics_alert_report.return_value = ( + alerts + ) + response = self.server.get( + "/data/plugin/debugger/numerics_alert_report" + ) + self.assertEqual(200, response.status_code) - retrieved_alerts = self._DeserializeResponse(response.get_data()) - self.assertEqual(2, len(retrieved_alerts)) - self.assertDictEqual({ - 'device_name': 'cpu0', - 'tensor_name': 'MatMul', - 'first_timestamp': 123, - 'nan_event_count': 2, - 'neg_inf_event_count': 3, - 'pos_inf_event_count': 4, - }, retrieved_alerts[0]) - self.assertDictEqual({ - 'device_name': 'cpu1', - 'tensor_name': 'Add', - 'first_timestamp': 124, - 'nan_event_count': 5, - 'neg_inf_event_count': 6, - 'pos_inf_event_count': 7, - }, retrieved_alerts[1]) + retrieved_alerts = self._DeserializeResponse(response.get_data()) + self.assertEqual(2, len(retrieved_alerts)) + self.assertDictEqual( + { + "device_name": "cpu0", + "tensor_name": "MatMul", + "first_timestamp": 123, + "nan_event_count": 2, + "neg_inf_event_count": 3, + "pos_inf_event_count": 4, + }, + retrieved_alerts[0], + ) + self.assertDictEqual( + { + "device_name": "cpu1", + "tensor_name": "Add", + "first_timestamp": 124, + "nan_event_count": 5, + "neg_inf_event_count": 6, + "pos_inf_event_count": 7, + }, + retrieved_alerts[1], + ) - def testDebuggerDataServerNotStartedWhenPortIsNone(self): - """Tests that the plugin starts no debugger data server if port is None.""" - self.mock_debugger_data_server_class.reset_mock() + def testDebuggerDataServerNotStartedWhenPortIsNone(self): + """Tests that the plugin starts no debugger data server if port is + None.""" + self.mock_debugger_data_server_class.reset_mock() - # Initialize a debugger plugin with no GRPC port provided. - self.debugger_plugin_module.DebuggerPlugin(self.context).get_plugin_apps() + # Initialize a debugger plugin with no GRPC port provided. + self.debugger_plugin_module.DebuggerPlugin( + self.context + ).get_plugin_apps() - # No debugger data server should have been started. - # assert_not_called is not available in Python 3.4. - self.assertFalse(self.mock_debugger_data_server_class.called) + # No debugger data server should have been started. + # assert_not_called is not available in Python 3.4. + self.assertFalse(self.mock_debugger_data_server_class.called) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/debugger/health_pill_calc.py b/tensorboard/plugins/debugger/health_pill_calc.py index 55bec2dd2a..4574993921 100644 --- a/tensorboard/plugins/debugger/health_pill_calc.py +++ b/tensorboard/plugins/debugger/health_pill_calc.py @@ -32,87 +32,92 @@ def calc_health_pill(tensor): - """Calculate health pill of a tensor. + """Calculate health pill of a tensor. - Args: - tensor: An instance of `np.array` (for initialized tensors) or - `tensorflow.python.debug.lib.debug_data.InconvertibleTensorProto` - (for unininitialized tensors). + Args: + tensor: An instance of `np.array` (for initialized tensors) or + `tensorflow.python.debug.lib.debug_data.InconvertibleTensorProto` + (for unininitialized tensors). - Returns: - If `tensor` is an initialized tensor of numeric or boolean types: - the calculated health pill, as a `list` of `float`s. - Else if `tensor` is an initialized tensor with `string`, `resource` or any - other non-numeric types: - `None`. - Else (i.e., if `tensor` is uninitialized): An all-zero `list`, with the - first element signifying that the tensor is uninitialized. - """ - health_pill = [0.0] * 14 + Returns: + If `tensor` is an initialized tensor of numeric or boolean types: + the calculated health pill, as a `list` of `float`s. + Else if `tensor` is an initialized tensor with `string`, `resource` or any + other non-numeric types: + `None`. + Else (i.e., if `tensor` is uninitialized): An all-zero `list`, with the + first element signifying that the tensor is uninitialized. + """ + health_pill = [0.0] * 14 - # TODO(cais): Add unit test for this method that compares results with - # DebugNumericSummary output. + # TODO(cais): Add unit test for this method that compares results with + # DebugNumericSummary output. - # Is tensor initialized. - if not isinstance(tensor, np.ndarray): - return health_pill - health_pill[0] = 1.0 + # Is tensor initialized. + if not isinstance(tensor, np.ndarray): + return health_pill + health_pill[0] = 1.0 - if not (np.issubdtype(tensor.dtype, np.float) or - np.issubdtype(tensor.dtype, np.complex) or - np.issubdtype(tensor.dtype, np.integer) or - tensor.dtype == np.bool): - return None + if not ( + np.issubdtype(tensor.dtype, np.float) + or np.issubdtype(tensor.dtype, np.complex) + or np.issubdtype(tensor.dtype, np.integer) + or tensor.dtype == np.bool + ): + return None - # Total number of elements. - health_pill[1] = float(np.size(tensor)) + # Total number of elements. + health_pill[1] = float(np.size(tensor)) - # TODO(cais): Further performance optimization? - nan_mask = np.isnan(tensor) - inf_mask = np.isinf(tensor) - # Number of NaN elements. - health_pill[2] = float(np.sum(nan_mask)) - # Number of -Inf elements. - health_pill[3] = float(np.sum(tensor == -np.inf)) - # Number of finite negative elements. - health_pill[4] = float(np.sum( - np.logical_and(np.logical_not(inf_mask), tensor < 0.0))) - # Number of zero elements. - health_pill[5] = float(np.sum(tensor == 0.0)) - # Number finite positive elements. - health_pill[6] = float(np.sum( - np.logical_and(np.logical_not(inf_mask), tensor > 0.0))) - # Number of +Inf elements. - health_pill[7] = float(np.sum(tensor == np.inf)) + # TODO(cais): Further performance optimization? + nan_mask = np.isnan(tensor) + inf_mask = np.isinf(tensor) + # Number of NaN elements. + health_pill[2] = float(np.sum(nan_mask)) + # Number of -Inf elements. + health_pill[3] = float(np.sum(tensor == -np.inf)) + # Number of finite negative elements. + health_pill[4] = float( + np.sum(np.logical_and(np.logical_not(inf_mask), tensor < 0.0)) + ) + # Number of zero elements. + health_pill[5] = float(np.sum(tensor == 0.0)) + # Number finite positive elements. + health_pill[6] = float( + np.sum(np.logical_and(np.logical_not(inf_mask), tensor > 0.0)) + ) + # Number of +Inf elements. + health_pill[7] = float(np.sum(tensor == np.inf)) - finite_subset = tensor[ - np.logical_and(np.logical_not(nan_mask), np.logical_not(inf_mask))] - if np.size(finite_subset): - # Finite subset is not empty. - # Minimum of the non-NaN non-Inf elements. - health_pill[8] = float(np.min(finite_subset)) - # Maximum of the non-NaN non-Inf elements. - health_pill[9] = float(np.max(finite_subset)) - # Mean of the non-NaN non-Inf elements. - health_pill[10] = float(np.mean(finite_subset)) - # Variance of the non-NaN non-Inf elements. - health_pill[11] = float(np.var(finite_subset)) - else: - # If no finite element exists: - # Set minimum to +inf. - health_pill[8] = np.inf - # Set maximum to -inf. - health_pill[9] = -np.inf - # Set mean to NaN. - health_pill[10] = np.nan - # Set variance to NaN. - health_pill[11] = np.nan + finite_subset = tensor[ + np.logical_and(np.logical_not(nan_mask), np.logical_not(inf_mask)) + ] + if np.size(finite_subset): + # Finite subset is not empty. + # Minimum of the non-NaN non-Inf elements. + health_pill[8] = float(np.min(finite_subset)) + # Maximum of the non-NaN non-Inf elements. + health_pill[9] = float(np.max(finite_subset)) + # Mean of the non-NaN non-Inf elements. + health_pill[10] = float(np.mean(finite_subset)) + # Variance of the non-NaN non-Inf elements. + health_pill[11] = float(np.var(finite_subset)) + else: + # If no finite element exists: + # Set minimum to +inf. + health_pill[8] = np.inf + # Set maximum to -inf. + health_pill[9] = -np.inf + # Set mean to NaN. + health_pill[10] = np.nan + # Set variance to NaN. + health_pill[11] = np.nan - # DType encoded as a number. - # TODO(cais): Convert numpy dtype to corresponding tensorflow dtype enum. - health_pill[12] = -1.0 - # ndims. - health_pill[13] = float(len(tensor.shape)) - # Size of the dimensions. - health_pill.extend([float(x) for x in tensor.shape]) - return health_pill + # DType encoded as a number. + # TODO(cais): Convert numpy dtype to corresponding tensorflow dtype enum. + health_pill[12] = -1.0 + # ndims. + health_pill[13] = float(len(tensor.shape)) + # Size of the dimensions. + health_pill.extend([float(x) for x in tensor.shape]) + return health_pill diff --git a/tensorboard/plugins/debugger/health_pill_calc_test.py b/tensorboard/plugins/debugger/health_pill_calc_test.py index ef026d7450..5d96c1b7c0 100644 --- a/tensorboard/plugins/debugger/health_pill_calc_test.py +++ b/tensorboard/plugins/debugger/health_pill_calc_test.py @@ -25,104 +25,103 @@ class HealthPillCalcTest(tf.test.TestCase): + def testInfOnlyArray(self): + x = np.array([[np.inf, -np.inf], [np.inf, np.inf]]) + health_pill = health_pill_calc.calc_health_pill(x) + self.assertEqual(16, len(health_pill)) + self.assertEqual(1.0, health_pill[0]) # Is initialized. + self.assertEqual(4.0, health_pill[1]) # numel. + self.assertEqual(0, health_pill[2]) # NaN count. + self.assertEqual(1, health_pill[3]) # -Infinity count. + self.assertEqual(0, health_pill[4]) # Finite negative count. + self.assertEqual(0, health_pill[5]) # Zero count. + self.assertEqual(0, health_pill[6]) # Finite positive count. + self.assertEqual(3, health_pill[7]) # +Infinity count. + self.assertEqual(np.inf, health_pill[8]) + self.assertEqual(-np.inf, health_pill[9]) + self.assertTrue(np.isnan(health_pill[10])) + self.assertTrue(np.isnan(health_pill[11])) + self.assertEqual(2, health_pill[13]) # Number of dimensions. + self.assertEqual(2, health_pill[14]) # Size is (2, 2). + self.assertEqual(2, health_pill[15]) - def testInfOnlyArray(self): - x = np.array([[np.inf, -np.inf], [np.inf, np.inf]]) - health_pill = health_pill_calc.calc_health_pill(x) - self.assertEqual(16, len(health_pill)) - self.assertEqual(1.0, health_pill[0]) # Is initialized. - self.assertEqual(4.0, health_pill[1]) # numel. - self.assertEqual(0, health_pill[2]) # NaN count. - self.assertEqual(1, health_pill[3]) # -Infinity count. - self.assertEqual(0, health_pill[4]) # Finite negative count. - self.assertEqual(0, health_pill[5]) # Zero count. - self.assertEqual(0, health_pill[6]) # Finite positive count. - self.assertEqual(3, health_pill[7]) # +Infinity count. - self.assertEqual(np.inf, health_pill[8]) - self.assertEqual(-np.inf, health_pill[9]) - self.assertTrue(np.isnan(health_pill[10])) - self.assertTrue(np.isnan(health_pill[11])) - self.assertEqual(2, health_pill[13]) # Number of dimensions. - self.assertEqual(2, health_pill[14]) # Size is (2, 2). - self.assertEqual(2, health_pill[15]) + def testNanOnlyArray(self): + x = np.array([[np.nan, np.nan, np.nan]]) + health_pill = health_pill_calc.calc_health_pill(x) + self.assertEqual(16, len(health_pill)) + self.assertEqual(1, health_pill[0]) # Is initialized. + self.assertEqual(3, health_pill[1]) # numel. + self.assertEqual(3, health_pill[2]) # NaN count. + self.assertEqual(0, health_pill[3]) # -Infinity count. + self.assertEqual(0, health_pill[4]) # Finite negative count. + self.assertEqual(0, health_pill[5]) # Zero count. + self.assertEqual(0, health_pill[6]) # Finite positive count. + self.assertEqual(0, health_pill[7]) # +Infinity count. + self.assertEqual(np.inf, health_pill[8]) + self.assertEqual(-np.inf, health_pill[9]) + self.assertTrue(np.isnan(health_pill[10])) + self.assertTrue(np.isnan(health_pill[11])) + self.assertEqual(2, health_pill[13]) # Number of dimensions. + self.assertEqual(1, health_pill[14]) # Size is (1, 3) + self.assertEqual(3, health_pill[15]) - def testNanOnlyArray(self): - x = np.array([[np.nan, np.nan, np.nan]]) - health_pill = health_pill_calc.calc_health_pill(x) - self.assertEqual(16, len(health_pill)) - self.assertEqual(1, health_pill[0]) # Is initialized. - self.assertEqual(3, health_pill[1]) # numel. - self.assertEqual(3, health_pill[2]) # NaN count. - self.assertEqual(0, health_pill[3]) # -Infinity count. - self.assertEqual(0, health_pill[4]) # Finite negative count. - self.assertEqual(0, health_pill[5]) # Zero count. - self.assertEqual(0, health_pill[6]) # Finite positive count. - self.assertEqual(0, health_pill[7]) # +Infinity count. - self.assertEqual(np.inf, health_pill[8]) - self.assertEqual(-np.inf, health_pill[9]) - self.assertTrue(np.isnan(health_pill[10])) - self.assertTrue(np.isnan(health_pill[11])) - self.assertEqual(2, health_pill[13]) # Number of dimensions. - self.assertEqual(1, health_pill[14]) # Size is (1, 3) - self.assertEqual(3, health_pill[15]) + def testInfAndNanOnlyArray(self): + x = np.array([np.inf, -np.inf, np.nan]) + health_pill = health_pill_calc.calc_health_pill(x) + self.assertEqual(15, len(health_pill)) + self.assertEqual(1, health_pill[0]) # Is initialized. + self.assertEqual(3, health_pill[1]) # numel. + self.assertEqual(1, health_pill[2]) # NaN count. + self.assertEqual(1, health_pill[3]) # -Infinity count. + self.assertEqual(0, health_pill[4]) # Finite negative count. + self.assertEqual(0, health_pill[5]) # Zero count. + self.assertEqual(0, health_pill[6]) # Finite positive count. + self.assertEqual(1, health_pill[7]) # +Infinity count. + self.assertEqual(np.inf, health_pill[8]) + self.assertEqual(-np.inf, health_pill[9]) + self.assertTrue(np.isnan(health_pill[10])) + self.assertTrue(np.isnan(health_pill[11])) + self.assertEqual(1, health_pill[13]) # Number of dimensions. + self.assertEqual(3, health_pill[14]) # Size is (3,). - def testInfAndNanOnlyArray(self): - x = np.array([np.inf, -np.inf, np.nan]) - health_pill = health_pill_calc.calc_health_pill(x) - self.assertEqual(15, len(health_pill)) - self.assertEqual(1, health_pill[0]) # Is initialized. - self.assertEqual(3, health_pill[1]) # numel. - self.assertEqual(1, health_pill[2]) # NaN count. - self.assertEqual(1, health_pill[3]) # -Infinity count. - self.assertEqual(0, health_pill[4]) # Finite negative count. - self.assertEqual(0, health_pill[5]) # Zero count. - self.assertEqual(0, health_pill[6]) # Finite positive count. - self.assertEqual(1, health_pill[7]) # +Infinity count. - self.assertEqual(np.inf, health_pill[8]) - self.assertEqual(-np.inf, health_pill[9]) - self.assertTrue(np.isnan(health_pill[10])) - self.assertTrue(np.isnan(health_pill[11])) - self.assertEqual(1, health_pill[13]) # Number of dimensions. - self.assertEqual(3, health_pill[14]) # Size is (3,). + def testEmptyArray(self): + x = np.array([[], []]) + health_pill = health_pill_calc.calc_health_pill(x) + self.assertEqual(16, len(health_pill)) + self.assertEqual(1, health_pill[0]) # Is initialized. + self.assertEqual(0, health_pill[1]) # numel. + self.assertEqual(0, health_pill[2]) # NaN count. + self.assertEqual(0, health_pill[3]) # -Infinity count. + self.assertEqual(0, health_pill[4]) # Finite negative count. + self.assertEqual(0, health_pill[5]) # Zero count. + self.assertEqual(0, health_pill[6]) # Finite positive count. + self.assertEqual(0, health_pill[7]) # +Infinity count. + self.assertEqual(np.inf, health_pill[8]) + self.assertEqual(-np.inf, health_pill[9]) + self.assertTrue(np.isnan(health_pill[10])) + self.assertTrue(np.isnan(health_pill[11])) + self.assertEqual(2, health_pill[13]) # Number of dimensions. + self.assertEqual(2, health_pill[14]) # Size is (2, 0). + self.assertEqual(0, health_pill[15]) - def testEmptyArray(self): - x = np.array([[], []]) - health_pill = health_pill_calc.calc_health_pill(x) - self.assertEqual(16, len(health_pill)) - self.assertEqual(1, health_pill[0]) # Is initialized. - self.assertEqual(0, health_pill[1]) # numel. - self.assertEqual(0, health_pill[2]) # NaN count. - self.assertEqual(0, health_pill[3]) # -Infinity count. - self.assertEqual(0, health_pill[4]) # Finite negative count. - self.assertEqual(0, health_pill[5]) # Zero count. - self.assertEqual(0, health_pill[6]) # Finite positive count. - self.assertEqual(0, health_pill[7]) # +Infinity count. - self.assertEqual(np.inf, health_pill[8]) - self.assertEqual(-np.inf, health_pill[9]) - self.assertTrue(np.isnan(health_pill[10])) - self.assertTrue(np.isnan(health_pill[11])) - self.assertEqual(2, health_pill[13]) # Number of dimensions. - self.assertEqual(2, health_pill[14]) # Size is (2, 0). - self.assertEqual(0, health_pill[15]) - - def testScalar(self): - x = np.array(-1337.0) - health_pill = health_pill_calc.calc_health_pill(x) - self.assertEqual(14, len(health_pill)) - self.assertEqual(1, health_pill[0]) # Is initialized. - self.assertEqual(1, health_pill[1]) # numel. - self.assertEqual(0, health_pill[2]) # NaN count. - self.assertEqual(0, health_pill[3]) # -Infinity count. - self.assertEqual(1, health_pill[4]) # Finite negative count. - self.assertEqual(0, health_pill[5]) # Zero count. - self.assertEqual(0, health_pill[6]) # Finite positive count. - self.assertEqual(0, health_pill[7]) # +Infinity count. - self.assertEqual(-1337.0, health_pill[8]) - self.assertEqual(-1337.0, health_pill[9]) - self.assertEqual(-1337.0, health_pill[10]) - self.assertEqual(0, health_pill[11]) - self.assertEqual(0, health_pill[13]) # Number of dimensions. + def testScalar(self): + x = np.array(-1337.0) + health_pill = health_pill_calc.calc_health_pill(x) + self.assertEqual(14, len(health_pill)) + self.assertEqual(1, health_pill[0]) # Is initialized. + self.assertEqual(1, health_pill[1]) # numel. + self.assertEqual(0, health_pill[2]) # NaN count. + self.assertEqual(0, health_pill[3]) # -Infinity count. + self.assertEqual(1, health_pill[4]) # Finite negative count. + self.assertEqual(0, health_pill[5]) # Zero count. + self.assertEqual(0, health_pill[6]) # Finite positive count. + self.assertEqual(0, health_pill[7]) # +Infinity count. + self.assertEqual(-1337.0, health_pill[8]) + self.assertEqual(-1337.0, health_pill[9]) + self.assertEqual(-1337.0, health_pill[10]) + self.assertEqual(0, health_pill[11]) + self.assertEqual(0, health_pill[13]) # Number of dimensions. if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/tensorboard/plugins/debugger/interactive_debugger_plugin.py b/tensorboard/plugins/debugger/interactive_debugger_plugin.py index 5e7ba8c69f..10b890ac28 100644 --- a/tensorboard/plugins/debugger/interactive_debugger_plugin.py +++ b/tensorboard/plugins/debugger/interactive_debugger_plugin.py @@ -37,301 +37,347 @@ logger = tb_logging.get_logger() # HTTP routes. -_ACK_ROUTE = '/ack' -_COMM_ROUTE = '/comm' -_DEBUGGER_GRAPH_ROUTE = '/debugger_graph' -_DEBUGGER_GRPC_HOST_PORT_ROUTE = '/debugger_grpc_host_port' -_GATED_GRPC_ROUTE = '/gated_grpc' -_TENSOR_DATA_ROUTE = '/tensor_data' -_SOURCE_CODE_ROUTE = '/source_code' +_ACK_ROUTE = "/ack" +_COMM_ROUTE = "/comm" +_DEBUGGER_GRAPH_ROUTE = "/debugger_graph" +_DEBUGGER_GRPC_HOST_PORT_ROUTE = "/debugger_grpc_host_port" +_GATED_GRPC_ROUTE = "/gated_grpc" +_TENSOR_DATA_ROUTE = "/tensor_data" +_SOURCE_CODE_ROUTE = "/source_code" class InteractiveDebuggerPlugin(base_plugin.TBPlugin): - """Interactive TensorFlow Debugger plugin. + """Interactive TensorFlow Debugger plugin. - This underlies the interactive Debugger Dashboard. + This underlies the interactive Debugger Dashboard. - This is different from the non-interactive `DebuggerPlugin` in module - `debugger_plugin`. The latter is for the "health pills" feature in the Graph - Dashboard. - """ - - # This string field is used by TensorBoard to generate the paths for routes - # provided by this plugin. It must thus be URL-friendly. This field is also - # used to uniquely identify this plugin throughout TensorBoard. See BasePlugin - # for details. - plugin_name = constants.DEBUGGER_PLUGIN_NAME - - def __init__(self, context): - """Constructs a debugger plugin for TensorBoard. - - This plugin adds handlers for retrieving debugger-related data. The plugin - also starts a debugger data server once the log directory is passed to the - plugin via the call to get_plugin_apps. - - Args: - context: A base_plugin.TBContext instance. + This is different from the non-interactive `DebuggerPlugin` in module + `debugger_plugin`. The latter is for the "health pills" feature in the Graph + Dashboard. """ - del context # Unused. - self._debugger_data_server = None - self._server_thread = None - self._grpc_port = None - - def listen(self, grpc_port): - """Start listening on the given gRPC port. - This method of an instance of InteractiveDebuggerPlugin can be invoked at - most once. This method is not thread safe. - - Args: - grpc_port: port number to listen at. - - Raises: - ValueError: If this instance is already listening at a gRPC port. - """ - if self._grpc_port: - raise ValueError( - 'This InteractiveDebuggerPlugin instance is already listening at ' - 'gRPC port %d' % self._grpc_port) - self._grpc_port = grpc_port - - sys.stderr.write('Creating InteractiveDebuggerPlugin at port %d\n' % - self._grpc_port) - sys.stderr.flush() - self._debugger_data_server = ( - interactive_debugger_server_lib.InteractiveDebuggerDataServer( - self._grpc_port)) - - self._server_thread = threading.Thread( - target=self._debugger_data_server.run_server) - self._server_thread.start() - - signal.signal(signal.SIGINT, self.signal_handler) - # Note: this is required because of a wontfix issue in grpc/python 2.7: - # https://github.com/grpc/grpc/issues/3820 - - def signal_handler(self, unused_signal, unused_frame): - if self._debugger_data_server and self._server_thread: - print('Stopping InteractiveDebuggerPlugin...') - # Enqueue a number of messages to the incoming message queue to try to - # let the debugged tensorflow runtime proceed past the current Session.run - # in the C++ layer and return to the Python layer, so the SIGINT handler - # registered there may be triggered. - for _ in xrange(len(self._debugger_data_server.breakpoints) + 1): + # This string field is used by TensorBoard to generate the paths for routes + # provided by this plugin. It must thus be URL-friendly. This field is also + # used to uniquely identify this plugin throughout TensorBoard. See BasePlugin + # for details. + plugin_name = constants.DEBUGGER_PLUGIN_NAME + + def __init__(self, context): + """Constructs a debugger plugin for TensorBoard. + + This plugin adds handlers for retrieving debugger-related data. The plugin + also starts a debugger data server once the log directory is passed to the + plugin via the call to get_plugin_apps. + + Args: + context: A base_plugin.TBContext instance. + """ + del context # Unused. + self._debugger_data_server = None + self._server_thread = None + self._grpc_port = None + + def listen(self, grpc_port): + """Start listening on the given gRPC port. + + This method of an instance of InteractiveDebuggerPlugin can be invoked at + most once. This method is not thread safe. + + Args: + grpc_port: port number to listen at. + + Raises: + ValueError: If this instance is already listening at a gRPC port. + """ + if self._grpc_port: + raise ValueError( + "This InteractiveDebuggerPlugin instance is already listening at " + "gRPC port %d" % self._grpc_port + ) + self._grpc_port = grpc_port + + sys.stderr.write( + "Creating InteractiveDebuggerPlugin at port %d\n" % self._grpc_port + ) + sys.stderr.flush() + self._debugger_data_server = interactive_debugger_server_lib.InteractiveDebuggerDataServer( + self._grpc_port + ) + + self._server_thread = threading.Thread( + target=self._debugger_data_server.run_server + ) + self._server_thread.start() + + signal.signal(signal.SIGINT, self.signal_handler) + # Note: this is required because of a wontfix issue in grpc/python 2.7: + # https://github.com/grpc/grpc/issues/3820 + + def signal_handler(self, unused_signal, unused_frame): + if self._debugger_data_server and self._server_thread: + print("Stopping InteractiveDebuggerPlugin...") + # Enqueue a number of messages to the incoming message queue to try to + # let the debugged tensorflow runtime proceed past the current Session.run + # in the C++ layer and return to the Python layer, so the SIGINT handler + # registered there may be triggered. + for _ in xrange(len(self._debugger_data_server.breakpoints) + 1): + self._debugger_data_server.put_incoming_message(True) + try: + self._debugger_data_server.stop_server() + except ValueError: + # In case the server has already stopped running. + pass + self._server_thread.join() + print("InteractiveDebuggerPlugin stopped.") + sys.exit(0) + + def get_plugin_apps(self): + """Obtains a mapping between routes and handlers. + + This function also starts a debugger data server on separate thread if the + plugin has not started one yet. + + Returns: + A mapping between routes and handlers (functions that respond to + requests). + """ + return { + _ACK_ROUTE: self._serve_ack, + _COMM_ROUTE: self._serve_comm, + _DEBUGGER_GRPC_HOST_PORT_ROUTE: self._serve_debugger_grpc_host_port, + _DEBUGGER_GRAPH_ROUTE: self._serve_debugger_graph, + _GATED_GRPC_ROUTE: self._serve_gated_grpc, + _TENSOR_DATA_ROUTE: self._serve_tensor_data, + _SOURCE_CODE_ROUTE: self._serve_source_code, + } + + def is_active(self): + """Determines whether this plugin is active. + + This plugin is active if any health pills information is present for any + run. + + Returns: + A boolean. Whether this plugin is active. + """ + return self._grpc_port is not None + + def frontend_metadata(self): + # TODO(#2338): Keep this in sync with the `registerDashboard` call + # on the frontend until that call is removed. + return base_plugin.FrontendMetadata( + element_name="tf-debugger-dashboard" + ) + + @wrappers.Request.application + def _serve_ack(self, request): + # Send client acknowledgement. `True` is just used as a dummy value. self._debugger_data_server.put_incoming_message(True) - try: - self._debugger_data_server.stop_server() - except ValueError: - # In case the server has already stopped running. - pass - self._server_thread.join() - print('InteractiveDebuggerPlugin stopped.') - sys.exit(0) - - def get_plugin_apps(self): - """Obtains a mapping between routes and handlers. - - This function also starts a debugger data server on separate thread if the - plugin has not started one yet. - - Returns: - A mapping between routes and handlers (functions that respond to - requests). - """ - return { - _ACK_ROUTE: self._serve_ack, - _COMM_ROUTE: self._serve_comm, - _DEBUGGER_GRPC_HOST_PORT_ROUTE: self._serve_debugger_grpc_host_port, - _DEBUGGER_GRAPH_ROUTE: self._serve_debugger_graph, - _GATED_GRPC_ROUTE: self._serve_gated_grpc, - _TENSOR_DATA_ROUTE: self._serve_tensor_data, - _SOURCE_CODE_ROUTE: self._serve_source_code, - } - - def is_active(self): - """Determines whether this plugin is active. - - This plugin is active if any health pills information is present for any - run. - - Returns: - A boolean. Whether this plugin is active. - """ - return self._grpc_port is not None - - def frontend_metadata(self): - # TODO(#2338): Keep this in sync with the `registerDashboard` call - # on the frontend until that call is removed. - return base_plugin.FrontendMetadata(element_name='tf-debugger-dashboard') - - @wrappers.Request.application - def _serve_ack(self, request): - # Send client acknowledgement. `True` is just used as a dummy value. - self._debugger_data_server.put_incoming_message(True) - return http_util.Respond(request, {}, 'application/json') - - @wrappers.Request.application - def _serve_comm(self, request): - # comm_channel.get() blocks until an item is put into the queue (by - # self._debugger_data_server). This is how the HTTP long polling ends. - pos = int(request.args.get("pos")) - comm_data = self._debugger_data_server.get_outgoing_message(pos) - return http_util.Respond(request, comm_data, 'application/json') - - @wrappers.Request.application - def _serve_debugger_graph(self, request): - device_name = request.args.get('device_name') - if not device_name or device_name == 'null': - return http_util.Respond(request, str(None), 'text/x-protobuf') - - run_key = interactive_debugger_server_lib.RunKey( - *json.loads(request.args.get('run_key'))) - graph_def = self._debugger_data_server.get_graph(run_key, device_name) - logger.debug( - '_serve_debugger_graph(): device_name = %s, run_key = %s, ' - 'type(graph_def) = %s', device_name, run_key, type(graph_def)) - # TODO(cais): Sending text proto may be slow in Python. Investigate whether - # there are ways to optimize it. - return http_util.Respond(request, str(graph_def), 'text/x-protobuf') - - def _error_response(self, request, error_msg): - logger.error(error_msg) - return http_util.Respond( - request, {'error': error_msg}, 'application/json', 400) - - @wrappers.Request.application - def _serve_gated_grpc(self, request): - mode = request.args.get('mode') - if mode == 'retrieve_all' or mode == 'retrieve_device_names': - # 'retrieve_all': Retrieve all gated-gRPC debug tensors and currently - # enabled breakpoints associated with the given run_key. - # 'retrieve_device_names': Retrieve all device names associated with the - # given run key. - run_key = interactive_debugger_server_lib.RunKey( - *json.loads(request.args.get('run_key'))) - # debug_graph_defs is a map from device_name to GraphDef. - debug_graph_defs = self._debugger_data_server.get_graphs(run_key, - debug=True) - if mode == 'retrieve_device_names': - return http_util.Respond(request, { - 'device_names': list(debug_graph_defs.keys()), - }, 'application/json') - - gated = {} - for device_name in debug_graph_defs: - gated[device_name] = self._debugger_data_server.get_gated_grpc_tensors( - run_key, device_name) - - # Both gated and self._debugger_data_server.breakpoints are lists whose - # items are (node_name, output_slot, debug_op_name). - return http_util.Respond(request, { - 'gated_grpc_tensors': gated, - 'breakpoints': self._debugger_data_server.breakpoints, - 'device_names': list(debug_graph_defs.keys()), - }, 'application/json') - elif mode == 'breakpoints': - # Retrieve currently enabled breakpoints. - return http_util.Respond( - request, self._debugger_data_server.breakpoints, 'application/json') - elif mode == 'set_state': - # Set the state of gated-gRPC debug tensors, e.g., disable, enable - # breakpoint. - node_name = request.args.get('node_name') - output_slot = int(request.args.get('output_slot')) - debug_op = request.args.get('debug_op') - state = request.args.get('state') - logger.debug('Setting state of %s:%d:%s to: %s' % - (node_name, output_slot, debug_op, state)) - if state == 'disable': - self._debugger_data_server.request_unwatch( - node_name, output_slot, debug_op) - elif state == 'watch': - self._debugger_data_server.request_watch( - node_name, output_slot, debug_op, breakpoint=False) - elif state == 'break': - self._debugger_data_server.request_watch( - node_name, output_slot, debug_op, breakpoint=True) - else: - return self._error_response( - request, 'Unrecognized new state for %s:%d:%s: %s' % (node_name, - output_slot, - debug_op, - state)) - return http_util.Respond( - request, - {'node_name': node_name, - 'output_slot': output_slot, - 'debug_op': debug_op, - 'state': state}, - 'application/json') - else: - return self._error_response( - request, 'Unrecognized mode for the gated_grpc route: %s' % mode) - - @wrappers.Request.application - def _serve_debugger_grpc_host_port(self, request): - return http_util.Respond( - request, - {'host': platform.node(), 'port': self._grpc_port}, 'application/json') - - @wrappers.Request.application - def _serve_tensor_data(self, request): - response_encoding = 'application/json' - watch_key = request.args.get('watch_key') - time_indices = request.args.get('time_indices') - mapping = request.args.get('mapping') - slicing = request.args.get('slicing') - - try: - sliced_tensor_data = self._debugger_data_server.query_tensor_store( - watch_key, time_indices=time_indices, slicing=slicing, - mapping=mapping) - response = { - 'tensor_data': sliced_tensor_data, - 'error': None - } - status_code = 200 - except (IndexError, ValueError) as e: - response = { - 'tensor_data': None, - 'error': { - 'type': type(e).__name__, - }, - } - # TODO(cais): Provide safe and succinct error messages for common error - # conditions, such as index out of bound, or invalid mapping for given - # tensor ranks. - status_code = 500 - return http_util.Respond(request, response, response_encoding, status_code) - - @wrappers.Request.application - def _serve_source_code(self, request): - response_encoding = 'application/json' - - mode = request.args.get('mode') - if mode == 'paths': - # Retrieve all file paths. - response = {'paths': self._debugger_data_server.query_source_file_paths()} - return http_util.Respond(request, response, response_encoding) - elif mode == 'content': - # Retrieve the content of a source file. - file_path = request.args.get('file_path') - response = { - 'content': { - file_path: self._debugger_data_server.query_source_file_content( - file_path)}, - 'lineno_to_op_name_and_stack_pos': - self._debugger_data_server.query_file_tracebacks(file_path)} - return http_util.Respond(request, response, response_encoding) - elif mode == 'op_traceback': - # Retrieve the traceback of a graph op by name of the op. - op_name = request.args.get('op_name') - response = { - 'op_traceback': { - op_name: self._debugger_data_server.query_op_traceback(op_name) - } - } - return http_util.Respond(request, response, response_encoding) - else: - response = {'error': 'Invalid mode for source_code endpoint: %s' % mode} - return http_util.Respond(request, response, response_encoding, 500) + return http_util.Respond(request, {}, "application/json") + + @wrappers.Request.application + def _serve_comm(self, request): + # comm_channel.get() blocks until an item is put into the queue (by + # self._debugger_data_server). This is how the HTTP long polling ends. + pos = int(request.args.get("pos")) + comm_data = self._debugger_data_server.get_outgoing_message(pos) + return http_util.Respond(request, comm_data, "application/json") + + @wrappers.Request.application + def _serve_debugger_graph(self, request): + device_name = request.args.get("device_name") + if not device_name or device_name == "null": + return http_util.Respond(request, str(None), "text/x-protobuf") + + run_key = interactive_debugger_server_lib.RunKey( + *json.loads(request.args.get("run_key")) + ) + graph_def = self._debugger_data_server.get_graph(run_key, device_name) + logger.debug( + "_serve_debugger_graph(): device_name = %s, run_key = %s, " + "type(graph_def) = %s", + device_name, + run_key, + type(graph_def), + ) + # TODO(cais): Sending text proto may be slow in Python. Investigate whether + # there are ways to optimize it. + return http_util.Respond(request, str(graph_def), "text/x-protobuf") + + def _error_response(self, request, error_msg): + logger.error(error_msg) + return http_util.Respond( + request, {"error": error_msg}, "application/json", 400 + ) + + @wrappers.Request.application + def _serve_gated_grpc(self, request): + mode = request.args.get("mode") + if mode == "retrieve_all" or mode == "retrieve_device_names": + # 'retrieve_all': Retrieve all gated-gRPC debug tensors and currently + # enabled breakpoints associated with the given run_key. + # 'retrieve_device_names': Retrieve all device names associated with the + # given run key. + run_key = interactive_debugger_server_lib.RunKey( + *json.loads(request.args.get("run_key")) + ) + # debug_graph_defs is a map from device_name to GraphDef. + debug_graph_defs = self._debugger_data_server.get_graphs( + run_key, debug=True + ) + if mode == "retrieve_device_names": + return http_util.Respond( + request, + {"device_names": list(debug_graph_defs.keys()),}, + "application/json", + ) + + gated = {} + for device_name in debug_graph_defs: + gated[ + device_name + ] = self._debugger_data_server.get_gated_grpc_tensors( + run_key, device_name + ) + + # Both gated and self._debugger_data_server.breakpoints are lists whose + # items are (node_name, output_slot, debug_op_name). + return http_util.Respond( + request, + { + "gated_grpc_tensors": gated, + "breakpoints": self._debugger_data_server.breakpoints, + "device_names": list(debug_graph_defs.keys()), + }, + "application/json", + ) + elif mode == "breakpoints": + # Retrieve currently enabled breakpoints. + return http_util.Respond( + request, + self._debugger_data_server.breakpoints, + "application/json", + ) + elif mode == "set_state": + # Set the state of gated-gRPC debug tensors, e.g., disable, enable + # breakpoint. + node_name = request.args.get("node_name") + output_slot = int(request.args.get("output_slot")) + debug_op = request.args.get("debug_op") + state = request.args.get("state") + logger.debug( + "Setting state of %s:%d:%s to: %s" + % (node_name, output_slot, debug_op, state) + ) + if state == "disable": + self._debugger_data_server.request_unwatch( + node_name, output_slot, debug_op + ) + elif state == "watch": + self._debugger_data_server.request_watch( + node_name, output_slot, debug_op, breakpoint=False + ) + elif state == "break": + self._debugger_data_server.request_watch( + node_name, output_slot, debug_op, breakpoint=True + ) + else: + return self._error_response( + request, + "Unrecognized new state for %s:%d:%s: %s" + % (node_name, output_slot, debug_op, state), + ) + return http_util.Respond( + request, + { + "node_name": node_name, + "output_slot": output_slot, + "debug_op": debug_op, + "state": state, + }, + "application/json", + ) + else: + return self._error_response( + request, "Unrecognized mode for the gated_grpc route: %s" % mode + ) + + @wrappers.Request.application + def _serve_debugger_grpc_host_port(self, request): + return http_util.Respond( + request, + {"host": platform.node(), "port": self._grpc_port}, + "application/json", + ) + + @wrappers.Request.application + def _serve_tensor_data(self, request): + response_encoding = "application/json" + watch_key = request.args.get("watch_key") + time_indices = request.args.get("time_indices") + mapping = request.args.get("mapping") + slicing = request.args.get("slicing") + + try: + sliced_tensor_data = self._debugger_data_server.query_tensor_store( + watch_key, + time_indices=time_indices, + slicing=slicing, + mapping=mapping, + ) + response = {"tensor_data": sliced_tensor_data, "error": None} + status_code = 200 + except (IndexError, ValueError) as e: + response = { + "tensor_data": None, + "error": {"type": type(e).__name__,}, + } + # TODO(cais): Provide safe and succinct error messages for common error + # conditions, such as index out of bound, or invalid mapping for given + # tensor ranks. + status_code = 500 + return http_util.Respond( + request, response, response_encoding, status_code + ) + + @wrappers.Request.application + def _serve_source_code(self, request): + response_encoding = "application/json" + + mode = request.args.get("mode") + if mode == "paths": + # Retrieve all file paths. + response = { + "paths": self._debugger_data_server.query_source_file_paths() + } + return http_util.Respond(request, response, response_encoding) + elif mode == "content": + # Retrieve the content of a source file. + file_path = request.args.get("file_path") + response = { + "content": { + file_path: self._debugger_data_server.query_source_file_content( + file_path + ) + }, + "lineno_to_op_name_and_stack_pos": self._debugger_data_server.query_file_tracebacks( + file_path + ), + } + return http_util.Respond(request, response, response_encoding) + elif mode == "op_traceback": + # Retrieve the traceback of a graph op by name of the op. + op_name = request.args.get("op_name") + response = { + "op_traceback": { + op_name: self._debugger_data_server.query_op_traceback( + op_name + ) + } + } + return http_util.Respond(request, response, response_encoding) + else: + response = { + "error": "Invalid mode for source_code endpoint: %s" % mode + } + return http_util.Respond(request, response, response_encoding, 500) diff --git a/tensorboard/plugins/debugger/interactive_debugger_plugin_test.py b/tensorboard/plugins/debugger/interactive_debugger_plugin_test.py index 53ba831d84..422d8a97cb 100644 --- a/tensorboard/plugins/debugger/interactive_debugger_plugin_test.py +++ b/tensorboard/plugins/debugger/interactive_debugger_plugin_test.py @@ -14,10 +14,11 @@ # ============================================================================== """Tests end-to-end debugger interactive data server behavior. -This test launches an instance InteractiveDebuggerPlugin as a separate thread. -The test then calls Session.run() using RunOptions pointing to the grpc:// debug -URL of the debugger data server. It then sends HTTP requests to the TensorBoard -backend endpoints to query and control the state of the Sessoin.run(). +This test launches an instance InteractiveDebuggerPlugin as a separate +thread. The test then calls Session.run() using RunOptions pointing to +the grpc:// debug URL of the debugger data server. It then sends HTTP +requests to the TensorBoard backend endpoints to query and control the +state of the Sessoin.run(). """ from __future__ import absolute_import @@ -34,12 +35,16 @@ import portpicker # pylint: disable=import-error from six.moves import urllib # pylint: disable=wrong-import-order import tensorflow.compat.v1 as tf # pylint: disable=wrong-import-order -from tensorflow.python import debug as tf_debug # pylint: disable=wrong-import-order +from tensorflow.python import ( + debug as tf_debug, +) # pylint: disable=wrong-import-order from werkzeug import test as werkzeug_test # pylint: disable=wrong-import-order from werkzeug import wrappers # pylint: disable=wrong-import-order from tensorboard.backend import application -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from tensorboard.backend.event_processing import ( + plugin_event_multiplexer as event_multiplexer, +) from tensorboard.plugins import base_plugin from tensorboard.plugins.debugger import interactive_debugger_plugin from tensorboard.util import test_util @@ -49,910 +54,1168 @@ tf.disable_v2_behavior() -_SERVER_URL_PREFIX = '/data/plugin/debugger/' +_SERVER_URL_PREFIX = "/data/plugin/debugger/" class InteractiveDebuggerPluginTest(tf.test.TestCase): - - def setUp(self): - super(InteractiveDebuggerPluginTest, self).setUp() - - self._dummy_logdir = tempfile.mkdtemp() - dummy_multiplexer = event_multiplexer.EventMultiplexer({}) - self._debugger_port = portpicker.pick_unused_port() - self._debugger_url = 'grpc://localhost:%d' % self._debugger_port - context = base_plugin.TBContext(logdir=self._dummy_logdir, - multiplexer=dummy_multiplexer) - self._debugger_plugin = ( - interactive_debugger_plugin.InteractiveDebuggerPlugin(context)) - self._debugger_plugin.listen(self._debugger_port) - - wsgi_app = application.TensorBoardWSGI([self._debugger_plugin]) - self._server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse) - - def tearDown(self): - # In some cases (e.g., an empty test method body), the stop_server() method - # may get called before the server is started, leading to a ValueError. - while True: - try: - self._debugger_plugin._debugger_data_server.stop_server() - break - except ValueError: - pass - shutil.rmtree(self._dummy_logdir, ignore_errors=True) - super(InteractiveDebuggerPluginTest, self).tearDown() - - def _serverGet(self, path, params=None, expected_status_code=200): - """Send the serve a GET request and obtain the response. - - Args: - path: URL path (excluding the prefix), without parameters encoded. - params: Query parameters to be encoded in the URL, as a dict. - expected_status_code: Expected status code. - - Returns: - Response from server. - """ - url = _SERVER_URL_PREFIX + path - if params: - url += '?' + urllib.parse.urlencode(params) - response = self._server.get(url) - self.assertEqual(expected_status_code, response.status_code) - return response - - def _deserializeResponse(self, response): - """Deserializes byte content that is a JSON encoding. - - Args: - response: A response object. - - Returns: - The deserialized python object decoded from JSON. - """ - return json.loads(response.get_data().decode("utf-8")) - - def _runSimpleAddMultiplyGraph(self, variable_size=1): - session_run_results = [] - def session_run_job(): - with tf.Session() as sess: - a = tf.Variable([10.0] * variable_size, name='a') - b = tf.Variable([20.0] * variable_size, name='b') - c = tf.Variable([30.0] * variable_size, name='c') - x = tf.multiply(a, b, name="x") - y = tf.add(x, c, name="y") - - sess.run(tf.global_variables_initializer()) - - sess = tf_debug.TensorBoardDebugWrapperSession(sess, self._debugger_url) - session_run_results.append(sess.run(y)) - session_run_thread = threading.Thread(target=session_run_job) - session_run_thread.start() - return session_run_thread, session_run_results - - def _runMultiStepAssignAddGraph(self, steps): - session_run_results = [] - def session_run_job(): - with tf.Session() as sess: - a = tf.Variable(10, dtype=tf.int32, name='a') - b = tf.Variable(1, dtype=tf.int32, name='b') - inc_a = tf.assign_add(a, b, name='inc_a') - - sess.run(tf.global_variables_initializer()) - - sess = tf_debug.TensorBoardDebugWrapperSession(sess, self._debugger_url) - for _ in range(steps): - session_run_results.append(sess.run(inc_a)) - session_run_thread = threading.Thread(target=session_run_job) - session_run_thread.start() - return session_run_thread, session_run_results - - def _runTfGroupGraph(self): - session_run_results = [] - def session_run_job(): - with tf.Session() as sess: - a = tf.Variable(10, dtype=tf.int32, name='a') - b = tf.Variable(20, dtype=tf.int32, name='b') - d = tf.constant(1, dtype=tf.int32, name='d') - inc_a = tf.assign_add(a, d, name='inc_a') - inc_b = tf.assign_add(b, d, name='inc_b') - inc_ab = tf.group([inc_a, inc_b], name="inc_ab") - - sess.run(tf.global_variables_initializer()) - - sess = tf_debug.TensorBoardDebugWrapperSession(sess, self._debugger_url) - session_run_results.append(sess.run(inc_ab)) - session_run_thread = threading.Thread(target=session_run_job) - session_run_thread.start() - return session_run_thread, session_run_results - - def testCommAndAckWithoutBreakpoints(self): - session_run_thread, session_run_results = self._runSimpleAddMultiplyGraph() - - comm_response = self._serverGet('comm', {'pos': 1}) - response_data = self._deserializeResponse(comm_response) - self.assertGreater(response_data['timestamp'], 0) - self.assertEqual('meta', response_data['type']) - self.assertEqual({'run_key': ['', 'y:0', '']}, response_data['data']) - - self._serverGet('ack') - - session_run_thread.join() - self.assertAllClose([[230.0]], session_run_results) - - def testGetDeviceNamesAndDebuggerGraph(self): - session_run_thread, session_run_results = self._runSimpleAddMultiplyGraph() - - comm_response = self._serverGet('comm', {'pos': 1}) - response_data = self._deserializeResponse(comm_response) - run_key = json.dumps(response_data['data']['run_key']) - - device_names_response = self._serverGet( - 'gated_grpc', {'mode': 'retrieve_device_names', 'run_key': run_key}) - device_names_data = self._deserializeResponse(device_names_response) - self.assertEqual(1, len(device_names_data['device_names'])) - device_name = device_names_data['device_names'][0] - - graph_response = self._serverGet( - 'debugger_graph', {'run_key': run_key, 'device_name': device_name}) - self.assertTrue(graph_response.get_data()) - - self._serverGet('ack') - - session_run_thread.join() - self.assertAllClose([[230.0]], session_run_results) - - def testRetrieveAllGatedGrpcTensors(self): - session_run_thread, session_run_results = self._runSimpleAddMultiplyGraph() - - comm_response = self._serverGet('comm', {'pos': 1}) - response_data = self._deserializeResponse(comm_response) - run_key = json.dumps(response_data['data']['run_key']) - - retrieve_all_response = self._serverGet( - 'gated_grpc', {'mode': 'retrieve_all', 'run_key': run_key}) - retrieve_all_data = self._deserializeResponse(retrieve_all_response) - self.assertTrue(retrieve_all_data['device_names']) - # No breakpoints have been activated. - self.assertEqual([], retrieve_all_data['breakpoints']) - device_name = retrieve_all_data['device_names'][0] - tensor_names = [item[0] for item - in retrieve_all_data['gated_grpc_tensors'][device_name]] - self.assertItemsEqual( - ['a', 'a/read', 'b', 'b/read', 'x', 'c', 'c/read', 'y'], tensor_names) - - self._serverGet('ack') - - session_run_thread.join() - self.assertAllClose([[230.0]], session_run_results) - - def testActivateOneBreakpoint(self): - session_run_thread, session_run_results = self._runSimpleAddMultiplyGraph() - - comm_response = self._serverGet('comm', {'pos': 1}) - - # Activate breakpoint for x:0. - self._serverGet( - 'gated_grpc', - {'mode': 'set_state', 'node_name': 'x', 'output_slot': 0, - 'debug_op': 'DebugIdentity', 'state': 'break'}) - - # Proceed to breakpoint x:0. - self._serverGet('ack') - comm_response = self._serverGet('comm', {'pos': 2}) - comm_data = self._deserializeResponse(comm_response) - self.assertGreater(comm_data['timestamp'], 0) - self.assertEqual('tensor', comm_data['type']) - self.assertEqual('float32', comm_data['data']['dtype']) - self.assertEqual([1], comm_data['data']['shape']) - self.assertEqual('x', comm_data['data']['node_name']) - self.assertEqual(0, comm_data['data']['output_slot']) - self.assertEqual('DebugIdentity', comm_data['data']['debug_op']) - self.assertAllClose([200.0], comm_data['data']['values']) - - # Proceed to the end of the Session.run(). - self._serverGet('ack') - - session_run_thread.join() - self.assertAllClose([[230.0]], session_run_results) - - # Verify that the activated breakpoint is remembered. - breakpoints_response = self._serverGet( - 'gated_grpc', {'mode': 'breakpoints'}) - breakpoints_data = self._deserializeResponse(breakpoints_response) - self.assertEqual([['x', 0, 'DebugIdentity']], breakpoints_data) - - def testActivateAndDeactivateOneBreakpoint(self): - session_run_thread, session_run_results = self._runSimpleAddMultiplyGraph() - - self._serverGet('comm', {'pos': 1}) - - # Activate breakpoint for x:0. - self._serverGet( - 'gated_grpc', - {'mode': 'set_state', 'node_name': 'x', 'output_slot': 0, - 'debug_op': 'DebugIdentity', 'state': 'break'}) - - # Deactivate the breakpoint right away. - self._serverGet( - 'gated_grpc', - {'mode': 'set_state', 'node_name': 'x', 'output_slot': 0, - 'debug_op': 'DebugIdentity', 'state': 'disable'}) - - # Proceed to the end of the Session.run(). - self._serverGet('ack') - - session_run_thread.join() - self.assertAllClose([[230.0]], session_run_results) - - # Verify that there is no breakpoint activated. - breakpoints_response = self._serverGet( - 'gated_grpc', {'mode': 'breakpoints'}) - breakpoints_data = self._deserializeResponse(breakpoints_response) - self.assertEqual([], breakpoints_data) - - def testActivateTwoBreakpoints(self): - session_run_thread, session_run_results = self._runSimpleAddMultiplyGraph() - - comm_response = self._serverGet('comm', {'pos': 1}) - - # Activate breakpoint for x:0. - self._serverGet( - 'gated_grpc', - {'mode': 'set_state', 'node_name': 'x', 'output_slot': 0, - 'debug_op': 'DebugIdentity', 'state': 'break'}) - # Activate breakpoint for y:0. - self._serverGet( - 'gated_grpc', - {'mode': 'set_state', 'node_name': 'y', 'output_slot': 0, - 'debug_op': 'DebugIdentity', 'state': 'break'}) - - # Proceed to breakpoint x:0. - self._serverGet('ack') - comm_response = self._serverGet('comm', {'pos': 2}) - comm_data = self._deserializeResponse(comm_response) - self.assertGreater(comm_data['timestamp'], 0) - self.assertEqual('tensor', comm_data['type']) - self.assertEqual('float32', comm_data['data']['dtype']) - self.assertEqual([1], comm_data['data']['shape']) - self.assertEqual('x', comm_data['data']['node_name']) - self.assertEqual(0, comm_data['data']['output_slot']) - self.assertEqual('DebugIdentity', comm_data['data']['debug_op']) - self.assertAllClose([200.0], comm_data['data']['values']) - - # Proceed to breakpoint y:0. - self._serverGet('ack') - comm_response = self._serverGet('comm', {'pos': 3}) - comm_data = self._deserializeResponse(comm_response) - self.assertGreater(comm_data['timestamp'], 0) - self.assertEqual('tensor', comm_data['type']) - self.assertEqual('float32', comm_data['data']['dtype']) - self.assertEqual([1], comm_data['data']['shape']) - self.assertEqual('y', comm_data['data']['node_name']) - self.assertEqual(0, comm_data['data']['output_slot']) - self.assertEqual('DebugIdentity', comm_data['data']['debug_op']) - self.assertAllClose([230.0], comm_data['data']['values']) - - # Proceed to the end of the Session.run(). - self._serverGet('ack') - - session_run_thread.join() - self.assertAllClose([[230.0]], session_run_results) - - # Verify that the activated breakpoints are remembered. - breakpoints_response = self._serverGet( - 'gated_grpc', {'mode': 'breakpoints'}) - breakpoints_data = self._deserializeResponse(breakpoints_response) - self.assertItemsEqual( - [['x', 0, 'DebugIdentity'], ['y', 0, 'DebugIdentity']], - breakpoints_data) - - def testCommResponseOmitsLargeSizedTensorValues(self): - session_run_thread, session_run_results = ( - self._runSimpleAddMultiplyGraph(10)) - - comm_response = self._serverGet('comm', {'pos': 1}) - comm_data = self._deserializeResponse(comm_response) - self.assertGreater(comm_data['timestamp'], 0) - self.assertEqual('meta', comm_data['type']) - self.assertEqual({'run_key': ['', 'y:0', '']}, comm_data['data']) - - # Activate breakpoint for inc_a:0. - self._serverGet( - 'gated_grpc', - {'mode': 'set_state', 'node_name': 'x', 'output_slot': 0, - 'debug_op': 'DebugIdentity', 'state': 'break'}) - - # Continue to the breakpiont at x:0. - self._serverGet('ack') - comm_response = self._serverGet('comm', {'pos': 2}) - comm_data = self._deserializeResponse(comm_response) - self.assertEqual('tensor', comm_data['type']) - self.assertEqual('float32', comm_data['data']['dtype']) - self.assertEqual([10], comm_data['data']['shape']) - self.assertEqual('x', comm_data['data']['node_name']) - self.assertEqual(0, comm_data['data']['output_slot']) - self.assertEqual('DebugIdentity', comm_data['data']['debug_op']) - # Verify that the large-sized tensor gets omitted in the comm response. - self.assertEqual(None, comm_data['data']['values']) - - # Use the /tensor_data endpoint to obtain the full value of x:0. - tensor_response = self._serverGet( - 'tensor_data', - {'watch_key': 'x:0:DebugIdentity', - 'time_indices': '-1', - 'mapping': '', - 'slicing': ''}) - tensor_data = self._deserializeResponse(tensor_response) - self.assertEqual(None, tensor_data['error']) - self.assertAllClose([[200.0] * 10], tensor_data['tensor_data']) - - # Use the /tensor_data endpoint to obtain the sliced value of x:0. - tensor_response = self._serverGet( - 'tensor_data', - {'watch_key': 'x:0:DebugIdentity', - 'time_indices': '-1', - 'mapping': '', - 'slicing': '[:5]'}) - tensor_data = self._deserializeResponse(tensor_response) - self.assertEqual(None, tensor_data['error']) - self.assertAllClose([[200.0] * 5], tensor_data['tensor_data']) - - # Continue to the end. - self._serverGet('ack') - - session_run_thread.join() - self.assertAllClose([[230.0] * 10], session_run_results) - - def testMultipleSessionRunsTensorValueFullHistory(self): - session_run_thread, session_run_results = ( - self._runMultiStepAssignAddGraph(2)) - - comm_response = self._serverGet('comm', {'pos': 1}) - comm_data = self._deserializeResponse(comm_response) - self.assertGreater(comm_data['timestamp'], 0) - self.assertEqual('meta', comm_data['type']) - self.assertEqual({'run_key': ['', 'inc_a:0', '']}, comm_data['data']) - - # Activate breakpoint for inc_a:0. - self._serverGet( - 'gated_grpc', - {'mode': 'set_state', 'node_name': 'inc_a', 'output_slot': 0, - 'debug_op': 'DebugIdentity', 'state': 'break'}) - - # Continue to inc_a:0 for the 1st time. - self._serverGet('ack') - comm_response = self._serverGet('comm', {'pos': 2}) - comm_data = self._deserializeResponse(comm_response) - self.assertEqual('tensor', comm_data['type']) - self.assertEqual('int32', comm_data['data']['dtype']) - self.assertEqual([], comm_data['data']['shape']) - self.assertEqual('inc_a', comm_data['data']['node_name']) - self.assertEqual(0, comm_data['data']['output_slot']) - self.assertEqual('DebugIdentity', comm_data['data']['debug_op']) - self.assertAllClose(11.0, comm_data['data']['values']) - - # Call /tensor_data to get the full history of the inc_a tensor (so far). - tensor_response = self._serverGet( - 'tensor_data', - {'watch_key': 'inc_a:0:DebugIdentity', - 'time_indices': ':', - 'mapping': '', - 'slicing': ''}) - tensor_data = self._deserializeResponse(tensor_response) - self.assertEqual({'tensor_data': [11], 'error': None}, tensor_data) - - # Continue to the beginning of the 2nd session.run. - self._serverGet('ack') - comm_response = self._serverGet('comm', {'pos': 3}) - comm_data = self._deserializeResponse(comm_response) - self.assertGreater(comm_data['timestamp'], 0) - self.assertEqual('meta', comm_data['type']) - self.assertEqual({'run_key': ['', 'inc_a:0', '']}, comm_data['data']) - - # Continue to inc_a:0 for the 2nd time. - self._serverGet('ack') - comm_response = self._serverGet('comm', {'pos': 4}) - comm_data = self._deserializeResponse(comm_response) - self.assertEqual('tensor', comm_data['type']) - self.assertEqual('int32', comm_data['data']['dtype']) - self.assertEqual([], comm_data['data']['shape']) - self.assertEqual('inc_a', comm_data['data']['node_name']) - self.assertEqual(0, comm_data['data']['output_slot']) - self.assertEqual('DebugIdentity', comm_data['data']['debug_op']) - self.assertAllClose(12.0, comm_data['data']['values']) - - # Call /tensor_data to get the full history of the inc_a tensor. - tensor_response = self._serverGet( - 'tensor_data', - {'watch_key': 'inc_a:0:DebugIdentity', - 'time_indices': ':', - 'mapping': '', - 'slicing': ''}) - tensor_data = self._deserializeResponse(tensor_response) - self.assertEqual({'tensor_data': [11, 12], 'error': None}, tensor_data) - - # Call /tensor_data to get the latst time index of the inc_a tensor. - tensor_response = self._serverGet( - 'tensor_data', - {'watch_key': 'inc_a:0:DebugIdentity', - 'time_indices': '-1', - 'mapping': '', - 'slicing': ''}) - tensor_data = self._deserializeResponse(tensor_response) - self.assertEqual({'tensor_data': [12], 'error': None}, tensor_data) - - # Continue to the end. - self._serverGet('ack') - - session_run_thread.join() - self.assertAllClose([11.0, 12.0], session_run_results) - - def testSetBreakpointOnNoTensorOp(self): - session_run_thread, session_run_results = self._runTfGroupGraph() - - comm_response = self._serverGet('comm', {'pos': 1}) - comm_data = self._deserializeResponse(comm_response) - self.assertGreater(comm_data['timestamp'], 0) - self.assertEqual('meta', comm_data['type']) - self.assertEqual({'run_key': ['', '', 'inc_ab']}, comm_data['data']) - - # Activate breakpoint for inc_a:0. - self._serverGet( - 'gated_grpc', - {'mode': 'set_state', 'node_name': 'inc_a', 'output_slot': 0, - 'debug_op': 'DebugIdentity', 'state': 'break'}) - - # Activate breakpoint for inc_ab. - self._serverGet( - 'gated_grpc', - {'mode': 'set_state', 'node_name': 'inc_ab', 'output_slot': -1, - 'debug_op': 'DebugIdentity', 'state': 'break'}) - - # Continue to inc_a:0. - self._serverGet('ack') - comm_response = self._serverGet('comm', {'pos': 2}) - comm_data = self._deserializeResponse(comm_response) - self.assertEqual('tensor', comm_data['type']) - self.assertEqual('int32', comm_data['data']['dtype']) - self.assertEqual([], comm_data['data']['shape']) - self.assertEqual('inc_a', comm_data['data']['node_name']) - self.assertEqual(0, comm_data['data']['output_slot']) - self.assertEqual('DebugIdentity', comm_data['data']['debug_op']) - self.assertAllClose(11.0, comm_data['data']['values']) - - # Continue to the end. The breakpoint at inc_ab should not have blocked - # the execution, due to the fact that inc_ab is a tf.group op that produces - # no output. - self._serverGet('ack') - session_run_thread.join() - self.assertEqual([None], session_run_results) - - breakpoints_response = self._serverGet( - 'gated_grpc', {'mode': 'breakpoints'}) - breakpoints_data = self._deserializeResponse(breakpoints_response) - self.assertItemsEqual( - [['inc_a', 0, 'DebugIdentity'], ['inc_ab', -1, 'DebugIdentity']], - breakpoints_data) - - def testCommDataCanBeServedToMultipleClients(self): - session_run_thread, session_run_results = self._runSimpleAddMultiplyGraph() - - comm_response = self._serverGet('comm', {'pos': 1}) - comm_data_1 = self._deserializeResponse(comm_response) - - # Activate breakpoint for x:0. - self._serverGet( - 'gated_grpc', - {'mode': 'set_state', 'node_name': 'x', 'output_slot': 0, - 'debug_op': 'DebugIdentity', 'state': 'break'}) - # Activate breakpoint for y:0. - self._serverGet( - 'gated_grpc', - {'mode': 'set_state', 'node_name': 'y', 'output_slot': 0, - 'debug_op': 'DebugIdentity', 'state': 'break'}) - - # Proceed to breakpoint x:0. - self._serverGet('ack') - comm_response = self._serverGet('comm', {'pos': 2}) - comm_data_2 = self._deserializeResponse(comm_response) - self.assertGreater(comm_data_2['timestamp'], 0) - self.assertEqual('tensor', comm_data_2['type']) - self.assertEqual('float32', comm_data_2['data']['dtype']) - self.assertEqual([1], comm_data_2['data']['shape']) - self.assertEqual('x', comm_data_2['data']['node_name']) - self.assertEqual(0, comm_data_2['data']['output_slot']) - self.assertEqual('DebugIdentity', comm_data_2['data']['debug_op']) - self.assertAllClose([200.0], comm_data_2['data']['values']) - - # Proceed to breakpoint y:0. - self._serverGet('ack') - comm_response = self._serverGet('comm', {'pos': 3}) - comm_data_3 = self._deserializeResponse(comm_response) - self.assertGreater(comm_data_3['timestamp'], 0) - self.assertEqual('tensor', comm_data_3['type']) - self.assertEqual('float32', comm_data_3['data']['dtype']) - self.assertEqual([1], comm_data_3['data']['shape']) - self.assertEqual('y', comm_data_3['data']['node_name']) - self.assertEqual(0, comm_data_3['data']['output_slot']) - self.assertEqual('DebugIdentity', comm_data_3['data']['debug_op']) - self.assertAllClose([230.0], comm_data_3['data']['values']) - - # Proceed to the end of the Session.run(). - self._serverGet('ack') - - session_run_thread.join() - self.assertAllClose([[230.0]], session_run_results) - - # A 2nd client requests for comm data at positions 1, 2 and 3 again. - comm_response = self._serverGet('comm', {'pos': 1}) - self.assertEqual(comm_data_1, self._deserializeResponse(comm_response)) - comm_response = self._serverGet('comm', {'pos': 2}) - self.assertEqual(comm_data_2, self._deserializeResponse(comm_response)) - comm_response = self._serverGet('comm', {'pos': 3}) - self.assertEqual(comm_data_3, self._deserializeResponse(comm_response)) - - def testInvalidBreakpointStateLeadsTo400Response(self): - session_run_thread, session_run_results = self._runSimpleAddMultiplyGraph() - self._serverGet('comm', {'pos': 1}) - - # Use an invalid state ('bad_state') when setting a breakpoint state. - response = self._serverGet( - 'gated_grpc', - {'mode': 'set_state', 'node_name': 'x', 'output_slot': 0, - 'debug_op': 'DebugIdentity', 'state': 'bad_state'}, - expected_status_code=400) - data = self._deserializeResponse(response) - self.assertEqual('Unrecognized new state for x:0:DebugIdentity: bad_state', - data['error']) - - self._serverGet('ack') - session_run_thread.join() - self.assertAllClose([[230.0]], session_run_results) - - def testInvalidModeArgForGatedGrpcRouteLeadsTo400Response(self): - session_run_thread, session_run_results = self._runSimpleAddMultiplyGraph() - self._serverGet('comm', {'pos': 1}) - - # Use an invalid mode argument ('bad_mode') when calling the 'gated_grpc' - # endpoint. - response = self._serverGet( - 'gated_grpc', - {'mode': 'bad_mode', 'node_name': 'x', 'output_slot': 0, - 'debug_op': 'DebugIdentity', 'state': 'break'}, - expected_status_code=400) - data = self._deserializeResponse(response) - self.assertEqual('Unrecognized mode for the gated_grpc route: bad_mode', - data['error']) - - self._serverGet('ack') - session_run_thread.join() - self.assertAllClose([[230.0]], session_run_results) - - def testDebuggerHostAndGrpcPortEndpoint(self): - response = self._serverGet('debugger_grpc_host_port') - response_data = self._deserializeResponse(response) - self.assertTrue(response_data['host']) - self.assertEqual(self._debugger_port, response_data['port']) - - def testGetSourceFilePaths(self): - session_run_thread, session_run_results = self._runSimpleAddMultiplyGraph() - self._serverGet('comm', {'pos': 1}) - - source_paths_response = self._serverGet('source_code', {'mode': 'paths'}) - response_data = self._deserializeResponse(source_paths_response) - self.assertIn(__file__, response_data['paths']) - - self._serverGet('ack') - session_run_thread.join() - self.assertAllClose([[230.0]], session_run_results) - - def testGetSourceFileContentWithValidFilePath(self): - session_run_thread, session_run_results = self._runSimpleAddMultiplyGraph() - self._serverGet('comm', {'pos': 1}) - - file_content_response = self._serverGet( - 'source_code', {'mode': 'content', 'file_path': __file__}) - response_data = self._deserializeResponse(file_content_response) - # Verify that the content of this file is included. - self.assertTrue(response_data['content'][__file__]) - # Verify that for the lines of the file that create TensorFlow ops, the list - # of op names and their stack heights are included. - op_linenos = collections.defaultdict(set) - for lineno in response_data['lineno_to_op_name_and_stack_pos']: - self.assertGreater(int(lineno), 0) - for op_name, stack_pos in response_data[ - 'lineno_to_op_name_and_stack_pos'][lineno]: - op_linenos[op_name].add(lineno) - self.assertGreaterEqual(stack_pos, 0) - self.assertTrue(op_linenos['a']) - self.assertTrue(op_linenos['a/Assign']) - self.assertTrue(op_linenos['a/initial_value']) - self.assertTrue(op_linenos['a/read']) - self.assertTrue(op_linenos['b']) - self.assertTrue(op_linenos['b/Assign']) - self.assertTrue(op_linenos['b/initial_value']) - self.assertTrue(op_linenos['b/read']) - self.assertTrue(op_linenos['c']) - self.assertTrue(op_linenos['c/Assign']) - self.assertTrue(op_linenos['c/initial_value']) - self.assertTrue(op_linenos['c/read']) - self.assertTrue(op_linenos['x']) - self.assertTrue(op_linenos['y']) - - self._serverGet('ack') - session_run_thread.join() - self.assertAllClose([[230.0]], session_run_results) - - def testGetSourceOpTraceback(self): - session_run_thread, session_run_results = self._runSimpleAddMultiplyGraph() - self._serverGet('comm', {'pos': 1}) - - for op_name in ('a', 'b', 'c', 'x', 'y'): - op_traceback_reponse = self._serverGet( - 'source_code', {'mode': 'op_traceback', 'op_name': op_name}) - response_data = self._deserializeResponse(op_traceback_reponse) - found_current_file = False - for file_path, lineno in response_data['op_traceback'][op_name]: - self.assertGreater(lineno, 0) - if file_path == __file__: - found_current_file = True - break - self.assertTrue(found_current_file) - - self._serverGet('ack') - session_run_thread.join() - self.assertAllClose([[230.0]], session_run_results) - - def _runInitializer(self): - session_run_results = [] - def session_run_job(): - with tf.Session() as sess: - a = tf.Variable([10.0] * 10, name='a') - sess = tf_debug.TensorBoardDebugWrapperSession(sess, self._debugger_url) - # Run the initializer with a debugger-wrapped tf.Session. - session_run_results.append(sess.run(a.initializer)) - session_run_results.append(sess.run(a)) - session_run_thread = threading.Thread(target=session_run_job) - session_run_thread.start() - return session_run_thread, session_run_results - - def testTensorDataForUnitializedTensorIsHandledCorrectly(self): - session_run_thread, session_run_results = self._runInitializer() - # Activate breakpoint for a:0. - self._serverGet( - 'gated_grpc', - {'mode': 'set_state', 'node_name': 'a', 'output_slot': 0, - 'debug_op': 'DebugIdentity', 'state': 'break'}) - self._serverGet('ack') - self._serverGet('ack') - self._serverGet('ack') - self._serverGet('ack') - session_run_thread.join() - self.assertEqual(2, len(session_run_results)) - self.assertIsNone(session_run_results[0]) - self.assertAllClose([10.0] * 10, session_run_results[1]) - - # Get tensor data without slicing. - tensor_response = self._serverGet( - 'tensor_data', - {'watch_key': 'a:0:DebugIdentity', - 'time_indices': ':', - 'mapping': '', - 'slicing': ''}) - tensor_data = self._deserializeResponse(tensor_response) - self.assertIsNone(tensor_data['error']) - tensor_data = tensor_data['tensor_data'] - self.assertEqual(2, len(tensor_data)) - self.assertIsNone(tensor_data[0]) - self.assertAllClose([10.0] * 10, tensor_data[1]) - - # Get tensor data with slicing. - tensor_response = self._serverGet( - 'tensor_data', - {'watch_key': 'a:0:DebugIdentity', - 'time_indices': ':', - 'mapping': '', - 'slicing': '[:5]'}) - tensor_data = self._deserializeResponse(tensor_response) - self.assertIsNone(tensor_data['error']) - tensor_data = tensor_data['tensor_data'] - self.assertEqual(2, len(tensor_data)) - self.assertIsNone(tensor_data[0]) - self.assertAllClose([10.0] * 5, tensor_data[1]) - - def testCommDataForUninitializedTensorIsHandledCorrectly(self): - session_run_thread, _ = self._runInitializer() - # Activate breakpoint for a:0. - self._serverGet( - 'gated_grpc', - {'mode': 'set_state', 'node_name': 'a', 'output_slot': 0, - 'debug_op': 'DebugIdentity', 'state': 'break'}) - self._serverGet('ack') - comm_response = self._serverGet('comm', {'pos': 2}) - comm_data = self._deserializeResponse(comm_response) - self.assertEqual('tensor', comm_data['type']) - self.assertEqual('Uninitialized', comm_data['data']['dtype']) - self.assertEqual('Uninitialized', comm_data['data']['shape']) - self.assertEqual('N/A', comm_data['data']['values']) - self.assertEqual( - 'a/(a)', comm_data['data']['maybe_base_expanded_node_name']) - self._serverGet('ack') - self._serverGet('ack') - self._serverGet('ack') - session_run_thread.join() - - def _runHealthPillNetwork(self): - session_run_results = [] - def session_run_job(): - with tf.Session() as sess: - a = tf.Variable( - [np.nan, np.inf, np.inf, -np.inf, -np.inf, -np.inf, 10, 20, 30], - dtype=tf.float32, name='a') - session_run_results.append(sess.run(a.initializer)) - sess = tf_debug.TensorBoardDebugWrapperSession(sess, self._debugger_url) - session_run_results.append(sess.run(a)) - session_run_thread = threading.Thread(target=session_run_job) - session_run_thread.start() - return session_run_thread, session_run_results - - def testHealthPill(self): - session_run_thread, _ = self._runHealthPillNetwork() - # Activate breakpoint for a:0. - self._serverGet( - 'gated_grpc', - {'mode': 'set_state', 'node_name': 'a', 'output_slot': 0, - 'debug_op': 'DebugIdentity', 'state': 'break'}) - self._serverGet('ack') - self._serverGet('ack') - session_run_thread.join() - tensor_response = self._serverGet( - 'tensor_data', - {'watch_key': 'a:0:DebugIdentity', - 'time_indices': '-1', - 'mapping': 'health-pill', - 'slicing': ''}) - tensor_data = self._deserializeResponse(tensor_response) - self.assertIsNone(tensor_data['error']) - tensor_data = tensor_data['tensor_data'][0] - self.assertAllClose(1.0, tensor_data[0]) # IsInitialized. - self.assertAllClose(9.0, tensor_data[1]) # Total count. - self.assertAllClose(1.0, tensor_data[2]) # NaN count. - self.assertAllClose(3.0, tensor_data[3]) # -Infinity count. - self.assertAllClose(0.0, tensor_data[4]) # Finite negative count. - self.assertAllClose(0.0, tensor_data[5]) # Zero count. - self.assertAllClose(3.0, tensor_data[6]) # Positive count. - self.assertAllClose(2.0, tensor_data[7]) # +Infinity count. - self.assertAllClose(10.0, tensor_data[8]) # Min. - self.assertAllClose(30.0, tensor_data[9]) # Max. - self.assertAllClose(20.0, tensor_data[10]) # Mean. - self.assertAllClose( - np.var([10.0, 20.0, 30.0]), tensor_data[11]) # Variance. - - def _runAsciiStringNetwork(self): - session_run_results = [] - def session_run_job(): - with tf.Session() as sess: - str1 = tf.Variable('abc', name='str1') - str2 = tf.Variable('def', name='str2') - str_concat = tf.add(str1, str2, name='str_concat') - sess.run(tf.global_variables_initializer()) - sess = tf_debug.TensorBoardDebugWrapperSession(sess, self._debugger_url) - session_run_results.append(sess.run(str_concat)) - session_run_thread = threading.Thread(target=session_run_job) - session_run_thread.start() - return session_run_thread, session_run_results - - def testAsciiStringTensorIsHandledCorrectly(self): - session_run_thread, session_run_results = self._runAsciiStringNetwork() - # Activate breakpoint for str1:0. - self._serverGet( - 'gated_grpc', - {'mode': 'set_state', 'node_name': 'str1', 'output_slot': 0, - 'debug_op': 'DebugIdentity', 'state': 'break'}) - self._serverGet('ack') - self._serverGet('ack') - comm_response = self._serverGet('comm', {'pos': 2}) - comm_data = self._deserializeResponse(comm_response) - self.assertEqual('tensor', comm_data['type']) - self.assertEqual('string', comm_data['data']['dtype']) - self.assertEqual([], comm_data['data']['shape']) - self.assertEqual('abc', comm_data['data']['values']) - self.assertEqual( - 'str1/(str1)', comm_data['data']['maybe_base_expanded_node_name']) - session_run_thread.join() - self.assertEqual(1, len(session_run_results)) - self.assertEqual(b"abcdef", session_run_results[0]) - - # Get the value of a tensor without mapping. - tensor_response = self._serverGet( - 'tensor_data', - {'watch_key': 'str1:0:DebugIdentity', - 'time_indices': '-1', - 'mapping': '', - 'slicing': ''}) - tensor_data = self._deserializeResponse(tensor_response) - self.assertEqual(None, tensor_data['error']) - self.assertEqual(['abc'], tensor_data['tensor_data']) - - # Get the health pill of a string tensor. - tensor_response = self._serverGet( - 'tensor_data', - {'watch_key': 'str1:0:DebugIdentity', - 'time_indices': '-1', - 'mapping': 'health-pill', - 'slicing': ''}) - tensor_data = self._deserializeResponse(tensor_response) - self.assertEqual(None, tensor_data['error']) - self.assertEqual([None], tensor_data['tensor_data']) - - def _runBinaryStringNetwork(self): - session_run_results = [] - def session_run_job(): - with tf.Session() as sess: - str1 = tf.Variable([b'\x01' * 3, b'\x02' * 3], name='str1') - str2 = tf.Variable([b'\x03' * 3, b'\x04' * 3], name='str2') - str_concat = tf.add(str1, str2, name='str_concat') - sess.run(tf.global_variables_initializer()) - sess = tf_debug.TensorBoardDebugWrapperSession(sess, self._debugger_url) - session_run_results.append(sess.run(str_concat)) - session_run_thread = threading.Thread(target=session_run_job) - session_run_thread.start() - return session_run_thread, session_run_results - - def testBinaryStringTensorIsHandledCorrectly(self): - session_run_thread, session_run_results = self._runBinaryStringNetwork() - # Activate breakpoint for str1:0. - self._serverGet( - 'gated_grpc', - {'mode': 'set_state', 'node_name': 'str1', 'output_slot': 0, - 'debug_op': 'DebugIdentity', 'state': 'break'}) - self._serverGet('ack') - self._serverGet('ack') - comm_response = self._serverGet('comm', {'pos': 2}) - comm_data = self._deserializeResponse(comm_response) - self.assertEqual('tensor', comm_data['type']) - self.assertEqual('string', comm_data['data']['dtype']) - self.assertEqual([2], comm_data['data']['shape']) - self.assertEqual(2, len(comm_data['data']['values'])) - self.assertEqual( - b'=01' * 3, tf.compat.as_bytes(comm_data['data']['values'][0])) - self.assertEqual( - b'=02' * 3, tf.compat.as_bytes(comm_data['data']['values'][1])) - self.assertEqual( - 'str1/(str1)', comm_data['data']['maybe_base_expanded_node_name']) - session_run_thread.join() - self.assertEqual(1, len(session_run_results)) - self.assertAllEqual( - np.array([b'\x01\x01\x01\x03\x03\x03', b'\x02\x02\x02\x04\x04\x04'], - dtype=np.object), - session_run_results[0]) - - # Get the value of a tensor without mapping. - tensor_response = self._serverGet( - 'tensor_data', - {'watch_key': 'str1:0:DebugIdentity', - 'time_indices': '-1', - 'mapping': '', - 'slicing': ''}) - tensor_data = self._deserializeResponse(tensor_response) - self.assertEqual(None, tensor_data['error']) - self.assertEqual(2, len(tensor_data['tensor_data'][0])) - self.assertEqual( - b'=01=01=01', tf.compat.as_bytes(tensor_data['tensor_data'][0][0])) - self.assertEqual( - b'=02=02=02', tf.compat.as_bytes(tensor_data['tensor_data'][0][1])) - - # Get the health pill of a string tensor. - tensor_response = self._serverGet( - 'tensor_data', - {'watch_key': 'str1:0:DebugIdentity', - 'time_indices': '-1', - 'mapping': 'health-pill', - 'slicing': ''}) - tensor_data = self._deserializeResponse(tensor_response) - self.assertEqual(None, tensor_data['error']) - self.assertEqual([None], tensor_data['tensor_data']) + def setUp(self): + super(InteractiveDebuggerPluginTest, self).setUp() + + self._dummy_logdir = tempfile.mkdtemp() + dummy_multiplexer = event_multiplexer.EventMultiplexer({}) + self._debugger_port = portpicker.pick_unused_port() + self._debugger_url = "grpc://localhost:%d" % self._debugger_port + context = base_plugin.TBContext( + logdir=self._dummy_logdir, multiplexer=dummy_multiplexer + ) + self._debugger_plugin = interactive_debugger_plugin.InteractiveDebuggerPlugin( + context + ) + self._debugger_plugin.listen(self._debugger_port) + + wsgi_app = application.TensorBoardWSGI([self._debugger_plugin]) + self._server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse) + + def tearDown(self): + # In some cases (e.g., an empty test method body), the stop_server() method + # may get called before the server is started, leading to a ValueError. + while True: + try: + self._debugger_plugin._debugger_data_server.stop_server() + break + except ValueError: + pass + shutil.rmtree(self._dummy_logdir, ignore_errors=True) + super(InteractiveDebuggerPluginTest, self).tearDown() + + def _serverGet(self, path, params=None, expected_status_code=200): + """Send the serve a GET request and obtain the response. + + Args: + path: URL path (excluding the prefix), without parameters encoded. + params: Query parameters to be encoded in the URL, as a dict. + expected_status_code: Expected status code. + + Returns: + Response from server. + """ + url = _SERVER_URL_PREFIX + path + if params: + url += "?" + urllib.parse.urlencode(params) + response = self._server.get(url) + self.assertEqual(expected_status_code, response.status_code) + return response + + def _deserializeResponse(self, response): + """Deserializes byte content that is a JSON encoding. + + Args: + response: A response object. + + Returns: + The deserialized python object decoded from JSON. + """ + return json.loads(response.get_data().decode("utf-8")) + + def _runSimpleAddMultiplyGraph(self, variable_size=1): + session_run_results = [] + + def session_run_job(): + with tf.Session() as sess: + a = tf.Variable([10.0] * variable_size, name="a") + b = tf.Variable([20.0] * variable_size, name="b") + c = tf.Variable([30.0] * variable_size, name="c") + x = tf.multiply(a, b, name="x") + y = tf.add(x, c, name="y") + + sess.run(tf.global_variables_initializer()) + + sess = tf_debug.TensorBoardDebugWrapperSession( + sess, self._debugger_url + ) + session_run_results.append(sess.run(y)) + + session_run_thread = threading.Thread(target=session_run_job) + session_run_thread.start() + return session_run_thread, session_run_results + + def _runMultiStepAssignAddGraph(self, steps): + session_run_results = [] + + def session_run_job(): + with tf.Session() as sess: + a = tf.Variable(10, dtype=tf.int32, name="a") + b = tf.Variable(1, dtype=tf.int32, name="b") + inc_a = tf.assign_add(a, b, name="inc_a") + + sess.run(tf.global_variables_initializer()) + + sess = tf_debug.TensorBoardDebugWrapperSession( + sess, self._debugger_url + ) + for _ in range(steps): + session_run_results.append(sess.run(inc_a)) + + session_run_thread = threading.Thread(target=session_run_job) + session_run_thread.start() + return session_run_thread, session_run_results + + def _runTfGroupGraph(self): + session_run_results = [] + + def session_run_job(): + with tf.Session() as sess: + a = tf.Variable(10, dtype=tf.int32, name="a") + b = tf.Variable(20, dtype=tf.int32, name="b") + d = tf.constant(1, dtype=tf.int32, name="d") + inc_a = tf.assign_add(a, d, name="inc_a") + inc_b = tf.assign_add(b, d, name="inc_b") + inc_ab = tf.group([inc_a, inc_b], name="inc_ab") + + sess.run(tf.global_variables_initializer()) + + sess = tf_debug.TensorBoardDebugWrapperSession( + sess, self._debugger_url + ) + session_run_results.append(sess.run(inc_ab)) + + session_run_thread = threading.Thread(target=session_run_job) + session_run_thread.start() + return session_run_thread, session_run_results + + def testCommAndAckWithoutBreakpoints(self): + ( + session_run_thread, + session_run_results, + ) = self._runSimpleAddMultiplyGraph() + + comm_response = self._serverGet("comm", {"pos": 1}) + response_data = self._deserializeResponse(comm_response) + self.assertGreater(response_data["timestamp"], 0) + self.assertEqual("meta", response_data["type"]) + self.assertEqual({"run_key": ["", "y:0", ""]}, response_data["data"]) + + self._serverGet("ack") + + session_run_thread.join() + self.assertAllClose([[230.0]], session_run_results) + + def testGetDeviceNamesAndDebuggerGraph(self): + ( + session_run_thread, + session_run_results, + ) = self._runSimpleAddMultiplyGraph() + + comm_response = self._serverGet("comm", {"pos": 1}) + response_data = self._deserializeResponse(comm_response) + run_key = json.dumps(response_data["data"]["run_key"]) + + device_names_response = self._serverGet( + "gated_grpc", {"mode": "retrieve_device_names", "run_key": run_key} + ) + device_names_data = self._deserializeResponse(device_names_response) + self.assertEqual(1, len(device_names_data["device_names"])) + device_name = device_names_data["device_names"][0] + + graph_response = self._serverGet( + "debugger_graph", {"run_key": run_key, "device_name": device_name} + ) + self.assertTrue(graph_response.get_data()) + + self._serverGet("ack") + + session_run_thread.join() + self.assertAllClose([[230.0]], session_run_results) + + def testRetrieveAllGatedGrpcTensors(self): + ( + session_run_thread, + session_run_results, + ) = self._runSimpleAddMultiplyGraph() + + comm_response = self._serverGet("comm", {"pos": 1}) + response_data = self._deserializeResponse(comm_response) + run_key = json.dumps(response_data["data"]["run_key"]) + + retrieve_all_response = self._serverGet( + "gated_grpc", {"mode": "retrieve_all", "run_key": run_key} + ) + retrieve_all_data = self._deserializeResponse(retrieve_all_response) + self.assertTrue(retrieve_all_data["device_names"]) + # No breakpoints have been activated. + self.assertEqual([], retrieve_all_data["breakpoints"]) + device_name = retrieve_all_data["device_names"][0] + tensor_names = [ + item[0] + for item in retrieve_all_data["gated_grpc_tensors"][device_name] + ] + self.assertItemsEqual( + ["a", "a/read", "b", "b/read", "x", "c", "c/read", "y"], + tensor_names, + ) + + self._serverGet("ack") + + session_run_thread.join() + self.assertAllClose([[230.0]], session_run_results) + + def testActivateOneBreakpoint(self): + ( + session_run_thread, + session_run_results, + ) = self._runSimpleAddMultiplyGraph() + + comm_response = self._serverGet("comm", {"pos": 1}) + + # Activate breakpoint for x:0. + self._serverGet( + "gated_grpc", + { + "mode": "set_state", + "node_name": "x", + "output_slot": 0, + "debug_op": "DebugIdentity", + "state": "break", + }, + ) + + # Proceed to breakpoint x:0. + self._serverGet("ack") + comm_response = self._serverGet("comm", {"pos": 2}) + comm_data = self._deserializeResponse(comm_response) + self.assertGreater(comm_data["timestamp"], 0) + self.assertEqual("tensor", comm_data["type"]) + self.assertEqual("float32", comm_data["data"]["dtype"]) + self.assertEqual([1], comm_data["data"]["shape"]) + self.assertEqual("x", comm_data["data"]["node_name"]) + self.assertEqual(0, comm_data["data"]["output_slot"]) + self.assertEqual("DebugIdentity", comm_data["data"]["debug_op"]) + self.assertAllClose([200.0], comm_data["data"]["values"]) + + # Proceed to the end of the Session.run(). + self._serverGet("ack") + + session_run_thread.join() + self.assertAllClose([[230.0]], session_run_results) + + # Verify that the activated breakpoint is remembered. + breakpoints_response = self._serverGet( + "gated_grpc", {"mode": "breakpoints"} + ) + breakpoints_data = self._deserializeResponse(breakpoints_response) + self.assertEqual([["x", 0, "DebugIdentity"]], breakpoints_data) + + def testActivateAndDeactivateOneBreakpoint(self): + ( + session_run_thread, + session_run_results, + ) = self._runSimpleAddMultiplyGraph() + + self._serverGet("comm", {"pos": 1}) + + # Activate breakpoint for x:0. + self._serverGet( + "gated_grpc", + { + "mode": "set_state", + "node_name": "x", + "output_slot": 0, + "debug_op": "DebugIdentity", + "state": "break", + }, + ) + + # Deactivate the breakpoint right away. + self._serverGet( + "gated_grpc", + { + "mode": "set_state", + "node_name": "x", + "output_slot": 0, + "debug_op": "DebugIdentity", + "state": "disable", + }, + ) + + # Proceed to the end of the Session.run(). + self._serverGet("ack") + + session_run_thread.join() + self.assertAllClose([[230.0]], session_run_results) + + # Verify that there is no breakpoint activated. + breakpoints_response = self._serverGet( + "gated_grpc", {"mode": "breakpoints"} + ) + breakpoints_data = self._deserializeResponse(breakpoints_response) + self.assertEqual([], breakpoints_data) + + def testActivateTwoBreakpoints(self): + ( + session_run_thread, + session_run_results, + ) = self._runSimpleAddMultiplyGraph() + + comm_response = self._serverGet("comm", {"pos": 1}) + + # Activate breakpoint for x:0. + self._serverGet( + "gated_grpc", + { + "mode": "set_state", + "node_name": "x", + "output_slot": 0, + "debug_op": "DebugIdentity", + "state": "break", + }, + ) + # Activate breakpoint for y:0. + self._serverGet( + "gated_grpc", + { + "mode": "set_state", + "node_name": "y", + "output_slot": 0, + "debug_op": "DebugIdentity", + "state": "break", + }, + ) + + # Proceed to breakpoint x:0. + self._serverGet("ack") + comm_response = self._serverGet("comm", {"pos": 2}) + comm_data = self._deserializeResponse(comm_response) + self.assertGreater(comm_data["timestamp"], 0) + self.assertEqual("tensor", comm_data["type"]) + self.assertEqual("float32", comm_data["data"]["dtype"]) + self.assertEqual([1], comm_data["data"]["shape"]) + self.assertEqual("x", comm_data["data"]["node_name"]) + self.assertEqual(0, comm_data["data"]["output_slot"]) + self.assertEqual("DebugIdentity", comm_data["data"]["debug_op"]) + self.assertAllClose([200.0], comm_data["data"]["values"]) + + # Proceed to breakpoint y:0. + self._serverGet("ack") + comm_response = self._serverGet("comm", {"pos": 3}) + comm_data = self._deserializeResponse(comm_response) + self.assertGreater(comm_data["timestamp"], 0) + self.assertEqual("tensor", comm_data["type"]) + self.assertEqual("float32", comm_data["data"]["dtype"]) + self.assertEqual([1], comm_data["data"]["shape"]) + self.assertEqual("y", comm_data["data"]["node_name"]) + self.assertEqual(0, comm_data["data"]["output_slot"]) + self.assertEqual("DebugIdentity", comm_data["data"]["debug_op"]) + self.assertAllClose([230.0], comm_data["data"]["values"]) + + # Proceed to the end of the Session.run(). + self._serverGet("ack") + + session_run_thread.join() + self.assertAllClose([[230.0]], session_run_results) + + # Verify that the activated breakpoints are remembered. + breakpoints_response = self._serverGet( + "gated_grpc", {"mode": "breakpoints"} + ) + breakpoints_data = self._deserializeResponse(breakpoints_response) + self.assertItemsEqual( + [["x", 0, "DebugIdentity"], ["y", 0, "DebugIdentity"]], + breakpoints_data, + ) + + def testCommResponseOmitsLargeSizedTensorValues(self): + ( + session_run_thread, + session_run_results, + ) = self._runSimpleAddMultiplyGraph(10) + + comm_response = self._serverGet("comm", {"pos": 1}) + comm_data = self._deserializeResponse(comm_response) + self.assertGreater(comm_data["timestamp"], 0) + self.assertEqual("meta", comm_data["type"]) + self.assertEqual({"run_key": ["", "y:0", ""]}, comm_data["data"]) + + # Activate breakpoint for inc_a:0. + self._serverGet( + "gated_grpc", + { + "mode": "set_state", + "node_name": "x", + "output_slot": 0, + "debug_op": "DebugIdentity", + "state": "break", + }, + ) + + # Continue to the breakpiont at x:0. + self._serverGet("ack") + comm_response = self._serverGet("comm", {"pos": 2}) + comm_data = self._deserializeResponse(comm_response) + self.assertEqual("tensor", comm_data["type"]) + self.assertEqual("float32", comm_data["data"]["dtype"]) + self.assertEqual([10], comm_data["data"]["shape"]) + self.assertEqual("x", comm_data["data"]["node_name"]) + self.assertEqual(0, comm_data["data"]["output_slot"]) + self.assertEqual("DebugIdentity", comm_data["data"]["debug_op"]) + # Verify that the large-sized tensor gets omitted in the comm response. + self.assertEqual(None, comm_data["data"]["values"]) + + # Use the /tensor_data endpoint to obtain the full value of x:0. + tensor_response = self._serverGet( + "tensor_data", + { + "watch_key": "x:0:DebugIdentity", + "time_indices": "-1", + "mapping": "", + "slicing": "", + }, + ) + tensor_data = self._deserializeResponse(tensor_response) + self.assertEqual(None, tensor_data["error"]) + self.assertAllClose([[200.0] * 10], tensor_data["tensor_data"]) + + # Use the /tensor_data endpoint to obtain the sliced value of x:0. + tensor_response = self._serverGet( + "tensor_data", + { + "watch_key": "x:0:DebugIdentity", + "time_indices": "-1", + "mapping": "", + "slicing": "[:5]", + }, + ) + tensor_data = self._deserializeResponse(tensor_response) + self.assertEqual(None, tensor_data["error"]) + self.assertAllClose([[200.0] * 5], tensor_data["tensor_data"]) + + # Continue to the end. + self._serverGet("ack") + + session_run_thread.join() + self.assertAllClose([[230.0] * 10], session_run_results) + + def testMultipleSessionRunsTensorValueFullHistory(self): + ( + session_run_thread, + session_run_results, + ) = self._runMultiStepAssignAddGraph(2) + + comm_response = self._serverGet("comm", {"pos": 1}) + comm_data = self._deserializeResponse(comm_response) + self.assertGreater(comm_data["timestamp"], 0) + self.assertEqual("meta", comm_data["type"]) + self.assertEqual({"run_key": ["", "inc_a:0", ""]}, comm_data["data"]) + + # Activate breakpoint for inc_a:0. + self._serverGet( + "gated_grpc", + { + "mode": "set_state", + "node_name": "inc_a", + "output_slot": 0, + "debug_op": "DebugIdentity", + "state": "break", + }, + ) + + # Continue to inc_a:0 for the 1st time. + self._serverGet("ack") + comm_response = self._serverGet("comm", {"pos": 2}) + comm_data = self._deserializeResponse(comm_response) + self.assertEqual("tensor", comm_data["type"]) + self.assertEqual("int32", comm_data["data"]["dtype"]) + self.assertEqual([], comm_data["data"]["shape"]) + self.assertEqual("inc_a", comm_data["data"]["node_name"]) + self.assertEqual(0, comm_data["data"]["output_slot"]) + self.assertEqual("DebugIdentity", comm_data["data"]["debug_op"]) + self.assertAllClose(11.0, comm_data["data"]["values"]) + + # Call /tensor_data to get the full history of the inc_a tensor (so far). + tensor_response = self._serverGet( + "tensor_data", + { + "watch_key": "inc_a:0:DebugIdentity", + "time_indices": ":", + "mapping": "", + "slicing": "", + }, + ) + tensor_data = self._deserializeResponse(tensor_response) + self.assertEqual({"tensor_data": [11], "error": None}, tensor_data) + + # Continue to the beginning of the 2nd session.run. + self._serverGet("ack") + comm_response = self._serverGet("comm", {"pos": 3}) + comm_data = self._deserializeResponse(comm_response) + self.assertGreater(comm_data["timestamp"], 0) + self.assertEqual("meta", comm_data["type"]) + self.assertEqual({"run_key": ["", "inc_a:0", ""]}, comm_data["data"]) + + # Continue to inc_a:0 for the 2nd time. + self._serverGet("ack") + comm_response = self._serverGet("comm", {"pos": 4}) + comm_data = self._deserializeResponse(comm_response) + self.assertEqual("tensor", comm_data["type"]) + self.assertEqual("int32", comm_data["data"]["dtype"]) + self.assertEqual([], comm_data["data"]["shape"]) + self.assertEqual("inc_a", comm_data["data"]["node_name"]) + self.assertEqual(0, comm_data["data"]["output_slot"]) + self.assertEqual("DebugIdentity", comm_data["data"]["debug_op"]) + self.assertAllClose(12.0, comm_data["data"]["values"]) + + # Call /tensor_data to get the full history of the inc_a tensor. + tensor_response = self._serverGet( + "tensor_data", + { + "watch_key": "inc_a:0:DebugIdentity", + "time_indices": ":", + "mapping": "", + "slicing": "", + }, + ) + tensor_data = self._deserializeResponse(tensor_response) + self.assertEqual({"tensor_data": [11, 12], "error": None}, tensor_data) + + # Call /tensor_data to get the latst time index of the inc_a tensor. + tensor_response = self._serverGet( + "tensor_data", + { + "watch_key": "inc_a:0:DebugIdentity", + "time_indices": "-1", + "mapping": "", + "slicing": "", + }, + ) + tensor_data = self._deserializeResponse(tensor_response) + self.assertEqual({"tensor_data": [12], "error": None}, tensor_data) + + # Continue to the end. + self._serverGet("ack") + + session_run_thread.join() + self.assertAllClose([11.0, 12.0], session_run_results) + + def testSetBreakpointOnNoTensorOp(self): + session_run_thread, session_run_results = self._runTfGroupGraph() + + comm_response = self._serverGet("comm", {"pos": 1}) + comm_data = self._deserializeResponse(comm_response) + self.assertGreater(comm_data["timestamp"], 0) + self.assertEqual("meta", comm_data["type"]) + self.assertEqual({"run_key": ["", "", "inc_ab"]}, comm_data["data"]) + + # Activate breakpoint for inc_a:0. + self._serverGet( + "gated_grpc", + { + "mode": "set_state", + "node_name": "inc_a", + "output_slot": 0, + "debug_op": "DebugIdentity", + "state": "break", + }, + ) + + # Activate breakpoint for inc_ab. + self._serverGet( + "gated_grpc", + { + "mode": "set_state", + "node_name": "inc_ab", + "output_slot": -1, + "debug_op": "DebugIdentity", + "state": "break", + }, + ) + + # Continue to inc_a:0. + self._serverGet("ack") + comm_response = self._serverGet("comm", {"pos": 2}) + comm_data = self._deserializeResponse(comm_response) + self.assertEqual("tensor", comm_data["type"]) + self.assertEqual("int32", comm_data["data"]["dtype"]) + self.assertEqual([], comm_data["data"]["shape"]) + self.assertEqual("inc_a", comm_data["data"]["node_name"]) + self.assertEqual(0, comm_data["data"]["output_slot"]) + self.assertEqual("DebugIdentity", comm_data["data"]["debug_op"]) + self.assertAllClose(11.0, comm_data["data"]["values"]) + + # Continue to the end. The breakpoint at inc_ab should not have blocked + # the execution, due to the fact that inc_ab is a tf.group op that produces + # no output. + self._serverGet("ack") + session_run_thread.join() + self.assertEqual([None], session_run_results) + + breakpoints_response = self._serverGet( + "gated_grpc", {"mode": "breakpoints"} + ) + breakpoints_data = self._deserializeResponse(breakpoints_response) + self.assertItemsEqual( + [["inc_a", 0, "DebugIdentity"], ["inc_ab", -1, "DebugIdentity"]], + breakpoints_data, + ) + + def testCommDataCanBeServedToMultipleClients(self): + ( + session_run_thread, + session_run_results, + ) = self._runSimpleAddMultiplyGraph() + + comm_response = self._serverGet("comm", {"pos": 1}) + comm_data_1 = self._deserializeResponse(comm_response) + + # Activate breakpoint for x:0. + self._serverGet( + "gated_grpc", + { + "mode": "set_state", + "node_name": "x", + "output_slot": 0, + "debug_op": "DebugIdentity", + "state": "break", + }, + ) + # Activate breakpoint for y:0. + self._serverGet( + "gated_grpc", + { + "mode": "set_state", + "node_name": "y", + "output_slot": 0, + "debug_op": "DebugIdentity", + "state": "break", + }, + ) + + # Proceed to breakpoint x:0. + self._serverGet("ack") + comm_response = self._serverGet("comm", {"pos": 2}) + comm_data_2 = self._deserializeResponse(comm_response) + self.assertGreater(comm_data_2["timestamp"], 0) + self.assertEqual("tensor", comm_data_2["type"]) + self.assertEqual("float32", comm_data_2["data"]["dtype"]) + self.assertEqual([1], comm_data_2["data"]["shape"]) + self.assertEqual("x", comm_data_2["data"]["node_name"]) + self.assertEqual(0, comm_data_2["data"]["output_slot"]) + self.assertEqual("DebugIdentity", comm_data_2["data"]["debug_op"]) + self.assertAllClose([200.0], comm_data_2["data"]["values"]) + + # Proceed to breakpoint y:0. + self._serverGet("ack") + comm_response = self._serverGet("comm", {"pos": 3}) + comm_data_3 = self._deserializeResponse(comm_response) + self.assertGreater(comm_data_3["timestamp"], 0) + self.assertEqual("tensor", comm_data_3["type"]) + self.assertEqual("float32", comm_data_3["data"]["dtype"]) + self.assertEqual([1], comm_data_3["data"]["shape"]) + self.assertEqual("y", comm_data_3["data"]["node_name"]) + self.assertEqual(0, comm_data_3["data"]["output_slot"]) + self.assertEqual("DebugIdentity", comm_data_3["data"]["debug_op"]) + self.assertAllClose([230.0], comm_data_3["data"]["values"]) + + # Proceed to the end of the Session.run(). + self._serverGet("ack") + + session_run_thread.join() + self.assertAllClose([[230.0]], session_run_results) + + # A 2nd client requests for comm data at positions 1, 2 and 3 again. + comm_response = self._serverGet("comm", {"pos": 1}) + self.assertEqual(comm_data_1, self._deserializeResponse(comm_response)) + comm_response = self._serverGet("comm", {"pos": 2}) + self.assertEqual(comm_data_2, self._deserializeResponse(comm_response)) + comm_response = self._serverGet("comm", {"pos": 3}) + self.assertEqual(comm_data_3, self._deserializeResponse(comm_response)) + + def testInvalidBreakpointStateLeadsTo400Response(self): + ( + session_run_thread, + session_run_results, + ) = self._runSimpleAddMultiplyGraph() + self._serverGet("comm", {"pos": 1}) + + # Use an invalid state ('bad_state') when setting a breakpoint state. + response = self._serverGet( + "gated_grpc", + { + "mode": "set_state", + "node_name": "x", + "output_slot": 0, + "debug_op": "DebugIdentity", + "state": "bad_state", + }, + expected_status_code=400, + ) + data = self._deserializeResponse(response) + self.assertEqual( + "Unrecognized new state for x:0:DebugIdentity: bad_state", + data["error"], + ) + + self._serverGet("ack") + session_run_thread.join() + self.assertAllClose([[230.0]], session_run_results) + + def testInvalidModeArgForGatedGrpcRouteLeadsTo400Response(self): + ( + session_run_thread, + session_run_results, + ) = self._runSimpleAddMultiplyGraph() + self._serverGet("comm", {"pos": 1}) + + # Use an invalid mode argument ('bad_mode') when calling the 'gated_grpc' + # endpoint. + response = self._serverGet( + "gated_grpc", + { + "mode": "bad_mode", + "node_name": "x", + "output_slot": 0, + "debug_op": "DebugIdentity", + "state": "break", + }, + expected_status_code=400, + ) + data = self._deserializeResponse(response) + self.assertEqual( + "Unrecognized mode for the gated_grpc route: bad_mode", + data["error"], + ) + + self._serverGet("ack") + session_run_thread.join() + self.assertAllClose([[230.0]], session_run_results) + + def testDebuggerHostAndGrpcPortEndpoint(self): + response = self._serverGet("debugger_grpc_host_port") + response_data = self._deserializeResponse(response) + self.assertTrue(response_data["host"]) + self.assertEqual(self._debugger_port, response_data["port"]) + + def testGetSourceFilePaths(self): + ( + session_run_thread, + session_run_results, + ) = self._runSimpleAddMultiplyGraph() + self._serverGet("comm", {"pos": 1}) + + source_paths_response = self._serverGet( + "source_code", {"mode": "paths"} + ) + response_data = self._deserializeResponse(source_paths_response) + self.assertIn(__file__, response_data["paths"]) + + self._serverGet("ack") + session_run_thread.join() + self.assertAllClose([[230.0]], session_run_results) + + def testGetSourceFileContentWithValidFilePath(self): + ( + session_run_thread, + session_run_results, + ) = self._runSimpleAddMultiplyGraph() + self._serverGet("comm", {"pos": 1}) + + file_content_response = self._serverGet( + "source_code", {"mode": "content", "file_path": __file__} + ) + response_data = self._deserializeResponse(file_content_response) + # Verify that the content of this file is included. + self.assertTrue(response_data["content"][__file__]) + # Verify that for the lines of the file that create TensorFlow ops, the list + # of op names and their stack heights are included. + op_linenos = collections.defaultdict(set) + for lineno in response_data["lineno_to_op_name_and_stack_pos"]: + self.assertGreater(int(lineno), 0) + for op_name, stack_pos in response_data[ + "lineno_to_op_name_and_stack_pos" + ][lineno]: + op_linenos[op_name].add(lineno) + self.assertGreaterEqual(stack_pos, 0) + self.assertTrue(op_linenos["a"]) + self.assertTrue(op_linenos["a/Assign"]) + self.assertTrue(op_linenos["a/initial_value"]) + self.assertTrue(op_linenos["a/read"]) + self.assertTrue(op_linenos["b"]) + self.assertTrue(op_linenos["b/Assign"]) + self.assertTrue(op_linenos["b/initial_value"]) + self.assertTrue(op_linenos["b/read"]) + self.assertTrue(op_linenos["c"]) + self.assertTrue(op_linenos["c/Assign"]) + self.assertTrue(op_linenos["c/initial_value"]) + self.assertTrue(op_linenos["c/read"]) + self.assertTrue(op_linenos["x"]) + self.assertTrue(op_linenos["y"]) + + self._serverGet("ack") + session_run_thread.join() + self.assertAllClose([[230.0]], session_run_results) + + def testGetSourceOpTraceback(self): + ( + session_run_thread, + session_run_results, + ) = self._runSimpleAddMultiplyGraph() + self._serverGet("comm", {"pos": 1}) + + for op_name in ("a", "b", "c", "x", "y"): + op_traceback_reponse = self._serverGet( + "source_code", {"mode": "op_traceback", "op_name": op_name} + ) + response_data = self._deserializeResponse(op_traceback_reponse) + found_current_file = False + for file_path, lineno in response_data["op_traceback"][op_name]: + self.assertGreater(lineno, 0) + if file_path == __file__: + found_current_file = True + break + self.assertTrue(found_current_file) + + self._serverGet("ack") + session_run_thread.join() + self.assertAllClose([[230.0]], session_run_results) + + def _runInitializer(self): + session_run_results = [] + + def session_run_job(): + with tf.Session() as sess: + a = tf.Variable([10.0] * 10, name="a") + sess = tf_debug.TensorBoardDebugWrapperSession( + sess, self._debugger_url + ) + # Run the initializer with a debugger-wrapped tf.Session. + session_run_results.append(sess.run(a.initializer)) + session_run_results.append(sess.run(a)) + + session_run_thread = threading.Thread(target=session_run_job) + session_run_thread.start() + return session_run_thread, session_run_results + + def testTensorDataForUnitializedTensorIsHandledCorrectly(self): + session_run_thread, session_run_results = self._runInitializer() + # Activate breakpoint for a:0. + self._serverGet( + "gated_grpc", + { + "mode": "set_state", + "node_name": "a", + "output_slot": 0, + "debug_op": "DebugIdentity", + "state": "break", + }, + ) + self._serverGet("ack") + self._serverGet("ack") + self._serverGet("ack") + self._serverGet("ack") + session_run_thread.join() + self.assertEqual(2, len(session_run_results)) + self.assertIsNone(session_run_results[0]) + self.assertAllClose([10.0] * 10, session_run_results[1]) + + # Get tensor data without slicing. + tensor_response = self._serverGet( + "tensor_data", + { + "watch_key": "a:0:DebugIdentity", + "time_indices": ":", + "mapping": "", + "slicing": "", + }, + ) + tensor_data = self._deserializeResponse(tensor_response) + self.assertIsNone(tensor_data["error"]) + tensor_data = tensor_data["tensor_data"] + self.assertEqual(2, len(tensor_data)) + self.assertIsNone(tensor_data[0]) + self.assertAllClose([10.0] * 10, tensor_data[1]) + + # Get tensor data with slicing. + tensor_response = self._serverGet( + "tensor_data", + { + "watch_key": "a:0:DebugIdentity", + "time_indices": ":", + "mapping": "", + "slicing": "[:5]", + }, + ) + tensor_data = self._deserializeResponse(tensor_response) + self.assertIsNone(tensor_data["error"]) + tensor_data = tensor_data["tensor_data"] + self.assertEqual(2, len(tensor_data)) + self.assertIsNone(tensor_data[0]) + self.assertAllClose([10.0] * 5, tensor_data[1]) + + def testCommDataForUninitializedTensorIsHandledCorrectly(self): + session_run_thread, _ = self._runInitializer() + # Activate breakpoint for a:0. + self._serverGet( + "gated_grpc", + { + "mode": "set_state", + "node_name": "a", + "output_slot": 0, + "debug_op": "DebugIdentity", + "state": "break", + }, + ) + self._serverGet("ack") + comm_response = self._serverGet("comm", {"pos": 2}) + comm_data = self._deserializeResponse(comm_response) + self.assertEqual("tensor", comm_data["type"]) + self.assertEqual("Uninitialized", comm_data["data"]["dtype"]) + self.assertEqual("Uninitialized", comm_data["data"]["shape"]) + self.assertEqual("N/A", comm_data["data"]["values"]) + self.assertEqual( + "a/(a)", comm_data["data"]["maybe_base_expanded_node_name"] + ) + self._serverGet("ack") + self._serverGet("ack") + self._serverGet("ack") + session_run_thread.join() + + def _runHealthPillNetwork(self): + session_run_results = [] + + def session_run_job(): + with tf.Session() as sess: + a = tf.Variable( + [ + np.nan, + np.inf, + np.inf, + -np.inf, + -np.inf, + -np.inf, + 10, + 20, + 30, + ], + dtype=tf.float32, + name="a", + ) + session_run_results.append(sess.run(a.initializer)) + sess = tf_debug.TensorBoardDebugWrapperSession( + sess, self._debugger_url + ) + session_run_results.append(sess.run(a)) + + session_run_thread = threading.Thread(target=session_run_job) + session_run_thread.start() + return session_run_thread, session_run_results + + def testHealthPill(self): + session_run_thread, _ = self._runHealthPillNetwork() + # Activate breakpoint for a:0. + self._serverGet( + "gated_grpc", + { + "mode": "set_state", + "node_name": "a", + "output_slot": 0, + "debug_op": "DebugIdentity", + "state": "break", + }, + ) + self._serverGet("ack") + self._serverGet("ack") + session_run_thread.join() + tensor_response = self._serverGet( + "tensor_data", + { + "watch_key": "a:0:DebugIdentity", + "time_indices": "-1", + "mapping": "health-pill", + "slicing": "", + }, + ) + tensor_data = self._deserializeResponse(tensor_response) + self.assertIsNone(tensor_data["error"]) + tensor_data = tensor_data["tensor_data"][0] + self.assertAllClose(1.0, tensor_data[0]) # IsInitialized. + self.assertAllClose(9.0, tensor_data[1]) # Total count. + self.assertAllClose(1.0, tensor_data[2]) # NaN count. + self.assertAllClose(3.0, tensor_data[3]) # -Infinity count. + self.assertAllClose(0.0, tensor_data[4]) # Finite negative count. + self.assertAllClose(0.0, tensor_data[5]) # Zero count. + self.assertAllClose(3.0, tensor_data[6]) # Positive count. + self.assertAllClose(2.0, tensor_data[7]) # +Infinity count. + self.assertAllClose(10.0, tensor_data[8]) # Min. + self.assertAllClose(30.0, tensor_data[9]) # Max. + self.assertAllClose(20.0, tensor_data[10]) # Mean. + self.assertAllClose( + np.var([10.0, 20.0, 30.0]), tensor_data[11] + ) # Variance. + + def _runAsciiStringNetwork(self): + session_run_results = [] + + def session_run_job(): + with tf.Session() as sess: + str1 = tf.Variable("abc", name="str1") + str2 = tf.Variable("def", name="str2") + str_concat = tf.add(str1, str2, name="str_concat") + sess.run(tf.global_variables_initializer()) + sess = tf_debug.TensorBoardDebugWrapperSession( + sess, self._debugger_url + ) + session_run_results.append(sess.run(str_concat)) + + session_run_thread = threading.Thread(target=session_run_job) + session_run_thread.start() + return session_run_thread, session_run_results + + def testAsciiStringTensorIsHandledCorrectly(self): + session_run_thread, session_run_results = self._runAsciiStringNetwork() + # Activate breakpoint for str1:0. + self._serverGet( + "gated_grpc", + { + "mode": "set_state", + "node_name": "str1", + "output_slot": 0, + "debug_op": "DebugIdentity", + "state": "break", + }, + ) + self._serverGet("ack") + self._serverGet("ack") + comm_response = self._serverGet("comm", {"pos": 2}) + comm_data = self._deserializeResponse(comm_response) + self.assertEqual("tensor", comm_data["type"]) + self.assertEqual("string", comm_data["data"]["dtype"]) + self.assertEqual([], comm_data["data"]["shape"]) + self.assertEqual("abc", comm_data["data"]["values"]) + self.assertEqual( + "str1/(str1)", comm_data["data"]["maybe_base_expanded_node_name"] + ) + session_run_thread.join() + self.assertEqual(1, len(session_run_results)) + self.assertEqual(b"abcdef", session_run_results[0]) + + # Get the value of a tensor without mapping. + tensor_response = self._serverGet( + "tensor_data", + { + "watch_key": "str1:0:DebugIdentity", + "time_indices": "-1", + "mapping": "", + "slicing": "", + }, + ) + tensor_data = self._deserializeResponse(tensor_response) + self.assertEqual(None, tensor_data["error"]) + self.assertEqual(["abc"], tensor_data["tensor_data"]) + + # Get the health pill of a string tensor. + tensor_response = self._serverGet( + "tensor_data", + { + "watch_key": "str1:0:DebugIdentity", + "time_indices": "-1", + "mapping": "health-pill", + "slicing": "", + }, + ) + tensor_data = self._deserializeResponse(tensor_response) + self.assertEqual(None, tensor_data["error"]) + self.assertEqual([None], tensor_data["tensor_data"]) + + def _runBinaryStringNetwork(self): + session_run_results = [] + + def session_run_job(): + with tf.Session() as sess: + str1 = tf.Variable([b"\x01" * 3, b"\x02" * 3], name="str1") + str2 = tf.Variable([b"\x03" * 3, b"\x04" * 3], name="str2") + str_concat = tf.add(str1, str2, name="str_concat") + sess.run(tf.global_variables_initializer()) + sess = tf_debug.TensorBoardDebugWrapperSession( + sess, self._debugger_url + ) + session_run_results.append(sess.run(str_concat)) + + session_run_thread = threading.Thread(target=session_run_job) + session_run_thread.start() + return session_run_thread, session_run_results + + def testBinaryStringTensorIsHandledCorrectly(self): + session_run_thread, session_run_results = self._runBinaryStringNetwork() + # Activate breakpoint for str1:0. + self._serverGet( + "gated_grpc", + { + "mode": "set_state", + "node_name": "str1", + "output_slot": 0, + "debug_op": "DebugIdentity", + "state": "break", + }, + ) + self._serverGet("ack") + self._serverGet("ack") + comm_response = self._serverGet("comm", {"pos": 2}) + comm_data = self._deserializeResponse(comm_response) + self.assertEqual("tensor", comm_data["type"]) + self.assertEqual("string", comm_data["data"]["dtype"]) + self.assertEqual([2], comm_data["data"]["shape"]) + self.assertEqual(2, len(comm_data["data"]["values"])) + self.assertEqual( + b"=01" * 3, tf.compat.as_bytes(comm_data["data"]["values"][0]) + ) + self.assertEqual( + b"=02" * 3, tf.compat.as_bytes(comm_data["data"]["values"][1]) + ) + self.assertEqual( + "str1/(str1)", comm_data["data"]["maybe_base_expanded_node_name"] + ) + session_run_thread.join() + self.assertEqual(1, len(session_run_results)) + self.assertAllEqual( + np.array( + [b"\x01\x01\x01\x03\x03\x03", b"\x02\x02\x02\x04\x04\x04"], + dtype=np.object, + ), + session_run_results[0], + ) + + # Get the value of a tensor without mapping. + tensor_response = self._serverGet( + "tensor_data", + { + "watch_key": "str1:0:DebugIdentity", + "time_indices": "-1", + "mapping": "", + "slicing": "", + }, + ) + tensor_data = self._deserializeResponse(tensor_response) + self.assertEqual(None, tensor_data["error"]) + self.assertEqual(2, len(tensor_data["tensor_data"][0])) + self.assertEqual( + b"=01=01=01", tf.compat.as_bytes(tensor_data["tensor_data"][0][0]) + ) + self.assertEqual( + b"=02=02=02", tf.compat.as_bytes(tensor_data["tensor_data"][0][1]) + ) + + # Get the health pill of a string tensor. + tensor_response = self._serverGet( + "tensor_data", + { + "watch_key": "str1:0:DebugIdentity", + "time_indices": "-1", + "mapping": "health-pill", + "slicing": "", + }, + ) + tensor_data = self._deserializeResponse(tensor_response) + self.assertEqual(None, tensor_data["error"]) + self.assertEqual([None], tensor_data["tensor_data"]) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/tensorboard/plugins/debugger/interactive_debugger_server_lib.py b/tensorboard/plugins/debugger/interactive_debugger_server_lib.py index 551ad7d53b..ba39d2d215 100644 --- a/tensorboard/plugins/debugger/interactive_debugger_server_lib.py +++ b/tensorboard/plugins/debugger/interactive_debugger_server_lib.py @@ -14,8 +14,8 @@ # ============================================================================== """Receives data from a TensorFlow debugger. Writes event summaries. -This listener server writes debugging-related events into a logdir directory, -from which a TensorBoard instance can read. +This listener server writes debugging-related events into a logdir +directory, from which a TensorBoard instance can read. """ from __future__ import absolute_import @@ -42,596 +42,655 @@ logger = tb_logging.get_logger() RunKey = collections.namedtuple( - 'RunKey', ['input_names', 'output_names', 'target_nodes']) + "RunKey", ["input_names", "output_names", "target_nodes"] +) def _extract_device_name_from_event(event): - """Extract device name from a tf.Event proto carrying tensor value.""" - plugin_data_content = json.loads( - tf.compat.as_str(event.summary.value[0].metadata.plugin_data.content)) - return plugin_data_content['device'] + """Extract device name from a tf.Event proto carrying tensor value.""" + plugin_data_content = json.loads( + tf.compat.as_str(event.summary.value[0].metadata.plugin_data.content) + ) + return plugin_data_content["device"] def _comm_metadata(run_key, timestamp): - return { - 'type': 'meta', - 'timestamp': timestamp, - 'data': { - 'run_key': run_key, - } - } + return { + "type": "meta", + "timestamp": timestamp, + "data": {"run_key": run_key,}, + } -UNINITIALIZED_TAG = 'Uninitialized' -UNSUPPORTED_TAG = 'Unsupported' -NA_TAG = 'N/A' +UNINITIALIZED_TAG = "Uninitialized" +UNSUPPORTED_TAG = "Unsupported" +NA_TAG = "N/A" STRING_ELEMENT_MAX_LEN = 40 -def _comm_tensor_data(device_name, - node_name, - maybe_base_expanded_node_name, - output_slot, - debug_op, - tensor_value, - wall_time): - """Create a dict() as the outgoing data in the tensor data comm route. - - Note: The tensor data in the comm route does not include the value of the - tensor in its entirety in general. Only if a tensor satisfies the following - conditions will its entire value be included in the return value of this - method: - 1. Has a numeric data type (e.g., float32, int32) and has fewer than 5 - elements. - 2. Is a string tensor and has fewer than 5 elements. Each string element is - up to 40 bytes. - - Args: - device_name: Name of the device that the tensor is on. - node_name: (Original) name of the node that produces the tensor. - maybe_base_expanded_node_name: Possbily base-expanded node name. - output_slot: Output slot number. - debug_op: Name of the debug op. - tensor_value: Value of the tensor, as a numpy.ndarray. - wall_time: Wall timestamp for the tensor. - - Returns: - A dict representing the tensor data. - """ - output_slot = int(output_slot) - logger.info( - 'Recording tensor value: %s, %d, %s', node_name, output_slot, debug_op) - tensor_values = None - if isinstance(tensor_value, debug_data.InconvertibleTensorProto): - if not tensor_value.initialized: - tensor_dtype = UNINITIALIZED_TAG - tensor_shape = UNINITIALIZED_TAG - else: - tensor_dtype = UNSUPPORTED_TAG - tensor_shape = UNSUPPORTED_TAG - tensor_values = NA_TAG - else: - tensor_dtype = tensor_helper.translate_dtype(tensor_value.dtype) - tensor_shape = tensor_value.shape - - # The /comm endpoint should respond with tensor values only if the tensor is - # small enough. Otherwise, the detailed values sould be queried through a - # dedicated tensor_data that supports slicing. - if tensor_helper.numel(tensor_shape) < 5: - _, _, tensor_values = tensor_helper.array_view(tensor_value) - if tensor_dtype == 'string' and tensor_value is not None: - tensor_values = tensor_helper.process_buffers_for_display( - tensor_values, limit=STRING_ELEMENT_MAX_LEN) - - return { - 'type': 'tensor', - 'timestamp': wall_time, - 'data': { - 'device_name': device_name, - 'node_name': node_name, - 'maybe_base_expanded_node_name': maybe_base_expanded_node_name, - 'output_slot': output_slot, - 'debug_op': debug_op, - 'dtype': tensor_dtype, - 'shape': tensor_shape, - 'values': tensor_values, - }, - } - - -class RunStates(object): - """A class that keeps track of state of debugged Session.run() calls.""" - - def __init__(self, breakpoints_func=None): - """Constructor of RunStates. +def _comm_tensor_data( + device_name, + node_name, + maybe_base_expanded_node_name, + output_slot, + debug_op, + tensor_value, + wall_time, +): + """Create a dict() as the outgoing data in the tensor data comm route. + + Note: The tensor data in the comm route does not include the value of the + tensor in its entirety in general. Only if a tensor satisfies the following + conditions will its entire value be included in the return value of this + method: + 1. Has a numeric data type (e.g., float32, int32) and has fewer than 5 + elements. + 2. Is a string tensor and has fewer than 5 elements. Each string element is + up to 40 bytes. Args: - breakpoint_func: A callable of the signatuer: - def breakpoint_func(): - which returns all the currently activated breakpoints. - """ - # Maps from run key to debug_graphs_helper.DebugGraphWrapper instance. - self._run_key_to_original_graphs = dict() - self._run_key_to_debug_graphs = dict() - - if breakpoints_func: - assert callable(breakpoints_func) - self._breakpoints_func = breakpoints_func - - def add_graph(self, run_key, device_name, graph_def, debug=False): - """Add a GraphDef. - - Args: - run_key: A key for the run, containing information about the feeds, - fetches, and targets. - device_name: The name of the device that the `GraphDef` is for. - graph_def: An instance of the `GraphDef` proto. - debug: Whether `graph_def` consists of the debug ops. - """ - graph_dict = (self._run_key_to_debug_graphs if debug else - self._run_key_to_original_graphs) - if not run_key in graph_dict: - graph_dict[run_key] = dict() # Mapping device_name to GraphDef. - graph_dict[run_key][tf.compat.as_str(device_name)] = ( - debug_graphs_helper.DebugGraphWrapper(graph_def)) - - def get_graphs(self, run_key, debug=False): - """Get the runtime GraphDef protos associated with a run key. - - Args: - run_key: A Session.run kay. - debug: Whether the debugger-decoratedgraph is to be retrieved. - - Returns: - A `dict` mapping device name to `GraphDef` protos. - """ - graph_dict = (self._run_key_to_debug_graphs if debug else - self._run_key_to_original_graphs) - graph_wrappers = graph_dict.get(run_key, {}) - graph_defs = dict() - for device_name, wrapper in graph_wrappers.items(): - graph_defs[device_name] = wrapper.graph_def - return graph_defs - - def get_graph(self, run_key, device_name, debug=False): - """Get the runtime GraphDef proto associated with a run key and a device. - - Args: - run_key: A Session.run kay. - device_name: Name of the device in question. - debug: Whether the debugger-decoratedgraph is to be retrieved. + device_name: Name of the device that the tensor is on. + node_name: (Original) name of the node that produces the tensor. + maybe_base_expanded_node_name: Possbily base-expanded node name. + output_slot: Output slot number. + debug_op: Name of the debug op. + tensor_value: Value of the tensor, as a numpy.ndarray. + wall_time: Wall timestamp for the tensor. Returns: - A `GraphDef` proto. - """ - return self.get_graphs(run_key, debug=debug).get(device_name, None) - - def get_breakpoints(self): - """Obtain all the currently activated breakpoints.""" - return self._breakpoints_func() - - def get_gated_grpc_tensors(self, run_key, device_name): - return self._run_key_to_debug_graphs[ - run_key][device_name].get_gated_grpc_tensors() - - def get_maybe_base_expanded_node_name(self, node_name, run_key, device_name): - """Obtain possibly base-expanded node name. - - Base-expansion is the transformation of a node name which happens to be the - name scope of other nodes in the same graph. For example, if two nodes, - called 'a/b' and 'a/b/read' in a graph, the name of the first node will - be base-expanded to 'a/b/(b)'. - - This method uses caching to avoid unnecessary recomputation. - - Args: - node_name: Name of the node. - run_key: The run key to which the node belongs. - graph_def: GraphDef to which the node belongs. - - Raises: - ValueError: If `run_key` and/or `device_name` do not exist in the record. - """ - device_name = tf.compat.as_str(device_name) - if run_key not in self._run_key_to_original_graphs: - raise ValueError('Unknown run_key: %s' % run_key) - if device_name not in self._run_key_to_original_graphs[run_key]: - raise ValueError( - 'Unknown device for run key "%s": %s' % (run_key, device_name)) - return self._run_key_to_original_graphs[ - run_key][device_name].maybe_base_expanded_node_name(node_name) - - -class InteractiveDebuggerDataStreamHandler( - grpc_debug_server.EventListenerBaseStreamHandler): - """Implementation of stream handler for debugger data. - - Each instance of this class is created by a InteractiveDebuggerDataServer - upon a gRPC stream established between the debugged Session::Run() invocation - in TensorFlow core runtime and the InteractiveDebuggerDataServer instance. - - Each instance of this class does the following: - 1) receives a core metadata Event proto during its constructor call. - 2) receives GraphDef Event proto(s) through its on_graph_def method. - 3) receives tensor value Event proto(s) through its on_value_event method. - """ - - def __init__( - self, incoming_channel, outgoing_channel, run_states, tensor_store): - """Constructor of InteractiveDebuggerDataStreamHandler. - - Args: - incoming_channel: An instance of FIFO queue, which manages incoming data, - e.g., ACK signals from the client side unblock breakpoints. - outgoing_channel: An instance of `CommChannel`, which manages outgoing - data, i.e., data regarding the starting of Session.runs and hitting of - tensor breakpoint.s - run_states: An instance of `RunStates`, which keeps track of the states - (graphs and breakpoints) of debugged Session.run() calls. - tensor_store: An instance of `TensorStore`, which stores Tensor values - from debugged Session.run() calls. + A dict representing the tensor data. """ - super(InteractiveDebuggerDataStreamHandler, self).__init__() - - self._incoming_channel = incoming_channel - self._outgoing_channel = outgoing_channel - self._run_states = run_states - self._tensor_store = tensor_store - - self._run_key = None - self._graph_defs = dict() # A dict mapping device name to GraphDef. - self._graph_defs_arrive_first = True - - def on_core_metadata_event(self, event): - """Implementation of the core metadata-carrying Event proto callback. - - Args: - event: An Event proto that contains core metadata about the debugged - Session::Run() in its log_message.message field, as a JSON string. - See the doc string of debug_data.DebugDumpDir.core_metadata for details. - """ - core_metadata = json.loads(event.log_message.message) - input_names = ','.join(core_metadata['input_names']) - output_names = ','.join(core_metadata['output_names']) - target_nodes = ','.join(core_metadata['target_nodes']) - - self._run_key = RunKey(input_names, output_names, target_nodes) - if not self._graph_defs: - self._graph_defs_arrive_first = False + output_slot = int(output_slot) + logger.info( + "Recording tensor value: %s, %d, %s", node_name, output_slot, debug_op + ) + tensor_values = None + if isinstance(tensor_value, debug_data.InconvertibleTensorProto): + if not tensor_value.initialized: + tensor_dtype = UNINITIALIZED_TAG + tensor_shape = UNINITIALIZED_TAG + else: + tensor_dtype = UNSUPPORTED_TAG + tensor_shape = UNSUPPORTED_TAG + tensor_values = NA_TAG else: - for device_name in self._graph_defs: - self._add_graph_def(device_name, self._graph_defs[device_name]) - - self._outgoing_channel.put(_comm_metadata(self._run_key, event.wall_time)) - - # Wait for acknowledgement from client. Blocks until an item is got. - logger.info('on_core_metadata_event() waiting for client ack (meta)...') - self._incoming_channel.get() - logger.info('on_core_metadata_event() client ack received (meta).') + tensor_dtype = tensor_helper.translate_dtype(tensor_value.dtype) + tensor_shape = tensor_value.shape + + # The /comm endpoint should respond with tensor values only if the tensor is + # small enough. Otherwise, the detailed values sould be queried through a + # dedicated tensor_data that supports slicing. + if tensor_helper.numel(tensor_shape) < 5: + _, _, tensor_values = tensor_helper.array_view(tensor_value) + if tensor_dtype == "string" and tensor_value is not None: + tensor_values = tensor_helper.process_buffers_for_display( + tensor_values, limit=STRING_ELEMENT_MAX_LEN + ) + + return { + "type": "tensor", + "timestamp": wall_time, + "data": { + "device_name": device_name, + "node_name": node_name, + "maybe_base_expanded_node_name": maybe_base_expanded_node_name, + "output_slot": output_slot, + "debug_op": debug_op, + "dtype": tensor_dtype, + "shape": tensor_shape, + "values": tensor_values, + }, + } - # TODO(cais): If eager mode, this should return something to yield. - def _add_graph_def(self, device_name, graph_def): - self._run_states.add_graph( - self._run_key, device_name, - tf_debug.reconstruct_non_debug_graph_def(graph_def)) - self._run_states.add_graph( - self._run_key, device_name, graph_def, debug=True) +class RunStates(object): + """A class that keeps track of state of debugged Session.run() calls.""" + + def __init__(self, breakpoints_func=None): + """Constructor of RunStates. + + Args: + breakpoint_func: A callable of the signatuer: + def breakpoint_func(): + which returns all the currently activated breakpoints. + """ + # Maps from run key to debug_graphs_helper.DebugGraphWrapper instance. + self._run_key_to_original_graphs = dict() + self._run_key_to_debug_graphs = dict() + + if breakpoints_func: + assert callable(breakpoints_func) + self._breakpoints_func = breakpoints_func + + def add_graph(self, run_key, device_name, graph_def, debug=False): + """Add a GraphDef. + + Args: + run_key: A key for the run, containing information about the feeds, + fetches, and targets. + device_name: The name of the device that the `GraphDef` is for. + graph_def: An instance of the `GraphDef` proto. + debug: Whether `graph_def` consists of the debug ops. + """ + graph_dict = ( + self._run_key_to_debug_graphs + if debug + else self._run_key_to_original_graphs + ) + if not run_key in graph_dict: + graph_dict[run_key] = dict() # Mapping device_name to GraphDef. + graph_dict[run_key][ + tf.compat.as_str(device_name) + ] = debug_graphs_helper.DebugGraphWrapper(graph_def) + + def get_graphs(self, run_key, debug=False): + """Get the runtime GraphDef protos associated with a run key. + + Args: + run_key: A Session.run kay. + debug: Whether the debugger-decoratedgraph is to be retrieved. + + Returns: + A `dict` mapping device name to `GraphDef` protos. + """ + graph_dict = ( + self._run_key_to_debug_graphs + if debug + else self._run_key_to_original_graphs + ) + graph_wrappers = graph_dict.get(run_key, {}) + graph_defs = dict() + for device_name, wrapper in graph_wrappers.items(): + graph_defs[device_name] = wrapper.graph_def + return graph_defs + + def get_graph(self, run_key, device_name, debug=False): + """Get the runtime GraphDef proto associated with a run key and a + device. + + Args: + run_key: A Session.run kay. + device_name: Name of the device in question. + debug: Whether the debugger-decoratedgraph is to be retrieved. + + Returns: + A `GraphDef` proto. + """ + return self.get_graphs(run_key, debug=debug).get(device_name, None) + + def get_breakpoints(self): + """Obtain all the currently activated breakpoints.""" + return self._breakpoints_func() + + def get_gated_grpc_tensors(self, run_key, device_name): + return self._run_key_to_debug_graphs[run_key][ + device_name + ].get_gated_grpc_tensors() + + def get_maybe_base_expanded_node_name( + self, node_name, run_key, device_name + ): + """Obtain possibly base-expanded node name. + + Base-expansion is the transformation of a node name which happens to be the + name scope of other nodes in the same graph. For example, if two nodes, + called 'a/b' and 'a/b/read' in a graph, the name of the first node will + be base-expanded to 'a/b/(b)'. + + This method uses caching to avoid unnecessary recomputation. + + Args: + node_name: Name of the node. + run_key: The run key to which the node belongs. + graph_def: GraphDef to which the node belongs. + + Raises: + ValueError: If `run_key` and/or `device_name` do not exist in the record. + """ + device_name = tf.compat.as_str(device_name) + if run_key not in self._run_key_to_original_graphs: + raise ValueError("Unknown run_key: %s" % run_key) + if device_name not in self._run_key_to_original_graphs[run_key]: + raise ValueError( + 'Unknown device for run key "%s": %s' % (run_key, device_name) + ) + return self._run_key_to_original_graphs[run_key][ + device_name + ].maybe_base_expanded_node_name(node_name) - def on_graph_def(self, graph_def, device_name, wall_time): - """Implementation of the GraphDef-carrying Event proto callback. - Args: - graph_def: A GraphDef proto. N.B.: The GraphDef is from - the core runtime of a debugged Session::Run() call, after graph - partition. Therefore it may differ from the GraphDef available to - the general TensorBoard. For example, the GraphDef in general - TensorBoard may get partitioned for multiple devices (CPUs and GPUs), - each of which will generate a GraphDef event proto sent to this - method. - device_name: Name of the device on which the graph was created. - wall_time: An epoch timestamp (in microseconds) for the graph. +class InteractiveDebuggerDataStreamHandler( + grpc_debug_server.EventListenerBaseStreamHandler +): + """Implementation of stream handler for debugger data. + + Each instance of this class is created by a InteractiveDebuggerDataServer + upon a gRPC stream established between the debugged Session::Run() invocation + in TensorFlow core runtime and the InteractiveDebuggerDataServer instance. + + Each instance of this class does the following: + 1) receives a core metadata Event proto during its constructor call. + 2) receives GraphDef Event proto(s) through its on_graph_def method. + 3) receives tensor value Event proto(s) through its on_value_event method. """ - # For now, we do nothing with the graph def. However, we must define this - # method to satisfy the handler's interface. Furthermore, we may use the - # graph in the future (for instance to provide a graph if there is no graph - # provided otherwise). - del wall_time - self._graph_defs[device_name] = graph_def - if not self._graph_defs_arrive_first: - self._add_graph_def(device_name, graph_def) - self._incoming_channel.get() - - def on_value_event(self, event): - """Records the summary values based on an updated message from the debugger. - - Logs an error message if writing the event to disk fails. - - Args: - event: The Event proto to be processed. - """ - if not event.summary.value: - logger.info('The summary of the event lacks a value.') - return None - - # The node name property in the event proto is actually a watch key, which - # is a concatenation of several pieces of data. - watch_key = event.summary.value[0].node_name - tensor_value = debug_data.load_tensor_from_event(event) - device_name = _extract_device_name_from_event(event) - node_name, output_slot, debug_op = ( - event.summary.value[0].node_name.split(':')) - maybe_base_expanded_node_name = ( - self._run_states.get_maybe_base_expanded_node_name(node_name, - self._run_key, - device_name)) - self._tensor_store.add(watch_key, tensor_value) - self._outgoing_channel.put(_comm_tensor_data( - device_name, node_name, maybe_base_expanded_node_name, output_slot, - debug_op, tensor_value, event.wall_time)) - - logger.info('on_value_event(): waiting for client ack (tensors)...') - self._incoming_channel.get() - logger.info('on_value_event(): client ack received (tensor).') - - # Determine if the particular debug watch key is in the current list of - # breakpoints. If it is, send an EventReply() to unblock the debug op. - if self._is_debug_node_in_breakpoints(event.summary.value[0].node_name): - logger.info('Sending empty EventReply for breakpoint: %s', - event.summary.value[0].node_name) - # TODO(cais): Support receiving and sending tensor value from front-end. - return debug_service_pb2.EventReply() - return None - - def _is_debug_node_in_breakpoints(self, debug_node_key): - node_name, output_slot, debug_op = debug_node_key.split(':') - output_slot = int(output_slot) - return (node_name, output_slot, - debug_op) in self._run_states.get_breakpoints() + def __init__( + self, incoming_channel, outgoing_channel, run_states, tensor_store + ): + """Constructor of InteractiveDebuggerDataStreamHandler. + + Args: + incoming_channel: An instance of FIFO queue, which manages incoming data, + e.g., ACK signals from the client side unblock breakpoints. + outgoing_channel: An instance of `CommChannel`, which manages outgoing + data, i.e., data regarding the starting of Session.runs and hitting of + tensor breakpoint.s + run_states: An instance of `RunStates`, which keeps track of the states + (graphs and breakpoints) of debugged Session.run() calls. + tensor_store: An instance of `TensorStore`, which stores Tensor values + from debugged Session.run() calls. + """ + super(InteractiveDebuggerDataStreamHandler, self).__init__() + + self._incoming_channel = incoming_channel + self._outgoing_channel = outgoing_channel + self._run_states = run_states + self._tensor_store = tensor_store + + self._run_key = None + self._graph_defs = dict() # A dict mapping device name to GraphDef. + self._graph_defs_arrive_first = True + + def on_core_metadata_event(self, event): + """Implementation of the core metadata-carrying Event proto callback. + + Args: + event: An Event proto that contains core metadata about the debugged + Session::Run() in its log_message.message field, as a JSON string. + See the doc string of debug_data.DebugDumpDir.core_metadata for details. + """ + core_metadata = json.loads(event.log_message.message) + input_names = ",".join(core_metadata["input_names"]) + output_names = ",".join(core_metadata["output_names"]) + target_nodes = ",".join(core_metadata["target_nodes"]) + + self._run_key = RunKey(input_names, output_names, target_nodes) + if not self._graph_defs: + self._graph_defs_arrive_first = False + else: + for device_name in self._graph_defs: + self._add_graph_def(device_name, self._graph_defs[device_name]) + + self._outgoing_channel.put( + _comm_metadata(self._run_key, event.wall_time) + ) + + # Wait for acknowledgement from client. Blocks until an item is got. + logger.info("on_core_metadata_event() waiting for client ack (meta)...") + self._incoming_channel.get() + logger.info("on_core_metadata_event() client ack received (meta).") + + # TODO(cais): If eager mode, this should return something to yield. + + def _add_graph_def(self, device_name, graph_def): + self._run_states.add_graph( + self._run_key, + device_name, + tf_debug.reconstruct_non_debug_graph_def(graph_def), + ) + self._run_states.add_graph( + self._run_key, device_name, graph_def, debug=True + ) + + def on_graph_def(self, graph_def, device_name, wall_time): + """Implementation of the GraphDef-carrying Event proto callback. + + Args: + graph_def: A GraphDef proto. N.B.: The GraphDef is from + the core runtime of a debugged Session::Run() call, after graph + partition. Therefore it may differ from the GraphDef available to + the general TensorBoard. For example, the GraphDef in general + TensorBoard may get partitioned for multiple devices (CPUs and GPUs), + each of which will generate a GraphDef event proto sent to this + method. + device_name: Name of the device on which the graph was created. + wall_time: An epoch timestamp (in microseconds) for the graph. + """ + # For now, we do nothing with the graph def. However, we must define this + # method to satisfy the handler's interface. Furthermore, we may use the + # graph in the future (for instance to provide a graph if there is no graph + # provided otherwise). + del wall_time + self._graph_defs[device_name] = graph_def + + if not self._graph_defs_arrive_first: + self._add_graph_def(device_name, graph_def) + self._incoming_channel.get() + + def on_value_event(self, event): + """Records the summary values based on an updated message from the + debugger. + + Logs an error message if writing the event to disk fails. + + Args: + event: The Event proto to be processed. + """ + if not event.summary.value: + logger.info("The summary of the event lacks a value.") + return None + + # The node name property in the event proto is actually a watch key, which + # is a concatenation of several pieces of data. + watch_key = event.summary.value[0].node_name + tensor_value = debug_data.load_tensor_from_event(event) + device_name = _extract_device_name_from_event(event) + node_name, output_slot, debug_op = event.summary.value[ + 0 + ].node_name.split(":") + maybe_base_expanded_node_name = self._run_states.get_maybe_base_expanded_node_name( + node_name, self._run_key, device_name + ) + self._tensor_store.add(watch_key, tensor_value) + self._outgoing_channel.put( + _comm_tensor_data( + device_name, + node_name, + maybe_base_expanded_node_name, + output_slot, + debug_op, + tensor_value, + event.wall_time, + ) + ) + + logger.info("on_value_event(): waiting for client ack (tensors)...") + self._incoming_channel.get() + logger.info("on_value_event(): client ack received (tensor).") + + # Determine if the particular debug watch key is in the current list of + # breakpoints. If it is, send an EventReply() to unblock the debug op. + if self._is_debug_node_in_breakpoints(event.summary.value[0].node_name): + logger.info( + "Sending empty EventReply for breakpoint: %s", + event.summary.value[0].node_name, + ) + # TODO(cais): Support receiving and sending tensor value from front-end. + return debug_service_pb2.EventReply() + return None + + def _is_debug_node_in_breakpoints(self, debug_node_key): + node_name, output_slot, debug_op = debug_node_key.split(":") + output_slot = int(output_slot) + return ( + node_name, + output_slot, + debug_op, + ) in self._run_states.get_breakpoints() # TODO(cais): Consider moving to a seperate python module. class SourceManager(object): - """Manages source files and tracebacks involved in the debugged TF program. - - """ - - def __init__(self): - # A dict mapping file path to file content as a list of strings. - self._source_file_content = dict() - # A dict mapping file path to host name. - self._source_file_host = dict() - # A dict mapping file path to last modified timestamp. - self._source_file_last_modified = dict() - # A dict mapping file path to size in bytes. - self._source_file_bytes = dict() - # Keeps track f the traceback of the latest graph version. - self._graph_traceback = None - self._graph_version = -1 - - def add_debugged_source_file(self, debugged_source_file): - """Add a DebuggedSourceFile proto.""" - # TODO(cais): Should the key include a host name, for certain distributed - # cases? - key = debugged_source_file.file_path - self._source_file_host[key] = debugged_source_file.host - self._source_file_last_modified[key] = debugged_source_file.last_modified - self._source_file_bytes[key] = debugged_source_file.bytes - self._source_file_content[key] = debugged_source_file.lines - - def add_graph_traceback(self, graph_version, graph_traceback): - if graph_version > self._graph_version: - self._graph_traceback = graph_traceback - self._graph_version = graph_version - - def get_paths(self): - """Get the paths to all available source files.""" - return list(self._source_file_content.keys()) - - def get_content(self, file_path): - """Get the content of a source file. - - # TODO(cais): Maybe support getting a range of lines by line number. - - Args: - file_path: Path to the source file. - """ - return self._source_file_content[file_path] - - def get_op_traceback(self, op_name): - """Get the traceback of an op in the latest version of the TF graph. - - Args: - op_name: Name of the op. - - Returns: - Creation traceback of the op, in the form of a list of 2-tuples: - (file_path, lineno) - - Raises: - ValueError: If the op with the given name cannot be found in the latest - version of the graph that this SourceManager instance has received, or - if this SourceManager instance has not received any graph traceback yet. - """ - if not self._graph_traceback: - raise ValueError('No graph traceback has been received yet.') - for op_log_entry in self._graph_traceback.log_entries: - if op_log_entry.name == op_name: - return self._code_def_to_traceback_list(op_log_entry.code_def) - raise ValueError( - 'No op named "%s" can be found in the graph of the latest version ' - ' (%d).' % (op_name, self._graph_version)) - - def get_file_tracebacks(self, file_path): - """Get the lists of ops created at lines of a specified source file. - - Args: - file_path: Path to the source file. - - Returns: - A dict mapping line number to a list of 2-tuples, - `(op_name, stack_position)` - `op_name` is the name of the name of the op whose creation traceback - includes the line. - `stack_position` is the position of the line in the op's creation - traceback, represented as a 0-based integer. - - Raises: - ValueError: If `file_path` does not point to a source file that has been - received by this instance of `SourceManager`. - """ - if file_path not in self._source_file_content: - raise ValueError( - 'Source file of path "%s" has not been received by this instance of ' - 'SourceManager.' % file_path) - - lineno_to_op_names_and_stack_position = dict() - for op_log_entry in self._graph_traceback.log_entries: - for stack_pos, trace in enumerate(op_log_entry.code_def.traces): - if self._graph_traceback.id_to_string[trace.file_id] == file_path: - if trace.lineno not in lineno_to_op_names_and_stack_position: - lineno_to_op_names_and_stack_position[trace.lineno] = [] - lineno_to_op_names_and_stack_position[trace.lineno].append( - (op_log_entry.name, stack_pos)) - return lineno_to_op_names_and_stack_position - - def _code_def_to_traceback_list(self, code_def): - return [ - (self._graph_traceback.id_to_string[trace.file_id], trace.lineno) - for trace in code_def.traces] + """Manages source files and tracebacks involved in the debugged TF + program.""" + + def __init__(self): + # A dict mapping file path to file content as a list of strings. + self._source_file_content = dict() + # A dict mapping file path to host name. + self._source_file_host = dict() + # A dict mapping file path to last modified timestamp. + self._source_file_last_modified = dict() + # A dict mapping file path to size in bytes. + self._source_file_bytes = dict() + # Keeps track f the traceback of the latest graph version. + self._graph_traceback = None + self._graph_version = -1 + + def add_debugged_source_file(self, debugged_source_file): + """Add a DebuggedSourceFile proto.""" + # TODO(cais): Should the key include a host name, for certain distributed + # cases? + key = debugged_source_file.file_path + self._source_file_host[key] = debugged_source_file.host + self._source_file_last_modified[ + key + ] = debugged_source_file.last_modified + self._source_file_bytes[key] = debugged_source_file.bytes + self._source_file_content[key] = debugged_source_file.lines + + def add_graph_traceback(self, graph_version, graph_traceback): + if graph_version > self._graph_version: + self._graph_traceback = graph_traceback + self._graph_version = graph_version + + def get_paths(self): + """Get the paths to all available source files.""" + return list(self._source_file_content.keys()) + + def get_content(self, file_path): + """Get the content of a source file. + + # TODO(cais): Maybe support getting a range of lines by line number. + + Args: + file_path: Path to the source file. + """ + return self._source_file_content[file_path] + + def get_op_traceback(self, op_name): + """Get the traceback of an op in the latest version of the TF graph. + + Args: + op_name: Name of the op. + + Returns: + Creation traceback of the op, in the form of a list of 2-tuples: + (file_path, lineno) + + Raises: + ValueError: If the op with the given name cannot be found in the latest + version of the graph that this SourceManager instance has received, or + if this SourceManager instance has not received any graph traceback yet. + """ + if not self._graph_traceback: + raise ValueError("No graph traceback has been received yet.") + for op_log_entry in self._graph_traceback.log_entries: + if op_log_entry.name == op_name: + return self._code_def_to_traceback_list(op_log_entry.code_def) + raise ValueError( + 'No op named "%s" can be found in the graph of the latest version ' + " (%d)." % (op_name, self._graph_version) + ) + + def get_file_tracebacks(self, file_path): + """Get the lists of ops created at lines of a specified source file. + + Args: + file_path: Path to the source file. + + Returns: + A dict mapping line number to a list of 2-tuples, + `(op_name, stack_position)` + `op_name` is the name of the name of the op whose creation traceback + includes the line. + `stack_position` is the position of the line in the op's creation + traceback, represented as a 0-based integer. + + Raises: + ValueError: If `file_path` does not point to a source file that has been + received by this instance of `SourceManager`. + """ + if file_path not in self._source_file_content: + raise ValueError( + 'Source file of path "%s" has not been received by this instance of ' + "SourceManager." % file_path + ) + + lineno_to_op_names_and_stack_position = dict() + for op_log_entry in self._graph_traceback.log_entries: + for stack_pos, trace in enumerate(op_log_entry.code_def.traces): + if ( + self._graph_traceback.id_to_string[trace.file_id] + == file_path + ): + if ( + trace.lineno + not in lineno_to_op_names_and_stack_position + ): + lineno_to_op_names_and_stack_position[trace.lineno] = [] + lineno_to_op_names_and_stack_position[trace.lineno].append( + (op_log_entry.name, stack_pos) + ) + return lineno_to_op_names_and_stack_position + + def _code_def_to_traceback_list(self, code_def): + return [ + (self._graph_traceback.id_to_string[trace.file_id], trace.lineno) + for trace in code_def.traces + ] class InteractiveDebuggerDataServer( - grpc_debug_server.EventListenerBaseServicer): - """A service that receives and writes debugger data such as health pills. - """ - - def __init__(self, receive_port): - """Receives health pills from a debugger and writes them to disk. - - Args: - receive_port: The port at which to receive health pills from the - TensorFlow debugger. - always_flush: A boolean indicating whether the EventsWriter will be - flushed after every write. Can be used for testing. - """ - super(InteractiveDebuggerDataServer, self).__init__( - receive_port, InteractiveDebuggerDataStreamHandler) - - self._incoming_channel = queue.Queue() - self._outgoing_channel = comm_channel_lib.CommChannel() - self._run_states = RunStates(breakpoints_func=lambda: self.breakpoints) - self._tensor_store = tensor_store_lib.TensorStore() - self._source_manager = SourceManager() - - curried_handler_constructor = functools.partial( - InteractiveDebuggerDataStreamHandler, - self._incoming_channel, self._outgoing_channel, self._run_states, - self._tensor_store) - grpc_debug_server.EventListenerBaseServicer.__init__( - self, receive_port, curried_handler_constructor) - - def SendTracebacks(self, request, context): - self._source_manager.add_graph_traceback(request.graph_version, - request.graph_traceback) - return debug_service_pb2.EventReply() - - def SendSourceFiles(self, request, context): - # TODO(cais): Handle case in which the size of the request is greater than - # the 4-MB gRPC limit. - for source_file in request.source_files: - self._source_manager.add_debugged_source_file(source_file) - return debug_service_pb2.EventReply() - - def get_graphs(self, run_key, debug=False): - return self._run_states.get_graphs(run_key, debug=debug) - - def get_graph(self, run_key, device_name, debug=False): - return self._run_states.get_graph(run_key, device_name, debug=debug) - - def get_gated_grpc_tensors(self, run_key, device_name): - return self._run_states.get_gated_grpc_tensors(run_key, device_name) - - def get_outgoing_message(self, pos): - msg, _ = self._outgoing_channel.get(pos) - return msg - - def put_incoming_message(self, message): - return self._incoming_channel.put(message) - - def query_tensor_store(self, - watch_key, - time_indices=None, - slicing=None, - mapping=None): - """Query tensor store for a given debugged tensor value. - - Args: - watch_key: The watch key of the debugged tensor being sought. Format: - :: - E.g., Dense_1/MatMul:0:DebugIdentity. - time_indices: Optional time indices string By default, the lastest time - index ('-1') is returned. - slicing: Optional slicing string. - mapping: Optional mapping string, e.g., 'image/png'. - - Returns: - If mapping is `None`, the possibly sliced values as a nested list of - values or its mapped format. A `list` of nested `list` of values, - If mapping is not `None`, the format of the return value will depend on - the mapping. - """ - return self._tensor_store.query(watch_key, - time_indices=time_indices, - slicing=slicing, - mapping=mapping) - - def query_source_file_paths(self): - """Query the source files involved in the current debugged TF program. - - Returns: - A `list` of file paths. The files that belong to the TensorFlow Python - library itself are *not* included. - """ - return self._source_manager.get_paths() - - def query_source_file_content(self, file_path): - """Query the content of a given source file. - - # TODO(cais): Allow query only a range of the source lines. - - Returns: - The source lines as a list of `str`. - """ - return list(self._source_manager.get_content(file_path)) - - def query_op_traceback(self, op_name): - """Query the tracebacks of ops in a TensorFlow graph. - - Returns: - TODO(cais): - """ - return self._source_manager.get_op_traceback(op_name) - - def query_file_tracebacks(self, file_path): - """Query the lists of ops created at lines of a given source file. - - Args: - file_path: Path to the source file to get the tracebacks for. - - Returns: - A `dict` mapping line number in the specified source file to a list of - 2-tuples: - `(op_name, stack_position)`. - `op_name` is the name of the name of the op whose creation traceback - includes the line. - `stack_position` is the position of the line in the op's creation - traceback, represented as a 0-based integer. - """ - return self._source_manager.get_file_tracebacks(file_path) - - def dispose(self): - """Disposes of this object. Call only after this is done being used.""" - self._tensor_store.dispose() + grpc_debug_server.EventListenerBaseServicer +): + """A service that receives and writes debugger data such as health + pills.""" + + def __init__(self, receive_port): + """Receives health pills from a debugger and writes them to disk. + + Args: + receive_port: The port at which to receive health pills from the + TensorFlow debugger. + always_flush: A boolean indicating whether the EventsWriter will be + flushed after every write. Can be used for testing. + """ + super(InteractiveDebuggerDataServer, self).__init__( + receive_port, InteractiveDebuggerDataStreamHandler + ) + + self._incoming_channel = queue.Queue() + self._outgoing_channel = comm_channel_lib.CommChannel() + self._run_states = RunStates(breakpoints_func=lambda: self.breakpoints) + self._tensor_store = tensor_store_lib.TensorStore() + self._source_manager = SourceManager() + + curried_handler_constructor = functools.partial( + InteractiveDebuggerDataStreamHandler, + self._incoming_channel, + self._outgoing_channel, + self._run_states, + self._tensor_store, + ) + grpc_debug_server.EventListenerBaseServicer.__init__( + self, receive_port, curried_handler_constructor + ) + + def SendTracebacks(self, request, context): + self._source_manager.add_graph_traceback( + request.graph_version, request.graph_traceback + ) + return debug_service_pb2.EventReply() + + def SendSourceFiles(self, request, context): + # TODO(cais): Handle case in which the size of the request is greater than + # the 4-MB gRPC limit. + for source_file in request.source_files: + self._source_manager.add_debugged_source_file(source_file) + return debug_service_pb2.EventReply() + + def get_graphs(self, run_key, debug=False): + return self._run_states.get_graphs(run_key, debug=debug) + + def get_graph(self, run_key, device_name, debug=False): + return self._run_states.get_graph(run_key, device_name, debug=debug) + + def get_gated_grpc_tensors(self, run_key, device_name): + return self._run_states.get_gated_grpc_tensors(run_key, device_name) + + def get_outgoing_message(self, pos): + msg, _ = self._outgoing_channel.get(pos) + return msg + + def put_incoming_message(self, message): + return self._incoming_channel.put(message) + + def query_tensor_store( + self, watch_key, time_indices=None, slicing=None, mapping=None + ): + """Query tensor store for a given debugged tensor value. + + Args: + watch_key: The watch key of the debugged tensor being sought. Format: + :: + E.g., Dense_1/MatMul:0:DebugIdentity. + time_indices: Optional time indices string By default, the lastest time + index ('-1') is returned. + slicing: Optional slicing string. + mapping: Optional mapping string, e.g., 'image/png'. + + Returns: + If mapping is `None`, the possibly sliced values as a nested list of + values or its mapped format. A `list` of nested `list` of values, + If mapping is not `None`, the format of the return value will depend on + the mapping. + """ + return self._tensor_store.query( + watch_key, + time_indices=time_indices, + slicing=slicing, + mapping=mapping, + ) + + def query_source_file_paths(self): + """Query the source files involved in the current debugged TF program. + + Returns: + A `list` of file paths. The files that belong to the TensorFlow Python + library itself are *not* included. + """ + return self._source_manager.get_paths() + + def query_source_file_content(self, file_path): + """Query the content of a given source file. + + # TODO(cais): Allow query only a range of the source lines. + + Returns: + The source lines as a list of `str`. + """ + return list(self._source_manager.get_content(file_path)) + + def query_op_traceback(self, op_name): + """Query the tracebacks of ops in a TensorFlow graph. + + Returns: + TODO(cais): + """ + return self._source_manager.get_op_traceback(op_name) + + def query_file_tracebacks(self, file_path): + """Query the lists of ops created at lines of a given source file. + + Args: + file_path: Path to the source file to get the tracebacks for. + + Returns: + A `dict` mapping line number in the specified source file to a list of + 2-tuples: + `(op_name, stack_position)`. + `op_name` is the name of the name of the op whose creation traceback + includes the line. + `stack_position` is the position of the line in the op's creation + traceback, represented as a 0-based integer. + """ + return self._source_manager.get_file_tracebacks(file_path) + + def dispose(self): + """Disposes of this object. + + Call only after this is done being used. + """ + self._tensor_store.dispose() diff --git a/tensorboard/plugins/debugger/numerics_alert.py b/tensorboard/plugins/debugger/numerics_alert.py index 6b01ea370b..0ac29dc448 100644 --- a/tensorboard/plugins/debugger/numerics_alert.py +++ b/tensorboard/plugins/debugger/numerics_alert.py @@ -40,303 +40,355 @@ # `NumericsAlert` events of the corresponding categories. NumericsAlert = collections.namedtuple( "NumericsAlert", - ["device_name", "tensor_name", "timestamp", "nan_count", "neg_inf_count", - "pos_inf_count"]) + [ + "device_name", + "tensor_name", + "timestamp", + "nan_count", + "neg_inf_count", + "pos_inf_count", + ], +) NumericsAlertReportRow = collections.namedtuple( "NumericsAlertReportRow", - ["device_name", "tensor_name", "first_timestamp", "nan_event_count", - "neg_inf_event_count", "pos_inf_event_count"]) + [ + "device_name", + "tensor_name", + "first_timestamp", + "nan_event_count", + "neg_inf_event_count", + "pos_inf_event_count", + ], +) # Used to reconstruct an _EventTracker from data read from disk. When updating # this named tuple, make sure to keep the properties of _EventTracker in sync. EventTrackerDescription = collections.namedtuple( "EventTrackerDescription", - ["event_count", "first_timestamp", "last_timestamp"]) + ["event_count", "first_timestamp", "last_timestamp"], +) # Used to reconstruct NumericsAlertHistory. HistoryTriplet = collections.namedtuple( - "HistoryTriplet", - ["device", "tensor", "jsonable_history"]) + "HistoryTriplet", ["device", "tensor", "jsonable_history"] +) class _EventTracker(object): - """Track events for a single category of values (NaN, -Inf, or +Inf).""" - - def __init__(self, event_count=0, first_timestamp=-1, last_timestamp=-1): - """Tracks events for a single category of values. - - Args: - event_count: The initial event count to use. - first_timestamp: The timestamp of the first event with this value. - last_timestamp: The timestamp of the last event with this category of - values. - """ - - # When updating the properties of this class, make sure to keep - # EventTrackerDescription in sync so that data can be written to and from - # disk correctly. - self.event_count = event_count - self.first_timestamp = first_timestamp - self.last_timestamp = last_timestamp - - def add(self, timestamp): - if self.event_count == 0: - self.first_timestamp = timestamp - self.last_timestamp = timestamp - else: - if timestamp < self.first_timestamp: - self.first_timestamp = timestamp - if timestamp > self.last_timestamp: - self.last_timestamp = timestamp - self.event_count += 1 - - def get_description(self): - return EventTrackerDescription( - self.event_count, self.first_timestamp, self.last_timestamp) + """Track events for a single category of values (NaN, -Inf, or +Inf).""" + + def __init__(self, event_count=0, first_timestamp=-1, last_timestamp=-1): + """Tracks events for a single category of values. + + Args: + event_count: The initial event count to use. + first_timestamp: The timestamp of the first event with this value. + last_timestamp: The timestamp of the last event with this category of + values. + """ + + # When updating the properties of this class, make sure to keep + # EventTrackerDescription in sync so that data can be written to and from + # disk correctly. + self.event_count = event_count + self.first_timestamp = first_timestamp + self.last_timestamp = last_timestamp + + def add(self, timestamp): + if self.event_count == 0: + self.first_timestamp = timestamp + self.last_timestamp = timestamp + else: + if timestamp < self.first_timestamp: + self.first_timestamp = timestamp + if timestamp > self.last_timestamp: + self.last_timestamp = timestamp + self.event_count += 1 + + def get_description(self): + return EventTrackerDescription( + self.event_count, self.first_timestamp, self.last_timestamp + ) class NumericsAlertHistory(object): - """History of numerics alerts.""" - - def __init__(self, initialization_list=None): - """Stores alert history for a single device, tensor pair. - - Args: - initialization_list: (`list`) An optional list parsed from JSON read - from disk. That entity is used to initialize this NumericsAlertHistory. - Use the create_jsonable_object method of this class to create such an - object. - """ - if initialization_list: - # Use data to initialize this NumericsAlertHistory. - self._trackers = {} - for value_category_key, description_list in initialization_list.items(): - description = EventTrackerDescription._make(description_list) - self._trackers[value_category_key] = _EventTracker( - event_count=description.event_count, - first_timestamp=description.first_timestamp, - last_timestamp=description.last_timestamp) - else: - # Start cleanly. With no prior data. - self._trackers = { - constants.NAN_KEY: _EventTracker(), - constants.NEG_INF_KEY: _EventTracker(), - constants.POS_INF_KEY: _EventTracker(), - } - - def add(self, numerics_alert): - if numerics_alert.nan_count: - self._trackers[constants.NAN_KEY].add(numerics_alert.timestamp) - if numerics_alert.neg_inf_count: - self._trackers[constants.NEG_INF_KEY].add(numerics_alert.timestamp) - if numerics_alert.pos_inf_count: - self._trackers[constants.POS_INF_KEY].add(numerics_alert.timestamp) - - def first_timestamp(self, event_key=None): - """Obtain the first timestamp. - - Args: - event_key: the type key of the sought events (e.g., constants.NAN_KEY). - If None, includes all event type keys. - - Returns: - First (earliest) timestamp of all the events of the given type (or all - event types if event_key is None). - """ - if event_key is None: - timestamps = [self._trackers[key].first_timestamp - for key in self._trackers] - return min(timestamp for timestamp in timestamps if timestamp >= 0) - else: - return self._trackers[event_key].first_timestamp - - def last_timestamp(self, event_key=None): - """Obtain the last timestamp. - - Args: - event_key: the type key of the sought events (e.g., constants.NAN_KEY). If - None, includes all event type keys. - - Returns: - Last (latest) timestamp of all the events of the given type (or all - event types if event_key is None). - """ - if event_key is None: - timestamps = [self._trackers[key].first_timestamp - for key in self._trackers] - return max(timestamp for timestamp in timestamps if timestamp >= 0) - else: - return self._trackers[event_key].last_timestamp - - def event_count(self, event_key): - """Obtain event count. - - Args: - event_key: the type key of the sought events (e.g., constants.NAN_KEY). If - None, includes all event type keys. - - Returns: - If event_key is None, return the sum of the event_count of all event - types. Otherwise, return the event_count of the specified event type. - """ - return self._trackers[event_key].event_count - - def create_jsonable_history(self): - """Creates a JSON-able representation of this object. - - Returns: - A dictionary mapping key to EventTrackerDescription (which can be used to - create event trackers). - """ - return {value_category_key: tracker.get_description() - for (value_category_key, tracker) in self._trackers.items()} + """History of numerics alerts.""" + + def __init__(self, initialization_list=None): + """Stores alert history for a single device, tensor pair. + + Args: + initialization_list: (`list`) An optional list parsed from JSON read + from disk. That entity is used to initialize this NumericsAlertHistory. + Use the create_jsonable_object method of this class to create such an + object. + """ + if initialization_list: + # Use data to initialize this NumericsAlertHistory. + self._trackers = {} + for ( + value_category_key, + description_list, + ) in initialization_list.items(): + description = EventTrackerDescription._make(description_list) + self._trackers[value_category_key] = _EventTracker( + event_count=description.event_count, + first_timestamp=description.first_timestamp, + last_timestamp=description.last_timestamp, + ) + else: + # Start cleanly. With no prior data. + self._trackers = { + constants.NAN_KEY: _EventTracker(), + constants.NEG_INF_KEY: _EventTracker(), + constants.POS_INF_KEY: _EventTracker(), + } + + def add(self, numerics_alert): + if numerics_alert.nan_count: + self._trackers[constants.NAN_KEY].add(numerics_alert.timestamp) + if numerics_alert.neg_inf_count: + self._trackers[constants.NEG_INF_KEY].add(numerics_alert.timestamp) + if numerics_alert.pos_inf_count: + self._trackers[constants.POS_INF_KEY].add(numerics_alert.timestamp) + + def first_timestamp(self, event_key=None): + """Obtain the first timestamp. + + Args: + event_key: the type key of the sought events (e.g., constants.NAN_KEY). + If None, includes all event type keys. + + Returns: + First (earliest) timestamp of all the events of the given type (or all + event types if event_key is None). + """ + if event_key is None: + timestamps = [ + self._trackers[key].first_timestamp for key in self._trackers + ] + return min(timestamp for timestamp in timestamps if timestamp >= 0) + else: + return self._trackers[event_key].first_timestamp + + def last_timestamp(self, event_key=None): + """Obtain the last timestamp. + + Args: + event_key: the type key of the sought events (e.g., constants.NAN_KEY). If + None, includes all event type keys. + + Returns: + Last (latest) timestamp of all the events of the given type (or all + event types if event_key is None). + """ + if event_key is None: + timestamps = [ + self._trackers[key].first_timestamp for key in self._trackers + ] + return max(timestamp for timestamp in timestamps if timestamp >= 0) + else: + return self._trackers[event_key].last_timestamp + + def event_count(self, event_key): + """Obtain event count. + + Args: + event_key: the type key of the sought events (e.g., constants.NAN_KEY). If + None, includes all event type keys. + + Returns: + If event_key is None, return the sum of the event_count of all event + types. Otherwise, return the event_count of the specified event type. + """ + return self._trackers[event_key].event_count + + def create_jsonable_history(self): + """Creates a JSON-able representation of this object. + + Returns: + A dictionary mapping key to EventTrackerDescription (which can be used to + create event trackers). + """ + return { + value_category_key: tracker.get_description() + for (value_category_key, tracker) in self._trackers.items() + } class NumericsAlertRegistry(object): - """A registry for alerts on numerics (e.g., due to NaNs and infinities).""" - - def __init__(self, capacity=100, initialization_list=None): - """Constructor. + """A registry for alerts on numerics (e.g., due to NaNs and infinities).""" + + def __init__(self, capacity=100, initialization_list=None): + """Constructor. + + Args: + capacity: (`int`) maximum number of device-tensor keys to store. + initialization_list: (`list`) An optional list (parsed from JSON) that + is used to initialize the data within this registry. Use the + create_jsonable_registry method of NumericsAlertRegistry to create such + a list. + """ + self._capacity = capacity + + # A map from device-tensor key to a the TensorAlertRecord namedtuple. + # The device-tensor key is a 2-tuple of the format (device_name, node_name). + # E.g., ("/job:worker/replica:0/task:1/gpu:0", "cross_entropy/Log:0"). + self._data = dict() + + if initialization_list: + # Initialize the alert registry using the data passed in. This might be + # backup data used to restore the registry after say a borg pre-emption. + for entry in initialization_list: + triplet = HistoryTriplet._make(entry) + self._data[ + (triplet.device, triplet.tensor) + ] = NumericsAlertHistory( + initialization_list=triplet.jsonable_history + ) + + def register(self, numerics_alert): + """Register an alerting numeric event. + + Args: + numerics_alert: An instance of `NumericsAlert`. + """ + key = (numerics_alert.device_name, numerics_alert.tensor_name) + if key in self._data: + self._data[key].add(numerics_alert) + else: + if len(self._data) < self._capacity: + history = NumericsAlertHistory() + history.add(numerics_alert) + self._data[key] = history + + def report(self, device_name_filter=None, tensor_name_filter=None): + """Get a report of offending device/tensor names. + + The report includes information about the device name, tensor name, first + (earliest) timestamp of the alerting events from the tensor, in addition to + counts of nan, positive inf and negative inf events. + + Args: + device_name_filter: regex filter for device name, or None (not filtered). + tensor_name_filter: regex filter for tensor name, or None (not filtered). + + Returns: + A list of NumericsAlertReportRow, sorted by the first_timestamp in + asecnding order. + """ + report = [] + for key in self._data: + device_name, tensor_name = key + history = self._data[key] + report.append( + NumericsAlertReportRow( + device_name=device_name, + tensor_name=tensor_name, + first_timestamp=history.first_timestamp(), + nan_event_count=history.event_count(constants.NAN_KEY), + neg_inf_event_count=history.event_count( + constants.NEG_INF_KEY + ), + pos_inf_event_count=history.event_count( + constants.POS_INF_KEY + ), + ) + ) + + if device_name_filter: + device_name_pattern = re.compile(device_name_filter) + report = [ + item + for item in report + if device_name_pattern.match(item.device_name) + ] + if tensor_name_filter: + tensor_name_pattern = re.compile(tensor_name_filter) + report = [ + item + for item in report + if tensor_name_pattern.match(item.tensor_name) + ] + # Sort results chronologically. + return sorted(report, key=lambda x: x.first_timestamp) + + def create_jsonable_registry(self): + """Creates a JSON-able representation of this object. + + Returns: + A dictionary mapping (device, tensor name) to JSON-able object + representations of NumericsAlertHistory. + """ + # JSON does not support tuples as keys. Only strings. Therefore, we store + # the device name, tensor name, and dictionary data within a 3-item list. + return [ + HistoryTriplet(pair[0], pair[1], history.create_jsonable_history()) + for (pair, history) in self._data.items() + ] - Args: - capacity: (`int`) maximum number of device-tensor keys to store. - initialization_list: (`list`) An optional list (parsed from JSON) that - is used to initialize the data within this registry. Use the - create_jsonable_registry method of NumericsAlertRegistry to create such - a list. - """ - self._capacity = capacity - # A map from device-tensor key to a the TensorAlertRecord namedtuple. - # The device-tensor key is a 2-tuple of the format (device_name, node_name). - # E.g., ("/job:worker/replica:0/task:1/gpu:0", "cross_entropy/Log:0"). - self._data = dict() - - if initialization_list: - # Initialize the alert registry using the data passed in. This might be - # backup data used to restore the registry after say a borg pre-emption. - for entry in initialization_list: - triplet = HistoryTriplet._make(entry) - self._data[(triplet.device, triplet.tensor)] = NumericsAlertHistory( - initialization_list=triplet.jsonable_history) - - def register(self, numerics_alert): - """Register an alerting numeric event. +def extract_numerics_alert(event): + """Determines whether a health pill event contains bad values. - Args: - numerics_alert: An instance of `NumericsAlert`. - """ - key = (numerics_alert.device_name, numerics_alert.tensor_name) - if key in self._data: - self._data[key].add(numerics_alert) - else: - if len(self._data) < self._capacity: - history = NumericsAlertHistory() - history.add(numerics_alert) - self._data[key] = history - - def report(self, device_name_filter=None, tensor_name_filter=None): - """Get a report of offending device/tensor names. - - The report includes information about the device name, tensor name, first - (earliest) timestamp of the alerting events from the tensor, in addition to - counts of nan, positive inf and negative inf events. + A bad value is one of NaN, -Inf, or +Inf. Args: - device_name_filter: regex filter for device name, or None (not filtered). - tensor_name_filter: regex filter for tensor name, or None (not filtered). + event: (`Event`) A `tensorflow.Event` proto from `DebugNumericSummary` + ops. Returns: - A list of NumericsAlertReportRow, sorted by the first_timestamp in - asecnding order. - """ - report = [] - for key in self._data: - device_name, tensor_name = key - history = self._data[key] - report.append( - NumericsAlertReportRow( - device_name=device_name, - tensor_name=tensor_name, - first_timestamp=history.first_timestamp(), - nan_event_count=history.event_count(constants.NAN_KEY), - neg_inf_event_count=history.event_count(constants.NEG_INF_KEY), - pos_inf_event_count=history.event_count(constants.POS_INF_KEY))) - - if device_name_filter: - device_name_pattern = re.compile(device_name_filter) - report = [item for item in report - if device_name_pattern.match(item.device_name)] - if tensor_name_filter: - tensor_name_pattern = re.compile(tensor_name_filter) - report = [item for item in report - if tensor_name_pattern.match(item.tensor_name)] - # Sort results chronologically. - return sorted(report, key=lambda x: x.first_timestamp) - - def create_jsonable_registry(self): - """Creates a JSON-able representation of this object. + An instance of `NumericsAlert`, if bad values are found. + `None`, if no bad values are found. - Returns: - A dictionary mapping (device, tensor name) to JSON-able object - representations of NumericsAlertHistory. + Raises: + ValueError: if the event does not have the expected tag prefix or the + debug op name is not the expected debug op name suffix. """ - # JSON does not support tuples as keys. Only strings. Therefore, we store - # the device name, tensor name, and dictionary data within a 3-item list. - return [HistoryTriplet(pair[0], pair[1], history.create_jsonable_history()) - for (pair, history) in self._data.items()] - - -def extract_numerics_alert(event): - """Determines whether a health pill event contains bad values. - - A bad value is one of NaN, -Inf, or +Inf. - - Args: - event: (`Event`) A `tensorflow.Event` proto from `DebugNumericSummary` - ops. - - Returns: - An instance of `NumericsAlert`, if bad values are found. - `None`, if no bad values are found. - - Raises: - ValueError: if the event does not have the expected tag prefix or the - debug op name is not the expected debug op name suffix. - """ - value = event.summary.value[0] - debugger_plugin_metadata_content = None - if value.HasField("metadata"): - plugin_data = value.metadata.plugin_data - if plugin_data.plugin_name == constants.DEBUGGER_PLUGIN_NAME: - debugger_plugin_metadata_content = plugin_data.content - - if not debugger_plugin_metadata_content: - raise ValueError("Event proto input lacks debugger plugin SummaryMetadata.") - - debugger_plugin_metadata_content = tf.compat.as_text( - debugger_plugin_metadata_content) - try: - content_object = json.loads(debugger_plugin_metadata_content) - device_name = content_object["device"] - except (KeyError, ValueError) as e: - raise ValueError("Could not determine device from JSON string %r, %r" % - (debugger_plugin_metadata_content, e)) - - debug_op_suffix = ":DebugNumericSummary" - if not value.node_name.endswith(debug_op_suffix): - raise ValueError( - "Event proto input does not have the expected debug op suffix %s" % - debug_op_suffix) - tensor_name = value.node_name[:-len(debug_op_suffix)] - - elements = tf_debug.load_tensor_from_event(event) - nan_count = elements[constants.NAN_NUMERIC_SUMMARY_OP_INDEX] - neg_inf_count = elements[constants.NEG_INF_NUMERIC_SUMMARY_OP_INDEX] - pos_inf_count = elements[constants.POS_INF_NUMERIC_SUMMARY_OP_INDEX] - if nan_count > 0 or neg_inf_count > 0 or pos_inf_count > 0: - return NumericsAlert( - device_name, tensor_name, event.wall_time, nan_count, neg_inf_count, - pos_inf_count) - return None + value = event.summary.value[0] + debugger_plugin_metadata_content = None + if value.HasField("metadata"): + plugin_data = value.metadata.plugin_data + if plugin_data.plugin_name == constants.DEBUGGER_PLUGIN_NAME: + debugger_plugin_metadata_content = plugin_data.content + + if not debugger_plugin_metadata_content: + raise ValueError( + "Event proto input lacks debugger plugin SummaryMetadata." + ) + + debugger_plugin_metadata_content = tf.compat.as_text( + debugger_plugin_metadata_content + ) + try: + content_object = json.loads(debugger_plugin_metadata_content) + device_name = content_object["device"] + except (KeyError, ValueError) as e: + raise ValueError( + "Could not determine device from JSON string %r, %r" + % (debugger_plugin_metadata_content, e) + ) + + debug_op_suffix = ":DebugNumericSummary" + if not value.node_name.endswith(debug_op_suffix): + raise ValueError( + "Event proto input does not have the expected debug op suffix %s" + % debug_op_suffix + ) + tensor_name = value.node_name[: -len(debug_op_suffix)] + + elements = tf_debug.load_tensor_from_event(event) + nan_count = elements[constants.NAN_NUMERIC_SUMMARY_OP_INDEX] + neg_inf_count = elements[constants.NEG_INF_NUMERIC_SUMMARY_OP_INDEX] + pos_inf_count = elements[constants.POS_INF_NUMERIC_SUMMARY_OP_INDEX] + if nan_count > 0 or neg_inf_count > 0 or pos_inf_count > 0: + return NumericsAlert( + device_name, + tensor_name, + event.wall_time, + nan_count, + neg_inf_count, + pos_inf_count, + ) + return None diff --git a/tensorboard/plugins/debugger/numerics_alert_test.py b/tensorboard/plugins/debugger/numerics_alert_test.py index 50d2946e7d..bcf5c639e5 100644 --- a/tensorboard/plugins/debugger/numerics_alert_test.py +++ b/tensorboard/plugins/debugger/numerics_alert_test.py @@ -25,217 +25,367 @@ class NumericAlertHistoryTest(tf.test.TestCase): + def testConstructFromOneAlert(self): + alert = numerics_alert.NumericsAlert( + "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 10, 0, 10 + ) + history = numerics_alert.NumericsAlertHistory() + history.add(alert) + self.assertEqual(1234, history.first_timestamp()) + self.assertEqual(1234, history.last_timestamp()) + self.assertEqual(1, history.event_count(constants.NAN_KEY)) + self.assertEqual(0, history.event_count(constants.NEG_INF_KEY)) + self.assertEqual(1, history.event_count(constants.POS_INF_KEY)) - def testConstructFromOneAlert(self): - alert = numerics_alert.NumericsAlert( - "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 10, 0, 10) - history = numerics_alert.NumericsAlertHistory() - history.add(alert) - self.assertEqual(1234, history.first_timestamp()) - self.assertEqual(1234, history.last_timestamp()) - self.assertEqual(1, history.event_count(constants.NAN_KEY)) - self.assertEqual(0, history.event_count(constants.NEG_INF_KEY)) - self.assertEqual(1, history.event_count(constants.POS_INF_KEY)) - - def testAddAlertInChronologicalOrder(self): - history = numerics_alert.NumericsAlertHistory() - alert_1 = numerics_alert.NumericsAlert( - "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 10, 0, 10) - history.add(alert_1) - - alert_2 = numerics_alert.NumericsAlert( - "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1240, 20, 20, 0) - history.add(alert_2) - - self.assertEqual(1234, history.first_timestamp()) - self.assertEqual(1240, history.last_timestamp()) - self.assertEqual(2, history.event_count(constants.NAN_KEY)) - self.assertEqual(1, history.event_count(constants.NEG_INF_KEY)) - self.assertEqual(1, history.event_count(constants.POS_INF_KEY)) - - def testAddAlertInReverseChronologicalOrder(self): - history = numerics_alert.NumericsAlertHistory() - alert_1 = numerics_alert.NumericsAlert( - "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 10, 0, 10) - history.add(alert_1) - - alert_2 = numerics_alert.NumericsAlert( - "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1220, 20, 20, 0) - history.add(alert_2) - - self.assertEqual(1220, history.first_timestamp()) - self.assertEqual(1234, history.last_timestamp()) - self.assertEqual(2, history.event_count(constants.NAN_KEY)) - self.assertEqual(1, history.event_count(constants.NEG_INF_KEY)) - self.assertEqual(1, history.event_count(constants.POS_INF_KEY)) + def testAddAlertInChronologicalOrder(self): + history = numerics_alert.NumericsAlertHistory() + alert_1 = numerics_alert.NumericsAlert( + "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 10, 0, 10 + ) + history.add(alert_1) + + alert_2 = numerics_alert.NumericsAlert( + "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1240, 20, 20, 0 + ) + history.add(alert_2) + + self.assertEqual(1234, history.first_timestamp()) + self.assertEqual(1240, history.last_timestamp()) + self.assertEqual(2, history.event_count(constants.NAN_KEY)) + self.assertEqual(1, history.event_count(constants.NEG_INF_KEY)) + self.assertEqual(1, history.event_count(constants.POS_INF_KEY)) + + def testAddAlertInReverseChronologicalOrder(self): + history = numerics_alert.NumericsAlertHistory() + alert_1 = numerics_alert.NumericsAlert( + "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 10, 0, 10 + ) + history.add(alert_1) + + alert_2 = numerics_alert.NumericsAlert( + "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1220, 20, 20, 0 + ) + history.add(alert_2) + + self.assertEqual(1220, history.first_timestamp()) + self.assertEqual(1234, history.last_timestamp()) + self.assertEqual(2, history.event_count(constants.NAN_KEY)) + self.assertEqual(1, history.event_count(constants.NEG_INF_KEY)) + self.assertEqual(1, history.event_count(constants.POS_INF_KEY)) class NumericsAlertRegistryTest(tf.test.TestCase): + def testNoAlert(self): + registry = numerics_alert.NumericsAlertRegistry() + self.assertEqual([], registry.report()) + + def testSingleAlert(self): + alert = numerics_alert.NumericsAlert( + "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 0, 10, 10 + ) + registry = numerics_alert.NumericsAlertRegistry() + registry.register(alert) + self.assertEqual( + [ + numerics_alert.NumericsAlertReportRow( + "/job:worker/replica:0/task:0/gpu:0", + "xent/Log:0", + 1234, + 0, + 1, + 1, + ) + ], + registry.report(), + ) + + def testMultipleEventsFromSameDeviceAndSameTensor(self): + alert_1 = numerics_alert.NumericsAlert( + "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 0, 10, 10 + ) + alert_2 = numerics_alert.NumericsAlert( + "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1634, 5, 5, 5 + ) + registry = numerics_alert.NumericsAlertRegistry() + registry.register(alert_1) + registry.register(alert_2) + self.assertEqual( + [ + numerics_alert.NumericsAlertReportRow( + "/job:worker/replica:0/task:0/gpu:0", + "xent/Log:0", + 1234, + 1, + 2, + 2, + ) + ], + registry.report(), + ) + + def testMultipleEventsFromSameDeviceAndDifferentTensor(self): + alert_1 = numerics_alert.NumericsAlert( + "/job:worker/replica:0/task:0/gpu:0", "div:0", 1434, 0, 1, 1 + ) + alert_2 = numerics_alert.NumericsAlert( + "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 0, 2, 2 + ) + alert_3 = numerics_alert.NumericsAlert( + "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1634, 3, 3, 3 + ) + registry = numerics_alert.NumericsAlertRegistry() + registry.register(alert_1) + registry.register(alert_2) + registry.register(alert_3) + self.assertEqual( + [ + numerics_alert.NumericsAlertReportRow( + "/job:worker/replica:0/task:0/gpu:0", + "xent/Log:0", + 1234, + 1, + 2, + 2, + ), + numerics_alert.NumericsAlertReportRow( + "/job:worker/replica:0/task:0/gpu:0", "div:0", 1434, 0, 1, 1 + ), + ], + registry.report(), + ) + + def testMultipleEventsFromDifferentDevicesAndSameTensorName(self): + alert_1 = numerics_alert.NumericsAlert( + "/job:worker/replica:0/task:1/gpu:0", "xent/Log:0", 1434, 0, 1, 1 + ) + alert_2 = numerics_alert.NumericsAlert( + "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 0, 1, 1 + ) + alert_3 = numerics_alert.NumericsAlert( + "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1634, 1, 1, 1 + ) + registry = numerics_alert.NumericsAlertRegistry() + registry.register(alert_1) + registry.register(alert_2) + registry.register(alert_3) + self.assertEqual( + [ + numerics_alert.NumericsAlertReportRow( + "/job:worker/replica:0/task:0/gpu:0", + "xent/Log:0", + 1234, + 1, + 2, + 2, + ), + numerics_alert.NumericsAlertReportRow( + "/job:worker/replica:0/task:1/gpu:0", + "xent/Log:0", + 1434, + 0, + 1, + 1, + ), + ], + registry.report(), + ) + + def testFilterReport(self): + alert_1 = numerics_alert.NumericsAlert( + "/job:worker/replica:0/task:1/gpu:0", "xent/Log:0", 1434, 0, 1, 1 + ) + alert_2 = numerics_alert.NumericsAlert( + "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 0, 1, 1 + ) + alert_3 = numerics_alert.NumericsAlert( + "/job:worker/replica:0/task:0/gpu:0", "xent/Mean:0", 1634, 1, 1, 1 + ) + registry = numerics_alert.NumericsAlertRegistry() + registry.register(alert_1) + registry.register(alert_2) + registry.register(alert_3) + self.assertEqual( + [ + numerics_alert.NumericsAlertReportRow( + "/job:worker/replica:0/task:1/gpu:0", + "xent/Log:0", + 1434, + 0, + 1, + 1, + ) + ], + registry.report(device_name_filter=r".*\/task:1\/.*"), + ) + self.assertEqual( + [ + numerics_alert.NumericsAlertReportRow( + "/job:worker/replica:0/task:0/gpu:0", + "xent/Mean:0", + 1634, + 1, + 1, + 1, + ) + ], + registry.report(tensor_name_filter=r".*Mean.*"), + ) + + def testRegisterBeyondCapacityObeysCapacity(self): + alert_1 = numerics_alert.NumericsAlert( + "/job:worker/replica:0/task:1/gpu:0", "xent/Log:0", 1434, 0, 1, 1 + ) + alert_2 = numerics_alert.NumericsAlert( + "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 0, 1, 1 + ) + alert_3 = numerics_alert.NumericsAlert( + "/job:worker/replica:0/task:2/gpu:0", "xent/Log:0", 1634, 0, 1, 1 + ) + alert_4 = numerics_alert.NumericsAlert( + "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1834, 1, 1, 1 + ) + registry = numerics_alert.NumericsAlertRegistry(capacity=2) + registry.register(alert_1) + registry.register(alert_2) + self.assertEqual( + [ + numerics_alert.NumericsAlertReportRow( + "/job:worker/replica:0/task:0/gpu:0", + "xent/Log:0", + 1234, + 0, + 1, + 1, + ), + numerics_alert.NumericsAlertReportRow( + "/job:worker/replica:0/task:1/gpu:0", + "xent/Log:0", + 1434, + 0, + 1, + 1, + ), + ], + registry.report(), + ) + + registry.register(alert_3) + self.assertEqual( + [ + numerics_alert.NumericsAlertReportRow( + "/job:worker/replica:0/task:0/gpu:0", + "xent/Log:0", + 1234, + 0, + 1, + 1, + ), + numerics_alert.NumericsAlertReportRow( + "/job:worker/replica:0/task:1/gpu:0", + "xent/Log:0", + 1434, + 0, + 1, + 1, + ), + ], + registry.report(), + ) + + registry.register(alert_4) + self.assertEqual( + [ + numerics_alert.NumericsAlertReportRow( + "/job:worker/replica:0/task:0/gpu:0", + "xent/Log:0", + 1234, + 1, + 2, + 2, + ), + numerics_alert.NumericsAlertReportRow( + "/job:worker/replica:0/task:1/gpu:0", + "xent/Log:0", + 1434, + 0, + 1, + 1, + ), + ], + registry.report(), + ) + + def testCreateJsonableRegistry(self): + alert = numerics_alert.NumericsAlert( + "/job:worker/replica:0/task:1/gpu:0", "xent/Log:0", 1434, 0, 1, 1 + ) + registry = numerics_alert.NumericsAlertRegistry() + registry.register(alert) + + triplet_list = registry.create_jsonable_registry() + self.assertEqual(1, len(triplet_list)) + + triplet = triplet_list[0] + self.assertEqual("/job:worker/replica:0/task:1/gpu:0", triplet.device) + self.assertEqual("xent/Log:0", triplet.tensor) + self.assertListEqual([0, -1, -1], list(triplet.jsonable_history["nan"])) + self.assertListEqual( + [1, 1434, 1434], list(triplet.jsonable_history["neg_inf"]) + ) + self.assertListEqual( + [1, 1434, 1434], list(triplet.jsonable_history["pos_inf"]) + ) + + def testLoadFromJson(self): + registry = numerics_alert.NumericsAlertRegistry( + initialization_list=[ + [ + "/job:localhost/replica:0/task:0/cpu:0", + "MatMul:0", + { + "pos_inf": [0, -1, -1], + "nan": [1624, 1496818651573005, 1496818690371163], + "neg_inf": [0, -1, -1], + }, + ], + [ + "/job:localhost/replica:0/task:0/cpu:0", + "weight/Adagrad:0", + { + "pos_inf": [0, -1, -1], + "nan": [1621, 1496818651607234, 1496818690370891], + "neg_inf": [0, -1, -1], + }, + ], + ] + ) + self.assertEqual( + [ + numerics_alert.NumericsAlertReportRow( + "/job:localhost/replica:0/task:0/cpu:0", + "MatMul:0", + 1496818651573005, + 1624, + 0, + 0, + ), + numerics_alert.NumericsAlertReportRow( + "/job:localhost/replica:0/task:0/cpu:0", + "weight/Adagrad:0", + 1496818651607234, + 1621, + 0, + 0, + ), + ], + registry.report(), + ) - def testNoAlert(self): - registry = numerics_alert.NumericsAlertRegistry() - self.assertEqual([], registry.report()) - - def testSingleAlert(self): - alert = numerics_alert.NumericsAlert( - "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 0, 10, 10) - registry = numerics_alert.NumericsAlertRegistry() - registry.register(alert) - self.assertEqual( - [numerics_alert.NumericsAlertReportRow( - "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 0, 1, 1)], - registry.report()) - - def testMultipleEventsFromSameDeviceAndSameTensor(self): - alert_1 = numerics_alert.NumericsAlert( - "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 0, 10, 10) - alert_2 = numerics_alert.NumericsAlert( - "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1634, 5, 5, 5) - registry = numerics_alert.NumericsAlertRegistry() - registry.register(alert_1) - registry.register(alert_2) - self.assertEqual( - [numerics_alert.NumericsAlertReportRow( - "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 1, 2, 2)], - registry.report()) - - def testMultipleEventsFromSameDeviceAndDifferentTensor(self): - alert_1 = numerics_alert.NumericsAlert( - "/job:worker/replica:0/task:0/gpu:0", "div:0", 1434, 0, 1, 1) - alert_2 = numerics_alert.NumericsAlert( - "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 0, 2, 2) - alert_3 = numerics_alert.NumericsAlert( - "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1634, 3, 3, 3) - registry = numerics_alert.NumericsAlertRegistry() - registry.register(alert_1) - registry.register(alert_2) - registry.register(alert_3) - self.assertEqual( - [numerics_alert.NumericsAlertReportRow( - "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 1, 2, 2), - numerics_alert.NumericsAlertReportRow( - "/job:worker/replica:0/task:0/gpu:0", "div:0", 1434, 0, 1, 1)], - registry.report()) - - def testMultipleEventsFromDifferentDevicesAndSameTensorName(self): - alert_1 = numerics_alert.NumericsAlert( - "/job:worker/replica:0/task:1/gpu:0", "xent/Log:0", 1434, 0, 1, 1) - alert_2 = numerics_alert.NumericsAlert( - "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 0, 1, 1) - alert_3 = numerics_alert.NumericsAlert( - "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1634, 1, 1, 1) - registry = numerics_alert.NumericsAlertRegistry() - registry.register(alert_1) - registry.register(alert_2) - registry.register(alert_3) - self.assertEqual( - [numerics_alert.NumericsAlertReportRow( - "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 1, 2, 2), - numerics_alert.NumericsAlertReportRow( - "/job:worker/replica:0/task:1/gpu:0", "xent/Log:0", 1434, 0, 1, - 1)], registry.report()) - - def testFilterReport(self): - alert_1 = numerics_alert.NumericsAlert( - "/job:worker/replica:0/task:1/gpu:0", "xent/Log:0", 1434, 0, 1, 1) - alert_2 = numerics_alert.NumericsAlert( - "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 0, 1, 1) - alert_3 = numerics_alert.NumericsAlert( - "/job:worker/replica:0/task:0/gpu:0", "xent/Mean:0", 1634, 1, 1, 1) - registry = numerics_alert.NumericsAlertRegistry() - registry.register(alert_1) - registry.register(alert_2) - registry.register(alert_3) - self.assertEqual( - [numerics_alert.NumericsAlertReportRow( - "/job:worker/replica:0/task:1/gpu:0", "xent/Log:0", 1434, 0, 1, 1)], - registry.report(device_name_filter=r".*\/task:1\/.*")) - self.assertEqual( - [numerics_alert.NumericsAlertReportRow( - "/job:worker/replica:0/task:0/gpu:0", "xent/Mean:0", 1634, 1, 1, - 1)], registry.report(tensor_name_filter=r".*Mean.*")) - - def testRegisterBeyondCapacityObeysCapacity(self): - alert_1 = numerics_alert.NumericsAlert( - "/job:worker/replica:0/task:1/gpu:0", "xent/Log:0", 1434, 0, 1, 1) - alert_2 = numerics_alert.NumericsAlert( - "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 0, 1, 1) - alert_3 = numerics_alert.NumericsAlert( - "/job:worker/replica:0/task:2/gpu:0", "xent/Log:0", 1634, 0, 1, 1) - alert_4 = numerics_alert.NumericsAlert( - "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1834, 1, 1, 1) - registry = numerics_alert.NumericsAlertRegistry(capacity=2) - registry.register(alert_1) - registry.register(alert_2) - self.assertEqual( - [numerics_alert.NumericsAlertReportRow( - "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 0, 1, 1), - numerics_alert.NumericsAlertReportRow( - "/job:worker/replica:0/task:1/gpu:0", "xent/Log:0", 1434, 0, 1, - 1)], registry.report()) - - registry.register(alert_3) - self.assertEqual( - [numerics_alert.NumericsAlertReportRow( - "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 0, 1, 1), - numerics_alert.NumericsAlertReportRow( - "/job:worker/replica:0/task:1/gpu:0", "xent/Log:0", 1434, 0, 1, - 1)], registry.report()) - - registry.register(alert_4) - self.assertEqual( - [numerics_alert.NumericsAlertReportRow( - "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 1, 2, 2), - numerics_alert.NumericsAlertReportRow( - "/job:worker/replica:0/task:1/gpu:0", "xent/Log:0", 1434, 0, 1, - 1)], registry.report()) - - def testCreateJsonableRegistry(self): - alert = numerics_alert.NumericsAlert("/job:worker/replica:0/task:1/gpu:0", - "xent/Log:0", 1434, 0, 1, 1) - registry = numerics_alert.NumericsAlertRegistry() - registry.register(alert) - - triplet_list = registry.create_jsonable_registry() - self.assertEqual(1, len(triplet_list)) - - triplet = triplet_list[0] - self.assertEqual("/job:worker/replica:0/task:1/gpu:0", triplet.device) - self.assertEqual("xent/Log:0", triplet.tensor) - self.assertListEqual([0, -1, -1], list(triplet.jsonable_history["nan"])) - self.assertListEqual([1, 1434, 1434], - list(triplet.jsonable_history["neg_inf"])) - self.assertListEqual([1, 1434, 1434], - list(triplet.jsonable_history["pos_inf"])) - - def testLoadFromJson(self): - registry = numerics_alert.NumericsAlertRegistry(initialization_list=[[ - "/job:localhost/replica:0/task:0/cpu:0", "MatMul:0", { - "pos_inf": [0, -1, -1], - "nan": [1624, 1496818651573005, 1496818690371163], - "neg_inf": [0, -1, -1] - } - ], [ - "/job:localhost/replica:0/task:0/cpu:0", "weight/Adagrad:0", { - "pos_inf": [0, -1, -1], - "nan": [1621, 1496818651607234, 1496818690370891], - "neg_inf": [0, -1, -1] - } - ]]) - self.assertEqual([ - numerics_alert.NumericsAlertReportRow( - "/job:localhost/replica:0/task:0/cpu:0", "MatMul:0", - 1496818651573005, 1624, 0, 0), - numerics_alert.NumericsAlertReportRow( - "/job:localhost/replica:0/task:0/cpu:0", "weight/Adagrad:0", - 1496818651607234, 1621, 0, 0) - ], registry.report()) - - def testCreateEmptyJsonableRegistry(self): - """Tests that an empty registry yields an empty report.""" - registry = numerics_alert.NumericsAlertRegistry() - self.assertEqual([], registry.report()) + def testCreateEmptyJsonableRegistry(self): + """Tests that an empty registry yields an empty report.""" + registry = numerics_alert.NumericsAlertRegistry() + self.assertEqual([], registry.report()) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/tensorboard/plugins/debugger/session_debug_test.py b/tensorboard/plugins/debugger/session_debug_test.py index e9ada5204f..79c9e855ad 100644 --- a/tensorboard/plugins/debugger/session_debug_test.py +++ b/tensorboard/plugins/debugger/session_debug_test.py @@ -15,10 +15,10 @@ """Tests end-to-end debugger data server behavior by starting TensorBoard. This test launches an instance of TensorBoard as a subprocess. In turn, -TensorBoard (specifically its debugger plugin) starts a debugger data server. -The test then calls Session.run() using RunOptions pointing to the grpc:// debug -URL of the debugger data server. It then checks the correctness of the Event -proto file created by the debugger data server. +TensorBoard (specifically its debugger plugin) starts a debugger data +server. The test then calls Session.run() using RunOptions pointing to +the grpc:// debug URL of the debugger data server. It then checks the +correctness of the Event proto file created by the debugger data server. """ from __future__ import absolute_import @@ -36,7 +36,9 @@ import numpy as np import portpicker # pylint: disable=import-error import tensorflow.compat.v1 as tf # pylint: disable=wrong-import-order -from tensorflow.python import debug as tf_debug # pylint: disable=wrong-import-order +from tensorflow.python import ( + debug as tf_debug, +) # pylint: disable=wrong-import-order from tensorboard.plugins.debugger import constants from tensorboard.plugins.debugger import debugger_server_lib @@ -48,286 +50,327 @@ class SessionDebugTestBase(tf.test.TestCase): + def setUp(self): + self._debugger_data_server_grpc_port = portpicker.pick_unused_port() + self._debug_url = ( + "grpc://localhost:%d" % self._debugger_data_server_grpc_port + ) + self._logdir = tempfile.mkdtemp(prefix="tensorboard_dds_") + + self._debug_data_server = debugger_server_lib.DebuggerDataServer( + self._debugger_data_server_grpc_port, + self._logdir, + always_flush=True, + ) + self._server_thread = threading.Thread( + target=self._debug_data_server.start_the_debugger_data_receiving_server + ) + self._server_thread.start() + + self.assertTrue(self._poll_server_till_success(50, 0.2)) + + def tearDown(self): + self._debug_data_server.stop_server() + self._server_thread.join() + + if os.path.isdir(self._logdir): + shutil.rmtree(self._logdir) + + tf.reset_default_graph() + + def _poll_server_till_success(self, max_tries, poll_interval_seconds): + for _ in range(max_tries): + try: + with tf.Session() as sess: + a_init_val = np.array([42.0]) + a_init = tf.constant(a_init_val, shape=[1], name="a_init") + a = tf.Variable(a_init, name="a") + + run_options = tf.RunOptions(output_partition_graphs=True) + tf_debug.watch_graph( + run_options, + sess.graph, + debug_ops=["DebugNumericSummary"], + debug_urls=[self._debug_url], + ) + + sess.run(a.initializer, options=run_options) + return True + except tf.errors.FailedPreconditionError as exc: + time.sleep(poll_interval_seconds) + + return False + + def _compute_health_pill(self, x): + x_clean = x[ + np.where( + np.logical_and( + np.logical_not(np.isnan(x)), np.logical_not(np.isinf(x)) + ) + ) + ] + if np.size(x_clean): + x_min = np.min(x_clean) + x_max = np.max(x_clean) + x_mean = np.mean(x_clean) + x_var = np.var(x_clean) + else: + x_min = np.inf + x_max = -np.inf + x_mean = np.nan + x_var = np.nan + + return np.array( + [ + 1.0, # Assume is initialized. + np.size(x), + np.sum(np.isnan(x)), + np.sum(x == -np.inf), + np.sum(np.logical_and(x < 0.0, x != -np.inf)), + np.sum(x == 0.0), + np.sum(np.logical_and(x > 0.0, x != np.inf)), + np.sum(x == np.inf), + x_min, + x_max, + x_mean, + x_var, + float(tf.as_dtype(x.dtype).as_datatype_enum), + float(len(x.shape)), + ] + + list(x.shape) + ) + + def _check_health_pills_in_events_file( + self, events_file_path, debug_key_to_tensors + ): + reader = tf.python_io.tf_record_iterator(events_file_path) + event_read = tf.Event() + + # The first event in the file should contain the events version, which is + # important because without it, TensorBoard may purge health pill events. + event_read.ParseFromString(next(reader)) + self.assertEqual("brain.Event:2", event_read.file_version) + + health_pills = {} + while True: + next_event = next(reader, None) + if not next_event: + break + event_read.ParseFromString(next_event) + values = event_read.summary.value + if values: + if ( + values[0].metadata.plugin_data.plugin_name + == constants.DEBUGGER_PLUGIN_NAME + ): + debug_key = values[0].node_name + if debug_key not in health_pills: + health_pills[debug_key] = [ + tf_debug.load_tensor_from_event(event_read) + ] + else: + health_pills[debug_key].append( + tf_debug.load_tensor_from_event(event_read) + ) + + for debug_key in debug_key_to_tensors: + tensors = debug_key_to_tensors[debug_key] + for i, tensor in enumerate(tensors): + self.assertAllClose( + self._compute_health_pill(tensor), + health_pills[debug_key][i], + ) + + def testRunSimpleNetworkoWithInfAndNaNWorks(self): + with tf.Session() as sess: + x_init_val = np.array([[2.0], [-1.0]]) + y_init_val = np.array([[0.0], [-0.25]]) + z_init_val = np.array([[0.0, 3.0], [-1.0, 0.0]]) + + x_init = tf.constant(x_init_val, shape=[2, 1], name="x_init") + x = tf.Variable(x_init, name="x") + y_init = tf.constant(y_init_val, shape=[2, 1]) + y = tf.Variable(y_init, name="y") + z_init = tf.constant(z_init_val, shape=[2, 2]) + z = tf.Variable(z_init, name="z") + + u = tf.div(x, y, name="u") # Produces an Inf. + v = tf.matmul(z, u, name="v") # Produces NaN and Inf. + + sess.run(x.initializer) + sess.run(y.initializer) + sess.run(z.initializer) + + run_options = tf.RunOptions(output_partition_graphs=True) + tf_debug.watch_graph( + run_options, + sess.graph, + debug_ops=["DebugNumericSummary"], + debug_urls=[self._debug_url], + ) + + result = sess.run(v, options=run_options) + self.assertTrue(np.isnan(result[0, 0])) + self.assertEqual(-np.inf, result[1, 0]) + + # Debugger data is stored within a special directory within logdir. + event_files = glob.glob( + os.path.join( + self._logdir, + constants.DEBUGGER_DATA_DIRECTORY_NAME, + "events.debugger*", + ) + ) + self.assertEqual(1, len(event_files)) + + self._check_health_pills_in_events_file( + event_files[0], + { + "x:0:DebugNumericSummary": [x_init_val], + "y:0:DebugNumericSummary": [y_init_val], + "z:0:DebugNumericSummary": [z_init_val], + "u:0:DebugNumericSummary": [x_init_val / y_init_val], + "v:0:DebugNumericSummary": [ + np.matmul(z_init_val, x_init_val / y_init_val) + ], + }, + ) + + report = self._debug_data_server.numerics_alert_report() + self.assertEqual(2, len(report)) + self.assertTrue(report[0].device_name.lower().endswith("cpu:0")) + self.assertEqual("u:0", report[0].tensor_name) + self.assertGreater(report[0].first_timestamp, 0) + self.assertEqual(0, report[0].nan_event_count) + self.assertEqual(0, report[0].neg_inf_event_count) + self.assertEqual(1, report[0].pos_inf_event_count) + self.assertTrue(report[1].device_name.lower().endswith("cpu:0")) + self.assertEqual("u:0", report[0].tensor_name) + self.assertGreaterEqual( + report[1].first_timestamp, report[0].first_timestamp + ) + self.assertEqual(1, report[1].nan_event_count) + self.assertEqual(1, report[1].neg_inf_event_count) + self.assertEqual(0, report[1].pos_inf_event_count) + + def testMultipleInt32ValuesOverMultipleRunsAreRecorded(self): + with tf.Session() as sess: + x_init_val = np.array([10], dtype=np.int32) + x_init = tf.constant(x_init_val, shape=[1], name="x_init") + x = tf.Variable(x_init, name="x") + + x_inc_val = np.array([2], dtype=np.int32) + x_inc = tf.constant(x_inc_val, name="x_inc") + inc_x = tf.assign_add(x, x_inc, name="inc_x") + + sess.run(x.initializer) + + run_options = tf.RunOptions(output_partition_graphs=True) + tf_debug.watch_graph( + run_options, + sess.graph, + debug_ops=["DebugNumericSummary"], + debug_urls=[self._debug_url], + ) + + # Increase three times. + for _ in range(3): + sess.run(inc_x, options=run_options) + + # Debugger data is stored within a special directory within logdir. + event_files = glob.glob( + os.path.join( + self._logdir, + constants.DEBUGGER_DATA_DIRECTORY_NAME, + "events.debugger*", + ) + ) + self.assertEqual(1, len(event_files)) + + self._check_health_pills_in_events_file( + event_files[0], + { + "x_inc:0:DebugNumericSummary": [x_inc_val] * 3, + "x:0:DebugNumericSummary": [ + x_init_val, + x_init_val + x_inc_val, + x_init_val + 2 * x_inc_val, + ], + }, + ) + + def testConcurrentNumericsAlertsAreRegisteredCorrectly(self): + num_threads = 3 + num_runs_per_thread = 2 + total_num_runs = num_threads * num_runs_per_thread + + # Before any Session runs, the report ought to be empty. + self.assertEqual([], self._debug_data_server.numerics_alert_report()) - def setUp(self): - self._debugger_data_server_grpc_port = portpicker.pick_unused_port() - self._debug_url = ( - "grpc://localhost:%d" % self._debugger_data_server_grpc_port) - self._logdir = tempfile.mkdtemp(prefix="tensorboard_dds_") - - self._debug_data_server = debugger_server_lib.DebuggerDataServer( - self._debugger_data_server_grpc_port, self._logdir, always_flush=True) - self._server_thread = threading.Thread( - target=self._debug_data_server.start_the_debugger_data_receiving_server) - self._server_thread.start() - - self.assertTrue(self._poll_server_till_success(50, 0.2)) - - def tearDown(self): - self._debug_data_server.stop_server() - self._server_thread.join() - - if os.path.isdir(self._logdir): - shutil.rmtree(self._logdir) - - tf.reset_default_graph() - - def _poll_server_till_success(self, max_tries, poll_interval_seconds): - for _ in range(max_tries): - try: with tf.Session() as sess: - a_init_val = np.array([42.0]) - a_init = tf.constant(a_init_val, shape=[1], name="a_init") - a = tf.Variable(a_init, name="a") - - run_options = tf.RunOptions(output_partition_graphs=True) - tf_debug.watch_graph(run_options, - sess.graph, - debug_ops=["DebugNumericSummary"], - debug_urls=[self._debug_url]) - - sess.run(a.initializer, options=run_options) - return True - except tf.errors.FailedPreconditionError as exc: - time.sleep(poll_interval_seconds) - - return False - - def _compute_health_pill(self, x): - x_clean = x[np.where( - np.logical_and( - np.logical_not(np.isnan(x)), np.logical_not(np.isinf(x))))] - if np.size(x_clean): - x_min = np.min(x_clean) - x_max = np.max(x_clean) - x_mean = np.mean(x_clean) - x_var = np.var(x_clean) - else: - x_min = np.inf - x_max = -np.inf - x_mean = np.nan - x_var = np.nan - - return np.array([ - 1.0, # Assume is initialized. - np.size(x), - np.sum(np.isnan(x)), - np.sum(x == -np.inf), - np.sum(np.logical_and(x < 0.0, x != -np.inf)), - np.sum(x == 0.0), - np.sum(np.logical_and(x > 0.0, x != np.inf)), - np.sum(x == np.inf), - x_min, - x_max, - x_mean, - x_var, - float(tf.as_dtype(x.dtype).as_datatype_enum), - float(len(x.shape)), - ] + list(x.shape)) - - def _check_health_pills_in_events_file(self, - events_file_path, - debug_key_to_tensors): - reader = tf.python_io.tf_record_iterator(events_file_path) - event_read = tf.Event() - - # The first event in the file should contain the events version, which is - # important because without it, TensorBoard may purge health pill events. - event_read.ParseFromString(next(reader)) - self.assertEqual("brain.Event:2", event_read.file_version) - - health_pills = {} - while True: - next_event = next(reader, None) - if not next_event: - break - event_read.ParseFromString(next_event) - values = event_read.summary.value - if values: - if (values[0].metadata.plugin_data.plugin_name == - constants.DEBUGGER_PLUGIN_NAME): - debug_key = values[0].node_name - if debug_key not in health_pills: - health_pills[debug_key] = [ - tf_debug.load_tensor_from_event(event_read)] - else: - health_pills[debug_key].append( - tf_debug.load_tensor_from_event(event_read)) - - for debug_key in debug_key_to_tensors: - tensors = debug_key_to_tensors[debug_key] - for i, tensor in enumerate(tensors): - self.assertAllClose( - self._compute_health_pill(tensor), - health_pills[debug_key][i]) - - def testRunSimpleNetworkoWithInfAndNaNWorks(self): - with tf.Session() as sess: - x_init_val = np.array([[2.0], [-1.0]]) - y_init_val = np.array([[0.0], [-0.25]]) - z_init_val = np.array([[0.0, 3.0], [-1.0, 0.0]]) - - x_init = tf.constant(x_init_val, shape=[2, 1], name="x_init") - x = tf.Variable(x_init, name="x") - y_init = tf.constant(y_init_val, shape=[2, 1]) - y = tf.Variable(y_init, name="y") - z_init = tf.constant(z_init_val, shape=[2, 2]) - z = tf.Variable(z_init, name="z") - - u = tf.div(x, y, name="u") # Produces an Inf. - v = tf.matmul(z, u, name="v") # Produces NaN and Inf. - - sess.run(x.initializer) - sess.run(y.initializer) - sess.run(z.initializer) - - run_options = tf.RunOptions(output_partition_graphs=True) - tf_debug.watch_graph(run_options, - sess.graph, - debug_ops=["DebugNumericSummary"], - debug_urls=[self._debug_url]) - - result = sess.run(v, options=run_options) - self.assertTrue(np.isnan(result[0, 0])) - self.assertEqual(-np.inf, result[1, 0]) - - # Debugger data is stored within a special directory within logdir. - event_files = glob.glob( - os.path.join(self._logdir, constants.DEBUGGER_DATA_DIRECTORY_NAME, - "events.debugger*")) - self.assertEqual(1, len(event_files)) - - self._check_health_pills_in_events_file(event_files[0], { - "x:0:DebugNumericSummary": [x_init_val], - "y:0:DebugNumericSummary": [y_init_val], - "z:0:DebugNumericSummary": [z_init_val], - "u:0:DebugNumericSummary": [x_init_val / y_init_val], - "v:0:DebugNumericSummary": [ - np.matmul(z_init_val, x_init_val / y_init_val) - ], - }) - - report = self._debug_data_server.numerics_alert_report() - self.assertEqual(2, len(report)) - self.assertTrue(report[0].device_name.lower().endswith("cpu:0")) - self.assertEqual("u:0", report[0].tensor_name) - self.assertGreater(report[0].first_timestamp, 0) - self.assertEqual(0, report[0].nan_event_count) - self.assertEqual(0, report[0].neg_inf_event_count) - self.assertEqual(1, report[0].pos_inf_event_count) - self.assertTrue(report[1].device_name.lower().endswith("cpu:0")) - self.assertEqual("u:0", report[0].tensor_name) - self.assertGreaterEqual(report[1].first_timestamp, - report[0].first_timestamp) - self.assertEqual(1, report[1].nan_event_count) - self.assertEqual(1, report[1].neg_inf_event_count) - self.assertEqual(0, report[1].pos_inf_event_count) - - def testMultipleInt32ValuesOverMultipleRunsAreRecorded(self): - with tf.Session() as sess: - x_init_val = np.array([10], dtype=np.int32) - x_init = tf.constant(x_init_val, shape=[1], name="x_init") - x = tf.Variable(x_init, name="x") - - x_inc_val = np.array([2], dtype=np.int32) - x_inc = tf.constant(x_inc_val, name="x_inc") - inc_x = tf.assign_add(x, x_inc, name="inc_x") - - sess.run(x.initializer) - - run_options = tf.RunOptions(output_partition_graphs=True) - tf_debug.watch_graph(run_options, - sess.graph, - debug_ops=["DebugNumericSummary"], - debug_urls=[self._debug_url]) - - # Increase three times. - for _ in range(3): - sess.run(inc_x, options=run_options) - - # Debugger data is stored within a special directory within logdir. - event_files = glob.glob( - os.path.join(self._logdir, constants.DEBUGGER_DATA_DIRECTORY_NAME, - "events.debugger*")) - self.assertEqual(1, len(event_files)) - - self._check_health_pills_in_events_file( - event_files[0], - { - "x_inc:0:DebugNumericSummary": [x_inc_val] * 3, - "x:0:DebugNumericSummary": [ - x_init_val, - x_init_val + x_inc_val, - x_init_val + 2 * x_inc_val], - }) - - def testConcurrentNumericsAlertsAreRegisteredCorrectly(self): - num_threads = 3 - num_runs_per_thread = 2 - total_num_runs = num_threads * num_runs_per_thread - - # Before any Session runs, the report ought to be empty. - self.assertEqual([], self._debug_data_server.numerics_alert_report()) - - with tf.Session() as sess: - x_init_val = np.array([[2.0], [-1.0]]) - y_init_val = np.array([[0.0], [-0.25]]) - z_init_val = np.array([[0.0, 3.0], [-1.0, 0.0]]) - - x_init = tf.constant(x_init_val, shape=[2, 1], name="x_init") - x = tf.Variable(x_init, name="x") - y_init = tf.constant(y_init_val, shape=[2, 1]) - y = tf.Variable(y_init, name="y") - z_init = tf.constant(z_init_val, shape=[2, 2]) - z = tf.Variable(z_init, name="z") - - u = tf.div(x, y, name="u") # Produces an Inf. - v = tf.matmul(z, u, name="v") # Produces NaN and Inf. - - sess.run(x.initializer) - sess.run(y.initializer) - sess.run(z.initializer) - - run_options_list = [] - for i in range(num_threads): - run_options = tf.RunOptions(output_partition_graphs=True) - # Use different grpc:// URL paths so that each thread opens a separate - # gRPC stream to the debug data server, simulating multi-worker setting. - tf_debug.watch_graph(run_options, - sess.graph, - debug_ops=["DebugNumericSummary"], - debug_urls=[self._debug_url + "/thread%d" % i]) - run_options_list.append(run_options) - - def run_v(thread_id): - for _ in range(num_runs_per_thread): - sess.run(v, options=run_options_list[thread_id]) - - run_threads = [] - for thread_id in range(num_threads): - thread = threading.Thread(target=functools.partial(run_v, thread_id)) - thread.start() - run_threads.append(thread) - - for thread in run_threads: - thread.join() - - report = self._debug_data_server.numerics_alert_report() - self.assertEqual(2, len(report)) - self.assertTrue(report[0].device_name.lower().endswith("cpu:0")) - self.assertEqual("u:0", report[0].tensor_name) - self.assertGreater(report[0].first_timestamp, 0) - self.assertEqual(0, report[0].nan_event_count) - self.assertEqual(0, report[0].neg_inf_event_count) - self.assertEqual(total_num_runs, report[0].pos_inf_event_count) - self.assertTrue(report[1].device_name.lower().endswith("cpu:0")) - self.assertEqual("u:0", report[0].tensor_name) - self.assertGreaterEqual(report[1].first_timestamp, - report[0].first_timestamp) - self.assertEqual(total_num_runs, report[1].nan_event_count) - self.assertEqual(total_num_runs, report[1].neg_inf_event_count) - self.assertEqual(0, report[1].pos_inf_event_count) + x_init_val = np.array([[2.0], [-1.0]]) + y_init_val = np.array([[0.0], [-0.25]]) + z_init_val = np.array([[0.0, 3.0], [-1.0, 0.0]]) + + x_init = tf.constant(x_init_val, shape=[2, 1], name="x_init") + x = tf.Variable(x_init, name="x") + y_init = tf.constant(y_init_val, shape=[2, 1]) + y = tf.Variable(y_init, name="y") + z_init = tf.constant(z_init_val, shape=[2, 2]) + z = tf.Variable(z_init, name="z") + + u = tf.div(x, y, name="u") # Produces an Inf. + v = tf.matmul(z, u, name="v") # Produces NaN and Inf. + + sess.run(x.initializer) + sess.run(y.initializer) + sess.run(z.initializer) + + run_options_list = [] + for i in range(num_threads): + run_options = tf.RunOptions(output_partition_graphs=True) + # Use different grpc:// URL paths so that each thread opens a separate + # gRPC stream to the debug data server, simulating multi-worker setting. + tf_debug.watch_graph( + run_options, + sess.graph, + debug_ops=["DebugNumericSummary"], + debug_urls=[self._debug_url + "/thread%d" % i], + ) + run_options_list.append(run_options) + + def run_v(thread_id): + for _ in range(num_runs_per_thread): + sess.run(v, options=run_options_list[thread_id]) + + run_threads = [] + for thread_id in range(num_threads): + thread = threading.Thread( + target=functools.partial(run_v, thread_id) + ) + thread.start() + run_threads.append(thread) + + for thread in run_threads: + thread.join() + + report = self._debug_data_server.numerics_alert_report() + self.assertEqual(2, len(report)) + self.assertTrue(report[0].device_name.lower().endswith("cpu:0")) + self.assertEqual("u:0", report[0].tensor_name) + self.assertGreater(report[0].first_timestamp, 0) + self.assertEqual(0, report[0].nan_event_count) + self.assertEqual(0, report[0].neg_inf_event_count) + self.assertEqual(total_num_runs, report[0].pos_inf_event_count) + self.assertTrue(report[1].device_name.lower().endswith("cpu:0")) + self.assertEqual("u:0", report[0].tensor_name) + self.assertGreaterEqual( + report[1].first_timestamp, report[0].first_timestamp + ) + self.assertEqual(total_num_runs, report[1].nan_event_count) + self.assertEqual(total_num_runs, report[1].neg_inf_event_count) + self.assertEqual(0, report[1].pos_inf_event_count) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/tensorboard/plugins/debugger/tensor_helper.py b/tensorboard/plugins/debugger/tensor_helper.py index 1be66445f6..c9048ee6d1 100644 --- a/tensorboard/plugins/debugger/tensor_helper.py +++ b/tensorboard/plugins/debugger/tensor_helper.py @@ -29,140 +29,147 @@ def numel(shape): - """Obtain total number of elements from a tensor (ndarray) shape. + """Obtain total number of elements from a tensor (ndarray) shape. - Args: - shape: A list or tuple represenitng a tensor (ndarray) shape. - """ - output = 1 - for dim in shape: - output *= dim - return output + Args: + shape: A list or tuple represenitng a tensor (ndarray) shape. + """ + output = 1 + for dim in shape: + output *= dim + return output def parse_time_indices(s): - """Parse a string as time indices. - - Args: - s: A valid slicing string for time indices. E.g., '-1', '[:]', ':', '2:10' - - Returns: - A slice object. - - Raises: - ValueError: If `s` does not represent valid time indices. - """ - if not s.startswith('['): - s = '[' + s + ']' - parsed = command_parser._parse_slices(s) - if len(parsed) != 1: - raise ValueError( - 'Invalid number of slicing objects in time indices (%d)' % len(parsed)) - else: - return parsed[0] + """Parse a string as time indices. + + Args: + s: A valid slicing string for time indices. E.g., '-1', '[:]', ':', '2:10' + + Returns: + A slice object. + + Raises: + ValueError: If `s` does not represent valid time indices. + """ + if not s.startswith("["): + s = "[" + s + "]" + parsed = command_parser._parse_slices(s) + if len(parsed) != 1: + raise ValueError( + "Invalid number of slicing objects in time indices (%d)" + % len(parsed) + ) + else: + return parsed[0] def translate_dtype(dtype): - """Translate numpy dtype into a string. + """Translate numpy dtype into a string. - The 'object' type is understood as a TensorFlow string and translated into - 'string'. + The 'object' type is understood as a TensorFlow string and translated into + 'string'. - Args: - dtype: A numpy dtype object. + Args: + dtype: A numpy dtype object. - Returns: - A string representing the data type. - """ - out = str(dtype) - # String-type TensorFlow Tensors are represented as object-type arrays in - # numpy. We map the type name back to 'string' for clarity. - return 'string' if out == 'object' else out + Returns: + A string representing the data type. + """ + out = str(dtype) + # String-type TensorFlow Tensors are represented as object-type arrays in + # numpy. We map the type name back to 'string' for clarity. + return "string" if out == "object" else out def process_buffers_for_display(s, limit=40): - """Process a buffer for human-readable display. - - This function performs the following operation on each of the buffers in `s`. - 1. Truncate input buffer if the length of the buffer is greater than - `limit`, to prevent large strings from overloading the frontend. - 2. Apply `binascii.b2a_qp` on the truncated buffer to make the buffer - printable and convertible to JSON. - 3. If truncation happened (in step 1), append a string at the end - describing the original length and the truncation. - - Args: - s: The buffer to be processed, either a single buffer or a nested array of - them. - limit: Length limit for each buffer, beyond which truncation will occur. - - Return: - A single processed buffer or a nested array of processed buffers. - """ - if isinstance(s, (list, tuple)): - return [process_buffers_for_display(elem, limit=limit) for elem in s] - else: - length = len(s) - if length > limit: - return (binascii.b2a_qp(s[:limit]) + - b' (length-%d truncated at %d bytes)' % (length, limit)) + """Process a buffer for human-readable display. + + This function performs the following operation on each of the buffers in `s`. + 1. Truncate input buffer if the length of the buffer is greater than + `limit`, to prevent large strings from overloading the frontend. + 2. Apply `binascii.b2a_qp` on the truncated buffer to make the buffer + printable and convertible to JSON. + 3. If truncation happened (in step 1), append a string at the end + describing the original length and the truncation. + + Args: + s: The buffer to be processed, either a single buffer or a nested array of + them. + limit: Length limit for each buffer, beyond which truncation will occur. + + Return: + A single processed buffer or a nested array of processed buffers. + """ + if isinstance(s, (list, tuple)): + return [process_buffers_for_display(elem, limit=limit) for elem in s] else: - return binascii.b2a_qp(s) + length = len(s) + if length > limit: + return binascii.b2a_qp( + s[:limit] + ) + b" (length-%d truncated at %d bytes)" % (length, limit) + else: + return binascii.b2a_qp(s) def array_view(array, slicing=None, mapping=None): - """View a slice or the entirety of an ndarray. - - Args: - array: The input array, as an numpy.ndarray. - slicing: Optional slicing string, e.g., "[:, 1:3, :]". - mapping: Optional mapping string. Supported mappings: - `None` or case-insensitive `'None'`: Unmapped nested list. - `'image/png'`: Image encoding of a 2D sliced array or 3D sliced array - with 3 as the last dimension. If the sliced array is not 2D or 3D with - 3 as the last dimension, a `ValueError` will be thrown. - `health-pill`: A succinct summary of the numeric values of a tensor. - See documentation in [`health_pill_calc.py`] for more details. - - Returns: - 1. dtype as a `str`. - 2. shape of the sliced array, as a tuple of `int`s. - 3. the potentially sliced values, as a nested `list`. - """ - - dtype = translate_dtype(array.dtype) - sliced_array = (array[command_parser._parse_slices(slicing)] if slicing - else array) - - if np.isscalar(sliced_array) and str(dtype) == 'string': - # When a string Tensor (for which dtype is 'object') is sliced down to only - # one element, it becomes a string, instead of an numpy array. - # We preserve the dimensionality of original array in the returned shape - # and slice. - ndims = len(array.shape) - slice_shape = [] - for _ in range(ndims): - sliced_array = [sliced_array] - slice_shape.append(1) - return dtype, tuple(slice_shape), sliced_array - else: - shape = sliced_array.shape - if mapping == "image/png": - if len(sliced_array.shape) == 2: - return dtype, shape, array_to_base64_png(sliced_array) - elif len(sliced_array.shape) == 3: - raise NotImplementedError( - "image/png mapping for 3D array has not been implemented") - else: - raise ValueError("Invalid rank for image/png mapping: %d" % - len(sliced_array.shape)) - elif mapping == 'health-pill': - health_pill = health_pill_calc.calc_health_pill(array) - return dtype, shape, health_pill - elif mapping is None or mapping == '' or mapping.lower() == 'none': - return dtype, shape, sliced_array.tolist() + """View a slice or the entirety of an ndarray. + + Args: + array: The input array, as an numpy.ndarray. + slicing: Optional slicing string, e.g., "[:, 1:3, :]". + mapping: Optional mapping string. Supported mappings: + `None` or case-insensitive `'None'`: Unmapped nested list. + `'image/png'`: Image encoding of a 2D sliced array or 3D sliced array + with 3 as the last dimension. If the sliced array is not 2D or 3D with + 3 as the last dimension, a `ValueError` will be thrown. + `health-pill`: A succinct summary of the numeric values of a tensor. + See documentation in [`health_pill_calc.py`] for more details. + + Returns: + 1. dtype as a `str`. + 2. shape of the sliced array, as a tuple of `int`s. + 3. the potentially sliced values, as a nested `list`. + """ + + dtype = translate_dtype(array.dtype) + sliced_array = ( + array[command_parser._parse_slices(slicing)] if slicing else array + ) + + if np.isscalar(sliced_array) and str(dtype) == "string": + # When a string Tensor (for which dtype is 'object') is sliced down to only + # one element, it becomes a string, instead of an numpy array. + # We preserve the dimensionality of original array in the returned shape + # and slice. + ndims = len(array.shape) + slice_shape = [] + for _ in range(ndims): + sliced_array = [sliced_array] + slice_shape.append(1) + return dtype, tuple(slice_shape), sliced_array else: - raise ValueError("Invalid mapping: %s" % mapping) + shape = sliced_array.shape + if mapping == "image/png": + if len(sliced_array.shape) == 2: + return dtype, shape, array_to_base64_png(sliced_array) + elif len(sliced_array.shape) == 3: + raise NotImplementedError( + "image/png mapping for 3D array has not been implemented" + ) + else: + raise ValueError( + "Invalid rank for image/png mapping: %d" + % len(sliced_array.shape) + ) + elif mapping == "health-pill": + health_pill = health_pill_calc.calc_health_pill(array) + return dtype, shape, health_pill + elif mapping is None or mapping == "" or mapping.lower() == "none": + return dtype, shape, sliced_array.tolist() + else: + raise ValueError("Invalid mapping: %s" % mapping) IMAGE_COLOR_CHANNELS = 3 @@ -172,52 +179,59 @@ def array_view(array, slicing=None, mapping=None): def array_to_base64_png(array): - """Convert an array into base64-enoded PNG image. - - Args: - array: A 2D np.ndarray or nested list of items. - - Returns: - A base64-encoded string the image. The image is grayscale if the array is - 2D. The image is RGB color if the image is 3D with lsat dimension equal to - 3. - - Raises: - ValueError: If the input `array` is not rank-2, or if the rank-2 `array` is - empty. - """ - # TODO(cais): Deal with 3D case. - # TODO(cais): If there are None values in here, replace them with all NaNs. - array = np.array(array, dtype=np.float32) - if len(array.shape) != 2: - raise ValueError( - "Expected rank-2 array; received rank-%d array." % len(array.shape)) - if not np.size(array): - raise ValueError( - "Cannot encode an empty array (size: %s) as image." % (array.shape,)) - - is_infinity = np.isinf(array) - is_positive = array > 0.0 - is_positive_infinity = np.logical_and(is_infinity, is_positive) - is_negative_infinity = np.logical_and(is_infinity, - np.logical_not(is_positive)) - is_nan = np.isnan(array) - finite_indices = np.where(np.logical_and(np.logical_not(is_infinity), - np.logical_not(is_nan))) - if np.size(finite_indices): - # Finite subset is not empty. - minval = np.min(array[finite_indices]) - maxval = np.max(array[finite_indices]) - scaled = np.array((array - minval) / (maxval - minval) * 255, - dtype=np.uint8) - rgb = np.repeat(np.expand_dims(scaled, -1), IMAGE_COLOR_CHANNELS, axis=-1) - else: - rgb = np.zeros(array.shape + (IMAGE_COLOR_CHANNELS,), dtype=np.uint8) - - # Color-code pixels that correspond to infinities and nans. - rgb[is_positive_infinity] = POSITIVE_INFINITY_RGB - rgb[is_negative_infinity] = NEGATIVE_INFINITY_RGB - rgb[is_nan] = NAN_RGB - - image_encoded = base64.b64encode(encoder.encode_png(rgb)) - return image_encoded + """Convert an array into base64-enoded PNG image. + + Args: + array: A 2D np.ndarray or nested list of items. + + Returns: + A base64-encoded string the image. The image is grayscale if the array is + 2D. The image is RGB color if the image is 3D with lsat dimension equal to + 3. + + Raises: + ValueError: If the input `array` is not rank-2, or if the rank-2 `array` is + empty. + """ + # TODO(cais): Deal with 3D case. + # TODO(cais): If there are None values in here, replace them with all NaNs. + array = np.array(array, dtype=np.float32) + if len(array.shape) != 2: + raise ValueError( + "Expected rank-2 array; received rank-%d array." % len(array.shape) + ) + if not np.size(array): + raise ValueError( + "Cannot encode an empty array (size: %s) as image." % (array.shape,) + ) + + is_infinity = np.isinf(array) + is_positive = array > 0.0 + is_positive_infinity = np.logical_and(is_infinity, is_positive) + is_negative_infinity = np.logical_and( + is_infinity, np.logical_not(is_positive) + ) + is_nan = np.isnan(array) + finite_indices = np.where( + np.logical_and(np.logical_not(is_infinity), np.logical_not(is_nan)) + ) + if np.size(finite_indices): + # Finite subset is not empty. + minval = np.min(array[finite_indices]) + maxval = np.max(array[finite_indices]) + scaled = np.array( + (array - minval) / (maxval - minval) * 255, dtype=np.uint8 + ) + rgb = np.repeat( + np.expand_dims(scaled, -1), IMAGE_COLOR_CHANNELS, axis=-1 + ) + else: + rgb = np.zeros(array.shape + (IMAGE_COLOR_CHANNELS,), dtype=np.uint8) + + # Color-code pixels that correspond to infinities and nans. + rgb[is_positive_infinity] = POSITIVE_INFINITY_RGB + rgb[is_negative_infinity] = NEGATIVE_INFINITY_RGB + rgb[is_nan] = NAN_RGB + + image_encoded = base64.b64encode(encoder.encode_png(rgb)) + return image_encoded diff --git a/tensorboard/plugins/debugger/tensor_helper_test.py b/tensorboard/plugins/debugger/tensor_helper_test.py index cf34b6e43b..d895c2d197 100644 --- a/tensorboard/plugins/debugger/tensor_helper_test.py +++ b/tensorboard/plugins/debugger/tensor_helper_test.py @@ -30,259 +30,280 @@ class TranslateDTypeTest(tf.test.TestCase): + def testTranslateNumericDTypes(self): + x = np.zeros([2, 2], dtype=np.float32) + self.assertEqual("float32", tensor_helper.translate_dtype(x.dtype)) + x = np.zeros([2], dtype=np.int16) + self.assertEqual("int16", tensor_helper.translate_dtype(x.dtype)) + x = np.zeros([], dtype=np.uint8) + self.assertEqual("uint8", tensor_helper.translate_dtype(x.dtype)) - def testTranslateNumericDTypes(self): - x = np.zeros([2, 2], dtype=np.float32) - self.assertEqual('float32', tensor_helper.translate_dtype(x.dtype)) - x = np.zeros([2], dtype=np.int16) - self.assertEqual('int16', tensor_helper.translate_dtype(x.dtype)) - x = np.zeros([], dtype=np.uint8) - self.assertEqual('uint8', tensor_helper.translate_dtype(x.dtype)) + def testTranslateBooleanDType(self): + x = np.zeros([2, 2], dtype=np.bool) + self.assertEqual("bool", tensor_helper.translate_dtype(x.dtype)) - def testTranslateBooleanDType(self): - x = np.zeros([2, 2], dtype=np.bool) - self.assertEqual('bool', tensor_helper.translate_dtype(x.dtype)) - - def testTranslateStringDType(self): - x = np.array(['abc'], dtype=np.object) - self.assertEqual('string', tensor_helper.translate_dtype(x.dtype)) + def testTranslateStringDType(self): + x = np.array(["abc"], dtype=np.object) + self.assertEqual("string", tensor_helper.translate_dtype(x.dtype)) class ProcessBuffersForDisplayTest(tf.test.TestCase): - - def testBinaryScalarBelowLimit(self): - x = b'\x01\x02\x03' - self.assertEqual(binascii.b2a_qp(x), - tensor_helper.process_buffers_for_display(x, 10)) - - def testAsciiScalarBelowLimit(self): - x = b'foo_bar' - self.assertEqual(b'foo_bar', - tensor_helper.process_buffers_for_display(x, 10)) - - def testBinaryScalarAboveLimit(self): - x = b'\x01\x02\x03' - self.assertEqual( - binascii.b2a_qp(x[:2]) + b' (length-3 truncated at 2 bytes)', - tensor_helper.process_buffers_for_display(x, 2)) - - def testAsciiScalarAboveLimit(self): - x = b'foo_bar' - self.assertEqual(b'foo_ (length-7 truncated at 4 bytes)', - tensor_helper.process_buffers_for_display(x, 4)) - - def testNestedArrayMixed(self): - x = [[b'\x01\x02\x03', b'foo_bar'], [b'\x01', b'f']] - self.assertEqual( - [[b'=01=02 (length-3 truncated at 2 bytes)', - b'fo (length-7 truncated at 2 bytes)'], - [b'=01', b'f']], tensor_helper.process_buffers_for_display(x, 2)) + def testBinaryScalarBelowLimit(self): + x = b"\x01\x02\x03" + self.assertEqual( + binascii.b2a_qp(x), tensor_helper.process_buffers_for_display(x, 10) + ) + + def testAsciiScalarBelowLimit(self): + x = b"foo_bar" + self.assertEqual( + b"foo_bar", tensor_helper.process_buffers_for_display(x, 10) + ) + + def testBinaryScalarAboveLimit(self): + x = b"\x01\x02\x03" + self.assertEqual( + binascii.b2a_qp(x[:2]) + b" (length-3 truncated at 2 bytes)", + tensor_helper.process_buffers_for_display(x, 2), + ) + + def testAsciiScalarAboveLimit(self): + x = b"foo_bar" + self.assertEqual( + b"foo_ (length-7 truncated at 4 bytes)", + tensor_helper.process_buffers_for_display(x, 4), + ) + + def testNestedArrayMixed(self): + x = [[b"\x01\x02\x03", b"foo_bar"], [b"\x01", b"f"]] + self.assertEqual( + [ + [ + b"=01=02 (length-3 truncated at 2 bytes)", + b"fo (length-7 truncated at 2 bytes)", + ], + [b"=01", b"f"], + ], + tensor_helper.process_buffers_for_display(x, 2), + ) class TensorHelperTest(tf.test.TestCase): - - def testArrayViewFloat2DNoSlicing(self): - float_array = np.ones([3, 3], dtype=np.float32) - dtype, shape, values = tensor_helper.array_view(float_array) - self.assertEqual("float32", dtype) - self.assertEqual((3, 3), shape) - self.assertEqual(float_array.tolist(), values) - - def testArrayViewFloat2DWithSlicing(self): - x = np.ones([4, 4], dtype=np.float64) - y = np.zeros([4, 4], dtype=np.float64) - float_array = np.concatenate((x, y), axis=1) - - dtype, shape, values = tensor_helper.array_view( - float_array, slicing="[2:, :]") - self.assertEqual("float64", dtype) - self.assertEqual((2, 8), shape) - self.assertAllClose( - [[1, 1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0]], values) - - def testArrayViewInt3DWithSlicing(self): - x = np.ones([4, 4], dtype=np.int32) - int_array = np.zeros([3, 4, 4], dtype=np.int32) - int_array[0, ...] = x - int_array[1, ...] = 2 * x - int_array[2, ...] = 3 * x - - dtype, shape, values = tensor_helper.array_view( - int_array, slicing="[:, :, 2]") - self.assertEqual("int32", dtype) - self.assertEqual((3, 4), shape) - self.assertEqual([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]], values) - - def testArrayView2DWithSlicingAndImagePngMapping(self): - x = np.ones([15, 16], dtype=np.int32) - dtype, shape, data = tensor_helper.array_view( - x, slicing="[:15:3, :16:2]", mapping="image/png") - self.assertEqual("int32", dtype) - self.assertEqual((5, 8), shape) - decoded_x = im_util.decode_png(base64.b64decode(data)) - self.assertEqual((5, 8, 3), decoded_x.shape) - self.assertEqual(np.uint8, decoded_x.dtype) - self.assertAllClose(np.zeros([5, 8, 3]), decoded_x) - - def testImagePngMappingWorksForArrayWithOnlyOneElement(self): - x = np.array([[-42]], dtype=np.int16) - dtype, shape, data = tensor_helper.array_view(x, mapping="image/png") - self.assertEqual("int16", dtype) - self.assertEqual((1, 1), shape) - decoded_x = im_util.decode_png(base64.b64decode(data)) - self.assertEqual((1, 1, 3), decoded_x.shape) - self.assertEqual(np.uint8, decoded_x.dtype) - self.assertAllClose(np.zeros([1, 1, 3]), decoded_x) - - def testImagePngMappingWorksForArrayWithInfAndNaN(self): - x = np.array([[1.1, 2.2, np.inf], [-np.inf, 3.3, np.nan]], dtype=np.float32) - dtype, shape, data = tensor_helper.array_view(x, mapping="image/png") - self.assertEqual("float32", dtype) - self.assertEqual((2, 3), shape) - decoded_x = im_util.decode_png(base64.b64decode(data)) - self.assertEqual((2, 3, 3), decoded_x.shape) - self.assertEqual(np.uint8, decoded_x.dtype) - self.assertAllClose([0, 0, 0], decoded_x[0, 0, :]) # 1.1. - self.assertAllClose([127, 127, 127], decoded_x[0, 1, :]) # 2.2. - self.assertAllClose(tensor_helper.POSITIVE_INFINITY_RGB, - decoded_x[0, 2, :]) # +infinity. - self.assertAllClose(tensor_helper.NEGATIVE_INFINITY_RGB, - decoded_x[1, 0, :]) # -infinity. - self.assertAllClose([255, 255, 255], decoded_x[1, 1, :]) # 3.3. - self.assertAllClose(tensor_helper.NAN_RGB, decoded_x[1, 2, :]) # nan. - - def testArrayViewSlicingDownNumericTensorToOneElement(self): - x = np.array([[1.1, 2.2, np.inf], [-np.inf, 3.3, np.nan]], dtype=np.float32) - dtype, shape, data = tensor_helper.array_view(x, slicing='[0,0]') - self.assertEqual('float32', dtype) - self.assertEqual(tuple(), shape) - self.assertTrue(np.allclose(1.1, data)) - - def testArrayViewSlicingStringTensorToNonScalarSubArray(self): - # Construct a numpy array that corresponds to a TensorFlow string tensor - # value. - x = np.array([['foo', 'bar', 'qux'], ['baz', 'corge', 'grault']], - dtype=np.object) - dtype, shape, data = tensor_helper.array_view(x, slicing='[:2, :2]') - self.assertEqual('string', dtype) - self.assertEqual((2, 2), shape) - self.assertEqual([['foo', 'bar'], ['baz', 'corge']], data) - - def testArrayViewSlicingStringTensorToScalar(self): - # Construct a numpy array that corresponds to a TensorFlow string tensor - # value. - x = np.array([['foo', 'bar', 'qux'], ['baz', 'corge', 'grault']], - dtype=np.object) - dtype, shape, data = tensor_helper.array_view(x, slicing='[1, 1]') - self.assertEqual('string', dtype) - self.assertEqual((1, 1), shape) - self.assertEqual([['corge']], data) - - def testArrayViewOnScalarString(self): - # Construct a numpy scalar that corresponds to a TensorFlow string tensor - # value. - x = np.array('foo', dtype=np.object) - dtype, shape, data = tensor_helper.array_view(x) - self.assertEqual('string', dtype) - self.assertEqual(tuple(), shape) - self.assertEqual('foo', data) - - def testImagePngMappingWorksForArrayWithOnlyInfAndNaN(self): - x = np.array([[np.nan, -np.inf], [np.inf, np.nan]], dtype=np.float32) - dtype, shape, data = tensor_helper.array_view(x, mapping="image/png") - self.assertEqual("float32", dtype) - self.assertEqual((2, 2), shape) - decoded_x = im_util.decode_png(base64.b64decode(data)) - self.assertEqual((2, 2, 3), decoded_x.shape) - self.assertEqual(np.uint8, decoded_x.dtype) - self.assertAllClose(tensor_helper.NAN_RGB, decoded_x[0, 0, :]) # nan. - self.assertAllClose(tensor_helper.NEGATIVE_INFINITY_RGB, - decoded_x[0, 1, :]) # -infinity. - self.assertAllClose(tensor_helper.POSITIVE_INFINITY_RGB, - decoded_x[1, 0, :]) # +infinity. - self.assertAllClose(tensor_helper.NAN_RGB, decoded_x[1, 1, :]) # nan. - - def testImagePngMappingRaisesExceptionForEmptyArray(self): - x = np.zeros([0, 0]) - with six.assertRaisesRegex( - self, ValueError, r"Cannot encode an empty array .* \(0, 0\)"): - tensor_helper.array_view(x, mapping="image/png") - - def testImagePngMappingRaisesExceptionForNonRank2Array(self): - x = np.ones([2, 2, 2]) - with six.assertRaisesRegex( - self, ValueError, r"Expected rank-2 array; received rank-3 array"): - tensor_helper.array_to_base64_png(x) + def testArrayViewFloat2DNoSlicing(self): + float_array = np.ones([3, 3], dtype=np.float32) + dtype, shape, values = tensor_helper.array_view(float_array) + self.assertEqual("float32", dtype) + self.assertEqual((3, 3), shape) + self.assertEqual(float_array.tolist(), values) + + def testArrayViewFloat2DWithSlicing(self): + x = np.ones([4, 4], dtype=np.float64) + y = np.zeros([4, 4], dtype=np.float64) + float_array = np.concatenate((x, y), axis=1) + + dtype, shape, values = tensor_helper.array_view( + float_array, slicing="[2:, :]" + ) + self.assertEqual("float64", dtype) + self.assertEqual((2, 8), shape) + self.assertAllClose( + [[1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0, 0]], values + ) + + def testArrayViewInt3DWithSlicing(self): + x = np.ones([4, 4], dtype=np.int32) + int_array = np.zeros([3, 4, 4], dtype=np.int32) + int_array[0, ...] = x + int_array[1, ...] = 2 * x + int_array[2, ...] = 3 * x + + dtype, shape, values = tensor_helper.array_view( + int_array, slicing="[:, :, 2]" + ) + self.assertEqual("int32", dtype) + self.assertEqual((3, 4), shape) + self.assertEqual([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]], values) + + def testArrayView2DWithSlicingAndImagePngMapping(self): + x = np.ones([15, 16], dtype=np.int32) + dtype, shape, data = tensor_helper.array_view( + x, slicing="[:15:3, :16:2]", mapping="image/png" + ) + self.assertEqual("int32", dtype) + self.assertEqual((5, 8), shape) + decoded_x = im_util.decode_png(base64.b64decode(data)) + self.assertEqual((5, 8, 3), decoded_x.shape) + self.assertEqual(np.uint8, decoded_x.dtype) + self.assertAllClose(np.zeros([5, 8, 3]), decoded_x) + + def testImagePngMappingWorksForArrayWithOnlyOneElement(self): + x = np.array([[-42]], dtype=np.int16) + dtype, shape, data = tensor_helper.array_view(x, mapping="image/png") + self.assertEqual("int16", dtype) + self.assertEqual((1, 1), shape) + decoded_x = im_util.decode_png(base64.b64decode(data)) + self.assertEqual((1, 1, 3), decoded_x.shape) + self.assertEqual(np.uint8, decoded_x.dtype) + self.assertAllClose(np.zeros([1, 1, 3]), decoded_x) + + def testImagePngMappingWorksForArrayWithInfAndNaN(self): + x = np.array( + [[1.1, 2.2, np.inf], [-np.inf, 3.3, np.nan]], dtype=np.float32 + ) + dtype, shape, data = tensor_helper.array_view(x, mapping="image/png") + self.assertEqual("float32", dtype) + self.assertEqual((2, 3), shape) + decoded_x = im_util.decode_png(base64.b64decode(data)) + self.assertEqual((2, 3, 3), decoded_x.shape) + self.assertEqual(np.uint8, decoded_x.dtype) + self.assertAllClose([0, 0, 0], decoded_x[0, 0, :]) # 1.1. + self.assertAllClose([127, 127, 127], decoded_x[0, 1, :]) # 2.2. + self.assertAllClose( + tensor_helper.POSITIVE_INFINITY_RGB, decoded_x[0, 2, :] + ) # +infinity. + self.assertAllClose( + tensor_helper.NEGATIVE_INFINITY_RGB, decoded_x[1, 0, :] + ) # -infinity. + self.assertAllClose([255, 255, 255], decoded_x[1, 1, :]) # 3.3. + self.assertAllClose(tensor_helper.NAN_RGB, decoded_x[1, 2, :]) # nan. + + def testArrayViewSlicingDownNumericTensorToOneElement(self): + x = np.array( + [[1.1, 2.2, np.inf], [-np.inf, 3.3, np.nan]], dtype=np.float32 + ) + dtype, shape, data = tensor_helper.array_view(x, slicing="[0,0]") + self.assertEqual("float32", dtype) + self.assertEqual(tuple(), shape) + self.assertTrue(np.allclose(1.1, data)) + + def testArrayViewSlicingStringTensorToNonScalarSubArray(self): + # Construct a numpy array that corresponds to a TensorFlow string tensor + # value. + x = np.array( + [["foo", "bar", "qux"], ["baz", "corge", "grault"]], dtype=np.object + ) + dtype, shape, data = tensor_helper.array_view(x, slicing="[:2, :2]") + self.assertEqual("string", dtype) + self.assertEqual((2, 2), shape) + self.assertEqual([["foo", "bar"], ["baz", "corge"]], data) + + def testArrayViewSlicingStringTensorToScalar(self): + # Construct a numpy array that corresponds to a TensorFlow string tensor + # value. + x = np.array( + [["foo", "bar", "qux"], ["baz", "corge", "grault"]], dtype=np.object + ) + dtype, shape, data = tensor_helper.array_view(x, slicing="[1, 1]") + self.assertEqual("string", dtype) + self.assertEqual((1, 1), shape) + self.assertEqual([["corge"]], data) + + def testArrayViewOnScalarString(self): + # Construct a numpy scalar that corresponds to a TensorFlow string tensor + # value. + x = np.array("foo", dtype=np.object) + dtype, shape, data = tensor_helper.array_view(x) + self.assertEqual("string", dtype) + self.assertEqual(tuple(), shape) + self.assertEqual("foo", data) + + def testImagePngMappingWorksForArrayWithOnlyInfAndNaN(self): + x = np.array([[np.nan, -np.inf], [np.inf, np.nan]], dtype=np.float32) + dtype, shape, data = tensor_helper.array_view(x, mapping="image/png") + self.assertEqual("float32", dtype) + self.assertEqual((2, 2), shape) + decoded_x = im_util.decode_png(base64.b64decode(data)) + self.assertEqual((2, 2, 3), decoded_x.shape) + self.assertEqual(np.uint8, decoded_x.dtype) + self.assertAllClose(tensor_helper.NAN_RGB, decoded_x[0, 0, :]) # nan. + self.assertAllClose( + tensor_helper.NEGATIVE_INFINITY_RGB, decoded_x[0, 1, :] + ) # -infinity. + self.assertAllClose( + tensor_helper.POSITIVE_INFINITY_RGB, decoded_x[1, 0, :] + ) # +infinity. + self.assertAllClose(tensor_helper.NAN_RGB, decoded_x[1, 1, :]) # nan. + + def testImagePngMappingRaisesExceptionForEmptyArray(self): + x = np.zeros([0, 0]) + with six.assertRaisesRegex( + self, ValueError, r"Cannot encode an empty array .* \(0, 0\)" + ): + tensor_helper.array_view(x, mapping="image/png") + + def testImagePngMappingRaisesExceptionForNonRank2Array(self): + x = np.ones([2, 2, 2]) + with six.assertRaisesRegex( + self, ValueError, r"Expected rank-2 array; received rank-3 array" + ): + tensor_helper.array_to_base64_png(x) class ArrayToBase64PNGTest(tf.test.TestCase): - - def testConvertHealthy2DArray(self): - x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - encoded_x = tensor_helper.array_to_base64_png(x) - decoded_x = im_util.decode_png(base64.b64decode(encoded_x)) - self.assertEqual((3, 3, 3), decoded_x.shape) - decoded_flat = decoded_x.flatten() - self.assertEqual(0, np.min(decoded_flat)) - self.assertEqual(255, np.max(decoded_flat)) - - def testConvertHealthy2DNestedList(self): - x = [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] - encoded_x = tensor_helper.array_to_base64_png(x) - decoded_x = im_util.decode_png(base64.b64decode(encoded_x)) - self.assertEqual((4, 4, 3), decoded_x.shape) - decoded_flat = decoded_x.flatten() - self.assertEqual(0, np.min(decoded_flat)) - self.assertEqual(255, np.max(decoded_flat)) + def testConvertHealthy2DArray(self): + x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + encoded_x = tensor_helper.array_to_base64_png(x) + decoded_x = im_util.decode_png(base64.b64decode(encoded_x)) + self.assertEqual((3, 3, 3), decoded_x.shape) + decoded_flat = decoded_x.flatten() + self.assertEqual(0, np.min(decoded_flat)) + self.assertEqual(255, np.max(decoded_flat)) + + def testConvertHealthy2DNestedList(self): + x = [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] + encoded_x = tensor_helper.array_to_base64_png(x) + decoded_x = im_util.decode_png(base64.b64decode(encoded_x)) + self.assertEqual((4, 4, 3), decoded_x.shape) + decoded_flat = decoded_x.flatten() + self.assertEqual(0, np.min(decoded_flat)) + self.assertEqual(255, np.max(decoded_flat)) class ParseTimeIndicesTest(tf.test.TestCase): - - def testParseSingleIntegerMinusOne(self): - slicing = tensor_helper.parse_time_indices('-1') - self.assertEqual(-1, slicing) - - def testParseSingleIntegerMinusOneWithBrackets(self): - slicing = tensor_helper.parse_time_indices('[-1]') - self.assertEqual(-1, slicing) - - def testParseSlicingWithStartAndStop(self): - slicing = tensor_helper.parse_time_indices('[0:3]') - self.assertEqual(slice(0, 3, None), slicing) - slicing = tensor_helper.parse_time_indices('0:3') - self.assertEqual(slice(0, 3, None), slicing) - - def testParseSlicingWithStep(self): - slicing = tensor_helper.parse_time_indices('[::2]') - self.assertEqual(slice(None, None, 2), slicing) - slicing = tensor_helper.parse_time_indices('::2') - self.assertEqual(slice(None, None, 2), slicing) - - def testParseSlicingWithOnlyStart(self): - slicing = tensor_helper.parse_time_indices('[3:]') - self.assertEqual(slice(3, None, None), slicing) - slicing = tensor_helper.parse_time_indices('3:') - self.assertEqual(slice(3, None, None), slicing) - - def testParseSlicingWithMinusOneStop(self): - slicing = tensor_helper.parse_time_indices('[3:-1]') - self.assertEqual(slice(3, -1, None), slicing) - slicing = tensor_helper.parse_time_indices('3:-1') - self.assertEqual(slice(3, -1, None), slicing) - - def testParseSlicingWithOnlyStop(self): - slicing = tensor_helper.parse_time_indices('[:-2]') - self.assertEqual(slice(None, -2, None), slicing) - slicing = tensor_helper.parse_time_indices(':-2') - self.assertEqual(slice(None, -2, None), slicing) - - def test2DSlicingLeadsToError(self): - with self.assertRaises(ValueError): - tensor_helper.parse_time_indices('[1:2, 3:4]') - with self.assertRaises(ValueError): - tensor_helper.parse_time_indices('1:2,3:4') - - -if __name__ == '__main__': - tf.test.main() + def testParseSingleIntegerMinusOne(self): + slicing = tensor_helper.parse_time_indices("-1") + self.assertEqual(-1, slicing) + + def testParseSingleIntegerMinusOneWithBrackets(self): + slicing = tensor_helper.parse_time_indices("[-1]") + self.assertEqual(-1, slicing) + + def testParseSlicingWithStartAndStop(self): + slicing = tensor_helper.parse_time_indices("[0:3]") + self.assertEqual(slice(0, 3, None), slicing) + slicing = tensor_helper.parse_time_indices("0:3") + self.assertEqual(slice(0, 3, None), slicing) + + def testParseSlicingWithStep(self): + slicing = tensor_helper.parse_time_indices("[::2]") + self.assertEqual(slice(None, None, 2), slicing) + slicing = tensor_helper.parse_time_indices("::2") + self.assertEqual(slice(None, None, 2), slicing) + + def testParseSlicingWithOnlyStart(self): + slicing = tensor_helper.parse_time_indices("[3:]") + self.assertEqual(slice(3, None, None), slicing) + slicing = tensor_helper.parse_time_indices("3:") + self.assertEqual(slice(3, None, None), slicing) + + def testParseSlicingWithMinusOneStop(self): + slicing = tensor_helper.parse_time_indices("[3:-1]") + self.assertEqual(slice(3, -1, None), slicing) + slicing = tensor_helper.parse_time_indices("3:-1") + self.assertEqual(slice(3, -1, None), slicing) + + def testParseSlicingWithOnlyStop(self): + slicing = tensor_helper.parse_time_indices("[:-2]") + self.assertEqual(slice(None, -2, None), slicing) + slicing = tensor_helper.parse_time_indices(":-2") + self.assertEqual(slice(None, -2, None), slicing) + + def test2DSlicingLeadsToError(self): + with self.assertRaises(ValueError): + tensor_helper.parse_time_indices("[1:2, 3:4]") + with self.assertRaises(ValueError): + tensor_helper.parse_time_indices("1:2,3:4") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/debugger/tensor_store.py b/tensorboard/plugins/debugger/tensor_store.py index 9552498cc5..1f7881bebb 100644 --- a/tensorboard/plugins/debugger/tensor_store.py +++ b/tensorboard/plugins/debugger/tensor_store.py @@ -31,231 +31,236 @@ class _TensorValueDiscarded(object): + def __init__(self, watch_key, time_index): + self._watch_key = watch_key + self._time_index = time_index - def __init__(self, watch_key, time_index): - self._watch_key = watch_key - self._time_index = time_index + @property + def watch_key(self): + return self._watch_key - @property - def watch_key(self): - return self._watch_key + @property + def time_index(self): + return self._time_index - @property - def time_index(self): - return self._time_index - - @property - def nbytes(self): - return 0 + @property + def nbytes(self): + return 0 class _WatchStore(object): - """The store for a single debug tensor watch. - - Discards data according to pre-set byte limit. - """ - - def __init__(self, - watch_key, - mem_bytes_limit=10e6): - """Constructor of _WatchStore. - - The overflowing works as follows: - The most recent tensor values are stored in memory, up to `mem_bytes_limit` - bytes. But at least one (the most recent) value is always stored in memory. - For older tensors exceeding that limit, they are discarded. - - Args: - watch_key: A string representing the debugger tensor watch, with th - format: - :: - e.g., - 'Dense_1/BiasAdd:0:DebugIdentity'. - mem_bytes_limit: Limit on number of bytes to store in memory. - """ + """The store for a single debug tensor watch. - self._watch_key = watch_key - self._mem_bytes_limit = mem_bytes_limit - self._in_mem_bytes = 0 - self._disposed = False - self._data = [] # A map from index to tensor value. - - def add(self, value): - """Add a tensor the watch store.""" - if self._disposed: - raise ValueError( - 'Cannot add value: this _WatchStore instance is already disposed') - self._data.append(value) - if hasattr(value, 'nbytes'): - self._in_mem_bytes += value.nbytes - self._ensure_bytes_limits() - - def _ensure_bytes_limits(self): - # TODO(cais): Thread safety? - if self._in_mem_bytes <= self._mem_bytes_limit: - return - - i = len(self._data) - 1 - cum_mem_size = 0 - while i >= 0: - if hasattr(self._data[i], 'nbytes'): - cum_mem_size += self._data[i].nbytes - if i < len(self._data) - 1 and cum_mem_size > self._mem_bytes_limit: - # Always keep at least one time index in the memory. - break - i -= 1 - # i is now the last time index to discard. - - # Mark remaining ones as discarded. - while i >= 0: - if not isinstance(self._data[i], _TensorValueDiscarded): - self._data[i] = _TensorValueDiscarded(self._watch_key, i) - i -= 1 - - def num_total(self): - """Get the total number of values.""" - return len(self._data) - - def num_in_memory(self): - """Get number of values in memory.""" - n = len(self._data) - 1 - while n >= 0: - if isinstance(self._data[n], _TensorValueDiscarded): - break - n -= 1 - return len(self._data) - 1 - n - - def num_discarded(self): - """Get the number of values discarded due to exceeding both limits.""" - if not self._data: - return 0 - n = 0 - while n < len(self._data): - if not isinstance(self._data[n], _TensorValueDiscarded): - break - n += 1 - return n - - def query(self, time_indices): - """Query the values at given time indices. - - Args: - time_indices: 0-based time indices to query, as a `list` of `int`. - - Returns: - Values as a list of `numpy.ndarray` (for time indices in memory) or - `None` (for time indices discarded). + Discards data according to pre-set byte limit. """ - if self._disposed: - raise ValueError( - 'Cannot query: this _WatchStore instance is already disposed') - if not isinstance(time_indices, (tuple, list)): - time_indices = [time_indices] - output = [] - for time_index in time_indices: - if isinstance(self._data[time_index], _TensorValueDiscarded): - output.append(None) - else: - data_item = self._data[time_index] - if (hasattr(data_item, 'dtype') and - tensor_helper.translate_dtype(data_item.dtype) == 'string'): - _, _, data_item = tensor_helper.array_view(data_item) - data_item = np.array( - tensor_helper.process_buffers_for_display(data_item), - dtype=np.object) - output.append(data_item) - - return output - - def dispose(self): - self._disposed = True - - -class TensorStore(object): - def __init__(self, watch_mem_bytes_limit=10e6): - """Constructor of TensorStore. + def __init__(self, watch_key, mem_bytes_limit=10e6): + """Constructor of _WatchStore. + + The overflowing works as follows: + The most recent tensor values are stored in memory, up to `mem_bytes_limit` + bytes. But at least one (the most recent) value is always stored in memory. + For older tensors exceeding that limit, they are discarded. + + Args: + watch_key: A string representing the debugger tensor watch, with th + format: + :: + e.g., + 'Dense_1/BiasAdd:0:DebugIdentity'. + mem_bytes_limit: Limit on number of bytes to store in memory. + """ + + self._watch_key = watch_key + self._mem_bytes_limit = mem_bytes_limit + self._in_mem_bytes = 0 + self._disposed = False + self._data = [] # A map from index to tensor value. + + def add(self, value): + """Add a tensor the watch store.""" + if self._disposed: + raise ValueError( + "Cannot add value: this _WatchStore instance is already disposed" + ) + self._data.append(value) + if hasattr(value, "nbytes"): + self._in_mem_bytes += value.nbytes + self._ensure_bytes_limits() + + def _ensure_bytes_limits(self): + # TODO(cais): Thread safety? + if self._in_mem_bytes <= self._mem_bytes_limit: + return + + i = len(self._data) - 1 + cum_mem_size = 0 + while i >= 0: + if hasattr(self._data[i], "nbytes"): + cum_mem_size += self._data[i].nbytes + if i < len(self._data) - 1 and cum_mem_size > self._mem_bytes_limit: + # Always keep at least one time index in the memory. + break + i -= 1 + # i is now the last time index to discard. + + # Mark remaining ones as discarded. + while i >= 0: + if not isinstance(self._data[i], _TensorValueDiscarded): + self._data[i] = _TensorValueDiscarded(self._watch_key, i) + i -= 1 + + def num_total(self): + """Get the total number of values.""" + return len(self._data) + + def num_in_memory(self): + """Get number of values in memory.""" + n = len(self._data) - 1 + while n >= 0: + if isinstance(self._data[n], _TensorValueDiscarded): + break + n -= 1 + return len(self._data) - 1 - n + + def num_discarded(self): + """Get the number of values discarded due to exceeding both limits.""" + if not self._data: + return 0 + n = 0 + while n < len(self._data): + if not isinstance(self._data[n], _TensorValueDiscarded): + break + n += 1 + return n + + def query(self, time_indices): + """Query the values at given time indices. + + Args: + time_indices: 0-based time indices to query, as a `list` of `int`. + + Returns: + Values as a list of `numpy.ndarray` (for time indices in memory) or + `None` (for time indices discarded). + """ + if self._disposed: + raise ValueError( + "Cannot query: this _WatchStore instance is already disposed" + ) + if not isinstance(time_indices, (tuple, list)): + time_indices = [time_indices] + output = [] + for time_index in time_indices: + if isinstance(self._data[time_index], _TensorValueDiscarded): + output.append(None) + else: + data_item = self._data[time_index] + if ( + hasattr(data_item, "dtype") + and tensor_helper.translate_dtype(data_item.dtype) + == "string" + ): + _, _, data_item = tensor_helper.array_view(data_item) + data_item = np.array( + tensor_helper.process_buffers_for_display(data_item), + dtype=np.object, + ) + output.append(data_item) + + return output + + def dispose(self): + self._disposed = True - Args: - watch_mem_bytes_limit: Limit on number of bytes to store in memory for - each watch key. - """ - self._watch_mem_bytes_limit = watch_mem_bytes_limit - self._tensor_data = dict() # A map from watch key to _WatchStore instances. - - def add(self, watch_key, tensor_value): - """Add a tensor value. - Args: - watch_key: A string representing the debugger tensor watch, e.g., - 'Dense_1/BiasAdd:0:DebugIdentity'. - tensor_value: The value of the tensor as a numpy.ndarray. - """ - if watch_key not in self._tensor_data: - self._tensor_data[watch_key] = _WatchStore( - watch_key, - mem_bytes_limit=self._watch_mem_bytes_limit) - self._tensor_data[watch_key].add(tensor_value) - - def query(self, - watch_key, - time_indices=None, - slicing=None, - mapping=None): - """Query tensor store for a given watch_key. - - Args: - watch_key: The watch key to query. - time_indices: A numpy-style slicing string for time indices. E.g., - `-1`, `:-2`, `[::2]`. If not provided (`None`), will use -1. - slicing: A numpy-style slicing string for individual time steps. - mapping: An mapping string or a list of them. Supported mappings: - `{None, 'image/png', 'health-pill'}`. - - Returns: - The potentially sliced values as a nested list of values or its mapped - format. A `list` of nested `list` of values. - - Raises: - ValueError: If the shape of the sliced array is incompatible with mapping - mode. Or if the mapping type is invalid. - """ - if watch_key not in self._tensor_data: - raise KeyError("watch_key not found: %s" % watch_key) - - if time_indices is None: - time_indices = '-1' - time_slicing = tensor_helper.parse_time_indices(time_indices) - all_time_indices = list(range(self._tensor_data[watch_key].num_total())) - sliced_time_indices = all_time_indices[time_slicing] - if not isinstance(sliced_time_indices, list): - sliced_time_indices = [sliced_time_indices] - - recombine_and_map = False - step_mapping = mapping - if len(sliced_time_indices) > 1 and mapping not in (None, ): - recombine_and_map = True - step_mapping = None - - output = [] - for index in sliced_time_indices: - value = self._tensor_data[watch_key].query(index)[0] - if (value is not None and - not isinstance(value, debug_data.InconvertibleTensorProto)): - output.append(tensor_helper.array_view( - value, slicing=slicing, mapping=step_mapping)[2]) - else: - output.append(None) - - if recombine_and_map: - if mapping == 'image/png': - output = tensor_helper.array_to_base64_png(output) - elif mapping and mapping != 'none': - logger.warn( - 'Unsupported mapping mode after recomining time steps: %s', - mapping) - return output - - def dispose(self): - for watch_key in self._tensor_data: - self._tensor_data[watch_key].dispose() +class TensorStore(object): + def __init__(self, watch_mem_bytes_limit=10e6): + """Constructor of TensorStore. + + Args: + watch_mem_bytes_limit: Limit on number of bytes to store in memory for + each watch key. + """ + self._watch_mem_bytes_limit = watch_mem_bytes_limit + self._tensor_data = ( + dict() + ) # A map from watch key to _WatchStore instances. + + def add(self, watch_key, tensor_value): + """Add a tensor value. + + Args: + watch_key: A string representing the debugger tensor watch, e.g., + 'Dense_1/BiasAdd:0:DebugIdentity'. + tensor_value: The value of the tensor as a numpy.ndarray. + """ + if watch_key not in self._tensor_data: + self._tensor_data[watch_key] = _WatchStore( + watch_key, mem_bytes_limit=self._watch_mem_bytes_limit + ) + self._tensor_data[watch_key].add(tensor_value) + + def query(self, watch_key, time_indices=None, slicing=None, mapping=None): + """Query tensor store for a given watch_key. + + Args: + watch_key: The watch key to query. + time_indices: A numpy-style slicing string for time indices. E.g., + `-1`, `:-2`, `[::2]`. If not provided (`None`), will use -1. + slicing: A numpy-style slicing string for individual time steps. + mapping: An mapping string or a list of them. Supported mappings: + `{None, 'image/png', 'health-pill'}`. + + Returns: + The potentially sliced values as a nested list of values or its mapped + format. A `list` of nested `list` of values. + + Raises: + ValueError: If the shape of the sliced array is incompatible with mapping + mode. Or if the mapping type is invalid. + """ + if watch_key not in self._tensor_data: + raise KeyError("watch_key not found: %s" % watch_key) + + if time_indices is None: + time_indices = "-1" + time_slicing = tensor_helper.parse_time_indices(time_indices) + all_time_indices = list(range(self._tensor_data[watch_key].num_total())) + sliced_time_indices = all_time_indices[time_slicing] + if not isinstance(sliced_time_indices, list): + sliced_time_indices = [sliced_time_indices] + + recombine_and_map = False + step_mapping = mapping + if len(sliced_time_indices) > 1 and mapping not in (None,): + recombine_and_map = True + step_mapping = None + + output = [] + for index in sliced_time_indices: + value = self._tensor_data[watch_key].query(index)[0] + if value is not None and not isinstance( + value, debug_data.InconvertibleTensorProto + ): + output.append( + tensor_helper.array_view( + value, slicing=slicing, mapping=step_mapping + )[2] + ) + else: + output.append(None) + + if recombine_and_map: + if mapping == "image/png": + output = tensor_helper.array_to_base64_png(output) + elif mapping and mapping != "none": + logger.warn( + "Unsupported mapping mode after recomining time steps: %s", + mapping, + ) + return output + + def dispose(self): + for watch_key in self._tensor_data: + self._tensor_data[watch_key].dispose() diff --git a/tensorboard/plugins/debugger/tensor_store_test.py b/tensorboard/plugins/debugger/tensor_store_test.py index cb2e6eebf2..5a8c552f95 100644 --- a/tensorboard/plugins/debugger/tensor_store_test.py +++ b/tensorboard/plugins/debugger/tensor_store_test.py @@ -29,223 +29,229 @@ class WatchStoreTest(tf.test.TestCase): - - def testAlwaysKeepsOneValueInMemory(self): - watch_key = 'Dense/BiasAdd:0:DebugIdentity' - watch_store = tensor_store._WatchStore(watch_key, mem_bytes_limit=50) - - value = np.eye(3, dtype=np.float64) - self.assertEqual(72, value.nbytes) - self.assertEqual(0, watch_store.num_total()) - self.assertEqual(0, watch_store.num_in_memory()) - self.assertEqual(0, watch_store.num_discarded()) - - watch_store.add(value) - self.assertEqual(1, watch_store.num_total()) - self.assertEqual(1, watch_store.num_in_memory()) - self.assertEqual(0, watch_store.num_discarded()) - self.assertAllEqual([value], watch_store.query(0)) - self.assertAllEqual([value], watch_store.query([0])) - with self.assertRaises(IndexError): - watch_store.query([1]) - - watch_store.add(value * 2) - self.assertEqual(2, watch_store.num_total()) - self.assertEqual(1, watch_store.num_in_memory()) - self.assertEqual(1, watch_store.num_discarded()) - self.assertEqual([None], watch_store.query([0])) - self.assertIsNone(watch_store.query([0, 1])[0]) - self.assertAllEqual(value * 2, watch_store.query([0, 1])[-1]) - with self.assertRaises(IndexError): - watch_store.query(2) - - def testDiscarding(self): - watch_key = 'Dense/BiasAdd:0:DebugIdentity' - watch_store = tensor_store._WatchStore(watch_key, mem_bytes_limit=150) - - value = np.eye(3, dtype=np.float64) - self.assertEqual(72, value.nbytes) - - watch_store.add(value) - watch_store.add(value * 2) - self.assertEqual(2, watch_store.num_total()) - self.assertEqual(2, watch_store.num_in_memory()) - self.assertEqual(0, watch_store.num_discarded()) - self.assertAllEqual([value], watch_store.query([0])) - self.assertAllEqual([value, value * 2], watch_store.query([0, 1])) - with self.assertRaises(IndexError): - watch_store.query(2) - - watch_store.add(value * 3) - self.assertEqual(3, watch_store.num_total()) - self.assertEqual(2, watch_store.num_in_memory()) - self.assertEqual(1, watch_store.num_discarded()) - self.assertEqual([None], watch_store.query([0])) - result = watch_store.query([0, 1]) - self.assertIsNone(result[0]) - self.assertAllEqual(value * 2, result[1]) - result = watch_store.query([0, 1, 2]) - self.assertIsNone(result[0]) - self.assertAllEqual([value * 2, value * 3], result[1:]) - with self.assertRaises(IndexError): - watch_store.query(3) - - watch_store.add(value * 4) - self.assertEqual(4, watch_store.num_total()) - self.assertEqual(2, watch_store.num_in_memory()) - self.assertEqual(2, watch_store.num_discarded()) - self.assertEqual([None], watch_store.query([0])) - result = watch_store.query([0, 1]) - self.assertIsNone(result[0]) - self.assertIsNone(result[1]) - result = watch_store.query([0, 1, 2]) - self.assertIsNone(result[0]) - self.assertIsNone(result[1]) - self.assertAllEqual(value * 3, result[2]) - result = watch_store.query([0, 1, 2, 3]) - self.assertIsNone(result[0]) - self.assertIsNone(result[1]) - self.assertAllEqual(value * 3, result[2]) - self.assertAllEqual(value * 4, result[3]) - with self.assertRaises(IndexError): - watch_store.query(4) - - def testAddAndQueryUnitializedTensor(self): - watch_key = 'Dense/Bias:0:DebugIdentity' - watch_store = tensor_store._WatchStore(watch_key, mem_bytes_limit=50) - uninitialized_value = debug_data.InconvertibleTensorProto( - None, initialized=False) - watch_store.add(uninitialized_value) - initialized_value = np.zeros([3], dtype=np.float64) - watch_store.add(initialized_value) - result = watch_store.query([0, 1]) - self.assertEqual(2, len(result)) - self.assertIsInstance(result[0], debug_data.InconvertibleTensorProto) - self.assertAllClose(initialized_value, result[1]) + def testAlwaysKeepsOneValueInMemory(self): + watch_key = "Dense/BiasAdd:0:DebugIdentity" + watch_store = tensor_store._WatchStore(watch_key, mem_bytes_limit=50) + + value = np.eye(3, dtype=np.float64) + self.assertEqual(72, value.nbytes) + self.assertEqual(0, watch_store.num_total()) + self.assertEqual(0, watch_store.num_in_memory()) + self.assertEqual(0, watch_store.num_discarded()) + + watch_store.add(value) + self.assertEqual(1, watch_store.num_total()) + self.assertEqual(1, watch_store.num_in_memory()) + self.assertEqual(0, watch_store.num_discarded()) + self.assertAllEqual([value], watch_store.query(0)) + self.assertAllEqual([value], watch_store.query([0])) + with self.assertRaises(IndexError): + watch_store.query([1]) + + watch_store.add(value * 2) + self.assertEqual(2, watch_store.num_total()) + self.assertEqual(1, watch_store.num_in_memory()) + self.assertEqual(1, watch_store.num_discarded()) + self.assertEqual([None], watch_store.query([0])) + self.assertIsNone(watch_store.query([0, 1])[0]) + self.assertAllEqual(value * 2, watch_store.query([0, 1])[-1]) + with self.assertRaises(IndexError): + watch_store.query(2) + + def testDiscarding(self): + watch_key = "Dense/BiasAdd:0:DebugIdentity" + watch_store = tensor_store._WatchStore(watch_key, mem_bytes_limit=150) + + value = np.eye(3, dtype=np.float64) + self.assertEqual(72, value.nbytes) + + watch_store.add(value) + watch_store.add(value * 2) + self.assertEqual(2, watch_store.num_total()) + self.assertEqual(2, watch_store.num_in_memory()) + self.assertEqual(0, watch_store.num_discarded()) + self.assertAllEqual([value], watch_store.query([0])) + self.assertAllEqual([value, value * 2], watch_store.query([0, 1])) + with self.assertRaises(IndexError): + watch_store.query(2) + + watch_store.add(value * 3) + self.assertEqual(3, watch_store.num_total()) + self.assertEqual(2, watch_store.num_in_memory()) + self.assertEqual(1, watch_store.num_discarded()) + self.assertEqual([None], watch_store.query([0])) + result = watch_store.query([0, 1]) + self.assertIsNone(result[0]) + self.assertAllEqual(value * 2, result[1]) + result = watch_store.query([0, 1, 2]) + self.assertIsNone(result[0]) + self.assertAllEqual([value * 2, value * 3], result[1:]) + with self.assertRaises(IndexError): + watch_store.query(3) + + watch_store.add(value * 4) + self.assertEqual(4, watch_store.num_total()) + self.assertEqual(2, watch_store.num_in_memory()) + self.assertEqual(2, watch_store.num_discarded()) + self.assertEqual([None], watch_store.query([0])) + result = watch_store.query([0, 1]) + self.assertIsNone(result[0]) + self.assertIsNone(result[1]) + result = watch_store.query([0, 1, 2]) + self.assertIsNone(result[0]) + self.assertIsNone(result[1]) + self.assertAllEqual(value * 3, result[2]) + result = watch_store.query([0, 1, 2, 3]) + self.assertIsNone(result[0]) + self.assertIsNone(result[1]) + self.assertAllEqual(value * 3, result[2]) + self.assertAllEqual(value * 4, result[3]) + with self.assertRaises(IndexError): + watch_store.query(4) + + def testAddAndQueryUnitializedTensor(self): + watch_key = "Dense/Bias:0:DebugIdentity" + watch_store = tensor_store._WatchStore(watch_key, mem_bytes_limit=50) + uninitialized_value = debug_data.InconvertibleTensorProto( + None, initialized=False + ) + watch_store.add(uninitialized_value) + initialized_value = np.zeros([3], dtype=np.float64) + watch_store.add(initialized_value) + result = watch_store.query([0, 1]) + self.assertEqual(2, len(result)) + self.assertIsInstance(result[0], debug_data.InconvertibleTensorProto) + self.assertAllClose(initialized_value, result[1]) class TensorHelperTest(tf.test.TestCase): - - def testAddAndQuerySingleTensor(self): - store = tensor_store.TensorStore() - watch_key = "A:0:DebugIdentity" - data = np.array([[1, 2], [3, 4]]) - store.add(watch_key, data) - self.assertAllClose([data], store.query(watch_key)) - - def testAddAndQuerySingleTensorWithSlicing(self): - store = tensor_store.TensorStore() - watch_key = "A:0:DebugIdentity" - data = np.array([[1, 2], [3, 4]]) - store.add(watch_key, data) - self.assertAllClose([[2, 4]], store.query(watch_key, slicing="[:, 1]")) - - def testAddAndQueryMultipleTensorForSameWatchKey(self): - store = tensor_store.TensorStore() - watch_key = "A:0:DebugIdentity" - data1 = np.array([[1, 2], [3, 4]]) - data2 = np.array([[-1, -2], [-3, -4]]) - store.add(watch_key, data1) - store.add(watch_key, data2) - - self.assertAllClose([data2], store.query(watch_key)) - self.assertAllClose([data1], store.query(watch_key, time_indices='0')) - self.assertAllClose([data2], store.query(watch_key, time_indices='1')) - self.assertAllClose([data2], store.query(watch_key, time_indices='-1')) - - def testAddAndQueryMultipleTensorForSameWatchKeyWithSlicing(self): - store = tensor_store.TensorStore() - watch_key = "A:0:DebugIdentity" - data1 = np.array([[1, 2], [3, 4]]) - data2 = np.array([[-1, -2], [-3, -4]]) - store.add(watch_key, data1) - store.add(watch_key, data2) - - self.assertAllClose( - [[2, 4], [-2, -4]], - store.query(watch_key, time_indices='0:2', slicing="[:,1]")) - - def testQueryMultipleTensorsAtOnce(self): - store = tensor_store.TensorStore() - watch_key = "A:0:DebugIdentity" - data1 = np.array([[1, 2], [3, 4]]) - data2 = np.array([[-1, -2], [-3, -4]]) - store.add(watch_key, data1) - store.add(watch_key, data2) - - self.assertAllClose( - [[[1, 2], [3, 4]], [[-1, -2], [-3, -4]]], - store.query(watch_key, time_indices='[0:2]')) - - def testQueryNonexistentWatchKey(self): - store = tensor_store.TensorStore() - watch_key = "A:0:DebugIdentity" - data = np.array([[1, 2], [3, 4]]) - store.add(watch_key, data) - with self.assertRaises(KeyError): - store.query("B:0:DebugIdentity") - - def testQueryInvalidTimeIndex(self): - store = tensor_store.TensorStore() - watch_key = "A:0:DebugIdentity" - data = np.array([[1, 2], [3, 4]]) - store.add(watch_key, data) - with self.assertRaises(IndexError): - store.query("A:0:DebugIdentity", time_indices='10') - - def testQeuryWithTimeIndicesStop(self): - store = tensor_store.TensorStore() - watch_key = "A:0:DebugIdentity" - store.add(watch_key, np.array(1)) - store.add(watch_key, np.array(3)) - store.add(watch_key, np.array(3)) - store.add(watch_key, np.array(7)) - self.assertAllClose([1, 3, 3], store.query(watch_key, time_indices=':3:')) - - def testQeuryWithTimeIndicesStopAndStep(self): - store = tensor_store.TensorStore() - watch_key = "A:0:DebugIdentity" - store.add(watch_key, np.array(1)) - store.add(watch_key, np.array(3)) - store.add(watch_key, np.array(3)) - store.add(watch_key, np.array(7)) - self.assertAllClose([3, 7], store.query(watch_key, time_indices='1::2')) - - def testQeuryWithTimeIndicesAllRange(self): - store = tensor_store.TensorStore() - watch_key = "A:0:DebugIdentity" - store.add(watch_key, np.array(1)) - store.add(watch_key, np.array(3)) - store.add(watch_key, np.array(3)) - store.add(watch_key, np.array(7)) - self.assertAllClose([1, 3, 3, 7], store.query(watch_key, time_indices=':')) - - def testQuery1DTensorHistoryWithImagePngMapping(self): - store = tensor_store.TensorStore() - watch_key = "A:0:DebugIdentity" - store.add(watch_key, np.array([0, 2, 4, 6, 8])) - store.add(watch_key, np.array([1, 3, 5, 7, 9])) - output = store.query(watch_key, time_indices=':', mapping='image/png') - decoded = im_util.decode_png(base64.b64decode(output)) - self.assertEqual((2, 5, 3), decoded.shape) - - def testTensorValuesExceedingMemBytesLimitAreDiscarded(self): - store = tensor_store.TensorStore(watch_mem_bytes_limit=150) - watch_key = "A:0:DebugIdentity" - value = np.eye(3, dtype=np.float64) - self.assertEqual(72, value.nbytes) - store.add(watch_key, value) - self.assertAllEqual([value], store.query(watch_key, time_indices=':')) - - store.add(watch_key, value * 2) - self.assertAllEqual([value, value * 2], - store.query(watch_key, time_indices=':')) - - store.add(watch_key, value * 3) - result = store.query(watch_key, time_indices=':') - self.assertIsNone(result[0]) - self.assertAllEqual([value * 2, value * 3], result[1:]) - - -if __name__ == '__main__': - tf.test.main() + def testAddAndQuerySingleTensor(self): + store = tensor_store.TensorStore() + watch_key = "A:0:DebugIdentity" + data = np.array([[1, 2], [3, 4]]) + store.add(watch_key, data) + self.assertAllClose([data], store.query(watch_key)) + + def testAddAndQuerySingleTensorWithSlicing(self): + store = tensor_store.TensorStore() + watch_key = "A:0:DebugIdentity" + data = np.array([[1, 2], [3, 4]]) + store.add(watch_key, data) + self.assertAllClose([[2, 4]], store.query(watch_key, slicing="[:, 1]")) + + def testAddAndQueryMultipleTensorForSameWatchKey(self): + store = tensor_store.TensorStore() + watch_key = "A:0:DebugIdentity" + data1 = np.array([[1, 2], [3, 4]]) + data2 = np.array([[-1, -2], [-3, -4]]) + store.add(watch_key, data1) + store.add(watch_key, data2) + + self.assertAllClose([data2], store.query(watch_key)) + self.assertAllClose([data1], store.query(watch_key, time_indices="0")) + self.assertAllClose([data2], store.query(watch_key, time_indices="1")) + self.assertAllClose([data2], store.query(watch_key, time_indices="-1")) + + def testAddAndQueryMultipleTensorForSameWatchKeyWithSlicing(self): + store = tensor_store.TensorStore() + watch_key = "A:0:DebugIdentity" + data1 = np.array([[1, 2], [3, 4]]) + data2 = np.array([[-1, -2], [-3, -4]]) + store.add(watch_key, data1) + store.add(watch_key, data2) + + self.assertAllClose( + [[2, 4], [-2, -4]], + store.query(watch_key, time_indices="0:2", slicing="[:,1]"), + ) + + def testQueryMultipleTensorsAtOnce(self): + store = tensor_store.TensorStore() + watch_key = "A:0:DebugIdentity" + data1 = np.array([[1, 2], [3, 4]]) + data2 = np.array([[-1, -2], [-3, -4]]) + store.add(watch_key, data1) + store.add(watch_key, data2) + + self.assertAllClose( + [[[1, 2], [3, 4]], [[-1, -2], [-3, -4]]], + store.query(watch_key, time_indices="[0:2]"), + ) + + def testQueryNonexistentWatchKey(self): + store = tensor_store.TensorStore() + watch_key = "A:0:DebugIdentity" + data = np.array([[1, 2], [3, 4]]) + store.add(watch_key, data) + with self.assertRaises(KeyError): + store.query("B:0:DebugIdentity") + + def testQueryInvalidTimeIndex(self): + store = tensor_store.TensorStore() + watch_key = "A:0:DebugIdentity" + data = np.array([[1, 2], [3, 4]]) + store.add(watch_key, data) + with self.assertRaises(IndexError): + store.query("A:0:DebugIdentity", time_indices="10") + + def testQeuryWithTimeIndicesStop(self): + store = tensor_store.TensorStore() + watch_key = "A:0:DebugIdentity" + store.add(watch_key, np.array(1)) + store.add(watch_key, np.array(3)) + store.add(watch_key, np.array(3)) + store.add(watch_key, np.array(7)) + self.assertAllClose( + [1, 3, 3], store.query(watch_key, time_indices=":3:") + ) + + def testQeuryWithTimeIndicesStopAndStep(self): + store = tensor_store.TensorStore() + watch_key = "A:0:DebugIdentity" + store.add(watch_key, np.array(1)) + store.add(watch_key, np.array(3)) + store.add(watch_key, np.array(3)) + store.add(watch_key, np.array(7)) + self.assertAllClose([3, 7], store.query(watch_key, time_indices="1::2")) + + def testQeuryWithTimeIndicesAllRange(self): + store = tensor_store.TensorStore() + watch_key = "A:0:DebugIdentity" + store.add(watch_key, np.array(1)) + store.add(watch_key, np.array(3)) + store.add(watch_key, np.array(3)) + store.add(watch_key, np.array(7)) + self.assertAllClose( + [1, 3, 3, 7], store.query(watch_key, time_indices=":") + ) + + def testQuery1DTensorHistoryWithImagePngMapping(self): + store = tensor_store.TensorStore() + watch_key = "A:0:DebugIdentity" + store.add(watch_key, np.array([0, 2, 4, 6, 8])) + store.add(watch_key, np.array([1, 3, 5, 7, 9])) + output = store.query(watch_key, time_indices=":", mapping="image/png") + decoded = im_util.decode_png(base64.b64decode(output)) + self.assertEqual((2, 5, 3), decoded.shape) + + def testTensorValuesExceedingMemBytesLimitAreDiscarded(self): + store = tensor_store.TensorStore(watch_mem_bytes_limit=150) + watch_key = "A:0:DebugIdentity" + value = np.eye(3, dtype=np.float64) + self.assertEqual(72, value.nbytes) + store.add(watch_key, value) + self.assertAllEqual([value], store.query(watch_key, time_indices=":")) + + store.add(watch_key, value * 2) + self.assertAllEqual( + [value, value * 2], store.query(watch_key, time_indices=":") + ) + + store.add(watch_key, value * 3) + result = store.query(watch_key, time_indices=":") + self.assertIsNone(result[0]) + self.assertAllEqual([value * 2, value * 3], result[1:]) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/debugger_v2/debugger_v2_plugin.py b/tensorboard/plugins/debugger_v2/debugger_v2_plugin.py index 75d37a5344..cf651c9701 100644 --- a/tensorboard/plugins/debugger_v2/debugger_v2_plugin.py +++ b/tensorboard/plugins/debugger_v2/debugger_v2_plugin.py @@ -22,36 +22,36 @@ class DebuggerV2Plugin(base_plugin.TBPlugin): - """Debugger V2 Plugin for TensorBoard.""" - - plugin_name = "debugger-v2" - - def __init__(self, context): - """Instantiates Debugger V2 Plugin via TensorBoard core. - Args: - context: A base_plugin.TBContext instance. - """ - super(DebuggerV2Plugin, self).__init__(context) - - def get_plugin_apps(self): - # TODO(cais): Add routes as they are implemented. - return {} - - def is_active(self): - """Check whether the Debugger V2 Plugin is always active. - - When no data in the tfdbg v2 format is available, a custom information - screen is displayed to instruct the user on how to generate such data - to be able to use the plugin. - - Returns: - `True` if and only if data in tfdbg v2's DebugEvent format is available. - """ - # TODO(cais): Implement logic. - return False - - def frontend_metadata(self): - return base_plugin.FrontendMetadata( - is_ng_component=True, - tab_name='Debugger V2', - disable_reload=True) + """Debugger V2 Plugin for TensorBoard.""" + + plugin_name = "debugger-v2" + + def __init__(self, context): + """Instantiates Debugger V2 Plugin via TensorBoard core. + + Args: + context: A base_plugin.TBContext instance. + """ + super(DebuggerV2Plugin, self).__init__(context) + + def get_plugin_apps(self): + # TODO(cais): Add routes as they are implemented. + return {} + + def is_active(self): + """Check whether the Debugger V2 Plugin is always active. + + When no data in the tfdbg v2 format is available, a custom information + screen is displayed to instruct the user on how to generate such data + to be able to use the plugin. + + Returns: + `True` if and only if data in tfdbg v2's DebugEvent format is available. + """ + # TODO(cais): Implement logic. + return False + + def frontend_metadata(self): + return base_plugin.FrontendMetadata( + is_ng_component=True, tab_name="Debugger V2", disable_reload=True + ) diff --git a/tensorboard/plugins/debugger_v2/debugger_v2_plugin_test.py b/tensorboard/plugins/debugger_v2/debugger_v2_plugin_test.py index beaaa7202e..7686881401 100644 --- a/tensorboard/plugins/debugger_v2/debugger_v2_plugin_test.py +++ b/tensorboard/plugins/debugger_v2/debugger_v2_plugin_test.py @@ -27,19 +27,18 @@ class DebuggerV2PluginTest(tf.test.TestCase): + def testInstantiatePlugin(self): + dummy_logdir = tempfile.mkdtemp() + context = base_plugin.TBContext(logdir=dummy_logdir) + plugin = debugger_v2_plugin.DebuggerV2Plugin(context) + self.assertTrue(plugin) - def testInstantiatePlugin(self): - dummy_logdir = tempfile.mkdtemp() - context = base_plugin.TBContext(logdir=dummy_logdir) - plugin = debugger_v2_plugin.DebuggerV2Plugin(context) - self.assertTrue(plugin) - - def testPluginIsNotActiveByDefault(self): - dummy_logdir = tempfile.mkdtemp() - context = base_plugin.TBContext(logdir=dummy_logdir) - plugin = debugger_v2_plugin.DebuggerV2Plugin(context) - self.assertFalse(plugin.is_active()) + def testPluginIsNotActiveByDefault(self): + dummy_logdir = tempfile.mkdtemp() + context = base_plugin.TBContext(logdir=dummy_logdir) + plugin = debugger_v2_plugin.DebuggerV2Plugin(context) + self.assertFalse(plugin.is_active()) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/tensorboard/plugins/distribution/compressor.py b/tensorboard/plugins/distribution/compressor.py index 7288a7027c..131296a932 100644 --- a/tensorboard/plugins/distribution/compressor.py +++ b/tensorboard/plugins/distribution/compressor.py @@ -28,116 +28,117 @@ NORMAL_HISTOGRAM_BPS = (0, 668, 1587, 3085, 5000, 6915, 8413, 9332, 10000) -CompressedHistogramValue = collections.namedtuple('CompressedHistogramValue', - ['basis_point', 'value']) +CompressedHistogramValue = collections.namedtuple( + "CompressedHistogramValue", ["basis_point", "value"] +) # TODO(@jart): Unfork these methods. def compress_histogram_proto(histo, bps=NORMAL_HISTOGRAM_BPS): - """Creates fixed size histogram by adding compression to accumulated state. - - This routine transforms a histogram at a particular step by interpolating its - variable number of buckets to represent their cumulative weight at a constant - number of compression points. This significantly reduces the size of the - histogram and makes it suitable for a two-dimensional area plot where the - output of this routine constitutes the ranges for a single x coordinate. - - Args: - histo: A HistogramProto object. - bps: Compression points represented in basis points, 1/100ths of a percent. - Defaults to normal distribution. - - Returns: - List of values for each basis point. - """ - # See also: Histogram::Percentile() in core/lib/histogram/histogram.cc - if not histo.num: - return [CompressedHistogramValue(b, 0.0) for b in bps] - bucket = np.array(histo.bucket) - bucket_limit = list(histo.bucket_limit) - weights = (bucket * bps[-1] / (bucket.sum() or 1.0)).cumsum() - values = [] - j = 0 - while j < len(bps): - i = np.searchsorted(weights, bps[j], side='right') - while i < len(weights): - cumsum = weights[i] - cumsum_prev = weights[i - 1] if i > 0 else 0.0 - if cumsum == cumsum_prev: # prevent lerp divide by zero - i += 1 - continue - if not i or not cumsum_prev: - lhs = histo.min - else: - lhs = max(bucket_limit[i - 1], histo.min) - rhs = min(bucket_limit[i], histo.max) - weight = _lerp(bps[j], cumsum_prev, cumsum, lhs, rhs) - values.append(CompressedHistogramValue(bps[j], weight)) - j += 1 - break - else: - break - while j < len(bps): - values.append(CompressedHistogramValue(bps[j], histo.max)) - j += 1 - return values + """Creates fixed size histogram by adding compression to accumulated state. + + This routine transforms a histogram at a particular step by interpolating its + variable number of buckets to represent their cumulative weight at a constant + number of compression points. This significantly reduces the size of the + histogram and makes it suitable for a two-dimensional area plot where the + output of this routine constitutes the ranges for a single x coordinate. + + Args: + histo: A HistogramProto object. + bps: Compression points represented in basis points, 1/100ths of a percent. + Defaults to normal distribution. + + Returns: + List of values for each basis point. + """ + # See also: Histogram::Percentile() in core/lib/histogram/histogram.cc + if not histo.num: + return [CompressedHistogramValue(b, 0.0) for b in bps] + bucket = np.array(histo.bucket) + bucket_limit = list(histo.bucket_limit) + weights = (bucket * bps[-1] / (bucket.sum() or 1.0)).cumsum() + values = [] + j = 0 + while j < len(bps): + i = np.searchsorted(weights, bps[j], side="right") + while i < len(weights): + cumsum = weights[i] + cumsum_prev = weights[i - 1] if i > 0 else 0.0 + if cumsum == cumsum_prev: # prevent lerp divide by zero + i += 1 + continue + if not i or not cumsum_prev: + lhs = histo.min + else: + lhs = max(bucket_limit[i - 1], histo.min) + rhs = min(bucket_limit[i], histo.max) + weight = _lerp(bps[j], cumsum_prev, cumsum, lhs, rhs) + values.append(CompressedHistogramValue(bps[j], weight)) + j += 1 + break + else: + break + while j < len(bps): + values.append(CompressedHistogramValue(bps[j], histo.max)) + j += 1 + return values def compress_histogram(buckets, bps=NORMAL_HISTOGRAM_BPS): - """Creates fixed size histogram by adding compression to accumulated state. - - This routine transforms a histogram at a particular step by linearly - interpolating its variable number of buckets to represent their cumulative - weight at a constant number of compression points. This significantly reduces - the size of the histogram and makes it suitable for a two-dimensional area - plot where the output of this routine constitutes the ranges for a single x - coordinate. - - Args: - buckets: A list of buckets, each of which is a 3-tuple of the form - `(min, max, count)`. - bps: Compression points represented in basis points, 1/100ths of a percent. - Defaults to normal distribution. - - Returns: - List of values for each basis point. - """ - # See also: Histogram::Percentile() in core/lib/histogram/histogram.cc - buckets = np.array(buckets) - if not buckets.size: - return [CompressedHistogramValue(b, 0.0) for b in bps] - (minmin, maxmax) = (buckets[0][0], buckets[-1][1]) - counts = buckets[:, 2] - right_edges = list(buckets[:, 1]) - weights = (counts * bps[-1] / (counts.sum() or 1.0)).cumsum() - - result = [] - bp_index = 0 - while bp_index < len(bps): - i = np.searchsorted(weights, bps[bp_index], side='right') - while i < len(weights): - cumsum = weights[i] - cumsum_prev = weights[i - 1] if i > 0 else 0.0 - if cumsum == cumsum_prev: # prevent division-by-zero in `_lerp` - i += 1 - continue - if not i or not cumsum_prev: - lhs = minmin - else: - lhs = max(right_edges[i - 1], minmin) - rhs = min(right_edges[i], maxmax) - weight = _lerp(bps[bp_index], cumsum_prev, cumsum, lhs, rhs) - result.append(CompressedHistogramValue(bps[bp_index], weight)) - bp_index += 1 - break - else: - break - while bp_index < len(bps): - result.append(CompressedHistogramValue(bps[bp_index], maxmax)) - bp_index += 1 - return result + """Creates fixed size histogram by adding compression to accumulated state. + + This routine transforms a histogram at a particular step by linearly + interpolating its variable number of buckets to represent their cumulative + weight at a constant number of compression points. This significantly reduces + the size of the histogram and makes it suitable for a two-dimensional area + plot where the output of this routine constitutes the ranges for a single x + coordinate. + + Args: + buckets: A list of buckets, each of which is a 3-tuple of the form + `(min, max, count)`. + bps: Compression points represented in basis points, 1/100ths of a percent. + Defaults to normal distribution. + + Returns: + List of values for each basis point. + """ + # See also: Histogram::Percentile() in core/lib/histogram/histogram.cc + buckets = np.array(buckets) + if not buckets.size: + return [CompressedHistogramValue(b, 0.0) for b in bps] + (minmin, maxmax) = (buckets[0][0], buckets[-1][1]) + counts = buckets[:, 2] + right_edges = list(buckets[:, 1]) + weights = (counts * bps[-1] / (counts.sum() or 1.0)).cumsum() + + result = [] + bp_index = 0 + while bp_index < len(bps): + i = np.searchsorted(weights, bps[bp_index], side="right") + while i < len(weights): + cumsum = weights[i] + cumsum_prev = weights[i - 1] if i > 0 else 0.0 + if cumsum == cumsum_prev: # prevent division-by-zero in `_lerp` + i += 1 + continue + if not i or not cumsum_prev: + lhs = minmin + else: + lhs = max(right_edges[i - 1], minmin) + rhs = min(right_edges[i], maxmax) + weight = _lerp(bps[bp_index], cumsum_prev, cumsum, lhs, rhs) + result.append(CompressedHistogramValue(bps[bp_index], weight)) + bp_index += 1 + break + else: + break + while bp_index < len(bps): + result.append(CompressedHistogramValue(bps[bp_index], maxmax)) + bp_index += 1 + return result def _lerp(x, x0, x1, y0, y1): - """Affinely map from [x0, x1] onto [y0, y1].""" - return y0 + (x - x0) * float(y1 - y0) / (x1 - x0) + """Affinely map from [x0, x1] onto [y0, y1].""" + return y0 + (x - x0) * float(y1 - y0) / (x1 - x0) diff --git a/tensorboard/plugins/distribution/compressor_test.py b/tensorboard/plugins/distribution/compressor_test.py index 3a98147123..00ab2b3c28 100644 --- a/tensorboard/plugins/distribution/compressor_test.py +++ b/tensorboard/plugins/distribution/compressor_test.py @@ -23,74 +23,70 @@ def _make_expected_value(*values): - return [compressor.CompressedHistogramValue(bp, val) for bp, val in values] + return [compressor.CompressedHistogramValue(bp, val) for bp, val in values] class CompressorTest(tf.test.TestCase): + def test_example(self): + bps = (0, 2500, 5000, 7500, 10000) + buckets = [[0, 1, 0], [1, 2, 3], [2, 3, 0]] + self.assertEqual( + _make_expected_value( + (0, 0.0), (2500, 0.5), (5000, 1.0), (7500, 1.5), (10000, 3.0), + ), + compressor.compress_histogram(buckets, bps), + ) - def test_example(self): - bps = (0, 2500, 5000, 7500, 10000) - buckets = [[0, 1, 0], [1, 2, 3], [2, 3, 0]] - self.assertEqual( - _make_expected_value( - (0, 0.0), - (2500, 0.5), - (5000, 1.0), - (7500, 1.5), - (10000, 3.0), - ), - compressor.compress_histogram(buckets, bps)) + def test_another_example(self): + bps = (0, 2500, 5000, 7500, 10000) + buckets = [[1, 2, 1], [2, 3, 3], [3, 4, 0]] + self.assertEqual( + _make_expected_value( + (0, 1.0), + (2500, 2.0), + (5000, 2.0 + 1 / 3), + (7500, 2.0 + 2 / 3), + (10000, 4.0), + ), + compressor.compress_histogram(buckets, bps), + ) - def test_another_example(self): - bps = (0, 2500, 5000, 7500, 10000) - buckets = [[1, 2, 1], [2, 3, 3], [3, 4, 0]] - self.assertEqual( - _make_expected_value( - (0, 1.0), - (2500, 2.0), - (5000, 2.0 + 1/3), - (7500, 2.0 + 2/3), - (10000, 4.0) - ), - compressor.compress_histogram(buckets, bps)) + def test_empty(self): + bps = (0, 2500, 5000, 7500, 10000) + buckets = [[0, 1, 0], [1, 2, 0], [2, 3, 0]] + self.assertEqual( + _make_expected_value( + (0, 3.0), (2500, 3.0), (5000, 3.0), (7500, 3.0), (10000, 3.0), + ), + compressor.compress_histogram(buckets, bps), + ) - def test_empty(self): - bps = (0, 2500, 5000, 7500, 10000) - buckets = [[0, 1, 0], [1, 2, 0], [2, 3, 0]] - self.assertEqual( - _make_expected_value( - (0, 3.0), - (2500, 3.0), - (5000, 3.0), - (7500, 3.0), - (10000, 3.0), - ), - compressor.compress_histogram(buckets, bps)) + def test_ugly(self): + bps = (0, 668, 1587, 3085, 5000, 6915, 8413, 9332, 10000) + bucket_limits = [ + -1.0, + 0.0, + 0.917246389039776, + 1.0089710279437536, + 1.7976931348623157e308, + ] + bucket_counts = [0.0, 896.0, 0.0, 64.0] + assert len(bucket_counts) == len(bucket_limits) - 1 + buckets = list( + zip(*[bucket_limits[:-1], bucket_limits[1:], bucket_counts]) + ) + vals = compressor.compress_histogram(buckets, bps) + self.assertEqual(tuple(v.basis_point for v in vals), bps) + self.assertAlmostEqual(vals[0].value, -1.0) + self.assertAlmostEqual(vals[1].value, -0.86277993701301037) + self.assertAlmostEqual(vals[2].value, -0.67399964077791519) + self.assertAlmostEqual(vals[3].value, -0.36628159533703131) + self.assertAlmostEqual(vals[4].value, 0.027096279842737214) + self.assertAlmostEqual(vals[5].value, 0.42047415502250551) + self.assertAlmostEqual(vals[6].value, 0.72819220046338917) + self.assertAlmostEqual(vals[7].value, 0.91697249669848446) + self.assertAlmostEqual(vals[8].value, 1.7976931348623157e308) - def test_ugly(self): - bps = (0, 668, 1587, 3085, 5000, 6915, 8413, 9332, 10000) - bucket_limits = [-1.0, - 0.0, - 0.917246389039776, - 1.0089710279437536, - 1.7976931348623157e+308] - bucket_counts = [0.0, 896.0, 0.0, 64.0] - assert len(bucket_counts) == len(bucket_limits) - 1 - buckets = list(zip(*[bucket_limits[:-1], - bucket_limits[1:], - bucket_counts])) - vals = compressor.compress_histogram(buckets, bps) - self.assertEqual(tuple(v.basis_point for v in vals), bps) - self.assertAlmostEqual(vals[0].value, -1.0) - self.assertAlmostEqual(vals[1].value, -0.86277993701301037) - self.assertAlmostEqual(vals[2].value, -0.67399964077791519) - self.assertAlmostEqual(vals[3].value, -0.36628159533703131) - self.assertAlmostEqual(vals[4].value, 0.027096279842737214) - self.assertAlmostEqual(vals[5].value, 0.42047415502250551) - self.assertAlmostEqual(vals[6].value, 0.72819220046338917) - self.assertAlmostEqual(vals[7].value, 0.91697249669848446) - self.assertAlmostEqual(vals[8].value, 1.7976931348623157e+308) - -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/distribution/distributions_plugin.py b/tensorboard/plugins/distribution/distributions_plugin.py index 0052de5029..464ef9b534 100644 --- a/tensorboard/plugins/distribution/distributions_plugin.py +++ b/tensorboard/plugins/distribution/distributions_plugin.py @@ -32,77 +32,83 @@ class DistributionsPlugin(base_plugin.TBPlugin): - """Distributions Plugin for TensorBoard. + """Distributions Plugin for TensorBoard. - This supports both old-style summaries (created with TensorFlow ops - that output directly to the `histo` field of the proto) and new-style - summaries (as created by the `tensorboard.plugins.histogram.summary` - module). - """ - - plugin_name = 'distributions' - - # Use a round number + 1 since sampling includes both start and end steps, - # so N+1 samples corresponds to dividing the step sequence into N intervals. - SAMPLE_SIZE = 501 - - def __init__(self, context): - """Instantiates DistributionsPlugin via TensorBoard core. - - Args: - context: A base_plugin.TBContext instance. + This supports both old-style summaries (created with TensorFlow ops + that output directly to the `histo` field of the proto) and new- + style summaries (as created by the + `tensorboard.plugins.histogram.summary` module). """ - self._histograms_plugin = histograms_plugin.HistogramsPlugin(context) - - def get_plugin_apps(self): - return { - '/distributions': self.distributions_route, - '/tags': self.tags_route, - } - - def is_active(self): - """This plugin is active iff any run has at least one histogram tag. - (The distributions plugin uses the same data source as the histogram - plugin.) - """ - return self._histograms_plugin.is_active() - - def frontend_metadata(self): - return base_plugin.FrontendMetadata( - element_name='tf-distribution-dashboard', - ) - - def distributions_impl(self, tag, run, experiment): - """Result of the form `(body, mime_type)`. - - Raises: - tensorboard.errors.PublicError: On invalid request. - """ - (histograms, mime_type) = self._histograms_plugin.histograms_impl( - tag, run, experiment=experiment, downsample_to=self.SAMPLE_SIZE) - return ([self._compress(histogram) for histogram in histograms], - mime_type) - - def _compress(self, histogram): - (wall_time, step, buckets) = histogram - converted_buckets = compressor.compress_histogram(buckets) - return [wall_time, step, converted_buckets] - - def index_impl(self, experiment): - return self._histograms_plugin.index_impl(experiment=experiment) - - @wrappers.Request.application - def tags_route(self, request): - experiment = plugin_util.experiment_id(request.environ) - index = self.index_impl(experiment=experiment) - return http_util.Respond(request, index, 'application/json') - - @wrappers.Request.application - def distributions_route(self, request): - """Given a tag and single run, return an array of compressed histograms.""" - experiment = plugin_util.experiment_id(request.environ) - tag = request.args.get('tag') - run = request.args.get('run') - (body, mime_type) = self.distributions_impl(tag, run, experiment=experiment) - return http_util.Respond(request, body, mime_type) + plugin_name = "distributions" + + # Use a round number + 1 since sampling includes both start and end steps, + # so N+1 samples corresponds to dividing the step sequence into N intervals. + SAMPLE_SIZE = 501 + + def __init__(self, context): + """Instantiates DistributionsPlugin via TensorBoard core. + + Args: + context: A base_plugin.TBContext instance. + """ + self._histograms_plugin = histograms_plugin.HistogramsPlugin(context) + + def get_plugin_apps(self): + return { + "/distributions": self.distributions_route, + "/tags": self.tags_route, + } + + def is_active(self): + """This plugin is active iff any run has at least one histogram tag. + + (The distributions plugin uses the same data source as the + histogram plugin.) + """ + return self._histograms_plugin.is_active() + + def frontend_metadata(self): + return base_plugin.FrontendMetadata( + element_name="tf-distribution-dashboard", + ) + + def distributions_impl(self, tag, run, experiment): + """Result of the form `(body, mime_type)`. + + Raises: + tensorboard.errors.PublicError: On invalid request. + """ + (histograms, mime_type) = self._histograms_plugin.histograms_impl( + tag, run, experiment=experiment, downsample_to=self.SAMPLE_SIZE + ) + return ( + [self._compress(histogram) for histogram in histograms], + mime_type, + ) + + def _compress(self, histogram): + (wall_time, step, buckets) = histogram + converted_buckets = compressor.compress_histogram(buckets) + return [wall_time, step, converted_buckets] + + def index_impl(self, experiment): + return self._histograms_plugin.index_impl(experiment=experiment) + + @wrappers.Request.application + def tags_route(self, request): + experiment = plugin_util.experiment_id(request.environ) + index = self.index_impl(experiment=experiment) + return http_util.Respond(request, index, "application/json") + + @wrappers.Request.application + def distributions_route(self, request): + """Given a tag and single run, return an array of compressed + histograms.""" + experiment = plugin_util.experiment_id(request.environ) + tag = request.args.get("tag") + run = request.args.get("run") + (body, mime_type) = self.distributions_impl( + tag, run, experiment=experiment + ) + return http_util.Respond(request, body, mime_type) diff --git a/tensorboard/plugins/distribution/distributions_plugin_test.py b/tensorboard/plugins/distribution/distributions_plugin_test.py index 876ae6b796..9284ede035 100644 --- a/tensorboard/plugins/distribution/distributions_plugin_test.py +++ b/tensorboard/plugins/distribution/distributions_plugin_test.py @@ -26,8 +26,12 @@ import tensorflow as tf from tensorboard import errors -from tensorboard.backend.event_processing import plugin_event_accumulator as event_accumulator -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from tensorboard.backend.event_processing import ( + plugin_event_accumulator as event_accumulator, +) +from tensorboard.backend.event_processing import ( + plugin_event_multiplexer as event_multiplexer, +) from tensorboard.plugins import base_plugin from tensorboard.plugins.distribution import compressor from tensorboard.plugins.distribution import distributions_plugin @@ -39,136 +43,164 @@ class DistributionsPluginTest(tf.test.TestCase): - _STEPS = 99 - - _LEGACY_DISTRIBUTION_TAG = 'my-ancient-distribution' - _DISTRIBUTION_TAG = 'my-favorite-distribution' - _SCALAR_TAG = 'my-boring-scalars' - - _DISPLAY_NAME = 'Very important production statistics' - _DESCRIPTION = 'quod *erat* dispertiendum' - _HTML_DESCRIPTION = '

quod erat dispertiendum

' - - _RUN_WITH_LEGACY_DISTRIBUTION = '_RUN_WITH_LEGACY_DISTRIBUTION' - _RUN_WITH_DISTRIBUTION = '_RUN_WITH_DISTRIBUTION' - _RUN_WITH_SCALARS = '_RUN_WITH_SCALARS' - - def __init__(self, *args, **kwargs): - super(DistributionsPluginTest, self).__init__(*args, **kwargs) - self.logdir = None - self.plugin = None - - def set_up_with_runs(self, run_names): - self.logdir = self.get_temp_dir() - for run_name in run_names: - self.generate_run(run_name) - multiplexer = event_multiplexer.EventMultiplexer(size_guidance={ - # don't truncate my test data, please - event_accumulator.TENSORS: self._STEPS, - }) - multiplexer.AddRunsFromDirectory(self.logdir) - multiplexer.Reload() - context = base_plugin.TBContext(logdir=self.logdir, multiplexer=multiplexer) - self.plugin = distributions_plugin.DistributionsPlugin(context) - - def generate_run(self, run_name): - tf.compat.v1.reset_default_graph() - sess = tf.compat.v1.Session() - placeholder = tf.compat.v1.placeholder(tf.float32, shape=[3]) - - if run_name == self._RUN_WITH_LEGACY_DISTRIBUTION: - tf.compat.v1.summary.histogram(self._LEGACY_DISTRIBUTION_TAG, placeholder) - elif run_name == self._RUN_WITH_DISTRIBUTION: - summary.op(self._DISTRIBUTION_TAG, placeholder, - display_name=self._DISPLAY_NAME, - description=self._DESCRIPTION) - elif run_name == self._RUN_WITH_SCALARS: - tf.compat.v1.summary.scalar(self._SCALAR_TAG, tf.reduce_mean(input_tensor=placeholder)) - else: - assert False, 'Invalid run name: %r' % run_name - summ = tf.compat.v1.summary.merge_all() - - subdir = os.path.join(self.logdir, run_name) - with test_util.FileWriterCache.get(subdir) as writer: - writer.add_graph(sess.graph) - for step in xrange(self._STEPS): - feed_dict = {placeholder: [1 + step, 2 + step, 3 + step]} - s = sess.run(summ, feed_dict=feed_dict) - writer.add_summary(s, global_step=step) - - def test_routes_provided(self): - """Tests that the plugin offers the correct routes.""" - self.set_up_with_runs([self._RUN_WITH_SCALARS]) - routes = self.plugin.get_plugin_apps() - self.assertIsInstance(routes['/distributions'], collections.Callable) - self.assertIsInstance(routes['/tags'], collections.Callable) - - def test_index(self): - self.set_up_with_runs([self._RUN_WITH_SCALARS, - self._RUN_WITH_LEGACY_DISTRIBUTION, - self._RUN_WITH_DISTRIBUTION]) - self.assertEqual({ - # _RUN_WITH_SCALARS omitted: No distribution data. - self._RUN_WITH_LEGACY_DISTRIBUTION: { - self._LEGACY_DISTRIBUTION_TAG: { - 'displayName': self._LEGACY_DISTRIBUTION_TAG, - 'description': '', - }, - }, - self._RUN_WITH_DISTRIBUTION: { - '%s/histogram_summary' % self._DISTRIBUTION_TAG: { - 'displayName': self._DISPLAY_NAME, - 'description': self._HTML_DESCRIPTION, + _STEPS = 99 + + _LEGACY_DISTRIBUTION_TAG = "my-ancient-distribution" + _DISTRIBUTION_TAG = "my-favorite-distribution" + _SCALAR_TAG = "my-boring-scalars" + + _DISPLAY_NAME = "Very important production statistics" + _DESCRIPTION = "quod *erat* dispertiendum" + _HTML_DESCRIPTION = "

quod erat dispertiendum

" + + _RUN_WITH_LEGACY_DISTRIBUTION = "_RUN_WITH_LEGACY_DISTRIBUTION" + _RUN_WITH_DISTRIBUTION = "_RUN_WITH_DISTRIBUTION" + _RUN_WITH_SCALARS = "_RUN_WITH_SCALARS" + + def __init__(self, *args, **kwargs): + super(DistributionsPluginTest, self).__init__(*args, **kwargs) + self.logdir = None + self.plugin = None + + def set_up_with_runs(self, run_names): + self.logdir = self.get_temp_dir() + for run_name in run_names: + self.generate_run(run_name) + multiplexer = event_multiplexer.EventMultiplexer( + size_guidance={ + # don't truncate my test data, please + event_accumulator.TENSORS: self._STEPS, + } + ) + multiplexer.AddRunsFromDirectory(self.logdir) + multiplexer.Reload() + context = base_plugin.TBContext( + logdir=self.logdir, multiplexer=multiplexer + ) + self.plugin = distributions_plugin.DistributionsPlugin(context) + + def generate_run(self, run_name): + tf.compat.v1.reset_default_graph() + sess = tf.compat.v1.Session() + placeholder = tf.compat.v1.placeholder(tf.float32, shape=[3]) + + if run_name == self._RUN_WITH_LEGACY_DISTRIBUTION: + tf.compat.v1.summary.histogram( + self._LEGACY_DISTRIBUTION_TAG, placeholder + ) + elif run_name == self._RUN_WITH_DISTRIBUTION: + summary.op( + self._DISTRIBUTION_TAG, + placeholder, + display_name=self._DISPLAY_NAME, + description=self._DESCRIPTION, + ) + elif run_name == self._RUN_WITH_SCALARS: + tf.compat.v1.summary.scalar( + self._SCALAR_TAG, tf.reduce_mean(input_tensor=placeholder) + ) + else: + assert False, "Invalid run name: %r" % run_name + summ = tf.compat.v1.summary.merge_all() + + subdir = os.path.join(self.logdir, run_name) + with test_util.FileWriterCache.get(subdir) as writer: + writer.add_graph(sess.graph) + for step in xrange(self._STEPS): + feed_dict = {placeholder: [1 + step, 2 + step, 3 + step]} + s = sess.run(summ, feed_dict=feed_dict) + writer.add_summary(s, global_step=step) + + def test_routes_provided(self): + """Tests that the plugin offers the correct routes.""" + self.set_up_with_runs([self._RUN_WITH_SCALARS]) + routes = self.plugin.get_plugin_apps() + self.assertIsInstance(routes["/distributions"], collections.Callable) + self.assertIsInstance(routes["/tags"], collections.Callable) + + def test_index(self): + self.set_up_with_runs( + [ + self._RUN_WITH_SCALARS, + self._RUN_WITH_LEGACY_DISTRIBUTION, + self._RUN_WITH_DISTRIBUTION, + ] + ) + self.assertEqual( + { + # _RUN_WITH_SCALARS omitted: No distribution data. + self._RUN_WITH_LEGACY_DISTRIBUTION: { + self._LEGACY_DISTRIBUTION_TAG: { + "displayName": self._LEGACY_DISTRIBUTION_TAG, + "description": "", + }, + }, + self._RUN_WITH_DISTRIBUTION: { + "%s/histogram_summary" + % self._DISTRIBUTION_TAG: { + "displayName": self._DISPLAY_NAME, + "description": self._HTML_DESCRIPTION, + }, + }, }, - }, - }, self.plugin.index_impl(experiment='exp')) - - def _test_distributions(self, run_name, tag_name, should_work=True): - self.set_up_with_runs([self._RUN_WITH_SCALARS, - self._RUN_WITH_LEGACY_DISTRIBUTION, - self._RUN_WITH_DISTRIBUTION]) - if should_work: - (data, mime_type) = self.plugin.distributions_impl( - tag_name, run_name, experiment='exp' - ) - self.assertEqual('application/json', mime_type) - self.assertEqual(len(data), self._STEPS) - for i in xrange(self._STEPS): - [_unused_wall_time, step, bps_and_icdfs] = data[i] - self.assertEqual(i, step) - (bps, _unused_icdfs) = zip(*bps_and_icdfs) - self.assertEqual(bps, compressor.NORMAL_HISTOGRAM_BPS) - else: - with self.assertRaises(errors.NotFoundError): - self.plugin.distributions_impl( - self._DISTRIBUTION_TAG, run_name, experiment='exp' + self.plugin.index_impl(experiment="exp"), ) - def test_distributions_with_scalars(self): - self._test_distributions(self._RUN_WITH_SCALARS, self._DISTRIBUTION_TAG, - should_work=False) + def _test_distributions(self, run_name, tag_name, should_work=True): + self.set_up_with_runs( + [ + self._RUN_WITH_SCALARS, + self._RUN_WITH_LEGACY_DISTRIBUTION, + self._RUN_WITH_DISTRIBUTION, + ] + ) + if should_work: + (data, mime_type) = self.plugin.distributions_impl( + tag_name, run_name, experiment="exp" + ) + self.assertEqual("application/json", mime_type) + self.assertEqual(len(data), self._STEPS) + for i in xrange(self._STEPS): + [_unused_wall_time, step, bps_and_icdfs] = data[i] + self.assertEqual(i, step) + (bps, _unused_icdfs) = zip(*bps_and_icdfs) + self.assertEqual(bps, compressor.NORMAL_HISTOGRAM_BPS) + else: + with self.assertRaises(errors.NotFoundError): + self.plugin.distributions_impl( + self._DISTRIBUTION_TAG, run_name, experiment="exp" + ) + + def test_distributions_with_scalars(self): + self._test_distributions( + self._RUN_WITH_SCALARS, self._DISTRIBUTION_TAG, should_work=False + ) - def test_distributions_with_legacy_distribution(self): - self._test_distributions(self._RUN_WITH_LEGACY_DISTRIBUTION, - self._LEGACY_DISTRIBUTION_TAG) + def test_distributions_with_legacy_distribution(self): + self._test_distributions( + self._RUN_WITH_LEGACY_DISTRIBUTION, self._LEGACY_DISTRIBUTION_TAG + ) - def test_distributions_with_distribution(self): - self._test_distributions(self._RUN_WITH_DISTRIBUTION, - '%s/histogram_summary' % self._DISTRIBUTION_TAG) + def test_distributions_with_distribution(self): + self._test_distributions( + self._RUN_WITH_DISTRIBUTION, + "%s/histogram_summary" % self._DISTRIBUTION_TAG, + ) - def test_active_with_distribution(self): - self.set_up_with_runs([self._RUN_WITH_DISTRIBUTION]) - self.assertTrue(self.plugin.is_active()) + def test_active_with_distribution(self): + self.set_up_with_runs([self._RUN_WITH_DISTRIBUTION]) + self.assertTrue(self.plugin.is_active()) - def test_active_with_scalars(self): - self.set_up_with_runs([self._RUN_WITH_SCALARS]) - self.assertFalse(self.plugin.is_active()) + def test_active_with_scalars(self): + self.set_up_with_runs([self._RUN_WITH_SCALARS]) + self.assertFalse(self.plugin.is_active()) - def test_active_with_both(self): - self.set_up_with_runs([self._RUN_WITH_DISTRIBUTION, - self._RUN_WITH_SCALARS]) - self.assertTrue(self.plugin.is_active()) + def test_active_with_both(self): + self.set_up_with_runs( + [self._RUN_WITH_DISTRIBUTION, self._RUN_WITH_SCALARS] + ) + self.assertTrue(self.plugin.is_active()) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/graph/graph_util.py b/tensorboard/plugins/graph/graph_util.py index db477674b8..c820752808 100644 --- a/tensorboard/plugins/graph/graph_util.py +++ b/tensorboard/plugins/graph/graph_util.py @@ -17,134 +17,150 @@ class _ProtoListDuplicateKeyError(Exception): - pass + pass class _SameKeyDiffContentError(Exception): - pass + pass def _safe_copy_proto_list_values(dst_proto_list, src_proto_list, get_key): - """Safely merge values from `src_proto_list` into `dst_proto_list`. - - Each element in `dst_proto_list` must be mapped by `get_key` to a key - value that is unique within that list; likewise for `src_proto_list`. - If an element of `src_proto_list` has the same key as an existing - element in `dst_proto_list`, then the elements must also be equal. - - Args: - dst_proto_list: A `RepeatedCompositeContainer` or - `RepeatedScalarContainer` into which values should be copied. - src_proto_list: A container holding the same kind of values as in - `dst_proto_list` from which values should be copied. - get_key: A function that takes an element of `dst_proto_list` or - `src_proto_list` and returns a key, such that if two elements have - the same key then it is required that they be deep-equal. For - instance, if `dst_proto_list` is a list of nodes, then `get_key` - might be `lambda node: node.name` to indicate that if two nodes - have the same name then they must be the same node. All keys must - be hashable. - - Raises: - _ProtoListDuplicateKeyError: A proto_list contains items with duplicate - keys. - _SameKeyDiffContentError: An item with the same key has different contents. - """ - - def _assert_proto_container_unique_keys(proto_list, get_key): - """Asserts proto_list to only contains unique keys. + """Safely merge values from `src_proto_list` into `dst_proto_list`. + + Each element in `dst_proto_list` must be mapped by `get_key` to a key + value that is unique within that list; likewise for `src_proto_list`. + If an element of `src_proto_list` has the same key as an existing + element in `dst_proto_list`, then the elements must also be equal. Args: - proto_list: A `RepeatedCompositeContainer` or `RepeatedScalarContainer`. - get_key: A function that takes an element of `proto_list` and returns a - hashable key. + dst_proto_list: A `RepeatedCompositeContainer` or + `RepeatedScalarContainer` into which values should be copied. + src_proto_list: A container holding the same kind of values as in + `dst_proto_list` from which values should be copied. + get_key: A function that takes an element of `dst_proto_list` or + `src_proto_list` and returns a key, such that if two elements have + the same key then it is required that they be deep-equal. For + instance, if `dst_proto_list` is a list of nodes, then `get_key` + might be `lambda node: node.name` to indicate that if two nodes + have the same name then they must be the same node. All keys must + be hashable. Raises: _ProtoListDuplicateKeyError: A proto_list contains items with duplicate keys. + _SameKeyDiffContentError: An item with the same key has different contents. """ - keys = set() - for item in proto_list: - key = get_key(item) - if key in keys: - raise _ProtoListDuplicateKeyError(key) - keys.add(key) - - _assert_proto_container_unique_keys(dst_proto_list, get_key) - _assert_proto_container_unique_keys(src_proto_list, get_key) - - key_to_proto = {} - for proto in dst_proto_list: - key = get_key(proto) - key_to_proto[key] = proto - - for proto in src_proto_list: - key = get_key(proto) - if key in key_to_proto: - if proto != key_to_proto.get(key): - raise _SameKeyDiffContentError(key) - else: - dst_proto_list.add().CopyFrom(proto) + + def _assert_proto_container_unique_keys(proto_list, get_key): + """Asserts proto_list to only contains unique keys. + + Args: + proto_list: A `RepeatedCompositeContainer` or `RepeatedScalarContainer`. + get_key: A function that takes an element of `proto_list` and returns a + hashable key. + + Raises: + _ProtoListDuplicateKeyError: A proto_list contains items with duplicate + keys. + """ + keys = set() + for item in proto_list: + key = get_key(item) + if key in keys: + raise _ProtoListDuplicateKeyError(key) + keys.add(key) + + _assert_proto_container_unique_keys(dst_proto_list, get_key) + _assert_proto_container_unique_keys(src_proto_list, get_key) + + key_to_proto = {} + for proto in dst_proto_list: + key = get_key(proto) + key_to_proto[key] = proto + + for proto in src_proto_list: + key = get_key(proto) + if key in key_to_proto: + if proto != key_to_proto.get(key): + raise _SameKeyDiffContentError(key) + else: + dst_proto_list.add().CopyFrom(proto) def combine_graph_defs(to_proto, from_proto): - """Combines two GraphDefs by adding nodes from from_proto into to_proto. - - All GraphDefs are expected to be of TensorBoard's. - It assumes node names are unique across GraphDefs if contents differ. The - names can be the same if the NodeDef content are exactly the same. - - Args: - to_proto: A destination TensorBoard GraphDef. - from_proto: A TensorBoard GraphDef to copy contents from. - - Returns: - to_proto - - Raises: - ValueError in case any assumption about GraphDef is violated: A - GraphDef should have unique node, function, and gradient function - names. Also, when merging GraphDefs, they should have not have nodes, - functions, or gradient function mappings that share the name but details - do not match. - """ - if from_proto.version != to_proto.version: - raise ValueError('Cannot combine GraphDefs of different versions.') - - try: - _safe_copy_proto_list_values( - to_proto.node, - from_proto.node, - lambda n: n.name) - except _ProtoListDuplicateKeyError as exc: - raise ValueError('A GraphDef contains non-unique node names: %s' % exc) - except _SameKeyDiffContentError as exc: - raise ValueError( - ('Cannot combine GraphDefs because nodes share a name ' - 'but contents are different: %s') % exc) - try: - _safe_copy_proto_list_values( - to_proto.library.function, - from_proto.library.function, - lambda n: n.signature.name) - except _ProtoListDuplicateKeyError as exc: - raise ValueError('A GraphDef contains non-unique function names: %s' % exc) - except _SameKeyDiffContentError as exc: - raise ValueError( - ('Cannot combine GraphDefs because functions share a name ' - 'but are different: %s') % exc) - - try: - _safe_copy_proto_list_values( - to_proto.library.gradient, - from_proto.library.gradient, - lambda g: g.gradient_func) - except _ProtoListDuplicateKeyError as exc: - raise ValueError( - 'A GraphDef contains non-unique gradient function names: %s' % exc) - except _SameKeyDiffContentError as exc: - raise ValueError( - ('Cannot combine GraphDefs because gradients share a gradient_func name ' - 'but map to different functions: %s') % exc) - - return to_proto + """Combines two GraphDefs by adding nodes from from_proto into to_proto. + + All GraphDefs are expected to be of TensorBoard's. + It assumes node names are unique across GraphDefs if contents differ. The + names can be the same if the NodeDef content are exactly the same. + + Args: + to_proto: A destination TensorBoard GraphDef. + from_proto: A TensorBoard GraphDef to copy contents from. + + Returns: + to_proto + + Raises: + ValueError in case any assumption about GraphDef is violated: A + GraphDef should have unique node, function, and gradient function + names. Also, when merging GraphDefs, they should have not have nodes, + functions, or gradient function mappings that share the name but details + do not match. + """ + if from_proto.version != to_proto.version: + raise ValueError("Cannot combine GraphDefs of different versions.") + + try: + _safe_copy_proto_list_values( + to_proto.node, from_proto.node, lambda n: n.name + ) + except _ProtoListDuplicateKeyError as exc: + raise ValueError("A GraphDef contains non-unique node names: %s" % exc) + except _SameKeyDiffContentError as exc: + raise ValueError( + ( + "Cannot combine GraphDefs because nodes share a name " + "but contents are different: %s" + ) + % exc + ) + try: + _safe_copy_proto_list_values( + to_proto.library.function, + from_proto.library.function, + lambda n: n.signature.name, + ) + except _ProtoListDuplicateKeyError as exc: + raise ValueError( + "A GraphDef contains non-unique function names: %s" % exc + ) + except _SameKeyDiffContentError as exc: + raise ValueError( + ( + "Cannot combine GraphDefs because functions share a name " + "but are different: %s" + ) + % exc + ) + + try: + _safe_copy_proto_list_values( + to_proto.library.gradient, + from_proto.library.gradient, + lambda g: g.gradient_func, + ) + except _ProtoListDuplicateKeyError as exc: + raise ValueError( + "A GraphDef contains non-unique gradient function names: %s" % exc + ) + except _SameKeyDiffContentError as exc: + raise ValueError( + ( + "Cannot combine GraphDefs because gradients share a gradient_func name " + "but map to different functions: %s" + ) + % exc + ) + + return to_proto diff --git a/tensorboard/plugins/graph/graph_util_test.py b/tensorboard/plugins/graph/graph_util_test.py index 124f629bdc..c456727b63 100644 --- a/tensorboard/plugins/graph/graph_util_test.py +++ b/tensorboard/plugins/graph/graph_util_test.py @@ -21,9 +21,8 @@ class GraphUtilTest(tf.test.TestCase): - - def test_combine_graph_defs(self): - expected_proto = ''' + def test_combine_graph_defs(self): + expected_proto = """ node { name: "X" op: "Input" @@ -55,10 +54,11 @@ def test_combine_graph_defs(self): versions { producer: 21 } - ''' + """ - graph_def_a = GraphDef() - text_format.Merge(''' + graph_def_a = GraphDef() + text_format.Merge( + """ node { name: "X" op: "Input" @@ -76,10 +76,13 @@ def test_combine_graph_defs(self): versions { producer: 21 } - ''', graph_def_a) + """, + graph_def_a, + ) - graph_def_b = GraphDef() - text_format.Merge(''' + graph_def_b = GraphDef() + text_format.Merge( + """ node { name: "A" op: "Input" @@ -97,14 +100,17 @@ def test_combine_graph_defs(self): versions { producer: 21 } - ''', graph_def_b) + """, + graph_def_b, + ) - self.assertProtoEquals( - expected_proto, - graph_util.combine_graph_defs(graph_def_a, graph_def_b)) + self.assertProtoEquals( + expected_proto, + graph_util.combine_graph_defs(graph_def_a, graph_def_b), + ) - def test_combine_graph_defs_name_collided_but_same_content(self): - expected_proto = ''' + def test_combine_graph_defs_name_collided_but_same_content(self): + expected_proto = """ node { name: "X" op: "Input" @@ -126,10 +132,11 @@ def test_combine_graph_defs_name_collided_but_same_content(self): versions { producer: 21 } - ''' + """ - graph_def_a = GraphDef() - text_format.Merge(''' + graph_def_a = GraphDef() + text_format.Merge( + """ node { name: "X" op: "Input" @@ -147,10 +154,13 @@ def test_combine_graph_defs_name_collided_but_same_content(self): versions { producer: 21 } - ''', graph_def_a) + """, + graph_def_a, + ) - graph_def_b = GraphDef() - text_format.Merge(''' + graph_def_b = GraphDef() + text_format.Merge( + """ node { name: "X" op: "Input" @@ -162,15 +172,19 @@ def test_combine_graph_defs_name_collided_but_same_content(self): versions { producer: 21 } - ''', graph_def_b) + """, + graph_def_b, + ) - self.assertProtoEquals( - expected_proto, - graph_util.combine_graph_defs(graph_def_a, graph_def_b)) + self.assertProtoEquals( + expected_proto, + graph_util.combine_graph_defs(graph_def_a, graph_def_b), + ) - def test_combine_graph_defs_name_collided_different_content(self): - graph_def_a = GraphDef() - text_format.Merge(''' + def test_combine_graph_defs_name_collided_different_content(self): + graph_def_a = GraphDef() + text_format.Merge( + """ node { name: "X" op: "Input" @@ -188,10 +202,13 @@ def test_combine_graph_defs_name_collided_different_content(self): versions { producer: 21 } - ''', graph_def_a) + """, + graph_def_a, + ) - graph_def_b = GraphDef() - text_format.Merge(''' + graph_def_b = GraphDef() + text_format.Merge( + """ node { name: "X" op: "Input" @@ -210,18 +227,24 @@ def test_combine_graph_defs_name_collided_different_content(self): versions { producer: 21 } - ''', graph_def_b) - - with six.assertRaisesRegex( - self, - ValueError, - ('Cannot combine GraphDefs because nodes share a name but ' - 'contents are different: X')): - graph_util.combine_graph_defs(graph_def_a, graph_def_b) - - def test_combine_graph_defs_dst_nodes_duplicate_keys(self): - graph_def_a = GraphDef() - text_format.Merge(''' + """, + graph_def_b, + ) + + with six.assertRaisesRegex( + self, + ValueError, + ( + "Cannot combine GraphDefs because nodes share a name but " + "contents are different: X" + ), + ): + graph_util.combine_graph_defs(graph_def_a, graph_def_b) + + def test_combine_graph_defs_dst_nodes_duplicate_keys(self): + graph_def_a = GraphDef() + text_format.Merge( + """ node { name: "X" op: "Input" @@ -233,10 +256,13 @@ def test_combine_graph_defs_dst_nodes_duplicate_keys(self): versions { producer: 21 } - ''', graph_def_a) + """, + graph_def_a, + ) - graph_def_b = GraphDef() - text_format.Merge(''' + graph_def_b = GraphDef() + text_format.Merge( + """ node { name: "X" op: "Input" @@ -248,17 +274,19 @@ def test_combine_graph_defs_dst_nodes_duplicate_keys(self): versions { producer: 21 } - ''', graph_def_b) + """, + graph_def_b, + ) - with six.assertRaisesRegex( - self, - ValueError, - 'A GraphDef contains non-unique node names: X'): - graph_util.combine_graph_defs(graph_def_a, graph_def_b) + with six.assertRaisesRegex( + self, ValueError, "A GraphDef contains non-unique node names: X" + ): + graph_util.combine_graph_defs(graph_def_a, graph_def_b) - def test_combine_graph_defs_src_nodes_duplicate_keys(self): - graph_def_a = GraphDef() - text_format.Merge(''' + def test_combine_graph_defs_src_nodes_duplicate_keys(self): + graph_def_a = GraphDef() + text_format.Merge( + """ node { name: "X" op: "Input" @@ -270,10 +298,13 @@ def test_combine_graph_defs_src_nodes_duplicate_keys(self): versions { producer: 21 } - ''', graph_def_a) + """, + graph_def_a, + ) - graph_def_b = GraphDef() - text_format.Merge(''' + graph_def_b = GraphDef() + text_format.Merge( + """ node { name: "W" op: "Input" @@ -286,16 +317,17 @@ def test_combine_graph_defs_src_nodes_duplicate_keys(self): versions { producer: 21 } - ''', graph_def_b) + """, + graph_def_b, + ) - with six.assertRaisesRegex( - self, - ValueError, - 'A GraphDef contains non-unique node names: W'): - graph_util.combine_graph_defs(graph_def_a, graph_def_b) + with six.assertRaisesRegex( + self, ValueError, "A GraphDef contains non-unique node names: W" + ): + graph_util.combine_graph_defs(graph_def_a, graph_def_b) - def test_combine_graph_defs_function(self): - expected_proto = ''' + def test_combine_graph_defs_function(self): + expected_proto = """ library { function { signature { @@ -336,10 +368,11 @@ def test_combine_graph_defs_function(self): } } } - ''' + """ - graph_def_a = GraphDef() - text_format.Merge(''' + graph_def_a = GraphDef() + text_format.Merge( + """ library { function { signature { @@ -361,10 +394,13 @@ def test_combine_graph_defs_function(self): } } } - ''', graph_def_a) + """, + graph_def_a, + ) - graph_def_b = GraphDef() - text_format.Merge(''' + graph_def_b = GraphDef() + text_format.Merge( + """ library { function { signature { @@ -405,15 +441,19 @@ def test_combine_graph_defs_function(self): } } } - ''', graph_def_b) + """, + graph_def_b, + ) - self.assertProtoEquals( - expected_proto, - graph_util.combine_graph_defs(graph_def_a, graph_def_b)) + self.assertProtoEquals( + expected_proto, + graph_util.combine_graph_defs(graph_def_a, graph_def_b), + ) - def test_combine_graph_defs_function_collison(self): - graph_def_a = GraphDef() - text_format.Merge(''' + def test_combine_graph_defs_function_collison(self): + graph_def_a = GraphDef() + text_format.Merge( + """ library { function { signature { @@ -435,10 +475,13 @@ def test_combine_graph_defs_function_collison(self): } } } - ''', graph_def_a) + """, + graph_def_a, + ) - graph_def_b = GraphDef() - text_format.Merge(''' + graph_def_b = GraphDef() + text_format.Merge( + """ library { function { signature { @@ -479,18 +522,24 @@ def test_combine_graph_defs_function_collison(self): } } } - ''', graph_def_b) - - with six.assertRaisesRegex( - self, - ValueError, - ('Cannot combine GraphDefs because functions share a name but ' - 'are different: foo')): - graph_util.combine_graph_defs(graph_def_a, graph_def_b) - - def test_combine_graph_defs_dst_function_duplicate_keys(self): - graph_def_a = GraphDef() - text_format.Merge(''' + """, + graph_def_b, + ) + + with six.assertRaisesRegex( + self, + ValueError, + ( + "Cannot combine GraphDefs because functions share a name but " + "are different: foo" + ), + ): + graph_util.combine_graph_defs(graph_def_a, graph_def_b) + + def test_combine_graph_defs_dst_function_duplicate_keys(self): + graph_def_a = GraphDef() + text_format.Merge( + """ library { function { signature { @@ -525,10 +574,13 @@ def test_combine_graph_defs_dst_function_duplicate_keys(self): } } } - ''', graph_def_a) + """, + graph_def_a, + ) - graph_def_b = GraphDef() - text_format.Merge(''' + graph_def_b = GraphDef() + text_format.Merge( + """ library { function { signature { @@ -550,17 +602,21 @@ def test_combine_graph_defs_dst_function_duplicate_keys(self): } } } - ''', graph_def_b) - - with six.assertRaisesRegex( - self, - ValueError, - ('A GraphDef contains non-unique function names: foo')): - graph_util.combine_graph_defs(graph_def_a, graph_def_b) - - def test_combine_graph_defs_src_function_duplicate_keys(self): - graph_def_a = GraphDef() - text_format.Merge(''' + """, + graph_def_b, + ) + + with six.assertRaisesRegex( + self, + ValueError, + ("A GraphDef contains non-unique function names: foo"), + ): + graph_util.combine_graph_defs(graph_def_a, graph_def_b) + + def test_combine_graph_defs_src_function_duplicate_keys(self): + graph_def_a = GraphDef() + text_format.Merge( + """ library { function { signature { @@ -582,10 +638,13 @@ def test_combine_graph_defs_src_function_duplicate_keys(self): } } } - ''', graph_def_a) + """, + graph_def_a, + ) - graph_def_b = GraphDef() - text_format.Merge(''' + graph_def_b = GraphDef() + text_format.Merge( + """ library { function { signature { @@ -614,16 +673,19 @@ def test_combine_graph_defs_src_function_duplicate_keys(self): } } } - ''', graph_def_b) + """, + graph_def_b, + ) - with six.assertRaisesRegex( - self, - ValueError, - 'A GraphDef contains non-unique function names: bar'): - graph_util.combine_graph_defs(graph_def_a, graph_def_b) + with six.assertRaisesRegex( + self, + ValueError, + "A GraphDef contains non-unique function names: bar", + ): + graph_util.combine_graph_defs(graph_def_a, graph_def_b) - def test_combine_graph_defs_gradient(self): - expected_proto = ''' + def test_combine_graph_defs_gradient(self): + expected_proto = """ library { gradient { function_name: "foo" @@ -634,20 +696,24 @@ def test_combine_graph_defs_gradient(self): gradient_func: "bar_grad" } } - ''' + """ - graph_def_a = GraphDef() - text_format.Merge(''' + graph_def_a = GraphDef() + text_format.Merge( + """ library { gradient { function_name: "foo" gradient_func: "foo_grad" } } - ''', graph_def_a) + """, + graph_def_a, + ) - graph_def_b = GraphDef() - text_format.Merge(''' + graph_def_b = GraphDef() + text_format.Merge( + """ library { gradient { function_name: "foo" @@ -658,25 +724,32 @@ def test_combine_graph_defs_gradient(self): gradient_func: "bar_grad" } } - ''', graph_def_b) + """, + graph_def_b, + ) - self.assertProtoEquals( - expected_proto, - graph_util.combine_graph_defs(graph_def_a, graph_def_b)) + self.assertProtoEquals( + expected_proto, + graph_util.combine_graph_defs(graph_def_a, graph_def_b), + ) - def test_combine_graph_defs_gradient_collison(self): - graph_def_a = GraphDef() - text_format.Merge(''' + def test_combine_graph_defs_gradient_collison(self): + graph_def_a = GraphDef() + text_format.Merge( + """ library { gradient { function_name: "foo" gradient_func: "foo_grad" } } - ''', graph_def_a) + """, + graph_def_a, + ) - graph_def_b = GraphDef() - text_format.Merge(''' + graph_def_b = GraphDef() + text_format.Merge( + """ library { gradient { function_name: "bar" @@ -687,18 +760,24 @@ def test_combine_graph_defs_gradient_collison(self): gradient_func: "foo_grad" } } - ''', graph_def_b) - - with six.assertRaisesRegex( - self, - ValueError, - ('share a gradient_func name but map to different functions: ' - 'foo_grad')): - graph_util.combine_graph_defs(graph_def_a, graph_def_b) - - def test_combine_graph_defs_dst_gradient_func_non_unique(self): - graph_def_a = GraphDef() - text_format.Merge(''' + """, + graph_def_b, + ) + + with six.assertRaisesRegex( + self, + ValueError, + ( + "share a gradient_func name but map to different functions: " + "foo_grad" + ), + ): + graph_util.combine_graph_defs(graph_def_a, graph_def_b) + + def test_combine_graph_defs_dst_gradient_func_non_unique(self): + graph_def_a = GraphDef() + text_format.Merge( + """ library { gradient { function_name: "foo" @@ -709,37 +788,47 @@ def test_combine_graph_defs_dst_gradient_func_non_unique(self): gradient_func: "foo_grad" } } - ''', graph_def_a) + """, + graph_def_a, + ) - graph_def_b = GraphDef() - text_format.Merge(''' + graph_def_b = GraphDef() + text_format.Merge( + """ library { gradient { function_name: "bar" gradient_func: "bar_grad" } } - ''', graph_def_b) - - with six.assertRaisesRegex( - self, - ValueError, - 'A GraphDef contains non-unique gradient function names: foo_grad'): - graph_util.combine_graph_defs(graph_def_a, graph_def_b) - - def test_combine_graph_defs_src_gradient_func_non_unique(self): - graph_def_a = GraphDef() - text_format.Merge(''' + """, + graph_def_b, + ) + + with six.assertRaisesRegex( + self, + ValueError, + "A GraphDef contains non-unique gradient function names: foo_grad", + ): + graph_util.combine_graph_defs(graph_def_a, graph_def_b) + + def test_combine_graph_defs_src_gradient_func_non_unique(self): + graph_def_a = GraphDef() + text_format.Merge( + """ library { gradient { function_name: "foo" gradient_func: "foo_grad" } } - ''', graph_def_a) + """, + graph_def_a, + ) - graph_def_b = GraphDef() - text_format.Merge(''' + graph_def_b = GraphDef() + text_format.Merge( + """ library { gradient { function_name: "bar" @@ -750,14 +839,17 @@ def test_combine_graph_defs_src_gradient_func_non_unique(self): gradient_func: "bar_grad" } } - ''', graph_def_b) + """, + graph_def_b, + ) - with six.assertRaisesRegex( - self, - ValueError, - 'A GraphDef contains non-unique gradient function names: bar_grad'): - graph_util.combine_graph_defs(graph_def_a, graph_def_b) + with six.assertRaisesRegex( + self, + ValueError, + "A GraphDef contains non-unique gradient function names: bar_grad", + ): + graph_util.combine_graph_defs(graph_def_a, graph_def_b) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/graph/graphs_plugin.py b/tensorboard/plugins/graph/graphs_plugin.py index ccbbf6b3b5..427c7d614c 100644 --- a/tensorboard/plugins/graph/graphs_plugin.py +++ b/tensorboard/plugins/graph/graphs_plugin.py @@ -25,7 +25,9 @@ from tensorboard import plugin_util from tensorboard.backend import http_util from tensorboard.backend import process_graph -from tensorboard.backend.event_processing import plugin_event_accumulator as event_accumulator +from tensorboard.backend.event_processing import ( + plugin_event_accumulator as event_accumulator, +) from tensorboard.compat.proto import config_pb2 from tensorboard.compat.proto import graph_pb2 from tensorboard.data import provider @@ -41,260 +43,314 @@ # As a result, this SummaryMetadata is a bit unconventional and uses non-public # hardcoded name as the plugin name. Please refer to link below for the summary ops. # https://github.com/tensorflow/tensorflow/blob/11f4ecb54708865ec757ca64e4805957b05d7570/tensorflow/python/ops/summary_ops_v2.py#L757 -_PLUGIN_NAME_RUN_METADATA = 'graph_run_metadata' +_PLUGIN_NAME_RUN_METADATA = "graph_run_metadata" # https://github.com/tensorflow/tensorflow/blob/11f4ecb54708865ec757ca64e4805957b05d7570/tensorflow/python/ops/summary_ops_v2.py#L788 -_PLUGIN_NAME_RUN_METADATA_WITH_GRAPH = 'graph_run_metadata_graph' +_PLUGIN_NAME_RUN_METADATA_WITH_GRAPH = "graph_run_metadata_graph" # https://github.com/tensorflow/tensorflow/blob/565952cc2f17fdfd995e25171cf07be0f6f06180/tensorflow/python/ops/summary_ops_v2.py#L825 -_PLUGIN_NAME_KERAS_MODEL = 'graph_keras_model' +_PLUGIN_NAME_KERAS_MODEL = "graph_keras_model" class GraphsPlugin(base_plugin.TBPlugin): - """Graphs Plugin for TensorBoard.""" - - plugin_name = metadata.PLUGIN_NAME - - def __init__(self, context): - """Instantiates GraphsPlugin via TensorBoard core. - - Args: - context: A base_plugin.TBContext instance. - """ - self._multiplexer = context.multiplexer - if context.flags and context.flags.generic_data == 'true': - self._data_provider = context.data_provider - else: - self._data_provider = None - - def get_plugin_apps(self): - return { - '/graph': self.graph_route, - '/info': self.info_route, - '/run_metadata': self.run_metadata_route, - } - - def is_active(self): - """The graphs plugin is active iff any run has a graph or metadata.""" - if self._data_provider: - # We don't have an experiment ID, and modifying the backend core - # to provide one would break backward compatibility. Hack for now. - return True - - return bool(self.info_impl()) - - def frontend_metadata(self): - return base_plugin.FrontendMetadata( - element_name='tf-graph-dashboard', - # TODO(@chihuahua): Reconcile this setting with Health Pills. - disable_reload=True, - ) - - def info_impl(self, experiment=None): - """Returns a dict of all runs and their data availabilities.""" - result = {} - def add_row_item(run, tag=None): - run_item = result.setdefault(run, { - 'run': run, - 'tags': {}, - # A run-wide GraphDef of ops. - 'run_graph': False}) - - tag_item = None - if tag: - tag_item = run_item.get('tags').setdefault(tag, { - 'tag': tag, - 'conceptual_graph': False, - # A tagged GraphDef of ops. - 'op_graph': False, - 'profile': False}) - return (run_item, tag_item) - - if self._data_provider: - mapping = self._data_provider.list_blob_sequences( - experiment_id=experiment, - plugin_name=metadata.PLUGIN_NAME, - ) - for (run_name, tag_to_time_series) in six.iteritems(mapping): - for tag in tag_to_time_series: - (run_item, tag_item) = add_row_item(run_name, tag) - run_item['run_graph'] = True - if tag_item: - tag_item['op_graph'] = True - return result - - mapping = self._multiplexer.PluginRunToTagToContent( - _PLUGIN_NAME_RUN_METADATA_WITH_GRAPH) - for run_name, tag_to_content in six.iteritems(mapping): - for (tag, content) in six.iteritems(tag_to_content): - # The Summary op is defined in TensorFlow and does not use a stringified proto - # as a content of plugin data. It contains single string that denotes a version. - # https://github.com/tensorflow/tensorflow/blob/11f4ecb54708865ec757ca64e4805957b05d7570/tensorflow/python/ops/summary_ops_v2.py#L789-L790 - if content != b'1': - logger.warn('Ignoring unrecognizable version of RunMetadata.') - continue - (_, tag_item) = add_row_item(run_name, tag) - tag_item['op_graph'] = True - - # Tensors associated with plugin name _PLUGIN_NAME_RUN_METADATA contain - # both op graph and profile information. - mapping = self._multiplexer.PluginRunToTagToContent( - _PLUGIN_NAME_RUN_METADATA) - for run_name, tag_to_content in six.iteritems(mapping): - for (tag, content) in six.iteritems(tag_to_content): - if content != b'1': - logger.warn('Ignoring unrecognizable version of RunMetadata.') - continue - (_, tag_item) = add_row_item(run_name, tag) - tag_item['profile'] = True - tag_item['op_graph'] = True - - # Tensors associated with plugin name _PLUGIN_NAME_KERAS_MODEL contain - # serialized Keras model in JSON format. - mapping = self._multiplexer.PluginRunToTagToContent( - _PLUGIN_NAME_KERAS_MODEL) - for run_name, tag_to_content in six.iteritems(mapping): - for (tag, content) in six.iteritems(tag_to_content): - if content != b'1': - logger.warn('Ignoring unrecognizable version of RunMetadata.') - continue - (_, tag_item) = add_row_item(run_name, tag) - tag_item['conceptual_graph'] = True - - for (run_name, run_data) in six.iteritems(self._multiplexer.Runs()): - if run_data.get(event_accumulator.GRAPH): - (run_item, _) = add_row_item(run_name, None) - run_item['run_graph'] = True - - for (run_name, run_data) in six.iteritems(self._multiplexer.Runs()): - if event_accumulator.RUN_METADATA in run_data: - for tag in run_data[event_accumulator.RUN_METADATA]: - (_, tag_item) = add_row_item(run_name, tag) - tag_item['profile'] = True - - return result - - def graph_impl(self, run, tag, is_conceptual, experiment=None, limit_attr_size=None, large_attrs_key=None): - """Result of the form `(body, mime_type)`, or `None` if no graph exists.""" - if self._data_provider: - graph_blob_sequences = self._data_provider.read_blob_sequences( - experiment_id=experiment, - plugin_name=metadata.PLUGIN_NAME, - run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]), - ) - blob_datum_list = graph_blob_sequences.get(run, {}).get(tag, ()) - try: - blob_ref = blob_datum_list[0].values[0] - except IndexError: - return None - # Always use the blob_key approach for now, even if there is a direct url. - graph_raw = self._data_provider.read_blob(blob_ref.blob_key) - # This method ultimately returns pbtxt, but we have to deserialize and - # later reserialize this anyway, because a) this way we accept binary - # protobufs too, and b) below we run `prepare_graph_for_ui` on the graph. - graph = graph_pb2.GraphDef.FromString(graph_raw) - - elif is_conceptual: - tensor_events = self._multiplexer.Tensors(run, tag) - # Take the first event if there are multiple events written from different - # steps. - keras_model_config = json.loads(tensor_events[0].tensor_proto.string_val[0]) - graph = keras_util.keras_model_to_graph_def(keras_model_config) - - elif tag: - tensor_events = self._multiplexer.Tensors(run, tag) - # Take the first event if there are multiple events written from different - # steps. - run_metadata = config_pb2.RunMetadata.FromString( - tensor_events[0].tensor_proto.string_val[0]) - graph = graph_pb2.GraphDef() - - for func_graph in run_metadata.function_graphs: - graph_util.combine_graph_defs(graph, func_graph.pre_optimization_graph) - else: - graph = self._multiplexer.Graph(run) - - # This next line might raise a ValueError if the limit parameters - # are invalid (size is negative, size present but key absent, etc.). - process_graph.prepare_graph_for_ui(graph, limit_attr_size, large_attrs_key) - return (str(graph), 'text/x-protobuf') # pbtxt - - def run_metadata_impl(self, run, tag): - """Result of the form `(body, mime_type)`, or `None` if no data exists.""" - if self._data_provider: - # TODO(davidsoergel, wchargin): Consider plumbing run metadata through data providers. - return None - try: - run_metadata = self._multiplexer.RunMetadata(run, tag) - except ValueError: - # TODO(stephanwlee): Should include whether FE is fetching for v1 or v2 RunMetadata - # so we can remove this try/except. - tensor_events = self._multiplexer.Tensors(run, tag) - if tensor_events is None: - return None - # Take the first event if there are multiple events written from different - # steps. - run_metadata = config_pb2.RunMetadata.FromString( - tensor_events[0].tensor_proto.string_val[0]) - if run_metadata is None: - return None - return (str(run_metadata), 'text/x-protobuf') # pbtxt - - @wrappers.Request.application - def info_route(self, request): - experiment = plugin_util.experiment_id(request.environ) - info = self.info_impl(experiment) - return http_util.Respond(request, info, 'application/json') - - @wrappers.Request.application - def graph_route(self, request): - """Given a single run, return the graph definition in protobuf format.""" - experiment = plugin_util.experiment_id(request.environ) - run = request.args.get('run') - tag = request.args.get('tag') - conceptual_arg = request.args.get('conceptual', False) - is_conceptual = True if conceptual_arg == 'true' else False - - if run is None: - return http_util.Respond( - request, 'query parameter "run" is required', 'text/plain', 400) - - limit_attr_size = request.args.get('limit_attr_size', None) - if limit_attr_size is not None: - try: - limit_attr_size = int(limit_attr_size) - except ValueError: - return http_util.Respond( - request, 'query parameter `limit_attr_size` must be an integer', - 'text/plain', 400) - - large_attrs_key = request.args.get('large_attrs_key', None) - - try: - result = self.graph_impl(run, tag, is_conceptual, experiment, limit_attr_size, large_attrs_key) - except ValueError as e: - return http_util.Respond(request, e.message, 'text/plain', code=400) - else: - if result is not None: - (body, mime_type) = result # pylint: disable=unpacking-non-sequence - return http_util.Respond(request, body, mime_type) - else: - return http_util.Respond(request, '404 Not Found', 'text/plain', - code=404) - - @wrappers.Request.application - def run_metadata_route(self, request): - """Given a tag and a run, return the session.run() metadata.""" - tag = request.args.get('tag') - run = request.args.get('run') - if tag is None: - return http_util.Respond( - request, 'query parameter "tag" is required', 'text/plain', 400) - if run is None: - return http_util.Respond( - request, 'query parameter "run" is required', 'text/plain', 400) - result = self.run_metadata_impl(run, tag) - if result is not None: - (body, mime_type) = result # pylint: disable=unpacking-non-sequence - return http_util.Respond(request, body, mime_type) - else: - return http_util.Respond(request, '404 Not Found', 'text/plain', - code=404) + """Graphs Plugin for TensorBoard.""" + + plugin_name = metadata.PLUGIN_NAME + + def __init__(self, context): + """Instantiates GraphsPlugin via TensorBoard core. + + Args: + context: A base_plugin.TBContext instance. + """ + self._multiplexer = context.multiplexer + if context.flags and context.flags.generic_data == "true": + self._data_provider = context.data_provider + else: + self._data_provider = None + + def get_plugin_apps(self): + return { + "/graph": self.graph_route, + "/info": self.info_route, + "/run_metadata": self.run_metadata_route, + } + + def is_active(self): + """The graphs plugin is active iff any run has a graph or metadata.""" + if self._data_provider: + # We don't have an experiment ID, and modifying the backend core + # to provide one would break backward compatibility. Hack for now. + return True + + return bool(self.info_impl()) + + def frontend_metadata(self): + return base_plugin.FrontendMetadata( + element_name="tf-graph-dashboard", + # TODO(@chihuahua): Reconcile this setting with Health Pills. + disable_reload=True, + ) + + def info_impl(self, experiment=None): + """Returns a dict of all runs and their data availabilities.""" + result = {} + + def add_row_item(run, tag=None): + run_item = result.setdefault( + run, + { + "run": run, + "tags": {}, + # A run-wide GraphDef of ops. + "run_graph": False, + }, + ) + + tag_item = None + if tag: + tag_item = run_item.get("tags").setdefault( + tag, + { + "tag": tag, + "conceptual_graph": False, + # A tagged GraphDef of ops. + "op_graph": False, + "profile": False, + }, + ) + return (run_item, tag_item) + + if self._data_provider: + mapping = self._data_provider.list_blob_sequences( + experiment_id=experiment, plugin_name=metadata.PLUGIN_NAME, + ) + for (run_name, tag_to_time_series) in six.iteritems(mapping): + for tag in tag_to_time_series: + (run_item, tag_item) = add_row_item(run_name, tag) + run_item["run_graph"] = True + if tag_item: + tag_item["op_graph"] = True + return result + + mapping = self._multiplexer.PluginRunToTagToContent( + _PLUGIN_NAME_RUN_METADATA_WITH_GRAPH + ) + for run_name, tag_to_content in six.iteritems(mapping): + for (tag, content) in six.iteritems(tag_to_content): + # The Summary op is defined in TensorFlow and does not use a stringified proto + # as a content of plugin data. It contains single string that denotes a version. + # https://github.com/tensorflow/tensorflow/blob/11f4ecb54708865ec757ca64e4805957b05d7570/tensorflow/python/ops/summary_ops_v2.py#L789-L790 + if content != b"1": + logger.warn( + "Ignoring unrecognizable version of RunMetadata." + ) + continue + (_, tag_item) = add_row_item(run_name, tag) + tag_item["op_graph"] = True + + # Tensors associated with plugin name _PLUGIN_NAME_RUN_METADATA contain + # both op graph and profile information. + mapping = self._multiplexer.PluginRunToTagToContent( + _PLUGIN_NAME_RUN_METADATA + ) + for run_name, tag_to_content in six.iteritems(mapping): + for (tag, content) in six.iteritems(tag_to_content): + if content != b"1": + logger.warn( + "Ignoring unrecognizable version of RunMetadata." + ) + continue + (_, tag_item) = add_row_item(run_name, tag) + tag_item["profile"] = True + tag_item["op_graph"] = True + + # Tensors associated with plugin name _PLUGIN_NAME_KERAS_MODEL contain + # serialized Keras model in JSON format. + mapping = self._multiplexer.PluginRunToTagToContent( + _PLUGIN_NAME_KERAS_MODEL + ) + for run_name, tag_to_content in six.iteritems(mapping): + for (tag, content) in six.iteritems(tag_to_content): + if content != b"1": + logger.warn( + "Ignoring unrecognizable version of RunMetadata." + ) + continue + (_, tag_item) = add_row_item(run_name, tag) + tag_item["conceptual_graph"] = True + + for (run_name, run_data) in six.iteritems(self._multiplexer.Runs()): + if run_data.get(event_accumulator.GRAPH): + (run_item, _) = add_row_item(run_name, None) + run_item["run_graph"] = True + + for (run_name, run_data) in six.iteritems(self._multiplexer.Runs()): + if event_accumulator.RUN_METADATA in run_data: + for tag in run_data[event_accumulator.RUN_METADATA]: + (_, tag_item) = add_row_item(run_name, tag) + tag_item["profile"] = True + + return result + + def graph_impl( + self, + run, + tag, + is_conceptual, + experiment=None, + limit_attr_size=None, + large_attrs_key=None, + ): + """Result of the form `(body, mime_type)`, or `None` if no graph + exists.""" + if self._data_provider: + graph_blob_sequences = self._data_provider.read_blob_sequences( + experiment_id=experiment, + plugin_name=metadata.PLUGIN_NAME, + run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]), + ) + blob_datum_list = graph_blob_sequences.get(run, {}).get(tag, ()) + try: + blob_ref = blob_datum_list[0].values[0] + except IndexError: + return None + # Always use the blob_key approach for now, even if there is a direct url. + graph_raw = self._data_provider.read_blob(blob_ref.blob_key) + # This method ultimately returns pbtxt, but we have to deserialize and + # later reserialize this anyway, because a) this way we accept binary + # protobufs too, and b) below we run `prepare_graph_for_ui` on the graph. + graph = graph_pb2.GraphDef.FromString(graph_raw) + + elif is_conceptual: + tensor_events = self._multiplexer.Tensors(run, tag) + # Take the first event if there are multiple events written from different + # steps. + keras_model_config = json.loads( + tensor_events[0].tensor_proto.string_val[0] + ) + graph = keras_util.keras_model_to_graph_def(keras_model_config) + + elif tag: + tensor_events = self._multiplexer.Tensors(run, tag) + # Take the first event if there are multiple events written from different + # steps. + run_metadata = config_pb2.RunMetadata.FromString( + tensor_events[0].tensor_proto.string_val[0] + ) + graph = graph_pb2.GraphDef() + + for func_graph in run_metadata.function_graphs: + graph_util.combine_graph_defs( + graph, func_graph.pre_optimization_graph + ) + else: + graph = self._multiplexer.Graph(run) + + # This next line might raise a ValueError if the limit parameters + # are invalid (size is negative, size present but key absent, etc.). + process_graph.prepare_graph_for_ui( + graph, limit_attr_size, large_attrs_key + ) + return (str(graph), "text/x-protobuf") # pbtxt + + def run_metadata_impl(self, run, tag): + """Result of the form `(body, mime_type)`, or `None` if no data + exists.""" + if self._data_provider: + # TODO(davidsoergel, wchargin): Consider plumbing run metadata through data providers. + return None + try: + run_metadata = self._multiplexer.RunMetadata(run, tag) + except ValueError: + # TODO(stephanwlee): Should include whether FE is fetching for v1 or v2 RunMetadata + # so we can remove this try/except. + tensor_events = self._multiplexer.Tensors(run, tag) + if tensor_events is None: + return None + # Take the first event if there are multiple events written from different + # steps. + run_metadata = config_pb2.RunMetadata.FromString( + tensor_events[0].tensor_proto.string_val[0] + ) + if run_metadata is None: + return None + return (str(run_metadata), "text/x-protobuf") # pbtxt + + @wrappers.Request.application + def info_route(self, request): + experiment = plugin_util.experiment_id(request.environ) + info = self.info_impl(experiment) + return http_util.Respond(request, info, "application/json") + + @wrappers.Request.application + def graph_route(self, request): + """Given a single run, return the graph definition in protobuf + format.""" + experiment = plugin_util.experiment_id(request.environ) + run = request.args.get("run") + tag = request.args.get("tag") + conceptual_arg = request.args.get("conceptual", False) + is_conceptual = True if conceptual_arg == "true" else False + + if run is None: + return http_util.Respond( + request, 'query parameter "run" is required', "text/plain", 400 + ) + + limit_attr_size = request.args.get("limit_attr_size", None) + if limit_attr_size is not None: + try: + limit_attr_size = int(limit_attr_size) + except ValueError: + return http_util.Respond( + request, + "query parameter `limit_attr_size` must be an integer", + "text/plain", + 400, + ) + + large_attrs_key = request.args.get("large_attrs_key", None) + + try: + result = self.graph_impl( + run, + tag, + is_conceptual, + experiment, + limit_attr_size, + large_attrs_key, + ) + except ValueError as e: + return http_util.Respond(request, e.message, "text/plain", code=400) + else: + if result is not None: + ( + body, + mime_type, + ) = result # pylint: disable=unpacking-non-sequence + return http_util.Respond(request, body, mime_type) + else: + return http_util.Respond( + request, "404 Not Found", "text/plain", code=404 + ) + + @wrappers.Request.application + def run_metadata_route(self, request): + """Given a tag and a run, return the session.run() metadata.""" + tag = request.args.get("tag") + run = request.args.get("run") + if tag is None: + return http_util.Respond( + request, 'query parameter "tag" is required', "text/plain", 400 + ) + if run is None: + return http_util.Respond( + request, 'query parameter "run" is required', "text/plain", 400 + ) + result = self.run_metadata_impl(run, tag) + if result is not None: + (body, mime_type) = result # pylint: disable=unpacking-non-sequence + return http_util.Respond(request, body, mime_type) + else: + return http_util.Respond( + request, "404 Not Found", "text/plain", code=404 + ) diff --git a/tensorboard/plugins/graph/graphs_plugin_test.py b/tensorboard/plugins/graph/graphs_plugin_test.py index 70dbb51f0b..a573c412f8 100644 --- a/tensorboard/plugins/graph/graphs_plugin_test.py +++ b/tensorboard/plugins/graph/graphs_plugin_test.py @@ -29,7 +29,9 @@ from google.protobuf import text_format from tensorboard.backend.event_processing import data_provider -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from tensorboard.backend.event_processing import ( + plugin_event_multiplexer as event_multiplexer, +) from tensorboard.compat.proto import config_pb2 from tensorboard.plugins import base_plugin from tensorboard.plugins.graph import graphs_plugin @@ -42,227 +44,282 @@ # can write graph and metadata with a TF public API. -_RUN_WITH_GRAPH_WITH_METADATA = ('_RUN_WITH_GRAPH_WITH_METADATA', True, True) -_RUN_WITHOUT_GRAPH_WITH_METADATA = ('_RUN_WITHOUT_GRAPH_WITH_METADATA', False, True) -_RUN_WITH_GRAPH_WITHOUT_METADATA = ('_RUN_WITH_GRAPH_WITHOUT_METADATA', True, False) -_RUN_WITHOUT_GRAPH_WITHOUT_METADATA = ('_RUN_WITHOUT_GRAPH_WITHOUT_METADATA', False, False) +_RUN_WITH_GRAPH_WITH_METADATA = ("_RUN_WITH_GRAPH_WITH_METADATA", True, True) +_RUN_WITHOUT_GRAPH_WITH_METADATA = ( + "_RUN_WITHOUT_GRAPH_WITH_METADATA", + False, + True, +) +_RUN_WITH_GRAPH_WITHOUT_METADATA = ( + "_RUN_WITH_GRAPH_WITHOUT_METADATA", + True, + False, +) +_RUN_WITHOUT_GRAPH_WITHOUT_METADATA = ( + "_RUN_WITHOUT_GRAPH_WITHOUT_METADATA", + False, + False, +) + def with_runs(run_specs): - """Run a test with a bare multiplexer and with a `data_provider`. - - The decorated function will receive an initialized `GraphsPlugin` - object as its first positional argument. - - The receiver argument of the decorated function must be a `TestCase` instance - that also provides `load_runs`.` - """ - def decorator(fn): - @functools.wraps(fn) - def wrapper(self, *args, **kwargs): - (logdir, multiplexer) = self.load_runs(run_specs) - with self.subTest('bare multiplexer'): - ctx = base_plugin.TBContext(logdir=logdir, multiplexer=multiplexer) - fn(self, graphs_plugin.GraphsPlugin(ctx), *args, **kwargs) - with self.subTest('generic data provider'): - flags = argparse.Namespace(generic_data='true') - provider = data_provider.MultiplexerDataProvider(multiplexer, logdir) - ctx = base_plugin.TBContext( - flags=flags, - logdir=logdir, - multiplexer=multiplexer, - data_provider=provider, - ) - fn(self, graphs_plugin.GraphsPlugin(ctx), *args, **kwargs) - return wrapper - return decorator + """Run a test with a bare multiplexer and with a `data_provider`. + + The decorated function will receive an initialized `GraphsPlugin` + object as its first positional argument. + + The receiver argument of the decorated function must be a `TestCase` instance + that also provides `load_runs`.` + """ + + def decorator(fn): + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + (logdir, multiplexer) = self.load_runs(run_specs) + with self.subTest("bare multiplexer"): + ctx = base_plugin.TBContext( + logdir=logdir, multiplexer=multiplexer + ) + fn(self, graphs_plugin.GraphsPlugin(ctx), *args, **kwargs) + with self.subTest("generic data provider"): + flags = argparse.Namespace(generic_data="true") + provider = data_provider.MultiplexerDataProvider( + multiplexer, logdir + ) + ctx = base_plugin.TBContext( + flags=flags, + logdir=logdir, + multiplexer=multiplexer, + data_provider=provider, + ) + fn(self, graphs_plugin.GraphsPlugin(ctx), *args, **kwargs) + + return wrapper + + return decorator + class GraphsPluginBaseTest(object): - _METADATA_TAG = 'secret-stats' - _MESSAGE_PREFIX_LENGTH_LOWER_BOUND = 1024 + _METADATA_TAG = "secret-stats" + _MESSAGE_PREFIX_LENGTH_LOWER_BOUND = 1024 + + def __init__(self, *args, **kwargs): + super(GraphsPluginBaseTest, self).__init__(*args, **kwargs) + self.plugin = None - def __init__(self, *args, **kwargs): - super(GraphsPluginBaseTest, self).__init__(*args, **kwargs) - self.plugin = None + def setUp(self): + super(GraphsPluginBaseTest, self).setUp() - def setUp(self): - super(GraphsPluginBaseTest, self).setUp() + def generate_run( + self, logdir, run_name, include_graph, include_run_metadata + ): + """Create a run.""" + raise NotImplementedError("Please implement generate_run") - def generate_run(self, logdir, run_name, include_graph, include_run_metadata): - """Create a run""" - raise NotImplementedError('Please implement generate_run') + def load_runs(self, run_specs): + logdir = self.get_temp_dir() + for run_spec in run_specs: + self.generate_run(logdir, *run_spec) + return self.bootstrap_plugin(logdir) - def load_runs(self, run_specs): - logdir = self.get_temp_dir() - for run_spec in run_specs: - self.generate_run(logdir, *run_spec) - return self.bootstrap_plugin(logdir) + def bootstrap_plugin(self, logdir): + multiplexer = event_multiplexer.EventMultiplexer() + multiplexer.AddRunsFromDirectory(logdir) + multiplexer.Reload() + return (logdir, multiplexer) - def bootstrap_plugin(self, logdir): - multiplexer = event_multiplexer.EventMultiplexer() - multiplexer.AddRunsFromDirectory(logdir) - multiplexer.Reload() - return (logdir, multiplexer) + @with_runs( + [_RUN_WITH_GRAPH_WITH_METADATA, _RUN_WITHOUT_GRAPH_WITH_METADATA] + ) + def testRoutesProvided(self, plugin): + """Tests that the plugin offers the correct routes.""" + routes = plugin.get_plugin_apps() + self.assertIsInstance(routes["/graph"], collections.Callable) + self.assertIsInstance(routes["/run_metadata"], collections.Callable) + self.assertIsInstance(routes["/info"], collections.Callable) - @with_runs([_RUN_WITH_GRAPH_WITH_METADATA, _RUN_WITHOUT_GRAPH_WITH_METADATA]) - def testRoutesProvided(self, plugin): - """Tests that the plugin offers the correct routes.""" - routes = plugin.get_plugin_apps() - self.assertIsInstance(routes['/graph'], collections.Callable) - self.assertIsInstance(routes['/run_metadata'], collections.Callable) - self.assertIsInstance(routes['/info'], collections.Callable) class GraphsPluginV1Test(GraphsPluginBaseTest, tf.test.TestCase): + def generate_run( + self, logdir, run_name, include_graph, include_run_metadata + ): + """Create a run with a text summary, metadata, and optionally a + graph.""" + tf.compat.v1.reset_default_graph() + k1 = tf.constant(math.pi, name="k1") + k2 = tf.constant(math.e, name="k2") + result = (k1 ** k2) - k1 + expected = tf.constant(20.0, name="expected") + error = tf.abs(result - expected, name="error") + message_prefix_value = "error " * 1000 + true_length = len(message_prefix_value) + assert ( + true_length > self._MESSAGE_PREFIX_LENGTH_LOWER_BOUND + ), true_length + message_prefix = tf.constant( + message_prefix_value, name="message_prefix" + ) + error_message = tf.strings.join( + [message_prefix, tf.as_string(error, name="error_string")], + name="error_message", + ) + summary_message = tf.compat.v1.summary.text( + "summary_message", error_message + ) - def generate_run(self, logdir, run_name, include_graph, include_run_metadata): - """Create a run with a text summary, metadata, and optionally a graph.""" - tf.compat.v1.reset_default_graph() - k1 = tf.constant(math.pi, name='k1') - k2 = tf.constant(math.e, name='k2') - result = (k1 ** k2) - k1 - expected = tf.constant(20.0, name='expected') - error = tf.abs(result - expected, name='error') - message_prefix_value = 'error ' * 1000 - true_length = len(message_prefix_value) - assert true_length > self._MESSAGE_PREFIX_LENGTH_LOWER_BOUND, true_length - message_prefix = tf.constant(message_prefix_value, name='message_prefix') - error_message = tf.strings.join([message_prefix, - tf.as_string(error, name='error_string')], - name='error_message') - summary_message = tf.compat.v1.summary.text('summary_message', error_message) - - sess = tf.compat.v1.Session() - writer = test_util.FileWriter(os.path.join(logdir, run_name)) - if include_graph: - writer.add_graph(sess.graph) - options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE) - run_metadata = config_pb2.RunMetadata() - s = sess.run(summary_message, options=options, run_metadata=run_metadata) - writer.add_summary(s) - if include_run_metadata: - writer.add_run_metadata(run_metadata, self._METADATA_TAG) - writer.close() - - def _get_graph(self, plugin, *args, **kwargs): - """Set up runs, then fetch and return the graph as a proto.""" - (graph_pbtxt, mime_type) = plugin.graph_impl( - _RUN_WITH_GRAPH_WITH_METADATA[0], *args, **kwargs) - self.assertEqual(mime_type, 'text/x-protobuf') - return text_format.Parse(graph_pbtxt, tf.compat.v1.GraphDef()) - - @with_runs([ - _RUN_WITH_GRAPH_WITH_METADATA, - _RUN_WITH_GRAPH_WITHOUT_METADATA, - _RUN_WITHOUT_GRAPH_WITH_METADATA, - _RUN_WITHOUT_GRAPH_WITHOUT_METADATA]) - def test_info(self, plugin): - expected = { - '_RUN_WITH_GRAPH_WITH_METADATA': { - 'run': '_RUN_WITH_GRAPH_WITH_METADATA', - 'run_graph': True, - 'tags': { - 'secret-stats': { - 'conceptual_graph': False, - 'profile': True, - 'tag': 'secret-stats', - 'op_graph': False, - }, - }, - }, - '_RUN_WITH_GRAPH_WITHOUT_METADATA': { - 'run': '_RUN_WITH_GRAPH_WITHOUT_METADATA', - 'run_graph': True, - 'tags': {}, - }, - '_RUN_WITHOUT_GRAPH_WITH_METADATA': { - 'run': '_RUN_WITHOUT_GRAPH_WITH_METADATA', - 'run_graph': False, - 'tags': { - 'secret-stats': { - 'conceptual_graph': False, - 'profile': True, - 'tag': 'secret-stats', - 'op_graph': False, - }, - }, - }, - } - - if plugin._data_provider: - # Hack, for now. - # Data providers don't yet pass RunMetadata, so this entry excludes it. - expected['_RUN_WITH_GRAPH_WITH_METADATA']['tags'] = {} - # Data providers don't yet pass RunMetadata, so this entry is completely omitted. - del expected['_RUN_WITHOUT_GRAPH_WITH_METADATA'] - - actual = plugin.info_impl('eid') - self.assertEqual(expected, actual) - - @with_runs([_RUN_WITH_GRAPH_WITH_METADATA]) - def test_graph_simple(self, plugin): - graph = self._get_graph( - plugin, - tag=None, - is_conceptual=False, - experiment='eid', + sess = tf.compat.v1.Session() + writer = test_util.FileWriter(os.path.join(logdir, run_name)) + if include_graph: + writer.add_graph(sess.graph) + options = tf.compat.v1.RunOptions( + trace_level=tf.compat.v1.RunOptions.FULL_TRACE + ) + run_metadata = config_pb2.RunMetadata() + s = sess.run( + summary_message, options=options, run_metadata=run_metadata + ) + writer.add_summary(s) + if include_run_metadata: + writer.add_run_metadata(run_metadata, self._METADATA_TAG) + writer.close() + + def _get_graph(self, plugin, *args, **kwargs): + """Set up runs, then fetch and return the graph as a proto.""" + (graph_pbtxt, mime_type) = plugin.graph_impl( + _RUN_WITH_GRAPH_WITH_METADATA[0], *args, **kwargs + ) + self.assertEqual(mime_type, "text/x-protobuf") + return text_format.Parse(graph_pbtxt, tf.compat.v1.GraphDef()) + + @with_runs( + [ + _RUN_WITH_GRAPH_WITH_METADATA, + _RUN_WITH_GRAPH_WITHOUT_METADATA, + _RUN_WITHOUT_GRAPH_WITH_METADATA, + _RUN_WITHOUT_GRAPH_WITHOUT_METADATA, + ] ) - node_names = set(node.name for node in graph.node) - self.assertEqual({ - 'k1', 'k2', 'pow', 'sub', 'expected', 'sub_1', 'error', - 'message_prefix', 'error_string', 'error_message', 'summary_message', - 'summary_message/tag', 'summary_message/serialized_summary_metadata', - }, node_names) - - @with_runs([_RUN_WITH_GRAPH_WITH_METADATA]) - def test_graph_large_attrs(self, plugin): - key = 'o---;;-;' - graph = self._get_graph( - plugin, - tag=None, - is_conceptual=False, - experiment='eid', - limit_attr_size=self._MESSAGE_PREFIX_LENGTH_LOWER_BOUND, - large_attrs_key=key) - large_attrs = { - node.name: list(node.attr[key].list.s) - for node in graph.node - if key in node.attr - } - self.assertEqual({'message_prefix': [b'value']}, - large_attrs) - - @with_runs([_RUN_WITH_GRAPH_WITH_METADATA]) - def test_run_metadata(self, plugin): - result = plugin.run_metadata_impl( - _RUN_WITH_GRAPH_WITH_METADATA[0], self._METADATA_TAG) - if plugin._data_provider: - # Hack, for now - self.assertEqual(result, None) - else: - (metadata_pbtxt, mime_type) = result - self.assertEqual(mime_type, 'text/x-protobuf') - text_format.Parse(metadata_pbtxt, config_pb2.RunMetadata()) - # If it parses, we're happy. - - @with_runs([_RUN_WITH_GRAPH_WITHOUT_METADATA]) - def test_is_active_with_graph_without_run_metadata(self, plugin): - self.assertTrue(plugin.is_active()) - - @with_runs([_RUN_WITHOUT_GRAPH_WITH_METADATA]) - def test_is_active_without_graph_with_run_metadata(self, plugin): - self.assertTrue(plugin.is_active()) - - @with_runs([_RUN_WITH_GRAPH_WITH_METADATA]) - def test_is_active_with_both(self, plugin): - self.assertTrue(plugin.is_active()) - - @with_runs([_RUN_WITHOUT_GRAPH_WITHOUT_METADATA]) - def test_is_inactive_without_both(self, plugin): - if plugin._data_provider: - # Hack, for now. - self.assertTrue(plugin.is_active()) - else: - self.assertFalse(plugin.is_active()) - -if __name__ == '__main__': - tf.test.main() + def test_info(self, plugin): + expected = { + "_RUN_WITH_GRAPH_WITH_METADATA": { + "run": "_RUN_WITH_GRAPH_WITH_METADATA", + "run_graph": True, + "tags": { + "secret-stats": { + "conceptual_graph": False, + "profile": True, + "tag": "secret-stats", + "op_graph": False, + }, + }, + }, + "_RUN_WITH_GRAPH_WITHOUT_METADATA": { + "run": "_RUN_WITH_GRAPH_WITHOUT_METADATA", + "run_graph": True, + "tags": {}, + }, + "_RUN_WITHOUT_GRAPH_WITH_METADATA": { + "run": "_RUN_WITHOUT_GRAPH_WITH_METADATA", + "run_graph": False, + "tags": { + "secret-stats": { + "conceptual_graph": False, + "profile": True, + "tag": "secret-stats", + "op_graph": False, + }, + }, + }, + } + + if plugin._data_provider: + # Hack, for now. + # Data providers don't yet pass RunMetadata, so this entry excludes it. + expected["_RUN_WITH_GRAPH_WITH_METADATA"]["tags"] = {} + # Data providers don't yet pass RunMetadata, so this entry is completely omitted. + del expected["_RUN_WITHOUT_GRAPH_WITH_METADATA"] + + actual = plugin.info_impl("eid") + self.assertEqual(expected, actual) + + @with_runs([_RUN_WITH_GRAPH_WITH_METADATA]) + def test_graph_simple(self, plugin): + graph = self._get_graph( + plugin, tag=None, is_conceptual=False, experiment="eid", + ) + node_names = set(node.name for node in graph.node) + self.assertEqual( + { + "k1", + "k2", + "pow", + "sub", + "expected", + "sub_1", + "error", + "message_prefix", + "error_string", + "error_message", + "summary_message", + "summary_message/tag", + "summary_message/serialized_summary_metadata", + }, + node_names, + ) + + @with_runs([_RUN_WITH_GRAPH_WITH_METADATA]) + def test_graph_large_attrs(self, plugin): + key = "o---;;-;" + graph = self._get_graph( + plugin, + tag=None, + is_conceptual=False, + experiment="eid", + limit_attr_size=self._MESSAGE_PREFIX_LENGTH_LOWER_BOUND, + large_attrs_key=key, + ) + large_attrs = { + node.name: list(node.attr[key].list.s) + for node in graph.node + if key in node.attr + } + self.assertEqual({"message_prefix": [b"value"]}, large_attrs) + + @with_runs([_RUN_WITH_GRAPH_WITH_METADATA]) + def test_run_metadata(self, plugin): + result = plugin.run_metadata_impl( + _RUN_WITH_GRAPH_WITH_METADATA[0], self._METADATA_TAG + ) + if plugin._data_provider: + # Hack, for now + self.assertEqual(result, None) + else: + (metadata_pbtxt, mime_type) = result + self.assertEqual(mime_type, "text/x-protobuf") + text_format.Parse(metadata_pbtxt, config_pb2.RunMetadata()) + # If it parses, we're happy. + + @with_runs([_RUN_WITH_GRAPH_WITHOUT_METADATA]) + def test_is_active_with_graph_without_run_metadata(self, plugin): + self.assertTrue(plugin.is_active()) + + @with_runs([_RUN_WITHOUT_GRAPH_WITH_METADATA]) + def test_is_active_without_graph_with_run_metadata(self, plugin): + self.assertTrue(plugin.is_active()) + + @with_runs([_RUN_WITH_GRAPH_WITH_METADATA]) + def test_is_active_with_both(self, plugin): + self.assertTrue(plugin.is_active()) + + @with_runs([_RUN_WITHOUT_GRAPH_WITHOUT_METADATA]) + def test_is_inactive_without_both(self, plugin): + if plugin._data_provider: + # Hack, for now. + self.assertTrue(plugin.is_active()) + else: + self.assertFalse(plugin.is_active()) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/graph/graphs_plugin_v2_test.py b/tensorboard/plugins/graph/graphs_plugin_v2_test.py index ca08955820..782a738c62 100644 --- a/tensorboard/plugins/graph/graphs_plugin_v2_test.py +++ b/tensorboard/plugins/graph/graphs_plugin_v2_test.py @@ -28,76 +28,95 @@ from tensorboard.plugins.graph import graphs_plugin_test -class GraphsPluginV2Test(graphs_plugin_test.GraphsPluginBaseTest, tf.test.TestCase): - - def generate_run(self, logdir, run_name, include_graph, include_run_metadata): - x, y = np.ones((10, 10)), np.ones((10, 1)) - val_x, val_y = np.ones((4, 10)), np.ones((4, 1)) - - model = tf.keras.Sequential([ - tf.keras.layers.Dense(10, activation='relu'), - tf.keras.layers.Dense(1, activation='sigmoid')]) - model.compile('rmsprop', 'binary_crossentropy') - - model.fit( - x, - y, - validation_data=(val_x, val_y), - batch_size=2, - epochs=1, - callbacks=[tf.compat.v2.keras.callbacks.TensorBoard( - log_dir=os.path.join(logdir, run_name), - write_graph=include_graph)]) - - def _get_graph(self, plugin, *args, **kwargs): - """Fetch and return the graph as a proto.""" - (graph_pbtxt, mime_type) = plugin.graph_impl(*args, **kwargs) - self.assertEqual(mime_type, 'text/x-protobuf') - return text_format.Parse(graph_pbtxt, graph_pb2.GraphDef()) - - @graphs_plugin_test.with_runs([ - graphs_plugin_test._RUN_WITH_GRAPH_WITH_METADATA, - graphs_plugin_test._RUN_WITHOUT_GRAPH_WITH_METADATA]) - def test_info(self, plugin): - raise self.skipTest('TODO: enable this after tf-nightly writes a conceptual graph.') - - expected = { - 'w_graph_wo_meta': { - 'run': 'w_graph_wo_meta', - 'run_graph': True, - 'tags': { - 'keras': { - 'conceptual_graph': True, - 'profile': False, - 'tag': 'keras', - 'op_graph': False, - }, - }, - }, - } - - self.generate_run('w_graph_wo_meta', - include_graph=True, - include_run_metadata=False) - self.generate_run('wo_graph_wo_meta', - include_graph=False, - include_run_metadata=False) - self.bootstrap_plugin() - - self.assertEqual(expected, plugin.info_impl()) - - def test_graph_conceptual_graph(self): - raise self.skipTest('TODO: enable this after tf-nightly writes a conceptual graph.') - - self.generate_run(self._RUN_WITH_GRAPH, - include_graph=True, - include_run_metadata=False) - self.bootstrap_plugin() - - graph = self._get_graph(self._RUN_WITH_GRAPH, tag='keras', is_conceptual=True) - node_names = set(node.name for node in graph.node) - self.assertEqual({'sequential/dense', 'sequential/dense_1'}, node_names) - - -if __name__ == '__main__': - tf.test.main() +class GraphsPluginV2Test( + graphs_plugin_test.GraphsPluginBaseTest, tf.test.TestCase +): + def generate_run( + self, logdir, run_name, include_graph, include_run_metadata + ): + x, y = np.ones((10, 10)), np.ones((10, 1)) + val_x, val_y = np.ones((4, 10)), np.ones((4, 1)) + + model = tf.keras.Sequential( + [ + tf.keras.layers.Dense(10, activation="relu"), + tf.keras.layers.Dense(1, activation="sigmoid"), + ] + ) + model.compile("rmsprop", "binary_crossentropy") + + model.fit( + x, + y, + validation_data=(val_x, val_y), + batch_size=2, + epochs=1, + callbacks=[ + tf.compat.v2.keras.callbacks.TensorBoard( + log_dir=os.path.join(logdir, run_name), + write_graph=include_graph, + ) + ], + ) + + def _get_graph(self, plugin, *args, **kwargs): + """Fetch and return the graph as a proto.""" + (graph_pbtxt, mime_type) = plugin.graph_impl(*args, **kwargs) + self.assertEqual(mime_type, "text/x-protobuf") + return text_format.Parse(graph_pbtxt, graph_pb2.GraphDef()) + + @graphs_plugin_test.with_runs( + [ + graphs_plugin_test._RUN_WITH_GRAPH_WITH_METADATA, + graphs_plugin_test._RUN_WITHOUT_GRAPH_WITH_METADATA, + ] + ) + def test_info(self, plugin): + raise self.skipTest( + "TODO: enable this after tf-nightly writes a conceptual graph." + ) + + expected = { + "w_graph_wo_meta": { + "run": "w_graph_wo_meta", + "run_graph": True, + "tags": { + "keras": { + "conceptual_graph": True, + "profile": False, + "tag": "keras", + "op_graph": False, + }, + }, + }, + } + + self.generate_run( + "w_graph_wo_meta", include_graph=True, include_run_metadata=False + ) + self.generate_run( + "wo_graph_wo_meta", include_graph=False, include_run_metadata=False + ) + self.bootstrap_plugin() + + self.assertEqual(expected, plugin.info_impl()) + + def test_graph_conceptual_graph(self): + raise self.skipTest( + "TODO: enable this after tf-nightly writes a conceptual graph." + ) + + self.generate_run( + self._RUN_WITH_GRAPH, include_graph=True, include_run_metadata=False + ) + self.bootstrap_plugin() + + graph = self._get_graph( + self._RUN_WITH_GRAPH, tag="keras", is_conceptual=True + ) + node_names = set(node.name for node in graph.node) + self.assertEqual({"sequential/dense", "sequential/dense_1"}, node_names) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/graph/keras_util.py b/tensorboard/plugins/graph/keras_util.py index dc720c2831..646569361e 100644 --- a/tensorboard/plugins/graph/keras_util.py +++ b/tensorboard/plugins/graph/keras_util.py @@ -46,7 +46,7 @@ def _walk_layers(keras_layer): - """Walks the nested keras layer configuration in preorder. + """Walks the nested keras layer configuration in preorder. Args: keras_layer: Keras configuration from model.to_json. @@ -55,185 +55,207 @@ def _walk_layers(keras_layer): name_scope: a string representing a scope name, similar to that of tf.name_scope. layer_config: a dict representing a Keras layer configuration. """ - yield ('', keras_layer) - if keras_layer.get('config').get('layers'): - name_scope = keras_layer.get('config').get('name') - for layer in keras_layer.get('config').get('layers'): - for (sub_name_scope, sublayer) in _walk_layers(layer): - sub_name_scope = '%s/%s' % ( - name_scope, sub_name_scope) if sub_name_scope else name_scope - yield (sub_name_scope, sublayer) + yield ("", keras_layer) + if keras_layer.get("config").get("layers"): + name_scope = keras_layer.get("config").get("name") + for layer in keras_layer.get("config").get("layers"): + for (sub_name_scope, sublayer) in _walk_layers(layer): + sub_name_scope = ( + "%s/%s" % (name_scope, sub_name_scope) + if sub_name_scope + else name_scope + ) + yield (sub_name_scope, sublayer) def _scoped_name(name_scope, node_name): - """Returns scoped name for a node as a string in the form '/'. + """Returns scoped name for a node as a string in the form '/'. - Args: - name_scope: a string representing a scope name, similar to that of tf.name_scope. - node_name: a string representing the current node name. + Args: + name_scope: a string representing a scope name, similar to that of tf.name_scope. + node_name: a string representing the current node name. - Returns - A string representing a scoped name. - """ - if name_scope: - return '%s/%s' % (name_scope, node_name) - return node_name + Returns + A string representing a scoped name. + """ + if name_scope: + return "%s/%s" % (name_scope, node_name) + return node_name def _is_model(layer): - """Returns True if layer is a model. + """Returns True if layer is a model. - Args: - layer: a dict representing a Keras model configuration. + Args: + layer: a dict representing a Keras model configuration. - Returns: - bool: True if layer is a model. - """ - return layer.get('config').get('layers') is not None + Returns: + bool: True if layer is a model. + """ + return layer.get("config").get("layers") is not None def _norm_to_list_of_layers(maybe_layers): - """Normalizes to a list of layers. - - Args: - maybe_layers: A list of data[1] or a list of list of data. - - Returns: - List of list of data. - - [1]: A Functional model has fields 'inbound_nodes' and 'output_layers' which can - look like below: - - ['in_layer_name', 0, 0] - - [['in_layer_is_model', 1, 0], ['in_layer_is_model', 1, 1]] - The data inside the list seems to describe [name, size, index]. - """ - return (maybe_layers if isinstance(maybe_layers[0], (list,)) - else [maybe_layers]) - - -def _update_dicts(name_scope, - model_layer, - input_to_in_layer, - model_name_to_output, - prev_node_name): - """Updates input_to_in_layer, model_name_to_output, and prev_node_name - based on the model_layer. - - Args: - name_scope: a string representing a scope name, similar to that of tf.name_scope. - model_layer: a dict representing a Keras model configuration. - input_to_in_layer: a dict mapping Keras.layers.Input to inbound layer. - model_name_to_output: a dict mapping Keras Model name to output layer of the model. - prev_node_name: a string representing a previous, in sequential model layout, - node name. - - Returns: - A tuple of (input_to_in_layer, model_name_to_output, prev_node_name). - input_to_in_layer: a dict mapping Keras.layers.Input to inbound layer. - model_name_to_output: a dict mapping Keras Model name to output layer of the model. - prev_node_name: a string representing a previous, in sequential model layout, - node name. - """ - layer_config = model_layer.get('config') - if not layer_config.get('layers'): - raise ValueError('layer is not a model.') - - node_name = _scoped_name(name_scope, layer_config.get('name')) - input_layers = layer_config.get('input_layers') - output_layers = layer_config.get('output_layers') - inbound_nodes = model_layer.get('inbound_nodes') - - is_functional_model = bool(input_layers and output_layers) - # In case of [1] and the parent model is functional, current layer - # will have the 'inbound_nodes' property. - is_parent_functional_model = bool(inbound_nodes) - - if is_parent_functional_model and is_functional_model: - for (input_layer, inbound_node) in zip(input_layers, inbound_nodes): - input_layer_name = _scoped_name(node_name, input_layer) - inbound_node_name = _scoped_name(name_scope, inbound_node[0]) - input_to_in_layer[input_layer_name] = inbound_node_name - elif is_parent_functional_model and not is_functional_model: - # Sequential model can take only one input. Make sure inbound to the - # model is linked to the first layer in the Sequential model. - prev_node_name = _scoped_name(name_scope, inbound_nodes[0][0][0]) - elif not is_parent_functional_model and prev_node_name and is_functional_model: - assert len(input_layers) == 1, ( - 'Cannot have multi-input Functional model when parent model ' - 'is not Functional. Number of input layers: %d' % len(input_layer)) - input_layer = input_layers[0] - input_layer_name = _scoped_name(node_name, input_layer) - input_to_in_layer[input_layer_name] = prev_node_name - - if is_functional_model and output_layers: - layers = _norm_to_list_of_layers(output_layers) - layer_names = [_scoped_name(node_name, layer[0]) for layer in layers] - model_name_to_output[node_name] = layer_names - else: - last_layer = layer_config.get('layers')[-1] - last_layer_name = last_layer.get('config').get('name') - output_node = _scoped_name(node_name, last_layer_name) - model_name_to_output[node_name] = [output_node] - return (input_to_in_layer, model_name_to_output, prev_node_name) + """Normalizes to a list of layers. + + Args: + maybe_layers: A list of data[1] or a list of list of data. + + Returns: + List of list of data. + + [1]: A Functional model has fields 'inbound_nodes' and 'output_layers' which can + look like below: + - ['in_layer_name', 0, 0] + - [['in_layer_is_model', 1, 0], ['in_layer_is_model', 1, 1]] + The data inside the list seems to describe [name, size, index]. + """ + return ( + maybe_layers if isinstance(maybe_layers[0], (list,)) else [maybe_layers] + ) + + +def _update_dicts( + name_scope, + model_layer, + input_to_in_layer, + model_name_to_output, + prev_node_name, +): + """Updates input_to_in_layer, model_name_to_output, and prev_node_name + based on the model_layer. + + Args: + name_scope: a string representing a scope name, similar to that of tf.name_scope. + model_layer: a dict representing a Keras model configuration. + input_to_in_layer: a dict mapping Keras.layers.Input to inbound layer. + model_name_to_output: a dict mapping Keras Model name to output layer of the model. + prev_node_name: a string representing a previous, in sequential model layout, + node name. + + Returns: + A tuple of (input_to_in_layer, model_name_to_output, prev_node_name). + input_to_in_layer: a dict mapping Keras.layers.Input to inbound layer. + model_name_to_output: a dict mapping Keras Model name to output layer of the model. + prev_node_name: a string representing a previous, in sequential model layout, + node name. + """ + layer_config = model_layer.get("config") + if not layer_config.get("layers"): + raise ValueError("layer is not a model.") + + node_name = _scoped_name(name_scope, layer_config.get("name")) + input_layers = layer_config.get("input_layers") + output_layers = layer_config.get("output_layers") + inbound_nodes = model_layer.get("inbound_nodes") + + is_functional_model = bool(input_layers and output_layers) + # In case of [1] and the parent model is functional, current layer + # will have the 'inbound_nodes' property. + is_parent_functional_model = bool(inbound_nodes) + + if is_parent_functional_model and is_functional_model: + for (input_layer, inbound_node) in zip(input_layers, inbound_nodes): + input_layer_name = _scoped_name(node_name, input_layer) + inbound_node_name = _scoped_name(name_scope, inbound_node[0]) + input_to_in_layer[input_layer_name] = inbound_node_name + elif is_parent_functional_model and not is_functional_model: + # Sequential model can take only one input. Make sure inbound to the + # model is linked to the first layer in the Sequential model. + prev_node_name = _scoped_name(name_scope, inbound_nodes[0][0][0]) + elif ( + not is_parent_functional_model + and prev_node_name + and is_functional_model + ): + assert len(input_layers) == 1, ( + "Cannot have multi-input Functional model when parent model " + "is not Functional. Number of input layers: %d" % len(input_layer) + ) + input_layer = input_layers[0] + input_layer_name = _scoped_name(node_name, input_layer) + input_to_in_layer[input_layer_name] = prev_node_name + + if is_functional_model and output_layers: + layers = _norm_to_list_of_layers(output_layers) + layer_names = [_scoped_name(node_name, layer[0]) for layer in layers] + model_name_to_output[node_name] = layer_names + else: + last_layer = layer_config.get("layers")[-1] + last_layer_name = last_layer.get("config").get("name") + output_node = _scoped_name(node_name, last_layer_name) + model_name_to_output[node_name] = [output_node] + return (input_to_in_layer, model_name_to_output, prev_node_name) def keras_model_to_graph_def(keras_layer): - """Returns a GraphDef representation of the Keras model in a dict form. - - Note that it only supports models that implemented to_json(). - - Args: - keras_layer: A dict from Keras model.to_json(). - - Returns: - A GraphDef representation of the layers in the model. - """ - input_to_layer = {} - model_name_to_output = {} - g = GraphDef() - - # Sequential model layers do not have a field "inbound_nodes" but - # instead are defined implicitly via order of layers. - prev_node_name = None - - for (name_scope, layer) in _walk_layers(keras_layer): - if _is_model(layer): - (input_to_layer, model_name_to_output, prev_node_name) = _update_dicts( - name_scope, layer, input_to_layer, model_name_to_output, prev_node_name) - continue - - layer_config = layer.get('config') - node_name = _scoped_name(name_scope, layer_config.get('name')) - - node_def = g.node.add() - node_def.name = node_name - - if layer.get('class_name') is not None: - keras_cls_name = layer.get('class_name').encode('ascii') - node_def.attr['keras_class'].s = keras_cls_name - - if layer_config.get('dtype') is not None: - tf_dtype = dtypes.as_dtype(layer_config.get('dtype')) - node_def.attr['dtype'].type = tf_dtype.as_datatype_enum - - if layer.get('inbound_nodes') is not None: - for maybe_inbound_node in layer.get('inbound_nodes'): - inbound_nodes = _norm_to_list_of_layers(maybe_inbound_node) - for [name, size, index, _] in inbound_nodes: - inbound_name = _scoped_name(name_scope, name) - # An input to a layer can be output from a model. In that case, the name - # of inbound_nodes to a layer is a name of a model. Remap the name of the - # model to output layer of the model. Also, since there can be multiple - # outputs in a model, make sure we pick the right output_layer from the model. - inbound_node_names = model_name_to_output.get( - inbound_name, [inbound_name]) - node_def.input.append(inbound_node_names[index]) - elif prev_node_name is not None: - node_def.input.append(prev_node_name) - - if node_name in input_to_layer: - node_def.input.append(input_to_layer.get(node_name)) - - prev_node_name = node_def.name - - return g + """Returns a GraphDef representation of the Keras model in a dict form. + + Note that it only supports models that implemented to_json(). + + Args: + keras_layer: A dict from Keras model.to_json(). + + Returns: + A GraphDef representation of the layers in the model. + """ + input_to_layer = {} + model_name_to_output = {} + g = GraphDef() + + # Sequential model layers do not have a field "inbound_nodes" but + # instead are defined implicitly via order of layers. + prev_node_name = None + + for (name_scope, layer) in _walk_layers(keras_layer): + if _is_model(layer): + ( + input_to_layer, + model_name_to_output, + prev_node_name, + ) = _update_dicts( + name_scope, + layer, + input_to_layer, + model_name_to_output, + prev_node_name, + ) + continue + + layer_config = layer.get("config") + node_name = _scoped_name(name_scope, layer_config.get("name")) + + node_def = g.node.add() + node_def.name = node_name + + if layer.get("class_name") is not None: + keras_cls_name = layer.get("class_name").encode("ascii") + node_def.attr["keras_class"].s = keras_cls_name + + if layer_config.get("dtype") is not None: + tf_dtype = dtypes.as_dtype(layer_config.get("dtype")) + node_def.attr["dtype"].type = tf_dtype.as_datatype_enum + + if layer.get("inbound_nodes") is not None: + for maybe_inbound_node in layer.get("inbound_nodes"): + inbound_nodes = _norm_to_list_of_layers(maybe_inbound_node) + for [name, size, index, _] in inbound_nodes: + inbound_name = _scoped_name(name_scope, name) + # An input to a layer can be output from a model. In that case, the name + # of inbound_nodes to a layer is a name of a model. Remap the name of the + # model to output layer of the model. Also, since there can be multiple + # outputs in a model, make sure we pick the right output_layer from the model. + inbound_node_names = model_name_to_output.get( + inbound_name, [inbound_name] + ) + node_def.input.append(inbound_node_names[index]) + elif prev_node_name is not None: + node_def.input.append(prev_node_name) + + if node_name in input_to_layer: + node_def.input.append(input_to_layer.get(node_name)) + + prev_node_name = node_def.name + + return g diff --git a/tensorboard/plugins/graph/keras_util_test.py b/tensorboard/plugins/graph/keras_util_test.py index f0565f8f7a..669fafed81 100644 --- a/tensorboard/plugins/graph/keras_util_test.py +++ b/tensorboard/plugins/graph/keras_util_test.py @@ -28,15 +28,15 @@ class KerasUtilTest(tf.test.TestCase): + def assertGraphDefToModel(self, expected_proto, model): + model_config = json.loads(model.to_json()) - def assertGraphDefToModel(self, expected_proto, model): - model_config = json.loads(model.to_json()) + self.assertProtoEquals( + expected_proto, keras_util.keras_model_to_graph_def(model_config) + ) - self.assertProtoEquals( - expected_proto, keras_util.keras_model_to_graph_def(model_config)) - - def test_keras_model_to_graph_def_sequential_model(self): - expected_proto = """ + def test_keras_model_to_graph_def_sequential_model(self): + expected_proto = """ node { name: "sequential/dense" attr { @@ -101,16 +101,18 @@ def test_keras_model_to_graph_def_sequential_model(self): } } """ - model = tf.keras.models.Sequential([ - tf.keras.layers.Dense(32, input_shape=(784,)), - tf.keras.layers.Activation('relu', name='my_relu'), - tf.keras.layers.Dense(10), - tf.keras.layers.Activation('softmax'), - ]) - self.assertGraphDefToModel(expected_proto, model) - - def test_keras_model_to_graph_def_functional_model(self): - expected_proto = """ + model = tf.keras.models.Sequential( + [ + tf.keras.layers.Dense(32, input_shape=(784,)), + tf.keras.layers.Activation("relu", name="my_relu"), + tf.keras.layers.Dense(10), + tf.keras.layers.Activation("softmax"), + ] + ) + self.assertGraphDefToModel(expected_proto, model) + + def test_keras_model_to_graph_def_functional_model(self): + expected_proto = """ node { name: "model/functional_input" attr { @@ -175,16 +177,16 @@ def test_keras_model_to_graph_def_functional_model(self): } } """ - inputs = tf.keras.layers.Input(shape=(784,), name='functional_input') - d0 = tf.keras.layers.Dense(64, activation='relu') - d1 = tf.keras.layers.Dense(64, activation='relu') - d2 = tf.keras.layers.Dense(64, activation='relu') + inputs = tf.keras.layers.Input(shape=(784,), name="functional_input") + d0 = tf.keras.layers.Dense(64, activation="relu") + d1 = tf.keras.layers.Dense(64, activation="relu") + d2 = tf.keras.layers.Dense(64, activation="relu") - model = tf.keras.models.Model(inputs=inputs, outputs=d2(d1(d0(inputs)))) - self.assertGraphDefToModel(expected_proto, model) + model = tf.keras.models.Model(inputs=inputs, outputs=d2(d1(d0(inputs)))) + self.assertGraphDefToModel(expected_proto, model) - def test_keras_model_to_graph_def_functional_model_with_cycle(self): - expected_proto = """ + def test_keras_model_to_graph_def_functional_model_with_cycle(self): + expected_proto = """ node { name: "model/cycle_input" attr { @@ -250,16 +252,18 @@ def test_keras_model_to_graph_def_functional_model_with_cycle(self): } } """ - inputs = tf.keras.layers.Input(shape=(784,), name='cycle_input') - d0 = tf.keras.layers.Dense(64, activation='relu') - d1 = tf.keras.layers.Dense(64, activation='relu') - d2 = tf.keras.layers.Dense(64, activation='relu') + inputs = tf.keras.layers.Input(shape=(784,), name="cycle_input") + d0 = tf.keras.layers.Dense(64, activation="relu") + d1 = tf.keras.layers.Dense(64, activation="relu") + d2 = tf.keras.layers.Dense(64, activation="relu") - model = tf.keras.models.Model(inputs=inputs, outputs=d1(d2(d1(d0(inputs))))) - self.assertGraphDefToModel(expected_proto, model) + model = tf.keras.models.Model( + inputs=inputs, outputs=d1(d2(d1(d0(inputs)))) + ) + self.assertGraphDefToModel(expected_proto, model) - def test_keras_model_to_graph_def_lstm_model(self): - expected_proto = """ + def test_keras_model_to_graph_def_lstm_model(self): + expected_proto = """ node { name: "model/lstm_input" attr { @@ -292,14 +296,14 @@ def test_keras_model_to_graph_def_lstm_model(self): } } """ - inputs = tf.keras.layers.Input(shape=(None, 5), name='lstm_input') - encoder = tf.keras.layers.SimpleRNN(256) + inputs = tf.keras.layers.Input(shape=(None, 5), name="lstm_input") + encoder = tf.keras.layers.SimpleRNN(256) - model = tf.keras.models.Model(inputs=inputs, outputs=encoder(inputs)) - self.assertGraphDefToModel(expected_proto, model) + model = tf.keras.models.Model(inputs=inputs, outputs=encoder(inputs)) + self.assertGraphDefToModel(expected_proto, model) - def test_keras_model_to_graph_def_nested_sequential_model(self): - expected_proto = """ + def test_keras_model_to_graph_def_nested_sequential_model(self): + expected_proto = """ node { name: "sequential_2/sequential_1/sequential/dense" attr { @@ -380,26 +384,29 @@ def test_keras_model_to_graph_def_nested_sequential_model(self): } } """ - sub_sub_model = tf.keras.models.Sequential([ - tf.keras.layers.Dense(32, input_shape=(784,)), - tf.keras.layers.Activation('relu'), - ]) - - sub_model = tf.keras.models.Sequential([ - sub_sub_model, - tf.keras.layers.Activation('relu', name='my_relu'), - ]) - - model = tf.keras.models.Sequential([ - sub_model, - tf.keras.layers.Dense(10), - tf.keras.layers.Activation('softmax'), - ]) - - self.assertGraphDefToModel(expected_proto, model) - - def test_keras_model_to_graph_def_functional_multi_inputs(self): - expected_proto = """ + sub_sub_model = tf.keras.models.Sequential( + [ + tf.keras.layers.Dense(32, input_shape=(784,)), + tf.keras.layers.Activation("relu"), + ] + ) + + sub_model = tf.keras.models.Sequential( + [sub_sub_model, tf.keras.layers.Activation("relu", name="my_relu"),] + ) + + model = tf.keras.models.Sequential( + [ + sub_model, + tf.keras.layers.Dense(10), + tf.keras.layers.Activation("softmax"), + ] + ) + + self.assertGraphDefToModel(expected_proto, model) + + def test_keras_model_to_graph_def_functional_multi_inputs(self): + expected_proto = """ node { name: "model/main_input" attr { @@ -528,28 +535,35 @@ def test_keras_model_to_graph_def_functional_multi_inputs(self): } } """ - main_input = tf.keras.layers.Input(shape=(100,), dtype='int32', name='main_input') - x = tf.keras.layers.Embedding( - output_dim=512, input_dim=10000, input_length=100)(main_input) - rnn_out = tf.keras.layers.SimpleRNN(32)(x) + main_input = tf.keras.layers.Input( + shape=(100,), dtype="int32", name="main_input" + ) + x = tf.keras.layers.Embedding( + output_dim=512, input_dim=10000, input_length=100 + )(main_input) + rnn_out = tf.keras.layers.SimpleRNN(32)(x) - auxiliary_output = tf.keras.layers.Dense( - 1, activation='sigmoid', name='aux_output')(rnn_out) - auxiliary_input = tf.keras.layers.Input(shape=(5,), name='aux_input') + auxiliary_output = tf.keras.layers.Dense( + 1, activation="sigmoid", name="aux_output" + )(rnn_out) + auxiliary_input = tf.keras.layers.Input(shape=(5,), name="aux_input") - x = tf.keras.layers.concatenate([rnn_out, auxiliary_input]) - x = tf.keras.layers.Dense(64, activation='relu')(x) + x = tf.keras.layers.concatenate([rnn_out, auxiliary_input]) + x = tf.keras.layers.Dense(64, activation="relu")(x) - main_output = tf.keras.layers.Dense(1, activation='sigmoid', name='main_output')(x) + main_output = tf.keras.layers.Dense( + 1, activation="sigmoid", name="main_output" + )(x) - model = tf.keras.models.Model( - inputs=[main_input, auxiliary_input], - outputs=[main_output, auxiliary_output]) + model = tf.keras.models.Model( + inputs=[main_input, auxiliary_input], + outputs=[main_output, auxiliary_output], + ) - self.assertGraphDefToModel(expected_proto, model) + self.assertGraphDefToModel(expected_proto, model) - def test_keras_model_to_graph_def_functional_model_as_layer(self): - expected_proto = """ + def test_keras_model_to_graph_def_functional_model_as_layer(self): + expected_proto = """ node { name: "model_1/sub_func_input_2" attr { @@ -676,23 +690,27 @@ def test_keras_model_to_graph_def_functional_model_as_layer(self): } } """ - inputs1 = tf.keras.layers.Input(shape=(784,), name='sub_func_input_1') - inputs2 = tf.keras.layers.Input(shape=(784,), name='sub_func_input_2') - d0 = tf.keras.layers.Dense(64, activation='relu') - d1 = tf.keras.layers.Dense(64, activation='relu') - d2 = tf.keras.layers.Dense(64, activation='relu') + inputs1 = tf.keras.layers.Input(shape=(784,), name="sub_func_input_1") + inputs2 = tf.keras.layers.Input(shape=(784,), name="sub_func_input_2") + d0 = tf.keras.layers.Dense(64, activation="relu") + d1 = tf.keras.layers.Dense(64, activation="relu") + d2 = tf.keras.layers.Dense(64, activation="relu") - sub_model = tf.keras.models.Model(inputs=[inputs2, inputs1], - outputs=[d0(inputs1), d1(inputs2)]) + sub_model = tf.keras.models.Model( + inputs=[inputs2, inputs1], outputs=[d0(inputs1), d1(inputs2)] + ) - main_outputs = d2(tf.keras.layers.concatenate(sub_model([inputs2, inputs1]))) - model = tf.keras.models.Model( - inputs=[inputs2, inputs1], outputs=main_outputs) + main_outputs = d2( + tf.keras.layers.concatenate(sub_model([inputs2, inputs1])) + ) + model = tf.keras.models.Model( + inputs=[inputs2, inputs1], outputs=main_outputs + ) - self.assertGraphDefToModel(expected_proto, model) + self.assertGraphDefToModel(expected_proto, model) - def test_keras_model_to_graph_def_functional_sequential_model(self): - expected_proto = """ + def test_keras_model_to_graph_def_functional_sequential_model(self): + expected_proto = """ node { name: "model/func_seq_input" attr { @@ -757,19 +775,23 @@ def test_keras_model_to_graph_def_functional_sequential_model(self): } } """ - inputs = tf.keras.layers.Input(shape=(784,), name='func_seq_input') - sub_model = tf.keras.models.Sequential([ - tf.keras.layers.Dense(32, input_shape=(784,)), - tf.keras.layers.Activation('relu', name='my_relu'), - ]) - dense = tf.keras.layers.Dense(64, activation='relu') + inputs = tf.keras.layers.Input(shape=(784,), name="func_seq_input") + sub_model = tf.keras.models.Sequential( + [ + tf.keras.layers.Dense(32, input_shape=(784,)), + tf.keras.layers.Activation("relu", name="my_relu"), + ] + ) + dense = tf.keras.layers.Dense(64, activation="relu") - model = tf.keras.models.Model(inputs=inputs, outputs=dense(sub_model(inputs))) + model = tf.keras.models.Model( + inputs=inputs, outputs=dense(sub_model(inputs)) + ) - self.assertGraphDefToModel(expected_proto, model) + self.assertGraphDefToModel(expected_proto, model) - def test_keras_model_to_graph_def_sequential_functional_model(self): - expected_proto = """ + def test_keras_model_to_graph_def_sequential_functional_model(self): + expected_proto = """ node { name: "sequential/model/func_seq_input" attr { @@ -834,18 +856,20 @@ def test_keras_model_to_graph_def_sequential_functional_model(self): } } """ - inputs = tf.keras.layers.Input(shape=(784,), name='func_seq_input') - dense = tf.keras.layers.Dense(64, activation='relu') + inputs = tf.keras.layers.Input(shape=(784,), name="func_seq_input") + dense = tf.keras.layers.Dense(64, activation="relu") - sub_model = tf.keras.models.Model(inputs=inputs, outputs=dense(inputs)) - model = tf.keras.models.Sequential([ - sub_model, - tf.keras.layers.Dense(32, input_shape=(784,)), - tf.keras.layers.Activation('relu', name='my_relu'), - ]) + sub_model = tf.keras.models.Model(inputs=inputs, outputs=dense(inputs)) + model = tf.keras.models.Sequential( + [ + sub_model, + tf.keras.layers.Dense(32, input_shape=(784,)), + tf.keras.layers.Activation("relu", name="my_relu"), + ] + ) - self.assertGraphDefToModel(expected_proto, model) + self.assertGraphDefToModel(expected_proto, model) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/graph/metadata.py b/tensorboard/plugins/graph/metadata.py index e0ef486502..1172ab41d5 100644 --- a/tensorboard/plugins/graph/metadata.py +++ b/tensorboard/plugins/graph/metadata.py @@ -23,4 +23,4 @@ # Note however that different 'plugin names' are used in the context of # graph Summaries. # See `graphs_plugin.py` for details. -PLUGIN_NAME = 'graphs' +PLUGIN_NAME = "graphs" diff --git a/tensorboard/plugins/histogram/histograms_demo.py b/tensorboard/plugins/histogram/histograms_demo.py index f4f29a3301..7b9ab4997b 100644 --- a/tensorboard/plugins/histogram/histograms_demo.py +++ b/tensorboard/plugins/histogram/histograms_demo.py @@ -26,88 +26,110 @@ from tensorboard.plugins.histogram import summary as histogram_summary # Directory into which to write tensorboard data. -LOGDIR = '/tmp/histograms_demo' +LOGDIR = "/tmp/histograms_demo" def run_all(logdir, verbose=False, num_summaries=400): - """Generate a bunch of histogram data, and write it to logdir.""" - del verbose - - tf.compat.v1.set_random_seed(0) - - k = tf.compat.v1.placeholder(tf.float32) - - # Make a normal distribution, with a shifting mean - mean_moving_normal = tf.random.normal(shape=[1000], mean=(5*k), stddev=1) - # Record that distribution into a histogram summary - histogram_summary.op("normal/moving_mean", - mean_moving_normal, - description="A normal distribution whose mean changes " - "over time.") - - # Make a normal distribution with shrinking variance - shrinking_normal = tf.random.normal(shape=[1000], mean=0, stddev=1-(k)) - # Record that distribution too - histogram_summary.op("normal/shrinking_variance", shrinking_normal, - description="A normal distribution whose variance " - "shrinks over time.") - - # Let's combine both of those distributions into one dataset - normal_combined = tf.concat([mean_moving_normal, shrinking_normal], 0) - # We add another histogram summary to record the combined distribution - histogram_summary.op("normal/bimodal", normal_combined, - description="A combination of two normal distributions, " - "one with a moving mean and one with " - "shrinking variance. The result is a " - "distribution that starts as unimodal and " - "becomes more and more bimodal over time.") - - # Add a gamma distribution - gamma = tf.random.gamma(shape=[1000], alpha=k) - histogram_summary.op("gamma", gamma, - description="A gamma distribution whose shape " - "parameter, α, changes over time.") - - # And a poisson distribution - poisson = tf.compat.v1.random_poisson(shape=[1000], lam=k) - histogram_summary.op("poisson", poisson, - description="A Poisson distribution, which only " - "takes on integer values.") - - # And a uniform distribution - uniform = tf.random.uniform(shape=[1000], maxval=k*10) - histogram_summary.op("uniform", uniform, - description="A simple uniform distribution.") - - # Finally, combine everything together! - all_distributions = [mean_moving_normal, shrinking_normal, - gamma, poisson, uniform] - all_combined = tf.concat(all_distributions, 0) - histogram_summary.op("all_combined", all_combined, - description="An amalgamation of five distributions: a " - "uniform distribution, a gamma " - "distribution, a Poisson distribution, and " - "two normal distributions.") - - summaries = tf.compat.v1.summary.merge_all() - - # Setup a session and summary writer - sess = tf.compat.v1.Session() - writer = tf.summary.FileWriter(logdir) - - # Setup a loop and write the summaries to disk - N = num_summaries - for step in xrange(N): - k_val = step/float(N) - summ = sess.run(summaries, feed_dict={k: k_val}) - writer.add_summary(summ, global_step=step) + """Generate a bunch of histogram data, and write it to logdir.""" + del verbose + + tf.compat.v1.set_random_seed(0) + + k = tf.compat.v1.placeholder(tf.float32) + + # Make a normal distribution, with a shifting mean + mean_moving_normal = tf.random.normal(shape=[1000], mean=(5 * k), stddev=1) + # Record that distribution into a histogram summary + histogram_summary.op( + "normal/moving_mean", + mean_moving_normal, + description="A normal distribution whose mean changes " "over time.", + ) + + # Make a normal distribution with shrinking variance + shrinking_normal = tf.random.normal(shape=[1000], mean=0, stddev=1 - (k)) + # Record that distribution too + histogram_summary.op( + "normal/shrinking_variance", + shrinking_normal, + description="A normal distribution whose variance " + "shrinks over time.", + ) + + # Let's combine both of those distributions into one dataset + normal_combined = tf.concat([mean_moving_normal, shrinking_normal], 0) + # We add another histogram summary to record the combined distribution + histogram_summary.op( + "normal/bimodal", + normal_combined, + description="A combination of two normal distributions, " + "one with a moving mean and one with " + "shrinking variance. The result is a " + "distribution that starts as unimodal and " + "becomes more and more bimodal over time.", + ) + + # Add a gamma distribution + gamma = tf.random.gamma(shape=[1000], alpha=k) + histogram_summary.op( + "gamma", + gamma, + description="A gamma distribution whose shape " + "parameter, α, changes over time.", + ) + + # And a poisson distribution + poisson = tf.compat.v1.random_poisson(shape=[1000], lam=k) + histogram_summary.op( + "poisson", + poisson, + description="A Poisson distribution, which only " + "takes on integer values.", + ) + + # And a uniform distribution + uniform = tf.random.uniform(shape=[1000], maxval=k * 10) + histogram_summary.op( + "uniform", uniform, description="A simple uniform distribution." + ) + + # Finally, combine everything together! + all_distributions = [ + mean_moving_normal, + shrinking_normal, + gamma, + poisson, + uniform, + ] + all_combined = tf.concat(all_distributions, 0) + histogram_summary.op( + "all_combined", + all_combined, + description="An amalgamation of five distributions: a " + "uniform distribution, a gamma " + "distribution, a Poisson distribution, and " + "two normal distributions.", + ) + + summaries = tf.compat.v1.summary.merge_all() + + # Setup a session and summary writer + sess = tf.compat.v1.Session() + writer = tf.summary.FileWriter(logdir) + + # Setup a loop and write the summaries to disk + N = num_summaries + for step in xrange(N): + k_val = step / float(N) + summ = sess.run(summaries, feed_dict={k: k_val}) + writer.add_summary(summ, global_step=step) def main(unused_argv): - print('Running histograms demo. Output saving to %s.' % LOGDIR) - run_all(LOGDIR) - print('Done. Output saved to %s.' % LOGDIR) + print("Running histograms demo. Output saving to %s." % LOGDIR) + run_all(LOGDIR) + print("Done. Output saved to %s." % LOGDIR) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/tensorboard/plugins/histogram/histograms_plugin.py b/tensorboard/plugins/histogram/histograms_plugin.py index eff11fddb1..201521a30e 100644 --- a/tensorboard/plugins/histogram/histograms_plugin.py +++ b/tensorboard/plugins/histogram/histograms_plugin.py @@ -40,84 +40,91 @@ class HistogramsPlugin(base_plugin.TBPlugin): - """Histograms Plugin for TensorBoard. + """Histograms Plugin for TensorBoard. - This supports both old-style summaries (created with TensorFlow ops - that output directly to the `histo` field of the proto) and new-style - summaries (as created by the `tensorboard.plugins.histogram.summary` - module). - """ - - plugin_name = metadata.PLUGIN_NAME - - # Use a round number + 1 since sampling includes both start and end steps, - # so N+1 samples corresponds to dividing the step sequence into N intervals. - SAMPLE_SIZE = 51 + This supports both old-style summaries (created with TensorFlow ops + that output directly to the `histo` field of the proto) and new- + style summaries (as created by the + `tensorboard.plugins.histogram.summary` module). + """ - def __init__(self, context): - """Instantiates HistogramsPlugin via TensorBoard core. + plugin_name = metadata.PLUGIN_NAME + + # Use a round number + 1 since sampling includes both start and end steps, + # so N+1 samples corresponds to dividing the step sequence into N intervals. + SAMPLE_SIZE = 51 + + def __init__(self, context): + """Instantiates HistogramsPlugin via TensorBoard core. + + Args: + context: A base_plugin.TBContext instance. + """ + self._multiplexer = context.multiplexer + self._db_connection_provider = context.db_connection_provider + if context.flags and context.flags.generic_data == "true": + self._data_provider = context.data_provider + else: + self._data_provider = None + + def get_plugin_apps(self): + return { + "/histograms": self.histograms_route, + "/tags": self.tags_route, + } - Args: - context: A base_plugin.TBContext instance. - """ - self._multiplexer = context.multiplexer - self._db_connection_provider = context.db_connection_provider - if context.flags and context.flags.generic_data == 'true': - self._data_provider = context.data_provider - else: - self._data_provider = None - - def get_plugin_apps(self): - return { - '/histograms': self.histograms_route, - '/tags': self.tags_route, - } - - def is_active(self): - """This plugin is active iff any run has at least one histograms tag.""" - if self._data_provider: - # We don't have an experiment ID, and modifying the backend core - # to provide one would break backward compatibility. Hack for now. - return True - - if self._db_connection_provider: - # The plugin is active if one relevant tag can be found in the database. - db = self._db_connection_provider() - cursor = db.execute(''' + def is_active(self): + """This plugin is active iff any run has at least one histograms + tag.""" + if self._data_provider: + # We don't have an experiment ID, and modifying the backend core + # to provide one would break backward compatibility. Hack for now. + return True + + if self._db_connection_provider: + # The plugin is active if one relevant tag can be found in the database. + db = self._db_connection_provider() + cursor = db.execute( + """ SELECT 1 FROM Tags WHERE Tags.plugin_name = ? LIMIT 1 - ''', (metadata.PLUGIN_NAME,)) - return bool(list(cursor)) - - if self._multiplexer: - return any(self.index_impl(experiment='').values()) - - return False - - def index_impl(self, experiment): - """Return {runName: {tagName: {displayName: ..., description: ...}}}.""" - if self._data_provider: - mapping = self._data_provider.list_tensors( - experiment_id=experiment, - plugin_name=metadata.PLUGIN_NAME, - ) - result = {run: {} for run in mapping} - for (run, tag_to_content) in six.iteritems(mapping): - for (tag, metadatum) in six.iteritems(tag_to_content): - description = plugin_util.markdown_to_safe_html(metadatum.description) - result[run][tag] = { - 'displayName': metadatum.display_name, - 'description': description, - } - return result - - if self._db_connection_provider: - # Read tags from the database. - db = self._db_connection_provider() - cursor = db.execute(''' + """, + (metadata.PLUGIN_NAME,), + ) + return bool(list(cursor)) + + if self._multiplexer: + return any(self.index_impl(experiment="").values()) + + return False + + def index_impl(self, experiment): + """Return {runName: {tagName: {displayName: ..., description: + ...}}}.""" + if self._data_provider: + mapping = self._data_provider.list_tensors( + experiment_id=experiment, plugin_name=metadata.PLUGIN_NAME, + ) + result = {run: {} for run in mapping} + for (run, tag_to_content) in six.iteritems(mapping): + for (tag, metadatum) in six.iteritems(tag_to_content): + description = plugin_util.markdown_to_safe_html( + metadatum.description + ) + result[run][tag] = { + "displayName": metadatum.display_name, + "description": description, + } + return result + + if self._db_connection_provider: + # Read tags from the database. + db = self._db_connection_provider() + cursor = db.execute( + """ SELECT Tags.tag_name, Tags.display_name, @@ -127,77 +134,85 @@ def index_impl(self, experiment): ON Tags.run_id = Runs.run_id WHERE Tags.plugin_name = ? - ''', (metadata.PLUGIN_NAME,)) - result = collections.defaultdict(dict) - for row in cursor: - tag_name, display_name, run_name = row - result[run_name][tag_name] = { - 'displayName': display_name, - # TODO(chihuahua): Populate the description. Currently, the tags - # table does not link with the description table. - 'description': '', - } - return result - - runs = self._multiplexer.Runs() - result = collections.defaultdict(lambda: {}) - - mapping = self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME) - for (run, tag_to_content) in six.iteritems(mapping): - for (tag, content) in six.iteritems(tag_to_content): - content = metadata.parse_plugin_metadata(content) - summary_metadata = self._multiplexer.SummaryMetadata(run, tag) - result[run][tag] = {'displayName': summary_metadata.display_name, - 'description': plugin_util.markdown_to_safe_html( - summary_metadata.summary_description)} - - return result - - def frontend_metadata(self): - return base_plugin.FrontendMetadata(element_name='tf-histogram-dashboard') - - def histograms_impl(self, tag, run, experiment, downsample_to=None): - """Result of the form `(body, mime_type)`. - - At most `downsample_to` events will be returned. If this value is - `None`, then no downsampling will be performed. - - Raises: - tensorboard.errors.PublicError: On invalid request. - """ - if self._data_provider: - # Downsample reads to 500 histograms per time series, which is - # the default size guidance for histograms under the multiplexer - # loading logic. - SAMPLE_COUNT = downsample_to if downsample_to is not None else 500 - all_histograms = self._data_provider.read_tensors( - experiment_id=experiment, - plugin_name=metadata.PLUGIN_NAME, - downsample=SAMPLE_COUNT, - run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]), - ) - histograms = all_histograms.get(run, {}).get(tag, None) - if histograms is None: - raise errors.NotFoundError( - "No histogram tag %r for run %r" % (tag, run) + """, + (metadata.PLUGIN_NAME,), + ) + result = collections.defaultdict(dict) + for row in cursor: + tag_name, display_name, run_name = row + result[run_name][tag_name] = { + "displayName": display_name, + # TODO(chihuahua): Populate the description. Currently, the tags + # table does not link with the description table. + "description": "", + } + return result + + runs = self._multiplexer.Runs() + result = collections.defaultdict(lambda: {}) + + mapping = self._multiplexer.PluginRunToTagToContent( + metadata.PLUGIN_NAME + ) + for (run, tag_to_content) in six.iteritems(mapping): + for (tag, content) in six.iteritems(tag_to_content): + content = metadata.parse_plugin_metadata(content) + summary_metadata = self._multiplexer.SummaryMetadata(run, tag) + result[run][tag] = { + "displayName": summary_metadata.display_name, + "description": plugin_util.markdown_to_safe_html( + summary_metadata.summary_description + ), + } + + return result + + def frontend_metadata(self): + return base_plugin.FrontendMetadata( + element_name="tf-histogram-dashboard" ) - # Downsample again, even though the data provider is supposed to, - # because the multiplexer provider currently doesn't. (For - # well-behaved data providers, this is a no-op.) - if downsample_to is not None: - rng = random.Random(0) - histograms = _downsample(rng, histograms, downsample_to) - events = [ - (e.wall_time, e.step, e.numpy.tolist()) - for e in histograms - ] - elif self._db_connection_provider: - # Serve data from the database. - db = self._db_connection_provider() - cursor = db.cursor() - # Prefetch the tag ID matching this run and tag. - cursor.execute( - ''' + + def histograms_impl(self, tag, run, experiment, downsample_to=None): + """Result of the form `(body, mime_type)`. + + At most `downsample_to` events will be returned. If this value is + `None`, then no downsampling will be performed. + + Raises: + tensorboard.errors.PublicError: On invalid request. + """ + if self._data_provider: + # Downsample reads to 500 histograms per time series, which is + # the default size guidance for histograms under the multiplexer + # loading logic. + SAMPLE_COUNT = downsample_to if downsample_to is not None else 500 + all_histograms = self._data_provider.read_tensors( + experiment_id=experiment, + plugin_name=metadata.PLUGIN_NAME, + downsample=SAMPLE_COUNT, + run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]), + ) + histograms = all_histograms.get(run, {}).get(tag, None) + if histograms is None: + raise errors.NotFoundError( + "No histogram tag %r for run %r" % (tag, run) + ) + # Downsample again, even though the data provider is supposed to, + # because the multiplexer provider currently doesn't. (For + # well-behaved data providers, this is a no-op.) + if downsample_to is not None: + rng = random.Random(0) + histograms = _downsample(rng, histograms, downsample_to) + events = [ + (e.wall_time, e.step, e.numpy.tolist()) for e in histograms + ] + elif self._db_connection_provider: + # Serve data from the database. + db = self._db_connection_provider() + cursor = db.cursor() + # Prefetch the tag ID matching this run and tag. + cursor.execute( + """ SELECT tag_id FROM Tags @@ -206,24 +221,25 @@ def histograms_impl(self, tag, run, experiment, downsample_to=None): Runs.run_name = :run AND Tags.tag_name = :tag AND Tags.plugin_name = :plugin - ''', - {'run': run, 'tag': tag, 'plugin': metadata.PLUGIN_NAME}) - row = cursor.fetchone() - if not row: - raise errors.NotFoundError( - 'No histogram tag %r for run %r' % (tag, run) - ) - (tag_id,) = row - # Fetch tensor values, optionally with linear-spaced sampling by step. - # For steps ranging from s_min to s_max and sample size k, this query - # divides the range into k - 1 equal-sized intervals and returns the - # lowest step at or above each of the k interval boundaries (which always - # includes s_min and s_max, and may be fewer than k results if there are - # intervals where no steps are present). For contiguous steps the results - # can be formally expressed as the following: - # [s_min + math.ceil(i / k * (s_max - s_min)) for i in range(0, k + 1)] - cursor.execute( - ''' + """, + {"run": run, "tag": tag, "plugin": metadata.PLUGIN_NAME}, + ) + row = cursor.fetchone() + if not row: + raise errors.NotFoundError( + "No histogram tag %r for run %r" % (tag, run) + ) + (tag_id,) = row + # Fetch tensor values, optionally with linear-spaced sampling by step. + # For steps ranging from s_min to s_max and sample size k, this query + # divides the range into k - 1 equal-sized intervals and returns the + # lowest step at or above each of the k interval boundaries (which always + # includes s_min and s_max, and may be fewer than k results if there are + # intervals where no steps are present). For contiguous steps the results + # can be formally expressed as the following: + # [s_min + math.ceil(i / k * (s_max - s_min)) for i in range(0, k + 1)] + cursor.execute( + """ SELECT MIN(step) AS step, computed_time, @@ -247,75 +263,88 @@ def histograms_impl(self, tag, run, experiment, downsample_to=None): IFNULL(:sample_size - 1, max_step - min_step) * (step - min_step) / (max_step - min_step) ORDER BY step - ''', - {'tag_id': tag_id, 'sample_size': downsample_to}) - events = [(computed_time, step, self._get_values(data, dtype, shape)) - for step, computed_time, data, dtype, shape in cursor] - else: - # Serve data from events files. - try: - tensor_events = self._multiplexer.Tensors(run, tag) - except KeyError: - raise errors.NotFoundError( - 'No histogram tag %r for run %r' % (tag, run) + """, + {"tag_id": tag_id, "sample_size": downsample_to}, + ) + events = [ + (computed_time, step, self._get_values(data, dtype, shape)) + for step, computed_time, data, dtype, shape in cursor + ] + else: + # Serve data from events files. + try: + tensor_events = self._multiplexer.Tensors(run, tag) + except KeyError: + raise errors.NotFoundError( + "No histogram tag %r for run %r" % (tag, run) + ) + if downsample_to is not None: + rng = random.Random(0) + tensor_events = _downsample(rng, tensor_events, downsample_to) + events = [ + [ + e.wall_time, + e.step, + tensor_util.make_ndarray(e.tensor_proto).tolist(), + ] + for e in tensor_events + ] + return (events, "application/json") + + def _get_values(self, data_blob, dtype_enum, shape_string): + """Obtains values for histogram data given blob and dtype enum. + + Args: + data_blob: The blob obtained from the database. + dtype_enum: The enum representing the dtype. + shape_string: A comma-separated string of numbers denoting shape. + Returns: + The histogram values as a list served to the frontend. + """ + buf = np.frombuffer( + data_blob, dtype=tf.DType(dtype_enum).as_numpy_dtype ) - if downsample_to is not None: - rng = random.Random(0) - tensor_events = _downsample(rng, tensor_events, downsample_to) - events = [[e.wall_time, e.step, tensor_util.make_ndarray(e.tensor_proto).tolist()] - for e in tensor_events] - return (events, 'application/json') - - def _get_values(self, data_blob, dtype_enum, shape_string): - """Obtains values for histogram data given blob and dtype enum. - Args: - data_blob: The blob obtained from the database. - dtype_enum: The enum representing the dtype. - shape_string: A comma-separated string of numbers denoting shape. - Returns: - The histogram values as a list served to the frontend. - """ - buf = np.frombuffer(data_blob, dtype=tf.DType(dtype_enum).as_numpy_dtype) - return buf.reshape([int(i) for i in shape_string.split(',')]).tolist() - - @wrappers.Request.application - def tags_route(self, request): - experiment = plugin_util.experiment_id(request.environ) - index = self.index_impl(experiment=experiment) - return http_util.Respond(request, index, 'application/json') - - @wrappers.Request.application - def histograms_route(self, request): - """Given a tag and single run, return array of histogram values.""" - experiment = plugin_util.experiment_id(request.environ) - tag = request.args.get('tag') - run = request.args.get('run') - (body, mime_type) = self.histograms_impl( - tag, run, experiment=experiment, downsample_to=self.SAMPLE_SIZE) - return http_util.Respond(request, body, mime_type) + return buf.reshape([int(i) for i in shape_string.split(",")]).tolist() + + @wrappers.Request.application + def tags_route(self, request): + experiment = plugin_util.experiment_id(request.environ) + index = self.index_impl(experiment=experiment) + return http_util.Respond(request, index, "application/json") + + @wrappers.Request.application + def histograms_route(self, request): + """Given a tag and single run, return array of histogram values.""" + experiment = plugin_util.experiment_id(request.environ) + tag = request.args.get("tag") + run = request.args.get("run") + (body, mime_type) = self.histograms_impl( + tag, run, experiment=experiment, downsample_to=self.SAMPLE_SIZE + ) + return http_util.Respond(request, body, mime_type) def _downsample(rng, xs, k): - """Uniformly choose a maximal at-most-`k`-subsequence of `xs`. + """Uniformly choose a maximal at-most-`k`-subsequence of `xs`. - If `k` is larger than `xs`, then the contents of `xs` itself will be - returned. + If `k` is larger than `xs`, then the contents of `xs` itself will be + returned. - This differs from `random.sample` in that it returns a subsequence - (i.e., order is preserved) and that it permits `k > len(xs)`. + This differs from `random.sample` in that it returns a subsequence + (i.e., order is preserved) and that it permits `k > len(xs)`. - Args: - rng: A `random` interface. - xs: A sequence (`collections.abc.Sequence`). - k: A non-negative integer. + Args: + rng: A `random` interface. + xs: A sequence (`collections.abc.Sequence`). + k: A non-negative integer. - Returns: - A new list whose elements are a subsequence of `xs` of length - `min(k, len(xs))`, uniformly selected among such subsequences. - """ + Returns: + A new list whose elements are a subsequence of `xs` of length + `min(k, len(xs))`, uniformly selected among such subsequences. + """ - if k > len(xs): - return list(xs) - indices = rng.sample(six.moves.xrange(len(xs)), k) - indices.sort() - return [xs[i] for i in indices] + if k > len(xs): + return list(xs) + indices = rng.sample(six.moves.xrange(len(xs)), k) + indices.sort() + return [xs[i] for i in indices] diff --git a/tensorboard/plugins/histogram/histograms_plugin_test.py b/tensorboard/plugins/histogram/histograms_plugin_test.py index 818f9d4a88..b9f635da27 100644 --- a/tensorboard/plugins/histogram/histograms_plugin_test.py +++ b/tensorboard/plugins/histogram/histograms_plugin_test.py @@ -29,8 +29,12 @@ from tensorboard import errors from tensorboard.backend.event_processing import data_provider -from tensorboard.backend.event_processing import plugin_event_accumulator as event_accumulator -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from tensorboard.backend.event_processing import ( + plugin_event_accumulator as event_accumulator, +) +from tensorboard.backend.event_processing import ( + plugin_event_multiplexer as event_multiplexer, +) from tensorboard.plugins import base_plugin from tensorboard.plugins.histogram import histograms_plugin from tensorboard.plugins.histogram import summary @@ -41,194 +45,232 @@ class HistogramsPluginTest(tf.test.TestCase): - _STEPS = 99 - - _LEGACY_HISTOGRAM_TAG = 'my-ancient-histogram' - _HISTOGRAM_TAG = 'my-favorite-histogram' - _SCALAR_TAG = 'my-boring-scalars' - - _DISPLAY_NAME = 'Important production statistics' - _DESCRIPTION = 'quod *erat* scribendum' - _HTML_DESCRIPTION = '

quod erat scribendum

' - - _RUN_WITH_LEGACY_HISTOGRAM = '_RUN_WITH_LEGACY_HISTOGRAM' - _RUN_WITH_HISTOGRAM = '_RUN_WITH_HISTOGRAM' - _RUN_WITH_SCALARS = '_RUN_WITH_SCALARS' - - def __init__(self, *args, **kwargs): - super(HistogramsPluginTest, self).__init__(*args, **kwargs) - self.logdir = None - - def load_runs(self, run_names): - logdir = self.get_temp_dir() - for run_name in run_names: - self.generate_run(logdir, run_name) - multiplexer = event_multiplexer.EventMultiplexer(size_guidance={ - # don't truncate my test data, please - event_accumulator.TENSORS: self._STEPS, - }) - multiplexer.AddRunsFromDirectory(logdir) - multiplexer.Reload() - return (logdir, multiplexer) - - def with_runs(run_names): - """Run a test with a bare multiplexer and with a `data_provider`. - - The decorated function will receive an initialized `HistogramsPlugin` - object as its first positional argument. - """ - def decorator(fn): - @functools.wraps(fn) - def wrapper(self, *args, **kwargs): - (logdir, multiplexer) = self.load_runs(run_names) - with self.subTest('bare multiplexer'): - ctx = base_plugin.TBContext(logdir=logdir, multiplexer=multiplexer) - fn(self, histograms_plugin.HistogramsPlugin(ctx), *args, **kwargs) - with self.subTest('generic data provider'): - flags = argparse.Namespace(generic_data='true') - provider = data_provider.MultiplexerDataProvider(multiplexer, logdir) - ctx = base_plugin.TBContext( - flags=flags, - logdir=logdir, - multiplexer=multiplexer, - data_provider=provider, - ) - fn(self, histograms_plugin.HistogramsPlugin(ctx), *args, **kwargs) - return wrapper - return decorator - - def generate_run(self, logdir, run_name): - tf.compat.v1.reset_default_graph() - sess = tf.compat.v1.Session() - placeholder = tf.compat.v1.placeholder(tf.float32, shape=[3]) - - if run_name == self._RUN_WITH_LEGACY_HISTOGRAM: - tf.compat.v1.summary.histogram(self._LEGACY_HISTOGRAM_TAG, placeholder) - elif run_name == self._RUN_WITH_HISTOGRAM: - summary.op(self._HISTOGRAM_TAG, placeholder, - display_name=self._DISPLAY_NAME, - description=self._DESCRIPTION) - elif run_name == self._RUN_WITH_SCALARS: - tf.compat.v1.summary.scalar(self._SCALAR_TAG, tf.reduce_mean(input_tensor=placeholder)) - else: - assert False, 'Invalid run name: %r' % run_name - summ = tf.compat.v1.summary.merge_all() - - subdir = os.path.join(logdir, run_name) - with test_util.FileWriterCache.get(subdir) as writer: - writer.add_graph(sess.graph) - for step in xrange(self._STEPS): - feed_dict = {placeholder: [1 + step, 2 + step, 3 + step]} - s = sess.run(summ, feed_dict=feed_dict) - writer.add_summary(s, global_step=step) - - @with_runs([_RUN_WITH_SCALARS]) - def test_routes_provided(self, plugin): - """Tests that the plugin offers the correct routes.""" - routes = plugin.get_plugin_apps() - self.assertIsInstance(routes['/histograms'], collections.Callable) - self.assertIsInstance(routes['/tags'], collections.Callable) - - @with_runs([ - _RUN_WITH_SCALARS, - _RUN_WITH_LEGACY_HISTOGRAM, - _RUN_WITH_HISTOGRAM, - ]) - def test_index(self, plugin): - self.assertEqual({ - # _RUN_WITH_SCALARS omitted: No histogram data. - self._RUN_WITH_LEGACY_HISTOGRAM: { - self._LEGACY_HISTOGRAM_TAG: { - 'displayName': self._LEGACY_HISTOGRAM_TAG, - 'description': '', - }, - }, - self._RUN_WITH_HISTOGRAM: { - '%s/histogram_summary' % self._HISTOGRAM_TAG: { - 'displayName': self._DISPLAY_NAME, - 'description': self._HTML_DESCRIPTION, + _STEPS = 99 + + _LEGACY_HISTOGRAM_TAG = "my-ancient-histogram" + _HISTOGRAM_TAG = "my-favorite-histogram" + _SCALAR_TAG = "my-boring-scalars" + + _DISPLAY_NAME = "Important production statistics" + _DESCRIPTION = "quod *erat* scribendum" + _HTML_DESCRIPTION = "

quod erat scribendum

" + + _RUN_WITH_LEGACY_HISTOGRAM = "_RUN_WITH_LEGACY_HISTOGRAM" + _RUN_WITH_HISTOGRAM = "_RUN_WITH_HISTOGRAM" + _RUN_WITH_SCALARS = "_RUN_WITH_SCALARS" + + def __init__(self, *args, **kwargs): + super(HistogramsPluginTest, self).__init__(*args, **kwargs) + self.logdir = None + + def load_runs(self, run_names): + logdir = self.get_temp_dir() + for run_name in run_names: + self.generate_run(logdir, run_name) + multiplexer = event_multiplexer.EventMultiplexer( + size_guidance={ + # don't truncate my test data, please + event_accumulator.TENSORS: self._STEPS, + } + ) + multiplexer.AddRunsFromDirectory(logdir) + multiplexer.Reload() + return (logdir, multiplexer) + + def with_runs(run_names): + """Run a test with a bare multiplexer and with a `data_provider`. + + The decorated function will receive an initialized + `HistogramsPlugin` object as its first positional argument. + """ + + def decorator(fn): + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + (logdir, multiplexer) = self.load_runs(run_names) + with self.subTest("bare multiplexer"): + ctx = base_plugin.TBContext( + logdir=logdir, multiplexer=multiplexer + ) + fn( + self, + histograms_plugin.HistogramsPlugin(ctx), + *args, + **kwargs + ) + with self.subTest("generic data provider"): + flags = argparse.Namespace(generic_data="true") + provider = data_provider.MultiplexerDataProvider( + multiplexer, logdir + ) + ctx = base_plugin.TBContext( + flags=flags, + logdir=logdir, + multiplexer=multiplexer, + data_provider=provider, + ) + fn( + self, + histograms_plugin.HistogramsPlugin(ctx), + *args, + **kwargs + ) + + return wrapper + + return decorator + + def generate_run(self, logdir, run_name): + tf.compat.v1.reset_default_graph() + sess = tf.compat.v1.Session() + placeholder = tf.compat.v1.placeholder(tf.float32, shape=[3]) + + if run_name == self._RUN_WITH_LEGACY_HISTOGRAM: + tf.compat.v1.summary.histogram( + self._LEGACY_HISTOGRAM_TAG, placeholder + ) + elif run_name == self._RUN_WITH_HISTOGRAM: + summary.op( + self._HISTOGRAM_TAG, + placeholder, + display_name=self._DISPLAY_NAME, + description=self._DESCRIPTION, + ) + elif run_name == self._RUN_WITH_SCALARS: + tf.compat.v1.summary.scalar( + self._SCALAR_TAG, tf.reduce_mean(input_tensor=placeholder) + ) + else: + assert False, "Invalid run name: %r" % run_name + summ = tf.compat.v1.summary.merge_all() + + subdir = os.path.join(logdir, run_name) + with test_util.FileWriterCache.get(subdir) as writer: + writer.add_graph(sess.graph) + for step in xrange(self._STEPS): + feed_dict = {placeholder: [1 + step, 2 + step, 3 + step]} + s = sess.run(summ, feed_dict=feed_dict) + writer.add_summary(s, global_step=step) + + @with_runs([_RUN_WITH_SCALARS]) + def test_routes_provided(self, plugin): + """Tests that the plugin offers the correct routes.""" + routes = plugin.get_plugin_apps() + self.assertIsInstance(routes["/histograms"], collections.Callable) + self.assertIsInstance(routes["/tags"], collections.Callable) + + @with_runs( + [_RUN_WITH_SCALARS, _RUN_WITH_LEGACY_HISTOGRAM, _RUN_WITH_HISTOGRAM,] + ) + def test_index(self, plugin): + self.assertEqual( + { + # _RUN_WITH_SCALARS omitted: No histogram data. + self._RUN_WITH_LEGACY_HISTOGRAM: { + self._LEGACY_HISTOGRAM_TAG: { + "displayName": self._LEGACY_HISTOGRAM_TAG, + "description": "", + }, + }, + self._RUN_WITH_HISTOGRAM: { + "%s/histogram_summary" + % self._HISTOGRAM_TAG: { + "displayName": self._DISPLAY_NAME, + "description": self._HTML_DESCRIPTION, + }, + }, }, - }, - }, plugin.index_impl(experiment='exp')) - - @with_runs([ - _RUN_WITH_SCALARS, - _RUN_WITH_LEGACY_HISTOGRAM, - _RUN_WITH_HISTOGRAM, - ]) - def _test_histograms(self, plugin, run_name, tag_name, should_work=True): - if should_work: - self._check_histograms_result(plugin, tag_name, run_name, downsample=False) - self._check_histograms_result(plugin, tag_name, run_name, downsample=True) - else: - with self.assertRaises(errors.NotFoundError): - plugin.histograms_impl(self._HISTOGRAM_TAG, run_name, experiment='exp') - - def _check_histograms_result(self, plugin, tag_name, run_name, downsample): - if downsample: - downsample_to = 50 - expected_length = 50 - else: - downsample_to = None - expected_length = self._STEPS - - (data, mime_type) = plugin.histograms_impl( - tag_name, run_name, experiment='exp', downsample_to=downsample_to + plugin.index_impl(experiment="exp"), + ) + + @with_runs( + [_RUN_WITH_SCALARS, _RUN_WITH_LEGACY_HISTOGRAM, _RUN_WITH_HISTOGRAM,] ) - self.assertEqual('application/json', mime_type) - self.assertEqual(expected_length, len(data), - 'expected %r, got %r (downsample=%r)' - % (expected_length, len(data), downsample)) - last_step_seen = None - for (i, datum) in enumerate(data): - [_unused_wall_time, step, buckets] = datum - if last_step_seen is not None: - self.assertGreater(step, last_step_seen) - last_step_seen = step - if not downsample: - self.assertEqual(i, step) - self.assertEqual(1 + step, buckets[0][0]) # first left-edge - self.assertEqual(3 + step, buckets[-1][1]) # last right-edge - self.assertAlmostEqual( - 3, # three items across all buckets - sum(bucket[2] for bucket in buckets)) - - def test_histograms_with_scalars(self): - self._test_histograms(self._RUN_WITH_SCALARS, self._HISTOGRAM_TAG, - should_work=False) - - def test_histograms_with_legacy_histogram(self): - self._test_histograms(self._RUN_WITH_LEGACY_HISTOGRAM, - self._LEGACY_HISTOGRAM_TAG) - - def test_histograms_with_histogram(self): - self._test_histograms(self._RUN_WITH_HISTOGRAM, - '%s/histogram_summary' % self._HISTOGRAM_TAG) - - @with_runs([_RUN_WITH_LEGACY_HISTOGRAM]) - def test_active_with_legacy_histogram(self, plugin): - self.assertTrue(plugin.is_active()) - - @with_runs([_RUN_WITH_HISTOGRAM]) - def test_active_with_histogram(self, plugin): - self.assertTrue(plugin.is_active()) - - @with_runs([_RUN_WITH_SCALARS]) - def test_active_with_scalars(self, plugin): - if plugin._data_provider: - # Hack, for now. - self.assertTrue(plugin.is_active()) - else: - self.assertFalse(plugin.is_active()) - - @with_runs([ - _RUN_WITH_SCALARS, - _RUN_WITH_LEGACY_HISTOGRAM, - _RUN_WITH_HISTOGRAM, - ]) - def test_active_with_all(self, plugin): - self.assertTrue(plugin.is_active()) - - -if __name__ == '__main__': - tf.test.main() + def _test_histograms(self, plugin, run_name, tag_name, should_work=True): + if should_work: + self._check_histograms_result( + plugin, tag_name, run_name, downsample=False + ) + self._check_histograms_result( + plugin, tag_name, run_name, downsample=True + ) + else: + with self.assertRaises(errors.NotFoundError): + plugin.histograms_impl( + self._HISTOGRAM_TAG, run_name, experiment="exp" + ) + + def _check_histograms_result(self, plugin, tag_name, run_name, downsample): + if downsample: + downsample_to = 50 + expected_length = 50 + else: + downsample_to = None + expected_length = self._STEPS + + (data, mime_type) = plugin.histograms_impl( + tag_name, run_name, experiment="exp", downsample_to=downsample_to + ) + self.assertEqual("application/json", mime_type) + self.assertEqual( + expected_length, + len(data), + "expected %r, got %r (downsample=%r)" + % (expected_length, len(data), downsample), + ) + last_step_seen = None + for (i, datum) in enumerate(data): + [_unused_wall_time, step, buckets] = datum + if last_step_seen is not None: + self.assertGreater(step, last_step_seen) + last_step_seen = step + if not downsample: + self.assertEqual(i, step) + self.assertEqual(1 + step, buckets[0][0]) # first left-edge + self.assertEqual(3 + step, buckets[-1][1]) # last right-edge + self.assertAlmostEqual( + 3, # three items across all buckets + sum(bucket[2] for bucket in buckets), + ) + + def test_histograms_with_scalars(self): + self._test_histograms( + self._RUN_WITH_SCALARS, self._HISTOGRAM_TAG, should_work=False + ) + + def test_histograms_with_legacy_histogram(self): + self._test_histograms( + self._RUN_WITH_LEGACY_HISTOGRAM, self._LEGACY_HISTOGRAM_TAG + ) + + def test_histograms_with_histogram(self): + self._test_histograms( + self._RUN_WITH_HISTOGRAM, + "%s/histogram_summary" % self._HISTOGRAM_TAG, + ) + + @with_runs([_RUN_WITH_LEGACY_HISTOGRAM]) + def test_active_with_legacy_histogram(self, plugin): + self.assertTrue(plugin.is_active()) + + @with_runs([_RUN_WITH_HISTOGRAM]) + def test_active_with_histogram(self, plugin): + self.assertTrue(plugin.is_active()) + + @with_runs([_RUN_WITH_SCALARS]) + def test_active_with_scalars(self, plugin): + if plugin._data_provider: + # Hack, for now. + self.assertTrue(plugin.is_active()) + else: + self.assertFalse(plugin.is_active()) + + @with_runs( + [_RUN_WITH_SCALARS, _RUN_WITH_LEGACY_HISTOGRAM, _RUN_WITH_HISTOGRAM,] + ) + def test_active_with_all(self, plugin): + self.assertTrue(plugin.is_active()) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/histogram/metadata.py b/tensorboard/plugins/histogram/metadata.py index 456d372179..a3ce604fde 100644 --- a/tensorboard/plugins/histogram/metadata.py +++ b/tensorboard/plugins/histogram/metadata.py @@ -24,7 +24,7 @@ logger = tb_logging.get_logger() -PLUGIN_NAME = 'histograms' +PLUGIN_NAME = "histograms" # The most recent value for the `version` field of the # `HistogramPluginData` proto. @@ -32,42 +32,46 @@ def create_summary_metadata(display_name, description): - """Create a `summary_pb2.SummaryMetadata` proto for histogram plugin data. + """Create a `summary_pb2.SummaryMetadata` proto for histogram plugin data. - Returns: - A `summary_pb2.SummaryMetadata` protobuf object. - """ - content = plugin_data_pb2.HistogramPluginData(version=PROTO_VERSION) - return summary_pb2.SummaryMetadata( - display_name=display_name, - summary_description=description, - plugin_data=summary_pb2.SummaryMetadata.PluginData( - plugin_name=PLUGIN_NAME, - content=content.SerializeToString())) + Returns: + A `summary_pb2.SummaryMetadata` protobuf object. + """ + content = plugin_data_pb2.HistogramPluginData(version=PROTO_VERSION) + return summary_pb2.SummaryMetadata( + display_name=display_name, + summary_description=description, + plugin_data=summary_pb2.SummaryMetadata.PluginData( + plugin_name=PLUGIN_NAME, content=content.SerializeToString() + ), + ) def parse_plugin_metadata(content): - """Parse summary metadata to a Python object. + """Parse summary metadata to a Python object. - Arguments: - content: The `content` field of a `SummaryMetadata` proto - corresponding to the histogram plugin. + Arguments: + content: The `content` field of a `SummaryMetadata` proto + corresponding to the histogram plugin. - Returns: - A `HistogramPluginData` protobuf object. - """ - if not isinstance(content, bytes): - raise TypeError('Content type must be bytes') - if content == b'{}': - # Old-style JSON format. Equivalent to an all-default proto. - return plugin_data_pb2.HistogramPluginData() - else: - result = plugin_data_pb2.HistogramPluginData.FromString(content) - if result.version == 0: - return result + Returns: + A `HistogramPluginData` protobuf object. + """ + if not isinstance(content, bytes): + raise TypeError("Content type must be bytes") + if content == b"{}": + # Old-style JSON format. Equivalent to an all-default proto. + return plugin_data_pb2.HistogramPluginData() else: - logger.warn( - 'Unknown metadata version: %s. The latest version known to ' - 'this build of TensorBoard is %s; perhaps a newer build is ' - 'available?', result.version, PROTO_VERSION) - return result + result = plugin_data_pb2.HistogramPluginData.FromString(content) + if result.version == 0: + return result + else: + logger.warn( + "Unknown metadata version: %s. The latest version known to " + "this build of TensorBoard is %s; perhaps a newer build is " + "available?", + result.version, + PROTO_VERSION, + ) + return result diff --git a/tensorboard/plugins/histogram/summary.py b/tensorboard/plugins/histogram/summary.py index 3e6beac03f..36c144ded9 100644 --- a/tensorboard/plugins/histogram/summary.py +++ b/tensorboard/plugins/histogram/summary.py @@ -44,167 +44,191 @@ def _buckets(data, bucket_count=None): - """Create a TensorFlow op to group data into histogram buckets. - - Arguments: - data: A `Tensor` of any shape. Must be castable to `float64`. - bucket_count: Optional positive `int` or scalar `int32` `Tensor`. - Returns: - A `Tensor` of shape `[k, 3]` and type `float64`. The `i`th row is - a triple `[left_edge, right_edge, count]` for a single bucket. - The value of `k` is either `bucket_count` or `1` or `0`. - """ - # TODO(nickfelt): remove on-demand imports once dep situation is fixed. - import tensorflow.compat.v1 as tf - if bucket_count is None: - bucket_count = summary_v2.DEFAULT_BUCKET_COUNT - with tf.name_scope('buckets', values=[data, bucket_count]), \ - tf.control_dependencies([tf.assert_scalar(bucket_count), - tf.assert_type(bucket_count, tf.int32)]): - data = tf.reshape(data, shape=[-1]) # flatten - data = tf.cast(data, tf.float64) - is_empty = tf.equal(tf.size(input=data), 0) - - def when_empty(): - return tf.constant([], shape=(0, 3), dtype=tf.float64) - - def when_nonempty(): - min_ = tf.reduce_min(input_tensor=data) - max_ = tf.reduce_max(input_tensor=data) - range_ = max_ - min_ - is_singular = tf.equal(range_, 0) - - def when_nonsingular(): - bucket_width = range_ / tf.cast(bucket_count, tf.float64) - offsets = data - min_ - bucket_indices = tf.cast(tf.floor(offsets / bucket_width), - dtype=tf.int32) - clamped_indices = tf.minimum(bucket_indices, bucket_count - 1) - one_hots = tf.one_hot(clamped_indices, depth=bucket_count) - bucket_counts = tf.cast(tf.reduce_sum(input_tensor=one_hots, axis=0), - dtype=tf.float64) - edges = tf.linspace(min_, max_, bucket_count + 1) - left_edges = edges[:-1] - right_edges = edges[1:] - return tf.transpose(a=tf.stack( - [left_edges, right_edges, bucket_counts])) - - def when_singular(): - center = min_ - bucket_starts = tf.stack([center - 0.5]) - bucket_ends = tf.stack([center + 0.5]) - bucket_counts = tf.stack([tf.cast(tf.size(input=data), tf.float64)]) - return tf.transpose( - a=tf.stack([bucket_starts, bucket_ends, bucket_counts])) - - return tf.cond(is_singular, when_singular, when_nonsingular) - - return tf.cond(is_empty, when_empty, when_nonempty) - - -def op(name, - data, - bucket_count=None, - display_name=None, - description=None, - collections=None): - """Create a legacy histogram summary op. - - Arguments: - name: A unique name for the generated summary node. - data: A `Tensor` of any shape. Must be castable to `float64`. - bucket_count: Optional positive `int`. The output will have this - many buckets, except in two edge cases. If there is no data, then - there are no buckets. If there is data but all points have the - same value, then there is one bucket whose left and right - endpoints are the same. - display_name: Optional name for this summary in TensorBoard, as a - constant `str`. Defaults to `name`. - description: Optional long-form description for this summary, as a - constant `str`. Markdown is supported. Defaults to empty. - collections: Optional list of graph collections keys. The new - summary op is added to these collections. Defaults to - `[Graph Keys.SUMMARIES]`. - - Returns: - A TensorFlow summary op. - """ - # TODO(nickfelt): remove on-demand imports once dep situation is fixed. - import tensorflow.compat.v1 as tf - - if display_name is None: - display_name = name - summary_metadata = metadata.create_summary_metadata( - display_name=display_name, description=description) - with tf.name_scope(name): - tensor = _buckets(data, bucket_count=bucket_count) - return tf.summary.tensor_summary(name='histogram_summary', - tensor=tensor, - collections=collections, - summary_metadata=summary_metadata) + """Create a TensorFlow op to group data into histogram buckets. + + Arguments: + data: A `Tensor` of any shape. Must be castable to `float64`. + bucket_count: Optional positive `int` or scalar `int32` `Tensor`. + Returns: + A `Tensor` of shape `[k, 3]` and type `float64`. The `i`th row is + a triple `[left_edge, right_edge, count]` for a single bucket. + The value of `k` is either `bucket_count` or `1` or `0`. + """ + # TODO(nickfelt): remove on-demand imports once dep situation is fixed. + import tensorflow.compat.v1 as tf + + if bucket_count is None: + bucket_count = summary_v2.DEFAULT_BUCKET_COUNT + with tf.name_scope( + "buckets", values=[data, bucket_count] + ), tf.control_dependencies( + [tf.assert_scalar(bucket_count), tf.assert_type(bucket_count, tf.int32)] + ): + data = tf.reshape(data, shape=[-1]) # flatten + data = tf.cast(data, tf.float64) + is_empty = tf.equal(tf.size(input=data), 0) + + def when_empty(): + return tf.constant([], shape=(0, 3), dtype=tf.float64) + + def when_nonempty(): + min_ = tf.reduce_min(input_tensor=data) + max_ = tf.reduce_max(input_tensor=data) + range_ = max_ - min_ + is_singular = tf.equal(range_, 0) + + def when_nonsingular(): + bucket_width = range_ / tf.cast(bucket_count, tf.float64) + offsets = data - min_ + bucket_indices = tf.cast( + tf.floor(offsets / bucket_width), dtype=tf.int32 + ) + clamped_indices = tf.minimum(bucket_indices, bucket_count - 1) + one_hots = tf.one_hot(clamped_indices, depth=bucket_count) + bucket_counts = tf.cast( + tf.reduce_sum(input_tensor=one_hots, axis=0), + dtype=tf.float64, + ) + edges = tf.linspace(min_, max_, bucket_count + 1) + left_edges = edges[:-1] + right_edges = edges[1:] + return tf.transpose( + a=tf.stack([left_edges, right_edges, bucket_counts]) + ) + + def when_singular(): + center = min_ + bucket_starts = tf.stack([center - 0.5]) + bucket_ends = tf.stack([center + 0.5]) + bucket_counts = tf.stack( + [tf.cast(tf.size(input=data), tf.float64)] + ) + return tf.transpose( + a=tf.stack([bucket_starts, bucket_ends, bucket_counts]) + ) + + return tf.cond(is_singular, when_singular, when_nonsingular) + + return tf.cond(is_empty, when_empty, when_nonempty) + + +def op( + name, + data, + bucket_count=None, + display_name=None, + description=None, + collections=None, +): + """Create a legacy histogram summary op. + + Arguments: + name: A unique name for the generated summary node. + data: A `Tensor` of any shape. Must be castable to `float64`. + bucket_count: Optional positive `int`. The output will have this + many buckets, except in two edge cases. If there is no data, then + there are no buckets. If there is data but all points have the + same value, then there is one bucket whose left and right + endpoints are the same. + display_name: Optional name for this summary in TensorBoard, as a + constant `str`. Defaults to `name`. + description: Optional long-form description for this summary, as a + constant `str`. Markdown is supported. Defaults to empty. + collections: Optional list of graph collections keys. The new + summary op is added to these collections. Defaults to + `[Graph Keys.SUMMARIES]`. + + Returns: + A TensorFlow summary op. + """ + # TODO(nickfelt): remove on-demand imports once dep situation is fixed. + import tensorflow.compat.v1 as tf + + if display_name is None: + display_name = name + summary_metadata = metadata.create_summary_metadata( + display_name=display_name, description=description + ) + with tf.name_scope(name): + tensor = _buckets(data, bucket_count=bucket_count) + return tf.summary.tensor_summary( + name="histogram_summary", + tensor=tensor, + collections=collections, + summary_metadata=summary_metadata, + ) def pb(name, data, bucket_count=None, display_name=None, description=None): - """Create a legacy histogram summary protobuf. - - Arguments: - name: A unique name for the generated summary, including any desired - name scopes. - data: A `np.array` or array-like form of any shape. Must have type - castable to `float`. - bucket_count: Optional positive `int`. The output will have this - many buckets, except in two edge cases. If there is no data, then - there are no buckets. If there is data but all points have the - same value, then there is one bucket whose left and right - endpoints are the same. - display_name: Optional name for this summary in TensorBoard, as a - `str`. Defaults to `name`. - description: Optional long-form description for this summary, as a - `str`. Markdown is supported. Defaults to empty. - - Returns: - A `tf.Summary` protobuf object. - """ - # TODO(nickfelt): remove on-demand imports once dep situation is fixed. - import tensorflow.compat.v1 as tf - - if bucket_count is None: - bucket_count = summary_v2.DEFAULT_BUCKET_COUNT - data = np.array(data).flatten().astype(float) - if data.size == 0: - buckets = np.array([]).reshape((0, 3)) - else: - min_ = np.min(data) - max_ = np.max(data) - range_ = max_ - min_ - if range_ == 0: - center = min_ - buckets = np.array([[center - 0.5, center + 0.5, float(data.size)]]) + """Create a legacy histogram summary protobuf. + + Arguments: + name: A unique name for the generated summary, including any desired + name scopes. + data: A `np.array` or array-like form of any shape. Must have type + castable to `float`. + bucket_count: Optional positive `int`. The output will have this + many buckets, except in two edge cases. If there is no data, then + there are no buckets. If there is data but all points have the + same value, then there is one bucket whose left and right + endpoints are the same. + display_name: Optional name for this summary in TensorBoard, as a + `str`. Defaults to `name`. + description: Optional long-form description for this summary, as a + `str`. Markdown is supported. Defaults to empty. + + Returns: + A `tf.Summary` protobuf object. + """ + # TODO(nickfelt): remove on-demand imports once dep situation is fixed. + import tensorflow.compat.v1 as tf + + if bucket_count is None: + bucket_count = summary_v2.DEFAULT_BUCKET_COUNT + data = np.array(data).flatten().astype(float) + if data.size == 0: + buckets = np.array([]).reshape((0, 3)) else: - bucket_width = range_ / bucket_count - offsets = data - min_ - bucket_indices = np.floor(offsets / bucket_width).astype(int) - clamped_indices = np.minimum(bucket_indices, bucket_count - 1) - one_hots = (np.array([clamped_indices]).transpose() - == np.arange(0, bucket_count)) # broadcast - assert one_hots.shape == (data.size, bucket_count), ( - one_hots.shape, (data.size, bucket_count)) - bucket_counts = np.sum(one_hots, axis=0) - edges = np.linspace(min_, max_, bucket_count + 1) - left_edges = edges[:-1] - right_edges = edges[1:] - buckets = np.array([left_edges, right_edges, bucket_counts]).transpose() - tensor = tf.make_tensor_proto(buckets, dtype=tf.float64) - - if display_name is None: - display_name = name - summary_metadata = metadata.create_summary_metadata( - display_name=display_name, description=description) - tf_summary_metadata = tf.SummaryMetadata.FromString( - summary_metadata.SerializeToString()) - - summary = tf.Summary() - summary.value.add(tag='%s/histogram_summary' % name, - metadata=tf_summary_metadata, - tensor=tensor) - return summary + min_ = np.min(data) + max_ = np.max(data) + range_ = max_ - min_ + if range_ == 0: + center = min_ + buckets = np.array([[center - 0.5, center + 0.5, float(data.size)]]) + else: + bucket_width = range_ / bucket_count + offsets = data - min_ + bucket_indices = np.floor(offsets / bucket_width).astype(int) + clamped_indices = np.minimum(bucket_indices, bucket_count - 1) + one_hots = np.array([clamped_indices]).transpose() == np.arange( + 0, bucket_count + ) # broadcast + assert one_hots.shape == (data.size, bucket_count), ( + one_hots.shape, + (data.size, bucket_count), + ) + bucket_counts = np.sum(one_hots, axis=0) + edges = np.linspace(min_, max_, bucket_count + 1) + left_edges = edges[:-1] + right_edges = edges[1:] + buckets = np.array( + [left_edges, right_edges, bucket_counts] + ).transpose() + tensor = tf.make_tensor_proto(buckets, dtype=tf.float64) + + if display_name is None: + display_name = name + summary_metadata = metadata.create_summary_metadata( + display_name=display_name, description=description + ) + tf_summary_metadata = tf.SummaryMetadata.FromString( + summary_metadata.SerializeToString() + ) + + summary = tf.Summary() + summary.value.add( + tag="%s/histogram_summary" % name, + metadata=tf_summary_metadata, + tensor=tensor, + ) + return summary diff --git a/tensorboard/plugins/histogram/summary_test.py b/tensorboard/plugins/histogram/summary_test.py index 312ed413ad..b6b4d50640 100644 --- a/tensorboard/plugins/histogram/summary_test.py +++ b/tensorboard/plugins/histogram/summary_test.py @@ -33,198 +33,211 @@ try: - tf2.__version__ # Force lazy import to resolve + tf2.__version__ # Force lazy import to resolve except ImportError: - tf2 = None + tf2 = None try: - tf.compat.v1.enable_eager_execution() + tf.compat.v1.enable_eager_execution() except AttributeError: - # TF 2.0 doesn't have this symbol because eager is the default. - pass + # TF 2.0 doesn't have this symbol because eager is the default. + pass class SummaryBaseTest(object): - - def setUp(self): - super(SummaryBaseTest, self).setUp() - np.random.seed(0) - self.gaussian = np.random.normal(size=[100]) - - def histogram(self, *args, **kwargs): - raise NotImplementedError() - - def test_metadata(self): - pb = self.histogram('h', [], description='foo') - self.assertEqual(len(pb.value), 1) - summary_metadata = pb.value[0].metadata - self.assertEqual(summary_metadata.summary_description, 'foo') - plugin_data = summary_metadata.plugin_data - self.assertEqual(plugin_data.plugin_name, metadata.PLUGIN_NAME) - parsed = metadata.parse_plugin_metadata(plugin_data.content) - self.assertEqual(metadata.PROTO_VERSION, parsed.version) - - def test_empty_input(self): - pb = self.histogram('empty', []) - buckets = tensor_util.make_ndarray(pb.value[0].tensor) - np.testing.assert_allclose(buckets, np.array([]).reshape((0, 3))) - - def test_empty_input_of_high_rank(self): - pb = self.histogram('empty_but_fancy', [[[], []], [[], []]]) - buckets = tensor_util.make_ndarray(pb.value[0].tensor) - np.testing.assert_allclose(buckets, np.array([]).reshape((0, 3))) - - def test_singleton_input(self): - pb = self.histogram('twelve', [12]) - buckets = tensor_util.make_ndarray(pb.value[0].tensor) - np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 1]])) - - def test_input_with_all_same_values(self): - pb = self.histogram('twelven', [12, 12, 12]) - buckets = tensor_util.make_ndarray(pb.value[0].tensor) - np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 3]])) - - def test_fixed_input(self): - pass # TODO: test a small fixed input - - def test_normal_distribution_input(self): - bucket_count = 44 - pb = self.histogram( - 'normal', data=self.gaussian.reshape((5, -1)), buckets=bucket_count) - buckets = tensor_util.make_ndarray(pb.value[0].tensor) - self.assertEqual(buckets[:, 0].min(), self.gaussian.min()) - # Assert near, not equal, since TF's linspace op introduces floating point - # error in the upper bound of the result. - self.assertNear(buckets[:, 1].max(), self.gaussian.max(), 1.0**-10) - self.assertEqual(buckets[:, 2].sum(), self.gaussian.size) - np.testing.assert_allclose(buckets[1:, 0], buckets[:-1, 1]) - - def test_when_shape_not_statically_known(self): - self.skipTest('TODO: figure out how to test this') - placeholder = tf.compat.v1.placeholder(tf.float64, shape=None) - reshaped = self.gaussian.reshape((25, -1)) - self.histogram(data=reshaped, - data_tensor=placeholder, - feed_dict={placeholder: reshaped}) - # The proto-equality check is all we need. - - def test_when_bucket_count_not_statically_known(self): - self.skipTest('TODO: figure out how to test this') - placeholder = tf.compat.v1.placeholder(tf.int32, shape=()) - bucket_count = 44 - pb = self.histogram( - bucket_count=bucket_count, - bucket_count_tensor=placeholder, - feed_dict={placeholder: bucket_count}) - buckets = tensor_util.make_ndarray(pb.value[0].tensor) - self.assertEqual(buckets.shape, (bucket_count, 3)) + def setUp(self): + super(SummaryBaseTest, self).setUp() + np.random.seed(0) + self.gaussian = np.random.normal(size=[100]) + + def histogram(self, *args, **kwargs): + raise NotImplementedError() + + def test_metadata(self): + pb = self.histogram("h", [], description="foo") + self.assertEqual(len(pb.value), 1) + summary_metadata = pb.value[0].metadata + self.assertEqual(summary_metadata.summary_description, "foo") + plugin_data = summary_metadata.plugin_data + self.assertEqual(plugin_data.plugin_name, metadata.PLUGIN_NAME) + parsed = metadata.parse_plugin_metadata(plugin_data.content) + self.assertEqual(metadata.PROTO_VERSION, parsed.version) + + def test_empty_input(self): + pb = self.histogram("empty", []) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + np.testing.assert_allclose(buckets, np.array([]).reshape((0, 3))) + + def test_empty_input_of_high_rank(self): + pb = self.histogram("empty_but_fancy", [[[], []], [[], []]]) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + np.testing.assert_allclose(buckets, np.array([]).reshape((0, 3))) + + def test_singleton_input(self): + pb = self.histogram("twelve", [12]) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 1]])) + + def test_input_with_all_same_values(self): + pb = self.histogram("twelven", [12, 12, 12]) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 3]])) + + def test_fixed_input(self): + pass # TODO: test a small fixed input + + def test_normal_distribution_input(self): + bucket_count = 44 + pb = self.histogram( + "normal", data=self.gaussian.reshape((5, -1)), buckets=bucket_count + ) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + self.assertEqual(buckets[:, 0].min(), self.gaussian.min()) + # Assert near, not equal, since TF's linspace op introduces floating point + # error in the upper bound of the result. + self.assertNear(buckets[:, 1].max(), self.gaussian.max(), 1.0 ** -10) + self.assertEqual(buckets[:, 2].sum(), self.gaussian.size) + np.testing.assert_allclose(buckets[1:, 0], buckets[:-1, 1]) + + def test_when_shape_not_statically_known(self): + self.skipTest("TODO: figure out how to test this") + placeholder = tf.compat.v1.placeholder(tf.float64, shape=None) + reshaped = self.gaussian.reshape((25, -1)) + self.histogram( + data=reshaped, + data_tensor=placeholder, + feed_dict={placeholder: reshaped}, + ) + # The proto-equality check is all we need. + + def test_when_bucket_count_not_statically_known(self): + self.skipTest("TODO: figure out how to test this") + placeholder = tf.compat.v1.placeholder(tf.int32, shape=()) + bucket_count = 44 + pb = self.histogram( + bucket_count=bucket_count, + bucket_count_tensor=placeholder, + feed_dict={placeholder: bucket_count}, + ) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + self.assertEqual(buckets.shape, (bucket_count, 3)) class SummaryV1PbTest(SummaryBaseTest, tf.test.TestCase): - def histogram(self, *args, **kwargs): - # Map new name to the old name. - if 'buckets' in kwargs: - kwargs['bucket_count'] = kwargs.pop('buckets') - return summary.pb(*args, **kwargs) + def histogram(self, *args, **kwargs): + # Map new name to the old name. + if "buckets" in kwargs: + kwargs["bucket_count"] = kwargs.pop("buckets") + return summary.pb(*args, **kwargs) - def test_tag(self): - self.assertEqual('a/histogram_summary', - self.histogram('a', []).value[0].tag) - self.assertEqual('a/b/histogram_summary', - self.histogram('a/b', []).value[0].tag) + def test_tag(self): + self.assertEqual( + "a/histogram_summary", self.histogram("a", []).value[0].tag + ) + self.assertEqual( + "a/b/histogram_summary", self.histogram("a/b", []).value[0].tag + ) class SummaryV1OpTest(SummaryBaseTest, tf.test.TestCase): - def histogram(self, *args, **kwargs): - # Map new name to the old name. - if 'buckets' in kwargs: - kwargs['bucket_count'] = kwargs.pop('buckets') - return summary_pb2.Summary.FromString(summary.op(*args, **kwargs).numpy()) - - def test_tag(self): - self.assertEqual('a/histogram_summary', - self.histogram('a', []).value[0].tag) - self.assertEqual('a/b/histogram_summary', - self.histogram('a/b', []).value[0].tag) - - def test_scoped_tag(self): - with tf.name_scope('scope'): - self.assertEqual('scope/a/histogram_summary', - self.histogram('a', []).value[0].tag) + def histogram(self, *args, **kwargs): + # Map new name to the old name. + if "buckets" in kwargs: + kwargs["bucket_count"] = kwargs.pop("buckets") + return summary_pb2.Summary.FromString( + summary.op(*args, **kwargs).numpy() + ) + + def test_tag(self): + self.assertEqual( + "a/histogram_summary", self.histogram("a", []).value[0].tag + ) + self.assertEqual( + "a/b/histogram_summary", self.histogram("a/b", []).value[0].tag + ) + + def test_scoped_tag(self): + with tf.name_scope("scope"): + self.assertEqual( + "scope/a/histogram_summary", + self.histogram("a", []).value[0].tag, + ) class SummaryV2PbTest(SummaryBaseTest, tf.test.TestCase): - def histogram(self, *args, **kwargs): - return summary.histogram_pb(*args, **kwargs) + def histogram(self, *args, **kwargs): + return summary.histogram_pb(*args, **kwargs) class SummaryV2OpTest(SummaryBaseTest, tf.test.TestCase): - def setUp(self): - super(SummaryV2OpTest, self).setUp() - if tf2 is None: - self.skipTest('v2 summary API not available') - - def histogram(self, *args, **kwargs): - return self.histogram_event(*args, **kwargs).summary - - def histogram_event(self, *args, **kwargs): - kwargs.setdefault('step', 1) - writer = tf2.summary.create_file_writer(self.get_temp_dir()) - with writer.as_default(): - summary.histogram(*args, **kwargs) - writer.close() - event_files = sorted(glob.glob(os.path.join(self.get_temp_dir(), '*'))) - self.assertEqual(len(event_files), 1) - events = list(tf.compat.v1.train.summary_iterator(event_files[0])) - # Expect a boilerplate event for the file_version, then the summary one. - self.assertEqual(len(events), 2) - # Delete the event file to reset to an empty directory for later calls. - # TODO(nickfelt): use a unique subdirectory per writer instead. - os.remove(event_files[0]) - return events[1] - - def write_histogram_event(self, *args, **kwargs): - kwargs.setdefault('step', 1) - writer = tf2.summary.create_file_writer(self.get_temp_dir()) - with writer.as_default(): - summary.histogram(*args, **kwargs) - writer.close() - - def test_scoped_tag(self): - with tf.name_scope('scope'): - self.assertEqual('scope/a', self.histogram('a', []).value[0].tag) - - def test_step(self): - event = self.histogram_event('a', [], step=333) - self.assertEqual(333, event.step) - - def test_default_step(self): - try: - tf2.summary.experimental.set_step(333) - # TODO(nickfelt): change test logic so we can just omit `step` entirely. - event = self.histogram_event('a', [], step=None) - self.assertEqual(333, event.step) - finally: - # Reset to default state for other tests. - tf2.summary.experimental.set_step(None) + def setUp(self): + super(SummaryV2OpTest, self).setUp() + if tf2 is None: + self.skipTest("v2 summary API not available") + + def histogram(self, *args, **kwargs): + return self.histogram_event(*args, **kwargs).summary + + def histogram_event(self, *args, **kwargs): + kwargs.setdefault("step", 1) + writer = tf2.summary.create_file_writer(self.get_temp_dir()) + with writer.as_default(): + summary.histogram(*args, **kwargs) + writer.close() + event_files = sorted(glob.glob(os.path.join(self.get_temp_dir(), "*"))) + self.assertEqual(len(event_files), 1) + events = list(tf.compat.v1.train.summary_iterator(event_files[0])) + # Expect a boilerplate event for the file_version, then the summary one. + self.assertEqual(len(events), 2) + # Delete the event file to reset to an empty directory for later calls. + # TODO(nickfelt): use a unique subdirectory per writer instead. + os.remove(event_files[0]) + return events[1] + + def write_histogram_event(self, *args, **kwargs): + kwargs.setdefault("step", 1) + writer = tf2.summary.create_file_writer(self.get_temp_dir()) + with writer.as_default(): + summary.histogram(*args, **kwargs) + writer.close() + + def test_scoped_tag(self): + with tf.name_scope("scope"): + self.assertEqual("scope/a", self.histogram("a", []).value[0].tag) + + def test_step(self): + event = self.histogram_event("a", [], step=333) + self.assertEqual(333, event.step) + + def test_default_step(self): + try: + tf2.summary.experimental.set_step(333) + # TODO(nickfelt): change test logic so we can just omit `step` entirely. + event = self.histogram_event("a", [], step=None) + self.assertEqual(333, event.step) + finally: + # Reset to default state for other tests. + tf2.summary.experimental.set_step(None) class SummaryV2OpGraphTest(SummaryV2OpTest, tf.test.TestCase): - def write_histogram_event(self, *args, **kwargs): - kwargs.setdefault('step', 1) - # Hack to extract current scope since there's no direct API for it. - with tf.name_scope('_') as temp_scope: - scope = temp_scope.rstrip('/_') - @tf2.function - def graph_fn(): - # Recreate the active scope inside the defun since it won't propagate. - with tf.name_scope(scope): - summary.histogram(*args, **kwargs) - writer = tf2.summary.create_file_writer(self.get_temp_dir()) - with writer.as_default(): - graph_fn() - writer.close() - - -if __name__ == '__main__': - tf.test.main() + def write_histogram_event(self, *args, **kwargs): + kwargs.setdefault("step", 1) + # Hack to extract current scope since there's no direct API for it. + with tf.name_scope("_") as temp_scope: + scope = temp_scope.rstrip("/_") + + @tf2.function + def graph_fn(): + # Recreate the active scope inside the defun since it won't propagate. + with tf.name_scope(scope): + summary.histogram(*args, **kwargs) + + writer = tf2.summary.create_file_writer(self.get_temp_dir()) + with writer.as_default(): + graph_fn() + writer.close() + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/histogram/summary_v2.py b/tensorboard/plugins/histogram/summary_v2.py index 172909f043..4877944773 100644 --- a/tensorboard/plugins/histogram/summary_v2.py +++ b/tensorboard/plugins/histogram/summary_v2.py @@ -42,170 +42,191 @@ def histogram(name, data, step=None, buckets=None, description=None): - """Write a histogram summary. - - Arguments: - name: A name for this summary. The summary tag used for TensorBoard will - be this name prefixed by any active name scopes. - data: A `Tensor` of any shape. Must be castable to `float64`. - step: Explicit `int64`-castable monotonic step value for this summary. If - omitted, this defaults to `tf.summary.experimental.get_step()`, which must - not be None. - buckets: Optional positive `int`. The output will have this - many buckets, except in two edge cases. If there is no data, then - there are no buckets. If there is data but all points have the - same value, then there is one bucket whose left and right - endpoints are the same. - description: Optional long-form description for this summary, as a - constant `str`. Markdown is supported. Defaults to empty. - - Returns: - True on success, or false if no summary was emitted because no default - summary writer was available. - - Raises: - ValueError: if a default writer exists, but no step was provided and - `tf.summary.experimental.get_step()` is None. - """ - summary_metadata = metadata.create_summary_metadata( - display_name=None, description=description) - # TODO(https://github.com/tensorflow/tensorboard/issues/2109): remove fallback - summary_scope = ( - getattr(tf.summary.experimental, 'summary_scope', None) or - tf.summary.summary_scope) - - def histogram_summary(data, buckets, histogram_metadata, step): - with summary_scope( - name, 'histogram_summary', values=[data, buckets, step]) as (tag, _): - # Defer histogram bucketing logic by passing it as a callable to write(), - # wrapped in a LazyTensorCreator for backwards compatibility, so that we - # only do this work when summaries are actually written. - @lazy_tensor_creator.LazyTensorCreator - def lazy_tensor(): - return _buckets(data, buckets) - return tf.summary.write( - tag=tag, tensor=lazy_tensor, step=step, metadata=summary_metadata) - - # `_buckets()` has dynamic output shapes which is not supported on TPU's. As so, place - # the bucketing ops on outside compilation cluster so that the function in executed on CPU. - # TODO(https://github.com/tensorflow/tensorboard/issues/2885): Remove this special - # handling once dynamic shapes are supported on TPU's. - if isinstance(tf.distribute.get_strategy(), - tf.distribute.experimental.TPUStrategy): - return tf.compat.v1.tpu.outside_compilation( - histogram_summary, data, buckets, summary_metadata, step) - return histogram_summary(data, buckets, summary_metadata, step) + """Write a histogram summary. + + Arguments: + name: A name for this summary. The summary tag used for TensorBoard will + be this name prefixed by any active name scopes. + data: A `Tensor` of any shape. Must be castable to `float64`. + step: Explicit `int64`-castable monotonic step value for this summary. If + omitted, this defaults to `tf.summary.experimental.get_step()`, which must + not be None. + buckets: Optional positive `int`. The output will have this + many buckets, except in two edge cases. If there is no data, then + there are no buckets. If there is data but all points have the + same value, then there is one bucket whose left and right + endpoints are the same. + description: Optional long-form description for this summary, as a + constant `str`. Markdown is supported. Defaults to empty. + + Returns: + True on success, or false if no summary was emitted because no default + summary writer was available. + + Raises: + ValueError: if a default writer exists, but no step was provided and + `tf.summary.experimental.get_step()` is None. + """ + summary_metadata = metadata.create_summary_metadata( + display_name=None, description=description + ) + # TODO(https://github.com/tensorflow/tensorboard/issues/2109): remove fallback + summary_scope = ( + getattr(tf.summary.experimental, "summary_scope", None) + or tf.summary.summary_scope + ) + + def histogram_summary(data, buckets, histogram_metadata, step): + with summary_scope( + name, "histogram_summary", values=[data, buckets, step] + ) as (tag, _): + # Defer histogram bucketing logic by passing it as a callable to write(), + # wrapped in a LazyTensorCreator for backwards compatibility, so that we + # only do this work when summaries are actually written. + @lazy_tensor_creator.LazyTensorCreator + def lazy_tensor(): + return _buckets(data, buckets) + + return tf.summary.write( + tag=tag, + tensor=lazy_tensor, + step=step, + metadata=summary_metadata, + ) + + # `_buckets()` has dynamic output shapes which is not supported on TPU's. As so, place + # the bucketing ops on outside compilation cluster so that the function in executed on CPU. + # TODO(https://github.com/tensorflow/tensorboard/issues/2885): Remove this special + # handling once dynamic shapes are supported on TPU's. + if isinstance( + tf.distribute.get_strategy(), tf.distribute.experimental.TPUStrategy + ): + return tf.compat.v1.tpu.outside_compilation( + histogram_summary, data, buckets, summary_metadata, step + ) + return histogram_summary(data, buckets, summary_metadata, step) def _buckets(data, bucket_count=None): - """Create a TensorFlow op to group data into histogram buckets. - - Arguments: - data: A `Tensor` of any shape. Must be castable to `float64`. - bucket_count: Optional positive `int` or scalar `int32` `Tensor`. - Returns: - A `Tensor` of shape `[k, 3]` and type `float64`. The `i`th row is - a triple `[left_edge, right_edge, count]` for a single bucket. - The value of `k` is either `bucket_count` or `1` or `0`. - """ - if bucket_count is None: - bucket_count = DEFAULT_BUCKET_COUNT - with tf.name_scope('buckets'): - tf.debugging.assert_scalar(bucket_count) - tf.debugging.assert_type(bucket_count, tf.int32) - data = tf.reshape(data, shape=[-1]) # flatten - data = tf.cast(data, tf.float64) - is_empty = tf.equal(tf.size(input=data), 0) - - def when_empty(): - return tf.constant([], shape=(0, 3), dtype=tf.float64) - - def when_nonempty(): - min_ = tf.reduce_min(input_tensor=data) - max_ = tf.reduce_max(input_tensor=data) - range_ = max_ - min_ - is_singular = tf.equal(range_, 0) - - def when_nonsingular(): - bucket_width = range_ / tf.cast(bucket_count, tf.float64) - offsets = data - min_ - bucket_indices = tf.cast(tf.floor(offsets / bucket_width), - dtype=tf.int32) - clamped_indices = tf.minimum(bucket_indices, bucket_count - 1) - one_hots = tf.one_hot(clamped_indices, depth=bucket_count) - bucket_counts = tf.cast(tf.reduce_sum(input_tensor=one_hots, axis=0), - dtype=tf.float64) - edges = tf.linspace(min_, max_, bucket_count + 1) - # Ensure edges[-1] == max_, which TF's linspace implementation does not - # do, leaving it subject to the whim of floating point rounding error. - edges = tf.concat([edges[:-1], [max_]], 0) - left_edges = edges[:-1] - right_edges = edges[1:] - return tf.transpose(a=tf.stack( - [left_edges, right_edges, bucket_counts])) - - def when_singular(): - center = min_ - bucket_starts = tf.stack([center - 0.5]) - bucket_ends = tf.stack([center + 0.5]) - bucket_counts = tf.stack([tf.cast(tf.size(input=data), tf.float64)]) - return tf.transpose( - a=tf.stack([bucket_starts, bucket_ends, bucket_counts])) - - return tf.cond(is_singular, when_singular, when_nonsingular) - - return tf.cond(is_empty, when_empty, when_nonempty) + """Create a TensorFlow op to group data into histogram buckets. + + Arguments: + data: A `Tensor` of any shape. Must be castable to `float64`. + bucket_count: Optional positive `int` or scalar `int32` `Tensor`. + Returns: + A `Tensor` of shape `[k, 3]` and type `float64`. The `i`th row is + a triple `[left_edge, right_edge, count]` for a single bucket. + The value of `k` is either `bucket_count` or `1` or `0`. + """ + if bucket_count is None: + bucket_count = DEFAULT_BUCKET_COUNT + with tf.name_scope("buckets"): + tf.debugging.assert_scalar(bucket_count) + tf.debugging.assert_type(bucket_count, tf.int32) + data = tf.reshape(data, shape=[-1]) # flatten + data = tf.cast(data, tf.float64) + is_empty = tf.equal(tf.size(input=data), 0) + + def when_empty(): + return tf.constant([], shape=(0, 3), dtype=tf.float64) + + def when_nonempty(): + min_ = tf.reduce_min(input_tensor=data) + max_ = tf.reduce_max(input_tensor=data) + range_ = max_ - min_ + is_singular = tf.equal(range_, 0) + + def when_nonsingular(): + bucket_width = range_ / tf.cast(bucket_count, tf.float64) + offsets = data - min_ + bucket_indices = tf.cast( + tf.floor(offsets / bucket_width), dtype=tf.int32 + ) + clamped_indices = tf.minimum(bucket_indices, bucket_count - 1) + one_hots = tf.one_hot(clamped_indices, depth=bucket_count) + bucket_counts = tf.cast( + tf.reduce_sum(input_tensor=one_hots, axis=0), + dtype=tf.float64, + ) + edges = tf.linspace(min_, max_, bucket_count + 1) + # Ensure edges[-1] == max_, which TF's linspace implementation does not + # do, leaving it subject to the whim of floating point rounding error. + edges = tf.concat([edges[:-1], [max_]], 0) + left_edges = edges[:-1] + right_edges = edges[1:] + return tf.transpose( + a=tf.stack([left_edges, right_edges, bucket_counts]) + ) + + def when_singular(): + center = min_ + bucket_starts = tf.stack([center - 0.5]) + bucket_ends = tf.stack([center + 0.5]) + bucket_counts = tf.stack( + [tf.cast(tf.size(input=data), tf.float64)] + ) + return tf.transpose( + a=tf.stack([bucket_starts, bucket_ends, bucket_counts]) + ) + + return tf.cond(is_singular, when_singular, when_nonsingular) + + return tf.cond(is_empty, when_empty, when_nonempty) def histogram_pb(tag, data, buckets=None, description=None): - """Create a histogram summary protobuf. - - Arguments: - tag: String tag for the summary. - data: A `np.array` or array-like form of any shape. Must have type - castable to `float`. - buckets: Optional positive `int`. The output will have this - many buckets, except in two edge cases. If there is no data, then - there are no buckets. If there is data but all points have the - same value, then there is one bucket whose left and right - endpoints are the same. - description: Optional long-form description for this summary, as a - `str`. Markdown is supported. Defaults to empty. - - Returns: - A `summary_pb2.Summary` protobuf object. - """ - bucket_count = DEFAULT_BUCKET_COUNT if buckets is None else buckets - data = np.array(data).flatten().astype(float) - if data.size == 0: - buckets = np.array([]).reshape((0, 3)) - else: - min_ = np.min(data) - max_ = np.max(data) - range_ = max_ - min_ - if range_ == 0: - center = min_ - buckets = np.array([[center - 0.5, center + 0.5, float(data.size)]]) + """Create a histogram summary protobuf. + + Arguments: + tag: String tag for the summary. + data: A `np.array` or array-like form of any shape. Must have type + castable to `float`. + buckets: Optional positive `int`. The output will have this + many buckets, except in two edge cases. If there is no data, then + there are no buckets. If there is data but all points have the + same value, then there is one bucket whose left and right + endpoints are the same. + description: Optional long-form description for this summary, as a + `str`. Markdown is supported. Defaults to empty. + + Returns: + A `summary_pb2.Summary` protobuf object. + """ + bucket_count = DEFAULT_BUCKET_COUNT if buckets is None else buckets + data = np.array(data).flatten().astype(float) + if data.size == 0: + buckets = np.array([]).reshape((0, 3)) else: - bucket_width = range_ / bucket_count - offsets = data - min_ - bucket_indices = np.floor(offsets / bucket_width).astype(int) - clamped_indices = np.minimum(bucket_indices, bucket_count - 1) - one_hots = (np.array([clamped_indices]).transpose() - == np.arange(0, bucket_count)) # broadcast - assert one_hots.shape == (data.size, bucket_count), ( - one_hots.shape, (data.size, bucket_count)) - bucket_counts = np.sum(one_hots, axis=0) - edges = np.linspace(min_, max_, bucket_count + 1) - left_edges = edges[:-1] - right_edges = edges[1:] - buckets = np.array([left_edges, right_edges, bucket_counts]).transpose() - tensor = tensor_util.make_tensor_proto(buckets, dtype=np.float64) - - summary_metadata = metadata.create_summary_metadata( - display_name=None, description=description) - summary = summary_pb2.Summary() - summary.value.add(tag=tag, - metadata=summary_metadata, - tensor=tensor) - return summary + min_ = np.min(data) + max_ = np.max(data) + range_ = max_ - min_ + if range_ == 0: + center = min_ + buckets = np.array([[center - 0.5, center + 0.5, float(data.size)]]) + else: + bucket_width = range_ / bucket_count + offsets = data - min_ + bucket_indices = np.floor(offsets / bucket_width).astype(int) + clamped_indices = np.minimum(bucket_indices, bucket_count - 1) + one_hots = np.array([clamped_indices]).transpose() == np.arange( + 0, bucket_count + ) # broadcast + assert one_hots.shape == (data.size, bucket_count), ( + one_hots.shape, + (data.size, bucket_count), + ) + bucket_counts = np.sum(one_hots, axis=0) + edges = np.linspace(min_, max_, bucket_count + 1) + left_edges = edges[:-1] + right_edges = edges[1:] + buckets = np.array( + [left_edges, right_edges, bucket_counts] + ).transpose() + tensor = tensor_util.make_tensor_proto(buckets, dtype=np.float64) + + summary_metadata = metadata.create_summary_metadata( + display_name=None, description=description + ) + summary = summary_pb2.Summary() + summary.value.add(tag=tag, metadata=summary_metadata, tensor=tensor) + return summary diff --git a/tensorboard/plugins/hparams/api_test.py b/tensorboard/plugins/hparams/api_test.py index cbb5e5a2a9..7c3a0dbd2d 100644 --- a/tensorboard/plugins/hparams/api_test.py +++ b/tensorboard/plugins/hparams/api_test.py @@ -24,13 +24,12 @@ class ApiTest(test.TestCase): + def test_has_core_attributes(self): + self.assertIs(api.HParam, summary_v2.HParam) - def test_has_core_attributes(self): - self.assertIs(api.HParam, summary_v2.HParam) - - def test_has_keras_dependent_attributes(self): - self.assertIs(api.KerasCallback, keras.Callback) + def test_has_keras_dependent_attributes(self): + self.assertIs(api.KerasCallback, keras.Callback) if __name__ == "__main__": - test.main() + test.main() diff --git a/tensorboard/plugins/hparams/backend_context.py b/tensorboard/plugins/hparams/backend_context.py index 1ccd7cd2c9..630c50caf2 100644 --- a/tensorboard/plugins/hparams/backend_context.py +++ b/tensorboard/plugins/hparams/backend_context.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Wraps the base_plugin.TBContext to stores additional data shared across - API handlers for the HParams plugin backend. -""" +"""Wraps the base_plugin.TBContext to stores additional data shared across API +handlers for the HParams plugin backend.""" from __future__ import absolute_import from __future__ import division @@ -33,277 +32,297 @@ class Context(object): - """Wraps the base_plugin.TBContext to stores additional data shared across - API handlers for the HParams plugin backend. - - Before adding fields to this class, carefully consider whether the field - truelly needs to be accessible to all API handlers or if it can be passed - separately to the handler constructor. - We want to avoid this class becoming a magic container of variables that - have no better place. See http://wiki.c2.com/?MagicContainer - """ - - def __init__(self, - tb_context, - max_domain_discrete_len=10): - """Instantiates a context. - - Args: - tb_context: base_plugin.TBContext. The "base" context we extend. - max_domain_discrete_len: int. Only used when computing the experiment - from the session runs. The maximum number of disticnt values a string - hyperparameter can have for us to populate its 'domain_discrete' field. - Typically, only tests should specify a value for this parameter. - """ - self._tb_context = tb_context - self._experiment_from_tag = None - self._experiment_from_tag_lock = threading.Lock() - self._max_domain_discrete_len = max_domain_discrete_len - - def experiment(self): - """Returns the experiment protobuffer defining the experiment. - - This method first attempts to find a metadata.EXPERIMENT_TAG tag and - retrieve the associated protobuffer. If no such tag is found, the method - will attempt to build a minimal experiment protobuffer by scanning for - all metadata.SESSION_START_INFO_TAG tags (to compute the hparam_infos - field of the experiment) and for all scalar tags (to compute the - metric_infos field of the experiment). - - Returns: - The experiment protobuffer. If no tags are found from which an experiment - protobuffer can be built (possibly, because the event data has not been - completely loaded yet), returns None. + """Wraps the base_plugin.TBContext to stores additional data shared across + API handlers for the HParams plugin backend. + + Before adding fields to this class, carefully consider whether the + field truelly needs to be accessible to all API handlers or if it + can be passed separately to the handler constructor. We want to + avoid this class becoming a magic container of variables that have + no better place. See http://wiki.c2.com/?MagicContainer """ - experiment = self._find_experiment_tag() - if experiment is None: - return self._compute_experiment_from_runs() - return experiment - @property - def multiplexer(self): - return self._tb_context.multiplexer + def __init__(self, tb_context, max_domain_discrete_len=10): + """Instantiates a context. + + Args: + tb_context: base_plugin.TBContext. The "base" context we extend. + max_domain_discrete_len: int. Only used when computing the experiment + from the session runs. The maximum number of disticnt values a string + hyperparameter can have for us to populate its 'domain_discrete' field. + Typically, only tests should specify a value for this parameter. + """ + self._tb_context = tb_context + self._experiment_from_tag = None + self._experiment_from_tag_lock = threading.Lock() + self._max_domain_discrete_len = max_domain_discrete_len + + def experiment(self): + """Returns the experiment protobuffer defining the experiment. + + This method first attempts to find a metadata.EXPERIMENT_TAG tag and + retrieve the associated protobuffer. If no such tag is found, the method + will attempt to build a minimal experiment protobuffer by scanning for + all metadata.SESSION_START_INFO_TAG tags (to compute the hparam_infos + field of the experiment) and for all scalar tags (to compute the + metric_infos field of the experiment). + + Returns: + The experiment protobuffer. If no tags are found from which an experiment + protobuffer can be built (possibly, because the event data has not been + completely loaded yet), returns None. + """ + experiment = self._find_experiment_tag() + if experiment is None: + return self._compute_experiment_from_runs() + return experiment + + @property + def multiplexer(self): + return self._tb_context.multiplexer + + @property + def tb_context(self): + return self._tb_context + + def _find_experiment_tag(self): + """Finds the experiment associcated with the metadata.EXPERIMENT_TAG + tag. + + Caches the experiment if it was found. + + Returns: + The experiment or None if no such experiment is found. + """ + with self._experiment_from_tag_lock: + if self._experiment_from_tag is None: + mapping = self.multiplexer.PluginRunToTagToContent( + metadata.PLUGIN_NAME + ) + for tag_to_content in mapping.values(): + if metadata.EXPERIMENT_TAG in tag_to_content: + self._experiment_from_tag = metadata.parse_experiment_plugin_data( + tag_to_content[metadata.EXPERIMENT_TAG] + ) + break + return self._experiment_from_tag + + def _compute_experiment_from_runs(self): + """Computes a minimal Experiment protocol buffer by scanning the + runs.""" + hparam_infos = self._compute_hparam_infos() + if not hparam_infos: + return None + metric_infos = self._compute_metric_infos() + return api_pb2.Experiment( + hparam_infos=hparam_infos, metric_infos=metric_infos + ) + + def _compute_hparam_infos(self): + """Computes a list of api_pb2.HParamInfo from the current run, tag + info. + + Finds all the SessionStartInfo messages and collects the hparams values + appearing in each one. For each hparam attempts to deduce a type that fits + all its values. Finally, sets the 'domain' of the resulting HParamInfo + to be discrete if the type is string and the number of distinct values is + small enough. + + Returns: + A list of api_pb2.HParamInfo messages. + """ + run_to_tag_to_content = self.multiplexer.PluginRunToTagToContent( + metadata.PLUGIN_NAME + ) + # Construct a dict mapping an hparam name to its list of values. + hparams = collections.defaultdict(list) + for tag_to_content in run_to_tag_to_content.values(): + if metadata.SESSION_START_INFO_TAG not in tag_to_content: + continue + start_info = metadata.parse_session_start_info_plugin_data( + tag_to_content[metadata.SESSION_START_INFO_TAG] + ) + for (name, value) in six.iteritems(start_info.hparams): + hparams[name].append(value) + + # Try to construct an HParamInfo for each hparam from its name and list + # of values. + result = [] + for (name, values) in six.iteritems(hparams): + hparam_info = self._compute_hparam_info_from_values(name, values) + if hparam_info is not None: + result.append(hparam_info) + return result + + def _compute_hparam_info_from_values(self, name, values): + """Builds an HParamInfo message from the hparam name and list of + values. + + Args: + name: string. The hparam name. + values: list of google.protobuf.Value messages. The list of values for the + hparam. + + Returns: + An api_pb2.HParamInfo message. + """ + # Figure out the type from the values. + # Ignore values whose type is not listed in api_pb2.DataType + # If all values have the same type, then that is the type used. + # Otherwise, the returned type is DATA_TYPE_STRING. + result = api_pb2.HParamInfo(name=name, type=api_pb2.DATA_TYPE_UNSET) + distinct_values = set( + _protobuf_value_to_string(v) + for v in values + if _protobuf_value_type(v) + ) + for v in values: + v_type = _protobuf_value_type(v) + if not v_type: + continue + if result.type == api_pb2.DATA_TYPE_UNSET: + result.type = v_type + elif result.type != v_type: + result.type = api_pb2.DATA_TYPE_STRING + if result.type == api_pb2.DATA_TYPE_STRING: + # A string result.type does not change, so we can exit the loop. + break + + # If we couldn't figure out a type, then we can't compute the hparam_info. + if result.type == api_pb2.DATA_TYPE_UNSET: + return None + + # If the result is a string, set the domain to be the distinct values if + # there aren't too many of them. + if ( + result.type == api_pb2.DATA_TYPE_STRING + and len(distinct_values) <= self._max_domain_discrete_len + ): + result.domain_discrete.extend(distinct_values) + + return result + + def _compute_metric_infos(self): + return ( + api_pb2.MetricInfo(name=api_pb2.MetricName(group=group, tag=tag)) + for tag, group in self._compute_metric_names() + ) + + def _compute_metric_names(self): + """Computes the list of metric names from all the scalar (run, tag) + pairs. + + The return value is a list of (tag, group) pairs representing the metric + names. The list is sorted in Python tuple-order (lexicographical). + + For example, if the scalar (run, tag) pairs are: + ("exp/session1", "loss") + ("exp/session2", "loss") + ("exp/session2/eval", "loss") + ("exp/session2/validation", "accuracy") + ("exp/no-session", "loss_2"), + and the runs corresponding to sessions are "exp/session1", "exp/session2", + this method will return [("loss", ""), ("loss", "/eval"), ("accuracy", + "/validation")] + + More precisely, each scalar (run, tag) pair is converted to a (tag, group) + metric name, where group is the suffix of run formed by removing the + longest prefix which is a session run. If no session run is a prefix of + 'run', the pair is skipped. + + Returns: + A python list containing pairs. Each pair is a (tag, group) pair + representing a metric name used in some session. + """ + session_runs = self._build_session_runs_set() + metric_names_set = set() + run_to_tag_to_content = self.multiplexer.PluginRunToTagToContent( + scalar_metadata.PLUGIN_NAME + ) + for (run, tag_to_content) in six.iteritems(run_to_tag_to_content): + session = _find_longest_parent_path(session_runs, run) + if not session: + continue + group = os.path.relpath(run, session) + # relpath() returns "." for the 'session' directory, we use an empty + # string. + if group == ".": + group = "" + metric_names_set.update( + (tag, group) for tag in tag_to_content.keys() + ) + metric_names_list = list(metric_names_set) + # Sort metrics for determinism. + metric_names_list.sort() + return metric_names_list + + def _build_session_runs_set(self): + result = set() + run_to_tag_to_content = self.multiplexer.PluginRunToTagToContent( + metadata.PLUGIN_NAME + ) + for (run, tag_to_content) in six.iteritems(run_to_tag_to_content): + if metadata.SESSION_START_INFO_TAG in tag_to_content: + result.add(run) + return result - @property - def tb_context(self): - return self._tb_context - def _find_experiment_tag(self): - """Finds the experiment associcated with the metadata.EXPERIMENT_TAG tag. +def _find_longest_parent_path(path_set, path): + """Finds the longest "parent-path" of 'path' in 'path_set'. - Caches the experiment if it was found. - - Returns: - The experiment or None if no such experiment is found. - """ - with self._experiment_from_tag_lock: - if self._experiment_from_tag is None: - mapping = self.multiplexer.PluginRunToTagToContent( - metadata.PLUGIN_NAME) - for tag_to_content in mapping.values(): - if metadata.EXPERIMENT_TAG in tag_to_content: - self._experiment_from_tag = metadata.parse_experiment_plugin_data( - tag_to_content[metadata.EXPERIMENT_TAG]) - break - return self._experiment_from_tag - - def _compute_experiment_from_runs(self): - """Computes a minimal Experiment protocol buffer by scanning the runs.""" - hparam_infos = self._compute_hparam_infos() - if not hparam_infos: - return None - metric_infos = self._compute_metric_infos() - return api_pb2.Experiment(hparam_infos=hparam_infos, - metric_infos=metric_infos) - - def _compute_hparam_infos(self): - """Computes a list of api_pb2.HParamInfo from the current run, tag info. - - Finds all the SessionStartInfo messages and collects the hparams values - appearing in each one. For each hparam attempts to deduce a type that fits - all its values. Finally, sets the 'domain' of the resulting HParamInfo - to be discrete if the type is string and the number of distinct values is - small enough. - - Returns: - A list of api_pb2.HParamInfo messages. - """ - run_to_tag_to_content = self.multiplexer.PluginRunToTagToContent( - metadata.PLUGIN_NAME) - # Construct a dict mapping an hparam name to its list of values. - hparams = collections.defaultdict(list) - for tag_to_content in run_to_tag_to_content.values(): - if metadata.SESSION_START_INFO_TAG not in tag_to_content: - continue - start_info = metadata.parse_session_start_info_plugin_data( - tag_to_content[metadata.SESSION_START_INFO_TAG]) - for (name, value) in six.iteritems(start_info.hparams): - hparams[name].append(value) - - # Try to construct an HParamInfo for each hparam from its name and list - # of values. - result = [] - for (name, values) in six.iteritems(hparams): - hparam_info = self._compute_hparam_info_from_values(name, values) - if hparam_info is not None: - result.append(hparam_info) - return result - - def _compute_hparam_info_from_values(self, name, values): - """Builds an HParamInfo message from the hparam name and list of values. + This function takes and returns "path-like" strings which are strings + made of strings separated by os.sep. No file access is performed here, so + these strings need not correspond to actual files in some file-system.. + This function returns the longest ancestor path + For example, for path_set=["/foo/bar", "/foo", "/bar/foo"] and + path="/foo/bar/sub_dir", returns "/foo/bar". Args: - name: string. The hparam name. - values: list of google.protobuf.Value messages. The list of values for the - hparam. + path_set: set of path-like strings -- e.g. a list of strings separated by + os.sep. No actual disk-access is performed here, so these need not + correspond to actual files. + path: a path-like string. Returns: - An api_pb2.HParamInfo message. + The element in path_set which is the longest parent directory of 'path'. """ - # Figure out the type from the values. - # Ignore values whose type is not listed in api_pb2.DataType - # If all values have the same type, then that is the type used. - # Otherwise, the returned type is DATA_TYPE_STRING. - result = api_pb2.HParamInfo(name=name, type=api_pb2.DATA_TYPE_UNSET) - distinct_values = set( - _protobuf_value_to_string(v) for v in values if _protobuf_value_type(v)) - for v in values: - v_type = _protobuf_value_type(v) - if not v_type: - continue - if result.type == api_pb2.DATA_TYPE_UNSET: - result.type = v_type - elif result.type != v_type: - result.type = api_pb2.DATA_TYPE_STRING - if result.type == api_pb2.DATA_TYPE_STRING: - # A string result.type does not change, so we can exit the loop. - break - - # If we couldn't figure out a type, then we can't compute the hparam_info. - if result.type == api_pb2.DATA_TYPE_UNSET: - return None - - # If the result is a string, set the domain to be the distinct values if - # there aren't too many of them. - if (result.type == api_pb2.DATA_TYPE_STRING - and len(distinct_values) <= self._max_domain_discrete_len): - result.domain_discrete.extend(distinct_values) - - return result - - def _compute_metric_infos(self): - return (api_pb2.MetricInfo(name=api_pb2.MetricName(group=group, tag=tag)) - for tag, group in self._compute_metric_names()) - - def _compute_metric_names(self): - """Computes the list of metric names from all the scalar (run, tag) pairs. - - The return value is a list of (tag, group) pairs representing the metric - names. The list is sorted in Python tuple-order (lexicographical). - - For example, if the scalar (run, tag) pairs are: - ("exp/session1", "loss") - ("exp/session2", "loss") - ("exp/session2/eval", "loss") - ("exp/session2/validation", "accuracy") - ("exp/no-session", "loss_2"), - and the runs corresponding to sessions are "exp/session1", "exp/session2", - this method will return [("loss", ""), ("loss", "/eval"), ("accuracy", - "/validation")] - - More precisely, each scalar (run, tag) pair is converted to a (tag, group) - metric name, where group is the suffix of run formed by removing the - longest prefix which is a session run. If no session run is a prefix of - 'run', the pair is skipped. - - Returns: - A python list containing pairs. Each pair is a (tag, group) pair - representing a metric name used in some session. - """ - session_runs = self._build_session_runs_set() - metric_names_set = set() - run_to_tag_to_content = self.multiplexer.PluginRunToTagToContent( - scalar_metadata.PLUGIN_NAME) - for (run, tag_to_content) in six.iteritems(run_to_tag_to_content): - session = _find_longest_parent_path(session_runs, run) - if not session: - continue - group = os.path.relpath(run, session) - # relpath() returns "." for the 'session' directory, we use an empty - # string. - if group == ".": - group = "" - metric_names_set.update((tag, group) for tag in tag_to_content.keys()) - metric_names_list = list(metric_names_set) - # Sort metrics for determinism. - metric_names_list.sort() - return metric_names_list - - def _build_session_runs_set(self): - result = set() - run_to_tag_to_content = self.multiplexer.PluginRunToTagToContent( - metadata.PLUGIN_NAME) - for (run, tag_to_content) in six.iteritems(run_to_tag_to_content): - if metadata.SESSION_START_INFO_TAG in tag_to_content: - result.add(run) - return result - - -def _find_longest_parent_path(path_set, path): - """Finds the longest "parent-path" of 'path' in 'path_set'. - - This function takes and returns "path-like" strings which are strings - made of strings separated by os.sep. No file access is performed here, so - these strings need not correspond to actual files in some file-system.. - This function returns the longest ancestor path - For example, for path_set=["/foo/bar", "/foo", "/bar/foo"] and - path="/foo/bar/sub_dir", returns "/foo/bar". - - Args: - path_set: set of path-like strings -- e.g. a list of strings separated by - os.sep. No actual disk-access is performed here, so these need not - correspond to actual files. - path: a path-like string. - - Returns: - The element in path_set which is the longest parent directory of 'path'. - """ - # This could likely be more efficiently implemented with a trie - # data-structure, but we don't want to add an extra dependency for that. - while path not in path_set: - if not path: - return None - path = os.path.dirname(path) - return path + # This could likely be more efficiently implemented with a trie + # data-structure, but we don't want to add an extra dependency for that. + while path not in path_set: + if not path: + return None + path = os.path.dirname(path) + return path def _protobuf_value_type(value): - """Returns the type of the google.protobuf.Value message as an api.DataType. + """Returns the type of the google.protobuf.Value message as an + api.DataType. - Returns None if the type of 'value' is not one of the types supported in - api_pb2.DataType. + Returns None if the type of 'value' is not one of the types supported in + api_pb2.DataType. - Args: - value: google.protobuf.Value message. - """ - if value.HasField("number_value"): - return api_pb2.DATA_TYPE_FLOAT64 - if value.HasField("string_value"): - return api_pb2.DATA_TYPE_STRING - if value.HasField("bool_value"): - return api_pb2.DATA_TYPE_BOOL - return None + Args: + value: google.protobuf.Value message. + """ + if value.HasField("number_value"): + return api_pb2.DATA_TYPE_FLOAT64 + if value.HasField("string_value"): + return api_pb2.DATA_TYPE_STRING + if value.HasField("bool_value"): + return api_pb2.DATA_TYPE_BOOL + return None def _protobuf_value_to_string(value): - """Returns a string representation of given google.protobuf.Value message. - - Args: - value: google.protobuf.Value message. Assumed to be of type 'number', - 'string' or 'bool'. - """ - value_in_json = json_format.MessageToJson(value) - if value.HasField("string_value"): - # Remove the quotations. - return value_in_json[1:-1] - return value_in_json + """Returns a string representation of given google.protobuf.Value message. + + Args: + value: google.protobuf.Value message. Assumed to be of type 'number', + 'string' or 'bool'. + """ + value_in_json = json_format.MessageToJson(value) + if value.HasField("string_value"): + # Remove the quotations. + return value_in_json[1:-1] + return value_in_json diff --git a/tensorboard/plugins/hparams/backend_context_test.py b/tensorboard/plugins/hparams/backend_context_test.py index e34133e80d..190a5c68f6 100644 --- a/tensorboard/plugins/hparams/backend_context_test.py +++ b/tensorboard/plugins/hparams/backend_context_test.py @@ -21,10 +21,10 @@ import operator try: - # python version >= 3.3 - from unittest import mock + # python version >= 3.3 + from unittest import mock except ImportError: - import mock # pylint: disable=unused-import + import mock # pylint: disable=unused-import import tensorflow as tf from google.protobuf import text_format @@ -36,130 +36,108 @@ from tensorboard.plugins.hparams import metadata from tensorboard.plugins.hparams import plugin_data_pb2 -DATA_TYPE_EXPERIMENT = 'experiment' -DATA_TYPE_SESSION_START_INFO = 'session_start_info' -DATA_TYPE_SESSION_END_INFO = 'session_end_info' +DATA_TYPE_EXPERIMENT = "experiment" +DATA_TYPE_SESSION_START_INFO = "session_start_info" +DATA_TYPE_SESSION_END_INFO = "session_end_info" class BackendContextTest(tf.test.TestCase): - # Make assertProtoEquals print all the diff. - maxDiff = None # pylint: disable=invalid-name + # Make assertProtoEquals print all the diff. + maxDiff = None # pylint: disable=invalid-name - def setUp(self): - self._mock_tb_context = mock.create_autospec(base_plugin.TBContext) - self._mock_multiplexer = mock.create_autospec( - plugin_event_multiplexer.EventMultiplexer) - self._mock_tb_context.multiplexer = self._mock_multiplexer - self._mock_multiplexer.PluginRunToTagToContent.side_effect = ( - self._mock_plugin_run_to_tag_to_content) - self.session_1_start_info_ = '' - self.session_2_start_info_ = '' - self.session_3_start_info_ = '' + def setUp(self): + self._mock_tb_context = mock.create_autospec(base_plugin.TBContext) + self._mock_multiplexer = mock.create_autospec( + plugin_event_multiplexer.EventMultiplexer + ) + self._mock_tb_context.multiplexer = self._mock_multiplexer + self._mock_multiplexer.PluginRunToTagToContent.side_effect = ( + self._mock_plugin_run_to_tag_to_content + ) + self.session_1_start_info_ = "" + self.session_2_start_info_ = "" + self.session_3_start_info_ = "" - def _mock_plugin_run_to_tag_to_content(self, plugin_name): - if plugin_name == metadata.PLUGIN_NAME: - return { - 'exp/session_1': { - metadata.SESSION_START_INFO_TAG: - self._serialized_plugin_data( - DATA_TYPE_SESSION_START_INFO, - self.session_1_start_info_ - ), - }, - 'exp/session_2': { - metadata.SESSION_START_INFO_TAG: - self._serialized_plugin_data( - DATA_TYPE_SESSION_START_INFO, - self.session_2_start_info_ - ), - }, - 'exp/session_3': { - metadata.SESSION_START_INFO_TAG: - self._serialized_plugin_data( - DATA_TYPE_SESSION_START_INFO, - self.session_3_start_info_ - ), - }, - } - SCALARS = event_accumulator.SCALARS # pylint: disable=invalid-name - if plugin_name == SCALARS: - return { - # We use None as the content here, since the content is not - # used in the test. - 'exp/session_1': { - 'loss': None, - 'accuracy': None - }, - 'exp/session_1/eval': { - 'loss': None, - }, - 'exp/session_1/train': { - 'loss': None, - }, - 'exp/session_2': { - 'loss': None, - 'accuracy': None, - }, - 'exp/session_2/eval': { - 'loss': None, - }, - 'exp/session_2/train': { - 'loss': None, - }, - 'exp/session_3': { - 'loss': None, - 'accuracy': None, - }, - 'exp/session_3/eval': { - 'loss': None, - }, - 'exp/session_3xyz/': { - 'loss2': None, - }, - } - self.fail("Unexpected plugin_name '%s' passed to" - ' EventMultiplexer.PluginRunToTagToContent' % plugin_name) + def _mock_plugin_run_to_tag_to_content(self, plugin_name): + if plugin_name == metadata.PLUGIN_NAME: + return { + "exp/session_1": { + metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( + DATA_TYPE_SESSION_START_INFO, self.session_1_start_info_ + ), + }, + "exp/session_2": { + metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( + DATA_TYPE_SESSION_START_INFO, self.session_2_start_info_ + ), + }, + "exp/session_3": { + metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( + DATA_TYPE_SESSION_START_INFO, self.session_3_start_info_ + ), + }, + } + SCALARS = event_accumulator.SCALARS # pylint: disable=invalid-name + if plugin_name == SCALARS: + return { + # We use None as the content here, since the content is not + # used in the test. + "exp/session_1": {"loss": None, "accuracy": None}, + "exp/session_1/eval": {"loss": None,}, + "exp/session_1/train": {"loss": None,}, + "exp/session_2": {"loss": None, "accuracy": None,}, + "exp/session_2/eval": {"loss": None,}, + "exp/session_2/train": {"loss": None,}, + "exp/session_3": {"loss": None, "accuracy": None,}, + "exp/session_3/eval": {"loss": None,}, + "exp/session_3xyz/": {"loss2": None,}, + } + self.fail( + "Unexpected plugin_name '%s' passed to" + " EventMultiplexer.PluginRunToTagToContent" % plugin_name + ) - def test_experiment_with_experiment_tag(self): - experiment = """ + def test_experiment_with_experiment_tag(self): + experiment = """ description: 'Test experiment' metric_infos: [ { name: { tag: 'current_temp' } } ] """ - self._mock_multiplexer.PluginRunToTagToContent.side_effect = None - self._mock_multiplexer.PluginRunToTagToContent.return_value = { - 'exp': { - metadata.EXPERIMENT_TAG: - self._serialized_plugin_data(DATA_TYPE_EXPERIMENT, experiment) + self._mock_multiplexer.PluginRunToTagToContent.side_effect = None + self._mock_multiplexer.PluginRunToTagToContent.return_value = { + "exp": { + metadata.EXPERIMENT_TAG: self._serialized_plugin_data( + DATA_TYPE_EXPERIMENT, experiment + ) + } } - } - ctxt = backend_context.Context(self._mock_tb_context) - self.assertProtoEquals(experiment, ctxt.experiment()) + ctxt = backend_context.Context(self._mock_tb_context) + self.assertProtoEquals(experiment, ctxt.experiment()) - def test_experiment_without_experiment_tag(self): - self.session_1_start_info_ = """ + def test_experiment_without_experiment_tag(self): + self.session_1_start_info_ = """ hparams:[ {key: 'batch_size' value: {number_value: 100}}, {key: 'lr' value: {number_value: 0.01}}, {key: 'model_type' value: {string_value: 'CNN'}} ] """ - self.session_2_start_info_ = """ + self.session_2_start_info_ = """ hparams:[ {key: 'batch_size' value: {number_value: 200}}, {key: 'lr' value: {number_value: 0.02}}, {key: 'model_type' value: {string_value: 'LATTICE'}} ] """ - self.session_3_start_info_ = """ + self.session_3_start_info_ = """ hparams:[ {key: 'batch_size' value: {number_value: 300}}, {key: 'lr' value: {number_value: 0.05}}, {key: 'model_type' value: {string_value: 'CNN'}} ] """ - expected_exp = """ + expected_exp = """ hparam_infos: { name: 'batch_size' type: DATA_TYPE_FLOAT64 @@ -189,31 +167,31 @@ def test_experiment_without_experiment_tag(self): name: {group: 'train', tag: 'loss'} } """ - ctxt = backend_context.Context(self._mock_tb_context) - actual_exp = ctxt.experiment() - _canonicalize_experiment(actual_exp) - self.assertProtoEquals(expected_exp, actual_exp) + ctxt = backend_context.Context(self._mock_tb_context) + actual_exp = ctxt.experiment() + _canonicalize_experiment(actual_exp) + self.assertProtoEquals(expected_exp, actual_exp) - def test_experiment_without_experiment_tag_different_hparam_types(self): - self.session_1_start_info_ = """ + def test_experiment_without_experiment_tag_different_hparam_types(self): + self.session_1_start_info_ = """ hparams:[ {key: 'batch_size' value: {number_value: 100}}, {key: 'lr' value: {string_value: '0.01'}} ] """ - self.session_2_start_info_ = """ + self.session_2_start_info_ = """ hparams:[ {key: 'lr' value: {number_value: 0.02}}, {key: 'model_type' value: {string_value: 'LATTICE'}} ] """ - self.session_3_start_info_ = """ + self.session_3_start_info_ = """ hparams:[ {key: 'batch_size' value: {bool_value: true}}, {key: 'model_type' value: {string_value: 'CNN'}} ] """ - expected_exp = """ + expected_exp = """ hparam_infos: { name: 'batch_size' type: DATA_TYPE_STRING @@ -251,31 +229,31 @@ def test_experiment_without_experiment_tag_different_hparam_types(self): name: {group: 'train', tag: 'loss'} } """ - ctxt = backend_context.Context(self._mock_tb_context) - actual_exp = ctxt.experiment() - _canonicalize_experiment(actual_exp) - self.assertProtoEquals(expected_exp, actual_exp) + ctxt = backend_context.Context(self._mock_tb_context) + actual_exp = ctxt.experiment() + _canonicalize_experiment(actual_exp) + self.assertProtoEquals(expected_exp, actual_exp) - def test_experiment_without_experiment_tag_many_distinct_values(self): - self.session_1_start_info_ = """ + def test_experiment_without_experiment_tag_many_distinct_values(self): + self.session_1_start_info_ = """ hparams:[ {key: 'batch_size' value: {number_value: 100}}, {key: 'lr' value: {string_value: '0.01'}} ] """ - self.session_2_start_info_ = """ + self.session_2_start_info_ = """ hparams:[ {key: 'lr' value: {number_value: 0.02}}, {key: 'model_type' value: {string_value: 'CNN'}} ] """ - self.session_3_start_info_ = """ + self.session_3_start_info_ = """ hparams:[ {key: 'batch_size' value: {bool_value: true}}, {key: 'model_type' value: {string_value: 'CNN'}} ] """ - expected_exp = """ + expected_exp = """ hparam_infos: { name: 'batch_size' type: DATA_TYPE_STRING @@ -304,33 +282,37 @@ def test_experiment_without_experiment_tag_many_distinct_values(self): name: {group: 'train', tag: 'loss'} } """ - ctxt = backend_context.Context(self._mock_tb_context, - max_domain_discrete_len=1) - actual_exp = ctxt.experiment() - _canonicalize_experiment(actual_exp) - self.assertProtoEquals(expected_exp, actual_exp) + ctxt = backend_context.Context( + self._mock_tb_context, max_domain_discrete_len=1 + ) + actual_exp = ctxt.experiment() + _canonicalize_experiment(actual_exp) + self.assertProtoEquals(expected_exp, actual_exp) - def _serialized_plugin_data(self, data_oneof_field, text_protobuffer): - oneof_type_dict = { - DATA_TYPE_EXPERIMENT: api_pb2.Experiment, - DATA_TYPE_SESSION_START_INFO: plugin_data_pb2.SessionStartInfo, - DATA_TYPE_SESSION_END_INFO: plugin_data_pb2.SessionEndInfo - } - protobuffer = text_format.Merge(text_protobuffer, - oneof_type_dict[data_oneof_field]()) - plugin_data = plugin_data_pb2.HParamsPluginData() - getattr(plugin_data, data_oneof_field).CopyFrom(protobuffer) - return metadata.create_summary_metadata(plugin_data).plugin_data.content + def _serialized_plugin_data(self, data_oneof_field, text_protobuffer): + oneof_type_dict = { + DATA_TYPE_EXPERIMENT: api_pb2.Experiment, + DATA_TYPE_SESSION_START_INFO: plugin_data_pb2.SessionStartInfo, + DATA_TYPE_SESSION_END_INFO: plugin_data_pb2.SessionEndInfo, + } + protobuffer = text_format.Merge( + text_protobuffer, oneof_type_dict[data_oneof_field]() + ) + plugin_data = plugin_data_pb2.HParamsPluginData() + getattr(plugin_data, data_oneof_field).CopyFrom(protobuffer) + return metadata.create_summary_metadata(plugin_data).plugin_data.content def _canonicalize_experiment(exp): - """Sorts the repeated fields of an Experiment message.""" - exp.hparam_infos.sort(key=operator.attrgetter('name')) - exp.metric_infos.sort(key=operator.attrgetter('name.group', 'name.tag')) - for hparam_info in exp.hparam_infos: - if hparam_info.HasField('domain_discrete'): - hparam_info.domain_discrete.values.sort( - key=operator.attrgetter('string_value')) + """Sorts the repeated fields of an Experiment message.""" + exp.hparam_infos.sort(key=operator.attrgetter("name")) + exp.metric_infos.sort(key=operator.attrgetter("name.group", "name.tag")) + for hparam_info in exp.hparam_infos: + if hparam_info.HasField("domain_discrete"): + hparam_info.domain_discrete.values.sort( + key=operator.attrgetter("string_value") + ) + -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/hparams/error.py b/tensorboard/plugins/hparams/error.py index b88446a825..02c7b89658 100644 --- a/tensorboard/plugins/hparams/error.py +++ b/tensorboard/plugins/hparams/error.py @@ -16,9 +16,11 @@ class HParamsError(Exception): - """Represents an error that is meaningful to the end-user. Such an error - should have a meaningful error message. Other errors, (such as resulting - from some internal invariants being violated) should be represented by - other exceptions. - """ - pass + """Represents an error that is meaningful to the end-user. + + Such an error should have a meaningful error message. Other errors, + (such as resulting from some internal invariants being violated) + should be represented by other exceptions. + """ + + pass diff --git a/tensorboard/plugins/hparams/get_experiment.py b/tensorboard/plugins/hparams/get_experiment.py index b690111d99..91dbf29f52 100644 --- a/tensorboard/plugins/hparams/get_experiment.py +++ b/tensorboard/plugins/hparams/get_experiment.py @@ -23,30 +23,30 @@ class Handler(object): - """Handles a GetExperiment request. """ - - def __init__(self, context): - """Constructor. - - Args: - context: A backend_context.Context instance. - """ - self._context = context - - def run(self): - """Handles the request specified on construction. - - Returns: - An Experiment object. - - """ - experiment = self._context.experiment() - if experiment is None: - raise error.HParamsError( - "Can't find an HParams-plugin experiment data in" - " the log directory. Note that it takes some time to" - " scan the log directory; if you just started" - " Tensorboard it could be that we haven't finished" - " scanning it yet. Consider trying again in a" - " few seconds.") - return experiment + """Handles a GetExperiment request.""" + + def __init__(self, context): + """Constructor. + + Args: + context: A backend_context.Context instance. + """ + self._context = context + + def run(self): + """Handles the request specified on construction. + + Returns: + An Experiment object. + """ + experiment = self._context.experiment() + if experiment is None: + raise error.HParamsError( + "Can't find an HParams-plugin experiment data in" + " the log directory. Note that it takes some time to" + " scan the log directory; if you just started" + " Tensorboard it could be that we haven't finished" + " scanning it yet. Consider trying again in a" + " few seconds." + ) + return experiment diff --git a/tensorboard/plugins/hparams/hparams_demo.py b/tensorboard/plugins/hparams/hparams_demo.py index ac4e762bcb..0d669f33d4 100644 --- a/tensorboard/plugins/hparams/hparams_demo.py +++ b/tensorboard/plugins/hparams/hparams_demo.py @@ -39,10 +39,10 @@ if int(tf.__version__.split(".")[0]) < 2: - # The tag names emitted for Keras metrics changed from "acc" (in 1.x) - # to "accuracy" (in 2.x), so this demo does not work properly in - # TensorFlow 1.x (even with `tf.enable_eager_execution()`). - raise ImportError("TensorFlow 2.x is required to run this demo.") + # The tag names emitted for Keras metrics changed from "acc" (in 1.x) + # to "accuracy" (in 2.x), so this demo does not work properly in + # TensorFlow 1.x (even with `tf.enable_eager_execution()`). + raise ImportError("TensorFlow 2.x is required to run this demo.") flags.DEFINE_integer( @@ -59,12 +59,10 @@ "summary_freq", 600, "Summaries will be written every n steps, where n is the value of " - "this flag.", + "this flag.", ) flags.DEFINE_integer( - "num_epochs", - 5, - "Number of epochs per trial.", + "num_epochs", 5, "Number of epochs per trial.", ) @@ -89,164 +87,154 @@ METRICS = [ hp.Metric( - "epoch_accuracy", - group="validation", - display_name="accuracy (val.)", + "epoch_accuracy", group="validation", display_name="accuracy (val.)", ), + hp.Metric("epoch_loss", group="validation", display_name="loss (val.)",), hp.Metric( - "epoch_loss", - group="validation", - display_name="loss (val.)", - ), - hp.Metric( - "batch_accuracy", - group="train", - display_name="accuracy (train)", - ), - hp.Metric( - "batch_loss", - group="train", - display_name="loss (train)", + "batch_accuracy", group="train", display_name="accuracy (train)", ), + hp.Metric("batch_loss", group="train", display_name="loss (train)",), ] def model_fn(hparams, seed): - """Create a Keras model with the given hyperparameters. - - Args: - hparams: A dict mapping hyperparameters in `HPARAMS` to values. - seed: A hashable object to be used as a random seed (e.g., to - construct dropout layers in the model). - - Returns: - A compiled Keras model. - """ - rng = random.Random(seed) - - model = tf.keras.models.Sequential() - model.add(tf.keras.layers.Input(INPUT_SHAPE)) - model.add(tf.keras.layers.Reshape(INPUT_SHAPE + (1,))) # grayscale channel - - # Add convolutional layers. - conv_filters = 8 - for _ in xrange(hparams[HP_CONV_LAYERS]): - model.add(tf.keras.layers.Conv2D( - filters=conv_filters, - kernel_size=hparams[HP_CONV_KERNEL_SIZE], - padding="same", - activation="relu", - )) - model.add(tf.keras.layers.MaxPool2D(pool_size=2, padding="same")) - conv_filters *= 2 - - model.add(tf.keras.layers.Flatten()) - model.add(tf.keras.layers.Dropout(hparams[HP_DROPOUT], seed=rng.random())) - - # Add fully connected layers. - dense_neurons = 32 - for _ in xrange(hparams[HP_DENSE_LAYERS]): - model.add(tf.keras.layers.Dense(dense_neurons, activation="relu")) - dense_neurons *= 2 - - # Add the final output layer. - model.add(tf.keras.layers.Dense(OUTPUT_CLASSES, activation="softmax")) - - model.compile( - loss="sparse_categorical_crossentropy", - optimizer=hparams[HP_OPTIMIZER], - metrics=["accuracy"], - ) - return model + """Create a Keras model with the given hyperparameters. + + Args: + hparams: A dict mapping hyperparameters in `HPARAMS` to values. + seed: A hashable object to be used as a random seed (e.g., to + construct dropout layers in the model). + + Returns: + A compiled Keras model. + """ + rng = random.Random(seed) + + model = tf.keras.models.Sequential() + model.add(tf.keras.layers.Input(INPUT_SHAPE)) + model.add(tf.keras.layers.Reshape(INPUT_SHAPE + (1,))) # grayscale channel + + # Add convolutional layers. + conv_filters = 8 + for _ in xrange(hparams[HP_CONV_LAYERS]): + model.add( + tf.keras.layers.Conv2D( + filters=conv_filters, + kernel_size=hparams[HP_CONV_KERNEL_SIZE], + padding="same", + activation="relu", + ) + ) + model.add(tf.keras.layers.MaxPool2D(pool_size=2, padding="same")) + conv_filters *= 2 + + model.add(tf.keras.layers.Flatten()) + model.add(tf.keras.layers.Dropout(hparams[HP_DROPOUT], seed=rng.random())) + + # Add fully connected layers. + dense_neurons = 32 + for _ in xrange(hparams[HP_DENSE_LAYERS]): + model.add(tf.keras.layers.Dense(dense_neurons, activation="relu")) + dense_neurons *= 2 + + # Add the final output layer. + model.add(tf.keras.layers.Dense(OUTPUT_CLASSES, activation="softmax")) + + model.compile( + loss="sparse_categorical_crossentropy", + optimizer=hparams[HP_OPTIMIZER], + metrics=["accuracy"], + ) + return model def run(data, base_logdir, session_id, hparams): - """Run a training/validation session. - - Flags must have been parsed for this function to behave. - - Args: - data: The data as loaded by `prepare_data()`. - base_logdir: The top-level logdir to which to write summary data. - session_id: A unique string ID for this session. - hparams: A dict mapping hyperparameters in `HPARAMS` to values. - """ - model = model_fn(hparams=hparams, seed=session_id) - logdir = os.path.join(base_logdir, session_id) - - callback = tf.keras.callbacks.TensorBoard( - logdir, - update_freq=flags.FLAGS.summary_freq, - profile_batch=0, # workaround for issue #2084 - ) - hparams_callback = hp.KerasCallback(logdir, hparams) - ((x_train, y_train), (x_test, y_test)) = data - result = model.fit( - x=x_train, - y=y_train, - epochs=flags.FLAGS.num_epochs, - shuffle=False, - validation_data=(x_test, y_test), - callbacks=[callback, hparams_callback], - ) + """Run a training/validation session. + + Flags must have been parsed for this function to behave. + + Args: + data: The data as loaded by `prepare_data()`. + base_logdir: The top-level logdir to which to write summary data. + session_id: A unique string ID for this session. + hparams: A dict mapping hyperparameters in `HPARAMS` to values. + """ + model = model_fn(hparams=hparams, seed=session_id) + logdir = os.path.join(base_logdir, session_id) + + callback = tf.keras.callbacks.TensorBoard( + logdir, + update_freq=flags.FLAGS.summary_freq, + profile_batch=0, # workaround for issue #2084 + ) + hparams_callback = hp.KerasCallback(logdir, hparams) + ((x_train, y_train), (x_test, y_test)) = data + result = model.fit( + x=x_train, + y=y_train, + epochs=flags.FLAGS.num_epochs, + shuffle=False, + validation_data=(x_test, y_test), + callbacks=[callback, hparams_callback], + ) def prepare_data(): - """Load and normalize data.""" - ((x_train, y_train), (x_test, y_test)) = DATASET.load_data() - x_train = x_train.astype("float32") - x_test = x_test.astype("float32") - x_train /= 255.0 - x_test /= 255.0 - return ((x_train, y_train), (x_test, y_test)) + """Load and normalize data.""" + ((x_train, y_train), (x_test, y_test)) = DATASET.load_data() + x_train = x_train.astype("float32") + x_test = x_test.astype("float32") + x_train /= 255.0 + x_test /= 255.0 + return ((x_train, y_train), (x_test, y_test)) def run_all(logdir, verbose=False): - """Perform random search over the hyperparameter space. - - Arguments: - logdir: The top-level directory into which to write data. This - directory should be empty or nonexistent. - verbose: If true, print out each run's name as it begins. - """ - data = prepare_data() - rng = random.Random(0) - - with tf.summary.create_file_writer(logdir).as_default(): - hp.hparams_config(hparams=HPARAMS, metrics=METRICS) - - sessions_per_group = 2 - num_sessions = flags.FLAGS.num_session_groups * sessions_per_group - session_index = 0 # across all session groups - for group_index in xrange(flags.FLAGS.num_session_groups): - hparams = {h: h.domain.sample_uniform(rng) for h in HPARAMS} - hparams_string = str(hparams) - for repeat_index in xrange(sessions_per_group): - session_id = str(session_index) - session_index += 1 - if verbose: - print( - "--- Running training session %d/%d" - % (session_index, num_sessions) - ) - print(hparams_string) - print("--- repeat #: %d" % (repeat_index + 1)) - run( - data=data, - base_logdir=logdir, - session_id=session_id, - hparams=hparams, - ) + """Perform random search over the hyperparameter space. + + Arguments: + logdir: The top-level directory into which to write data. This + directory should be empty or nonexistent. + verbose: If true, print out each run's name as it begins. + """ + data = prepare_data() + rng = random.Random(0) + + with tf.summary.create_file_writer(logdir).as_default(): + hp.hparams_config(hparams=HPARAMS, metrics=METRICS) + + sessions_per_group = 2 + num_sessions = flags.FLAGS.num_session_groups * sessions_per_group + session_index = 0 # across all session groups + for group_index in xrange(flags.FLAGS.num_session_groups): + hparams = {h: h.domain.sample_uniform(rng) for h in HPARAMS} + hparams_string = str(hparams) + for repeat_index in xrange(sessions_per_group): + session_id = str(session_index) + session_index += 1 + if verbose: + print( + "--- Running training session %d/%d" + % (session_index, num_sessions) + ) + print(hparams_string) + print("--- repeat #: %d" % (repeat_index + 1)) + run( + data=data, + base_logdir=logdir, + session_id=session_id, + hparams=hparams, + ) def main(unused_argv): - np.random.seed(0) - logdir = flags.FLAGS.logdir - shutil.rmtree(logdir, ignore_errors=True) - print("Saving output to %s." % logdir) - run_all(logdir=logdir, verbose=True) - print("Done. Output saved to %s." % logdir) + np.random.seed(0) + logdir = flags.FLAGS.logdir + shutil.rmtree(logdir, ignore_errors=True) + print("Saving output to %s." % logdir) + run_all(logdir=logdir, verbose=True) + print("Done. Output saved to %s." % logdir) if __name__ == "__main__": - app.run(main) + app.run(main) diff --git a/tensorboard/plugins/hparams/hparams_minimal_demo.py b/tensorboard/plugins/hparams/hparams_minimal_demo.py index ffd648d82d..6fc8e640ba 100644 --- a/tensorboard/plugins/hparams/hparams_minimal_demo.py +++ b/tensorboard/plugins/hparams/hparams_minimal_demo.py @@ -55,23 +55,27 @@ FLAGS = flags.FLAGS -flags.DEFINE_integer('num_session_groups', 50, - 'The approximate number of session groups to create.') -flags.DEFINE_string('logdir', '/tmp/hparams_minimal_demo', - 'The directory to write the summary information to.') -flags.DEFINE_integer('summary_freq', 1, - 'Summaries will be every n steps, ' - 'where n is the value of this flag.') -flags.DEFINE_integer('num_steps', 100, - 'Number of steps per trial.') +flags.DEFINE_integer( + "num_session_groups", + 50, + "The approximate number of session groups to create.", +) +flags.DEFINE_string( + "logdir", + "/tmp/hparams_minimal_demo", + "The directory to write the summary information to.", +) +flags.DEFINE_integer( + "summary_freq", + 1, + "Summaries will be every n steps, " "where n is the value of this flag.", +) +flags.DEFINE_integer("num_steps", 100, "Number of steps per trial.") # Total number of sessions is given by: # len(TEMPERATURE_LIST)^2 * len(HEAT_COEFFICIENTS) * 2 -HEAT_COEFFICIENTS = { - 'water': 0.001, - 'air': 0.003 -} +HEAT_COEFFICIENTS = {"water": 0.001, "air": 0.003} TEMPERATURE_LIST = [] @@ -79,191 +83,223 @@ # depends on a flag and flag parsing hasn't happened yet. Instead, we use # a function that we call in main() below. def init_temperature_list(): - global TEMPERATURE_LIST - TEMPERATURE_LIST = [ - 270+i*50.0 - for i in xrange( - 0, int(math.sqrt(FLAGS.num_session_groups/len(HEAT_COEFFICIENTS))))] + global TEMPERATURE_LIST + TEMPERATURE_LIST = [ + 270 + i * 50.0 + for i in xrange( + 0, int(math.sqrt(FLAGS.num_session_groups / len(HEAT_COEFFICIENTS))) + ) + ] def fingerprint(string): - m = hashlib.md5() - m.update(string.encode('utf-8')) - return m.hexdigest() + m = hashlib.md5() + m.update(string.encode("utf-8")) + return m.hexdigest() def create_experiment_summary(): - """Returns a summary proto buffer holding this experiment.""" - - # Convert TEMPERATURE_LIST to google.protobuf.ListValue - temperature_list = struct_pb2.ListValue() - temperature_list.extend(TEMPERATURE_LIST) - materials = struct_pb2.ListValue() - materials.extend(HEAT_COEFFICIENTS.keys()) - return summary.experiment_pb( - hparam_infos=[ - api_pb2.HParamInfo(name='initial_temperature', - display_name='Initial temperature', - type=api_pb2.DATA_TYPE_FLOAT64, - domain_discrete=temperature_list), - api_pb2.HParamInfo(name='ambient_temperature', - display_name='Ambient temperature', - type=api_pb2.DATA_TYPE_FLOAT64, - domain_discrete=temperature_list), - api_pb2.HParamInfo(name='material', - display_name='Material', - type=api_pb2.DATA_TYPE_STRING, - domain_discrete=materials) - ], - metric_infos=[ - api_pb2.MetricInfo( - name=api_pb2.MetricName( - tag='temperature/current/scalar_summary'), - display_name='Current Temp.'), - api_pb2.MetricInfo( - name=api_pb2.MetricName( - tag='temperature/difference_to_ambient/scalar_summary'), - display_name='Difference To Ambient Temp.'), - api_pb2.MetricInfo( - name=api_pb2.MetricName( - tag='delta/scalar_summary'), - display_name='Delta T') - ] - ) + """Returns a summary proto buffer holding this experiment.""" + + # Convert TEMPERATURE_LIST to google.protobuf.ListValue + temperature_list = struct_pb2.ListValue() + temperature_list.extend(TEMPERATURE_LIST) + materials = struct_pb2.ListValue() + materials.extend(HEAT_COEFFICIENTS.keys()) + return summary.experiment_pb( + hparam_infos=[ + api_pb2.HParamInfo( + name="initial_temperature", + display_name="Initial temperature", + type=api_pb2.DATA_TYPE_FLOAT64, + domain_discrete=temperature_list, + ), + api_pb2.HParamInfo( + name="ambient_temperature", + display_name="Ambient temperature", + type=api_pb2.DATA_TYPE_FLOAT64, + domain_discrete=temperature_list, + ), + api_pb2.HParamInfo( + name="material", + display_name="Material", + type=api_pb2.DATA_TYPE_STRING, + domain_discrete=materials, + ), + ], + metric_infos=[ + api_pb2.MetricInfo( + name=api_pb2.MetricName( + tag="temperature/current/scalar_summary" + ), + display_name="Current Temp.", + ), + api_pb2.MetricInfo( + name=api_pb2.MetricName( + tag="temperature/difference_to_ambient/scalar_summary" + ), + display_name="Difference To Ambient Temp.", + ), + api_pb2.MetricInfo( + name=api_pb2.MetricName(tag="delta/scalar_summary"), + display_name="Delta T", + ), + ], + ) def run(logdir, session_id, hparams, group_name): - """Runs a temperature simulation. - - This will simulate an object at temperature `initial_temperature` - sitting at rest in a large room at temperature `ambient_temperature`. - The object has some intrinsic `heat_coefficient`, which indicates - how much thermal conductivity it has: for instance, metals have high - thermal conductivity, while the thermal conductivity of water is low. - - Over time, the object's temperature will adjust to match the - temperature of its environment. We'll track the object's temperature, - how far it is from the room's temperature, and how much it changes at - each time step. - - Arguments: - logdir: the top-level directory into which to write summary data - session_id: an id for the session. - hparams: A dictionary mapping a hyperparameter name to its value. - group_name: an id for the session group this session belongs to. - """ - tf.reset_default_graph() - tf.set_random_seed(0) - - initial_temperature = hparams['initial_temperature'] - ambient_temperature = hparams['ambient_temperature'] - heat_coefficient = HEAT_COEFFICIENTS[hparams['material']] - session_dir = os.path.join(logdir, session_id) - writer = tf.summary.FileWriter(session_dir) - writer.add_summary(summary.session_start_pb(hparams=hparams, - group_name=group_name)) - writer.flush() - with tf.name_scope('temperature'): - # Create a mutable variable to hold the object's temperature, and - # create a scalar summary to track its value over time. The name of - # the summary will appear as 'temperature/current' due to the - # name-scope above. - temperature = tf.Variable( - tf.constant(initial_temperature), - name='temperature') - scalar_summary.op('current', temperature, - display_name='Temperature', - description='The temperature of the object under ' - 'simulation, in Kelvins.') - - # Compute how much the object's temperature differs from that of its - # environment, and track this, too: likewise, as - # 'temperature/difference_to_ambient'. - ambient_difference = temperature - ambient_temperature - scalar_summary.op('difference_to_ambient', ambient_difference, - display_name='Difference to ambient temperature', - description=('The difference between the ambient ' - 'temperature and the temperature of the ' - 'object under simulation, in Kelvins.')) - - # Newton suggested that the rate of change of the temperature of an - # object is directly proportional to this `ambient_difference` above, - # where the proportionality constant is what we called the heat - # coefficient. But in real life, not everything is quite so clean, so - # we'll add in some noise. (The value of 50 is arbitrary, chosen to - # make the data look somewhat interesting. :-) ) - noise = 50 * tf.random.normal([]) - delta = -heat_coefficient * (ambient_difference + noise) - scalar_summary.op('delta', delta, - description='The change in temperature from the previous ' - 'step, in Kelvins.') - - # Collect all the scalars that we want to keep track of. - summ = tf.summary.merge_all() - - # Now, augment the current temperature by this delta that we computed, - # blocking the assignment on summary collection to avoid race conditions - # and ensure that the summary always reports the pre-update value. - with tf.control_dependencies([summ]): - update_step = temperature.assign_add(delta) - - sess = tf.Session() - sess.run(tf.global_variables_initializer()) - for step in xrange(FLAGS.num_steps): - # By asking TensorFlow to compute the update step, we force it to - # change the value of the temperature variable. We don't actually - # care about this value, so we discard it; instead, we grab the - # summary data computed along the way. - (s, _) = sess.run([summ, update_step]) - if (step % FLAGS.summary_freq) == 0: - writer.add_summary(s, global_step=step) - writer.add_summary(summary.session_end_pb(api_pb2.STATUS_SUCCESS)) - writer.close() + """Runs a temperature simulation. + + This will simulate an object at temperature `initial_temperature` + sitting at rest in a large room at temperature `ambient_temperature`. + The object has some intrinsic `heat_coefficient`, which indicates + how much thermal conductivity it has: for instance, metals have high + thermal conductivity, while the thermal conductivity of water is low. + + Over time, the object's temperature will adjust to match the + temperature of its environment. We'll track the object's temperature, + how far it is from the room's temperature, and how much it changes at + each time step. + + Arguments: + logdir: the top-level directory into which to write summary data + session_id: an id for the session. + hparams: A dictionary mapping a hyperparameter name to its value. + group_name: an id for the session group this session belongs to. + """ + tf.reset_default_graph() + tf.set_random_seed(0) + + initial_temperature = hparams["initial_temperature"] + ambient_temperature = hparams["ambient_temperature"] + heat_coefficient = HEAT_COEFFICIENTS[hparams["material"]] + session_dir = os.path.join(logdir, session_id) + writer = tf.summary.FileWriter(session_dir) + writer.add_summary( + summary.session_start_pb(hparams=hparams, group_name=group_name) + ) + writer.flush() + with tf.name_scope("temperature"): + # Create a mutable variable to hold the object's temperature, and + # create a scalar summary to track its value over time. The name of + # the summary will appear as 'temperature/current' due to the + # name-scope above. + temperature = tf.Variable( + tf.constant(initial_temperature), name="temperature" + ) + scalar_summary.op( + "current", + temperature, + display_name="Temperature", + description="The temperature of the object under " + "simulation, in Kelvins.", + ) + + # Compute how much the object's temperature differs from that of its + # environment, and track this, too: likewise, as + # 'temperature/difference_to_ambient'. + ambient_difference = temperature - ambient_temperature + scalar_summary.op( + "difference_to_ambient", + ambient_difference, + display_name="Difference to ambient temperature", + description=( + "The difference between the ambient " + "temperature and the temperature of the " + "object under simulation, in Kelvins." + ), + ) + + # Newton suggested that the rate of change of the temperature of an + # object is directly proportional to this `ambient_difference` above, + # where the proportionality constant is what we called the heat + # coefficient. But in real life, not everything is quite so clean, so + # we'll add in some noise. (The value of 50 is arbitrary, chosen to + # make the data look somewhat interesting. :-) ) + noise = 50 * tf.random.normal([]) + delta = -heat_coefficient * (ambient_difference + noise) + scalar_summary.op( + "delta", + delta, + description="The change in temperature from the previous " + "step, in Kelvins.", + ) + + # Collect all the scalars that we want to keep track of. + summ = tf.summary.merge_all() + + # Now, augment the current temperature by this delta that we computed, + # blocking the assignment on summary collection to avoid race conditions + # and ensure that the summary always reports the pre-update value. + with tf.control_dependencies([summ]): + update_step = temperature.assign_add(delta) + + sess = tf.Session() + sess.run(tf.global_variables_initializer()) + for step in xrange(FLAGS.num_steps): + # By asking TensorFlow to compute the update step, we force it to + # change the value of the temperature variable. We don't actually + # care about this value, so we discard it; instead, we grab the + # summary data computed along the way. + (s, _) = sess.run([summ, update_step]) + if (step % FLAGS.summary_freq) == 0: + writer.add_summary(s, global_step=step) + writer.add_summary(summary.session_end_pb(api_pb2.STATUS_SUCCESS)) + writer.close() def run_all(logdir, verbose=False): - """Run simulations on a reasonable set of parameters. - - Arguments: - logdir: the directory into which to store all the runs' data - verbose: if true, print out each run's name as it begins. - """ - writer = tf.summary.FileWriter(logdir) - writer.add_summary(create_experiment_summary()) - writer.close() - session_num = 0 - num_sessions = (len(TEMPERATURE_LIST)*len(TEMPERATURE_LIST)* - len(HEAT_COEFFICIENTS)*2) - for initial_temperature in TEMPERATURE_LIST: - for ambient_temperature in TEMPERATURE_LIST: - for material in HEAT_COEFFICIENTS: - hparams = {u'initial_temperature': initial_temperature, - u'ambient_temperature': ambient_temperature, - u'material': material} - hparam_str = str(hparams) - group_name = fingerprint(hparam_str) - for repeat_idx in xrange(2): - session_id = str(session_num) - if verbose: - print('--- Running training session %d/%d' % (session_num + 1, - num_sessions)) - print(hparam_str) - print('--- repeat #: %d' % (repeat_idx+1)) - run(logdir, session_id, hparams, group_name) - session_num += 1 + """Run simulations on a reasonable set of parameters. + + Arguments: + logdir: the directory into which to store all the runs' data + verbose: if true, print out each run's name as it begins. + """ + writer = tf.summary.FileWriter(logdir) + writer.add_summary(create_experiment_summary()) + writer.close() + session_num = 0 + num_sessions = ( + len(TEMPERATURE_LIST) + * len(TEMPERATURE_LIST) + * len(HEAT_COEFFICIENTS) + * 2 + ) + for initial_temperature in TEMPERATURE_LIST: + for ambient_temperature in TEMPERATURE_LIST: + for material in HEAT_COEFFICIENTS: + hparams = { + u"initial_temperature": initial_temperature, + u"ambient_temperature": ambient_temperature, + u"material": material, + } + hparam_str = str(hparams) + group_name = fingerprint(hparam_str) + for repeat_idx in xrange(2): + session_id = str(session_num) + if verbose: + print( + "--- Running training session %d/%d" + % (session_num + 1, num_sessions) + ) + print(hparam_str) + print("--- repeat #: %d" % (repeat_idx + 1)) + run(logdir, session_id, hparams, group_name) + session_num += 1 def main(unused_argv): - if tf.executing_eagerly(): - print('Sorry, this demo currently can\'t be run in eager mode.') - return + if tf.executing_eagerly(): + print("Sorry, this demo currently can't be run in eager mode.") + return - init_temperature_list() - shutil.rmtree(FLAGS.logdir, ignore_errors=True) - print('Saving output to %s.' % FLAGS.logdir) - run_all(FLAGS.logdir, verbose=True) - print('Done. Output saved to %s.' % FLAGS.logdir) + init_temperature_list() + shutil.rmtree(FLAGS.logdir, ignore_errors=True) + print("Saving output to %s." % FLAGS.logdir) + run_all(FLAGS.logdir, verbose=True) + print("Done. Output saved to %s." % FLAGS.logdir) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/tensorboard/plugins/hparams/hparams_plugin.py b/tensorboard/plugins/hparams/hparams_plugin.py index a04e2e885b..89ca26f3bb 100644 --- a/tensorboard/plugins/hparams/hparams_plugin.py +++ b/tensorboard/plugins/hparams/hparams_plugin.py @@ -14,8 +14,8 @@ # ============================================================================== """The TensorBoard HParams plugin. -See `http_api.md` in this directory for specifications of the routes for this -plugin. +See `http_api.md` in this directory for specifications of the routes for +this plugin. """ from __future__ import absolute_import @@ -45,119 +45,138 @@ class HParamsPlugin(base_plugin.TBPlugin): - """HParams Plugin for TensorBoard. - It supports both GETs and POSTs. See 'http_api.md' for more details. - """ + """HParams Plugin for TensorBoard. - plugin_name = metadata.PLUGIN_NAME - - def __init__(self, context): - """Instantiates HParams plugin via TensorBoard core. - - Args: - context: A base_plugin.TBContext instance. + It supports both GETs and POSTs. See 'http_api.md' for more details. """ - self._context = backend_context.Context(context) - - def get_plugin_apps(self): - """See base class.""" - - return { - '/experiment': self.get_experiment_route, - '/session_groups': self.list_session_groups_route, - '/metric_evals': self.list_metric_evals_route, - } - def is_active(self): - """Returns True if the hparams plugin is active. - - The hparams plugin is active iff there is a tag with - the hparams plugin name as its plugin name and the scalars plugin is - registered and active. - """ - if not self._context.multiplexer: - return False - scalars_plugin = self._get_scalars_plugin() - if not scalars_plugin or not scalars_plugin.is_active(): - return False - return bool(self._context.multiplexer.PluginRunToTagToContent( - metadata.PLUGIN_NAME)) - - def frontend_metadata(self): - return base_plugin.FrontendMetadata(element_name='tf-hparams-dashboard') - - # ---- /experiment ----------------------------------------------------------- - @wrappers.Request.application - def get_experiment_route(self, request): - try: - # This backend currently ignores the request parameters, but (for a POST) - # we must advance the input stream to skip them -- otherwise the next HTTP - # request will be parsed incorrectly. - _ = _parse_request_argument(request, api_pb2.GetExperimentRequest) - return http_util.Respond( - request, - json_format.MessageToJson( - get_experiment.Handler(self._context).run(), - including_default_value_fields=True, - ), 'application/json') - except error.HParamsError as e: - logger.error('HParams error: %s' % e) - raise werkzeug.exceptions.BadRequest(description=str(e)) - - # ---- /session_groups ------------------------------------------------------- - @wrappers.Request.application - def list_session_groups_route(self, request): - try: - request_proto = _parse_request_argument( - request, api_pb2.ListSessionGroupsRequest) - return http_util.Respond( - request, - json_format.MessageToJson( - list_session_groups.Handler(self._context, request_proto).run(), - including_default_value_fields=True, - ), - 'application/json') - except error.HParamsError as e: - logger.error('HParams error: %s' % e) - raise werkzeug.exceptions.BadRequest(description=str(e)) - - # ---- /metric_evals --------------------------------------------------------- - @wrappers.Request.application - def list_metric_evals_route(self, request): - try: - request_proto = _parse_request_argument( - request, api_pb2.ListMetricEvalsRequest) - scalars_plugin = self._get_scalars_plugin() - if not scalars_plugin: - raise error.HParamsError('Internal error: the scalars plugin is not' - ' registered; yet, the hparams plugin is' - ' active.') - return http_util.Respond( - request, - json.dumps( - list_metric_evals.Handler(request_proto, scalars_plugin).run()), - 'application/json') - except error.HParamsError as e: - logger.error('HParams error: %s' % e) - raise werkzeug.exceptions.BadRequest(description=str(e)) - - def _get_scalars_plugin(self): - """Tries to get the scalars plugin. - - Returns: - The scalars plugin or None if it is not yet registered. - """ - return self._context.tb_context.plugin_name_to_instance.get( - scalars_metadata.PLUGIN_NAME) + plugin_name = metadata.PLUGIN_NAME + + def __init__(self, context): + """Instantiates HParams plugin via TensorBoard core. + + Args: + context: A base_plugin.TBContext instance. + """ + self._context = backend_context.Context(context) + + def get_plugin_apps(self): + """See base class.""" + + return { + "/experiment": self.get_experiment_route, + "/session_groups": self.list_session_groups_route, + "/metric_evals": self.list_metric_evals_route, + } + + def is_active(self): + """Returns True if the hparams plugin is active. + + The hparams plugin is active iff there is a tag with the hparams + plugin name as its plugin name and the scalars plugin is + registered and active. + """ + if not self._context.multiplexer: + return False + scalars_plugin = self._get_scalars_plugin() + if not scalars_plugin or not scalars_plugin.is_active(): + return False + return bool( + self._context.multiplexer.PluginRunToTagToContent( + metadata.PLUGIN_NAME + ) + ) + + def frontend_metadata(self): + return base_plugin.FrontendMetadata(element_name="tf-hparams-dashboard") + + # ---- /experiment ----------------------------------------------------------- + @wrappers.Request.application + def get_experiment_route(self, request): + try: + # This backend currently ignores the request parameters, but (for a POST) + # we must advance the input stream to skip them -- otherwise the next HTTP + # request will be parsed incorrectly. + _ = _parse_request_argument(request, api_pb2.GetExperimentRequest) + return http_util.Respond( + request, + json_format.MessageToJson( + get_experiment.Handler(self._context).run(), + including_default_value_fields=True, + ), + "application/json", + ) + except error.HParamsError as e: + logger.error("HParams error: %s" % e) + raise werkzeug.exceptions.BadRequest(description=str(e)) + + # ---- /session_groups ------------------------------------------------------- + @wrappers.Request.application + def list_session_groups_route(self, request): + try: + request_proto = _parse_request_argument( + request, api_pb2.ListSessionGroupsRequest + ) + return http_util.Respond( + request, + json_format.MessageToJson( + list_session_groups.Handler( + self._context, request_proto + ).run(), + including_default_value_fields=True, + ), + "application/json", + ) + except error.HParamsError as e: + logger.error("HParams error: %s" % e) + raise werkzeug.exceptions.BadRequest(description=str(e)) + + # ---- /metric_evals --------------------------------------------------------- + @wrappers.Request.application + def list_metric_evals_route(self, request): + try: + request_proto = _parse_request_argument( + request, api_pb2.ListMetricEvalsRequest + ) + scalars_plugin = self._get_scalars_plugin() + if not scalars_plugin: + raise error.HParamsError( + "Internal error: the scalars plugin is not" + " registered; yet, the hparams plugin is" + " active." + ) + return http_util.Respond( + request, + json.dumps( + list_metric_evals.Handler( + request_proto, scalars_plugin + ).run() + ), + "application/json", + ) + except error.HParamsError as e: + logger.error("HParams error: %s" % e) + raise werkzeug.exceptions.BadRequest(description=str(e)) + + def _get_scalars_plugin(self): + """Tries to get the scalars plugin. + + Returns: + The scalars plugin or None if it is not yet registered. + """ + return self._context.tb_context.plugin_name_to_instance.get( + scalars_metadata.PLUGIN_NAME + ) def _parse_request_argument(request, proto_class): - if request.method == 'POST': - return json_format.Parse(request.data, proto_class()) - - # args.get() returns the request URI-unescaped. - request_json = request.args.get('request') - if request_json is None: - raise error.HParamsError( - 'Expected a JSON-formatted \'request\' arg of type: %s' % proto_class) - return json_format.Parse(request_json, proto_class()) + if request.method == "POST": + return json_format.Parse(request.data, proto_class()) + + # args.get() returns the request URI-unescaped. + request_json = request.args.get("request") + if request_json is None: + raise error.HParamsError( + "Expected a JSON-formatted 'request' arg of type: %s" % proto_class + ) + return json_format.Parse(request_json, proto_class()) diff --git a/tensorboard/plugins/hparams/hparams_util.py b/tensorboard/plugins/hparams/hparams_util.py index 7390254daf..0df4a83763 100644 --- a/tensorboard/plugins/hparams/hparams_util.py +++ b/tensorboard/plugins/hparams/hparams_util.py @@ -60,135 +60,185 @@ FLAGS = flags.FLAGS flags.DEFINE_string( - "action", "", "The action to perform. One of {'create_experiment'," - " 'start_session', 'end_session'}.") + "action", + "", + "The action to perform. One of {'create_experiment'," + " 'start_session', 'end_session'}.", +) flags.DEFINE_string("logdir", "", "The log directory to write the summary to.") # --action=create_experiment flags. flags.DEFINE_string( - "hparam_infos", "", "Only used when --action=create_experiment." + "hparam_infos", + "", + "Only used when --action=create_experiment." " A text-formatted HParamsInfoList protobuf describing" - " the hyperparameters used in the experiment.") + " the hyperparameters used in the experiment.", +) flags.DEFINE_string( - "metric_infos", "", "Only used when --action=create_experiment." + "metric_infos", + "", + "Only used when --action=create_experiment." " A text-formatted MetricInfosList protobuf describing" - " the metrics used in the experiment.") + " the metrics used in the experiment.", +) flags.DEFINE_string( - "description", "", "(Optional) only used when --action=create_experiment." - " The description for the experiment.") + "description", + "", + "(Optional) only used when --action=create_experiment." + " The description for the experiment.", +) flags.DEFINE_string( - "user", getpass.getuser(), + "user", + getpass.getuser(), "(Optional) only used when --action=create_experiment." - " The name of the user creating the experiment.") + " The name of the user creating the experiment.", +) flags.DEFINE_float( - "time_created_secs", time.time(), + "time_created_secs", + time.time(), "(Optional) only used when --action=create_experiment." " The creation time of the experiment in seconds since" - " epoch.") + " epoch.", +) # --action=start_session flags. flags.DEFINE_string( - "hparams", "", "Only used when --action=start_session." + "hparams", + "", + "Only used when --action=start_session." " A text-formatted HParams protobuf describing" - " the hyperparameter values used in the session.") + " the hyperparameter values used in the session.", +) flags.DEFINE_string( - "model_uri", "", "(Optional) only used when --action=start_session." + "model_uri", + "", + "(Optional) only used when --action=start_session." " A uri describing the location where model checkpoints" - " are saved.") + " are saved.", +) flags.DEFINE_string( - "monitor_url", "", "(Optional) only used when --action=start_session." + "monitor_url", + "", + "(Optional) only used when --action=start_session." " A url for a webpage showing monitoring information on" - " the session job.") + " the session job.", +) flags.DEFINE_string( - "group_name", "", "(Optional) only used when --action=start_session." + "group_name", + "", + "(Optional) only used when --action=start_session." " The name of the group this session belongs to" " (empty group means the session is the only one in its" - " group).") + " group).", +) flags.DEFINE_float( - "start_time_secs", time.time(), + "start_time_secs", + time.time(), "(Optional) only used when --action=start_session." - " The time the session started in seconds since epoch.") + " The time the session started in seconds since epoch.", +) # --action=end_session flags. flags.DEFINE_string( - "status", "", "Only used when --action=end_session." + "status", + "", + "Only used when --action=end_session." " A string representation of a member of the Status enum." - " The status the session ended at.") + " The status the session ended at.", +) flags.DEFINE_float( - "end_time_secs", time.time(), + "end_time_secs", + time.time(), "(Optional) only used when --action=end_session." - " The time the session ended in seconds since epoch.") + " The time the session ended in seconds since epoch.", +) def main(argv): - del argv # Unused. - if FLAGS.action == "create_experiment": - create_experiment() - elif FLAGS.action == "start_session": - start_session() - elif FLAGS.action == "end_session": - end_session() - else: - raise ValueError("Invalid action requested: '%s'" % FLAGS.action) + del argv # Unused. + if FLAGS.action == "create_experiment": + create_experiment() + elif FLAGS.action == "start_session": + start_session() + elif FLAGS.action == "end_session": + end_session() + else: + raise ValueError("Invalid action requested: '%s'" % FLAGS.action) def create_experiment(): - hparam_infos = hparams_util_pb2.HParamInfosList() - text_format.Merge(FLAGS.hparam_infos, hparam_infos) - metric_infos = hparams_util_pb2.MetricInfosList() - text_format.Merge(FLAGS.metric_infos, metric_infos) - write_summary( - summary.experiment_pb(hparam_infos.hparam_infos, - metric_infos.metric_infos, FLAGS.user, - FLAGS.description, FLAGS.time_created_secs)) + hparam_infos = hparams_util_pb2.HParamInfosList() + text_format.Merge(FLAGS.hparam_infos, hparam_infos) + metric_infos = hparams_util_pb2.MetricInfosList() + text_format.Merge(FLAGS.metric_infos, metric_infos) + write_summary( + summary.experiment_pb( + hparam_infos.hparam_infos, + metric_infos.metric_infos, + FLAGS.user, + FLAGS.description, + FLAGS.time_created_secs, + ) + ) def start_session(): - hparams = hparams_util_pb2.HParams() - text_format.Merge(FLAGS.hparams, hparams) - # Convert hparams.hparams values from google.protobuf.Value to Python native - # objects. - hparams = { - key: value_to_python(value) - for (key, value) in six.iteritems(hparams.hparams) - } - write_summary( - summary.session_start_pb(hparams, FLAGS.model_uri, FLAGS.monitor_url, - FLAGS.group_name, FLAGS.start_time_secs)) + hparams = hparams_util_pb2.HParams() + text_format.Merge(FLAGS.hparams, hparams) + # Convert hparams.hparams values from google.protobuf.Value to Python native + # objects. + hparams = { + key: value_to_python(value) + for (key, value) in six.iteritems(hparams.hparams) + } + write_summary( + summary.session_start_pb( + hparams, + FLAGS.model_uri, + FLAGS.monitor_url, + FLAGS.group_name, + FLAGS.start_time_secs, + ) + ) def value_to_python(value): - """Converts a google.protobuf.Value to a native python object.""" + """Converts a google.protobuf.Value to a native python object.""" - # We use the ListValue Well Known Type Value-to-native Python conversion - # logic to avoid depending on value's protobuf representation. - l = struct_pb2.ListValue(values=[value]) - return l[0] + # We use the ListValue Well Known Type Value-to-native Python conversion + # logic to avoid depending on value's protobuf representation. + l = struct_pb2.ListValue(values=[value]) + return l[0] def end_session(): - write_summary( - summary.session_end_pb( - api_pb2.Status.Value(FLAGS.status), FLAGS.end_time_secs)) + write_summary( + summary.session_end_pb( + api_pb2.Status.Value(FLAGS.status), FLAGS.end_time_secs + ) + ) def write_summary(summary_pb): - tf.compat.v1.enable_eager_execution() - writer = tf.compat.v2.summary.create_file_writer(FLAGS.logdir) - with writer.as_default(): - if hasattr(tf.compat.v2.summary.experimental, "write_raw_pb"): - tf.compat.v2.summary.experimental.write_raw_pb( - summary_pb.SerializeToString(), step=0) - else: - # TODO(https://github.com/tensorflow/tensorboard/issues/2109): remove the - # fallback to import_event(). - event = tf.compat.v1.Event(summary=summary_pb) - tf.compat.v2.summary.import_event( - tf.constant(event.SerializeToString(), dtype=tf.string)) - # The following may not be required since the context manager may - # already flush on __exit__, but it doesn't hurt to do it here, as well. - tf.compat.v2.summary.flush() + tf.compat.v1.enable_eager_execution() + writer = tf.compat.v2.summary.create_file_writer(FLAGS.logdir) + with writer.as_default(): + if hasattr(tf.compat.v2.summary.experimental, "write_raw_pb"): + tf.compat.v2.summary.experimental.write_raw_pb( + summary_pb.SerializeToString(), step=0 + ) + else: + # TODO(https://github.com/tensorflow/tensorboard/issues/2109): remove the + # fallback to import_event(). + event = tf.compat.v1.Event(summary=summary_pb) + tf.compat.v2.summary.import_event( + tf.constant(event.SerializeToString(), dtype=tf.string) + ) + # The following may not be required since the context manager may + # already flush on __exit__, but it doesn't hurt to do it here, as well. + tf.compat.v2.summary.flush() if __name__ == "__main__": - app.run(main) + app.run(main) diff --git a/tensorboard/plugins/hparams/keras.py b/tensorboard/plugins/hparams/keras.py index 1fce3bd0f7..889d0b7c77 100644 --- a/tensorboard/plugins/hparams/keras.py +++ b/tensorboard/plugins/hparams/keras.py @@ -31,67 +31,69 @@ class Callback(tf.keras.callbacks.Callback): - """Callback for logging hyperparameters to TensorBoard. + """Callback for logging hyperparameters to TensorBoard. - NOTE: This callback only works in TensorFlow eager mode. - """ + NOTE: This callback only works in TensorFlow eager mode. + """ - def __init__(self, writer, hparams, trial_id=None): - """Create a callback for logging hyperparameters to TensorBoard. + def __init__(self, writer, hparams, trial_id=None): + """Create a callback for logging hyperparameters to TensorBoard. - As with the standard `tf.keras.callbacks.TensorBoard` class, each - callback object is valid for only one call to `model.fit`. + As with the standard `tf.keras.callbacks.TensorBoard` class, each + callback object is valid for only one call to `model.fit`. - Args: - writer: The `SummaryWriter` object to which hparams should be - written, or a logdir (as a `str`) to be passed to - `tf.summary.create_file_writer` to create such a writer. - hparams: A `dict` mapping hyperparameters to the values used in - this session. Keys should be the names of `HParam` objects used - in an experiment, or the `HParam` objects themselves. Values - should be Python `bool`, `int`, `float`, or `string` values, - depending on the type of the hyperparameter. - trial_id: An optional `str` ID for the set of hyperparameter - values used in this trial. Defaults to a hash of the - hyperparameters. + Args: + writer: The `SummaryWriter` object to which hparams should be + written, or a logdir (as a `str`) to be passed to + `tf.summary.create_file_writer` to create such a writer. + hparams: A `dict` mapping hyperparameters to the values used in + this session. Keys should be the names of `HParam` objects used + in an experiment, or the `HParam` objects themselves. Values + should be Python `bool`, `int`, `float`, or `string` values, + depending on the type of the hyperparameter. + trial_id: An optional `str` ID for the set of hyperparameter + values used in this trial. Defaults to a hash of the + hyperparameters. - Raises: - ValueError: If two entries in `hparams` share the same - hyperparameter name. - """ - # Defer creating the actual summary until we write it, so that the - # timestamp is correct. But create a "dry-run" first to fail fast in - # case the `hparams` are invalid. - self._hparams = dict(hparams) - self._trial_id = trial_id - summary_v2.hparams_pb(self._hparams, trial_id=self._trial_id) - if writer is None: - raise TypeError("writer must be a `SummaryWriter` or `str`, not None") - elif isinstance(writer, str): - self._writer = tf.compat.v2.summary.create_file_writer(writer) - else: - self._writer = writer + Raises: + ValueError: If two entries in `hparams` share the same + hyperparameter name. + """ + # Defer creating the actual summary until we write it, so that the + # timestamp is correct. But create a "dry-run" first to fail fast in + # case the `hparams` are invalid. + self._hparams = dict(hparams) + self._trial_id = trial_id + summary_v2.hparams_pb(self._hparams, trial_id=self._trial_id) + if writer is None: + raise TypeError( + "writer must be a `SummaryWriter` or `str`, not None" + ) + elif isinstance(writer, str): + self._writer = tf.compat.v2.summary.create_file_writer(writer) + else: + self._writer = writer - def _get_writer(self): - if self._writer is None: - raise RuntimeError( - "hparams Keras callback cannot be reused across training sessions" - ) - if not tf.executing_eagerly(): - raise RuntimeError( - "hparams Keras callback only supported in TensorFlow eager mode" - ) - return self._writer + def _get_writer(self): + if self._writer is None: + raise RuntimeError( + "hparams Keras callback cannot be reused across training sessions" + ) + if not tf.executing_eagerly(): + raise RuntimeError( + "hparams Keras callback only supported in TensorFlow eager mode" + ) + return self._writer - def on_train_begin(self, logs=None): - del logs # unused - with self._get_writer().as_default(): - summary_v2.hparams(self._hparams, trial_id=self._trial_id) + def on_train_begin(self, logs=None): + del logs # unused + with self._get_writer().as_default(): + summary_v2.hparams(self._hparams, trial_id=self._trial_id) - def on_train_end(self, logs=None): - del logs # unused - with self._get_writer().as_default(): - pb = summary.session_end_pb(api_pb2.STATUS_SUCCESS) - raw_pb = pb.SerializeToString() - tf.compat.v2.summary.experimental.write_raw_pb(raw_pb, step=0) - self._writer = None + def on_train_end(self, logs=None): + del logs # unused + with self._get_writer().as_default(): + pb = summary.session_end_pb(api_pb2.STATUS_SUCCESS) + raw_pb = pb.SerializeToString() + tf.compat.v2.summary.experimental.write_raw_pb(raw_pb, step=0) + self._writer = None diff --git a/tensorboard/plugins/hparams/keras_test.py b/tensorboard/plugins/hparams/keras_test.py index 6c3ddb761c..152611eed7 100644 --- a/tensorboard/plugins/hparams/keras_test.py +++ b/tensorboard/plugins/hparams/keras_test.py @@ -25,10 +25,10 @@ import tensorflow as tf try: - # python version >= 3.3 - from unittest import mock + # python version >= 3.3 + from unittest import mock except ImportError: - import mock # pylint: disable=unused-import + import mock # pylint: disable=unused-import from tensorboard.plugins.hparams import keras from tensorboard.plugins.hparams import metadata @@ -40,69 +40,76 @@ class CallbackTest(tf.test.TestCase): - - def setUp(self): - super(CallbackTest, self).setUp() - self.logdir = os.path.join(self.get_temp_dir(), "logs") - - def _initialize_model(self, writer): - HP_DENSE_NEURONS = hp.HParam("dense_neurons", hp.IntInterval(4, 16)) - self.hparams = { - "optimizer": "adam", - HP_DENSE_NEURONS: 8, - } - self.model = tf.keras.models.Sequential([ - tf.keras.layers.Dense(self.hparams[HP_DENSE_NEURONS], input_shape=(1,)), - tf.keras.layers.Dense(1, activation="sigmoid"), - ]) - self.model.compile(loss="mse", optimizer=self.hparams["optimizer"]) - self.trial_id = "my_trial" - self.callback = keras.Callback(writer, self.hparams, trial_id=self.trial_id) - - def test_eager(self): - def mock_time(): - mock_time.time += 1 - return mock_time.time - mock_time.time = 1556227801.875 - initial_time = mock_time.time - with mock.patch("time.time", mock_time): - self._initialize_model(writer=self.logdir) - self.model.fit(x=[(1,)], y=[(2,)], callbacks=[self.callback]) - final_time = mock_time.time - - files = os.listdir(self.logdir) - self.assertEqual(len(files), 1, files) - events_file = os.path.join(self.logdir, files[0]) - plugin_data = [] - for event in tf.compat.v1.train.summary_iterator(events_file): - if event.WhichOneof("what") != "summary": - continue - self.assertEqual(len(event.summary.value), 1, event.summary.value) - value = event.summary.value[0] - self.assertEqual( - value.metadata.plugin_data.plugin_name, - metadata.PLUGIN_NAME, - ) - plugin_data.append(value.metadata.plugin_data.content) - - self.assertEqual(len(plugin_data), 2, plugin_data) - (start_plugin_data, end_plugin_data) = plugin_data - start_pb = metadata.parse_session_start_info_plugin_data(start_plugin_data) - end_pb = metadata.parse_session_end_info_plugin_data(end_plugin_data) - - # We're not the only callers of `time.time`; Keras calls it - # internally an unspecified number of times, so we're not guaranteed - # to know the exact values. Instead, we perform relative checks... - self.assertGreater(start_pb.start_time_secs, initial_time) - self.assertLess(start_pb.start_time_secs, end_pb.end_time_secs) - self.assertLessEqual(start_pb.start_time_secs, final_time) - # ...and then stub out the times for proto equality checks below. - start_pb.start_time_secs = 1234.5 - end_pb.end_time_secs = 6789.0 - - expected_start_pb = plugin_data_pb2.SessionStartInfo() - text_format.Merge( - """ + def setUp(self): + super(CallbackTest, self).setUp() + self.logdir = os.path.join(self.get_temp_dir(), "logs") + + def _initialize_model(self, writer): + HP_DENSE_NEURONS = hp.HParam("dense_neurons", hp.IntInterval(4, 16)) + self.hparams = { + "optimizer": "adam", + HP_DENSE_NEURONS: 8, + } + self.model = tf.keras.models.Sequential( + [ + tf.keras.layers.Dense( + self.hparams[HP_DENSE_NEURONS], input_shape=(1,) + ), + tf.keras.layers.Dense(1, activation="sigmoid"), + ] + ) + self.model.compile(loss="mse", optimizer=self.hparams["optimizer"]) + self.trial_id = "my_trial" + self.callback = keras.Callback( + writer, self.hparams, trial_id=self.trial_id + ) + + def test_eager(self): + def mock_time(): + mock_time.time += 1 + return mock_time.time + + mock_time.time = 1556227801.875 + initial_time = mock_time.time + with mock.patch("time.time", mock_time): + self._initialize_model(writer=self.logdir) + self.model.fit(x=[(1,)], y=[(2,)], callbacks=[self.callback]) + final_time = mock_time.time + + files = os.listdir(self.logdir) + self.assertEqual(len(files), 1, files) + events_file = os.path.join(self.logdir, files[0]) + plugin_data = [] + for event in tf.compat.v1.train.summary_iterator(events_file): + if event.WhichOneof("what") != "summary": + continue + self.assertEqual(len(event.summary.value), 1, event.summary.value) + value = event.summary.value[0] + self.assertEqual( + value.metadata.plugin_data.plugin_name, metadata.PLUGIN_NAME, + ) + plugin_data.append(value.metadata.plugin_data.content) + + self.assertEqual(len(plugin_data), 2, plugin_data) + (start_plugin_data, end_plugin_data) = plugin_data + start_pb = metadata.parse_session_start_info_plugin_data( + start_plugin_data + ) + end_pb = metadata.parse_session_end_info_plugin_data(end_plugin_data) + + # We're not the only callers of `time.time`; Keras calls it + # internally an unspecified number of times, so we're not guaranteed + # to know the exact values. Instead, we perform relative checks... + self.assertGreater(start_pb.start_time_secs, initial_time) + self.assertLess(start_pb.start_time_secs, end_pb.end_time_secs) + self.assertLessEqual(start_pb.start_time_secs, final_time) + # ...and then stub out the times for proto equality checks below. + start_pb.start_time_secs = 1234.5 + end_pb.end_time_secs = 6789.0 + + expected_start_pb = plugin_data_pb2.SessionStartInfo() + text_format.Merge( + """ start_time_secs: 1234.5 group_name: "my_trial" hparams { @@ -118,78 +125,85 @@ def mock_time(): } } """, - expected_start_pb, - ) - self.assertEqual(start_pb, expected_start_pb) + expected_start_pb, + ) + self.assertEqual(start_pb, expected_start_pb) - expected_end_pb = plugin_data_pb2.SessionEndInfo() - text_format.Merge( - """ + expected_end_pb = plugin_data_pb2.SessionEndInfo() + text_format.Merge( + """ end_time_secs: 6789.0 status: STATUS_SUCCESS """, - expected_end_pb, - ) - self.assertEqual(end_pb, expected_end_pb) - - def test_explicit_writer(self): - writer = tf.compat.v2.summary.create_file_writer( - self.logdir, - filename_suffix=".magic", - ) - self._initialize_model(writer=writer) - self.model.fit(x=[(1,)], y=[(2,)], callbacks=[self.callback]) - - files = os.listdir(self.logdir) - self.assertEqual(len(files), 1, files) - filename = files[0] - self.assertTrue(filename.endswith(".magic"), filename) - # We'll assume that the contents are correct, as in the case where - # the file writer was constructed implicitly. - - def test_non_eager_failure(self): - with tf.compat.v1.Graph().as_default(): - assert not tf.executing_eagerly() - self._initialize_model(writer=self.logdir) - with six.assertRaisesRegex( - self, RuntimeError, "only supported in TensorFlow eager mode"): + expected_end_pb, + ) + self.assertEqual(end_pb, expected_end_pb) + + def test_explicit_writer(self): + writer = tf.compat.v2.summary.create_file_writer( + self.logdir, filename_suffix=".magic", + ) + self._initialize_model(writer=writer) + self.model.fit(x=[(1,)], y=[(2,)], callbacks=[self.callback]) + + files = os.listdir(self.logdir) + self.assertEqual(len(files), 1, files) + filename = files[0] + self.assertTrue(filename.endswith(".magic"), filename) + # We'll assume that the contents are correct, as in the case where + # the file writer was constructed implicitly. + + def test_non_eager_failure(self): + with tf.compat.v1.Graph().as_default(): + assert not tf.executing_eagerly() + self._initialize_model(writer=self.logdir) + with six.assertRaisesRegex( + self, RuntimeError, "only supported in TensorFlow eager mode" + ): + self.model.fit(x=[(1,)], y=[(2,)], callbacks=[self.callback]) + + def test_reuse_failure(self): + self._initialize_model(writer=self.logdir) self.model.fit(x=[(1,)], y=[(2,)], callbacks=[self.callback]) + with six.assertRaisesRegex( + self, RuntimeError, "cannot be reused across training sessions" + ): + self.model.fit(x=[(1,)], y=[(2,)], callbacks=[self.callback]) + + def test_invalid_writer(self): + with six.assertRaisesRegex( + self, + TypeError, + "writer must be a `SummaryWriter` or `str`, not None", + ): + keras.Callback(writer=None, hparams={}) + + def test_duplicate_hparam_names_across_object_and_string(self): + hparams = { + "foo": 1, + hp.HParam("foo"): 1, + } + with six.assertRaisesRegex( + self, ValueError, "multiple values specified for hparam 'foo'" + ): + keras.Callback(self.get_temp_dir(), hparams) + + def test_duplicate_hparam_names_from_two_objects(self): + hparams = { + hp.HParam("foo"): 1, + hp.HParam("foo"): 1, + } + with six.assertRaisesRegex( + self, ValueError, "multiple values specified for hparam 'foo'" + ): + keras.Callback(self.get_temp_dir(), hparams) - def test_reuse_failure(self): - self._initialize_model(writer=self.logdir) - self.model.fit(x=[(1,)], y=[(2,)], callbacks=[self.callback]) - with six.assertRaisesRegex( - self, RuntimeError, "cannot be reused across training sessions"): - self.model.fit(x=[(1,)], y=[(2,)], callbacks=[self.callback]) - - def test_invalid_writer(self): - with six.assertRaisesRegex( - self, TypeError, "writer must be a `SummaryWriter` or `str`, not None"): - keras.Callback(writer=None, hparams={}) - - def test_duplicate_hparam_names_across_object_and_string(self): - hparams = { - "foo": 1, - hp.HParam("foo"): 1, - } - with six.assertRaisesRegex( - self, ValueError, "multiple values specified for hparam 'foo'"): - keras.Callback(self.get_temp_dir(), hparams) - - def test_duplicate_hparam_names_from_two_objects(self): - hparams = { - hp.HParam("foo"): 1, - hp.HParam("foo"): 1, - } - with six.assertRaisesRegex( - self, ValueError, "multiple values specified for hparam 'foo'"): - keras.Callback(self.get_temp_dir(), hparams) - - def test_invalid_trial_id(self): - with six.assertRaisesRegex( - self, TypeError, "`trial_id` should be a `str`, but got: 12"): - keras.Callback(self.get_temp_dir(), {}, trial_id=12) + def test_invalid_trial_id(self): + with six.assertRaisesRegex( + self, TypeError, "`trial_id` should be a `str`, but got: 12" + ): + keras.Callback(self.get_temp_dir(), {}, trial_id=12) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/tensorboard/plugins/hparams/list_metric_evals.py b/tensorboard/plugins/hparams/list_metric_evals.py index dc54ce20f0..1d22ede1ed 100644 --- a/tensorboard/plugins/hparams/list_metric_evals.py +++ b/tensorboard/plugins/hparams/list_metric_evals.py @@ -23,27 +23,29 @@ class Handler(object): - """Handles a ListMetricEvals request. """ - - def __init__(self, request, scalars_plugin_instance): - """Constructor. - - Args: - request: A ListSessionGroupsRequest protobuf. - scalars_plugin_instance: A scalars_plugin.ScalarsPlugin. - """ - self._request = request - self._scalars_plugin_instance = scalars_plugin_instance - - def run(self): - """Executes the request. - - Returns: - An array of tuples representing the metric evaluations--each of the form - (, , ). - """ - run, tag = metrics.run_tag_from_session_and_metric( - self._request.session_name, self._request.metric_name) - body, _ = self._scalars_plugin_instance.scalars_impl( - tag, run, None, scalars_plugin.OutputFormat.JSON) - return body + """Handles a ListMetricEvals request.""" + + def __init__(self, request, scalars_plugin_instance): + """Constructor. + + Args: + request: A ListSessionGroupsRequest protobuf. + scalars_plugin_instance: A scalars_plugin.ScalarsPlugin. + """ + self._request = request + self._scalars_plugin_instance = scalars_plugin_instance + + def run(self): + """Executes the request. + + Returns: + An array of tuples representing the metric evaluations--each of the form + (, , ). + """ + run, tag = metrics.run_tag_from_session_and_metric( + self._request.session_name, self._request.metric_name + ) + body, _ = self._scalars_plugin_instance.scalars_impl( + tag, run, None, scalars_plugin.OutputFormat.JSON + ) + return body diff --git a/tensorboard/plugins/hparams/list_metric_evals_test.py b/tensorboard/plugins/hparams/list_metric_evals_test.py index 92485cca4c..da14f7dea5 100644 --- a/tensorboard/plugins/hparams/list_metric_evals_test.py +++ b/tensorboard/plugins/hparams/list_metric_evals_test.py @@ -19,10 +19,10 @@ from __future__ import print_function try: - # python version >= 3.3 - from unittest import mock + # python version >= 3.3 + from unittest import mock except ImportError: - import mock # pylint: disable=unused-import + import mock # pylint: disable=unused-import import tensorflow as tf from google.protobuf import text_format @@ -32,35 +32,39 @@ class ListMetricEvalsTest(tf.test.TestCase): + def setUp(self): + self._mock_scalars_plugin = mock.create_autospec( + scalars_plugin.ScalarsPlugin + ) + self._mock_scalars_plugin.scalars_impl.side_effect = ( + self._mock_scalars_impl + ) - def setUp(self): - self._mock_scalars_plugin = mock.create_autospec( - scalars_plugin.ScalarsPlugin) - self._mock_scalars_plugin.scalars_impl.side_effect = self._mock_scalars_impl + def _mock_scalars_impl(self, tag, run, experiment, output_format): + del experiment # unused + self.assertEqual("metric_tag", tag) + self.assertEqual("/this/is/a/session/metric_group", run) + self.assertEqual(scalars_plugin.OutputFormat.JSON, output_format) + return ([(1, 1, 1.0), (2, 2, 2.0), (3, 3, 3.0)]), "application/json" - def _mock_scalars_impl(self, tag, run, experiment, output_format): - del experiment # unused - self.assertEqual('metric_tag', tag) - self.assertEqual('/this/is/a/session/metric_group', run) - self.assertEqual(scalars_plugin.OutputFormat.JSON, output_format) - return ([(1, 1, 1.0), (2, 2, 2.0), (3, 3, 3.0)]), 'application/json' + def _run_handler(self, request): + request_proto = api_pb2.ListMetricEvalsRequest() + text_format.Merge(request, request_proto) + handler = list_metric_evals.Handler( + request_proto, self._mock_scalars_plugin + ) + return handler.run() - def _run_handler(self, request): - request_proto = api_pb2.ListMetricEvalsRequest() - text_format.Merge(request, request_proto) - handler = list_metric_evals.Handler( - request_proto, self._mock_scalars_plugin) - return handler.run() - - def test_run(self): - result = self._run_handler( - '''session_name: '/this/is/a/session' + def test_run(self): + result = self._run_handler( + """session_name: '/this/is/a/session' metric_name: { tag: 'metric_tag' group: 'metric_group' - }''') - self.assertEqual([(1, 1, 1.0), (2, 2, 2.0), (3, 3, 3.0)], result) + }""" + ) + self.assertEqual([(1, 1, 1.0), (2, 2, 2.0), (3, 3, 3.0)], result) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/hparams/list_session_groups.py b/tensorboard/plugins/hparams/list_session_groups.py index 4818d62cc6..b5a86281ef 100644 --- a/tensorboard/plugins/hparams/list_session_groups.py +++ b/tensorboard/plugins/hparams/list_session_groups.py @@ -33,316 +33,347 @@ class Handler(object): - """Handles a ListSessionGroups request.""" + """Handles a ListSessionGroups request.""" + + def __init__(self, context, request): + """Constructor. + + Args: + context: A backend_context.Context instance. + request: A ListSessionGroupsRequest protobuf. + """ + self._context = context + self._request = request + self._extractors = _create_extractors(request.col_params) + self._filters = _create_filters(request.col_params, self._extractors) + # Since an context.experiment() call may search through all the runs, we + # cache it here. + self._experiment = context.experiment() + + def run(self): + """Handles the request specified on construction. + + Returns: + A ListSessionGroupsResponse object. + """ + session_groups = self._build_session_groups() + session_groups = self._filter(session_groups) + self._sort(session_groups) + return self._create_response(session_groups) + + def _build_session_groups(self): + """Returns a list of SessionGroups protobuffers from the summary + data.""" + + # Algorithm: We keep a dict 'groups_by_name' mapping a SessionGroup name + # (str) to a SessionGroup protobuffer. We traverse the runs associated with + # the plugin--each representing a single session. We form a Session + # protobuffer from each run and add it to the relevant SessionGroup object + # in the 'groups_by_name' dict. We create the SessionGroup object, if this + # is the first session of that group we encounter. + groups_by_name = {} + run_to_tag_to_content = self._context.multiplexer.PluginRunToTagToContent( + metadata.PLUGIN_NAME + ) + for (run, tag_to_content) in six.iteritems(run_to_tag_to_content): + if metadata.SESSION_START_INFO_TAG not in tag_to_content: + continue + start_info = metadata.parse_session_start_info_plugin_data( + tag_to_content[metadata.SESSION_START_INFO_TAG] + ) + end_info = None + if metadata.SESSION_END_INFO_TAG in tag_to_content: + end_info = metadata.parse_session_end_info_plugin_data( + tag_to_content[metadata.SESSION_END_INFO_TAG] + ) + session = self._build_session(run, start_info, end_info) + if session.status in self._request.allowed_statuses: + self._add_session(session, start_info, groups_by_name) + + # Compute the session group's aggregated metrics for each group. + groups = groups_by_name.values() + for group in groups: + # We sort the sessions in a group so that the order is deterministic. + group.sessions.sort(key=operator.attrgetter("name")) + self._aggregate_metrics(group) + return groups + + def _add_session(self, session, start_info, groups_by_name): + """Adds a new Session protobuffer to the 'groups_by_name' dictionary. + + Called by _build_session_groups when we encounter a new session. Creates + the Session protobuffer and adds it to the relevant group in the + 'groups_by_name' dict. Creates the session group if this is the first time + we encounter it. + + Args: + session: api_pb2.Session. The session to add. + start_info: The SessionStartInfo protobuffer associated with the session. + groups_by_name: A str to SessionGroup protobuffer dict. Representing the + session groups and sessions found so far. + """ + # If the group_name is empty, this session's group contains only + # this session. Use the session name for the group name since session + # names are unique. + group_name = start_info.group_name or session.name + if group_name in groups_by_name: + groups_by_name[group_name].sessions.extend([session]) + else: + # Create the group and add the session as the first one. + group = api_pb2.SessionGroup( + name=group_name, + sessions=[session], + monitor_url=start_info.monitor_url, + ) + # Copy hparams from the first session (all sessions should have the same + # hyperparameter values) into result. + # There doesn't seem to be a way to initialize a protobuffer map in the + # constructor. + for (key, value) in six.iteritems(start_info.hparams): + group.hparams[key].CopyFrom(value) + groups_by_name[group_name] = group + + def _build_session(self, name, start_info, end_info): + """Builds a session object.""" + + assert start_info is not None + result = api_pb2.Session( + name=name, + start_time_secs=start_info.start_time_secs, + model_uri=start_info.model_uri, + metric_values=self._build_session_metric_values(name), + monitor_url=start_info.monitor_url, + ) + if end_info is not None: + result.status = end_info.status + result.end_time_secs = end_info.end_time_secs + return result + + def _build_session_metric_values(self, session_name): + """Builds the session metric values.""" + + # result is a list of api_pb2.MetricValue instances. + result = [] + metric_infos = self._experiment.metric_infos + for metric_info in metric_infos: + metric_name = metric_info.name + try: + metric_eval = metrics.last_metric_eval( + self._context.multiplexer, session_name, metric_name + ) + except KeyError: + # It's ok if we don't find the metric in the session. + # We skip it here. For filtering and sorting purposes its value is None. + continue + + # metric_eval is a 3-tuple of the form [wall_time, step, value] + result.append( + api_pb2.MetricValue( + name=metric_name, + wall_time_secs=metric_eval[0], + training_step=metric_eval[1], + value=metric_eval[2], + ) + ) + return result + + def _aggregate_metrics(self, session_group): + """Sets the metrics of the group based on aggregation_type.""" + + if ( + self._request.aggregation_type == api_pb2.AGGREGATION_AVG + or self._request.aggregation_type == api_pb2.AGGREGATION_UNSET + ): + _set_avg_session_metrics(session_group) + elif self._request.aggregation_type == api_pb2.AGGREGATION_MEDIAN: + _set_median_session_metrics( + session_group, self._request.aggregation_metric + ) + elif self._request.aggregation_type == api_pb2.AGGREGATION_MIN: + _set_extremum_session_metrics( + session_group, self._request.aggregation_metric, min + ) + elif self._request.aggregation_type == api_pb2.AGGREGATION_MAX: + _set_extremum_session_metrics( + session_group, self._request.aggregation_metric, max + ) + else: + raise error.HParamsError( + "Unknown aggregation_type in request: %s" + % self._request.aggregation_type + ) + + def _filter(self, session_groups): + return [sg for sg in session_groups if self._passes_all_filters(sg)] + + def _passes_all_filters(self, session_group): + return all(filter_fn(session_group) for filter_fn in self._filters) + + def _sort(self, session_groups): + """Sorts 'session_groups' in place according to _request.col_params.""" + + # Sort by session_group name so we have a deterministic order. + session_groups.sort(key=operator.attrgetter("name")) + # Sort by lexicographical order of the _request.col_params whose order + # is not ORDER_UNSPECIFIED. The first such column is the primary sorting + # key, the second is the secondary sorting key, etc. To achieve that we + # need to iterate on these columns in reverse order (thus the primary key + # is the key used in the last sort). + for col_param, extractor in reversed( + list(zip(self._request.col_params, self._extractors)) + ): + if col_param.order == api_pb2.ORDER_UNSPECIFIED: + continue + if col_param.order == api_pb2.ORDER_ASC: + session_groups.sort( + key=_create_key_func( + extractor, + none_is_largest=not col_param.missing_values_first, + ) + ) + elif col_param.order == api_pb2.ORDER_DESC: + session_groups.sort( + key=_create_key_func( + extractor, + none_is_largest=col_param.missing_values_first, + ), + reverse=True, + ) + else: + raise error.HParamsError( + "Unknown col_param.order given: %s" % col_param + ) + + def _create_response(self, session_groups): + return api_pb2.ListSessionGroupsResponse( + session_groups=session_groups[ + self._request.start_index : self._request.start_index + + self._request.slice_size + ], + total_size=len(session_groups), + ) - def __init__(self, context, request): - """Constructor. - Args: - context: A backend_context.Context instance. - request: A ListSessionGroupsRequest protobuf. - """ - self._context = context - self._request = request - self._extractors = _create_extractors(request.col_params) - self._filters = _create_filters(request.col_params, self._extractors) - # Since an context.experiment() call may search through all the runs, we - # cache it here. - self._experiment = context.experiment() - - def run(self): - """Handles the request specified on construction. +def _create_key_func(extractor, none_is_largest): + """Returns a key_func to be used in list.sort(). - Returns: - A ListSessionGroupsResponse object. - - """ - session_groups = self._build_session_groups() - session_groups = self._filter(session_groups) - self._sort(session_groups) - return self._create_response(session_groups) - - def _build_session_groups(self): - """Returns a list of SessionGroups protobuffers from the summary data.""" - - # Algorithm: We keep a dict 'groups_by_name' mapping a SessionGroup name - # (str) to a SessionGroup protobuffer. We traverse the runs associated with - # the plugin--each representing a single session. We form a Session - # protobuffer from each run and add it to the relevant SessionGroup object - # in the 'groups_by_name' dict. We create the SessionGroup object, if this - # is the first session of that group we encounter. - groups_by_name = {} - run_to_tag_to_content = self._context.multiplexer.PluginRunToTagToContent( - metadata.PLUGIN_NAME) - for (run, tag_to_content) in six.iteritems(run_to_tag_to_content): - if metadata.SESSION_START_INFO_TAG not in tag_to_content: - continue - start_info = metadata.parse_session_start_info_plugin_data( - tag_to_content[metadata.SESSION_START_INFO_TAG]) - end_info = None - if metadata.SESSION_END_INFO_TAG in tag_to_content: - end_info = metadata.parse_session_end_info_plugin_data( - tag_to_content[metadata.SESSION_END_INFO_TAG]) - session = self._build_session(run, start_info, end_info) - if session.status in self._request.allowed_statuses: - self._add_session(session, start_info, groups_by_name) - - # Compute the session group's aggregated metrics for each group. - groups = groups_by_name.values() - for group in groups: - # We sort the sessions in a group so that the order is deterministic. - group.sessions.sort(key=operator.attrgetter('name')) - self._aggregate_metrics(group) - return groups - - def _add_session(self, session, start_info, groups_by_name): - """Adds a new Session protobuffer to the 'groups_by_name' dictionary. - - Called by _build_session_groups when we encounter a new session. Creates - the Session protobuffer and adds it to the relevant group in the - 'groups_by_name' dict. Creates the session group if this is the first time - we encounter it. + Returns a key_func to be used in list.sort() that sorts session groups + by the value extracted by extractor. 'None' extracted values will either + be considered largest or smallest as specified by the "none_is_largest" + boolean parameter. Args: - session: api_pb2.Session. The session to add. - start_info: The SessionStartInfo protobuffer associated with the session. - groups_by_name: A str to SessionGroup protobuffer dict. Representing the - session groups and sessions found so far. + extractor: An extractor function that extract the key from the session + group. + none_is_largest: bool. If true treats 'None's as largest; otherwise + smallest. """ - # If the group_name is empty, this session's group contains only - # this session. Use the session name for the group name since session - # names are unique. - group_name = start_info.group_name or session.name - if group_name in groups_by_name: - groups_by_name[group_name].sessions.extend([session]) - else: - # Create the group and add the session as the first one. - group = api_pb2.SessionGroup( - name=group_name, - sessions=[session], - monitor_url=start_info.monitor_url) - # Copy hparams from the first session (all sessions should have the same - # hyperparameter values) into result. - # There doesn't seem to be a way to initialize a protobuffer map in the - # constructor. - for (key, value) in six.iteritems(start_info.hparams): - group.hparams[key].CopyFrom(value) - groups_by_name[group_name] = group - - def _build_session(self, name, start_info, end_info): - """Builds a session object.""" - - assert start_info is not None - result = api_pb2.Session( - name=name, - start_time_secs=start_info.start_time_secs, - model_uri=start_info.model_uri, - metric_values=self._build_session_metric_values(name), - monitor_url=start_info.monitor_url) - if end_info is not None: - result.status = end_info.status - result.end_time_secs = end_info.end_time_secs - return result + if none_is_largest: - def _build_session_metric_values(self, session_name): - """Builds the session metric values.""" + def key_func_none_is_largest(session_group): + value = extractor(session_group) + return (value is None, value) - # result is a list of api_pb2.MetricValue instances. - result = [] - metric_infos = self._experiment.metric_infos - for metric_info in metric_infos: - metric_name = metric_info.name - try: - metric_eval = metrics.last_metric_eval( - self._context.multiplexer, - session_name, - metric_name) - except KeyError: - # It's ok if we don't find the metric in the session. - # We skip it here. For filtering and sorting purposes its value is None. - continue - - # metric_eval is a 3-tuple of the form [wall_time, step, value] - result.append(api_pb2.MetricValue(name=metric_name, - wall_time_secs=metric_eval[0], - training_step=metric_eval[1], - value=metric_eval[2])) - return result - - def _aggregate_metrics(self, session_group): - """Sets the metrics of the group based on aggregation_type.""" - - if (self._request.aggregation_type == api_pb2.AGGREGATION_AVG or - self._request.aggregation_type == api_pb2.AGGREGATION_UNSET): - _set_avg_session_metrics(session_group) - elif self._request.aggregation_type == api_pb2.AGGREGATION_MEDIAN: - _set_median_session_metrics(session_group, - self._request.aggregation_metric) - elif self._request.aggregation_type == api_pb2.AGGREGATION_MIN: - _set_extremum_session_metrics(session_group, - self._request.aggregation_metric, - min) - elif self._request.aggregation_type == api_pb2.AGGREGATION_MAX: - _set_extremum_session_metrics(session_group, - self._request.aggregation_metric, - max) - else: - raise error.HParamsError('Unknown aggregation_type in request: %s' % - self._request.aggregation_type) - - def _filter(self, session_groups): - return [sg for sg in session_groups if self._passes_all_filters(sg)] - - def _passes_all_filters(self, session_group): - return all(filter_fn(session_group) for filter_fn in self._filters) - - def _sort(self, session_groups): - """Sorts 'session_groups' in place according to _request.col_params.""" - - # Sort by session_group name so we have a deterministic order. - session_groups.sort(key=operator.attrgetter('name')) - # Sort by lexicographical order of the _request.col_params whose order - # is not ORDER_UNSPECIFIED. The first such column is the primary sorting - # key, the second is the secondary sorting key, etc. To achieve that we - # need to iterate on these columns in reverse order (thus the primary key - # is the key used in the last sort). - for col_param, extractor in reversed(list(zip(self._request.col_params, - self._extractors))): - if col_param.order == api_pb2.ORDER_UNSPECIFIED: - continue - if col_param.order == api_pb2.ORDER_ASC: - session_groups.sort( - key=_create_key_func( - extractor, - none_is_largest=not col_param.missing_values_first)) - elif col_param.order == api_pb2.ORDER_DESC: - session_groups.sort( - key=_create_key_func( - extractor, - none_is_largest=col_param.missing_values_first), - reverse=True) - else: - raise error.HParamsError('Unknown col_param.order given: %s' % - col_param) - - def _create_response(self, session_groups): - return api_pb2.ListSessionGroupsResponse( - session_groups=session_groups[ - self._request.start_index: - self._request.start_index+self._request.slice_size], - total_size=len(session_groups)) + return key_func_none_is_largest + def key_func_none_is_smallest(session_group): + value = extractor(session_group) + return (value is not None, value) -def _create_key_func(extractor, none_is_largest): - """Returns a key_func to be used in list.sort(). - - Returns a key_func to be used in list.sort() that sorts session groups - by the value extracted by extractor. 'None' extracted values will either - be considered largest or smallest as specified by the "none_is_largest" - boolean parameter. - - Args: - extractor: An extractor function that extract the key from the session - group. - none_is_largest: bool. If true treats 'None's as largest; otherwise - smallest. - - """ - if none_is_largest: - def key_func_none_is_largest(session_group): - value = extractor(session_group) - return (value is None, value) - return key_func_none_is_largest - def key_func_none_is_smallest(session_group): - value = extractor(session_group) - return (value is not None, value) - return key_func_none_is_smallest + return key_func_none_is_smallest # Extractors. An extractor is a function that extracts some property (a metric # or a hyperparameter) from a SessionGroup instance. def _create_extractors(col_params): - """Creates extractors to extract properties corresponding to 'col_params'. + """Creates extractors to extract properties corresponding to 'col_params'. - Args: - col_params: List of ListSessionGroupsRequest.ColParam protobufs. - Returns: - A list of extractor functions. The ith element in the - returned list extracts the column corresponding to the ith element of - _request.col_params - """ - result = [] - for col_param in col_params: - result.append(_create_extractor(col_param)) - return result + Args: + col_params: List of ListSessionGroupsRequest.ColParam protobufs. + Returns: + A list of extractor functions. The ith element in the + returned list extracts the column corresponding to the ith element of + _request.col_params + """ + result = [] + for col_param in col_params: + result.append(_create_extractor(col_param)) + return result def _create_extractor(col_param): - if col_param.HasField('metric'): - return _create_metric_extractor(col_param.metric) - elif col_param.HasField('hparam'): - return _create_hparam_extractor(col_param.hparam) - else: - raise error.HParamsError( - 'Got ColParam with both "metric" and "hparam" fields unset: %s' % - col_param) + if col_param.HasField("metric"): + return _create_metric_extractor(col_param.metric) + elif col_param.HasField("hparam"): + return _create_hparam_extractor(col_param.hparam) + else: + raise error.HParamsError( + 'Got ColParam with both "metric" and "hparam" fields unset: %s' + % col_param + ) def _create_metric_extractor(metric_name): - """Returns function that extracts a metric from a session group or a session. + """Returns function that extracts a metric from a session group or a + session. - Args: - metric_name: tensorboard.hparams.MetricName protobuffer. Identifies the - metric to extract from the session group. - Returns: - A function that takes a tensorboard.hparams.SessionGroup or - tensorborad.hparams.Session protobuffer and returns the value of the metric - identified by 'metric_name' or None if the value doesn't exist. - """ - def extractor_fn(session_or_group): - metric_value = _find_metric_value(session_or_group, - metric_name) - return metric_value.value if metric_value else None + Args: + metric_name: tensorboard.hparams.MetricName protobuffer. Identifies the + metric to extract from the session group. + Returns: + A function that takes a tensorboard.hparams.SessionGroup or + tensorborad.hparams.Session protobuffer and returns the value of the metric + identified by 'metric_name' or None if the value doesn't exist. + """ + + def extractor_fn(session_or_group): + metric_value = _find_metric_value(session_or_group, metric_name) + return metric_value.value if metric_value else None - return extractor_fn + return extractor_fn def _find_metric_value(session_or_group, metric_name): - """Returns the metric_value for a given metric in a session or session group. - - Args: - session_or_group: A Session protobuffer or SessionGroup protobuffer. - metric_name: A MetricName protobuffer. The metric to search for. - Returns: - A MetricValue protobuffer representing the value of the given metric or - None if no such metric was found in session_or_group. - """ - # Note: We can speed this up by converting the metric_values field - # to a dictionary on initialization, to avoid a linear search here. We'll - # need to wrap the SessionGroup and Session protos in a python object for - # that. - for metric_value in session_or_group.metric_values: - if (metric_value.name.tag == metric_name.tag and - metric_value.name.group == metric_name.group): - return metric_value + """Returns the metric_value for a given metric in a session or session + group. + + Args: + session_or_group: A Session protobuffer or SessionGroup protobuffer. + metric_name: A MetricName protobuffer. The metric to search for. + Returns: + A MetricValue protobuffer representing the value of the given metric or + None if no such metric was found in session_or_group. + """ + # Note: We can speed this up by converting the metric_values field + # to a dictionary on initialization, to avoid a linear search here. We'll + # need to wrap the SessionGroup and Session protos in a python object for + # that. + for metric_value in session_or_group.metric_values: + if ( + metric_value.name.tag == metric_name.tag + and metric_value.name.group == metric_name.group + ): + return metric_value def _create_hparam_extractor(hparam_name): - """Returns an extractor function that extracts an hparam from a session group. + """Returns an extractor function that extracts an hparam from a session + group. + + Args: + hparam_name: str. Identies the hparam to extract from the session group. + Returns: + A function that takes a tensorboard.hparams.SessionGroup protobuffer and + returns the value, as a native Python object, of the hparam identified by + 'hparam_name'. + """ - Args: - hparam_name: str. Identies the hparam to extract from the session group. - Returns: - A function that takes a tensorboard.hparams.SessionGroup protobuffer and - returns the value, as a native Python object, of the hparam identified by - 'hparam_name'. - """ - def extractor_fn(session_group): - if hparam_name in session_group.hparams: - return _value_to_python(session_group.hparams[hparam_name]) - return None + def extractor_fn(session_group): + if hparam_name in session_group.hparams: + return _value_to_python(session_group.hparams[hparam_name]) + return None - return extractor_fn + return extractor_fn # Filters. A filter is a boolean function that takes a session group and returns @@ -350,137 +381,143 @@ def extractor_fn(session_group): # of a single column value extracted from the session group with a given # extractor specified in the construction of the filter. def _create_filters(col_params, extractors): - """Creates filters for the given col_params. - - Args: - col_params: List of ListSessionGroupsRequest.ColParam protobufs. - extractors: list of extractor functions of the same length as col_params. - Each element should extract the column described by the corresponding - element of col_params. - Returns: - A list of filter functions. Each corresponding to a single - col_params.filter oneof field of _request - """ - result = [] - for col_param, extractor in zip(col_params, extractors): - a_filter = _create_filter(col_param, extractor) - if a_filter: - result.append(a_filter) - return result + """Creates filters for the given col_params. + + Args: + col_params: List of ListSessionGroupsRequest.ColParam protobufs. + extractors: list of extractor functions of the same length as col_params. + Each element should extract the column described by the corresponding + element of col_params. + Returns: + A list of filter functions. Each corresponding to a single + col_params.filter oneof field of _request + """ + result = [] + for col_param, extractor in zip(col_params, extractors): + a_filter = _create_filter(col_param, extractor) + if a_filter: + result.append(a_filter) + return result def _create_filter(col_param, extractor): - """Creates a filter for the given col_param and extractor. - - Args: - col_param: A tensorboard.hparams.ColParams object identifying the column - and describing the filter to apply. - extractor: A function that extract the column value identified by - 'col_param' from a tensorboard.hparams.SessionGroup protobuffer. - Returns: - A boolean function taking a tensorboard.hparams.SessionGroup protobuffer - returning True if the session group passes the filter described by - 'col_param'. If col_param does not specify a filter (i.e. any session - group passes) returns None. - """ - include_missing_values = not col_param.exclude_missing_values - if col_param.HasField('filter_regexp'): - value_filter_fn = _create_regexp_filter(col_param.filter_regexp) - elif col_param.HasField('filter_interval'): - value_filter_fn = _create_interval_filter(col_param.filter_interval) - elif col_param.HasField('filter_discrete'): - value_filter_fn = _create_discrete_set_filter(col_param.filter_discrete) - elif include_missing_values: - # No 'filter' field and include_missing_values is True. - # Thus, the resulting filter always returns True, so to optimize for this - # common case we do not include it in the list of filters to check. - return None - else: - value_filter_fn = lambda _: True - - def filter_fn(session_group): - value = extractor(session_group) - if value is None: - return include_missing_values - return value_filter_fn(value) - - return filter_fn + """Creates a filter for the given col_param and extractor. + + Args: + col_param: A tensorboard.hparams.ColParams object identifying the column + and describing the filter to apply. + extractor: A function that extract the column value identified by + 'col_param' from a tensorboard.hparams.SessionGroup protobuffer. + Returns: + A boolean function taking a tensorboard.hparams.SessionGroup protobuffer + returning True if the session group passes the filter described by + 'col_param'. If col_param does not specify a filter (i.e. any session + group passes) returns None. + """ + include_missing_values = not col_param.exclude_missing_values + if col_param.HasField("filter_regexp"): + value_filter_fn = _create_regexp_filter(col_param.filter_regexp) + elif col_param.HasField("filter_interval"): + value_filter_fn = _create_interval_filter(col_param.filter_interval) + elif col_param.HasField("filter_discrete"): + value_filter_fn = _create_discrete_set_filter(col_param.filter_discrete) + elif include_missing_values: + # No 'filter' field and include_missing_values is True. + # Thus, the resulting filter always returns True, so to optimize for this + # common case we do not include it in the list of filters to check. + return None + else: + value_filter_fn = lambda _: True + + def filter_fn(session_group): + value = extractor(session_group) + if value is None: + return include_missing_values + return value_filter_fn(value) + + return filter_fn def _create_regexp_filter(regex): - """Returns a boolean function that filters strings based on a regular exp. - - Args: - regex: A string describing the regexp to use. - Returns: - A function taking a string and returns True if any of its substrings - matches regex. - """ - # Warning: Note that python's regex library allows inputs that take - # exponential time. Time-limiting it is difficult. When we move to - # a true multi-tenant tensorboard server, the regexp implementation here - # would need to be replaced by something more secure. - compiled_regex = re.compile(regex) - def filter_fn(value): - if not isinstance(value, six.string_types): - raise error.HParamsError( - 'Cannot use a regexp filter for a value of type %s. Value: %s' % - (type(value), value)) - return re.search(compiled_regex, value) is not None - - return filter_fn + """Returns a boolean function that filters strings based on a regular exp. + + Args: + regex: A string describing the regexp to use. + Returns: + A function taking a string and returns True if any of its substrings + matches regex. + """ + # Warning: Note that python's regex library allows inputs that take + # exponential time. Time-limiting it is difficult. When we move to + # a true multi-tenant tensorboard server, the regexp implementation here + # would need to be replaced by something more secure. + compiled_regex = re.compile(regex) + + def filter_fn(value): + if not isinstance(value, six.string_types): + raise error.HParamsError( + "Cannot use a regexp filter for a value of type %s. Value: %s" + % (type(value), value) + ) + return re.search(compiled_regex, value) is not None + + return filter_fn def _create_interval_filter(interval): - """Returns a function that checkes whether a number belongs to an interval. - - Args: - interval: A tensorboard.hparams.Interval protobuf describing the interval. - Returns: - A function taking a number (a float or an object of a type in - six.integer_types) that returns True if the number belongs to (the closed) - 'interval'. - """ - def filter_fn(value): - if (not isinstance(value, six.integer_types) and - not isinstance(value, float)): - raise error.HParamsError( - 'Cannot use an interval filter for a value of type: %s, Value: %s' % - (type(value), value)) - return interval.min_value <= value and value <= interval.max_value - - return filter_fn + """Returns a function that checkes whether a number belongs to an interval. + + Args: + interval: A tensorboard.hparams.Interval protobuf describing the interval. + Returns: + A function taking a number (a float or an object of a type in + six.integer_types) that returns True if the number belongs to (the closed) + 'interval'. + """ + + def filter_fn(value): + if not isinstance(value, six.integer_types) and not isinstance( + value, float + ): + raise error.HParamsError( + "Cannot use an interval filter for a value of type: %s, Value: %s" + % (type(value), value) + ) + return interval.min_value <= value and value <= interval.max_value + + return filter_fn def _create_discrete_set_filter(discrete_set): - """Returns a function that checks whether a value belongs to a set. + """Returns a function that checks whether a value belongs to a set. - Args: - discrete_set: A list of objects representing the set. - Returns: - A function taking an object and returns True if its in the set. Membership - is tested using the Python 'in' operator (thus, equality of distinct - objects is computed using the '==' operator). - """ - def filter_fn(value): - return value in discrete_set + Args: + discrete_set: A list of objects representing the set. + Returns: + A function taking an object and returns True if its in the set. Membership + is tested using the Python 'in' operator (thus, equality of distinct + objects is computed using the '==' operator). + """ - return filter_fn + def filter_fn(value): + return value in discrete_set + return filter_fn -def _value_to_python(value): - """Converts a google.protobuf.Value to a native Python object.""" - assert isinstance(value, struct_pb2.Value) - field = value.WhichOneof('kind') - if field == 'number_value': - return value.number_value - elif field == 'string_value': - return value.string_value - elif field == 'bool_value': - return value.bool_value - else: - raise ValueError('Unknown struct_pb2.Value oneof field set: %s' % field) +def _value_to_python(value): + """Converts a google.protobuf.Value to a native Python object.""" + + assert isinstance(value, struct_pb2.Value) + field = value.WhichOneof("kind") + if field == "number_value": + return value.number_value + elif field == "string_value": + return value.string_value + elif field == "bool_value": + return value.bool_value + else: + raise ValueError("Unknown struct_pb2.Value oneof field set: %s" % field) # As protobuffers are mutable we can't use MetricName directly as a dict's key. @@ -489,136 +526,148 @@ def _value_to_python(value): # immutable class that defines equality and __hash__ methods based on the text # representation of the protocol buffer. This is more complex, but won't # require modification if we ever add fields to MetricName. -_MetricIdentifier = collections.namedtuple('_MetricIdentifier', 'group tag') +_MetricIdentifier = collections.namedtuple("_MetricIdentifier", "group tag") class _MetricStats(object): - """A simple class to hold metric stats used in calculating metric averages. - - Used in _set_avg_session_metrics(). See the comments in that function - for more details. - - Attributes: - total: int. The sum of the metric measurements seen so far. - count: int. The number of largest-step measuremens seen so far. - total_step: int. The sum of the steps at which the measurements were taken - total_wall_time_secs: float. The sum of the wall_time_secs at - which the measurements were taken. - """ - # We use slots here to catch typos in attributes earlier. Note that this makes - # this class incompatible with 'pickle'. - __slots__ = [ - 'total', - 'count', - 'total_step', - 'total_wall_time_secs', - ] - - def __init__(self): - self.total = 0 - self.count = 0 - self.total_step = 0 - self.total_wall_time_secs = 0.0 + """A simple class to hold metric stats used in calculating metric averages. + + Used in _set_avg_session_metrics(). See the comments in that function + for more details. + + Attributes: + total: int. The sum of the metric measurements seen so far. + count: int. The number of largest-step measuremens seen so far. + total_step: int. The sum of the steps at which the measurements were taken + total_wall_time_secs: float. The sum of the wall_time_secs at + which the measurements were taken. + """ + + # We use slots here to catch typos in attributes earlier. Note that this makes + # this class incompatible with 'pickle'. + __slots__ = [ + "total", + "count", + "total_step", + "total_wall_time_secs", + ] + + def __init__(self): + self.total = 0 + self.count = 0 + self.total_step = 0 + self.total_wall_time_secs = 0.0 def _set_avg_session_metrics(session_group): - """Sets the metrics for the group to be the average of its sessions. - - The resulting session group metrics consist of the union of metrics across - the group's sessions. The value of each session group metric is the average - of that metric values across the sessions in the group. The 'step' and - 'wall_time_secs' fields of the resulting MetricValue field in the session - group are populated with the corresponding averages (truncated for 'step') - as well. - - Args: - session_group: A SessionGroup protobuffer. - """ - assert session_group.sessions, 'SessionGroup cannot be empty.' - # Algorithm: Iterate over all (session, metric) pairs and maintain a - # dict from _MetricIdentifier to _MetricStats objects. - # Then use the final dict state to compute the average for each metric. - metric_stats = collections.defaultdict(_MetricStats) - for session in session_group.sessions: - for metric_value in session.metric_values: - metric_name = _MetricIdentifier(group=metric_value.name.group, - tag=metric_value.name.tag) - stats = metric_stats[metric_name] - stats.total += metric_value.value - stats.count += 1 - stats.total_step += metric_value.training_step - stats.total_wall_time_secs += metric_value.wall_time_secs - - del session_group.metric_values[:] - for (metric_name, stats) in six.iteritems(metric_stats): - session_group.metric_values.add( - name=api_pb2.MetricName(group=metric_name.group, tag=metric_name.tag), - value=float(stats.total)/float(stats.count), - training_step=stats.total_step // stats.count, - wall_time_secs=stats.total_wall_time_secs / stats.count) + """Sets the metrics for the group to be the average of its sessions. + + The resulting session group metrics consist of the union of metrics across + the group's sessions. The value of each session group metric is the average + of that metric values across the sessions in the group. The 'step' and + 'wall_time_secs' fields of the resulting MetricValue field in the session + group are populated with the corresponding averages (truncated for 'step') + as well. + + Args: + session_group: A SessionGroup protobuffer. + """ + assert session_group.sessions, "SessionGroup cannot be empty." + # Algorithm: Iterate over all (session, metric) pairs and maintain a + # dict from _MetricIdentifier to _MetricStats objects. + # Then use the final dict state to compute the average for each metric. + metric_stats = collections.defaultdict(_MetricStats) + for session in session_group.sessions: + for metric_value in session.metric_values: + metric_name = _MetricIdentifier( + group=metric_value.name.group, tag=metric_value.name.tag + ) + stats = metric_stats[metric_name] + stats.total += metric_value.value + stats.count += 1 + stats.total_step += metric_value.training_step + stats.total_wall_time_secs += metric_value.wall_time_secs + + del session_group.metric_values[:] + for (metric_name, stats) in six.iteritems(metric_stats): + session_group.metric_values.add( + name=api_pb2.MetricName( + group=metric_name.group, tag=metric_name.tag + ), + value=float(stats.total) / float(stats.count), + training_step=stats.total_step // stats.count, + wall_time_secs=stats.total_wall_time_secs / stats.count, + ) # A namedtuple to hold a session's metric value. # 'session_index' is the index of the session in its group. -_Measurement = collections.namedtuple('_Measurement', ['metric_value', - 'session_index']) +_Measurement = collections.namedtuple( + "_Measurement", ["metric_value", "session_index"] +) def _set_median_session_metrics(session_group, aggregation_metric): - """Sets the metrics for session_group to those of its "median session". - - The median session is the session in session_group with the median value - of the metric given by 'aggregation_metric'. The median is taken over the - subset of sessions in the group whose 'aggregation_metric' was measured - at the largest training step among the sessions in the group. - - Args: - session_group: A SessionGroup protobuffer. - aggregation_metric: A MetricName protobuffer. - """ - measurements = sorted(_measurements(session_group, aggregation_metric), - key=operator.attrgetter('metric_value.value')) - median_session = measurements[(len(measurements) - 1) // 2].session_index - del session_group.metric_values[:] - session_group.metric_values.MergeFrom( - session_group.sessions[median_session].metric_values) - - -def _set_extremum_session_metrics(session_group, aggregation_metric, - extremum_fn): - """Sets the metrics for session_group to those of its "extremum session". - - The extremum session is the session in session_group with the extremum value - of the metric given by 'aggregation_metric'. The extremum is taken over the - subset of sessions in the group whose 'aggregation_metric' was measured - at the largest training step among the sessions in the group. - - Args: - session_group: A SessionGroup protobuffer. - aggregation_metric: A MetricName protobuffer. - extremum_fn: callable. Must be either 'min' or 'max'. Determines the type of - extremum to compute. - """ - measurements = _measurements(session_group, aggregation_metric) - ext_session = extremum_fn( - measurements, - key=operator.attrgetter('metric_value.value')).session_index - del session_group.metric_values[:] - session_group.metric_values.MergeFrom( - session_group.sessions[ext_session].metric_values) + """Sets the metrics for session_group to those of its "median session". + + The median session is the session in session_group with the median value + of the metric given by 'aggregation_metric'. The median is taken over the + subset of sessions in the group whose 'aggregation_metric' was measured + at the largest training step among the sessions in the group. + + Args: + session_group: A SessionGroup protobuffer. + aggregation_metric: A MetricName protobuffer. + """ + measurements = sorted( + _measurements(session_group, aggregation_metric), + key=operator.attrgetter("metric_value.value"), + ) + median_session = measurements[(len(measurements) - 1) // 2].session_index + del session_group.metric_values[:] + session_group.metric_values.MergeFrom( + session_group.sessions[median_session].metric_values + ) + + +def _set_extremum_session_metrics( + session_group, aggregation_metric, extremum_fn +): + """Sets the metrics for session_group to those of its "extremum session". + + The extremum session is the session in session_group with the extremum value + of the metric given by 'aggregation_metric'. The extremum is taken over the + subset of sessions in the group whose 'aggregation_metric' was measured + at the largest training step among the sessions in the group. + + Args: + session_group: A SessionGroup protobuffer. + aggregation_metric: A MetricName protobuffer. + extremum_fn: callable. Must be either 'min' or 'max'. Determines the type of + extremum to compute. + """ + measurements = _measurements(session_group, aggregation_metric) + ext_session = extremum_fn( + measurements, key=operator.attrgetter("metric_value.value") + ).session_index + del session_group.metric_values[:] + session_group.metric_values.MergeFrom( + session_group.sessions[ext_session].metric_values + ) def _measurements(session_group, metric_name): - """A generator for the values of the metric across the sessions in the group. - - Args: - session_group: A SessionGroup protobuffer. - metric_name: A MetricName protobuffer. - Yields: - The next metric value wrapped in a _Measurement instance. - """ - for session_index, session in enumerate(session_group.sessions): - metric_value = _find_metric_value(session, metric_name) - if not metric_value: - continue - yield _Measurement(metric_value, session_index) + """A generator for the values of the metric across the sessions in the + group. + + Args: + session_group: A SessionGroup protobuffer. + metric_name: A MetricName protobuffer. + Yields: + The next metric value wrapped in a _Measurement instance. + """ + for session_index, session in enumerate(session_group.sessions): + metric_value = _find_metric_value(session, metric_name) + if not metric_value: + continue + yield _Measurement(metric_value, session_index) diff --git a/tensorboard/plugins/hparams/list_session_groups_test.py b/tensorboard/plugins/hparams/list_session_groups_test.py index 9184c12e5c..800ad4d259 100644 --- a/tensorboard/plugins/hparams/list_session_groups_test.py +++ b/tensorboard/plugins/hparams/list_session_groups_test.py @@ -21,11 +21,12 @@ import operator import tensorflow as tf + try: - # python version >= 3.3 - from unittest import mock + # python version >= 3.3 + from unittest import mock except ImportError: - import mock # pylint: disable=unused-import + import mock # pylint: disable=unused-import from google.protobuf import text_format from tensorboard.backend.event_processing import event_accumulator @@ -38,9 +39,9 @@ from tensorboard.plugins.hparams import plugin_data_pb2 -DATA_TYPE_EXPERIMENT = 'experiment' -DATA_TYPE_SESSION_START_INFO = 'session_start_info' -DATA_TYPE_SESSION_END_INFO = 'session_end_info' +DATA_TYPE_EXPERIMENT = "experiment" +DATA_TYPE_SESSION_START_INFO = "session_start_info" +DATA_TYPE_SESSION_END_INFO = "session_end_info" # Allow us to abbreviate event_accumulator.TensorEvent @@ -48,20 +49,20 @@ class ListSessionGroupsTest(tf.test.TestCase): - # Make assertProtoEquals print all the diff. - maxDiff = None # pylint: disable=invalid-name - - def setUp(self): - self._mock_tb_context = mock.create_autospec( - base_plugin.TBContext) - self._mock_multiplexer = mock.create_autospec( - plugin_event_multiplexer.EventMultiplexer) - self._mock_tb_context.multiplexer = self._mock_multiplexer - self._mock_multiplexer.PluginRunToTagToContent.return_value = { - '': { - metadata.EXPERIMENT_TAG: - self._serialized_plugin_data( - DATA_TYPE_EXPERIMENT, ''' + # Make assertProtoEquals print all the diff. + maxDiff = None # pylint: disable=invalid-name + + def setUp(self): + self._mock_tb_context = mock.create_autospec(base_plugin.TBContext) + self._mock_multiplexer = mock.create_autospec( + plugin_event_multiplexer.EventMultiplexer + ) + self._mock_tb_context.multiplexer = self._mock_multiplexer + self._mock_multiplexer.PluginRunToTagToContent.return_value = { + "": { + metadata.EXPERIMENT_TAG: self._serialized_plugin_data( + DATA_TYPE_EXPERIMENT, + """ description: 'Test experiment' user: 'Test user' hparam_infos: [ @@ -82,12 +83,13 @@ def setUp(self): { name: { tag: 'delta_temp' } }, { name: { tag: 'optional_metric' } } ] - ''') - }, - 'session_1': { - metadata.SESSION_START_INFO_TAG: - self._serialized_plugin_data( - DATA_TYPE_SESSION_START_INFO, ''' + """, + ) + }, + "session_1": { + metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( + DATA_TYPE_SESSION_START_INFO, + """ hparams:{ key: 'initial_temp' value: { number_value: 270 } }, hparams:{ key: 'final_temp' value: { number_value: 150 } }, hparams:{ @@ -96,18 +98,20 @@ def setUp(self): hparams:{ key: 'bool_hparam' value: { bool_value: true } } group_name: 'group_1' start_time_secs: 314159 - '''), - metadata.SESSION_END_INFO_TAG: - self._serialized_plugin_data( - DATA_TYPE_SESSION_END_INFO, ''' + """, + ), + metadata.SESSION_END_INFO_TAG: self._serialized_plugin_data( + DATA_TYPE_SESSION_END_INFO, + """ status: STATUS_SUCCESS end_time_secs: 314164 - ''') - }, - 'session_2': { - metadata.SESSION_START_INFO_TAG: - self._serialized_plugin_data( - DATA_TYPE_SESSION_START_INFO, ''' + """, + ), + }, + "session_2": { + metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( + DATA_TYPE_SESSION_START_INFO, + """ hparams:{ key: 'initial_temp' value: { number_value: 280 } }, hparams:{ key: 'final_temp' value: { number_value: 100 } }, hparams:{ @@ -116,18 +120,20 @@ def setUp(self): hparams:{ key: 'bool_hparam' value: { bool_value: false } } group_name: 'group_2' start_time_secs: 314159 - '''), - metadata.SESSION_END_INFO_TAG: - self._serialized_plugin_data( - DATA_TYPE_SESSION_END_INFO, ''' + """, + ), + metadata.SESSION_END_INFO_TAG: self._serialized_plugin_data( + DATA_TYPE_SESSION_END_INFO, + """ status: STATUS_SUCCESS end_time_secs: 314164 - ''') - }, - 'session_3': { - metadata.SESSION_START_INFO_TAG: - self._serialized_plugin_data( - DATA_TYPE_SESSION_START_INFO, ''' + """, + ), + }, + "session_3": { + metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( + DATA_TYPE_SESSION_START_INFO, + """ hparams:{ key: 'initial_temp' value: { number_value: 280 } }, hparams:{ key: 'final_temp' value: { number_value: 100 } }, hparams:{ @@ -136,18 +142,20 @@ def setUp(self): hparams:{ key: 'bool_hparam' value: { bool_value: false } } group_name: 'group_2' start_time_secs: 314159 - '''), - metadata.SESSION_END_INFO_TAG: - self._serialized_plugin_data( - DATA_TYPE_SESSION_END_INFO, ''' + """, + ), + metadata.SESSION_END_INFO_TAG: self._serialized_plugin_data( + DATA_TYPE_SESSION_END_INFO, + """ status: STATUS_FAILURE end_time_secs: 314164 - ''') - }, - 'session_4': { - metadata.SESSION_START_INFO_TAG: - self._serialized_plugin_data( - DATA_TYPE_SESSION_START_INFO, ''' + """, + ), + }, + "session_4": { + metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( + DATA_TYPE_SESSION_START_INFO, + """ hparams:{ key: 'initial_temp' value: { number_value: 300 } }, hparams:{ key: 'final_temp' value: { number_value: 120 } }, hparams:{ @@ -159,18 +167,20 @@ def setUp(self): }, group_name: 'group_3' start_time_secs: 314159 - '''), - metadata.SESSION_END_INFO_TAG: - self._serialized_plugin_data( - DATA_TYPE_SESSION_END_INFO, ''' + """, + ), + metadata.SESSION_END_INFO_TAG: self._serialized_plugin_data( + DATA_TYPE_SESSION_END_INFO, + """ status: STATUS_UNKNOWN end_time_secs: 314164 - ''') - }, - 'session_5': { - metadata.SESSION_START_INFO_TAG: - self._serialized_plugin_data( - DATA_TYPE_SESSION_START_INFO, ''' + """, + ), + }, + "session_5": { + metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( + DATA_TYPE_SESSION_START_INFO, + """ hparams:{ key: 'initial_temp' value: { number_value: 280 } }, hparams:{ key: 'final_temp' value: { number_value: 100 } }, hparams:{ @@ -179,113 +189,148 @@ def setUp(self): hparams:{ key: 'bool_hparam' value: { bool_value: false } } group_name: 'group_2' start_time_secs: 314159 - '''), - metadata.SESSION_END_INFO_TAG: - self._serialized_plugin_data( - DATA_TYPE_SESSION_END_INFO, ''' + """, + ), + metadata.SESSION_END_INFO_TAG: self._serialized_plugin_data( + DATA_TYPE_SESSION_END_INFO, + """ status: STATUS_SUCCESS end_time_secs: 314164 - ''') - }, - } - self._mock_multiplexer.Tensors.side_effect = self._mock_tensors - - # A mock version of EventMultiplexer.Tensors - def _mock_tensors(self, run, tag): - result_dict = { - 'session_1': { - 'current_temp': [ - TensorEvent( - wall_time=1, step=1, - tensor_proto=tf.compat.v1.make_tensor_proto(10.0)) - ], - 'delta_temp': [ - TensorEvent( - wall_time=1, step=1, - tensor_proto=tf.compat.v1.make_tensor_proto(20.0)), - TensorEvent( - wall_time=10, step=2, - tensor_proto=tf.compat.v1.make_tensor_proto(15.0)) - ], - 'optional_metric': [ - TensorEvent( - wall_time=1, step=1, - tensor_proto=tf.compat.v1.make_tensor_proto(20.0)), - TensorEvent( - wall_time=2, step=20, - tensor_proto=tf.compat.v1.make_tensor_proto(33.0)) - ] - }, - 'session_2': { - 'current_temp': [ - TensorEvent( - wall_time=1, step=1, - tensor_proto=tf.compat.v1.make_tensor_proto(100.0)), - ], - 'delta_temp': [ - TensorEvent( - wall_time=1, step=1, - tensor_proto=tf.compat.v1.make_tensor_proto(200.0)), - TensorEvent( - wall_time=11, step=3, - tensor_proto=tf.compat.v1.make_tensor_proto(150.0)) - ] - }, - 'session_3': { - 'current_temp': [ - TensorEvent( - wall_time=1, step=1, - tensor_proto=tf.compat.v1.make_tensor_proto(1.0)), - ], - 'delta_temp': [ - TensorEvent( - wall_time=1, step=1, - tensor_proto=tf.compat.v1.make_tensor_proto(2.0)), - TensorEvent( - wall_time=10, step=2, - tensor_proto=tf.compat.v1.make_tensor_proto(1.5)) - ] - }, - 'session_4': { - 'current_temp': [ - TensorEvent( - wall_time=1, step=1, - tensor_proto=tf.compat.v1.make_tensor_proto(101.0)), - ], - 'delta_temp': [ - TensorEvent( - wall_time=1, step=1, - tensor_proto=tf.compat.v1.make_tensor_proto(201.0)), - TensorEvent( - wall_time=10, step=2, - tensor_proto=tf.compat.v1.make_tensor_proto(-151.0)) - ] - }, - 'session_5': { - 'current_temp': [ - TensorEvent( - wall_time=1, step=1, - tensor_proto=tf.compat.v1.make_tensor_proto(52.0)), - ], - 'delta_temp': [ - TensorEvent( - wall_time=1, step=1, - tensor_proto=tf.compat.v1.make_tensor_proto(2.0)), - TensorEvent( - wall_time=10, step=2, - tensor_proto=tf.compat.v1.make_tensor_proto(-18)) - ] - }, - } - return result_dict[run][tag] - - def test_empty_request(self): - # Since we don't allow any statuses, result should be empty. - self.assertProtoEquals('total_size: 0', - self._run_handler(request='')) - - def test_no_filter_no_sort(self): - request = ''' + """, + ), + }, + } + self._mock_multiplexer.Tensors.side_effect = self._mock_tensors + + # A mock version of EventMultiplexer.Tensors + def _mock_tensors(self, run, tag): + result_dict = { + "session_1": { + "current_temp": [ + TensorEvent( + wall_time=1, + step=1, + tensor_proto=tf.compat.v1.make_tensor_proto(10.0), + ) + ], + "delta_temp": [ + TensorEvent( + wall_time=1, + step=1, + tensor_proto=tf.compat.v1.make_tensor_proto(20.0), + ), + TensorEvent( + wall_time=10, + step=2, + tensor_proto=tf.compat.v1.make_tensor_proto(15.0), + ), + ], + "optional_metric": [ + TensorEvent( + wall_time=1, + step=1, + tensor_proto=tf.compat.v1.make_tensor_proto(20.0), + ), + TensorEvent( + wall_time=2, + step=20, + tensor_proto=tf.compat.v1.make_tensor_proto(33.0), + ), + ], + }, + "session_2": { + "current_temp": [ + TensorEvent( + wall_time=1, + step=1, + tensor_proto=tf.compat.v1.make_tensor_proto(100.0), + ), + ], + "delta_temp": [ + TensorEvent( + wall_time=1, + step=1, + tensor_proto=tf.compat.v1.make_tensor_proto(200.0), + ), + TensorEvent( + wall_time=11, + step=3, + tensor_proto=tf.compat.v1.make_tensor_proto(150.0), + ), + ], + }, + "session_3": { + "current_temp": [ + TensorEvent( + wall_time=1, + step=1, + tensor_proto=tf.compat.v1.make_tensor_proto(1.0), + ), + ], + "delta_temp": [ + TensorEvent( + wall_time=1, + step=1, + tensor_proto=tf.compat.v1.make_tensor_proto(2.0), + ), + TensorEvent( + wall_time=10, + step=2, + tensor_proto=tf.compat.v1.make_tensor_proto(1.5), + ), + ], + }, + "session_4": { + "current_temp": [ + TensorEvent( + wall_time=1, + step=1, + tensor_proto=tf.compat.v1.make_tensor_proto(101.0), + ), + ], + "delta_temp": [ + TensorEvent( + wall_time=1, + step=1, + tensor_proto=tf.compat.v1.make_tensor_proto(201.0), + ), + TensorEvent( + wall_time=10, + step=2, + tensor_proto=tf.compat.v1.make_tensor_proto(-151.0), + ), + ], + }, + "session_5": { + "current_temp": [ + TensorEvent( + wall_time=1, + step=1, + tensor_proto=tf.compat.v1.make_tensor_proto(52.0), + ), + ], + "delta_temp": [ + TensorEvent( + wall_time=1, + step=1, + tensor_proto=tf.compat.v1.make_tensor_proto(2.0), + ), + TensorEvent( + wall_time=10, + step=2, + tensor_proto=tf.compat.v1.make_tensor_proto(-18), + ), + ], + }, + } + return result_dict[run][tag] + + def test_empty_request(self): + # Since we don't allow any statuses, result should be empty. + self.assertProtoEquals("total_size: 0", self._run_handler(request="")) + + def test_no_filter_no_sort(self): + request = """ start_index: 0 slice_size: 3 allowed_statuses: [STATUS_UNKNOWN, @@ -293,10 +338,10 @@ def test_no_filter_no_sort(self): STATUS_FAILURE, STATUS_RUNNING] aggregation_type: AGGREGATION_AVG - ''' - response = self._run_handler(request) - self.assertProtoEquals( - ''' + """ + response = self._run_handler(request) + self.assertProtoEquals( + """ session_groups { name: "group_1" hparams { key: "bool_hparam" value { bool_value: true } } @@ -450,47 +495,52 @@ def test_no_filter_no_sort(self): } } total_size: 3 - ''', - response) + """, + response, + ) - def test_no_allowed_statuses(self): - request = ''' + def test_no_allowed_statuses(self): + request = """ start_index: 0 slice_size: 3 allowed_statuses: [] aggregation_type: AGGREGATION_AVG - ''' - response = self._run_handler(request) - self.assertEquals(len(response.session_groups), 0) + """ + response = self._run_handler(request) + self.assertEquals(len(response.session_groups), 0) - def test_some_allowed_statuses(self): - request = ''' + def test_some_allowed_statuses(self): + request = """ start_index: 0 slice_size: 3 allowed_statuses: [STATUS_UNKNOWN, STATUS_SUCCESS] aggregation_type: AGGREGATION_AVG - ''' - response = self._run_handler(request) - self.assertEquals( - _reduce_to_names(response.session_groups), - [('group_1', ['session_1']), - ('group_2', ['session_2', 'session_5']), - ('group_3', ['session_4'])]) - - def test_some_allowed_statuses_empty_groups(self): - request = ''' + """ + response = self._run_handler(request) + self.assertEquals( + _reduce_to_names(response.session_groups), + [ + ("group_1", ["session_1"]), + ("group_2", ["session_2", "session_5"]), + ("group_3", ["session_4"]), + ], + ) + + def test_some_allowed_statuses_empty_groups(self): + request = """ start_index: 0 slice_size: 3 allowed_statuses: [STATUS_FAILURE] aggregation_type: AGGREGATION_AVG - ''' - response = self._run_handler(request) - self.assertEquals( - _reduce_to_names(response.session_groups), - [('group_2', ['session_3'])]) - - def test_aggregation_median_current_temp(self): - request = ''' + """ + response = self._run_handler(request) + self.assertEquals( + _reduce_to_names(response.session_groups), + [("group_2", ["session_3"])], + ) + + def test_aggregation_median_current_temp(self): + request = """ start_index: 0 slice_size: 3 allowed_statuses: [STATUS_UNKNOWN, @@ -499,24 +549,26 @@ def test_aggregation_median_current_temp(self): STATUS_RUNNING] aggregation_type: AGGREGATION_MEDIAN aggregation_metric: { tag: "current_temp" } - ''' - response = self._run_handler(request) - self.assertEquals(len(response.session_groups[1].metric_values), 2) - self.assertProtoEquals( - '''name { tag: "current_temp" } + """ + response = self._run_handler(request) + self.assertEquals(len(response.session_groups[1].metric_values), 2) + self.assertProtoEquals( + """name { tag: "current_temp" } value: 52.0 training_step: 1 - wall_time_secs: 1.0''', - response.session_groups[1].metric_values[0]) - self.assertProtoEquals( - '''name { tag: "delta_temp" } + wall_time_secs: 1.0""", + response.session_groups[1].metric_values[0], + ) + self.assertProtoEquals( + """name { tag: "delta_temp" } value: -18.0 training_step: 2 - wall_time_secs: 10.0''', - response.session_groups[1].metric_values[1]) + wall_time_secs: 10.0""", + response.session_groups[1].metric_values[1], + ) - def test_aggregation_median_delta_temp(self): - request = ''' + def test_aggregation_median_delta_temp(self): + request = """ start_index: 0 slice_size: 3 allowed_statuses: [STATUS_UNKNOWN, @@ -525,24 +577,26 @@ def test_aggregation_median_delta_temp(self): STATUS_RUNNING] aggregation_type: AGGREGATION_MEDIAN aggregation_metric: { tag: "delta_temp" } - ''' - response = self._run_handler(request) - self.assertEquals(len(response.session_groups[1].metric_values), 2) - self.assertProtoEquals( - '''name { tag: "current_temp" } + """ + response = self._run_handler(request) + self.assertEquals(len(response.session_groups[1].metric_values), 2) + self.assertProtoEquals( + """name { tag: "current_temp" } value: 1.0 training_step: 1 - wall_time_secs: 1.0''', - response.session_groups[1].metric_values[0]) - self.assertProtoEquals( - '''name { tag: "delta_temp" } + wall_time_secs: 1.0""", + response.session_groups[1].metric_values[0], + ) + self.assertProtoEquals( + """name { tag: "delta_temp" } value: 1.5 training_step: 2 - wall_time_secs: 10.0''', - response.session_groups[1].metric_values[1]) + wall_time_secs: 10.0""", + response.session_groups[1].metric_values[1], + ) - def test_aggregation_max_current_temp(self): - request = ''' + def test_aggregation_max_current_temp(self): + request = """ start_index: 0 slice_size: 3 allowed_statuses: [STATUS_UNKNOWN, @@ -551,24 +605,26 @@ def test_aggregation_max_current_temp(self): STATUS_RUNNING] aggregation_type: AGGREGATION_MAX aggregation_metric: { tag: "current_temp" } - ''' - response = self._run_handler(request) - self.assertEquals(len(response.session_groups[1].metric_values), 2) - self.assertProtoEquals( - '''name { tag: "current_temp" } + """ + response = self._run_handler(request) + self.assertEquals(len(response.session_groups[1].metric_values), 2) + self.assertProtoEquals( + """name { tag: "current_temp" } value: 100 training_step: 1 - wall_time_secs: 1.0''', - response.session_groups[1].metric_values[0]) - self.assertProtoEquals( - '''name { tag: "delta_temp" } + wall_time_secs: 1.0""", + response.session_groups[1].metric_values[0], + ) + self.assertProtoEquals( + """name { tag: "delta_temp" } value: 150.0 training_step: 3 - wall_time_secs: 11.0''', - response.session_groups[1].metric_values[1]) + wall_time_secs: 11.0""", + response.session_groups[1].metric_values[1], + ) - def test_aggregation_max_delta_temp(self): - request = ''' + def test_aggregation_max_delta_temp(self): + request = """ start_index: 0 slice_size: 3 allowed_statuses: [STATUS_UNKNOWN, @@ -577,24 +633,26 @@ def test_aggregation_max_delta_temp(self): STATUS_RUNNING] aggregation_type: AGGREGATION_MAX aggregation_metric: { tag: "delta_temp" } - ''' - response = self._run_handler(request) - self.assertEquals(len(response.session_groups[1].metric_values), 2) - self.assertProtoEquals( - '''name { tag: "current_temp" } + """ + response = self._run_handler(request) + self.assertEquals(len(response.session_groups[1].metric_values), 2) + self.assertProtoEquals( + """name { tag: "current_temp" } value: 100.0 training_step: 1 - wall_time_secs: 1.0''', - response.session_groups[1].metric_values[0]) - self.assertProtoEquals( - '''name { tag: "delta_temp" } + wall_time_secs: 1.0""", + response.session_groups[1].metric_values[0], + ) + self.assertProtoEquals( + """name { tag: "delta_temp" } value: 150.0 training_step: 3 - wall_time_secs: 11.0''', - response.session_groups[1].metric_values[1]) + wall_time_secs: 11.0""", + response.session_groups[1].metric_values[1], + ) - def test_aggregation_min_current_temp(self): - request = ''' + def test_aggregation_min_current_temp(self): + request = """ start_index: 0 slice_size: 3 allowed_statuses: [STATUS_UNKNOWN, @@ -603,24 +661,26 @@ def test_aggregation_min_current_temp(self): STATUS_RUNNING] aggregation_type: AGGREGATION_MIN aggregation_metric: { tag: "current_temp" } - ''' - response = self._run_handler(request) - self.assertEquals(len(response.session_groups[1].metric_values), 2) - self.assertProtoEquals( - '''name { tag: "current_temp" } + """ + response = self._run_handler(request) + self.assertEquals(len(response.session_groups[1].metric_values), 2) + self.assertProtoEquals( + """name { tag: "current_temp" } value: 1.0 training_step: 1 - wall_time_secs: 1.0''', - response.session_groups[1].metric_values[0]) - self.assertProtoEquals( - '''name { tag: "delta_temp" } + wall_time_secs: 1.0""", + response.session_groups[1].metric_values[0], + ) + self.assertProtoEquals( + """name { tag: "delta_temp" } value: 1.5 training_step: 2 - wall_time_secs: 10.0''', - response.session_groups[1].metric_values[1]) + wall_time_secs: 10.0""", + response.session_groups[1].metric_values[1], + ) - def test_aggregation_min_delta_temp(self): - request = ''' + def test_aggregation_min_delta_temp(self): + request = """ start_index: 0 slice_size: 3 allowed_statuses: [STATUS_UNKNOWN, @@ -629,38 +689,41 @@ def test_aggregation_min_delta_temp(self): STATUS_RUNNING] aggregation_type: AGGREGATION_MIN aggregation_metric: { tag: "delta_temp" } - ''' - response = self._run_handler(request) - self.assertEquals(len(response.session_groups[1].metric_values), 2) - self.assertProtoEquals( - '''name { tag: "current_temp" } + """ + response = self._run_handler(request) + self.assertEquals(len(response.session_groups[1].metric_values), 2) + self.assertProtoEquals( + """name { tag: "current_temp" } value: 52.0 training_step: 1 - wall_time_secs: 1.0''', - response.session_groups[1].metric_values[0]) - self.assertProtoEquals( - '''name { tag: "delta_temp" } + wall_time_secs: 1.0""", + response.session_groups[1].metric_values[0], + ) + self.assertProtoEquals( + """name { tag: "delta_temp" } value: -18.0 training_step: 2 - wall_time_secs: 10.0''', - response.session_groups[1].metric_values[1]) + wall_time_secs: 10.0""", + response.session_groups[1].metric_values[1], + ) - def test_no_filter_no_sort_partial_slice(self): - self._verify_handler( - request=''' + def test_no_filter_no_sort_partial_slice(self): + self._verify_handler( + request=""" start_index: 1 slice_size: 1 allowed_statuses: [STATUS_UNKNOWN, STATUS_SUCCESS, STATUS_FAILURE, STATUS_RUNNING] - ''', - expected_session_group_names=['group_2'], - expected_total_size=3) + """, + expected_session_group_names=["group_2"], + expected_total_size=3, + ) - def test_no_filter_exclude_missing_values(self): - self._verify_handler( - request=''' + def test_no_filter_exclude_missing_values(self): + self._verify_handler( + request=""" col_params: { metric: { tag: 'optional_metric' } exclude_missing_values: true @@ -671,13 +734,14 @@ def test_no_filter_exclude_missing_values(self): STATUS_RUNNING] start_index: 0 slice_size: 3 - ''', - expected_session_group_names=['group_1'], - expected_total_size=1) + """, + expected_session_group_names=["group_1"], + expected_total_size=1, + ) - def test_filter_regexp(self): - self._verify_handler( - request=''' + def test_filter_regexp(self): + self._verify_handler( + request=""" col_params: { hparam: 'string_hparam' filter_regexp: 'AA' @@ -688,12 +752,13 @@ def test_filter_regexp(self): STATUS_RUNNING] start_index: 0 slice_size: 3 - ''', - expected_session_group_names=['group_2'], - expected_total_size=1) - # Test filtering out all session groups. - self._verify_handler( - request=''' + """, + expected_session_group_names=["group_2"], + expected_total_size=1, + ) + # Test filtering out all session groups. + self._verify_handler( + request=""" col_params: { hparam: 'string_hparam' filter_regexp: 'a string_100' @@ -704,13 +769,14 @@ def test_filter_regexp(self): STATUS_RUNNING] start_index: 0 slice_size: 3 - ''', - expected_session_group_names=[], - expected_total_size=0) + """, + expected_session_group_names=[], + expected_total_size=0, + ) - def test_filter_interval(self): - self._verify_handler( - request=''' + def test_filter_interval(self): + self._verify_handler( + request=""" col_params: { hparam: 'initial_temp' filter_interval: { min_value: 270 max_value: 282 } @@ -721,13 +787,14 @@ def test_filter_interval(self): STATUS_RUNNING] start_index: 0 slice_size: 3 - ''', - expected_session_group_names=['group_1', 'group_2'], - expected_total_size=2) + """, + expected_session_group_names=["group_1", "group_2"], + expected_total_size=2, + ) - def test_filter_discrete_set(self): - self._verify_handler( - request=''' + def test_filter_discrete_set(self): + self._verify_handler( + request=""" col_params: { metric: { tag: 'current_temp' } filter_discrete: { values: [{ number_value: 101.0 }, @@ -739,13 +806,14 @@ def test_filter_discrete_set(self): STATUS_RUNNING] start_index: 0 slice_size: 3 - ''', - expected_session_group_names=['group_1', 'group_3'], - expected_total_size=2) + """, + expected_session_group_names=["group_1", "group_3"], + expected_total_size=2, + ) - def test_filter_multiple_columns(self): - self._verify_handler( - request=''' + def test_filter_multiple_columns(self): + self._verify_handler( + request=""" col_params: { metric: { tag: 'current_temp' } filter_discrete: { values: [{ number_value: 101.0 }, @@ -761,13 +829,14 @@ def test_filter_multiple_columns(self): STATUS_RUNNING] start_index: 0 slice_size: 3 - ''', - expected_session_group_names=['group_1'], - expected_total_size=1) + """, + expected_session_group_names=["group_1"], + expected_total_size=1, + ) - def test_filter_single_column_with_missing_values(self): - self._verify_handler( - request=''' + def test_filter_single_column_with_missing_values(self): + self._verify_handler( + request=""" col_params: { hparam: 'optional_string_hparam' filter_regexp: 'B' @@ -779,11 +848,12 @@ def test_filter_single_column_with_missing_values(self): STATUS_RUNNING] start_index: 0 slice_size: 3 - ''', - expected_session_group_names=['group_3'], - expected_total_size=1) - self._verify_handler( - request=''' + """, + expected_session_group_names=["group_3"], + expected_total_size=1, + ) + self._verify_handler( + request=""" col_params: { hparam: 'optional_string_hparam' filter_regexp: 'B' @@ -795,12 +865,13 @@ def test_filter_single_column_with_missing_values(self): STATUS_RUNNING] start_index: 0 slice_size: 3 - ''', - expected_session_group_names=['group_1', 'group_2', 'group_3'], - expected_total_size=3) + """, + expected_session_group_names=["group_1", "group_2", "group_3"], + expected_total_size=3, + ) - self._verify_handler( - request=''' + self._verify_handler( + request=""" col_params: { metric: { tag: 'optional_metric' } filter_discrete: { values: { number_value: 33.0 } } @@ -812,13 +883,14 @@ def test_filter_single_column_with_missing_values(self): STATUS_RUNNING] start_index: 0 slice_size: 3 - ''', - expected_session_group_names=['group_1'], - expected_total_size=1) + """, + expected_session_group_names=["group_1"], + expected_total_size=1, + ) - def test_sort_one_column(self): - self._verify_handler( - request=''' + def test_sort_one_column(self): + self._verify_handler( + request=""" col_params: { metric: { tag: 'delta_temp' } order: ORDER_ASC @@ -829,11 +901,12 @@ def test_sort_one_column(self): STATUS_RUNNING] start_index: 0 slice_size: 3 - ''', - expected_session_group_names=['group_3', 'group_1', 'group_2'], - expected_total_size=3) - self._verify_handler( - request=''' + """, + expected_session_group_names=["group_3", "group_1", "group_2"], + expected_total_size=3, + ) + self._verify_handler( + request=""" col_params: { hparam: 'string_hparam' order: ORDER_ASC @@ -844,12 +917,13 @@ def test_sort_one_column(self): STATUS_RUNNING] start_index: 0 slice_size: 3 - ''', - expected_session_group_names=['group_2', 'group_1', 'group_3'], - expected_total_size=3) - # Test descending order. - self._verify_handler( - request=''' + """, + expected_session_group_names=["group_2", "group_1", "group_3"], + expected_total_size=3, + ) + # Test descending order. + self._verify_handler( + request=""" col_params: { hparam: 'string_hparam' order: ORDER_DESC @@ -860,13 +934,14 @@ def test_sort_one_column(self): STATUS_RUNNING] start_index: 0 slice_size: 3 - ''', - expected_session_group_names=['group_3', 'group_1', 'group_2'], - expected_total_size=3) + """, + expected_session_group_names=["group_3", "group_1", "group_2"], + expected_total_size=3, + ) - def test_sort_multiple_columns(self): - self._verify_handler( - request=''' + def test_sort_multiple_columns(self): + self._verify_handler( + request=""" col_params: { hparam: 'bool_hparam' order: ORDER_ASC @@ -881,12 +956,13 @@ def test_sort_multiple_columns(self): STATUS_RUNNING] start_index: 0 slice_size: 3 - ''', - expected_session_group_names=['group_2', 'group_3', 'group_1'], - expected_total_size=3) - # Primary key in descending order. Secondary key in ascending order. - self._verify_handler( - request=''' + """, + expected_session_group_names=["group_2", "group_3", "group_1"], + expected_total_size=3, + ) + # Primary key in descending order. Secondary key in ascending order. + self._verify_handler( + request=""" col_params: { hparam: 'bool_hparam' order: ORDER_DESC @@ -901,13 +977,14 @@ def test_sort_multiple_columns(self): STATUS_RUNNING] start_index: 0 slice_size: 3 - ''', - expected_session_group_names=['group_3', 'group_1', 'group_2'], - expected_total_size=3) + """, + expected_session_group_names=["group_3", "group_1", "group_2"], + expected_total_size=3, + ) - def test_sort_one_column_with_missing_values(self): - self._verify_handler( - request=''' + def test_sort_one_column_with_missing_values(self): + self._verify_handler( + request=""" col_params: { metric: { tag: 'optional_metric' } order: ORDER_ASC @@ -919,11 +996,12 @@ def test_sort_one_column_with_missing_values(self): STATUS_RUNNING] start_index: 0 slice_size: 3 - ''', - expected_session_group_names=['group_1', 'group_2', 'group_3'], - expected_total_size=3) - self._verify_handler( - request=''' + """, + expected_session_group_names=["group_1", "group_2", "group_3"], + expected_total_size=3, + ) + self._verify_handler( + request=""" col_params: { metric: { tag: 'optional_metric' } order: ORDER_ASC @@ -935,11 +1013,12 @@ def test_sort_one_column_with_missing_values(self): STATUS_RUNNING] start_index: 0 slice_size: 3 - ''', - expected_session_group_names=['group_2', 'group_3', 'group_1'], - expected_total_size=3) - self._verify_handler( - request=''' + """, + expected_session_group_names=["group_2", "group_3", "group_1"], + expected_total_size=3, + ) + self._verify_handler( + request=""" col_params: { hparam: 'optional_string_hparam' order: ORDER_ASC @@ -951,11 +1030,12 @@ def test_sort_one_column_with_missing_values(self): STATUS_RUNNING] start_index: 0 slice_size: 3 - ''', - expected_session_group_names=['group_3', 'group_1', 'group_2'], - expected_total_size=3) - self._verify_handler( - request=''' + """, + expected_session_group_names=["group_3", "group_1", "group_2"], + expected_total_size=3, + ) + self._verify_handler( + request=""" col_params: { hparam: 'optional_string_hparam' order: ORDER_ASC @@ -967,51 +1047,58 @@ def test_sort_one_column_with_missing_values(self): STATUS_RUNNING] start_index: 0 slice_size: 3 - ''', - expected_session_group_names=['group_1', 'group_2', 'group_3'], - expected_total_size=3) - - def _run_handler(self, request): - request_proto = api_pb2.ListSessionGroupsRequest() - text_format.Merge(request, request_proto) - handler = list_session_groups.Handler( - backend_context.Context(self._mock_tb_context), - request_proto) - response = handler.run() - # Sort the metric values repeated field in each session group to - # canonicalize the response. - for group in response.session_groups: - group.metric_values.sort(key=operator.attrgetter('name.tag')) - return response - - def _verify_handler( - self, request, expected_session_group_names, expected_total_size): - response = self._run_handler(request) - self.assertEqual(expected_session_group_names, - [sg.name for sg in response.session_groups]) - self.assertEqual(expected_total_size, response.total_size) - - def _serialized_plugin_data(self, data_oneof_field, text_protobuffer): - oneof_type_dict = { - DATA_TYPE_EXPERIMENT: api_pb2.Experiment, - DATA_TYPE_SESSION_START_INFO: plugin_data_pb2.SessionStartInfo, - DATA_TYPE_SESSION_END_INFO: plugin_data_pb2.SessionEndInfo - } - protobuffer = text_format.Merge(text_protobuffer, - oneof_type_dict[data_oneof_field]()) - plugin_data = plugin_data_pb2.HParamsPluginData() - getattr(plugin_data, data_oneof_field).CopyFrom(protobuffer) - return metadata.create_summary_metadata(plugin_data).plugin_data.content + """, + expected_session_group_names=["group_1", "group_2", "group_3"], + expected_total_size=3, + ) + + def _run_handler(self, request): + request_proto = api_pb2.ListSessionGroupsRequest() + text_format.Merge(request, request_proto) + handler = list_session_groups.Handler( + backend_context.Context(self._mock_tb_context), request_proto + ) + response = handler.run() + # Sort the metric values repeated field in each session group to + # canonicalize the response. + for group in response.session_groups: + group.metric_values.sort(key=operator.attrgetter("name.tag")) + return response + + def _verify_handler( + self, request, expected_session_group_names, expected_total_size + ): + response = self._run_handler(request) + self.assertEqual( + expected_session_group_names, + [sg.name for sg in response.session_groups], + ) + self.assertEqual(expected_total_size, response.total_size) + + def _serialized_plugin_data(self, data_oneof_field, text_protobuffer): + oneof_type_dict = { + DATA_TYPE_EXPERIMENT: api_pb2.Experiment, + DATA_TYPE_SESSION_START_INFO: plugin_data_pb2.SessionStartInfo, + DATA_TYPE_SESSION_END_INFO: plugin_data_pb2.SessionEndInfo, + } + protobuffer = text_format.Merge( + text_protobuffer, oneof_type_dict[data_oneof_field]() + ) + plugin_data = plugin_data_pb2.HParamsPluginData() + getattr(plugin_data, data_oneof_field).CopyFrom(protobuffer) + return metadata.create_summary_metadata(plugin_data).plugin_data.content def _reduce_session_group_to_names(session_group): - return [session.name for session in session_group.sessions] + return [session.name for session in session_group.sessions] def _reduce_to_names(session_groups): - return [(session_group.name, _reduce_session_group_to_names(session_group)) - for session_group in session_groups] + return [ + (session_group.name, _reduce_session_group_to_names(session_group)) + for session_group in session_groups + ] -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/hparams/metadata.py b/tensorboard/plugins/hparams/metadata.py index 1581787d76..92f2dc4141 100644 --- a/tensorboard/plugins/hparams/metadata.py +++ b/tensorboard/plugins/hparams/metadata.py @@ -23,90 +23,99 @@ from tensorboard.plugins.hparams import plugin_data_pb2 -PLUGIN_NAME = 'hparams' +PLUGIN_NAME = "hparams" PLUGIN_DATA_VERSION = 0 -EXPERIMENT_TAG = '_hparams_/experiment' -SESSION_START_INFO_TAG = '_hparams_/session_start_info' -SESSION_END_INFO_TAG = '_hparams_/session_end_info' +EXPERIMENT_TAG = "_hparams_/experiment" +SESSION_START_INFO_TAG = "_hparams_/session_start_info" +SESSION_END_INFO_TAG = "_hparams_/session_end_info" def create_summary_metadata(hparams_plugin_data_pb): - """Returns a summary metadata for the HParams plugin. - - Returns a summary_pb2.SummaryMetadata holding a copy of the given - HParamsPluginData message in its plugin_data.content field. - Sets the version field of the hparams_plugin_data_pb copy to - PLUGIN_DATA_VERSION. - - Args: - hparams_plugin_data_pb: the HParamsPluginData protobuffer to use. - """ - if not isinstance(hparams_plugin_data_pb, plugin_data_pb2.HParamsPluginData): - raise TypeError('Needed an instance of plugin_data_pb2.HParamsPluginData.' - ' Got: %s' % type(hparams_plugin_data_pb)) - content = plugin_data_pb2.HParamsPluginData() - content.CopyFrom(hparams_plugin_data_pb) - content.version = PLUGIN_DATA_VERSION - return summary_pb2.SummaryMetadata( - plugin_data=summary_pb2.SummaryMetadata.PluginData( - plugin_name=PLUGIN_NAME, content=content.SerializeToString())) + """Returns a summary metadata for the HParams plugin. + + Returns a summary_pb2.SummaryMetadata holding a copy of the given + HParamsPluginData message in its plugin_data.content field. + Sets the version field of the hparams_plugin_data_pb copy to + PLUGIN_DATA_VERSION. + + Args: + hparams_plugin_data_pb: the HParamsPluginData protobuffer to use. + """ + if not isinstance( + hparams_plugin_data_pb, plugin_data_pb2.HParamsPluginData + ): + raise TypeError( + "Needed an instance of plugin_data_pb2.HParamsPluginData." + " Got: %s" % type(hparams_plugin_data_pb) + ) + content = plugin_data_pb2.HParamsPluginData() + content.CopyFrom(hparams_plugin_data_pb) + content.version = PLUGIN_DATA_VERSION + return summary_pb2.SummaryMetadata( + plugin_data=summary_pb2.SummaryMetadata.PluginData( + plugin_name=PLUGIN_NAME, content=content.SerializeToString() + ) + ) def parse_experiment_plugin_data(content): - """Returns the experiment from HParam's SummaryMetadata.plugin_data.content. + """Returns the experiment from HParam's + SummaryMetadata.plugin_data.content. - Raises HParamsError if the content doesn't have 'experiment' set or - this file is incompatible with the version of the metadata stored. + Raises HParamsError if the content doesn't have 'experiment' set or + this file is incompatible with the version of the metadata stored. - Args: - content: The SummaryMetadata.plugin_data.content to use. - """ - return _parse_plugin_data_as(content, 'experiment') + Args: + content: The SummaryMetadata.plugin_data.content to use. + """ + return _parse_plugin_data_as(content, "experiment") def parse_session_start_info_plugin_data(content): - """Returns session_start_info from the plugin_data.content. + """Returns session_start_info from the plugin_data.content. - Raises HParamsError if the content doesn't have 'session_start_info' set or - this file is incompatible with the version of the metadata stored. + Raises HParamsError if the content doesn't have 'session_start_info' set or + this file is incompatible with the version of the metadata stored. - Args: - content: The SummaryMetadata.plugin_data.content to use. - """ - return _parse_plugin_data_as(content, 'session_start_info') + Args: + content: The SummaryMetadata.plugin_data.content to use. + """ + return _parse_plugin_data_as(content, "session_start_info") def parse_session_end_info_plugin_data(content): - """Returns session_end_info from the plugin_data.content. + """Returns session_end_info from the plugin_data.content. - Raises HParamsError if the content doesn't have 'session_end_info' set or - this file is incompatible with the version of the metadata stored. + Raises HParamsError if the content doesn't have 'session_end_info' set or + this file is incompatible with the version of the metadata stored. - Args: - content: The SummaryMetadata.plugin_data.content to use. - """ - return _parse_plugin_data_as(content, 'session_end_info') + Args: + content: The SummaryMetadata.plugin_data.content to use. + """ + return _parse_plugin_data_as(content, "session_end_info") def _parse_plugin_data_as(content, data_oneof_field): - """Returns a data oneof's field from plugin_data.content. - - Raises HParamsError if the content doesn't have 'data_oneof_field' set or - this file is incompatible with the version of the metadata stored. - - Args: - content: The SummaryMetadata.plugin_data.content to use. - data_oneof_field: string. The name of the data oneof field to return. - """ - plugin_data = plugin_data_pb2.HParamsPluginData.FromString(content) - if plugin_data.version != PLUGIN_DATA_VERSION: - raise error.HParamsError( - 'Only supports plugin_data version: %s; found: %s in: %s' % - (PLUGIN_DATA_VERSION, plugin_data.version, plugin_data)) - if not plugin_data.HasField(data_oneof_field): - raise error.HParamsError( - 'Expected plugin_data.%s to be set. Got: %s' % - (data_oneof_field, plugin_data)) - return getattr(plugin_data, data_oneof_field) + """Returns a data oneof's field from plugin_data.content. + + Raises HParamsError if the content doesn't have 'data_oneof_field' set or + this file is incompatible with the version of the metadata stored. + + Args: + content: The SummaryMetadata.plugin_data.content to use. + data_oneof_field: string. The name of the data oneof field to return. + """ + plugin_data = plugin_data_pb2.HParamsPluginData.FromString(content) + if plugin_data.version != PLUGIN_DATA_VERSION: + raise error.HParamsError( + "Only supports plugin_data version: %s; found: %s in: %s" + % (PLUGIN_DATA_VERSION, plugin_data.version, plugin_data) + ) + if not plugin_data.HasField(data_oneof_field): + raise error.HParamsError( + "Expected plugin_data.%s to be set. Got: %s" + % (data_oneof_field, plugin_data) + ) + return getattr(plugin_data, data_oneof_field) diff --git a/tensorboard/plugins/hparams/metrics.py b/tensorboard/plugins/hparams/metrics.py index 442f7bd0df..f8e7b88c12 100644 --- a/tensorboard/plugins/hparams/metrics.py +++ b/tensorboard/plugins/hparams/metrics.py @@ -27,53 +27,57 @@ def run_tag_from_session_and_metric(session_name, metric_name): - """Returns a (run,tag) tuple storing the evaluations of the specified metric. + """Returns a (run,tag) tuple storing the evaluations of the specified + metric. - Args: - session_name: str. - metric_name: MetricName protobuffer. - Returns: (run, tag) tuple. - """ - assert isinstance(session_name, six.string_types) - assert isinstance(metric_name, api_pb2.MetricName) - # os.path.join() will append a final slash if the group is empty; it seems - # like multiplexer.Tensors won't recognize paths that end with a '/' so - # we normalize the result of os.path.join() to remove the final '/' in that - # case. - run = os.path.normpath(os.path.join(session_name, metric_name.group)) - tag = metric_name.tag - return run, tag + Args: + session_name: str. + metric_name: MetricName protobuffer. + Returns: (run, tag) tuple. + """ + assert isinstance(session_name, six.string_types) + assert isinstance(metric_name, api_pb2.MetricName) + # os.path.join() will append a final slash if the group is empty; it seems + # like multiplexer.Tensors won't recognize paths that end with a '/' so + # we normalize the result of os.path.join() to remove the final '/' in that + # case. + run = os.path.normpath(os.path.join(session_name, metric_name.group)) + tag = metric_name.tag + return run, tag def last_metric_eval(multiplexer, session_name, metric_name): - """Returns the last evaluations of the given metric at the given session. + """Returns the last evaluations of the given metric at the given session. - Args: - multiplexer: The EventMultiplexer instance allowing access to - the exported summary data. - session_name: String. The session name for which to get the metric - evaluations. - metric_name: api_pb2.MetricName proto. The name of the metric to use. + Args: + multiplexer: The EventMultiplexer instance allowing access to + the exported summary data. + session_name: String. The session name for which to get the metric + evaluations. + metric_name: api_pb2.MetricName proto. The name of the metric to use. - Returns: - A 3-tuples, of the form [wall-time, step, value], denoting - the last evaluation of the metric, where wall-time denotes the wall time - in seconds since UNIX epoch of the time of the evaluation, step denotes - the training step at which the model is evaluated, and value denotes the - (scalar real) value of the metric. + Returns: + A 3-tuples, of the form [wall-time, step, value], denoting + the last evaluation of the metric, where wall-time denotes the wall time + in seconds since UNIX epoch of the time of the evaluation, step denotes + the training step at which the model is evaluated, and value denotes the + (scalar real) value of the metric. - Raises: - KeyError if the given session does not have the metric. - """ - try: - run, tag = run_tag_from_session_and_metric(session_name, metric_name) - tensor_events = multiplexer.Tensors(run=run, tag=tag) - except KeyError as e: - raise KeyError( - 'Can\'t find metric %s for session: %s. Underlying error message: %s' - % (metric_name, session_name, e)) - last_event = tensor_events[-1] - # TODO(erez): Raise HParamsError if the tensor is not a 0-D real scalar. - return (last_event.wall_time, - last_event.step, - tensor_util.make_ndarray(last_event.tensor_proto).item()) + Raises: + KeyError if the given session does not have the metric. + """ + try: + run, tag = run_tag_from_session_and_metric(session_name, metric_name) + tensor_events = multiplexer.Tensors(run=run, tag=tag) + except KeyError as e: + raise KeyError( + "Can't find metric %s for session: %s. Underlying error message: %s" + % (metric_name, session_name, e) + ) + last_event = tensor_events[-1] + # TODO(erez): Raise HParamsError if the tensor is not a 0-D real scalar. + return ( + last_event.wall_time, + last_event.step, + tensor_util.make_ndarray(last_event.tensor_proto).item(), + ) diff --git a/tensorboard/plugins/hparams/summary.py b/tensorboard/plugins/hparams/summary.py index 30db75dc6b..0d8234d0b6 100644 --- a/tensorboard/plugins/hparams/summary.py +++ b/tensorboard/plugins/hparams/summary.py @@ -47,145 +47,151 @@ def experiment_pb( - hparam_infos, - metric_infos, - user='', - description='', - time_created_secs=None): - """Creates a summary that defines a hyperparameter-tuning experiment. - - Args: - hparam_infos: Array of api_pb2.HParamInfo messages. Describes the - hyperparameters used in the experiment. - metric_infos: Array of api_pb2.MetricInfo messages. Describes the metrics - used in the experiment. See the documentation at the top of this file - for how to populate this. - user: String. An id for the user running the experiment - description: String. A description for the experiment. May contain markdown. - time_created_secs: float. The time the experiment is created in seconds - since the UNIX epoch. If None uses the current time. - - Returns: - A summary protobuffer containing the experiment definition. - """ - if time_created_secs is None: - time_created_secs = time.time() - experiment = api_pb2.Experiment( - description=description, - user=user, - time_created_secs=time_created_secs, - hparam_infos=hparam_infos, - metric_infos=metric_infos) - return _summary(metadata.EXPERIMENT_TAG, - plugin_data_pb2.HParamsPluginData(experiment=experiment)) - - -def session_start_pb(hparams, - model_uri='', - monitor_url='', - group_name='', - start_time_secs=None): - """Constructs a SessionStartInfo protobuffer. - - Creates a summary that contains a training session metadata information. - One such summary per training session should be created. Each should have - a different run. - - Args: - hparams: A dictionary with string keys. Describes the hyperparameter values - used in the session, mapping each hyperparameter name to its value. - Supported value types are `bool`, `int`, `float`, `str`, `list`, - `tuple`. - The type of value must correspond to the type of hyperparameter - (defined in the corresponding api_pb2.HParamInfo member of the - Experiment protobuf) as follows: - - +-----------------+---------------------------------+ - |Hyperparameter | Allowed (Python) value types | - |type | | - +-----------------+---------------------------------+ - |DATA_TYPE_BOOL | bool | - |DATA_TYPE_FLOAT64| int, float | - |DATA_TYPE_STRING | six.string_types, tuple, list | - +-----------------+---------------------------------+ - - Tuple and list instances will be converted to their string - representation. - model_uri: See the comment for the field with the same name of - plugin_data_pb2.SessionStartInfo. - monitor_url: See the comment for the field with the same name of + hparam_infos, metric_infos, user="", description="", time_created_secs=None +): + """Creates a summary that defines a hyperparameter-tuning experiment. + + Args: + hparam_infos: Array of api_pb2.HParamInfo messages. Describes the + hyperparameters used in the experiment. + metric_infos: Array of api_pb2.MetricInfo messages. Describes the metrics + used in the experiment. See the documentation at the top of this file + for how to populate this. + user: String. An id for the user running the experiment + description: String. A description for the experiment. May contain markdown. + time_created_secs: float. The time the experiment is created in seconds + since the UNIX epoch. If None uses the current time. + + Returns: + A summary protobuffer containing the experiment definition. + """ + if time_created_secs is None: + time_created_secs = time.time() + experiment = api_pb2.Experiment( + description=description, + user=user, + time_created_secs=time_created_secs, + hparam_infos=hparam_infos, + metric_infos=metric_infos, + ) + return _summary( + metadata.EXPERIMENT_TAG, + plugin_data_pb2.HParamsPluginData(experiment=experiment), + ) + + +def session_start_pb( + hparams, model_uri="", monitor_url="", group_name="", start_time_secs=None +): + """Constructs a SessionStartInfo protobuffer. + + Creates a summary that contains a training session metadata information. + One such summary per training session should be created. Each should have + a different run. + + Args: + hparams: A dictionary with string keys. Describes the hyperparameter values + used in the session, mapping each hyperparameter name to its value. + Supported value types are `bool`, `int`, `float`, `str`, `list`, + `tuple`. + The type of value must correspond to the type of hyperparameter + (defined in the corresponding api_pb2.HParamInfo member of the + Experiment protobuf) as follows: + + +-----------------+---------------------------------+ + |Hyperparameter | Allowed (Python) value types | + |type | | + +-----------------+---------------------------------+ + |DATA_TYPE_BOOL | bool | + |DATA_TYPE_FLOAT64| int, float | + |DATA_TYPE_STRING | six.string_types, tuple, list | + +-----------------+---------------------------------+ + + Tuple and list instances will be converted to their string + representation. + model_uri: See the comment for the field with the same name of plugin_data_pb2.SessionStartInfo. - group_name: See the comment for the field with the same name of - plugin_data_pb2.SessionStartInfo. - start_time_secs: float. The time to use as the session start time. - Represented as seconds since the UNIX epoch. If None uses - the current time. - Returns: - The summary protobuffer mentioned above. - """ - if start_time_secs is None: - start_time_secs = time.time() - session_start_info = plugin_data_pb2.SessionStartInfo( - model_uri=model_uri, - monitor_url=monitor_url, - group_name=group_name, - start_time_secs=start_time_secs) - for (hp_name, hp_val) in six.iteritems(hparams): - if isinstance(hp_val, (float, int)): - session_start_info.hparams[hp_name].number_value = hp_val - elif isinstance(hp_val, six.string_types): - session_start_info.hparams[hp_name].string_value = hp_val - elif isinstance(hp_val, bool): - session_start_info.hparams[hp_name].bool_value = hp_val - elif isinstance(hp_val, (list, tuple)): - session_start_info.hparams[hp_name].string_value = str(hp_val) - else: - raise TypeError('hparams[%s]=%s has type: %s which is not supported' % - (hp_name, hp_val, type(hp_val))) - return _summary(metadata.SESSION_START_INFO_TAG, - plugin_data_pb2.HParamsPluginData( - session_start_info=session_start_info)) + monitor_url: See the comment for the field with the same name of + plugin_data_pb2.SessionStartInfo. + group_name: See the comment for the field with the same name of + plugin_data_pb2.SessionStartInfo. + start_time_secs: float. The time to use as the session start time. + Represented as seconds since the UNIX epoch. If None uses + the current time. + Returns: + The summary protobuffer mentioned above. + """ + if start_time_secs is None: + start_time_secs = time.time() + session_start_info = plugin_data_pb2.SessionStartInfo( + model_uri=model_uri, + monitor_url=monitor_url, + group_name=group_name, + start_time_secs=start_time_secs, + ) + for (hp_name, hp_val) in six.iteritems(hparams): + if isinstance(hp_val, (float, int)): + session_start_info.hparams[hp_name].number_value = hp_val + elif isinstance(hp_val, six.string_types): + session_start_info.hparams[hp_name].string_value = hp_val + elif isinstance(hp_val, bool): + session_start_info.hparams[hp_name].bool_value = hp_val + elif isinstance(hp_val, (list, tuple)): + session_start_info.hparams[hp_name].string_value = str(hp_val) + else: + raise TypeError( + "hparams[%s]=%s has type: %s which is not supported" + % (hp_name, hp_val, type(hp_val)) + ) + return _summary( + metadata.SESSION_START_INFO_TAG, + plugin_data_pb2.HParamsPluginData( + session_start_info=session_start_info + ), + ) def session_end_pb(status, end_time_secs=None): - """Constructs a SessionEndInfo protobuffer. - - Creates a summary that contains status information for a completed - training session. Should be exported after the training session is completed. - One such summary per training session should be created. Each should have - a different run. - - Args: - status: A tensorboard.hparams.Status enumeration value denoting the - status of the session. - end_time_secs: float. The time to use as the session end time. Represented - as seconds since the unix epoch. If None uses the current time. - - Returns: - The summary protobuffer mentioned above. - """ - if end_time_secs is None: - end_time_secs = time.time() - - session_end_info = plugin_data_pb2.SessionEndInfo(status=status, - end_time_secs=end_time_secs) - return _summary(metadata.SESSION_END_INFO_TAG, - plugin_data_pb2.HParamsPluginData( - session_end_info=session_end_info)) + """Constructs a SessionEndInfo protobuffer. + + Creates a summary that contains status information for a completed + training session. Should be exported after the training session is completed. + One such summary per training session should be created. Each should have + a different run. + + Args: + status: A tensorboard.hparams.Status enumeration value denoting the + status of the session. + end_time_secs: float. The time to use as the session end time. Represented + as seconds since the unix epoch. If None uses the current time. + + Returns: + The summary protobuffer mentioned above. + """ + if end_time_secs is None: + end_time_secs = time.time() + + session_end_info = plugin_data_pb2.SessionEndInfo( + status=status, end_time_secs=end_time_secs + ) + return _summary( + metadata.SESSION_END_INFO_TAG, + plugin_data_pb2.HParamsPluginData(session_end_info=session_end_info), + ) def _summary(tag, hparams_plugin_data): - """Returns a summary holding the given HParamsPluginData message. - - Helper function. - - Args: - tag: string. The tag to use. - hparams_plugin_data: The HParamsPluginData message to use. - """ - summary = tf.compat.v1.Summary() - tb_metadata = metadata.create_summary_metadata(hparams_plugin_data) - raw_metadata = tb_metadata.SerializeToString() - tf_metadata = tf.compat.v1.SummaryMetadata.FromString(raw_metadata) - summary.value.add(tag=tag, metadata=tf_metadata) - return summary + """Returns a summary holding the given HParamsPluginData message. + + Helper function. + + Args: + tag: string. The tag to use. + hparams_plugin_data: The HParamsPluginData message to use. + """ + summary = tf.compat.v1.Summary() + tb_metadata = metadata.create_summary_metadata(hparams_plugin_data) + raw_metadata = tb_metadata.SerializeToString() + tf_metadata = tf.compat.v1.SummaryMetadata.FromString(raw_metadata) + summary.value.add(tag=tag, metadata=tf_metadata) + return summary diff --git a/tensorboard/plugins/hparams/summary_test.py b/tensorboard/plugins/hparams/summary_test.py index a0cfb37d66..3a64c8c901 100644 --- a/tensorboard/plugins/hparams/summary_test.py +++ b/tensorboard/plugins/hparams/summary_test.py @@ -27,70 +27,98 @@ class SummaryTest(tf.test.TestCase): + def test_experiment_pb(self): + hparam_infos = [ + api_pb2.HParamInfo( + name="param1", + display_name="display_name1", + description="foo", + type=api_pb2.DATA_TYPE_STRING, + domain_discrete=struct_pb2.ListValue( + values=[ + struct_pb2.Value(string_value="a"), + struct_pb2.Value(string_value="b"), + ] + ), + ), + api_pb2.HParamInfo( + name="param2", + display_name="display_name2", + description="bar", + type=api_pb2.DATA_TYPE_FLOAT64, + domain_interval=api_pb2.Interval( + min_value=-100.0, max_value=100.0 + ), + ), + ] + metric_infos = [ + api_pb2.MetricInfo( + name=api_pb2.MetricName(tag="loss"), + dataset_type=api_pb2.DATASET_VALIDATION, + ), + api_pb2.MetricInfo( + name=api_pb2.MetricName(group="train/", tag="acc"), + dataset_type=api_pb2.DATASET_TRAINING, + ), + ] + time_created_secs = 314159.0 + self.assertEqual( + summary.experiment_pb( + hparam_infos, metric_infos, time_created_secs=time_created_secs + ), + tf.compat.v1.Summary( + value=[ + tf.compat.v1.Summary.Value( + tag="_hparams_/experiment", + metadata=tf.compat.v1.SummaryMetadata( + plugin_data=tf.compat.v1.SummaryMetadata.PluginData( + plugin_name="hparams", + content=( + plugin_data_pb2.HParamsPluginData( + version=0, + experiment=api_pb2.Experiment( + time_created_secs=time_created_secs, + hparam_infos=hparam_infos, + metric_infos=metric_infos, + ), + ).SerializeToString() + ), + ) + ), + ) + ] + ), + ) - def test_experiment_pb(self): - hparam_infos = [ - api_pb2.HParamInfo( - name="param1", - display_name="display_name1", - description="foo", - type=api_pb2.DATA_TYPE_STRING, - domain_discrete=struct_pb2.ListValue(values=[ - struct_pb2.Value(string_value="a"), - struct_pb2.Value(string_value="b") - ])), - api_pb2.HParamInfo( - name="param2", - display_name="display_name2", - description="bar", - type=api_pb2.DATA_TYPE_FLOAT64, - domain_interval=api_pb2.Interval(min_value=-100.0, max_value=100.0)) - ] - metric_infos = [ - api_pb2.MetricInfo( - name=api_pb2.MetricName(tag="loss"), - dataset_type=api_pb2.DATASET_VALIDATION), - api_pb2.MetricInfo( - name=api_pb2.MetricName(group="train/", tag="acc"), - dataset_type=api_pb2.DATASET_TRAINING), - ] - time_created_secs = 314159.0 - self.assertEqual( - summary.experiment_pb( - hparam_infos, metric_infos, time_created_secs=time_created_secs), - tf.compat.v1.Summary(value=[ - tf.compat.v1.Summary.Value( - tag="_hparams_/experiment", - metadata=tf.compat.v1.SummaryMetadata( - plugin_data=tf.compat.v1.SummaryMetadata.PluginData( - plugin_name="hparams", - content=(plugin_data_pb2.HParamsPluginData( - version=0, - experiment=api_pb2.Experiment( - time_created_secs=time_created_secs, - hparam_infos=hparam_infos, - metric_infos=metric_infos)) - .SerializeToString())))) - ])) - - def test_session_end_pb(self): - end_time_secs = 1234.0 - self.assertEqual( - summary.session_end_pb(api_pb2.STATUS_SUCCESS, end_time_secs), - tf.compat.v1.Summary(value=[ - tf.compat.v1.Summary.Value( - tag="_hparams_/session_end_info", - metadata=tf.compat.v1.SummaryMetadata( - plugin_data=tf.compat.v1.SummaryMetadata.PluginData( - plugin_name="hparams", - content=(plugin_data_pb2.HParamsPluginData( - version=0, - session_end_info=(plugin_data_pb2.SessionEndInfo( - status=api_pb2.STATUS_SUCCESS, - end_time_secs=end_time_secs, - ))).SerializeToString())))) - ])) + def test_session_end_pb(self): + end_time_secs = 1234.0 + self.assertEqual( + summary.session_end_pb(api_pb2.STATUS_SUCCESS, end_time_secs), + tf.compat.v1.Summary( + value=[ + tf.compat.v1.Summary.Value( + tag="_hparams_/session_end_info", + metadata=tf.compat.v1.SummaryMetadata( + plugin_data=tf.compat.v1.SummaryMetadata.PluginData( + plugin_name="hparams", + content=( + plugin_data_pb2.HParamsPluginData( + version=0, + session_end_info=( + plugin_data_pb2.SessionEndInfo( + status=api_pb2.STATUS_SUCCESS, + end_time_secs=end_time_secs, + ) + ), + ).SerializeToString() + ), + ) + ), + ) + ] + ), + ) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/tensorboard/plugins/hparams/summary_v2.py b/tensorboard/plugins/hparams/summary_v2.py index b974841e3a..3f85ac57b2 100644 --- a/tensorboard/plugins/hparams/summary_v2.py +++ b/tensorboard/plugins/hparams/summary_v2.py @@ -37,505 +37,507 @@ def hparams(hparams, trial_id=None, start_time_secs=None): - # NOTE: Keep docs in sync with `hparams_pb` below. - """Write hyperparameter values for a single trial. - - Args: - hparams: A `dict` mapping hyperparameters to the values used in this - trial. Keys should be the names of `HParam` objects used in an - experiment, or the `HParam` objects themselves. Values should be - Python `bool`, `int`, `float`, or `string` values, depending on - the type of the hyperparameter. - trial_id: An optional `str` ID for the set of hyperparameter values - used in this trial. Defaults to a hash of the hyperparameters. - start_time_secs: The time that this trial started training, as - seconds since epoch. Defaults to the current time. - - Returns: - A tensor whose value is `True` on success, or `False` if no summary - was written because no default summary writer was available. - """ - pb = hparams_pb( - hparams=hparams, - trial_id=trial_id, - start_time_secs=start_time_secs, - ) - return _write_summary("hparams", pb) + # NOTE: Keep docs in sync with `hparams_pb` below. + """Write hyperparameter values for a single trial. + + Args: + hparams: A `dict` mapping hyperparameters to the values used in this + trial. Keys should be the names of `HParam` objects used in an + experiment, or the `HParam` objects themselves. Values should be + Python `bool`, `int`, `float`, or `string` values, depending on + the type of the hyperparameter. + trial_id: An optional `str` ID for the set of hyperparameter values + used in this trial. Defaults to a hash of the hyperparameters. + start_time_secs: The time that this trial started training, as + seconds since epoch. Defaults to the current time. + + Returns: + A tensor whose value is `True` on success, or `False` if no summary + was written because no default summary writer was available. + """ + pb = hparams_pb( + hparams=hparams, trial_id=trial_id, start_time_secs=start_time_secs, + ) + return _write_summary("hparams", pb) def hparams_pb(hparams, trial_id=None, start_time_secs=None): - # NOTE: Keep docs in sync with `hparams` above. - """Create a summary encoding hyperparameter values for a single trial. - - Args: - hparams: A `dict` mapping hyperparameters to the values used in this - trial. Keys should be the names of `HParam` objects used in an - experiment, or the `HParam` objects themselves. Values should be - Python `bool`, `int`, `float`, or `string` values, depending on - the type of the hyperparameter. - trial_id: An optional `str` ID for the set of hyperparameter values - used in this trial. Defaults to a hash of the hyperparameters. - start_time_secs: The time that this trial started training, as - seconds since epoch. Defaults to the current time. - - Returns: - A TensorBoard `summary_pb2.Summary` message. - """ - if start_time_secs is None: - start_time_secs = time.time() - hparams = _normalize_hparams(hparams) - group_name = _derive_session_group_name(trial_id, hparams) - - session_start_info = plugin_data_pb2.SessionStartInfo( - group_name=group_name, - start_time_secs=start_time_secs, - ) - for hp_name in sorted(hparams): - hp_value = hparams[hp_name] - if isinstance(hp_value, bool): - session_start_info.hparams[hp_name].bool_value = hp_value - elif isinstance(hp_value, (float, int)): - session_start_info.hparams[hp_name].number_value = hp_value - elif isinstance(hp_value, six.string_types): - session_start_info.hparams[hp_name].string_value = hp_value - else: - raise TypeError( - "hparams[%r] = %r, of unsupported type %r" - % (hp_name, hp_value, type(hp_value)) - ) - - return _summary_pb( - metadata.SESSION_START_INFO_TAG, - plugin_data_pb2.HParamsPluginData(session_start_info=session_start_info), - ) + # NOTE: Keep docs in sync with `hparams` above. + """Create a summary encoding hyperparameter values for a single trial. + + Args: + hparams: A `dict` mapping hyperparameters to the values used in this + trial. Keys should be the names of `HParam` objects used in an + experiment, or the `HParam` objects themselves. Values should be + Python `bool`, `int`, `float`, or `string` values, depending on + the type of the hyperparameter. + trial_id: An optional `str` ID for the set of hyperparameter values + used in this trial. Defaults to a hash of the hyperparameters. + start_time_secs: The time that this trial started training, as + seconds since epoch. Defaults to the current time. + + Returns: + A TensorBoard `summary_pb2.Summary` message. + """ + if start_time_secs is None: + start_time_secs = time.time() + hparams = _normalize_hparams(hparams) + group_name = _derive_session_group_name(trial_id, hparams) + + session_start_info = plugin_data_pb2.SessionStartInfo( + group_name=group_name, start_time_secs=start_time_secs, + ) + for hp_name in sorted(hparams): + hp_value = hparams[hp_name] + if isinstance(hp_value, bool): + session_start_info.hparams[hp_name].bool_value = hp_value + elif isinstance(hp_value, (float, int)): + session_start_info.hparams[hp_name].number_value = hp_value + elif isinstance(hp_value, six.string_types): + session_start_info.hparams[hp_name].string_value = hp_value + else: + raise TypeError( + "hparams[%r] = %r, of unsupported type %r" + % (hp_name, hp_value, type(hp_value)) + ) + + return _summary_pb( + metadata.SESSION_START_INFO_TAG, + plugin_data_pb2.HParamsPluginData( + session_start_info=session_start_info + ), + ) def hparams_config(hparams, metrics, time_created_secs=None): - # NOTE: Keep docs in sync with `hparams_config_pb` below. - """Write a top-level experiment configuration. - - This configuration describes the hyperparameters and metrics that will - be tracked in the experiment, but does not record any actual values of - those hyperparameters and metrics. It can be created before any models - are actually trained. - - Args: - hparams: A list of `HParam` values. - metrics: A list of `Metric` values. - time_created_secs: The time that this experiment was created, as - seconds since epoch. Defaults to the current time. - - Returns: - A tensor whose value is `True` on success, or `False` if no summary - was written because no default summary writer was available. - """ - pb = hparams_config_pb( - hparams=hparams, - metrics=metrics, - time_created_secs=time_created_secs, - ) - return _write_summary("hparams_config", pb) + # NOTE: Keep docs in sync with `hparams_config_pb` below. + """Write a top-level experiment configuration. + + This configuration describes the hyperparameters and metrics that will + be tracked in the experiment, but does not record any actual values of + those hyperparameters and metrics. It can be created before any models + are actually trained. + + Args: + hparams: A list of `HParam` values. + metrics: A list of `Metric` values. + time_created_secs: The time that this experiment was created, as + seconds since epoch. Defaults to the current time. + + Returns: + A tensor whose value is `True` on success, or `False` if no summary + was written because no default summary writer was available. + """ + pb = hparams_config_pb( + hparams=hparams, metrics=metrics, time_created_secs=time_created_secs, + ) + return _write_summary("hparams_config", pb) def hparams_config_pb(hparams, metrics, time_created_secs=None): - # NOTE: Keep docs in sync with `hparams_config` above. - """Create a top-level experiment configuration. - - This configuration describes the hyperparameters and metrics that will - be tracked in the experiment, but does not record any actual values of - those hyperparameters and metrics. It can be created before any models - are actually trained. - - Args: - hparams: A list of `HParam` values. - metrics: A list of `Metric` values. - time_created_secs: The time that this experiment was created, as - seconds since epoch. Defaults to the current time. - - Returns: - A TensorBoard `summary_pb2.Summary` message. - """ - hparam_infos = [] - for hparam in hparams: - info = api_pb2.HParamInfo( - name=hparam.name, - description=hparam.description, - display_name=hparam.display_name, + # NOTE: Keep docs in sync with `hparams_config` above. + """Create a top-level experiment configuration. + + This configuration describes the hyperparameters and metrics that will + be tracked in the experiment, but does not record any actual values of + those hyperparameters and metrics. It can be created before any models + are actually trained. + + Args: + hparams: A list of `HParam` values. + metrics: A list of `Metric` values. + time_created_secs: The time that this experiment was created, as + seconds since epoch. Defaults to the current time. + + Returns: + A TensorBoard `summary_pb2.Summary` message. + """ + hparam_infos = [] + for hparam in hparams: + info = api_pb2.HParamInfo( + name=hparam.name, + description=hparam.description, + display_name=hparam.display_name, + ) + domain = hparam.domain + if domain is not None: + domain.update_hparam_info(info) + hparam_infos.append(info) + metric_infos = [metric.as_proto() for metric in metrics] + experiment = api_pb2.Experiment( + hparam_infos=hparam_infos, + metric_infos=metric_infos, + time_created_secs=time_created_secs, + ) + return _summary_pb( + metadata.EXPERIMENT_TAG, + plugin_data_pb2.HParamsPluginData(experiment=experiment), ) - domain = hparam.domain - if domain is not None: - domain.update_hparam_info(info) - hparam_infos.append(info) - metric_infos = [metric.as_proto() for metric in metrics] - experiment = api_pb2.Experiment( - hparam_infos=hparam_infos, - metric_infos=metric_infos, - time_created_secs=time_created_secs, - ) - return _summary_pb( - metadata.EXPERIMENT_TAG, - plugin_data_pb2.HParamsPluginData(experiment=experiment), - ) def _normalize_hparams(hparams): - """Normalize a dict keyed by `HParam`s and/or raw strings. - - Args: - hparams: A `dict` whose keys are `HParam` objects and/or strings - representing hyperparameter names, and whose values are - hyperparameter values. No two keys may have the same name. - - Returns: - A `dict` whose keys are hyperparameter names (as strings) and whose - values are the corresponding hyperparameter values. - - Raises: - ValueError: If two entries in `hparams` share the same - hyperparameter name. - """ - result = {} - for (k, v) in six.iteritems(hparams): - if isinstance(k, HParam): - k = k.name - if k in result: - raise ValueError("multiple values specified for hparam %r" % (k,)) - result[k] = v - return result + """Normalize a dict keyed by `HParam`s and/or raw strings. + + Args: + hparams: A `dict` whose keys are `HParam` objects and/or strings + representing hyperparameter names, and whose values are + hyperparameter values. No two keys may have the same name. + + Returns: + A `dict` whose keys are hyperparameter names (as strings) and whose + values are the corresponding hyperparameter values. + + Raises: + ValueError: If two entries in `hparams` share the same + hyperparameter name. + """ + result = {} + for (k, v) in six.iteritems(hparams): + if isinstance(k, HParam): + k = k.name + if k in result: + raise ValueError("multiple values specified for hparam %r" % (k,)) + result[k] = v + return result def _derive_session_group_name(trial_id, hparams): - if trial_id is not None: - if not isinstance(trial_id, six.string_types): - raise TypeError("`trial_id` should be a `str`, but got: %r" % (trial_id,)) - return trial_id - # Use `json.dumps` rather than `str` to ensure invariance under string - # type (incl. across Python versions) and dict iteration order. - jparams = json.dumps(hparams, sort_keys=True, separators=(",", ":")) - return hashlib.sha256(jparams.encode("utf-8")).hexdigest() + if trial_id is not None: + if not isinstance(trial_id, six.string_types): + raise TypeError( + "`trial_id` should be a `str`, but got: %r" % (trial_id,) + ) + return trial_id + # Use `json.dumps` rather than `str` to ensure invariance under string + # type (incl. across Python versions) and dict iteration order. + jparams = json.dumps(hparams, sort_keys=True, separators=(",", ":")) + return hashlib.sha256(jparams.encode("utf-8")).hexdigest() def _write_summary(name, pb): - """Write a summary, returning the writing op. + """Write a summary, returning the writing op. - Args: - name: As passed to `summary_scope`. - pb: A `summary_pb2.Summary` message. + Args: + name: As passed to `summary_scope`. + pb: A `summary_pb2.Summary` message. - Returns: - A tensor whose value is `True` on success, or `False` if no summary - was written because no default summary writer was available. - """ - raw_pb = pb.SerializeToString() - summary_scope = ( - getattr(tf.summary.experimental, "summary_scope", None) - or tf.summary.summary_scope - ) - with summary_scope(name): - return tf.summary.experimental.write_raw_pb(raw_pb, step=0) + Returns: + A tensor whose value is `True` on success, or `False` if no summary + was written because no default summary writer was available. + """ + raw_pb = pb.SerializeToString() + summary_scope = ( + getattr(tf.summary.experimental, "summary_scope", None) + or tf.summary.summary_scope + ) + with summary_scope(name): + return tf.summary.experimental.write_raw_pb(raw_pb, step=0) def _summary_pb(tag, hparams_plugin_data): - """Create a summary holding the given `HParamsPluginData` message. + """Create a summary holding the given `HParamsPluginData` message. - Args: - tag: The `str` tag to use. - hparams_plugin_data: The `HParamsPluginData` message to use. + Args: + tag: The `str` tag to use. + hparams_plugin_data: The `HParamsPluginData` message to use. - Returns: - A TensorBoard `summary_pb2.Summary` message. - """ - summary = summary_pb2.Summary() - summary_metadata = metadata.create_summary_metadata(hparams_plugin_data) - summary.value.add(tag=tag, metadata=summary_metadata) - return summary + Returns: + A TensorBoard `summary_pb2.Summary` message. + """ + summary = summary_pb2.Summary() + summary_metadata = metadata.create_summary_metadata(hparams_plugin_data) + summary.value.add(tag=tag, metadata=summary_metadata) + return summary class HParam(object): - """A hyperparameter in an experiment. + """A hyperparameter in an experiment. - This class describes a hyperparameter in the abstract. It ranges over - a domain of values, but is not bound to any particular value. - """ - - def __init__(self, name, domain=None, display_name=None, description=None): - """Create a hyperparameter object. - - Args: - name: A string ID for this hyperparameter, which should be unique - within an experiment. - domain: An optional `Domain` object describing the values that - this hyperparameter can take on. - display_name: An optional human-readable display name (`str`). - description: An optional Markdown string describing this - hyperparameter. - - Raises: - ValueError: If `domain` is not a `Domain`. + This class describes a hyperparameter in the abstract. It ranges + over a domain of values, but is not bound to any particular value. """ - self._name = name - self._domain = domain - self._display_name = display_name - self._description = description - if not isinstance(self._domain, (Domain, type(None))): - raise ValueError("not a domain: %r" % (self._domain,)) - - def __str__(self): - return "" % (self._name, self._domain) - - def __repr__(self): - fields = [ - ("name", self._name), - ("domain", self._domain), - ("display_name", self._display_name), - ("description", self._description), - ] - fields_string = ", ".join("%s=%r" % (k, v) for (k, v) in fields) - return "HParam(%s)" % fields_string - - @property - def name(self): - return self._name - - @property - def domain(self): - return self._domain - - @property - def display_name(self): - return self._display_name - - @property - def description(self): - return self._description + + def __init__(self, name, domain=None, display_name=None, description=None): + """Create a hyperparameter object. + + Args: + name: A string ID for this hyperparameter, which should be unique + within an experiment. + domain: An optional `Domain` object describing the values that + this hyperparameter can take on. + display_name: An optional human-readable display name (`str`). + description: An optional Markdown string describing this + hyperparameter. + + Raises: + ValueError: If `domain` is not a `Domain`. + """ + self._name = name + self._domain = domain + self._display_name = display_name + self._description = description + if not isinstance(self._domain, (Domain, type(None))): + raise ValueError("not a domain: %r" % (self._domain,)) + + def __str__(self): + return "" % (self._name, self._domain) + + def __repr__(self): + fields = [ + ("name", self._name), + ("domain", self._domain), + ("display_name", self._display_name), + ("description", self._description), + ] + fields_string = ", ".join("%s=%r" % (k, v) for (k, v) in fields) + return "HParam(%s)" % fields_string + + @property + def name(self): + return self._name + + @property + def domain(self): + return self._domain + + @property + def display_name(self): + return self._display_name + + @property + def description(self): + return self._description @six.add_metaclass(abc.ABCMeta) class Domain(object): - """The domain of a hyperparameter. + """The domain of a hyperparameter. - Domains are restricted to values of the simple types `float`, `int`, - `str`, and `bool`. - """ + Domains are restricted to values of the simple types `float`, `int`, + `str`, and `bool`. + """ - @abc.abstractproperty - def dtype(self): - """Data type of this domain: `float`, `int`, `str`, or `bool`.""" - pass + @abc.abstractproperty + def dtype(self): + """Data type of this domain: `float`, `int`, `str`, or `bool`.""" + pass - @abc.abstractmethod - def sample_uniform(self, rng=random): - """Sample a value from this domain uniformly at random. + @abc.abstractmethod + def sample_uniform(self, rng=random): + """Sample a value from this domain uniformly at random. - Args: - rng: A `random.Random` interface; defaults to the `random` module - itself. + Args: + rng: A `random.Random` interface; defaults to the `random` module + itself. - Raises: - IndexError: If the domain is empty. - """ - pass + Raises: + IndexError: If the domain is empty. + """ + pass - @abc.abstractmethod - def update_hparam_info(self, hparam_info): - """Update an `HParamInfo` proto to include this domain. + @abc.abstractmethod + def update_hparam_info(self, hparam_info): + """Update an `HParamInfo` proto to include this domain. - This should update the `type` field on the proto and exactly one of - the `domain` variants on the proto. + This should update the `type` field on the proto and exactly one of + the `domain` variants on the proto. - Args: - hparam_info: An `api_pb2.HParamInfo` proto to modify. - """ - pass + Args: + hparam_info: An `api_pb2.HParamInfo` proto to modify. + """ + pass class IntInterval(Domain): - """A domain that takes on all integer values in a closed interval.""" + """A domain that takes on all integer values in a closed interval.""" - def __init__(self, min_value=None, max_value=None): - """Create an `IntInterval`. + def __init__(self, min_value=None, max_value=None): + """Create an `IntInterval`. - Args: - min_value: The lower bound (inclusive) of the interval. - max_value: The upper bound (inclusive) of the interval. + Args: + min_value: The lower bound (inclusive) of the interval. + max_value: The upper bound (inclusive) of the interval. - Raises: - TypeError: If `min_value` or `max_value` is not an `int`. - ValueError: If `min_value > max_value`. - """ - if not isinstance(min_value, int): - raise TypeError("min_value must be an int: %r" % (min_value,)) - if not isinstance(max_value, int): - raise TypeError("max_value must be an int: %r" % (max_value,)) - if min_value > max_value: - raise ValueError("%r > %r" % (min_value, max_value)) - self._min_value = min_value - self._max_value = max_value + Raises: + TypeError: If `min_value` or `max_value` is not an `int`. + ValueError: If `min_value > max_value`. + """ + if not isinstance(min_value, int): + raise TypeError("min_value must be an int: %r" % (min_value,)) + if not isinstance(max_value, int): + raise TypeError("max_value must be an int: %r" % (max_value,)) + if min_value > max_value: + raise ValueError("%r > %r" % (min_value, max_value)) + self._min_value = min_value + self._max_value = max_value - def __str__(self): - return "[%s, %s]" % (self._min_value, self._max_value) + def __str__(self): + return "[%s, %s]" % (self._min_value, self._max_value) - def __repr__(self): - return "IntInterval(%r, %r)" % (self._min_value, self._max_value) + def __repr__(self): + return "IntInterval(%r, %r)" % (self._min_value, self._max_value) - @property - def dtype(self): - return int + @property + def dtype(self): + return int - @property - def min_value(self): - return self._min_value + @property + def min_value(self): + return self._min_value - @property - def max_value(self): - return self._max_value + @property + def max_value(self): + return self._max_value - def sample_uniform(self, rng=random): - return rng.randint(self._min_value, self._max_value) + def sample_uniform(self, rng=random): + return rng.randint(self._min_value, self._max_value) - def update_hparam_info(self, hparam_info): - hparam_info.type = api_pb2.DATA_TYPE_FLOAT64 # TODO(#1998): Add int dtype. - hparam_info.domain_interval.min_value = self._min_value - hparam_info.domain_interval.max_value = self._max_value + def update_hparam_info(self, hparam_info): + hparam_info.type = ( + api_pb2.DATA_TYPE_FLOAT64 + ) # TODO(#1998): Add int dtype. + hparam_info.domain_interval.min_value = self._min_value + hparam_info.domain_interval.max_value = self._max_value class RealInterval(Domain): - """A domain that takes on all real values in a closed interval.""" + """A domain that takes on all real values in a closed interval.""" - def __init__(self, min_value=None, max_value=None): - """Create a `RealInterval`. + def __init__(self, min_value=None, max_value=None): + """Create a `RealInterval`. - Args: - min_value: The lower bound (inclusive) of the interval. - max_value: The upper bound (inclusive) of the interval. + Args: + min_value: The lower bound (inclusive) of the interval. + max_value: The upper bound (inclusive) of the interval. - Raises: - TypeError: If `min_value` or `max_value` is not an `float`. - ValueError: If `min_value > max_value`. - """ - if not isinstance(min_value, float): - raise TypeError("min_value must be a float: %r" % (min_value,)) - if not isinstance(max_value, float): - raise TypeError("max_value must be a float: %r" % (max_value,)) - if min_value > max_value: - raise ValueError("%r > %r" % (min_value, max_value)) - self._min_value = min_value - self._max_value = max_value + Raises: + TypeError: If `min_value` or `max_value` is not an `float`. + ValueError: If `min_value > max_value`. + """ + if not isinstance(min_value, float): + raise TypeError("min_value must be a float: %r" % (min_value,)) + if not isinstance(max_value, float): + raise TypeError("max_value must be a float: %r" % (max_value,)) + if min_value > max_value: + raise ValueError("%r > %r" % (min_value, max_value)) + self._min_value = min_value + self._max_value = max_value - def __str__(self): - return "[%s, %s]" % (self._min_value, self._max_value) + def __str__(self): + return "[%s, %s]" % (self._min_value, self._max_value) - def __repr__(self): - return "RealInterval(%r, %r)" % (self._min_value, self._max_value) + def __repr__(self): + return "RealInterval(%r, %r)" % (self._min_value, self._max_value) - @property - def dtype(self): - return float + @property + def dtype(self): + return float - @property - def min_value(self): - return self._min_value + @property + def min_value(self): + return self._min_value - @property - def max_value(self): - return self._max_value + @property + def max_value(self): + return self._max_value - def sample_uniform(self, rng=random): - return rng.uniform(self._min_value, self._max_value) + def sample_uniform(self, rng=random): + return rng.uniform(self._min_value, self._max_value) - def update_hparam_info(self, hparam_info): - hparam_info.type = api_pb2.DATA_TYPE_FLOAT64 - hparam_info.domain_interval.min_value = self._min_value - hparam_info.domain_interval.max_value = self._max_value + def update_hparam_info(self, hparam_info): + hparam_info.type = api_pb2.DATA_TYPE_FLOAT64 + hparam_info.domain_interval.min_value = self._min_value + hparam_info.domain_interval.max_value = self._max_value class Discrete(Domain): - """A domain that takes on a fixed set of values. - - These values may be of any (single) domain type. - """ - - def __init__(self, values, dtype=None): - """Construct a discrete domain. - - Args: - values: A iterable of the values in this domain. - dtype: The Python data type of values in this domain: one of - `int`, `float`, `bool`, or `str`. If `values` is non-empty, - `dtype` may be `None`, in which case it will be inferred as the - type of the first element of `values`. + """A domain that takes on a fixed set of values. - Raises: - ValueError: If `values` is empty but no `dtype` is specified. - ValueError: If `dtype` or its inferred value is not `int`, - `float`, `bool`, or `str`. - TypeError: If an element of `values` is not an instance of - `dtype`. + These values may be of any (single) domain type. """ - self._values = list(values) - if dtype is None: - if self._values: - dtype = type(self._values[0]) - else: - raise ValueError("Empty domain with no dtype specified") - if dtype not in (int, float, bool, str): - raise ValueError("Unknown dtype: %r" % (dtype,)) - self._dtype = dtype - for value in self._values: - if not isinstance(value, self._dtype): - raise TypeError( - "dtype mismatch: not isinstance(%r, %s)" - % (value, self._dtype.__name__) - ) - self._values.sort() - - def __str__(self): - return "{%s}" % (", ".join(repr(x) for x in self._values)) - - def __repr__(self): - return "Discrete(%r)" % (self._values,) - @property - def dtype(self): - return self._dtype - - @property - def values(self): - return list(self._values) - - def sample_uniform(self, rng=random): - return rng.choice(self._values) - - def update_hparam_info(self, hparam_info): - hparam_info.type = { - int: api_pb2.DATA_TYPE_FLOAT64, # TODO(#1998): Add int dtype. - float: api_pb2.DATA_TYPE_FLOAT64, - bool: api_pb2.DATA_TYPE_BOOL, - str: api_pb2.DATA_TYPE_STRING, - }[self._dtype] - hparam_info.ClearField("domain_discrete") - hparam_info.domain_discrete.extend(self._values) + def __init__(self, values, dtype=None): + """Construct a discrete domain. + + Args: + values: A iterable of the values in this domain. + dtype: The Python data type of values in this domain: one of + `int`, `float`, `bool`, or `str`. If `values` is non-empty, + `dtype` may be `None`, in which case it will be inferred as the + type of the first element of `values`. + + Raises: + ValueError: If `values` is empty but no `dtype` is specified. + ValueError: If `dtype` or its inferred value is not `int`, + `float`, `bool`, or `str`. + TypeError: If an element of `values` is not an instance of + `dtype`. + """ + self._values = list(values) + if dtype is None: + if self._values: + dtype = type(self._values[0]) + else: + raise ValueError("Empty domain with no dtype specified") + if dtype not in (int, float, bool, str): + raise ValueError("Unknown dtype: %r" % (dtype,)) + self._dtype = dtype + for value in self._values: + if not isinstance(value, self._dtype): + raise TypeError( + "dtype mismatch: not isinstance(%r, %s)" + % (value, self._dtype.__name__) + ) + self._values.sort() + + def __str__(self): + return "{%s}" % (", ".join(repr(x) for x in self._values)) + + def __repr__(self): + return "Discrete(%r)" % (self._values,) + + @property + def dtype(self): + return self._dtype + + @property + def values(self): + return list(self._values) + + def sample_uniform(self, rng=random): + return rng.choice(self._values) + + def update_hparam_info(self, hparam_info): + hparam_info.type = { + int: api_pb2.DATA_TYPE_FLOAT64, # TODO(#1998): Add int dtype. + float: api_pb2.DATA_TYPE_FLOAT64, + bool: api_pb2.DATA_TYPE_BOOL, + str: api_pb2.DATA_TYPE_STRING, + }[self._dtype] + hparam_info.ClearField("domain_discrete") + hparam_info.domain_discrete.extend(self._values) class Metric(object): - """A metric in an experiment. - - A metric is a real-valued function of a model. Each metric is - associated with a TensorBoard scalar summary, which logs the metric's - value as the model trains. - """ - TRAINING = api_pb2.DATASET_TRAINING - VALIDATION = api_pb2.DATASET_VALIDATION - - def __init__( - self, - tag, - group=None, - display_name=None, - description=None, - dataset_type=None, - ): + """A metric in an experiment. + + A metric is a real-valued function of a model. Each metric is + associated with a TensorBoard scalar summary, which logs the + metric's value as the model trains. """ + + TRAINING = api_pb2.DATASET_TRAINING + VALIDATION = api_pb2.DATASET_VALIDATION + + def __init__( + self, + tag, + group=None, + display_name=None, + description=None, + dataset_type=None, + ): + """ Args: tag: The tag name of the scalar summary that corresponds to this metric (as a `str`). @@ -551,21 +553,18 @@ def __init__( dataset_type: Either `Metric.TRAINING` or `Metric.VALIDATION`, or `None`. """ - self._tag = tag - self._group = group - self._display_name = display_name - self._description = description - self._dataset_type = dataset_type - if self._dataset_type not in (None, Metric.TRAINING, Metric.VALIDATION): - raise ValueError("invalid dataset type: %r" % (self._dataset_type,)) - - def as_proto(self): - return api_pb2.MetricInfo( - name=api_pb2.MetricName( - group=self._group, - tag=self._tag, - ), - display_name=self._display_name, - description=self._description, - dataset_type=self._dataset_type, - ) + self._tag = tag + self._group = group + self._display_name = display_name + self._description = description + self._dataset_type = dataset_type + if self._dataset_type not in (None, Metric.TRAINING, Metric.VALIDATION): + raise ValueError("invalid dataset type: %r" % (self._dataset_type,)) + + def as_proto(self): + return api_pb2.MetricInfo( + name=api_pb2.MetricName(group=self._group, tag=self._tag,), + display_name=self._display_name, + description=self._description, + dataset_type=self._dataset_type, + ) diff --git a/tensorboard/plugins/hparams/summary_v2_test.py b/tensorboard/plugins/hparams/summary_v2_test.py index aa915bfedd..a0e10a9a11 100644 --- a/tensorboard/plugins/hparams/summary_v2_test.py +++ b/tensorboard/plugins/hparams/summary_v2_test.py @@ -30,10 +30,10 @@ from six.moves import xrange # pylint: disable=redefined-builtin try: - # python version >= 3.3 - from unittest import mock + # python version >= 3.3 + from unittest import mock except ImportError: - import mock # pylint: disable=unused-import + import mock # pylint: disable=unused-import from tensorboard import test from tensorboard.compat import tf @@ -45,48 +45,48 @@ if tf.__version__ == "stub": - tf = None + tf = None if tf is not None: - tf.compat.v1.enable_eager_execution() + tf.compat.v1.enable_eager_execution() requires_tf = unittest.skipIf(tf is None, "Requires TensorFlow.") class HParamsTest(test.TestCase): - """Tests for `summary_v2.hparams` and `summary_v2.hparams_pb`.""" - - def setUp(self): - self.logdir = os.path.join(self.get_temp_dir(), "logs") - self.hparams = { - hp.HParam("learning_rate", hp.RealInterval(1e-2, 1e-1)): 0.02, - hp.HParam("dense_layers", hp.IntInterval(2, 7)): 5, - hp.HParam("optimizer", hp.Discrete(["adam", "sgd"])): "adam", - hp.HParam("who_knows_what"): "???", - hp.HParam( - "magic", - hp.Discrete([False, True]), - display_name="~*~ Magic ~*~", - description="descriptive", - ): True, - "dropout": 0.3, - } - self.normalized_hparams = { - "learning_rate": 0.02, - "dense_layers": 5, - "optimizer": "adam", - "who_knows_what": "???", - "magic": True, - "dropout": 0.3, - } - self.start_time_secs = 123.45 - self.trial_id = "psl27" - - self.expected_session_start_pb = plugin_data_pb2.SessionStartInfo() - text_format.Merge( - """ + """Tests for `summary_v2.hparams` and `summary_v2.hparams_pb`.""" + + def setUp(self): + self.logdir = os.path.join(self.get_temp_dir(), "logs") + self.hparams = { + hp.HParam("learning_rate", hp.RealInterval(1e-2, 1e-1)): 0.02, + hp.HParam("dense_layers", hp.IntInterval(2, 7)): 5, + hp.HParam("optimizer", hp.Discrete(["adam", "sgd"])): "adam", + hp.HParam("who_knows_what"): "???", + hp.HParam( + "magic", + hp.Discrete([False, True]), + display_name="~*~ Magic ~*~", + description="descriptive", + ): True, + "dropout": 0.3, + } + self.normalized_hparams = { + "learning_rate": 0.02, + "dense_layers": 5, + "optimizer": "adam", + "who_knows_what": "???", + "magic": True, + "dropout": 0.3, + } + self.start_time_secs = 123.45 + self.trial_id = "psl27" + + self.expected_session_start_pb = plugin_data_pb2.SessionStartInfo() + text_format.Merge( + """ hparams { key: "learning_rate" value { number_value: 0.02 } } hparams { key: "dense_layers" value { number_value: 5 } } hparams { key: "optimizer" value { string_value: "adam" } } @@ -94,206 +94,218 @@ def setUp(self): hparams { key: "magic" value { bool_value: true } } hparams { key: "dropout" value { number_value: 0.3 } } """, - self.expected_session_start_pb, - ) - self.expected_session_start_pb.group_name = self.trial_id - self.expected_session_start_pb.start_time_secs = self.start_time_secs - - def _check_summary(self, summary_pb, check_group_name=False): - """Test that a summary contains exactly the expected hparams PB.""" - values = summary_pb.value - self.assertEqual(len(values), 1, values) - actual_value = values[0] - self.assertEqual( - actual_value.metadata.plugin_data.plugin_name, - metadata.PLUGIN_NAME, - ) - plugin_content = actual_value.metadata.plugin_data.content - info_pb = metadata.parse_session_start_info_plugin_data(plugin_content) - # Usually ignore the `group_name` field; its properties are checked - # separately. - if not check_group_name: - info_pb.group_name = self.expected_session_start_pb.group_name - self.assertEqual(info_pb, self.expected_session_start_pb) - - def _check_logdir(self, logdir, check_group_name=False): - """Test that the hparams summary was written to `logdir`.""" - self._check_summary( - _get_unique_summary(self, logdir), - check_group_name=check_group_name, - ) - - @requires_tf - def test_eager(self): - with tf.compat.v2.summary.create_file_writer(self.logdir).as_default(): - result = hp.hparams( - self.hparams, - trial_id=self.trial_id, - start_time_secs=self.start_time_secs, - ) - self.assertTrue(result) - self._check_logdir(self.logdir) - - @requires_tf - def test_graph_mode(self): - with \ - tf.compat.v1.Graph().as_default(), \ - tf.compat.v1.Session() as sess, \ - tf.compat.v2.summary.create_file_writer(self.logdir).as_default() as w: - sess.run(w.init()) - summ = hp.hparams(self.hparams, start_time_secs=self.start_time_secs) - self.assertTrue(sess.run(summ)) - sess.run(w.flush()) - self._check_logdir(self.logdir) - - @requires_tf - def test_eager_no_default_writer(self): - result = hp.hparams(self.hparams, start_time_secs=self.start_time_secs) - self.assertFalse(result) # no default writer - - def test_pb_contents(self): - result = hp.hparams_pb(self.hparams, start_time_secs=self.start_time_secs) - self._check_summary(result) - - def test_pb_is_tensorboard_copy_of_proto(self): - result = hp.hparams_pb(self.hparams, start_time_secs=self.start_time_secs) - self.assertIsInstance(result, summary_pb2.Summary) - if tf is not None: - self.assertNotIsInstance(result, tf.compat.v1.Summary) - - def test_pb_explicit_trial_id(self): - result = hp.hparams_pb( - self.hparams, - trial_id=self.trial_id, - start_time_secs=self.start_time_secs, - ) - self._check_summary(result, check_group_name=True) - - def test_pb_invalid_trial_id(self): - with six.assertRaisesRegex( - self, TypeError, "`trial_id` should be a `str`, but got: 12"): - hp.hparams_pb(self.hparams, trial_id=12) - - def assert_hparams_summaries_equal(self, summary_1, summary_2): - def canonical(summary): - """Return a canonical form for `summary`. - - The result is such that `canonical(a) == canonical(b)` if and only - if `a` and `b` are logically equivalent. - - Args: - summary: A `summary_pb2.Summary` containing hparams plugin data. - """ - new_summary = summary_pb2.Summary() - new_summary.MergeFrom(summary) - values = new_summary.value - self.assertEqual(len(values), 1, values) - value = values[0] - raw_content = value.metadata.plugin_data.content - value.metadata.plugin_data.content = b"" - content = plugin_data_pb2.HParamsPluginData.FromString(raw_content) - return (new_summary, content) - - self.assertEqual(canonical(summary_1), canonical(summary_2)) - - def test_consistency_across_string_key_and_object_key(self): - hparams_1 = { - hp.HParam("optimizer", hp.Discrete(["adam", "sgd"])): "adam", - "learning_rate": 0.02, - } - hparams_2 = { - "optimizer": "adam", - hp.HParam("learning_rate", hp.RealInterval(1e-2, 1e-1)): 0.02, - } - self.assert_hparams_summaries_equal( - hp.hparams_pb(hparams_1, start_time_secs=self.start_time_secs), - hp.hparams_pb(hparams_2, start_time_secs=self.start_time_secs), - ) - - def test_duplicate_hparam_names_across_object_and_string(self): - hparams = { - "foo": 1, - hp.HParam("foo"): 1, - } - with six.assertRaisesRegex( - self, ValueError, "multiple values specified for hparam 'foo'"): - hp.hparams_pb(hparams) - - def test_duplicate_hparam_names_from_two_objects(self): - hparams = { - hp.HParam("foo"): 1, - hp.HParam("foo"): 1, - } - with six.assertRaisesRegex( - self, ValueError, "multiple values specified for hparam 'foo'"): - hp.hparams_pb(hparams) - - def test_invariant_under_permutation(self): - # In particular, the group name should be the same. - hparams_1 = { - "optimizer": "adam", - "learning_rate": 0.02, - } - hparams_2 = { - "learning_rate": 0.02, - "optimizer": "adam", - } - self.assert_hparams_summaries_equal( - hp.hparams_pb(hparams_1, start_time_secs=self.start_time_secs), - hp.hparams_pb(hparams_2, start_time_secs=self.start_time_secs), - ) - - def test_group_name_differs_across_hparams_values(self): - hparams_1 = {"foo": 1, "bar": 2, "baz": 4} - hparams_2 = {"foo": 1, "bar": 3, "baz": 4} - def get_group_name(hparams): - summary_pb = hp.hparams_pb(hparams) - values = summary_pb.value - self.assertEqual(len(values), 1, values) - actual_value = values[0] - self.assertEqual( - actual_value.metadata.plugin_data.plugin_name, - metadata.PLUGIN_NAME, - ) - plugin_content = actual_value.metadata.plugin_data.content - info = metadata.parse_session_start_info_plugin_data(plugin_content) - return info.group_name - - self.assertNotEqual(get_group_name(hparams_1), get_group_name(hparams_2)) + self.expected_session_start_pb, + ) + self.expected_session_start_pb.group_name = self.trial_id + self.expected_session_start_pb.start_time_secs = self.start_time_secs + + def _check_summary(self, summary_pb, check_group_name=False): + """Test that a summary contains exactly the expected hparams PB.""" + values = summary_pb.value + self.assertEqual(len(values), 1, values) + actual_value = values[0] + self.assertEqual( + actual_value.metadata.plugin_data.plugin_name, metadata.PLUGIN_NAME, + ) + plugin_content = actual_value.metadata.plugin_data.content + info_pb = metadata.parse_session_start_info_plugin_data(plugin_content) + # Usually ignore the `group_name` field; its properties are checked + # separately. + if not check_group_name: + info_pb.group_name = self.expected_session_start_pb.group_name + self.assertEqual(info_pb, self.expected_session_start_pb) + + def _check_logdir(self, logdir, check_group_name=False): + """Test that the hparams summary was written to `logdir`.""" + self._check_summary( + _get_unique_summary(self, logdir), + check_group_name=check_group_name, + ) + + @requires_tf + def test_eager(self): + with tf.compat.v2.summary.create_file_writer(self.logdir).as_default(): + result = hp.hparams( + self.hparams, + trial_id=self.trial_id, + start_time_secs=self.start_time_secs, + ) + self.assertTrue(result) + self._check_logdir(self.logdir) + + @requires_tf + def test_graph_mode(self): + with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as sess, tf.compat.v2.summary.create_file_writer( + self.logdir + ).as_default() as w: + sess.run(w.init()) + summ = hp.hparams( + self.hparams, start_time_secs=self.start_time_secs + ) + self.assertTrue(sess.run(summ)) + sess.run(w.flush()) + self._check_logdir(self.logdir) + + @requires_tf + def test_eager_no_default_writer(self): + result = hp.hparams(self.hparams, start_time_secs=self.start_time_secs) + self.assertFalse(result) # no default writer + + def test_pb_contents(self): + result = hp.hparams_pb( + self.hparams, start_time_secs=self.start_time_secs + ) + self._check_summary(result) + + def test_pb_is_tensorboard_copy_of_proto(self): + result = hp.hparams_pb( + self.hparams, start_time_secs=self.start_time_secs + ) + self.assertIsInstance(result, summary_pb2.Summary) + if tf is not None: + self.assertNotIsInstance(result, tf.compat.v1.Summary) + + def test_pb_explicit_trial_id(self): + result = hp.hparams_pb( + self.hparams, + trial_id=self.trial_id, + start_time_secs=self.start_time_secs, + ) + self._check_summary(result, check_group_name=True) + + def test_pb_invalid_trial_id(self): + with six.assertRaisesRegex( + self, TypeError, "`trial_id` should be a `str`, but got: 12" + ): + hp.hparams_pb(self.hparams, trial_id=12) + + def assert_hparams_summaries_equal(self, summary_1, summary_2): + def canonical(summary): + """Return a canonical form for `summary`. + + The result is such that `canonical(a) == canonical(b)` if and only + if `a` and `b` are logically equivalent. + + Args: + summary: A `summary_pb2.Summary` containing hparams plugin data. + """ + new_summary = summary_pb2.Summary() + new_summary.MergeFrom(summary) + values = new_summary.value + self.assertEqual(len(values), 1, values) + value = values[0] + raw_content = value.metadata.plugin_data.content + value.metadata.plugin_data.content = b"" + content = plugin_data_pb2.HParamsPluginData.FromString(raw_content) + return (new_summary, content) + + self.assertEqual(canonical(summary_1), canonical(summary_2)) + + def test_consistency_across_string_key_and_object_key(self): + hparams_1 = { + hp.HParam("optimizer", hp.Discrete(["adam", "sgd"])): "adam", + "learning_rate": 0.02, + } + hparams_2 = { + "optimizer": "adam", + hp.HParam("learning_rate", hp.RealInterval(1e-2, 1e-1)): 0.02, + } + self.assert_hparams_summaries_equal( + hp.hparams_pb(hparams_1, start_time_secs=self.start_time_secs), + hp.hparams_pb(hparams_2, start_time_secs=self.start_time_secs), + ) + + def test_duplicate_hparam_names_across_object_and_string(self): + hparams = { + "foo": 1, + hp.HParam("foo"): 1, + } + with six.assertRaisesRegex( + self, ValueError, "multiple values specified for hparam 'foo'" + ): + hp.hparams_pb(hparams) + + def test_duplicate_hparam_names_from_two_objects(self): + hparams = { + hp.HParam("foo"): 1, + hp.HParam("foo"): 1, + } + with six.assertRaisesRegex( + self, ValueError, "multiple values specified for hparam 'foo'" + ): + hp.hparams_pb(hparams) + + def test_invariant_under_permutation(self): + # In particular, the group name should be the same. + hparams_1 = { + "optimizer": "adam", + "learning_rate": 0.02, + } + hparams_2 = { + "learning_rate": 0.02, + "optimizer": "adam", + } + self.assert_hparams_summaries_equal( + hp.hparams_pb(hparams_1, start_time_secs=self.start_time_secs), + hp.hparams_pb(hparams_2, start_time_secs=self.start_time_secs), + ) + + def test_group_name_differs_across_hparams_values(self): + hparams_1 = {"foo": 1, "bar": 2, "baz": 4} + hparams_2 = {"foo": 1, "bar": 3, "baz": 4} + + def get_group_name(hparams): + summary_pb = hp.hparams_pb(hparams) + values = summary_pb.value + self.assertEqual(len(values), 1, values) + actual_value = values[0] + self.assertEqual( + actual_value.metadata.plugin_data.plugin_name, + metadata.PLUGIN_NAME, + ) + plugin_content = actual_value.metadata.plugin_data.content + info = metadata.parse_session_start_info_plugin_data(plugin_content) + return info.group_name + + self.assertNotEqual( + get_group_name(hparams_1), get_group_name(hparams_2) + ) class HParamsConfigTest(test.TestCase): - def setUp(self): - self.logdir = os.path.join(self.get_temp_dir(), "logs") - - self.hparams = [ - hp.HParam("learning_rate", hp.RealInterval(1e-2, 1e-1)), - hp.HParam("dense_layers", hp.IntInterval(2, 7)), - hp.HParam("optimizer", hp.Discrete(["adam", "sgd"])), - hp.HParam("who_knows_what"), - hp.HParam( - "magic", - hp.Discrete([False, True]), - display_name="~*~ Magic ~*~", - description="descriptive", - ), - ] - self.metrics = [ - hp.Metric("samples_per_second"), - hp.Metric(group="train", tag="batch_loss", display_name="loss (train)"), - hp.Metric( - group="validation", - tag="epoch_accuracy", - display_name="accuracy (val.)", - description="Accuracy on the _validation_ dataset.", - dataset_type=hp.Metric.VALIDATION, - ), - ] - self.time_created_secs = 1555624767.0 - - self.expected_experiment_pb = api_pb2.Experiment() - text_format.Merge( - """ + def setUp(self): + self.logdir = os.path.join(self.get_temp_dir(), "logs") + + self.hparams = [ + hp.HParam("learning_rate", hp.RealInterval(1e-2, 1e-1)), + hp.HParam("dense_layers", hp.IntInterval(2, 7)), + hp.HParam("optimizer", hp.Discrete(["adam", "sgd"])), + hp.HParam("who_knows_what"), + hp.HParam( + "magic", + hp.Discrete([False, True]), + display_name="~*~ Magic ~*~", + description="descriptive", + ), + ] + self.metrics = [ + hp.Metric("samples_per_second"), + hp.Metric( + group="train", tag="batch_loss", display_name="loss (train)" + ), + hp.Metric( + group="validation", + tag="epoch_accuracy", + display_name="accuracy (val.)", + description="Accuracy on the _validation_ dataset.", + dataset_type=hp.Metric.VALIDATION, + ), + ] + self.time_created_secs = 1555624767.0 + + self.expected_experiment_pb = api_pb2.Experiment() + text_format.Merge( + """ time_created_secs: 1555624767.0 hparam_infos { name: "learning_rate" @@ -362,259 +374,261 @@ def setUp(self): dataset_type: DATASET_VALIDATION } """, - self.expected_experiment_pb, - ) - - def _check_summary(self, summary_pb): - """Test that a summary contains exactly the expected experiment PB.""" - values = summary_pb.value - self.assertEqual(len(values), 1, values) - actual_value = values[0] - self.assertEqual( - actual_value.metadata.plugin_data.plugin_name, - metadata.PLUGIN_NAME, - ) - plugin_content = actual_value.metadata.plugin_data.content - self.assertEqual( - metadata.parse_experiment_plugin_data(plugin_content), - self.expected_experiment_pb, - ) - - def _check_logdir(self, logdir): - """Test that the experiment summary was written to `logdir`.""" - self._check_summary(_get_unique_summary(self, logdir)) - - @requires_tf - def test_eager(self): - with tf.compat.v2.summary.create_file_writer(self.logdir).as_default(): - result = hp.hparams_config( - hparams=self.hparams, - metrics=self.metrics, - time_created_secs=self.time_created_secs, - ) - self.assertTrue(result) - self._check_logdir(self.logdir) - - @requires_tf - def test_graph_mode(self): - with \ - tf.compat.v1.Graph().as_default(), \ - tf.compat.v1.Session() as sess, \ - tf.compat.v2.summary.create_file_writer(self.logdir).as_default() as w: - sess.run(w.init()) - summ = hp.hparams_config( - hparams=self.hparams, - metrics=self.metrics, - time_created_secs=self.time_created_secs, - ) - self.assertTrue(sess.run(summ)) - sess.run(w.flush()) - self._check_logdir(self.logdir) - - @requires_tf - def test_eager_no_default_writer(self): - result = hp.hparams_config( - hparams=self.hparams, - metrics=self.metrics, - time_created_secs=self.time_created_secs, - ) - self.assertFalse(result) # no default writer - - def test_pb_contents(self): - result = hp.hparams_config_pb( - hparams=self.hparams, - metrics=self.metrics, - time_created_secs=self.time_created_secs, - ) - self._check_summary(result) - - def test_pb_is_tensorboard_copy_of_proto(self): - result = hp.hparams_config_pb( - hparams=self.hparams, - metrics=self.metrics, - time_created_secs=self.time_created_secs, - ) - self.assertIsInstance(result, summary_pb2.Summary) - if tf is not None: - self.assertNotIsInstance(result, tf.compat.v1.Summary) + self.expected_experiment_pb, + ) + + def _check_summary(self, summary_pb): + """Test that a summary contains exactly the expected experiment PB.""" + values = summary_pb.value + self.assertEqual(len(values), 1, values) + actual_value = values[0] + self.assertEqual( + actual_value.metadata.plugin_data.plugin_name, metadata.PLUGIN_NAME, + ) + plugin_content = actual_value.metadata.plugin_data.content + self.assertEqual( + metadata.parse_experiment_plugin_data(plugin_content), + self.expected_experiment_pb, + ) + + def _check_logdir(self, logdir): + """Test that the experiment summary was written to `logdir`.""" + self._check_summary(_get_unique_summary(self, logdir)) + + @requires_tf + def test_eager(self): + with tf.compat.v2.summary.create_file_writer(self.logdir).as_default(): + result = hp.hparams_config( + hparams=self.hparams, + metrics=self.metrics, + time_created_secs=self.time_created_secs, + ) + self.assertTrue(result) + self._check_logdir(self.logdir) + + @requires_tf + def test_graph_mode(self): + with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as sess, tf.compat.v2.summary.create_file_writer( + self.logdir + ).as_default() as w: + sess.run(w.init()) + summ = hp.hparams_config( + hparams=self.hparams, + metrics=self.metrics, + time_created_secs=self.time_created_secs, + ) + self.assertTrue(sess.run(summ)) + sess.run(w.flush()) + self._check_logdir(self.logdir) + + @requires_tf + def test_eager_no_default_writer(self): + result = hp.hparams_config( + hparams=self.hparams, + metrics=self.metrics, + time_created_secs=self.time_created_secs, + ) + self.assertFalse(result) # no default writer + + def test_pb_contents(self): + result = hp.hparams_config_pb( + hparams=self.hparams, + metrics=self.metrics, + time_created_secs=self.time_created_secs, + ) + self._check_summary(result) + + def test_pb_is_tensorboard_copy_of_proto(self): + result = hp.hparams_config_pb( + hparams=self.hparams, + metrics=self.metrics, + time_created_secs=self.time_created_secs, + ) + self.assertIsInstance(result, summary_pb2.Summary) + if tf is not None: + self.assertNotIsInstance(result, tf.compat.v1.Summary) def _get_unique_summary(self, logdir): - """Get the unique `Summary` stored in `logdir`. - - Specifically, `logdir` must be a directory containing exactly one - entry, which must be an events file of whose events exactly one is a - summary. This unique summary will be returned. - - Args: - self: A `TestCase` object, used for assertions. - logdir: String path to a logdir. - - Returns: - A `summary_pb2.Summary` object. - """ - files = os.listdir(logdir) - self.assertEqual(len(files), 1, files) - events_file = os.path.join(logdir, files[0]) - summaries = [ - event.summary - for event in tf.compat.v1.train.summary_iterator(events_file) - if event.WhichOneof("what") == "summary" - ] - self.assertEqual(len(summaries), 1, summaries) - return summaries[0] + """Get the unique `Summary` stored in `logdir`. + + Specifically, `logdir` must be a directory containing exactly one + entry, which must be an events file of whose events exactly one is a + summary. This unique summary will be returned. + + Args: + self: A `TestCase` object, used for assertions. + logdir: String path to a logdir. + + Returns: + A `summary_pb2.Summary` object. + """ + files = os.listdir(logdir) + self.assertEqual(len(files), 1, files) + events_file = os.path.join(logdir, files[0]) + summaries = [ + event.summary + for event in tf.compat.v1.train.summary_iterator(events_file) + if event.WhichOneof("what") == "summary" + ] + self.assertEqual(len(summaries), 1, summaries) + return summaries[0] class IntIntervalTest(test.TestCase): - def test_simple(self): - domain = hp.IntInterval(3, 7) - self.assertEqual(domain.min_value, 3) - self.assertEqual(domain.max_value, 7) - self.assertEqual(domain.dtype, int) - - def test_singleton_domain(self): - domain = hp.IntInterval(61, 61) - self.assertEqual(domain.min_value, 61) - self.assertEqual(domain.max_value, 61) - self.assertEqual(domain.dtype, int) - - def test_non_ints(self): - with six.assertRaisesRegex( - self, TypeError, "min_value must be an int: -inf"): - hp.IntInterval(float("-inf"), 0) - with six.assertRaisesRegex( - self, TypeError, "max_value must be an int: 'eleven'"): - hp.IntInterval(7, "eleven") - - def test_backward_endpoints(self): - with six.assertRaisesRegex( - self, ValueError, "123 > 45"): - hp.IntInterval(123, 45) - - def test_sample_uniform(self): - domain = hp.IntInterval(2, 7) - rng = mock.Mock() - sentinel = object() - # Note: `randint` samples from a closed interval, which is what we - # want (as opposed to `randrange`). - rng.randint.return_value = sentinel - result = domain.sample_uniform(rng) - self.assertIs(result, sentinel) - rng.randint.assert_called_once_with(2, 7) - - def test_sample_uniform_unseeded(self): - domain = hp.IntInterval(2, 7) - # Note: `randint` samples from a closed interval, which is what we - # want (as opposed to `randrange`). - with mock.patch.object(random, "randint") as m: - sentinel = object() - m.return_value = sentinel - result = domain.sample_uniform() - self.assertIs(result, sentinel) - m.assert_called_once_with(2, 7) + def test_simple(self): + domain = hp.IntInterval(3, 7) + self.assertEqual(domain.min_value, 3) + self.assertEqual(domain.max_value, 7) + self.assertEqual(domain.dtype, int) + + def test_singleton_domain(self): + domain = hp.IntInterval(61, 61) + self.assertEqual(domain.min_value, 61) + self.assertEqual(domain.max_value, 61) + self.assertEqual(domain.dtype, int) + + def test_non_ints(self): + with six.assertRaisesRegex( + self, TypeError, "min_value must be an int: -inf" + ): + hp.IntInterval(float("-inf"), 0) + with six.assertRaisesRegex( + self, TypeError, "max_value must be an int: 'eleven'" + ): + hp.IntInterval(7, "eleven") + + def test_backward_endpoints(self): + with six.assertRaisesRegex(self, ValueError, "123 > 45"): + hp.IntInterval(123, 45) + + def test_sample_uniform(self): + domain = hp.IntInterval(2, 7) + rng = mock.Mock() + sentinel = object() + # Note: `randint` samples from a closed interval, which is what we + # want (as opposed to `randrange`). + rng.randint.return_value = sentinel + result = domain.sample_uniform(rng) + self.assertIs(result, sentinel) + rng.randint.assert_called_once_with(2, 7) + + def test_sample_uniform_unseeded(self): + domain = hp.IntInterval(2, 7) + # Note: `randint` samples from a closed interval, which is what we + # want (as opposed to `randrange`). + with mock.patch.object(random, "randint") as m: + sentinel = object() + m.return_value = sentinel + result = domain.sample_uniform() + self.assertIs(result, sentinel) + m.assert_called_once_with(2, 7) class RealIntervalTest(test.TestCase): - def test_simple(self): - domain = hp.RealInterval(3.1, 7.7) - self.assertEqual(domain.min_value, 3.1) - self.assertEqual(domain.max_value, 7.7) - self.assertEqual(domain.dtype, float) - - def test_singleton_domain(self): - domain = hp.RealInterval(61.318, 61.318) - self.assertEqual(domain.min_value, 61.318) - self.assertEqual(domain.max_value, 61.318) - self.assertEqual(domain.dtype, float) - - def test_infinite_domain(self): - inf = float("inf") - domain = hp.RealInterval(-inf, inf) - self.assertEqual(domain.min_value, -inf) - self.assertEqual(domain.max_value, inf) - self.assertEqual(domain.dtype, float) - - def test_non_ints(self): - with six.assertRaisesRegex( - self, TypeError, "min_value must be a float: True"): - hp.RealInterval(True, 2.0) - with six.assertRaisesRegex( - self, TypeError, "max_value must be a float: 'wat'"): - hp.RealInterval(1.2, "wat") - - def test_backward_endpoints(self): - with six.assertRaisesRegex( - self, ValueError, "2.1 > 1.2"): - hp.RealInterval(2.1, 1.2) - - def test_sample_uniform(self): - domain = hp.RealInterval(2.0, 4.0) - rng = mock.Mock() - sentinel = object() - rng.uniform.return_value = sentinel - result = domain.sample_uniform(rng) - self.assertIs(result, sentinel) - rng.uniform.assert_called_once_with(2.0, 4.0) - - def test_sample_uniform_unseeded(self): - domain = hp.RealInterval(2.0, 4.0) - with mock.patch.object(random, "uniform") as m: - sentinel = object() - m.return_value = sentinel - result = domain.sample_uniform() - self.assertIs(result, sentinel) - m.assert_called_once_with(2.0, 4.0) + def test_simple(self): + domain = hp.RealInterval(3.1, 7.7) + self.assertEqual(domain.min_value, 3.1) + self.assertEqual(domain.max_value, 7.7) + self.assertEqual(domain.dtype, float) + + def test_singleton_domain(self): + domain = hp.RealInterval(61.318, 61.318) + self.assertEqual(domain.min_value, 61.318) + self.assertEqual(domain.max_value, 61.318) + self.assertEqual(domain.dtype, float) + + def test_infinite_domain(self): + inf = float("inf") + domain = hp.RealInterval(-inf, inf) + self.assertEqual(domain.min_value, -inf) + self.assertEqual(domain.max_value, inf) + self.assertEqual(domain.dtype, float) + + def test_non_ints(self): + with six.assertRaisesRegex( + self, TypeError, "min_value must be a float: True" + ): + hp.RealInterval(True, 2.0) + with six.assertRaisesRegex( + self, TypeError, "max_value must be a float: 'wat'" + ): + hp.RealInterval(1.2, "wat") + + def test_backward_endpoints(self): + with six.assertRaisesRegex(self, ValueError, "2.1 > 1.2"): + hp.RealInterval(2.1, 1.2) + + def test_sample_uniform(self): + domain = hp.RealInterval(2.0, 4.0) + rng = mock.Mock() + sentinel = object() + rng.uniform.return_value = sentinel + result = domain.sample_uniform(rng) + self.assertIs(result, sentinel) + rng.uniform.assert_called_once_with(2.0, 4.0) + + def test_sample_uniform_unseeded(self): + domain = hp.RealInterval(2.0, 4.0) + with mock.patch.object(random, "uniform") as m: + sentinel = object() + m.return_value = sentinel + result = domain.sample_uniform() + self.assertIs(result, sentinel) + m.assert_called_once_with(2.0, 4.0) class DiscreteTest(test.TestCase): - def test_simple(self): - domain = hp.Discrete([1, 2, 5]) - self.assertEqual(domain.values, [1, 2, 5]) - self.assertEqual(domain.dtype, int) - - def test_values_sorted(self): - domain = hp.Discrete([2, 3, 1]) - self.assertEqual(domain.values, [1, 2, 3]) - self.assertEqual(domain.dtype, int) - - def test_empty_with_explicit_dtype(self): - domain = hp.Discrete([], dtype=bool) - self.assertIs(domain.dtype, bool) - self.assertEqual(domain.values, []) - - def test_empty_with_unspecified_dtype(self): - with six.assertRaisesRegex( - self, ValueError, "Empty domain with no dtype specified"): - hp.Discrete([]) - - def test_dtype_mismatch(self): - with six.assertRaisesRegex( - self, TypeError, r"dtype mismatch: not isinstance\(2, str\)"): - hp.Discrete(["one", 2]) - - def test_sample_uniform(self): - domain = hp.Discrete(["red", "green", "blue"]) - rng = mock.Mock() - sentinel = object() - rng.choice.return_value = sentinel - result = domain.sample_uniform(rng) - self.assertIs(result, sentinel) - # Call to `sorted` is an implementation detail of `sample_uniform`. - rng.choice.assert_called_once_with(sorted(["red", "green", "blue"])) - - def test_sample_uniform_unseeded(self): - domain = hp.Discrete(["red", "green", "blue"]) - with mock.patch.object(random, "choice") as m: - sentinel = object() - m.return_value = sentinel - result = domain.sample_uniform() - self.assertIs(result, sentinel) - # Call to `sorted` is an implementation detail of `sample_uniform`. - m.assert_called_once_with(sorted(["red", "green", "blue"])) + def test_simple(self): + domain = hp.Discrete([1, 2, 5]) + self.assertEqual(domain.values, [1, 2, 5]) + self.assertEqual(domain.dtype, int) + + def test_values_sorted(self): + domain = hp.Discrete([2, 3, 1]) + self.assertEqual(domain.values, [1, 2, 3]) + self.assertEqual(domain.dtype, int) + + def test_empty_with_explicit_dtype(self): + domain = hp.Discrete([], dtype=bool) + self.assertIs(domain.dtype, bool) + self.assertEqual(domain.values, []) + + def test_empty_with_unspecified_dtype(self): + with six.assertRaisesRegex( + self, ValueError, "Empty domain with no dtype specified" + ): + hp.Discrete([]) + + def test_dtype_mismatch(self): + with six.assertRaisesRegex( + self, TypeError, r"dtype mismatch: not isinstance\(2, str\)" + ): + hp.Discrete(["one", 2]) + + def test_sample_uniform(self): + domain = hp.Discrete(["red", "green", "blue"]) + rng = mock.Mock() + sentinel = object() + rng.choice.return_value = sentinel + result = domain.sample_uniform(rng) + self.assertIs(result, sentinel) + # Call to `sorted` is an implementation detail of `sample_uniform`. + rng.choice.assert_called_once_with(sorted(["red", "green", "blue"])) + + def test_sample_uniform_unseeded(self): + domain = hp.Discrete(["red", "green", "blue"]) + with mock.patch.object(random, "choice") as m: + sentinel = object() + m.return_value = sentinel + result = domain.sample_uniform() + self.assertIs(result, sentinel) + # Call to `sorted` is an implementation detail of `sample_uniform`. + m.assert_called_once_with(sorted(["red", "green", "blue"])) if __name__ == "__main__": - if tf is not None: - tf.test.main() - else: - test.main() + if tf is not None: + tf.test.main() + else: + test.main() diff --git a/tensorboard/plugins/image/images_demo.py b/tensorboard/plugins/image/images_demo.py index 8e017ef665..3141afbe8e 100644 --- a/tensorboard/plugins/image/images_demo.py +++ b/tensorboard/plugins/image/images_demo.py @@ -36,10 +36,10 @@ logger = tb_logging.get_logger() # Directory into which to write tensorboard data. -LOGDIR = '/tmp/images_demo' +LOGDIR = "/tmp/images_demo" # pylint: disable=line-too-long -IMAGE_URL = r'https://upload.wikimedia.org/wikipedia/commons/f/f0/Valve_original_%281%29.PNG' +IMAGE_URL = r"https://upload.wikimedia.org/wikipedia/commons/f/f0/Valve_original_%281%29.PNG" # pylint: enable=line-too-long IMAGE_CREDIT = textwrap.dedent( """\ @@ -48,240 +48,278 @@ [User:Tauraloke]: https://commons.wikimedia.org/wiki/User:Tauraloke [Source]: https://commons.wikimedia.org/wiki/File:Valve_original_(1).PNG - """) + """ +) (IMAGE_WIDTH, IMAGE_HEIGHT) = (640, 480) _IMAGE_DATA = None def image_data(verbose=False): - """Get the raw encoded image data, downloading it if necessary.""" - # This is a principled use of the `global` statement; don't lint me. - global _IMAGE_DATA # pylint: disable=global-statement - if _IMAGE_DATA is None: - if verbose: - logger.info("--- Downloading image.") - with contextlib.closing(urllib.request.urlopen(IMAGE_URL)) as infile: - _IMAGE_DATA = infile.read() - return _IMAGE_DATA + """Get the raw encoded image data, downloading it if necessary.""" + # This is a principled use of the `global` statement; don't lint me. + global _IMAGE_DATA # pylint: disable=global-statement + if _IMAGE_DATA is None: + if verbose: + logger.info("--- Downloading image.") + with contextlib.closing(urllib.request.urlopen(IMAGE_URL)) as infile: + _IMAGE_DATA = infile.read() + return _IMAGE_DATA def convolve(image, pixel_filter, channels=3, name=None): - """Perform a 2D pixel convolution on the given image. - - Arguments: - image: A 3D `float32` `Tensor` of shape `[height, width, channels]`, - where `channels` is the third argument to this function and the - first two dimensions are arbitrary. - pixel_filter: A 2D `Tensor`, representing pixel weightings for the - kernel. This will be used to create a 4D kernel---the extra two - dimensions are for channels (see `tf.nn.conv2d` documentation), - and the kernel will be constructed so that the channels are - independent: each channel only observes the data from neighboring - pixels of the same channel. - channels: An integer representing the number of channels in the - image (e.g., 3 for RGB). - - Returns: - A 3D `float32` `Tensor` of the same shape as the input. - """ - with tf.name_scope(name, 'convolve'): - tf.compat.v1.assert_type(image, tf.float32) - channel_filter = tf.eye(channels) - filter_ = (tf.expand_dims(tf.expand_dims(pixel_filter, -1), -1) * - tf.expand_dims(tf.expand_dims(channel_filter, 0), 0)) - result_batch = tf.nn.conv2d(tf.stack([image]), # batch - filter=filter_, - strides=[1, 1, 1, 1], - padding='SAME') - return result_batch[0] # unbatch + """Perform a 2D pixel convolution on the given image. + + Arguments: + image: A 3D `float32` `Tensor` of shape `[height, width, channels]`, + where `channels` is the third argument to this function and the + first two dimensions are arbitrary. + pixel_filter: A 2D `Tensor`, representing pixel weightings for the + kernel. This will be used to create a 4D kernel---the extra two + dimensions are for channels (see `tf.nn.conv2d` documentation), + and the kernel will be constructed so that the channels are + independent: each channel only observes the data from neighboring + pixels of the same channel. + channels: An integer representing the number of channels in the + image (e.g., 3 for RGB). + + Returns: + A 3D `float32` `Tensor` of the same shape as the input. + """ + with tf.name_scope(name, "convolve"): + tf.compat.v1.assert_type(image, tf.float32) + channel_filter = tf.eye(channels) + filter_ = tf.expand_dims( + tf.expand_dims(pixel_filter, -1), -1 + ) * tf.expand_dims(tf.expand_dims(channel_filter, 0), 0) + result_batch = tf.nn.conv2d( + tf.stack([image]), # batch + filter=filter_, + strides=[1, 1, 1, 1], + padding="SAME", + ) + return result_batch[0] # unbatch def get_image(verbose=False): - """Get the image as a TensorFlow variable. + """Get the image as a TensorFlow variable. - Returns: - A `tf.Variable`, which must be initialized prior to use: - invoke `sess.run(result.initializer)`.""" - base_data = tf.constant(image_data(verbose=verbose)) - base_image = tf.image.decode_image(base_data, channels=3) - base_image.set_shape((IMAGE_HEIGHT, IMAGE_WIDTH, 3)) - parsed_image = tf.Variable(base_image, name='image', dtype=tf.uint8) - return parsed_image + Returns: + A `tf.Variable`, which must be initialized prior to use: + invoke `sess.run(result.initializer)`. + """ + base_data = tf.constant(image_data(verbose=verbose)) + base_image = tf.image.decode_image(base_data, channels=3) + base_image.set_shape((IMAGE_HEIGHT, IMAGE_WIDTH, 3)) + parsed_image = tf.Variable(base_image, name="image", dtype=tf.uint8) + return parsed_image def run_box_to_gaussian(logdir, verbose=False): - """Run a box-blur-to-Gaussian-blur demonstration. - - See the summary description for more details. - - Arguments: - logdir: Directory into which to write event logs. - verbose: Boolean; whether to log any output. - """ - if verbose: - logger.info('--- Starting run: box_to_gaussian') - - tf.compat.v1.reset_default_graph() - tf.compat.v1.set_random_seed(0) - - image = get_image(verbose=verbose) - blur_radius = tf.compat.v1.placeholder(shape=(), dtype=tf.int32) - with tf.name_scope('filter'): - blur_side_length = blur_radius * 2 + 1 - pixel_filter = tf.ones((blur_side_length, blur_side_length)) - pixel_filter = (pixel_filter - / tf.cast(tf.size(input=pixel_filter), tf.float32)) # normalize - - iterations = 4 - images = [tf.cast(image, tf.float32) / 255.0] - for _ in xrange(iterations): - images.append(convolve(images[-1], pixel_filter)) - with tf.name_scope('convert_to_uint8'): - images = tf.stack( - [tf.cast(255 * tf.clip_by_value(image_, 0.0, 1.0), tf.uint8) - for image_ in images]) - - summ = image_summary.op( - 'box_to_gaussian', images, max_outputs=iterations, - display_name='Gaussian blur as a limit process of box blurs', - description=('Demonstration of forming a Gaussian blur by ' - 'composing box blurs, each of which can be expressed ' - 'as a 2D convolution.\n\n' - 'A Gaussian blur is formed by convolving a Gaussian ' - 'kernel over an image. But a Gaussian kernel is ' - 'itself the limit of convolving a constant kernel ' - 'with itself many times. Thus, while applying ' - 'a box-filter convolution just once produces ' - 'results that are noticeably different from those ' - 'of a Gaussian blur, repeating the same convolution ' - 'just a few times causes the result to rapidly ' - 'converge to an actual Gaussian blur.\n\n' - 'Here, the step value controls the blur radius, ' - 'and the image sample controls the number of times ' - 'that the convolution is applied (plus one). ' - 'So, when *sample*=1, the original image is shown; ' - '*sample*=2 shows a box blur; and a hypothetical ' - '*sample*=∞ would show a true Gaussian blur.\n\n' - 'This is one ingredient in a recipe to compute very ' - 'fast Gaussian blurs. The other pieces require ' - 'special treatment for the box blurs themselves ' - '(decomposition to dual one-dimensional box blurs, ' - 'each of which is computed with a sliding window); ' - 'we don’t perform those optimizations here.\n\n' - '[Here are some slides describing the full process.]' - '(%s)\n\n' - '%s' - % ('http://elynxsdk.free.fr/ext-docs/Blur/Fast_box_blur.pdf', - IMAGE_CREDIT))) - - with tf.compat.v1.Session() as sess: - sess.run(image.initializer) - writer = tf.summary.FileWriter(os.path.join(logdir, 'box_to_gaussian')) - writer.add_graph(sess.graph) - for step in xrange(8): - if verbose: - logger.info('--- box_to_gaussian: step: %s' % step) - feed_dict = {blur_radius: step} - run_options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE) - run_metadata = config_pb2.RunMetadata() - s = sess.run(summ, feed_dict=feed_dict, - options=run_options, run_metadata=run_metadata) - writer.add_summary(s, global_step=step) - writer.add_run_metadata(run_metadata, 'step_%04d' % step) - writer.close() + """Run a box-blur-to-Gaussian-blur demonstration. + + See the summary description for more details. + + Arguments: + logdir: Directory into which to write event logs. + verbose: Boolean; whether to log any output. + """ + if verbose: + logger.info("--- Starting run: box_to_gaussian") + + tf.compat.v1.reset_default_graph() + tf.compat.v1.set_random_seed(0) + + image = get_image(verbose=verbose) + blur_radius = tf.compat.v1.placeholder(shape=(), dtype=tf.int32) + with tf.name_scope("filter"): + blur_side_length = blur_radius * 2 + 1 + pixel_filter = tf.ones((blur_side_length, blur_side_length)) + pixel_filter = pixel_filter / tf.cast( + tf.size(input=pixel_filter), tf.float32 + ) # normalize + + iterations = 4 + images = [tf.cast(image, tf.float32) / 255.0] + for _ in xrange(iterations): + images.append(convolve(images[-1], pixel_filter)) + with tf.name_scope("convert_to_uint8"): + images = tf.stack( + [ + tf.cast(255 * tf.clip_by_value(image_, 0.0, 1.0), tf.uint8) + for image_ in images + ] + ) + + summ = image_summary.op( + "box_to_gaussian", + images, + max_outputs=iterations, + display_name="Gaussian blur as a limit process of box blurs", + description=( + "Demonstration of forming a Gaussian blur by " + "composing box blurs, each of which can be expressed " + "as a 2D convolution.\n\n" + "A Gaussian blur is formed by convolving a Gaussian " + "kernel over an image. But a Gaussian kernel is " + "itself the limit of convolving a constant kernel " + "with itself many times. Thus, while applying " + "a box-filter convolution just once produces " + "results that are noticeably different from those " + "of a Gaussian blur, repeating the same convolution " + "just a few times causes the result to rapidly " + "converge to an actual Gaussian blur.\n\n" + "Here, the step value controls the blur radius, " + "and the image sample controls the number of times " + "that the convolution is applied (plus one). " + "So, when *sample*=1, the original image is shown; " + "*sample*=2 shows a box blur; and a hypothetical " + "*sample*=∞ would show a true Gaussian blur.\n\n" + "This is one ingredient in a recipe to compute very " + "fast Gaussian blurs. The other pieces require " + "special treatment for the box blurs themselves " + "(decomposition to dual one-dimensional box blurs, " + "each of which is computed with a sliding window); " + "we don’t perform those optimizations here.\n\n" + "[Here are some slides describing the full process.]" + "(%s)\n\n" + "%s" + % ( + "http://elynxsdk.free.fr/ext-docs/Blur/Fast_box_blur.pdf", + IMAGE_CREDIT, + ) + ), + ) + + with tf.compat.v1.Session() as sess: + sess.run(image.initializer) + writer = tf.summary.FileWriter(os.path.join(logdir, "box_to_gaussian")) + writer.add_graph(sess.graph) + for step in xrange(8): + if verbose: + logger.info("--- box_to_gaussian: step: %s" % step) + feed_dict = {blur_radius: step} + run_options = tf.compat.v1.RunOptions( + trace_level=tf.compat.v1.RunOptions.FULL_TRACE + ) + run_metadata = config_pb2.RunMetadata() + s = sess.run( + summ, + feed_dict=feed_dict, + options=run_options, + run_metadata=run_metadata, + ) + writer.add_summary(s, global_step=step) + writer.add_run_metadata(run_metadata, "step_%04d" % step) + writer.close() def run_sobel(logdir, verbose=False): - """Run a Sobel edge detection demonstration. - - See the summary description for more details. - - Arguments: - logdir: Directory into which to write event logs. - verbose: Boolean; whether to log any output. - """ - if verbose: - logger.info('--- Starting run: sobel') - - tf.compat.v1.reset_default_graph() - tf.compat.v1.set_random_seed(0) - - image = get_image(verbose=verbose) - kernel_radius = tf.compat.v1.placeholder(shape=(), dtype=tf.int32) - - with tf.name_scope('horizontal_kernel'): - kernel_side_length = kernel_radius * 2 + 1 - # Drop off influence for pixels further away from the center. - weighting_kernel = ( - 1.0 - tf.abs(tf.linspace(-1.0, 1.0, num=kernel_side_length))) - differentiation_kernel = tf.linspace(-1.0, 1.0, num=kernel_side_length) - horizontal_kernel = tf.matmul(tf.expand_dims(weighting_kernel, 1), - tf.expand_dims(differentiation_kernel, 0)) - - with tf.name_scope('vertical_kernel'): - vertical_kernel = tf.transpose(a=horizontal_kernel) - - float_image = tf.cast(image, tf.float32) - dx = convolve(float_image, horizontal_kernel, name='convolve_dx') - dy = convolve(float_image, vertical_kernel, name='convolve_dy') - gradient_magnitude = tf.norm(tensor=[dx, dy], axis=0, name='gradient_magnitude') - with tf.name_scope('normalized_gradient'): - normalized_gradient = gradient_magnitude / tf.reduce_max(input_tensor=gradient_magnitude) - with tf.name_scope('output_image'): - output_image = tf.cast(255 * normalized_gradient, tf.uint8) - - summ = image_summary.op( - 'sobel', tf.stack([output_image]), - display_name='Sobel edge detection', - description=(u'Demonstration of [Sobel edge detection]. The step ' - 'parameter adjusts the radius of the kernel. ' - 'The kernel can be of arbitrary size, and considers ' - u'nearby pixels with \u2113\u2082-linear falloff.\n\n' - # (that says ``$\ell_2$-linear falloff'') - 'Edge detection is done on a per-channel basis, so ' - 'you can observe which edges are “mostly red ' - 'edges,” for instance.\n\n' - 'For practical edge detection, a small kernel ' - '(usually not more than more than *r*=2) is best.\n\n' - '[Sobel edge detection]: %s\n\n' - "%s" - % ('https://en.wikipedia.org/wiki/Sobel_operator', - IMAGE_CREDIT))) - - with tf.compat.v1.Session() as sess: - sess.run(image.initializer) - writer = tf.summary.FileWriter(os.path.join(logdir, 'sobel')) - writer.add_graph(sess.graph) - for step in xrange(8): - if verbose: - logger.info("--- sobel: step: %s" % step) - feed_dict = {kernel_radius: step} - run_options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE) - run_metadata = config_pb2.RunMetadata() - s = sess.run(summ, feed_dict=feed_dict, - options=run_options, run_metadata=run_metadata) - writer.add_summary(s, global_step=step) - writer.add_run_metadata(run_metadata, 'step_%04d' % step) - writer.close() + """Run a Sobel edge detection demonstration. + + See the summary description for more details. + + Arguments: + logdir: Directory into which to write event logs. + verbose: Boolean; whether to log any output. + """ + if verbose: + logger.info("--- Starting run: sobel") + + tf.compat.v1.reset_default_graph() + tf.compat.v1.set_random_seed(0) + + image = get_image(verbose=verbose) + kernel_radius = tf.compat.v1.placeholder(shape=(), dtype=tf.int32) + + with tf.name_scope("horizontal_kernel"): + kernel_side_length = kernel_radius * 2 + 1 + # Drop off influence for pixels further away from the center. + weighting_kernel = 1.0 - tf.abs( + tf.linspace(-1.0, 1.0, num=kernel_side_length) + ) + differentiation_kernel = tf.linspace(-1.0, 1.0, num=kernel_side_length) + horizontal_kernel = tf.matmul( + tf.expand_dims(weighting_kernel, 1), + tf.expand_dims(differentiation_kernel, 0), + ) + + with tf.name_scope("vertical_kernel"): + vertical_kernel = tf.transpose(a=horizontal_kernel) + + float_image = tf.cast(image, tf.float32) + dx = convolve(float_image, horizontal_kernel, name="convolve_dx") + dy = convolve(float_image, vertical_kernel, name="convolve_dy") + gradient_magnitude = tf.norm( + tensor=[dx, dy], axis=0, name="gradient_magnitude" + ) + with tf.name_scope("normalized_gradient"): + normalized_gradient = gradient_magnitude / tf.reduce_max( + input_tensor=gradient_magnitude + ) + with tf.name_scope("output_image"): + output_image = tf.cast(255 * normalized_gradient, tf.uint8) + + summ = image_summary.op( + "sobel", + tf.stack([output_image]), + display_name="Sobel edge detection", + description=( + u"Demonstration of [Sobel edge detection]. The step " + "parameter adjusts the radius of the kernel. " + "The kernel can be of arbitrary size, and considers " + u"nearby pixels with \u2113\u2082-linear falloff.\n\n" + # (that says ``$\ell_2$-linear falloff'') + "Edge detection is done on a per-channel basis, so " + "you can observe which edges are “mostly red " + "edges,” for instance.\n\n" + "For practical edge detection, a small kernel " + "(usually not more than more than *r*=2) is best.\n\n" + "[Sobel edge detection]: %s\n\n" + "%s" + % ("https://en.wikipedia.org/wiki/Sobel_operator", IMAGE_CREDIT) + ), + ) + + with tf.compat.v1.Session() as sess: + sess.run(image.initializer) + writer = tf.summary.FileWriter(os.path.join(logdir, "sobel")) + writer.add_graph(sess.graph) + for step in xrange(8): + if verbose: + logger.info("--- sobel: step: %s" % step) + feed_dict = {kernel_radius: step} + run_options = tf.compat.v1.RunOptions( + trace_level=tf.compat.v1.RunOptions.FULL_TRACE + ) + run_metadata = config_pb2.RunMetadata() + s = sess.run( + summ, + feed_dict=feed_dict, + options=run_options, + run_metadata=run_metadata, + ) + writer.add_summary(s, global_step=step) + writer.add_run_metadata(run_metadata, "step_%04d" % step) + writer.close() def run_all(logdir, verbose=False): - """Run simulations on a reasonable set of parameters. + """Run simulations on a reasonable set of parameters. - Arguments: - logdir: the directory into which to store all the runs' data - verbose: if true, print out each run's name as it begins - """ - run_box_to_gaussian(logdir, verbose=verbose) - run_sobel(logdir, verbose=verbose) + Arguments: + logdir: the directory into which to store all the runs' data + verbose: if true, print out each run's name as it begins + """ + run_box_to_gaussian(logdir, verbose=verbose) + run_sobel(logdir, verbose=verbose) def main(unused_argv): - logging.set_verbosity(logging.INFO) - logger.info('Saving output to %s.' % LOGDIR) - run_all(LOGDIR, verbose=True) - logger.info('Done. Output saved to %s.' % LOGDIR) + logging.set_verbosity(logging.INFO) + logger.info("Saving output to %s." % LOGDIR) + run_all(LOGDIR, verbose=True) + logger.info("Done. Output saved to %s." % LOGDIR) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/tensorboard/plugins/image/images_plugin.py b/tensorboard/plugins/image/images_plugin.py index dfb0c795bf..a07b83d37e 100644 --- a/tensorboard/plugins/image/images_plugin.py +++ b/tensorboard/plugins/image/images_plugin.py @@ -33,73 +33,78 @@ _IMGHDR_TO_MIMETYPE = { - 'bmp': 'image/bmp', - 'gif': 'image/gif', - 'jpeg': 'image/jpeg', - 'png': 'image/png', - 'svg': 'image/svg+xml' + "bmp": "image/bmp", + "gif": "image/gif", + "jpeg": "image/jpeg", + "png": "image/png", + "svg": "image/svg+xml", } -_DEFAULT_IMAGE_MIMETYPE = 'application/octet-stream' +_DEFAULT_IMAGE_MIMETYPE = "application/octet-stream" # Extend imghdr.tests to include svg. def detect_svg(data, f): - del f # Unused. - # Assume XML documents attached to image tag to be SVG. - if data.startswith(b'= 1 - ''', - {'plugin': metadata.PLUGIN_NAME}) - result = collections.defaultdict(dict) - for row in cursor: - run_name, tag_name, display_name, description, samples = row - description = description or '' # Handle missing descriptions. - result[run_name][tag_name] = { - 'displayName': display_name, - 'description': plugin_util.markdown_to_safe_html(description), - 'samples': samples - } - return result - - runs = self._multiplexer.Runs() - result = {run: {} for run in runs} - mapping = self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME) - for (run, tag_to_content) in six.iteritems(mapping): - for tag in tag_to_content: - summary_metadata = self._multiplexer.SummaryMetadata(run, tag) - tensor_events = self._multiplexer.Tensors(run, tag) - samples = max([len(event.tensor_proto.string_val[2:]) # width, height - for event in tensor_events] + [0]) - result[run][tag] = {'displayName': summary_metadata.display_name, - 'description': plugin_util.markdown_to_safe_html( - summary_metadata.summary_description), - 'samples': samples} - return result - - @wrappers.Request.application - def _serve_image_metadata(self, request): - """Given a tag and list of runs, serve a list of metadata for images. - - Note that the images themselves are not sent; instead, we respond with URLs - to the images. The frontend should treat these URLs as opaque and should not - try to parse information about them or generate them itself, as the format - may change. - - Args: - request: A werkzeug.wrappers.Request object. - - Returns: - A werkzeug.Response application. - """ - tag = request.args.get('tag') - run = request.args.get('run') - sample = int(request.args.get('sample', 0)) - try: - response = self._image_response_for_run(run, tag, sample) - except KeyError: - return http_util.Respond( - request, 'Invalid run or tag', 'text/plain', code=400 - ) - return http_util.Respond(request, response, 'application/json') - - def _image_response_for_run(self, run, tag, sample): - """Builds a JSON-serializable object with information about images. - - Args: - run: The name of the run. - tag: The name of the tag the images all belong to. - sample: The zero-indexed sample of the image for which to retrieve - information. For instance, setting `sample` to `2` will fetch - information about only the third image of each batch. Steps with - fewer than three images will be omitted from the results. - - Returns: - A list of dictionaries containing the wall time, step, URL, width, and - height for each image. - """ - if self._db_connection_provider: - db = self._db_connection_provider() - cursor = db.execute( - ''' + """, + {"plugin": metadata.PLUGIN_NAME}, + ) + result = collections.defaultdict(dict) + for row in cursor: + run_name, tag_name, display_name, description, samples = row + description = description or "" # Handle missing descriptions. + result[run_name][tag_name] = { + "displayName": display_name, + "description": plugin_util.markdown_to_safe_html( + description + ), + "samples": samples, + } + return result + + runs = self._multiplexer.Runs() + result = {run: {} for run in runs} + mapping = self._multiplexer.PluginRunToTagToContent( + metadata.PLUGIN_NAME + ) + for (run, tag_to_content) in six.iteritems(mapping): + for tag in tag_to_content: + summary_metadata = self._multiplexer.SummaryMetadata(run, tag) + tensor_events = self._multiplexer.Tensors(run, tag) + samples = max( + [ + len(event.tensor_proto.string_val[2:]) # width, height + for event in tensor_events + ] + + [0] + ) + result[run][tag] = { + "displayName": summary_metadata.display_name, + "description": plugin_util.markdown_to_safe_html( + summary_metadata.summary_description + ), + "samples": samples, + } + return result + + @wrappers.Request.application + def _serve_image_metadata(self, request): + """Given a tag and list of runs, serve a list of metadata for images. + + Note that the images themselves are not sent; instead, we respond with URLs + to the images. The frontend should treat these URLs as opaque and should not + try to parse information about them or generate them itself, as the format + may change. + + Args: + request: A werkzeug.wrappers.Request object. + + Returns: + A werkzeug.Response application. + """ + tag = request.args.get("tag") + run = request.args.get("run") + sample = int(request.args.get("sample", 0)) + try: + response = self._image_response_for_run(run, tag, sample) + except KeyError: + return http_util.Respond( + request, "Invalid run or tag", "text/plain", code=400 + ) + return http_util.Respond(request, response, "application/json") + + def _image_response_for_run(self, run, tag, sample): + """Builds a JSON-serializable object with information about images. + + Args: + run: The name of the run. + tag: The name of the tag the images all belong to. + sample: The zero-indexed sample of the image for which to retrieve + information. For instance, setting `sample` to `2` will fetch + information about only the third image of each batch. Steps with + fewer than three images will be omitted from the results. + + Returns: + A list of dictionaries containing the wall time, step, URL, width, and + height for each image. + """ + if self._db_connection_provider: + db = self._db_connection_provider() + cursor = db.execute( + """ SELECT computed_time, step, @@ -214,81 +232,94 @@ def _image_response_for_run(self, run, tag, sample): AND T0.idx = 0 AND T1.idx = 1 ORDER BY step - ''', - {'run': run, 'tag': tag, 'dtype': tf.string.as_datatype_enum}) - return [{ - 'wall_time': computed_time, - 'step': step, - 'width': width, - 'height': height, - 'query': self._query_for_individual_image(run, tag, sample, index) - } for index, (computed_time, step, width, height) in enumerate(cursor)] - response = [] - index = 0 - tensor_events = self._multiplexer.Tensors(run, tag) - filtered_events = self._filter_by_sample(tensor_events, sample) - for (index, tensor_event) in enumerate(filtered_events): - (width, height) = tensor_event.tensor_proto.string_val[:2] - response.append({ - 'wall_time': tensor_event.wall_time, - 'step': tensor_event.step, - # We include the size so that the frontend can add that to the - # tag so that the page layout doesn't change when the image loads. - 'width': int(width), - 'height': int(height), - 'query': self._query_for_individual_image(run, tag, sample, index) - }) - return response - - def _filter_by_sample(self, tensor_events, sample): - return [tensor_event for tensor_event in tensor_events - if (len(tensor_event.tensor_proto.string_val) - 2 # width, height - > sample)] - - def _query_for_individual_image(self, run, tag, sample, index): - """Builds a URL for accessing the specified image. - - This should be kept in sync with _serve_image_metadata. Note that the URL is - *not* guaranteed to always return the same image, since images may be - unloaded from the reservoir as new images come in. - - Args: - run: The name of the run. - tag: The tag. - sample: The relevant sample index, zero-indexed. See documentation - on `_image_response_for_run` for more details. - index: The index of the image. Negative values are OK. - - Returns: - A string representation of a URL that will load the index-th sampled image - in the given run with the given tag. - """ - query_string = urllib.parse.urlencode({ - 'run': run, - 'tag': tag, - 'sample': sample, - 'index': index, - }) - return query_string - - def _get_individual_image(self, run, tag, index, sample): - """ - Returns the actual image bytes for a given image. - - Args: - run: The name of the run the image belongs to. - tag: The name of the tag the images belongs to. - index: The index of the image in the current reservoir. - sample: The zero-indexed sample of the image to retrieve (for example, - setting `sample` to `2` will fetch the third image sample at `step`). - - Returns: - A bytestring of the raw image bytes. - """ - if self._db_connection_provider: - db = self._db_connection_provider() - cursor = db.execute( - ''' + """, + {"run": run, "tag": tag, "dtype": tf.string.as_datatype_enum}, + ) + return [ + { + "wall_time": computed_time, + "step": step, + "width": width, + "height": height, + "query": self._query_for_individual_image( + run, tag, sample, index + ), + } + for index, (computed_time, step, width, height) in enumerate( + cursor + ) + ] + response = [] + index = 0 + tensor_events = self._multiplexer.Tensors(run, tag) + filtered_events = self._filter_by_sample(tensor_events, sample) + for (index, tensor_event) in enumerate(filtered_events): + (width, height) = tensor_event.tensor_proto.string_val[:2] + response.append( + { + "wall_time": tensor_event.wall_time, + "step": tensor_event.step, + # We include the size so that the frontend can add that to the + # tag so that the page layout doesn't change when the image loads. + "width": int(width), + "height": int(height), + "query": self._query_for_individual_image( + run, tag, sample, index + ), + } + ) + return response + + def _filter_by_sample(self, tensor_events, sample): + return [ + tensor_event + for tensor_event in tensor_events + if ( + len(tensor_event.tensor_proto.string_val) - 2 # width, height + > sample + ) + ] + + def _query_for_individual_image(self, run, tag, sample, index): + """Builds a URL for accessing the specified image. + + This should be kept in sync with _serve_image_metadata. Note that the URL is + *not* guaranteed to always return the same image, since images may be + unloaded from the reservoir as new images come in. + + Args: + run: The name of the run. + tag: The tag. + sample: The relevant sample index, zero-indexed. See documentation + on `_image_response_for_run` for more details. + index: The index of the image. Negative values are OK. + + Returns: + A string representation of a URL that will load the index-th sampled image + in the given run with the given tag. + """ + query_string = urllib.parse.urlencode( + {"run": run, "tag": tag, "sample": sample, "index": index,} + ) + return query_string + + def _get_individual_image(self, run, tag, index, sample): + """Returns the actual image bytes for a given image. + + Args: + run: The name of the run the image belongs to. + tag: The name of the tag the images belongs to. + index: The index of the image in the current reservoir. + sample: The zero-indexed sample of the image to retrieve (for example, + setting `sample` to `2` will fetch the third image sample at `step`). + + Returns: + A bytestring of the raw image bytes. + """ + if self._db_connection_provider: + db = self._db_connection_provider() + cursor = db.execute( + """ SELECT data FROM TensorStrings WHERE @@ -312,37 +343,47 @@ def _get_individual_image(self, run, tag, index, sample): ORDER BY step LIMIT 1 OFFSET :index) - ''', - {'run': run, - 'tag': tag, - 'sample': sample, - 'index': index, - 'dtype': tf.string.as_datatype_enum}) - (data,) = cursor.fetchone() - return six.binary_type(data) - - events = self._filter_by_sample(self._multiplexer.Tensors(run, tag), sample) - images = events[index].tensor_proto.string_val[2:] # skip width, height - return images[sample] - - @wrappers.Request.application - def _serve_individual_image(self, request): - """Serves an individual image.""" - run = request.args.get('run') - tag = request.args.get('tag') - index = int(request.args.get('index', '0')) - sample = int(request.args.get('sample', '0')) - try: - data = self._get_individual_image(run, tag, index, sample) - except (KeyError, IndexError): - return http_util.Respond( - request, 'Invalid run, tag, index, or sample', 'text/plain', code=400 - ) - image_type = imghdr.what(None, data) - content_type = _IMGHDR_TO_MIMETYPE.get(image_type, _DEFAULT_IMAGE_MIMETYPE) - return http_util.Respond(request, data, content_type) - - @wrappers.Request.application - def _serve_tags(self, request): - index = self._index_impl() - return http_util.Respond(request, index, 'application/json') + """, + { + "run": run, + "tag": tag, + "sample": sample, + "index": index, + "dtype": tf.string.as_datatype_enum, + }, + ) + (data,) = cursor.fetchone() + return six.binary_type(data) + + events = self._filter_by_sample( + self._multiplexer.Tensors(run, tag), sample + ) + images = events[index].tensor_proto.string_val[2:] # skip width, height + return images[sample] + + @wrappers.Request.application + def _serve_individual_image(self, request): + """Serves an individual image.""" + run = request.args.get("run") + tag = request.args.get("tag") + index = int(request.args.get("index", "0")) + sample = int(request.args.get("sample", "0")) + try: + data = self._get_individual_image(run, tag, index, sample) + except (KeyError, IndexError): + return http_util.Respond( + request, + "Invalid run, tag, index, or sample", + "text/plain", + code=400, + ) + image_type = imghdr.what(None, data) + content_type = _IMGHDR_TO_MIMETYPE.get( + image_type, _DEFAULT_IMAGE_MIMETYPE + ) + return http_util.Respond(request, data, content_type) + + @wrappers.Request.application + def _serve_tags(self, request): + index = self._index_impl() + return http_util.Respond(request, index, "application/json") diff --git a/tensorboard/plugins/image/images_plugin_test.py b/tensorboard/plugins/image/images_plugin_test.py index 4184dc1561..ef7daa2f61 100644 --- a/tensorboard/plugins/image/images_plugin_test.py +++ b/tensorboard/plugins/image/images_plugin_test.py @@ -32,7 +32,9 @@ from werkzeug import wrappers from tensorboard.backend import application -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from tensorboard.backend.event_processing import ( + plugin_event_multiplexer as event_multiplexer, +) from tensorboard.plugins import base_plugin from tensorboard.plugins.image import summary from tensorboard.plugins.image import images_plugin @@ -42,179 +44,204 @@ class ImagesPluginTest(tf.test.TestCase): - - def setUp(self): - self.log_dir = tempfile.mkdtemp() - - # We use numpy.random to generate images. We seed to avoid non-determinism - # in this test. - numpy.random.seed(42) - - # Create old-style image summaries for run "foo". - tf.compat.v1.reset_default_graph() - sess = tf.compat.v1.Session() - placeholder = tf.compat.v1.placeholder(tf.uint8) - tf.compat.v1.summary.image(name="baz", tensor=placeholder) - merged_summary_op = tf.compat.v1.summary.merge_all() - foo_directory = os.path.join(self.log_dir, "foo") - with test_util.FileWriterCache.get(foo_directory) as writer: - writer.add_graph(sess.graph) - for step in xrange(2): - writer.add_summary(sess.run(merged_summary_op, feed_dict={ - placeholder: (numpy.random.rand(1, 16, 42, 3) * 255).astype( - numpy.uint8) - }), global_step=step) - - # Create new-style image summaries for run bar. - tf.compat.v1.reset_default_graph() - sess = tf.compat.v1.Session() - placeholder = tf.compat.v1.placeholder(tf.uint8) - summary.op(name="quux", images=placeholder, - description="how do you pronounce that, anyway?") - merged_summary_op = tf.compat.v1.summary.merge_all() - bar_directory = os.path.join(self.log_dir, "bar") - with test_util.FileWriterCache.get(bar_directory) as writer: - writer.add_graph(sess.graph) - for step in xrange(2): - writer.add_summary(sess.run(merged_summary_op, feed_dict={ - placeholder: (numpy.random.rand(1, 8, 6, 3) * 255).astype( - numpy.uint8) - }), global_step=step) - - # Start a server with the plugin. - multiplexer = event_multiplexer.EventMultiplexer({ - "foo": foo_directory, - "bar": bar_directory, - }) - multiplexer.Reload() - context = base_plugin.TBContext( - logdir=self.log_dir, multiplexer=multiplexer) - plugin = images_plugin.ImagesPlugin(context) - wsgi_app = application.TensorBoardWSGI([plugin]) - self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse) - self.routes = plugin.get_plugin_apps() - - def tearDown(self): - shutil.rmtree(self.log_dir, ignore_errors=True) - - def _DeserializeResponse(self, byte_content): - """Deserializes byte content that is a JSON encoding. - - Args: - byte_content: The byte content of a response. - - Returns: - The deserialized python object decoded from JSON. - """ - return json.loads(byte_content.decode("utf-8")) - - def testRoutesProvided(self): - """Tests that the plugin offers the correct routes.""" - self.assertIsInstance(self.routes["/images"], collections.Callable) - self.assertIsInstance(self.routes["/individualImage"], collections.Callable) - self.assertIsInstance(self.routes["/tags"], collections.Callable) - - def testOldStyleImagesRoute(self): - """Tests that the /images routes returns correct old-style data.""" - response = self.server.get( - "/data/plugin/images/images?run=foo&tag=baz/image/0&sample=0") - self.assertEqual(200, response.status_code) - - # Verify that the correct entries are returned. - entries = self._DeserializeResponse(response.get_data()) - self.assertEqual(2, len(entries)) - - # Verify that the 1st entry is correct. - entry = entries[0] - self.assertEqual(42, entry["width"]) - self.assertEqual(16, entry["height"]) - self.assertEqual(0, entry["step"]) - parsed_query = urllib.parse.parse_qs(entry["query"]) - self.assertListEqual(["foo"], parsed_query["run"]) - self.assertListEqual(["baz/image/0"], parsed_query["tag"]) - self.assertListEqual(["0"], parsed_query["sample"]) - self.assertListEqual(["0"], parsed_query["index"]) - - # Verify that the 2nd entry is correct. - entry = entries[1] - self.assertEqual(42, entry["width"]) - self.assertEqual(16, entry["height"]) - self.assertEqual(1, entry["step"]) - parsed_query = urllib.parse.parse_qs(entry["query"]) - self.assertListEqual(["foo"], parsed_query["run"]) - self.assertListEqual(["baz/image/0"], parsed_query["tag"]) - self.assertListEqual(["0"], parsed_query["sample"]) - self.assertListEqual(["1"], parsed_query["index"]) - - def testNewStyleImagesRoute(self): - """Tests that the /images routes returns correct new-style data.""" - response = self.server.get( - "/data/plugin/images/images?run=bar&tag=quux/image_summary&sample=0") - self.assertEqual(200, response.status_code) - - # Verify that the correct entries are returned. - entries = self._DeserializeResponse(response.get_data()) - self.assertEqual(2, len(entries)) - - # Verify that the 1st entry is correct. - entry = entries[0] - self.assertEqual(6, entry["width"]) - self.assertEqual(8, entry["height"]) - self.assertEqual(0, entry["step"]) - parsed_query = urllib.parse.parse_qs(entry["query"]) - self.assertListEqual(["bar"], parsed_query["run"]) - self.assertListEqual(["quux/image_summary"], parsed_query["tag"]) - self.assertListEqual(["0"], parsed_query["sample"]) - self.assertListEqual(["0"], parsed_query["index"]) - - # Verify that the 2nd entry is correct. - entry = entries[1] - self.assertEqual(6, entry["width"]) - self.assertEqual(8, entry["height"]) - self.assertEqual(1, entry["step"]) - parsed_query = urllib.parse.parse_qs(entry["query"]) - self.assertListEqual(["bar"], parsed_query["run"]) - self.assertListEqual(["quux/image_summary"], parsed_query["tag"]) - self.assertListEqual(["0"], parsed_query["sample"]) - self.assertListEqual(["1"], parsed_query["index"]) - - def testOldStyleIndividualImageRoute(self): - """Tests fetching an individual image from an old-style summary.""" - response = self.server.get( - "/data/plugin/images/individualImage" - "?run=foo&tag=baz/image/0&sample=0&index=0") - self.assertEqual(200, response.status_code) - self.assertEqual("image/png", response.headers.get("content-type")) - - def testNewStyleIndividualImageRoute(self): - """Tests fetching an individual image from a new-style summary.""" - response = self.server.get( - "/data/plugin/images/individualImage" - "?run=bar&tag=quux/image_summary&sample=0&index=0") - self.assertEqual(200, response.status_code) - self.assertEqual("image/png", response.headers.get("content-type")) - - def testRunsRoute(self): - """Tests that the /runs route offers the correct run to tag mapping.""" - response = self.server.get("/data/plugin/images/tags") - self.assertEqual(200, response.status_code) - self.assertDictEqual({ - "foo": { - "baz/image/0": { - "displayName": "baz/image/0", - "description": "", - "samples": 1, - }, - }, - "bar": { - "quux/image_summary": { - "displayName": "quux", - "description": "

how do you pronounce that, anyway?

", - "samples": 1, + def setUp(self): + self.log_dir = tempfile.mkdtemp() + + # We use numpy.random to generate images. We seed to avoid non-determinism + # in this test. + numpy.random.seed(42) + + # Create old-style image summaries for run "foo". + tf.compat.v1.reset_default_graph() + sess = tf.compat.v1.Session() + placeholder = tf.compat.v1.placeholder(tf.uint8) + tf.compat.v1.summary.image(name="baz", tensor=placeholder) + merged_summary_op = tf.compat.v1.summary.merge_all() + foo_directory = os.path.join(self.log_dir, "foo") + with test_util.FileWriterCache.get(foo_directory) as writer: + writer.add_graph(sess.graph) + for step in xrange(2): + writer.add_summary( + sess.run( + merged_summary_op, + feed_dict={ + placeholder: ( + numpy.random.rand(1, 16, 42, 3) * 255 + ).astype(numpy.uint8) + }, + ), + global_step=step, + ) + + # Create new-style image summaries for run bar. + tf.compat.v1.reset_default_graph() + sess = tf.compat.v1.Session() + placeholder = tf.compat.v1.placeholder(tf.uint8) + summary.op( + name="quux", + images=placeholder, + description="how do you pronounce that, anyway?", + ) + merged_summary_op = tf.compat.v1.summary.merge_all() + bar_directory = os.path.join(self.log_dir, "bar") + with test_util.FileWriterCache.get(bar_directory) as writer: + writer.add_graph(sess.graph) + for step in xrange(2): + writer.add_summary( + sess.run( + merged_summary_op, + feed_dict={ + placeholder: ( + numpy.random.rand(1, 8, 6, 3) * 255 + ).astype(numpy.uint8) + }, + ), + global_step=step, + ) + + # Start a server with the plugin. + multiplexer = event_multiplexer.EventMultiplexer( + {"foo": foo_directory, "bar": bar_directory,} + ) + multiplexer.Reload() + context = base_plugin.TBContext( + logdir=self.log_dir, multiplexer=multiplexer + ) + plugin = images_plugin.ImagesPlugin(context) + wsgi_app = application.TensorBoardWSGI([plugin]) + self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse) + self.routes = plugin.get_plugin_apps() + + def tearDown(self): + shutil.rmtree(self.log_dir, ignore_errors=True) + + def _DeserializeResponse(self, byte_content): + """Deserializes byte content that is a JSON encoding. + + Args: + byte_content: The byte content of a response. + + Returns: + The deserialized python object decoded from JSON. + """ + return json.loads(byte_content.decode("utf-8")) + + def testRoutesProvided(self): + """Tests that the plugin offers the correct routes.""" + self.assertIsInstance(self.routes["/images"], collections.Callable) + self.assertIsInstance( + self.routes["/individualImage"], collections.Callable + ) + self.assertIsInstance(self.routes["/tags"], collections.Callable) + + def testOldStyleImagesRoute(self): + """Tests that the /images routes returns correct old-style data.""" + response = self.server.get( + "/data/plugin/images/images?run=foo&tag=baz/image/0&sample=0" + ) + self.assertEqual(200, response.status_code) + + # Verify that the correct entries are returned. + entries = self._DeserializeResponse(response.get_data()) + self.assertEqual(2, len(entries)) + + # Verify that the 1st entry is correct. + entry = entries[0] + self.assertEqual(42, entry["width"]) + self.assertEqual(16, entry["height"]) + self.assertEqual(0, entry["step"]) + parsed_query = urllib.parse.parse_qs(entry["query"]) + self.assertListEqual(["foo"], parsed_query["run"]) + self.assertListEqual(["baz/image/0"], parsed_query["tag"]) + self.assertListEqual(["0"], parsed_query["sample"]) + self.assertListEqual(["0"], parsed_query["index"]) + + # Verify that the 2nd entry is correct. + entry = entries[1] + self.assertEqual(42, entry["width"]) + self.assertEqual(16, entry["height"]) + self.assertEqual(1, entry["step"]) + parsed_query = urllib.parse.parse_qs(entry["query"]) + self.assertListEqual(["foo"], parsed_query["run"]) + self.assertListEqual(["baz/image/0"], parsed_query["tag"]) + self.assertListEqual(["0"], parsed_query["sample"]) + self.assertListEqual(["1"], parsed_query["index"]) + + def testNewStyleImagesRoute(self): + """Tests that the /images routes returns correct new-style data.""" + response = self.server.get( + "/data/plugin/images/images?run=bar&tag=quux/image_summary&sample=0" + ) + self.assertEqual(200, response.status_code) + + # Verify that the correct entries are returned. + entries = self._DeserializeResponse(response.get_data()) + self.assertEqual(2, len(entries)) + + # Verify that the 1st entry is correct. + entry = entries[0] + self.assertEqual(6, entry["width"]) + self.assertEqual(8, entry["height"]) + self.assertEqual(0, entry["step"]) + parsed_query = urllib.parse.parse_qs(entry["query"]) + self.assertListEqual(["bar"], parsed_query["run"]) + self.assertListEqual(["quux/image_summary"], parsed_query["tag"]) + self.assertListEqual(["0"], parsed_query["sample"]) + self.assertListEqual(["0"], parsed_query["index"]) + + # Verify that the 2nd entry is correct. + entry = entries[1] + self.assertEqual(6, entry["width"]) + self.assertEqual(8, entry["height"]) + self.assertEqual(1, entry["step"]) + parsed_query = urllib.parse.parse_qs(entry["query"]) + self.assertListEqual(["bar"], parsed_query["run"]) + self.assertListEqual(["quux/image_summary"], parsed_query["tag"]) + self.assertListEqual(["0"], parsed_query["sample"]) + self.assertListEqual(["1"], parsed_query["index"]) + + def testOldStyleIndividualImageRoute(self): + """Tests fetching an individual image from an old-style summary.""" + response = self.server.get( + "/data/plugin/images/individualImage" + "?run=foo&tag=baz/image/0&sample=0&index=0" + ) + self.assertEqual(200, response.status_code) + self.assertEqual("image/png", response.headers.get("content-type")) + + def testNewStyleIndividualImageRoute(self): + """Tests fetching an individual image from a new-style summary.""" + response = self.server.get( + "/data/plugin/images/individualImage" + "?run=bar&tag=quux/image_summary&sample=0&index=0" + ) + self.assertEqual(200, response.status_code) + self.assertEqual("image/png", response.headers.get("content-type")) + + def testRunsRoute(self): + """Tests that the /runs route offers the correct run to tag mapping.""" + response = self.server.get("/data/plugin/images/tags") + self.assertEqual(200, response.status_code) + self.assertDictEqual( + { + "foo": { + "baz/image/0": { + "displayName": "baz/image/0", + "description": "", + "samples": 1, + }, + }, + "bar": { + "quux/image_summary": { + "displayName": "quux", + "description": "

how do you pronounce that, anyway?

", + "samples": 1, + }, + }, }, - }, - }, self._DeserializeResponse(response.get_data())) + self._DeserializeResponse(response.get_data()), + ) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/tensorboard/plugins/image/metadata.py b/tensorboard/plugins/image/metadata.py index f96d9b252e..93f6512579 100644 --- a/tensorboard/plugins/image/metadata.py +++ b/tensorboard/plugins/image/metadata.py @@ -24,7 +24,7 @@ logger = tb_logging.get_logger() -PLUGIN_NAME = 'images' +PLUGIN_NAME = "images" # The most recent value for the `version` field of the `ImagePluginData` # proto. @@ -32,39 +32,43 @@ def create_summary_metadata(display_name, description): - """Create a `summary_pb2.SummaryMetadata` proto for image plugin data. + """Create a `summary_pb2.SummaryMetadata` proto for image plugin data. - Returns: - A `summary_pb2.SummaryMetadata` protobuf object. - """ - content = plugin_data_pb2.ImagePluginData(version=PROTO_VERSION) - metadata = summary_pb2.SummaryMetadata( - display_name=display_name, - summary_description=description, - plugin_data=summary_pb2.SummaryMetadata.PluginData( - plugin_name=PLUGIN_NAME, - content=content.SerializeToString())) - return metadata + Returns: + A `summary_pb2.SummaryMetadata` protobuf object. + """ + content = plugin_data_pb2.ImagePluginData(version=PROTO_VERSION) + metadata = summary_pb2.SummaryMetadata( + display_name=display_name, + summary_description=description, + plugin_data=summary_pb2.SummaryMetadata.PluginData( + plugin_name=PLUGIN_NAME, content=content.SerializeToString() + ), + ) + return metadata def parse_plugin_metadata(content): - """Parse summary metadata to a Python object. + """Parse summary metadata to a Python object. - Arguments: - content: The `content` field of a `SummaryMetadata` proto - corresponding to the image plugin. + Arguments: + content: The `content` field of a `SummaryMetadata` proto + corresponding to the image plugin. - Returns: - An `ImagePluginData` protobuf object. - """ - if not isinstance(content, bytes): - raise TypeError('Content type must be bytes') - result = plugin_data_pb2.ImagePluginData.FromString(content) - if result.version == 0: - return result - else: - logger.warn( - 'Unknown metadata version: %s. The latest version known to ' - 'this build of TensorBoard is %s; perhaps a newer build is ' - 'available?', result.version, PROTO_VERSION) - return result + Returns: + An `ImagePluginData` protobuf object. + """ + if not isinstance(content, bytes): + raise TypeError("Content type must be bytes") + result = plugin_data_pb2.ImagePluginData.FromString(content) + if result.version == 0: + return result + else: + logger.warn( + "Unknown metadata version: %s. The latest version known to " + "this build of TensorBoard is %s; perhaps a newer build is " + "available?", + result.version, + PROTO_VERSION, + ) + return result diff --git a/tensorboard/plugins/image/summary.py b/tensorboard/plugins/image/summary.py index 22fc5829b3..1f52acd7f9 100644 --- a/tensorboard/plugins/image/summary.py +++ b/tensorboard/plugins/image/summary.py @@ -36,110 +36,129 @@ image = summary_v2.image -def op(name, - images, - max_outputs=3, - display_name=None, - description=None, - collections=None): - """Create a legacy image summary op for use in a TensorFlow graph. - - Arguments: - name: A unique name for the generated summary node. - images: A `Tensor` representing pixel data with shape `[k, h, w, c]`, - where `k` is the number of images, `h` and `w` are the height and - width of the images, and `c` is the number of channels, which - should be 1, 3, or 4. Any of the dimensions may be statically - unknown (i.e., `None`). - max_outputs: Optional `int` or rank-0 integer `Tensor`. At most this - many images will be emitted at each step. When more than - `max_outputs` many images are provided, the first `max_outputs` many - images will be used and the rest silently discarded. - display_name: Optional name for this summary in TensorBoard, as a - constant `str`. Defaults to `name`. - description: Optional long-form description for this summary, as a - constant `str`. Markdown is supported. Defaults to empty. - collections: Optional list of graph collections keys. The new - summary op is added to these collections. Defaults to - `[Graph Keys.SUMMARIES]`. - - Returns: - A TensorFlow summary op. - """ - # TODO(nickfelt): remove on-demand imports once dep situation is fixed. - import tensorflow.compat.v1 as tf - - if display_name is None: - display_name = name - summary_metadata = metadata.create_summary_metadata( - display_name=display_name, description=description) - with tf.name_scope(name), \ - tf.control_dependencies([tf.assert_rank(images, 4), - tf.assert_type(images, tf.uint8), - tf.assert_non_negative(max_outputs)]): - limited_images = images[:max_outputs] - encoded_images = tf.map_fn(tf.image.encode_png, limited_images, - dtype=tf.string, - name='encode_each_image') - image_shape = tf.shape(input=images) - dimensions = tf.stack([tf.as_string(image_shape[2], name='width'), - tf.as_string(image_shape[1], name='height')], - name='dimensions') - tensor = tf.concat([dimensions, encoded_images], axis=0) - return tf.summary.tensor_summary(name='image_summary', - tensor=tensor, - collections=collections, - summary_metadata=summary_metadata) +def op( + name, + images, + max_outputs=3, + display_name=None, + description=None, + collections=None, +): + """Create a legacy image summary op for use in a TensorFlow graph. + + Arguments: + name: A unique name for the generated summary node. + images: A `Tensor` representing pixel data with shape `[k, h, w, c]`, + where `k` is the number of images, `h` and `w` are the height and + width of the images, and `c` is the number of channels, which + should be 1, 3, or 4. Any of the dimensions may be statically + unknown (i.e., `None`). + max_outputs: Optional `int` or rank-0 integer `Tensor`. At most this + many images will be emitted at each step. When more than + `max_outputs` many images are provided, the first `max_outputs` many + images will be used and the rest silently discarded. + display_name: Optional name for this summary in TensorBoard, as a + constant `str`. Defaults to `name`. + description: Optional long-form description for this summary, as a + constant `str`. Markdown is supported. Defaults to empty. + collections: Optional list of graph collections keys. The new + summary op is added to these collections. Defaults to + `[Graph Keys.SUMMARIES]`. + + Returns: + A TensorFlow summary op. + """ + # TODO(nickfelt): remove on-demand imports once dep situation is fixed. + import tensorflow.compat.v1 as tf + + if display_name is None: + display_name = name + summary_metadata = metadata.create_summary_metadata( + display_name=display_name, description=description + ) + with tf.name_scope(name), tf.control_dependencies( + [ + tf.assert_rank(images, 4), + tf.assert_type(images, tf.uint8), + tf.assert_non_negative(max_outputs), + ] + ): + limited_images = images[:max_outputs] + encoded_images = tf.map_fn( + tf.image.encode_png, + limited_images, + dtype=tf.string, + name="encode_each_image", + ) + image_shape = tf.shape(input=images) + dimensions = tf.stack( + [ + tf.as_string(image_shape[2], name="width"), + tf.as_string(image_shape[1], name="height"), + ], + name="dimensions", + ) + tensor = tf.concat([dimensions, encoded_images], axis=0) + return tf.summary.tensor_summary( + name="image_summary", + tensor=tensor, + collections=collections, + summary_metadata=summary_metadata, + ) def pb(name, images, max_outputs=3, display_name=None, description=None): - """Create a legacy image summary protobuf. - - This behaves as if you were to create an `op` with the same arguments - (wrapped with constant tensors where appropriate) and then execute - that summary op in a TensorFlow session. - - Arguments: - name: A unique name for the generated summary, including any desired - name scopes. - images: An `np.array` representing pixel data with shape - `[k, h, w, c]`, where `k` is the number of images, `w` and `h` are - the width and height of the images, and `c` is the number of - channels, which should be 1, 3, or 4. - max_outputs: Optional `int`. At most this many images will be - emitted. If more than this many images are provided, the first - `max_outputs` many images will be used and the rest silently - discarded. - display_name: Optional name for this summary in TensorBoard, as a - `str`. Defaults to `name`. - description: Optional long-form description for this summary, as a - `str`. Markdown is supported. Defaults to empty. - - Returns: - A `tf.Summary` protobuf object. - """ - # TODO(nickfelt): remove on-demand imports once dep situation is fixed. - import tensorflow.compat.v1 as tf - - images = np.array(images).astype(np.uint8) - if images.ndim != 4: - raise ValueError('Shape %r must have rank 4' % (images.shape, )) - - limited_images = images[:max_outputs] - encoded_images = [encoder.encode_png(image) for image in limited_images] - (width, height) = (images.shape[2], images.shape[1]) - content = [str(width), str(height)] + encoded_images - tensor = tf.make_tensor_proto(content, dtype=tf.string) - - if display_name is None: - display_name = name - summary_metadata = metadata.create_summary_metadata( - display_name=display_name, description=description) - tf_summary_metadata = tf.SummaryMetadata.FromString( - summary_metadata.SerializeToString()) - - summary = tf.Summary() - summary.value.add(tag='%s/image_summary' % name, - metadata=tf_summary_metadata, - tensor=tensor) - return summary + """Create a legacy image summary protobuf. + + This behaves as if you were to create an `op` with the same arguments + (wrapped with constant tensors where appropriate) and then execute + that summary op in a TensorFlow session. + + Arguments: + name: A unique name for the generated summary, including any desired + name scopes. + images: An `np.array` representing pixel data with shape + `[k, h, w, c]`, where `k` is the number of images, `w` and `h` are + the width and height of the images, and `c` is the number of + channels, which should be 1, 3, or 4. + max_outputs: Optional `int`. At most this many images will be + emitted. If more than this many images are provided, the first + `max_outputs` many images will be used and the rest silently + discarded. + display_name: Optional name for this summary in TensorBoard, as a + `str`. Defaults to `name`. + description: Optional long-form description for this summary, as a + `str`. Markdown is supported. Defaults to empty. + + Returns: + A `tf.Summary` protobuf object. + """ + # TODO(nickfelt): remove on-demand imports once dep situation is fixed. + import tensorflow.compat.v1 as tf + + images = np.array(images).astype(np.uint8) + if images.ndim != 4: + raise ValueError("Shape %r must have rank 4" % (images.shape,)) + + limited_images = images[:max_outputs] + encoded_images = [encoder.encode_png(image) for image in limited_images] + (width, height) = (images.shape[2], images.shape[1]) + content = [str(width), str(height)] + encoded_images + tensor = tf.make_tensor_proto(content, dtype=tf.string) + + if display_name is None: + display_name = name + summary_metadata = metadata.create_summary_metadata( + display_name=display_name, description=description + ) + tf_summary_metadata = tf.SummaryMetadata.FromString( + summary_metadata.SerializeToString() + ) + + summary = tf.Summary() + summary.value.add( + tag="%s/image_summary" % name, + metadata=tf_summary_metadata, + tensor=tensor, + ) + return summary diff --git a/tensorboard/plugins/image/summary_test.py b/tensorboard/plugins/image/summary_test.py index af7e58ca48..7efa02d3fb 100644 --- a/tensorboard/plugins/image/summary_test.py +++ b/tensorboard/plugins/image/summary_test.py @@ -32,208 +32,215 @@ from tensorboard.plugins.image import summary try: - tf2.__version__ # Force lazy import to resolve + tf2.__version__ # Force lazy import to resolve except ImportError: - tf2 = None + tf2 = None try: - tf.compat.v1.enable_eager_execution() + tf.compat.v1.enable_eager_execution() except AttributeError: - # TF 2.0 doesn't have this symbol because eager is the default. - pass + # TF 2.0 doesn't have this symbol because eager is the default. + pass class SummaryBaseTest(object): - - def setUp(self): - super(SummaryBaseTest, self).setUp() - np.random.seed(0) - self.image_width = 20 - self.image_height = 15 - self.image_count = 1 - self.image_channels = 3 - - def _generate_images(self, **kwargs): - size = [ - kwargs.get('n', self.image_count), - kwargs.get('h', self.image_height), - kwargs.get('w', self.image_width), - kwargs.get('c', self.image_channels), - ] - return np.random.uniform(low=0, high=255, size=size).astype(np.uint8) - - def image(self, *args, **kwargs): - raise NotImplementedError() - - def test_tag(self): - data = np.array(1, np.uint8, ndmin=4) - self.assertEqual('a', self.image('a', data).value[0].tag) - self.assertEqual('a/b', self.image('a/b', data).value[0].tag) - - def test_metadata(self): - data = np.array(1, np.uint8, ndmin=4) - description = 'By Leonardo da Vinci' - pb = self.image('mona_lisa', data, description=description) - summary_metadata = pb.value[0].metadata - self.assertEqual(summary_metadata.summary_description, description) - plugin_data = summary_metadata.plugin_data - self.assertEqual(plugin_data.plugin_name, metadata.PLUGIN_NAME) - content = summary_metadata.plugin_data.content - # There's no content, so successfully parsing is fine. - metadata.parse_plugin_metadata(content) - - def test_png_format_roundtrip(self): - images = self._generate_images(c=1) - pb = self.image('mona_lisa', images) - encoded = pb.value[0].tensor.string_val[2] # skip width, height - self.assertAllEqual(images[0], tf.image.decode_png(encoded)) - - def _test_dimensions(self, images): - pb = self.image('mona_lisa', images) - self.assertEqual(1, len(pb.value)) - result = pb.value[0].tensor.string_val - # Check annotated dimensions. - self.assertEqual(tf.compat.as_bytes(str(self.image_width)), result[0]) - self.assertEqual(tf.compat.as_bytes(str(self.image_height)), result[1]) - for i, encoded in enumerate(result[2:]): - decoded = tf.image.decode_png(encoded) - self.assertEqual(images[i].shape, decoded.shape) - - def test_dimensions(self): - self._test_dimensions(self._generate_images(c=1)) - self._test_dimensions(self._generate_images(c=2)) - self._test_dimensions(self._generate_images(c=3)) - self._test_dimensions(self._generate_images(c=4)) - - def test_image_count_zero(self): - shape = (0, self.image_height, self.image_width, 3) - data = np.array([], np.uint8).reshape(shape) - pb = self.image('mona_lisa', data, max_outputs=3) - self.assertEqual(1, len(pb.value)) - result = pb.value[0].tensor.string_val - self.assertEqual(tf.compat.as_bytes(str(self.image_width)), result[0]) - self.assertEqual(tf.compat.as_bytes(str(self.image_height)), result[1]) - self.assertEqual(2, len(result)) - - def test_image_count_less_than_max_outputs(self): - max_outputs = 3 - data = self._generate_images(n=(max_outputs - 1)) - pb = self.image('mona_lisa', data, max_outputs=max_outputs) - self.assertEqual(1, len(pb.value)) - result = pb.value[0].tensor.string_val - image_results = result[2:] # skip width, height - self.assertEqual(len(data), len(image_results)) - - def test_image_count_more_than_max_outputs(self): - max_outputs = 3 - data = self._generate_images(n=(max_outputs + 1)) - pb = self.image('mona_lisa', data, max_outputs=max_outputs) - self.assertEqual(1, len(pb.value)) - result = pb.value[0].tensor.string_val - image_results = result[2:] # skip width, height - self.assertEqual(max_outputs, len(image_results)) - - def test_requires_nonnegative_max_outputs(self): - data = np.array(1, np.uint8, ndmin=4) - with six.assertRaisesRegex( - self, (ValueError, tf.errors.InvalidArgumentError), '>= 0'): - self.image('mona_lisa', data, max_outputs=-1) - - def test_requires_rank_4(self): - with six.assertRaisesRegex(self, ValueError, 'must have rank 4'): - self.image('mona_lisa', [[[1], [2]], [[3], [4]]]) + def setUp(self): + super(SummaryBaseTest, self).setUp() + np.random.seed(0) + self.image_width = 20 + self.image_height = 15 + self.image_count = 1 + self.image_channels = 3 + + def _generate_images(self, **kwargs): + size = [ + kwargs.get("n", self.image_count), + kwargs.get("h", self.image_height), + kwargs.get("w", self.image_width), + kwargs.get("c", self.image_channels), + ] + return np.random.uniform(low=0, high=255, size=size).astype(np.uint8) + + def image(self, *args, **kwargs): + raise NotImplementedError() + + def test_tag(self): + data = np.array(1, np.uint8, ndmin=4) + self.assertEqual("a", self.image("a", data).value[0].tag) + self.assertEqual("a/b", self.image("a/b", data).value[0].tag) + + def test_metadata(self): + data = np.array(1, np.uint8, ndmin=4) + description = "By Leonardo da Vinci" + pb = self.image("mona_lisa", data, description=description) + summary_metadata = pb.value[0].metadata + self.assertEqual(summary_metadata.summary_description, description) + plugin_data = summary_metadata.plugin_data + self.assertEqual(plugin_data.plugin_name, metadata.PLUGIN_NAME) + content = summary_metadata.plugin_data.content + # There's no content, so successfully parsing is fine. + metadata.parse_plugin_metadata(content) + + def test_png_format_roundtrip(self): + images = self._generate_images(c=1) + pb = self.image("mona_lisa", images) + encoded = pb.value[0].tensor.string_val[2] # skip width, height + self.assertAllEqual(images[0], tf.image.decode_png(encoded)) + + def _test_dimensions(self, images): + pb = self.image("mona_lisa", images) + self.assertEqual(1, len(pb.value)) + result = pb.value[0].tensor.string_val + # Check annotated dimensions. + self.assertEqual(tf.compat.as_bytes(str(self.image_width)), result[0]) + self.assertEqual(tf.compat.as_bytes(str(self.image_height)), result[1]) + for i, encoded in enumerate(result[2:]): + decoded = tf.image.decode_png(encoded) + self.assertEqual(images[i].shape, decoded.shape) + + def test_dimensions(self): + self._test_dimensions(self._generate_images(c=1)) + self._test_dimensions(self._generate_images(c=2)) + self._test_dimensions(self._generate_images(c=3)) + self._test_dimensions(self._generate_images(c=4)) + + def test_image_count_zero(self): + shape = (0, self.image_height, self.image_width, 3) + data = np.array([], np.uint8).reshape(shape) + pb = self.image("mona_lisa", data, max_outputs=3) + self.assertEqual(1, len(pb.value)) + result = pb.value[0].tensor.string_val + self.assertEqual(tf.compat.as_bytes(str(self.image_width)), result[0]) + self.assertEqual(tf.compat.as_bytes(str(self.image_height)), result[1]) + self.assertEqual(2, len(result)) + + def test_image_count_less_than_max_outputs(self): + max_outputs = 3 + data = self._generate_images(n=(max_outputs - 1)) + pb = self.image("mona_lisa", data, max_outputs=max_outputs) + self.assertEqual(1, len(pb.value)) + result = pb.value[0].tensor.string_val + image_results = result[2:] # skip width, height + self.assertEqual(len(data), len(image_results)) + + def test_image_count_more_than_max_outputs(self): + max_outputs = 3 + data = self._generate_images(n=(max_outputs + 1)) + pb = self.image("mona_lisa", data, max_outputs=max_outputs) + self.assertEqual(1, len(pb.value)) + result = pb.value[0].tensor.string_val + image_results = result[2:] # skip width, height + self.assertEqual(max_outputs, len(image_results)) + + def test_requires_nonnegative_max_outputs(self): + data = np.array(1, np.uint8, ndmin=4) + with six.assertRaisesRegex( + self, (ValueError, tf.errors.InvalidArgumentError), ">= 0" + ): + self.image("mona_lisa", data, max_outputs=-1) + + def test_requires_rank_4(self): + with six.assertRaisesRegex(self, ValueError, "must have rank 4"): + self.image("mona_lisa", [[[1], [2]], [[3], [4]]]) class SummaryV1PbTest(SummaryBaseTest, tf.test.TestCase): - def image(self, *args, **kwargs): - return summary.pb(*args, **kwargs) + def image(self, *args, **kwargs): + return summary.pb(*args, **kwargs) - def test_tag(self): - data = np.array(1, np.uint8, ndmin=4) - self.assertEqual('a/image_summary', self.image('a', data).value[0].tag) - self.assertEqual('a/b/image_summary', self.image('a/b', data).value[0].tag) + def test_tag(self): + data = np.array(1, np.uint8, ndmin=4) + self.assertEqual("a/image_summary", self.image("a", data).value[0].tag) + self.assertEqual( + "a/b/image_summary", self.image("a/b", data).value[0].tag + ) - def test_requires_nonnegative_max_outputs(self): - self.skipTest('summary V1 pb does not actually enforce this') + def test_requires_nonnegative_max_outputs(self): + self.skipTest("summary V1 pb does not actually enforce this") class SummaryV1OpTest(SummaryBaseTest, tf.test.TestCase): - def image(self, *args, **kwargs): - args = list(args) - # Force first argument to tf.uint8 since the V1 version requires this. - args[1] = tf.cast(tf.constant(args[1]), tf.uint8) - return summary_pb2.Summary.FromString(summary.op(*args, **kwargs).numpy()) - - def test_tag(self): - data = np.array(1, np.uint8, ndmin=4) - self.assertEqual('a/image_summary', self.image('a', data).value[0].tag) - self.assertEqual('a/b/image_summary', self.image('a/b', data).value[0].tag) - - def test_scoped_tag(self): - data = np.array(1, np.uint8, ndmin=4) - with tf.name_scope('scope'): - self.assertEqual('scope/a/image_summary', - self.image('a', data).value[0].tag) - - def test_image_count_zero(self): - self.skipTest('fails under eager because map_fn() returns float dtype') + def image(self, *args, **kwargs): + args = list(args) + # Force first argument to tf.uint8 since the V1 version requires this. + args[1] = tf.cast(tf.constant(args[1]), tf.uint8) + return summary_pb2.Summary.FromString( + summary.op(*args, **kwargs).numpy() + ) + + def test_tag(self): + data = np.array(1, np.uint8, ndmin=4) + self.assertEqual("a/image_summary", self.image("a", data).value[0].tag) + self.assertEqual( + "a/b/image_summary", self.image("a/b", data).value[0].tag + ) + + def test_scoped_tag(self): + data = np.array(1, np.uint8, ndmin=4) + with tf.name_scope("scope"): + self.assertEqual( + "scope/a/image_summary", self.image("a", data).value[0].tag + ) + + def test_image_count_zero(self): + self.skipTest("fails under eager because map_fn() returns float dtype") class SummaryV2OpTest(SummaryBaseTest, tf.test.TestCase): - def setUp(self): - super(SummaryV2OpTest, self).setUp() - if tf2 is None: - self.skipTest('TF v2 summary API not available') - - def image(self, *args, **kwargs): - return self.image_event(*args, **kwargs).summary - - def image_event(self, *args, **kwargs): - kwargs.setdefault('step', 1) - writer = tf2.summary.create_file_writer(self.get_temp_dir()) - with writer.as_default(): - summary.image(*args, **kwargs) - writer.close() - event_files = sorted(glob.glob(os.path.join(self.get_temp_dir(), '*'))) - self.assertEqual(len(event_files), 1) - events = list(tf.compat.v1.train.summary_iterator(event_files[0])) - # Expect a boilerplate event for the file_version, then the summary one. - self.assertEqual(len(events), 2) - # Delete the event file to reset to an empty directory for later calls. - # TODO(nickfelt): use a unique subdirectory per writer instead. - os.remove(event_files[0]) - return events[1] - - def test_scoped_tag(self): - data = np.array(1, np.uint8, ndmin=4) - with tf.name_scope('scope'): - self.assertEqual('scope/a', self.image('a', data).value[0].tag) - - def test_step(self): - data = np.array(1, np.uint8, ndmin=4) - event = self.image_event('a', data, step=333) - self.assertEqual(333, event.step) - - def test_default_step(self): - data = np.array(1, np.uint8, ndmin=4) - try: - tf2.summary.experimental.set_step(333) - # TODO(nickfelt): change test logic so we can just omit `step` entirely. - event = self.image_event('a', data, step=None) - self.assertEqual(333, event.step) - finally: - # Reset to default state for other tests. - tf2.summary.experimental.set_step(None) - - def test_floating_point_data(self): - data = np.array([-0.01, 0.0, 0.9, 1.0, 1.1]).reshape((1, -1, 1, 1)) - pb = self.image('mona_lisa', data) - encoded = pb.value[0].tensor.string_val[2] # skip width, height - decoded = tf.image.decode_png(encoded).numpy() - # Float values outside [0, 1) are truncated, and everything is scaled to the - # range [0, 255] with 229 = 0.9 * 255, truncated. - self.assertAllEqual([0, 0, 229, 255, 255], list(decoded.flat)) - - -if __name__ == '__main__': - tf.test.main() + def setUp(self): + super(SummaryV2OpTest, self).setUp() + if tf2 is None: + self.skipTest("TF v2 summary API not available") + + def image(self, *args, **kwargs): + return self.image_event(*args, **kwargs).summary + + def image_event(self, *args, **kwargs): + kwargs.setdefault("step", 1) + writer = tf2.summary.create_file_writer(self.get_temp_dir()) + with writer.as_default(): + summary.image(*args, **kwargs) + writer.close() + event_files = sorted(glob.glob(os.path.join(self.get_temp_dir(), "*"))) + self.assertEqual(len(event_files), 1) + events = list(tf.compat.v1.train.summary_iterator(event_files[0])) + # Expect a boilerplate event for the file_version, then the summary one. + self.assertEqual(len(events), 2) + # Delete the event file to reset to an empty directory for later calls. + # TODO(nickfelt): use a unique subdirectory per writer instead. + os.remove(event_files[0]) + return events[1] + + def test_scoped_tag(self): + data = np.array(1, np.uint8, ndmin=4) + with tf.name_scope("scope"): + self.assertEqual("scope/a", self.image("a", data).value[0].tag) + + def test_step(self): + data = np.array(1, np.uint8, ndmin=4) + event = self.image_event("a", data, step=333) + self.assertEqual(333, event.step) + + def test_default_step(self): + data = np.array(1, np.uint8, ndmin=4) + try: + tf2.summary.experimental.set_step(333) + # TODO(nickfelt): change test logic so we can just omit `step` entirely. + event = self.image_event("a", data, step=None) + self.assertEqual(333, event.step) + finally: + # Reset to default state for other tests. + tf2.summary.experimental.set_step(None) + + def test_floating_point_data(self): + data = np.array([-0.01, 0.0, 0.9, 1.0, 1.1]).reshape((1, -1, 1, 1)) + pb = self.image("mona_lisa", data) + encoded = pb.value[0].tensor.string_val[2] # skip width, height + decoded = tf.image.decode_png(encoded).numpy() + # Float values outside [0, 1) are truncated, and everything is scaled to the + # range [0, 255] with 229 = 0.9 * 255, truncated. + self.assertAllEqual([0, 0, 229, 255, 255], list(decoded.flat)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/image/summary_v2.py b/tensorboard/plugins/image/summary_v2.py index 2b87278f89..a3f146fd0e 100644 --- a/tensorboard/plugins/image/summary_v2.py +++ b/tensorboard/plugins/image/summary_v2.py @@ -27,71 +27,80 @@ from tensorboard.util import lazy_tensor_creator -def image(name, - data, - step=None, - max_outputs=3, - description=None): - """Write an image summary. +def image(name, data, step=None, max_outputs=3, description=None): + """Write an image summary. - Arguments: - name: A name for this summary. The summary tag used for TensorBoard will - be this name prefixed by any active name scopes. - data: A `Tensor` representing pixel data with shape `[k, h, w, c]`, - where `k` is the number of images, `h` and `w` are the height and - width of the images, and `c` is the number of channels, which - should be 1, 2, 3, or 4 (grayscale, grayscale with alpha, RGB, RGBA). - Any of the dimensions may be statically unknown (i.e., `None`). - Floating point data will be clipped to the range [0,1). - step: Explicit `int64`-castable monotonic step value for this summary. If - omitted, this defaults to `tf.summary.experimental.get_step()`, which must - not be None. - max_outputs: Optional `int` or rank-0 integer `Tensor`. At most this - many images will be emitted at each step. When more than - `max_outputs` many images are provided, the first `max_outputs` many - images will be used and the rest silently discarded. - description: Optional long-form description for this summary, as a - constant `str`. Markdown is supported. Defaults to empty. + Arguments: + name: A name for this summary. The summary tag used for TensorBoard will + be this name prefixed by any active name scopes. + data: A `Tensor` representing pixel data with shape `[k, h, w, c]`, + where `k` is the number of images, `h` and `w` are the height and + width of the images, and `c` is the number of channels, which + should be 1, 2, 3, or 4 (grayscale, grayscale with alpha, RGB, RGBA). + Any of the dimensions may be statically unknown (i.e., `None`). + Floating point data will be clipped to the range [0,1). + step: Explicit `int64`-castable monotonic step value for this summary. If + omitted, this defaults to `tf.summary.experimental.get_step()`, which must + not be None. + max_outputs: Optional `int` or rank-0 integer `Tensor`. At most this + many images will be emitted at each step. When more than + `max_outputs` many images are provided, the first `max_outputs` many + images will be used and the rest silently discarded. + description: Optional long-form description for this summary, as a + constant `str`. Markdown is supported. Defaults to empty. - Returns: - True on success, or false if no summary was emitted because no default - summary writer was available. + Returns: + True on success, or false if no summary was emitted because no default + summary writer was available. - Raises: - ValueError: if a default writer exists, but no step was provided and - `tf.summary.experimental.get_step()` is None. - """ - summary_metadata = metadata.create_summary_metadata( - display_name=None, description=description) - # TODO(https://github.com/tensorflow/tensorboard/issues/2109): remove fallback - summary_scope = ( - getattr(tf.summary.experimental, 'summary_scope', None) or - tf.summary.summary_scope) - with summary_scope( - name, 'image_summary', values=[data, max_outputs, step]) as (tag, _): - # Defer image encoding preprocessing by passing it as a callable to write(), - # wrapped in a LazyTensorCreator for backwards compatibility, so that we - # only do this work when summaries are actually written. - @lazy_tensor_creator.LazyTensorCreator - def lazy_tensor(): - tf.debugging.assert_rank(data, 4) - tf.debugging.assert_non_negative(max_outputs) - images = tf.image.convert_image_dtype(data, tf.uint8, saturate=True) - limited_images = images[:max_outputs] - encoded_images = tf.map_fn(tf.image.encode_png, limited_images, - dtype=tf.string, - name='encode_each_image') - # Workaround for map_fn returning float dtype for an empty elems input. - encoded_images = tf.cond( - tf.shape(input=encoded_images)[0] > 0, - lambda: encoded_images, lambda: tf.constant([], tf.string)) - image_shape = tf.shape(input=images) - dimensions = tf.stack([tf.as_string(image_shape[2], name='width'), - tf.as_string(image_shape[1], name='height')], - name='dimensions') - return tf.concat([dimensions, encoded_images], axis=0) + Raises: + ValueError: if a default writer exists, but no step was provided and + `tf.summary.experimental.get_step()` is None. + """ + summary_metadata = metadata.create_summary_metadata( + display_name=None, description=description + ) + # TODO(https://github.com/tensorflow/tensorboard/issues/2109): remove fallback + summary_scope = ( + getattr(tf.summary.experimental, "summary_scope", None) + or tf.summary.summary_scope + ) + with summary_scope( + name, "image_summary", values=[data, max_outputs, step] + ) as (tag, _): + # Defer image encoding preprocessing by passing it as a callable to write(), + # wrapped in a LazyTensorCreator for backwards compatibility, so that we + # only do this work when summaries are actually written. + @lazy_tensor_creator.LazyTensorCreator + def lazy_tensor(): + tf.debugging.assert_rank(data, 4) + tf.debugging.assert_non_negative(max_outputs) + images = tf.image.convert_image_dtype(data, tf.uint8, saturate=True) + limited_images = images[:max_outputs] + encoded_images = tf.map_fn( + tf.image.encode_png, + limited_images, + dtype=tf.string, + name="encode_each_image", + ) + # Workaround for map_fn returning float dtype for an empty elems input. + encoded_images = tf.cond( + tf.shape(input=encoded_images)[0] > 0, + lambda: encoded_images, + lambda: tf.constant([], tf.string), + ) + image_shape = tf.shape(input=images) + dimensions = tf.stack( + [ + tf.as_string(image_shape[2], name="width"), + tf.as_string(image_shape[1], name="height"), + ], + name="dimensions", + ) + return tf.concat([dimensions, encoded_images], axis=0) - # To ensure that image encoding logic is only executed when summaries - # are written, we pass callable to `tensor` parameter. - return tf.summary.write( - tag=tag, tensor=lazy_tensor, step=step, metadata=summary_metadata) + # To ensure that image encoding logic is only executed when summaries + # are written, we pass callable to `tensor` parameter. + return tf.summary.write( + tag=tag, tensor=lazy_tensor, step=step, metadata=summary_metadata + ) diff --git a/tensorboard/plugins/interactive_inference/interactive_inference_plugin.py b/tensorboard/plugins/interactive_inference/interactive_inference_plugin.py index 1ba875aad5..edd203678b 100644 --- a/tensorboard/plugins/interactive_inference/interactive_inference_plugin.py +++ b/tensorboard/plugins/interactive_inference/interactive_inference_plugin.py @@ -48,380 +48,494 @@ class InteractiveInferencePlugin(base_plugin.TBPlugin): - """Plugin for understanding/debugging model inference. - """ - - # This string field is used by TensorBoard to generate the paths for routes - # provided by this plugin. It must thus be URL-friendly. This field is also - # used to uniquely identify this plugin throughout TensorBoard. See BasePlugin - # for details. - plugin_name = 'whatif' - examples = [] - updated_example_indices = set() - sprite = None - example_class = tf.train.Example - - # The standard name for encoded image features in TensorFlow. - image_feature_name = 'image/encoded' - - # The width and height of the thumbnail for any images for Facets Dive. - sprite_thumbnail_dim_px = 32 - - # The vocab of inference class indices to label names for the model. - label_vocab = [] - - def __init__(self, context): - """Constructs an interactive inference plugin for TensorBoard. - - Args: - context: A base_plugin.TBContext instance. - """ - self._logdir = context.logdir - self._has_auth_group = (context.flags and - 'authorized_groups' in context.flags and - context.flags.authorized_groups != '') - - def get_plugin_apps(self): - """Obtains a mapping between routes and handlers. Stores the logdir. - - Returns: - A mapping between routes and handlers (functions that respond to - requests). - """ - return { - '/infer': self._infer, - '/update_example': self._update_example, - '/examples_from_path': self._examples_from_path_handler, - '/sprite': self._serve_sprite, - '/duplicate_example': self._duplicate_example, - '/delete_example': self._delete_example, - '/infer_mutants': self._infer_mutants_handler, - '/eligible_features': self._eligible_features_from_example_handler, - '/sort_eligible_features': self._sort_eligible_features_handler, - } - - def is_active(self): - """Determines whether this plugin is active. - - Returns: - A boolean. Whether this plugin is active. - """ - # TODO(jameswex): Maybe enable if config flags were specified? - return False - - def frontend_metadata(self): - # TODO(#2338): Keep this in sync with the `registerDashboard` call - # on the frontend until that call is removed. - return base_plugin.FrontendMetadata( - element_name='tf-interactive-inference-dashboard', - tab_name='What-If Tool', - ) - - def generate_sprite(self, example_strings): - # Generate a sprite image for the examples if the examples contain the - # standard encoded image feature. - feature_list = (self.examples[0].features.feature - if self.example_class == tf.train.Example - else self.examples[0].context.feature) - self.sprite = ( - inference_utils.create_sprite_image(example_strings) - if (len(self.examples) and self.image_feature_name in feature_list) else - None) - - @wrappers.Request.application - def _examples_from_path_handler(self, request): - """Returns JSON of the specified examples. - - Args: - request: A request that should contain 'examples_path' and 'max_examples'. - - Returns: - JSON of up to max_examlpes of the examples in the path. - """ - examples_count = int(request.args.get('max_examples')) - examples_path = request.args.get('examples_path') - sampling_odds = float(request.args.get('sampling_odds')) - self.example_class = (tf.train.SequenceExample - if request.args.get('sequence_examples') == 'true' - else tf.train.Example) - try: - platform_utils.throw_if_file_access_not_allowed(examples_path, - self._logdir, - self._has_auth_group) - example_strings = platform_utils.example_protos_from_path( - examples_path, examples_count, parse_examples=False, - sampling_odds=sampling_odds, example_class=self.example_class) - self.examples = [ - self.example_class.FromString(ex) for ex in example_strings] - self.generate_sprite(example_strings) - json_examples = [ - json_format.MessageToJson(example) for example in self.examples - ] - self.updated_example_indices = set(range(len(json_examples))) - return http_util.Respond( - request, - {'examples': json_examples, - 'sprite': True if self.sprite else False}, 'application/json') - except common_utils.InvalidUserInputError as e: - return http_util.Respond(request, {'error': e.message}, - 'application/json', code=400) - - @wrappers.Request.application - def _serve_sprite(self, request): - return http_util.Respond(request, self.sprite, 'image/png') - - @wrappers.Request.application - def _update_example(self, request): - """Updates the specified example. - - Args: - request: A request that should contain 'index' and 'example'. - - Returns: - An empty response. - """ - if request.method != 'POST': - return http_util.Respond(request, {'error': 'invalid non-POST request'}, - 'application/json', code=405) - example_json = request.form['example'] - index = int(request.form['index']) - if index >= len(self.examples): - return http_util.Respond(request, {'error': 'invalid index provided'}, - 'application/json', code=400) - new_example = self.example_class() - json_format.Parse(example_json, new_example) - self.examples[index] = new_example - self.updated_example_indices.add(index) - self.generate_sprite([ex.SerializeToString() for ex in self.examples]) - return http_util.Respond(request, {}, 'application/json') - - @wrappers.Request.application - def _duplicate_example(self, request): - """Duplicates the specified example. - - Args: - request: A request that should contain 'index'. - - Returns: - An empty response. - """ - index = int(request.args.get('index')) - if index >= len(self.examples): - return http_util.Respond(request, {'error': 'invalid index provided'}, - 'application/json', code=400) - new_example = self.example_class() - new_example.CopyFrom(self.examples[index]) - self.examples.append(new_example) - self.updated_example_indices.add(len(self.examples) - 1) - self.generate_sprite([ex.SerializeToString() for ex in self.examples]) - return http_util.Respond(request, {}, 'application/json') - - @wrappers.Request.application - def _delete_example(self, request): - """Deletes the specified example. - - Args: - request: A request that should contain 'index'. - - Returns: - An empty response. - """ - index = int(request.args.get('index')) - if index >= len(self.examples): - return http_util.Respond(request, {'error': 'invalid index provided'}, - 'application/json', code=400) - del self.examples[index] - self.updated_example_indices = set([ - i if i < index else i - 1 for i in self.updated_example_indices]) - self.generate_sprite([ex.SerializeToString() for ex in self.examples]) - return http_util.Respond(request, {}, 'application/json') - - def _parse_request_arguments(self, request): - """Parses comma separated request arguments - - Args: - request: A request that should contain 'inference_address', 'model_name', - 'model_version', 'model_signature'. - - Returns: - A tuple of lists for model parameters - """ - inference_addresses = request.args.get('inference_address').split(',') - model_names = request.args.get('model_name').split(',') - model_versions = request.args.get('model_version').split(',') - model_signatures = request.args.get('model_signature').split(',') - if len(model_names) != len(inference_addresses): - raise common_utils.InvalidUserInputError('Every model should have a ' + - 'name and address.') - return inference_addresses, model_names, model_versions, model_signatures - - @wrappers.Request.application - def _infer(self, request): - """Returns JSON for the `vz-line-chart`s for a feature. - - Args: - request: A request that should contain 'inference_address', 'model_name', - 'model_type, 'model_version', 'model_signature' and 'label_vocab_path'. - - Returns: - A list of JSON objects, one for each chart. - """ - label_vocab = inference_utils.get_label_vocab( - request.args.get('label_vocab_path')) - - try: - if request.method != 'GET': - logger.error('%s requests are forbidden.', request.method) - return http_util.Respond(request, {'error': 'invalid non-GET request'}, - 'application/json', code=405) - - (inference_addresses, model_names, model_versions, - model_signatures) = self._parse_request_arguments(request) - - indices_to_infer = sorted(self.updated_example_indices) - examples_to_infer = [self.examples[index] for index in indices_to_infer] - infer_objs = [] - for model_num in xrange(len(inference_addresses)): - serving_bundle = inference_utils.ServingBundle( - inference_addresses[model_num], - model_names[model_num], - request.args.get('model_type'), - model_versions[model_num], - model_signatures[model_num], - request.args.get('use_predict') == 'true', - request.args.get('predict_input_tensor'), - request.args.get('predict_output_tensor')) - (predictions, _) = inference_utils.run_inference_for_inference_results( - examples_to_infer, serving_bundle) - infer_objs.append(predictions) - - resp = {'indices': indices_to_infer, 'results': infer_objs} - self.updated_example_indices = set() - return http_util.Respond(request, {'inferences': json.dumps(resp), - 'vocab': json.dumps(label_vocab)}, - 'application/json') - except common_utils.InvalidUserInputError as e: - return http_util.Respond(request, {'error': e.message}, - 'application/json', code=400) - except AbortionError as e: - return http_util.Respond(request, {'error': e.details}, - 'application/json', code=400) - - @wrappers.Request.application - def _eligible_features_from_example_handler(self, request): - """Returns a list of JSON objects for each feature in the example. - - Args: - request: A request for features. - - Returns: - A list with a JSON object for each feature. - Numeric features are represented as {name: observedMin: observedMax:}. - Categorical features are repesented as {name: samples:[]}. - """ - features_list = inference_utils.get_eligible_features( - self.examples[0: NUM_EXAMPLES_TO_SCAN], NUM_MUTANTS) - return http_util.Respond(request, features_list, 'application/json') - - @wrappers.Request.application - def _sort_eligible_features_handler(self, request): - """Returns a sorted list of JSON objects for each feature in the example. - - The list is sorted by interestingness in terms of the resulting change in - inference values across feature values, for partial dependence plots. - - Args: - request: A request for sorted features. - - Returns: - A sorted list with a JSON object for each feature. - Numeric features are represented as - {name: observedMin: observedMax: interestingness:}. - Categorical features are repesented as - {name: samples:[] interestingness:}. - """ - try: - features_list = inference_utils.get_eligible_features( - self.examples[0: NUM_EXAMPLES_TO_SCAN], NUM_MUTANTS) - example_index = int(request.args.get('example_index', '0')) - (inference_addresses, model_names, model_versions, - model_signatures) = self._parse_request_arguments(request) - chart_data = {} - for feat in features_list: - chart_data[feat['name']] = self._infer_mutants_impl( - feat['name'], example_index, - inference_addresses, model_names, request.args.get('model_type'), - model_versions, model_signatures, - request.args.get('use_predict') == 'true', - request.args.get('predict_input_tensor'), - request.args.get('predict_output_tensor'), - feat['observedMin'] if 'observedMin' in feat else 0, - feat['observedMax'] if 'observedMin' in feat else 0, - None) - features_list = inference_utils.sort_eligible_features( - features_list, chart_data) - return http_util.Respond(request, features_list, 'application/json') - except common_utils.InvalidUserInputError as e: - return http_util.Respond(request, {'error': e.message}, - 'application/json', code=400) - - @wrappers.Request.application - def _infer_mutants_handler(self, request): - """Returns JSON for the partial dependence plots for a feature. - - Args: - request: A request that should contain 'feature_name', 'example_index', - 'inference_address', 'model_name', 'model_type', 'model_version', and - 'model_signature'. - - Returns: - A list of JSON objects, one for each chart. - """ - try: - if request.method != 'GET': - logger.error('%s requests are forbidden.', request.method) - return http_util.Respond(request, {'error': 'invalid non-GET request'}, - 'application/json', code=405) - - example_index = int(request.args.get('example_index', '0')) - feature_name = request.args.get('feature_name') - (inference_addresses, model_names, model_versions, - model_signatures) = self._parse_request_arguments(request) - json_mapping = self._infer_mutants_impl(feature_name, example_index, - inference_addresses, model_names, request.args.get('model_type'), - model_versions, model_signatures, - request.args.get('use_predict') == 'true', - request.args.get('predict_input_tensor'), - request.args.get('predict_output_tensor'), - request.args.get('x_min'), request.args.get('x_max'), - request.args.get('feature_index_pattern')) - return http_util.Respond(request, json_mapping, 'application/json') - except common_utils.InvalidUserInputError as e: - return http_util.Respond(request, {'error': e.message}, - 'application/json', code=400) - - def _infer_mutants_impl(self, feature_name, example_index, inference_addresses, - model_names, model_type, model_versions, model_signatures, use_predict, - predict_input_tensor, predict_output_tensor, x_min, x_max, - feature_index_pattern): - """Helper for generating PD plots for a feature.""" - examples = (self.examples if example_index == -1 - else [self.examples[example_index]]) - serving_bundles = [] - for model_num in xrange(len(inference_addresses)): - serving_bundles.append(inference_utils.ServingBundle( - inference_addresses[model_num], - model_names[model_num], - model_type, - model_versions[model_num], - model_signatures[model_num], - use_predict, - predict_input_tensor, - predict_output_tensor)) - - viz_params = inference_utils.VizParams( - x_min, x_max, - self.examples[0:NUM_EXAMPLES_TO_SCAN], NUM_MUTANTS, - feature_index_pattern) - return inference_utils.mutant_charts_for_feature( - examples, feature_name, serving_bundles, viz_params) + """Plugin for understanding/debugging model inference.""" + + # This string field is used by TensorBoard to generate the paths for routes + # provided by this plugin. It must thus be URL-friendly. This field is also + # used to uniquely identify this plugin throughout TensorBoard. See BasePlugin + # for details. + plugin_name = "whatif" + examples = [] + updated_example_indices = set() + sprite = None + example_class = tf.train.Example + + # The standard name for encoded image features in TensorFlow. + image_feature_name = "image/encoded" + + # The width and height of the thumbnail for any images for Facets Dive. + sprite_thumbnail_dim_px = 32 + + # The vocab of inference class indices to label names for the model. + label_vocab = [] + + def __init__(self, context): + """Constructs an interactive inference plugin for TensorBoard. + + Args: + context: A base_plugin.TBContext instance. + """ + self._logdir = context.logdir + self._has_auth_group = ( + context.flags + and "authorized_groups" in context.flags + and context.flags.authorized_groups != "" + ) + + def get_plugin_apps(self): + """Obtains a mapping between routes and handlers. Stores the logdir. + + Returns: + A mapping between routes and handlers (functions that respond to + requests). + """ + return { + "/infer": self._infer, + "/update_example": self._update_example, + "/examples_from_path": self._examples_from_path_handler, + "/sprite": self._serve_sprite, + "/duplicate_example": self._duplicate_example, + "/delete_example": self._delete_example, + "/infer_mutants": self._infer_mutants_handler, + "/eligible_features": self._eligible_features_from_example_handler, + "/sort_eligible_features": self._sort_eligible_features_handler, + } + + def is_active(self): + """Determines whether this plugin is active. + + Returns: + A boolean. Whether this plugin is active. + """ + # TODO(jameswex): Maybe enable if config flags were specified? + return False + + def frontend_metadata(self): + # TODO(#2338): Keep this in sync with the `registerDashboard` call + # on the frontend until that call is removed. + return base_plugin.FrontendMetadata( + element_name="tf-interactive-inference-dashboard", + tab_name="What-If Tool", + ) + + def generate_sprite(self, example_strings): + # Generate a sprite image for the examples if the examples contain the + # standard encoded image feature. + feature_list = ( + self.examples[0].features.feature + if self.example_class == tf.train.Example + else self.examples[0].context.feature + ) + self.sprite = ( + inference_utils.create_sprite_image(example_strings) + if (len(self.examples) and self.image_feature_name in feature_list) + else None + ) + + @wrappers.Request.application + def _examples_from_path_handler(self, request): + """Returns JSON of the specified examples. + + Args: + request: A request that should contain 'examples_path' and 'max_examples'. + + Returns: + JSON of up to max_examlpes of the examples in the path. + """ + examples_count = int(request.args.get("max_examples")) + examples_path = request.args.get("examples_path") + sampling_odds = float(request.args.get("sampling_odds")) + self.example_class = ( + tf.train.SequenceExample + if request.args.get("sequence_examples") == "true" + else tf.train.Example + ) + try: + platform_utils.throw_if_file_access_not_allowed( + examples_path, self._logdir, self._has_auth_group + ) + example_strings = platform_utils.example_protos_from_path( + examples_path, + examples_count, + parse_examples=False, + sampling_odds=sampling_odds, + example_class=self.example_class, + ) + self.examples = [ + self.example_class.FromString(ex) for ex in example_strings + ] + self.generate_sprite(example_strings) + json_examples = [ + json_format.MessageToJson(example) for example in self.examples + ] + self.updated_example_indices = set(range(len(json_examples))) + return http_util.Respond( + request, + { + "examples": json_examples, + "sprite": True if self.sprite else False, + }, + "application/json", + ) + except common_utils.InvalidUserInputError as e: + return http_util.Respond( + request, {"error": e.message}, "application/json", code=400 + ) + + @wrappers.Request.application + def _serve_sprite(self, request): + return http_util.Respond(request, self.sprite, "image/png") + + @wrappers.Request.application + def _update_example(self, request): + """Updates the specified example. + + Args: + request: A request that should contain 'index' and 'example'. + + Returns: + An empty response. + """ + if request.method != "POST": + return http_util.Respond( + request, + {"error": "invalid non-POST request"}, + "application/json", + code=405, + ) + example_json = request.form["example"] + index = int(request.form["index"]) + if index >= len(self.examples): + return http_util.Respond( + request, + {"error": "invalid index provided"}, + "application/json", + code=400, + ) + new_example = self.example_class() + json_format.Parse(example_json, new_example) + self.examples[index] = new_example + self.updated_example_indices.add(index) + self.generate_sprite([ex.SerializeToString() for ex in self.examples]) + return http_util.Respond(request, {}, "application/json") + + @wrappers.Request.application + def _duplicate_example(self, request): + """Duplicates the specified example. + + Args: + request: A request that should contain 'index'. + + Returns: + An empty response. + """ + index = int(request.args.get("index")) + if index >= len(self.examples): + return http_util.Respond( + request, + {"error": "invalid index provided"}, + "application/json", + code=400, + ) + new_example = self.example_class() + new_example.CopyFrom(self.examples[index]) + self.examples.append(new_example) + self.updated_example_indices.add(len(self.examples) - 1) + self.generate_sprite([ex.SerializeToString() for ex in self.examples]) + return http_util.Respond(request, {}, "application/json") + + @wrappers.Request.application + def _delete_example(self, request): + """Deletes the specified example. + + Args: + request: A request that should contain 'index'. + + Returns: + An empty response. + """ + index = int(request.args.get("index")) + if index >= len(self.examples): + return http_util.Respond( + request, + {"error": "invalid index provided"}, + "application/json", + code=400, + ) + del self.examples[index] + self.updated_example_indices = set( + [i if i < index else i - 1 for i in self.updated_example_indices] + ) + self.generate_sprite([ex.SerializeToString() for ex in self.examples]) + return http_util.Respond(request, {}, "application/json") + + def _parse_request_arguments(self, request): + """Parses comma separated request arguments. + + Args: + request: A request that should contain 'inference_address', 'model_name', + 'model_version', 'model_signature'. + + Returns: + A tuple of lists for model parameters + """ + inference_addresses = request.args.get("inference_address").split(",") + model_names = request.args.get("model_name").split(",") + model_versions = request.args.get("model_version").split(",") + model_signatures = request.args.get("model_signature").split(",") + if len(model_names) != len(inference_addresses): + raise common_utils.InvalidUserInputError( + "Every model should have a " + "name and address." + ) + return ( + inference_addresses, + model_names, + model_versions, + model_signatures, + ) + + @wrappers.Request.application + def _infer(self, request): + """Returns JSON for the `vz-line-chart`s for a feature. + + Args: + request: A request that should contain 'inference_address', 'model_name', + 'model_type, 'model_version', 'model_signature' and 'label_vocab_path'. + + Returns: + A list of JSON objects, one for each chart. + """ + label_vocab = inference_utils.get_label_vocab( + request.args.get("label_vocab_path") + ) + + try: + if request.method != "GET": + logger.error("%s requests are forbidden.", request.method) + return http_util.Respond( + request, + {"error": "invalid non-GET request"}, + "application/json", + code=405, + ) + + ( + inference_addresses, + model_names, + model_versions, + model_signatures, + ) = self._parse_request_arguments(request) + + indices_to_infer = sorted(self.updated_example_indices) + examples_to_infer = [ + self.examples[index] for index in indices_to_infer + ] + infer_objs = [] + for model_num in xrange(len(inference_addresses)): + serving_bundle = inference_utils.ServingBundle( + inference_addresses[model_num], + model_names[model_num], + request.args.get("model_type"), + model_versions[model_num], + model_signatures[model_num], + request.args.get("use_predict") == "true", + request.args.get("predict_input_tensor"), + request.args.get("predict_output_tensor"), + ) + ( + predictions, + _, + ) = inference_utils.run_inference_for_inference_results( + examples_to_infer, serving_bundle + ) + infer_objs.append(predictions) + + resp = {"indices": indices_to_infer, "results": infer_objs} + self.updated_example_indices = set() + return http_util.Respond( + request, + { + "inferences": json.dumps(resp), + "vocab": json.dumps(label_vocab), + }, + "application/json", + ) + except common_utils.InvalidUserInputError as e: + return http_util.Respond( + request, {"error": e.message}, "application/json", code=400 + ) + except AbortionError as e: + return http_util.Respond( + request, {"error": e.details}, "application/json", code=400 + ) + + @wrappers.Request.application + def _eligible_features_from_example_handler(self, request): + """Returns a list of JSON objects for each feature in the example. + + Args: + request: A request for features. + + Returns: + A list with a JSON object for each feature. + Numeric features are represented as {name: observedMin: observedMax:}. + Categorical features are repesented as {name: samples:[]}. + """ + features_list = inference_utils.get_eligible_features( + self.examples[0:NUM_EXAMPLES_TO_SCAN], NUM_MUTANTS + ) + return http_util.Respond(request, features_list, "application/json") + + @wrappers.Request.application + def _sort_eligible_features_handler(self, request): + """Returns a sorted list of JSON objects for each feature in the + example. + + The list is sorted by interestingness in terms of the resulting change in + inference values across feature values, for partial dependence plots. + + Args: + request: A request for sorted features. + + Returns: + A sorted list with a JSON object for each feature. + Numeric features are represented as + {name: observedMin: observedMax: interestingness:}. + Categorical features are repesented as + {name: samples:[] interestingness:}. + """ + try: + features_list = inference_utils.get_eligible_features( + self.examples[0:NUM_EXAMPLES_TO_SCAN], NUM_MUTANTS + ) + example_index = int(request.args.get("example_index", "0")) + ( + inference_addresses, + model_names, + model_versions, + model_signatures, + ) = self._parse_request_arguments(request) + chart_data = {} + for feat in features_list: + chart_data[feat["name"]] = self._infer_mutants_impl( + feat["name"], + example_index, + inference_addresses, + model_names, + request.args.get("model_type"), + model_versions, + model_signatures, + request.args.get("use_predict") == "true", + request.args.get("predict_input_tensor"), + request.args.get("predict_output_tensor"), + feat["observedMin"] if "observedMin" in feat else 0, + feat["observedMax"] if "observedMin" in feat else 0, + None, + ) + features_list = inference_utils.sort_eligible_features( + features_list, chart_data + ) + return http_util.Respond(request, features_list, "application/json") + except common_utils.InvalidUserInputError as e: + return http_util.Respond( + request, {"error": e.message}, "application/json", code=400 + ) + + @wrappers.Request.application + def _infer_mutants_handler(self, request): + """Returns JSON for the partial dependence plots for a feature. + + Args: + request: A request that should contain 'feature_name', 'example_index', + 'inference_address', 'model_name', 'model_type', 'model_version', and + 'model_signature'. + + Returns: + A list of JSON objects, one for each chart. + """ + try: + if request.method != "GET": + logger.error("%s requests are forbidden.", request.method) + return http_util.Respond( + request, + {"error": "invalid non-GET request"}, + "application/json", + code=405, + ) + + example_index = int(request.args.get("example_index", "0")) + feature_name = request.args.get("feature_name") + ( + inference_addresses, + model_names, + model_versions, + model_signatures, + ) = self._parse_request_arguments(request) + json_mapping = self._infer_mutants_impl( + feature_name, + example_index, + inference_addresses, + model_names, + request.args.get("model_type"), + model_versions, + model_signatures, + request.args.get("use_predict") == "true", + request.args.get("predict_input_tensor"), + request.args.get("predict_output_tensor"), + request.args.get("x_min"), + request.args.get("x_max"), + request.args.get("feature_index_pattern"), + ) + return http_util.Respond(request, json_mapping, "application/json") + except common_utils.InvalidUserInputError as e: + return http_util.Respond( + request, {"error": e.message}, "application/json", code=400 + ) + + def _infer_mutants_impl( + self, + feature_name, + example_index, + inference_addresses, + model_names, + model_type, + model_versions, + model_signatures, + use_predict, + predict_input_tensor, + predict_output_tensor, + x_min, + x_max, + feature_index_pattern, + ): + """Helper for generating PD plots for a feature.""" + examples = ( + self.examples + if example_index == -1 + else [self.examples[example_index]] + ) + serving_bundles = [] + for model_num in xrange(len(inference_addresses)): + serving_bundles.append( + inference_utils.ServingBundle( + inference_addresses[model_num], + model_names[model_num], + model_type, + model_versions[model_num], + model_signatures[model_num], + use_predict, + predict_input_tensor, + predict_output_tensor, + ) + ) + + viz_params = inference_utils.VizParams( + x_min, + x_max, + self.examples[0:NUM_EXAMPLES_TO_SCAN], + NUM_MUTANTS, + feature_index_pattern, + ) + return inference_utils.mutant_charts_for_feature( + examples, feature_name, serving_bundles, viz_params + ) diff --git a/tensorboard/plugins/interactive_inference/interactive_inference_plugin_loader.py b/tensorboard/plugins/interactive_inference/interactive_inference_plugin_loader.py index eceaf1ad62..72e71daea3 100644 --- a/tensorboard/plugins/interactive_inference/interactive_inference_plugin_loader.py +++ b/tensorboard/plugins/interactive_inference/interactive_inference_plugin_loader.py @@ -22,24 +22,27 @@ class InteractiveInferencePluginLoader(base_plugin.TBLoader): - """InteractiveInferencePlugin factory. + """InteractiveInferencePlugin factory. - This class checks for `tensorflow` install and dependency. - """ - - def load(self, context): - """Returns the plugin, if possible. - - Args: - context: The TBContext flags. - - Returns: - A InteractiveInferencePlugin instance or None if it couldn't be loaded. + This class checks for `tensorflow` install and dependency. """ - try: - # pylint: disable=unused-import - import tensorflow - except ImportError: - return - from tensorboard.plugins.interactive_inference.interactive_inference_plugin import InteractiveInferencePlugin - return InteractiveInferencePlugin(context) + + def load(self, context): + """Returns the plugin, if possible. + + Args: + context: The TBContext flags. + + Returns: + A InteractiveInferencePlugin instance or None if it couldn't be loaded. + """ + try: + # pylint: disable=unused-import + import tensorflow + except ImportError: + return + from tensorboard.plugins.interactive_inference.interactive_inference_plugin import ( + InteractiveInferencePlugin, + ) + + return InteractiveInferencePlugin(context) diff --git a/tensorboard/plugins/interactive_inference/interactive_inference_plugin_test.py b/tensorboard/plugins/interactive_inference/interactive_inference_plugin_test.py index 417dd17687..a18ae484ee 100644 --- a/tensorboard/plugins/interactive_inference/interactive_inference_plugin_test.py +++ b/tensorboard/plugins/interactive_inference/interactive_inference_plugin_test.py @@ -26,10 +26,10 @@ import tensorflow as tf try: - # python version >= 3.3 - from unittest import mock + # python version >= 3.3 + from unittest import mock except ImportError: - import mock # pylint: disable=unused-import + import mock # pylint: disable=unused-import from six.moves import urllib_parse from google.protobuf import json_format @@ -38,233 +38,286 @@ from werkzeug import wrappers from tensorboard.backend import application -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from tensorboard.backend.event_processing import ( + plugin_event_multiplexer as event_multiplexer, +) from tensorboard.plugins import base_plugin from tensorboard.plugins.interactive_inference.utils import inference_utils from tensorboard.plugins.interactive_inference.utils import platform_utils from tensorboard.plugins.interactive_inference.utils import test_utils -from tensorboard.plugins.interactive_inference import interactive_inference_plugin +from tensorboard.plugins.interactive_inference import ( + interactive_inference_plugin, +) class InferencePluginTest(tf.test.TestCase): - - def setUp(self): - self.logdir = tf.compat.v1.test.get_temp_dir() - - self.context = base_plugin.TBContext(logdir=self.logdir) - self.plugin = interactive_inference_plugin.InteractiveInferencePlugin( - self.context) - wsgi_app = application.TensorBoardWSGI([self.plugin]) - self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse) - - def get_fake_example(self, single_int_value=0): - example = tf.train.Example() - example.features.feature['single_int'].int64_list.value.extend( - [single_int_value]) - return example - - def test_examples_from_path(self): - examples = [self.get_fake_example(0), self.get_fake_example(1)] - examples_path = os.path.join(self.logdir, 'test_examples.rio') - test_utils.write_out_examples(examples, examples_path) - - response = self.server.get( - '/data/plugin/whatif/examples_from_path?' + - urllib_parse.urlencode({ - 'examples_path': examples_path, - 'max_examples': 2, - 'sampling_odds': 1, - })) - self.assertEqual(200, response.status_code) - example_strings = json.loads(response.get_data().decode('utf-8'))['examples'] - received_examples = [json.loads(x) for x in example_strings] - self.assertEqual(2, len(received_examples)) - self.assertEqual(0, - int(received_examples[0]['features']['feature'][ - 'single_int']['int64List']['value'][0])) - self.assertEqual(1, - int(received_examples[1]['features']['feature'][ - 'single_int']['int64List']['value'][0])) - - def test_examples_from_path_if_path_does_not_exist(self): - response = self.server.get( - '/data/plugin/whatif/examples_from_path?' + - urllib_parse.urlencode({ - 'examples_path': 'does_not_exist', - 'max_examples': 2, - 'sampling_odds': 1, - })) - error = json.loads(response.get_data().decode('utf-8'))['error'] - self.assertTrue(error) - - def test_update_example(self): - self.plugin.examples = [tf.train.Example()] - example = self.get_fake_example() - response = self.server.post( - '/data/plugin/whatif/update_example', - data=dict(example=json_format.MessageToJson(example), index='0')) - self.assertEqual(200, response.status_code) - self.assertEqual(example, self.plugin.examples[0]) - self.assertTrue(0 in self.plugin.updated_example_indices) - - def test_update_example_invalid_index(self): - self.plugin.examples = [tf.train.Example()] - example = self.get_fake_example() - response = self.server.post( - '/data/plugin/whatif/update_example', - data=dict(example=json_format.MessageToJson(example), index='1')) - error = json.loads(response.get_data().decode('utf-8'))['error'] - self.assertTrue(error) - - @mock.patch.object(platform_utils, 'call_servo') - def test_infer(self, mock_call_servo): - self.plugin.examples = [ - self.get_fake_example(0), - self.get_fake_example(1), - self.get_fake_example(2) - ] - self.plugin.updated_example_indices = set([0, 2]) - - inference_result_proto = regression_pb2.RegressionResponse() - regression = inference_result_proto.result.regressions.add() - regression.value = 0.45 - regression = inference_result_proto.result.regressions.add() - regression.value = 0.55 - mock_call_servo.return_value = inference_result_proto - - response = self.server.get( - '/data/plugin/whatif/infer?' + urllib_parse.urlencode({ - 'inference_address': 'addr', - 'model_name': 'name', - 'model_type': 'regression', - 'model_version': ',', - 'model_signature': ',', - })) - - self.assertEqual(200, response.status_code) - self.assertEqual(0, len(self.plugin.updated_example_indices)) - inferences = json.loads(json.loads(response.get_data().decode('utf-8'))[ - 'inferences']) - self.assertTrue(0 in inferences['indices']) - self.assertFalse(1 in inferences['indices']) - self.assertTrue(2 in inferences['indices']) - - def _DeserializeResponse(self, byte_content): - """Deserializes byte content that is a JSON encoding. - - Args: - byte_content: The byte content of a JSON response. - - Returns: - The deserialized python object decoded from JSON. - """ - return json.loads(byte_content.decode('utf-8')) - - def test_eligible_features_from_example_proto(self): - example = test_utils.make_fake_example(single_int_val=2) - self.plugin.examples = [example] - - response = self.server.get('/data/plugin/whatif/eligible_features') - self.assertEqual(200, response.status_code) - - # Returns a list of dict objects that have been sorted by feature_name. - data = self._DeserializeResponse(response.get_data()) - - sorted_feature_names = [ - 'non_numeric', 'repeated_float', 'repeated_int', 'single_float', - 'single_int' - ] - self.assertEqual(sorted_feature_names, [d['name'] for d in data]) - np.testing.assert_almost_equal([-1, 1., 10, 24.5, 2.], - [d.get('observedMin', -1) for d in data]) - np.testing.assert_almost_equal([-1, 4., 20, 24.5, 2.], - [d.get('observedMax', -1) for d in data]) - - # Test that only non_numeric feature has samples. - self.assertFalse(any(d.get('samples') for d in data[1:])) - self.assertEqual(['cat'], data[0]['samples']) - - @mock.patch.object(inference_utils, 'mutant_charts_for_feature') - def test_infer_mutants_handler(self, mock_mutant_charts_for_feature): - - # A no-op that just passes the example passed to mutant_charts_for_feature - # back through. This tests that the URL parameters get processed properly - # within infer_mutants_handler. - def pass_through(example, feature_name, serving_bundles, viz_params): - return { - 'example': str(example), - 'feature_name': feature_name, - 'serving_bundles': [{ - 'inference_address': serving_bundles[0].inference_address, - 'model_name': serving_bundles[0].model_name, - 'model_type': serving_bundles[0].model_type, - }], - 'viz_params': { - 'x_min': viz_params.x_min, - 'x_max': viz_params.x_max - } - } - - mock_mutant_charts_for_feature.side_effect = pass_through - - example = test_utils.make_fake_example() - self.plugin.examples = [example] - - response = self.server.get( - '/data/plugin/whatif/infer_mutants?' + urllib_parse.urlencode({ - 'feature_name': 'single_int', - 'model_name': '/ml/cassandrax/iris_classification', - 'inference_address': 'ml-serving-temp.prediction', - 'model_type': 'classification', - 'model_version': ',', - 'model_signature': ',', - 'x_min': '-10', - 'x_max': '10', - })) - result = self._DeserializeResponse(response.get_data()) - self.assertEqual(str([example]), result['example']) - self.assertEqual('single_int', result['feature_name']) - self.assertEqual('ml-serving-temp.prediction', - result['serving_bundles'][0]['inference_address']) - self.assertEqual('/ml/cassandrax/iris_classification', - result['serving_bundles'][0]['model_name']) - self.assertEqual('classification', result['serving_bundles'][0]['model_type']) - self.assertAlmostEqual(-10, result['viz_params']['x_min']) - self.assertAlmostEqual(10, result['viz_params']['x_max']) - - @mock.patch.object(inference_utils, 'sort_eligible_features') - @mock.patch.object(inference_utils, 'mutant_charts_for_feature') - def test_infer( - self, mock_mutant_charts_for_feature, mock_sort_eligible_features): - self.plugin.examples = [ - self.get_fake_example(0), - self.get_fake_example(1), - self.get_fake_example(2) - ] - - mock_mutant_charts_for_feature.return_value = [] - sorted_features_list = [ - {'name': 'feat1', 'interestingness': .2}, - {'name': 'feat2', 'interestingness': .1} - ] - mock_sort_eligible_features.return_value = sorted_features_list - - url_options = urllib_parse.urlencode({ - 'inference_address': 'addr', - 'model_name': 'name', - 'model_type': 'regression', - 'model_version': '', - 'model_signature': '', - }) - response = self.server.get( - '/data/plugin/whatif/sort_eligible_features?' + url_options) - - self.assertEqual(200, response.status_code) - self.assertEqual(0, len(self.plugin.updated_example_indices)) - output_list = json.loads(response.get_data().decode('utf-8')) - self.assertEquals('feat1', output_list[0]['name']) - self.assertEquals('feat2', output_list[1]['name']) - - -if __name__ == '__main__': - tf.test.main() + def setUp(self): + self.logdir = tf.compat.v1.test.get_temp_dir() + + self.context = base_plugin.TBContext(logdir=self.logdir) + self.plugin = interactive_inference_plugin.InteractiveInferencePlugin( + self.context + ) + wsgi_app = application.TensorBoardWSGI([self.plugin]) + self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse) + + def get_fake_example(self, single_int_value=0): + example = tf.train.Example() + example.features.feature["single_int"].int64_list.value.extend( + [single_int_value] + ) + return example + + def test_examples_from_path(self): + examples = [self.get_fake_example(0), self.get_fake_example(1)] + examples_path = os.path.join(self.logdir, "test_examples.rio") + test_utils.write_out_examples(examples, examples_path) + + response = self.server.get( + "/data/plugin/whatif/examples_from_path?" + + urllib_parse.urlencode( + { + "examples_path": examples_path, + "max_examples": 2, + "sampling_odds": 1, + } + ) + ) + self.assertEqual(200, response.status_code) + example_strings = json.loads(response.get_data().decode("utf-8"))[ + "examples" + ] + received_examples = [json.loads(x) for x in example_strings] + self.assertEqual(2, len(received_examples)) + self.assertEqual( + 0, + int( + received_examples[0]["features"]["feature"]["single_int"][ + "int64List" + ]["value"][0] + ), + ) + self.assertEqual( + 1, + int( + received_examples[1]["features"]["feature"]["single_int"][ + "int64List" + ]["value"][0] + ), + ) + + def test_examples_from_path_if_path_does_not_exist(self): + response = self.server.get( + "/data/plugin/whatif/examples_from_path?" + + urllib_parse.urlencode( + { + "examples_path": "does_not_exist", + "max_examples": 2, + "sampling_odds": 1, + } + ) + ) + error = json.loads(response.get_data().decode("utf-8"))["error"] + self.assertTrue(error) + + def test_update_example(self): + self.plugin.examples = [tf.train.Example()] + example = self.get_fake_example() + response = self.server.post( + "/data/plugin/whatif/update_example", + data=dict(example=json_format.MessageToJson(example), index="0"), + ) + self.assertEqual(200, response.status_code) + self.assertEqual(example, self.plugin.examples[0]) + self.assertTrue(0 in self.plugin.updated_example_indices) + + def test_update_example_invalid_index(self): + self.plugin.examples = [tf.train.Example()] + example = self.get_fake_example() + response = self.server.post( + "/data/plugin/whatif/update_example", + data=dict(example=json_format.MessageToJson(example), index="1"), + ) + error = json.loads(response.get_data().decode("utf-8"))["error"] + self.assertTrue(error) + + @mock.patch.object(platform_utils, "call_servo") + def test_infer(self, mock_call_servo): + self.plugin.examples = [ + self.get_fake_example(0), + self.get_fake_example(1), + self.get_fake_example(2), + ] + self.plugin.updated_example_indices = set([0, 2]) + + inference_result_proto = regression_pb2.RegressionResponse() + regression = inference_result_proto.result.regressions.add() + regression.value = 0.45 + regression = inference_result_proto.result.regressions.add() + regression.value = 0.55 + mock_call_servo.return_value = inference_result_proto + + response = self.server.get( + "/data/plugin/whatif/infer?" + + urllib_parse.urlencode( + { + "inference_address": "addr", + "model_name": "name", + "model_type": "regression", + "model_version": ",", + "model_signature": ",", + } + ) + ) + + self.assertEqual(200, response.status_code) + self.assertEqual(0, len(self.plugin.updated_example_indices)) + inferences = json.loads( + json.loads(response.get_data().decode("utf-8"))["inferences"] + ) + self.assertTrue(0 in inferences["indices"]) + self.assertFalse(1 in inferences["indices"]) + self.assertTrue(2 in inferences["indices"]) + + def _DeserializeResponse(self, byte_content): + """Deserializes byte content that is a JSON encoding. + + Args: + byte_content: The byte content of a JSON response. + + Returns: + The deserialized python object decoded from JSON. + """ + return json.loads(byte_content.decode("utf-8")) + + def test_eligible_features_from_example_proto(self): + example = test_utils.make_fake_example(single_int_val=2) + self.plugin.examples = [example] + + response = self.server.get("/data/plugin/whatif/eligible_features") + self.assertEqual(200, response.status_code) + + # Returns a list of dict objects that have been sorted by feature_name. + data = self._DeserializeResponse(response.get_data()) + + sorted_feature_names = [ + "non_numeric", + "repeated_float", + "repeated_int", + "single_float", + "single_int", + ] + self.assertEqual(sorted_feature_names, [d["name"] for d in data]) + np.testing.assert_almost_equal( + [-1, 1.0, 10, 24.5, 2.0], [d.get("observedMin", -1) for d in data] + ) + np.testing.assert_almost_equal( + [-1, 4.0, 20, 24.5, 2.0], [d.get("observedMax", -1) for d in data] + ) + + # Test that only non_numeric feature has samples. + self.assertFalse(any(d.get("samples") for d in data[1:])) + self.assertEqual(["cat"], data[0]["samples"]) + + @mock.patch.object(inference_utils, "mutant_charts_for_feature") + def test_infer_mutants_handler(self, mock_mutant_charts_for_feature): + + # A no-op that just passes the example passed to mutant_charts_for_feature + # back through. This tests that the URL parameters get processed properly + # within infer_mutants_handler. + def pass_through(example, feature_name, serving_bundles, viz_params): + return { + "example": str(example), + "feature_name": feature_name, + "serving_bundles": [ + { + "inference_address": serving_bundles[ + 0 + ].inference_address, + "model_name": serving_bundles[0].model_name, + "model_type": serving_bundles[0].model_type, + } + ], + "viz_params": { + "x_min": viz_params.x_min, + "x_max": viz_params.x_max, + }, + } + + mock_mutant_charts_for_feature.side_effect = pass_through + + example = test_utils.make_fake_example() + self.plugin.examples = [example] + + response = self.server.get( + "/data/plugin/whatif/infer_mutants?" + + urllib_parse.urlencode( + { + "feature_name": "single_int", + "model_name": "/ml/cassandrax/iris_classification", + "inference_address": "ml-serving-temp.prediction", + "model_type": "classification", + "model_version": ",", + "model_signature": ",", + "x_min": "-10", + "x_max": "10", + } + ) + ) + result = self._DeserializeResponse(response.get_data()) + self.assertEqual(str([example]), result["example"]) + self.assertEqual("single_int", result["feature_name"]) + self.assertEqual( + "ml-serving-temp.prediction", + result["serving_bundles"][0]["inference_address"], + ) + self.assertEqual( + "/ml/cassandrax/iris_classification", + result["serving_bundles"][0]["model_name"], + ) + self.assertEqual( + "classification", result["serving_bundles"][0]["model_type"] + ) + self.assertAlmostEqual(-10, result["viz_params"]["x_min"]) + self.assertAlmostEqual(10, result["viz_params"]["x_max"]) + + @mock.patch.object(inference_utils, "sort_eligible_features") + @mock.patch.object(inference_utils, "mutant_charts_for_feature") + def test_infer( + self, mock_mutant_charts_for_feature, mock_sort_eligible_features + ): + self.plugin.examples = [ + self.get_fake_example(0), + self.get_fake_example(1), + self.get_fake_example(2), + ] + + mock_mutant_charts_for_feature.return_value = [] + sorted_features_list = [ + {"name": "feat1", "interestingness": 0.2}, + {"name": "feat2", "interestingness": 0.1}, + ] + mock_sort_eligible_features.return_value = sorted_features_list + + url_options = urllib_parse.urlencode( + { + "inference_address": "addr", + "model_name": "name", + "model_type": "regression", + "model_version": "", + "model_signature": "", + } + ) + response = self.server.get( + "/data/plugin/whatif/sort_eligible_features?" + url_options + ) + + self.assertEqual(200, response.status_code) + self.assertEqual(0, len(self.plugin.updated_example_indices)) + output_list = json.loads(response.get_data().decode("utf-8")) + self.assertEquals("feat1", output_list[0]["name"]) + self.assertEquals("feat2", output_list[1]["name"]) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/interactive_inference/utils/common_utils.py b/tensorboard/plugins/interactive_inference/utils/common_utils.py index ce916a0bc7..3ca6b21ba3 100644 --- a/tensorboard/plugins/interactive_inference/utils/common_utils.py +++ b/tensorboard/plugins/interactive_inference/utils/common_utils.py @@ -19,73 +19,78 @@ class InvalidUserInputError(Exception): - """An exception to throw if user input is detected to be invalid. + """An exception to throw if user input is detected to be invalid. - Attributes: - original_exception: The triggering `Exception` object to be wrapped, or - a string. - """ + Attributes: + original_exception: The triggering `Exception` object to be wrapped, or + a string. + """ - def __init__(self, original_exception): - """Inits InvalidUserInputError.""" - self.original_exception = original_exception - Exception.__init__(self) + def __init__(self, original_exception): + """Inits InvalidUserInputError.""" + self.original_exception = original_exception + Exception.__init__(self) - @property - def message(self): - return 'InvalidUserInputError: ' + str(self.original_exception) + @property + def message(self): + return "InvalidUserInputError: " + str(self.original_exception) def convert_predict_response(pred, serving_bundle): - """Converts a PredictResponse to ClassificationResponse or RegressionResponse. + """Converts a PredictResponse to ClassificationResponse or + RegressionResponse. - Args: - pred: PredictResponse to convert. - serving_bundle: A `ServingBundle` object that contains the information about - the serving request that the response was generated by. + Args: + pred: PredictResponse to convert. + serving_bundle: A `ServingBundle` object that contains the information about + the serving request that the response was generated by. + + Returns: + A ClassificationResponse or RegressionResponse. + """ + output = pred.outputs[serving_bundle.predict_output_tensor] + raw_output = output.float_val + if serving_bundle.model_type == "classification": + values = [] + for example_index in range(output.tensor_shape.dim[0].size): + start = example_index * output.tensor_shape.dim[1].size + values.append( + raw_output[start : start + output.tensor_shape.dim[1].size] + ) + else: + values = raw_output + return convert_prediction_values(values, serving_bundle, pred.model_spec) - Returns: - A ClassificationResponse or RegressionResponse. - """ - output = pred.outputs[serving_bundle.predict_output_tensor] - raw_output = output.float_val - if serving_bundle.model_type == 'classification': - values = [] - for example_index in range(output.tensor_shape.dim[0].size): - start = example_index * output.tensor_shape.dim[1].size - values.append(raw_output[start:start + output.tensor_shape.dim[1].size]) - else: - values = raw_output - return convert_prediction_values(values, serving_bundle, pred.model_spec) def convert_prediction_values(values, serving_bundle, model_spec=None): - """Converts tensor values into ClassificationResponse or RegressionResponse. + """Converts tensor values into ClassificationResponse or + RegressionResponse. - Args: - values: For classification, a 2D list of numbers. The first dimension is for - each example being predicted. The second dimension are the probabilities - for each class ID in the prediction. For regression, a 1D list of numbers, - with a regression score for each example being predicted. - serving_bundle: A `ServingBundle` object that contains the information about - the serving request that the response was generated by. - model_spec: Optional model spec to put into the response. + Args: + values: For classification, a 2D list of numbers. The first dimension is for + each example being predicted. The second dimension are the probabilities + for each class ID in the prediction. For regression, a 1D list of numbers, + with a regression score for each example being predicted. + serving_bundle: A `ServingBundle` object that contains the information about + the serving request that the response was generated by. + model_spec: Optional model spec to put into the response. - Returns: - A ClassificationResponse or RegressionResponse. - """ - if serving_bundle.model_type == 'classification': - response = classification_pb2.ClassificationResponse() - for example_index in range(len(values)): - classification = response.result.classifications.add() - for class_index in range(len(values[example_index])): - class_score = classification.classes.add() - class_score.score = values[example_index][class_index] - class_score.label = str(class_index) - else: - response = regression_pb2.RegressionResponse() - for example_index in range(len(values)): - regression = response.result.regressions.add() - regression.value = values[example_index] - if model_spec: - response.model_spec.CopyFrom(model_spec) - return response + Returns: + A ClassificationResponse or RegressionResponse. + """ + if serving_bundle.model_type == "classification": + response = classification_pb2.ClassificationResponse() + for example_index in range(len(values)): + classification = response.result.classifications.add() + for class_index in range(len(values[example_index])): + class_score = classification.classes.add() + class_score.score = values[example_index][class_index] + class_score.label = str(class_index) + else: + response = regression_pb2.RegressionResponse() + for example_index in range(len(values)): + regression = response.result.regressions.add() + regression.value = values[example_index] + if model_spec: + response.model_spec.CopyFrom(model_spec) + return response diff --git a/tensorboard/plugins/interactive_inference/utils/inference_utils.py b/tensorboard/plugins/interactive_inference/utils/inference_utils.py index 6b552b6767..c612052db6 100644 --- a/tensorboard/plugins/interactive_inference/utils/inference_utils.py +++ b/tensorboard/plugins/interactive_inference/utils/inference_utils.py @@ -37,607 +37,690 @@ class VizParams(object): - """Light-weight class for holding UI state. - - Attributes: - x_min: The minimum value to use to generate mutants for the feature - (as specified the user on the UI). - x_max: The maximum value to use to generate mutants for the feature - (as specified the user on the UI). - examples: A list of examples to scan in order to generate statistics for - mutants. - num_mutants: Int number of mutants to generate per chart. - feature_index_pattern: String that specifies a restricted set of indices - of the feature to generate mutants for (useful for features that is a - long repeated field. See `convert_pattern_to_indices` for more details. - """ - - def __init__(self, x_min, x_max, examples, num_mutants, - feature_index_pattern): - """Inits VizParams may raise InvalidUserInputError for bad user inputs.""" - - def to_float_or_none(x): - try: - return float(x) - except (ValueError, TypeError): - return None - - def to_int(x): - try: - return int(x) - except (ValueError, TypeError) as e: - raise common_utils.InvalidUserInputError(e) - - def convert_pattern_to_indices(pattern): - """Converts a printer-page-style pattern and returns a list of indices. - - Args: - pattern: A printer-page-style pattern with only numeric characters, - commas, dashes, and optionally spaces. - - For example, a pattern of '0,2,4-6' would yield [0, 2, 4, 5, 6]. - - Returns: - A list of indices represented by the pattern. - """ - pieces = [token.strip() for token in pattern.split(',')] - indices = [] - for piece in pieces: - if '-' in piece: - lower, upper = [int(x.strip()) for x in piece.split('-', 1)] - indices.extend(range(lower, upper + 1)) - else: - indices.append(int(piece.strip())) - return sorted(indices) - - self.x_min = to_float_or_none(x_min) - self.x_max = to_float_or_none(x_max) - self.examples = examples - self.num_mutants = to_int(num_mutants) + """Light-weight class for holding UI state. + + Attributes: + x_min: The minimum value to use to generate mutants for the feature + (as specified the user on the UI). + x_max: The maximum value to use to generate mutants for the feature + (as specified the user on the UI). + examples: A list of examples to scan in order to generate statistics for + mutants. + num_mutants: Int number of mutants to generate per chart. + feature_index_pattern: String that specifies a restricted set of indices + of the feature to generate mutants for (useful for features that is a + long repeated field. See `convert_pattern_to_indices` for more details. + """ - # By default, there are no specific user-requested feature indices. - self.feature_indices = [] - if feature_index_pattern: - try: - self.feature_indices = convert_pattern_to_indices( - feature_index_pattern) - except ValueError as e: - # If the user-requested range is invalid, use the default range. - pass + def __init__( + self, x_min, x_max, examples, num_mutants, feature_index_pattern + ): + """Inits VizParams may raise InvalidUserInputError for bad user + inputs.""" + + def to_float_or_none(x): + try: + return float(x) + except (ValueError, TypeError): + return None + + def to_int(x): + try: + return int(x) + except (ValueError, TypeError) as e: + raise common_utils.InvalidUserInputError(e) + + def convert_pattern_to_indices(pattern): + """Converts a printer-page-style pattern and returns a list of + indices. + + Args: + pattern: A printer-page-style pattern with only numeric characters, + commas, dashes, and optionally spaces. + + For example, a pattern of '0,2,4-6' would yield [0, 2, 4, 5, 6]. + + Returns: + A list of indices represented by the pattern. + """ + pieces = [token.strip() for token in pattern.split(",")] + indices = [] + for piece in pieces: + if "-" in piece: + lower, upper = [int(x.strip()) for x in piece.split("-", 1)] + indices.extend(range(lower, upper + 1)) + else: + indices.append(int(piece.strip())) + return sorted(indices) + + self.x_min = to_float_or_none(x_min) + self.x_max = to_float_or_none(x_max) + self.examples = examples + self.num_mutants = to_int(num_mutants) + + # By default, there are no specific user-requested feature indices. + self.feature_indices = [] + if feature_index_pattern: + try: + self.feature_indices = convert_pattern_to_indices( + feature_index_pattern + ) + except ValueError as e: + # If the user-requested range is invalid, use the default range. + pass class OriginalFeatureList(object): - """Light-weight class for holding the original values in the example. + """Light-weight class for holding the original values in the example. - Should not be created by hand, but rather generated via - `parse_original_feature_from_example`. Just used to hold inferred info - about the example. + Should not be created by hand, but rather generated via + `parse_original_feature_from_example`. Just used to hold inferred info + about the example. - Attributes: - feature_name: String name of the feature. - original_value: The value of the feature in the original example. - feature_type: One of ['int64_list', 'float_list']. + Attributes: + feature_name: String name of the feature. + original_value: The value of the feature in the original example. + feature_type: One of ['int64_list', 'float_list']. - Raises: - ValueError: If OriginalFeatureList fails init validation. - """ + Raises: + ValueError: If OriginalFeatureList fails init validation. + """ - def __init__(self, feature_name, original_value, feature_type): - """Inits OriginalFeatureList.""" - self.feature_name = feature_name - self.original_value = [ - ensure_not_binary(value) for value in original_value] - self.feature_type = feature_type + def __init__(self, feature_name, original_value, feature_type): + """Inits OriginalFeatureList.""" + self.feature_name = feature_name + self.original_value = [ + ensure_not_binary(value) for value in original_value + ] + self.feature_type = feature_type - # Derived attributes. - self.length = sum(1 for _ in original_value) + # Derived attributes. + self.length = sum(1 for _ in original_value) class MutantFeatureValue(object): - """Light-weight class for holding mutated values in the example. - - Should not be created by hand but rather generated via `make_mutant_features`. + """Light-weight class for holding mutated values in the example. - Used to represent a "mutant example": an example that is mostly identical to - the user-provided original example, but has one feature that is different. + Should not be created by hand but rather generated via `make_mutant_features`. - Attributes: - original_feature: An `OriginalFeatureList` object representing the feature - to create mutants for. - index: The index of the feature to create mutants for. The feature can be - a repeated field, and we want to plot mutations of its various indices. - mutant_value: The proposed mutant value for the given index. + Used to represent a "mutant example": an example that is mostly identical to + the user-provided original example, but has one feature that is different. - Raises: - ValueError: If MutantFeatureValue fails init validation. - """ + Attributes: + original_feature: An `OriginalFeatureList` object representing the feature + to create mutants for. + index: The index of the feature to create mutants for. The feature can be + a repeated field, and we want to plot mutations of its various indices. + mutant_value: The proposed mutant value for the given index. - def __init__(self, original_feature, index, mutant_value): - """Inits MutantFeatureValue.""" - if not isinstance(original_feature, OriginalFeatureList): - raise ValueError( - 'original_feature should be `OriginalFeatureList`, but had ' - 'unexpected type: {}'.format(type(original_feature))) - self.original_feature = original_feature + Raises: + ValueError: If MutantFeatureValue fails init validation. + """ - if index is not None and not isinstance(index, integer_types): - raise ValueError( - 'index should be None or int, but had unexpected type: {}'.format( - type(index))) - self.index = index - self.mutant_value = (mutant_value.encode() - if isinstance(mutant_value, string_types) else mutant_value) + def __init__(self, original_feature, index, mutant_value): + """Inits MutantFeatureValue.""" + if not isinstance(original_feature, OriginalFeatureList): + raise ValueError( + "original_feature should be `OriginalFeatureList`, but had " + "unexpected type: {}".format(type(original_feature)) + ) + self.original_feature = original_feature + + if index is not None and not isinstance(index, integer_types): + raise ValueError( + "index should be None or int, but had unexpected type: {}".format( + type(index) + ) + ) + self.index = index + self.mutant_value = ( + mutant_value.encode() + if isinstance(mutant_value, string_types) + else mutant_value + ) class ServingBundle(object): - """Light-weight class for holding info to make the inference request. - - Attributes: - inference_address: An address (such as "hostname:port") to send inference - requests to. - model_name: The Servo model name. - model_type: One of ['classification', 'regression']. - model_version: The version number of the model as a string. If set to an - empty string, the latest model will be used. - signature: The signature of the model to infer. If set to an empty string, - the default signuature will be used. - use_predict: If true then use the servo Predict API as opposed to - Classification or Regression. - predict_input_tensor: The name of the input tensor to parse when using the - Predict API. - predict_output_tensor: The name of the output tensor to parse when using the - Predict API. - estimator: An estimator to use instead of calling an external model. - feature_spec: A feature spec for use with the estimator. - custom_predict_fn: A custom prediction function. - - Raises: - ValueError: If ServingBundle fails init validation. - """ - - def __init__(self, inference_address, model_name, model_type, model_version, - signature, use_predict, predict_input_tensor, - predict_output_tensor, estimator=None, feature_spec=None, - custom_predict_fn=None): - """Inits ServingBundle.""" - if not isinstance(inference_address, string_types): - raise ValueError('Invalid inference_address has type: {}'.format( - type(inference_address))) - # Clean the inference_address so that SmartStub likes it. - self.inference_address = inference_address.replace('http://', '').replace( - 'https://', '') - - if not isinstance(model_name, string_types): - raise ValueError('Invalid model_name has type: {}'.format( - type(model_name))) - self.model_name = model_name - - if model_type not in ['classification', 'regression']: - raise ValueError('Invalid model_type: {}'.format(model_type)) - self.model_type = model_type - - self.model_version = int(model_version) if model_version else None - - self.signature = signature if signature else None - - self.use_predict = use_predict - self.predict_input_tensor = predict_input_tensor - self.predict_output_tensor = predict_output_tensor - self.estimator = estimator - self.feature_spec = feature_spec - self.custom_predict_fn = custom_predict_fn + """Light-weight class for holding info to make the inference request. + + Attributes: + inference_address: An address (such as "hostname:port") to send inference + requests to. + model_name: The Servo model name. + model_type: One of ['classification', 'regression']. + model_version: The version number of the model as a string. If set to an + empty string, the latest model will be used. + signature: The signature of the model to infer. If set to an empty string, + the default signuature will be used. + use_predict: If true then use the servo Predict API as opposed to + Classification or Regression. + predict_input_tensor: The name of the input tensor to parse when using the + Predict API. + predict_output_tensor: The name of the output tensor to parse when using the + Predict API. + estimator: An estimator to use instead of calling an external model. + feature_spec: A feature spec for use with the estimator. + custom_predict_fn: A custom prediction function. + + Raises: + ValueError: If ServingBundle fails init validation. + """ + + def __init__( + self, + inference_address, + model_name, + model_type, + model_version, + signature, + use_predict, + predict_input_tensor, + predict_output_tensor, + estimator=None, + feature_spec=None, + custom_predict_fn=None, + ): + """Inits ServingBundle.""" + if not isinstance(inference_address, string_types): + raise ValueError( + "Invalid inference_address has type: {}".format( + type(inference_address) + ) + ) + # Clean the inference_address so that SmartStub likes it. + self.inference_address = inference_address.replace( + "http://", "" + ).replace("https://", "") + + if not isinstance(model_name, string_types): + raise ValueError( + "Invalid model_name has type: {}".format(type(model_name)) + ) + self.model_name = model_name + + if model_type not in ["classification", "regression"]: + raise ValueError("Invalid model_type: {}".format(model_type)) + self.model_type = model_type + + self.model_version = int(model_version) if model_version else None + + self.signature = signature if signature else None + + self.use_predict = use_predict + self.predict_input_tensor = predict_input_tensor + self.predict_output_tensor = predict_output_tensor + self.estimator = estimator + self.feature_spec = feature_spec + self.custom_predict_fn = custom_predict_fn def ensure_not_binary(value): - """Return non-binary version of value.""" - try: - return value.decode() if isinstance(value, binary_type) else value - except UnicodeDecodeError: - # If the value cannot be decoded as a string (such as an encoded image), - # then just return the value. - return value + """Return non-binary version of value.""" + try: + return value.decode() if isinstance(value, binary_type) else value + except UnicodeDecodeError: + # If the value cannot be decoded as a string (such as an encoded image), + # then just return the value. + return value def proto_value_for_feature(example, feature_name): - """Get the value of a feature from Example regardless of feature type.""" - feature = get_example_features(example)[feature_name] - if feature is None: - raise ValueError('Feature {} is not on example proto.'.format(feature_name)) - feature_type = feature.WhichOneof('kind') - if feature_type is None: - raise ValueError('Feature {} on example proto has no declared type.'.format( - feature_name)) - return getattr(feature, feature_type).value + """Get the value of a feature from Example regardless of feature type.""" + feature = get_example_features(example)[feature_name] + if feature is None: + raise ValueError( + "Feature {} is not on example proto.".format(feature_name) + ) + feature_type = feature.WhichOneof("kind") + if feature_type is None: + raise ValueError( + "Feature {} on example proto has no declared type.".format( + feature_name + ) + ) + return getattr(feature, feature_type).value def parse_original_feature_from_example(example, feature_name): - """Returns an `OriginalFeatureList` for the specified feature_name. + """Returns an `OriginalFeatureList` for the specified feature_name. - Args: - example: An example. - feature_name: A string feature name. + Args: + example: An example. + feature_name: A string feature name. - Returns: - A filled in `OriginalFeatureList` object representing the feature. - """ - feature = get_example_features(example)[feature_name] - feature_type = feature.WhichOneof('kind') - original_value = proto_value_for_feature(example, feature_name) + Returns: + A filled in `OriginalFeatureList` object representing the feature. + """ + feature = get_example_features(example)[feature_name] + feature_type = feature.WhichOneof("kind") + original_value = proto_value_for_feature(example, feature_name) - return OriginalFeatureList(feature_name, original_value, feature_type) + return OriginalFeatureList(feature_name, original_value, feature_type) def wrap_inference_results(inference_result_proto): - """Returns packaged inference results from the provided proto. + """Returns packaged inference results from the provided proto. - Args: - inference_result_proto: The classification or regression response proto. + Args: + inference_result_proto: The classification or regression response proto. - Returns: - An InferenceResult proto with the result from the response. - """ - inference_proto = inference_pb2.InferenceResult() - if isinstance(inference_result_proto, - classification_pb2.ClassificationResponse): - inference_proto.classification_result.CopyFrom( - inference_result_proto.result) - elif isinstance(inference_result_proto, regression_pb2.RegressionResponse): - inference_proto.regression_result.CopyFrom(inference_result_proto.result) - return inference_proto + Returns: + An InferenceResult proto with the result from the response. + """ + inference_proto = inference_pb2.InferenceResult() + if isinstance( + inference_result_proto, classification_pb2.ClassificationResponse + ): + inference_proto.classification_result.CopyFrom( + inference_result_proto.result + ) + elif isinstance(inference_result_proto, regression_pb2.RegressionResponse): + inference_proto.regression_result.CopyFrom( + inference_result_proto.result + ) + return inference_proto def get_numeric_feature_names(example): - """Returns a list of feature names for float and int64 type features. + """Returns a list of feature names for float and int64 type features. - Args: - example: An example. + Args: + example: An example. - Returns: - A list of strings of the names of numeric features. - """ - numeric_features = ('float_list', 'int64_list') - features = get_example_features(example) - return sorted([ - feature_name for feature_name in features - if features[feature_name].WhichOneof('kind') in numeric_features - ]) + Returns: + A list of strings of the names of numeric features. + """ + numeric_features = ("float_list", "int64_list") + features = get_example_features(example) + return sorted( + [ + feature_name + for feature_name in features + if features[feature_name].WhichOneof("kind") in numeric_features + ] + ) def get_categorical_feature_names(example): - """Returns a list of feature names for byte type features. + """Returns a list of feature names for byte type features. - Args: - example: An example. + Args: + example: An example. - Returns: - A list of categorical feature names (e.g. ['education', 'marital_status'] ) - """ - features = get_example_features(example) - return sorted([ - feature_name for feature_name in features - if features[feature_name].WhichOneof('kind') == 'bytes_list' - ]) + Returns: + A list of categorical feature names (e.g. ['education', 'marital_status'] ) + """ + features = get_example_features(example) + return sorted( + [ + feature_name + for feature_name in features + if features[feature_name].WhichOneof("kind") == "bytes_list" + ] + ) def get_numeric_features_to_observed_range(examples): - """Returns numerical features and their observed ranges. - - Args: - examples: Examples to read to get ranges. - - Returns: - A dict mapping feature_name -> {'observedMin': 'observedMax': } dicts, - with a key for each numerical feature. - """ - observed_features = collections.defaultdict(list) # name -> [value, ] - for example in examples: - for feature_name in get_numeric_feature_names(example): - original_feature = parse_original_feature_from_example( - example, feature_name) - observed_features[feature_name].extend(original_feature.original_value) - return { - feature_name: { - 'observedMin': min(feature_values), - 'observedMax': max(feature_values), - } - for feature_name, feature_values in iteritems(observed_features) - } + """Returns numerical features and their observed ranges. + + Args: + examples: Examples to read to get ranges. + + Returns: + A dict mapping feature_name -> {'observedMin': 'observedMax': } dicts, + with a key for each numerical feature. + """ + observed_features = collections.defaultdict(list) # name -> [value, ] + for example in examples: + for feature_name in get_numeric_feature_names(example): + original_feature = parse_original_feature_from_example( + example, feature_name + ) + observed_features[feature_name].extend( + original_feature.original_value + ) + return { + feature_name: { + "observedMin": min(feature_values), + "observedMax": max(feature_values), + } + for feature_name, feature_values in iteritems(observed_features) + } def get_categorical_features_to_sampling(examples, top_k): - """Returns categorical features and a sampling of their most-common values. - - The results of this slow function are used by the visualization repeatedly, - so the results are cached. - - Args: - examples: Examples to read to get feature samples. - top_k: Max number of samples to return per feature. - - Returns: - A dict of feature_name -> {'samples': ['Married-civ-spouse', - 'Never-married', 'Divorced']}. - - There is one key for each categorical feature. - - Currently, the inner dict just has one key, but this structure leaves room - for further expansion, and mirrors the structure used by - `get_numeric_features_to_observed_range`. - """ - observed_features = collections.defaultdict(list) # name -> [value, ] - for example in examples: - for feature_name in get_categorical_feature_names(example): - original_feature = parse_original_feature_from_example( - example, feature_name) - observed_features[feature_name].extend(original_feature.original_value) - - result = {} - for feature_name, feature_values in sorted(iteritems(observed_features)): - samples = [ - word - for word, count in collections.Counter(feature_values).most_common( - top_k) if count > 1 - ] - if samples: - result[feature_name] = {'samples': samples} - return result + """Returns categorical features and a sampling of their most-common values. + The results of this slow function are used by the visualization repeatedly, + so the results are cached. -def make_mutant_features(original_feature, index_to_mutate, viz_params): - """Return a list of `MutantFeatureValue`s that are variants of original.""" - lower = viz_params.x_min - upper = viz_params.x_max - examples = viz_params.examples - num_mutants = viz_params.num_mutants - - if original_feature.feature_type == 'float_list': - return [ - MutantFeatureValue(original_feature, index_to_mutate, value) - for value in np.linspace(lower, upper, num_mutants) - ] - elif original_feature.feature_type == 'int64_list': - mutant_values = np.linspace(int(lower), int(upper), - num_mutants).astype(int).tolist() - # Remove duplicates that can occur due to integer constraint. - mutant_values = sorted(set(mutant_values)) - return [ - MutantFeatureValue(original_feature, index_to_mutate, value) - for value in mutant_values - ] - elif original_feature.feature_type == 'bytes_list': - feature_to_samples = get_categorical_features_to_sampling( - examples, num_mutants) - - # `mutant_values` looks like: - # [['Married-civ-spouse'], ['Never-married'], ['Divorced'], ['Separated']] - mutant_values = feature_to_samples[original_feature.feature_name]['samples'] - return [ - MutantFeatureValue(original_feature, None, value) - for value in mutant_values - ] - else: - raise ValueError('Malformed original feature had type of: ' + - original_feature.feature_type) - - -def make_mutant_tuples(example_protos, original_feature, index_to_mutate, - viz_params): - """Return a list of `MutantFeatureValue`s and a list of mutant Examples. - - Args: - example_protos: The examples to mutate. - original_feature: A `OriginalFeatureList` that encapsulates the feature to - mutate. - index_to_mutate: The index of the int64_list or float_list to mutate. - viz_params: A `VizParams` object that contains the UI state of the request. - - Returns: - A list of `MutantFeatureValue`s and a list of mutant examples. - """ - mutant_features = make_mutant_features(original_feature, index_to_mutate, - viz_params) - mutant_examples = [] - for example_proto in example_protos: - for mutant_feature in mutant_features: - copied_example = copy.deepcopy(example_proto) - feature_name = mutant_feature.original_feature.feature_name - - try: - feature_list = proto_value_for_feature(copied_example, feature_name) - if index_to_mutate is None: - new_values = mutant_feature.mutant_value - else: - new_values = list(feature_list) - new_values[index_to_mutate] = mutant_feature.mutant_value - - del feature_list[:] - feature_list.extend(new_values) - mutant_examples.append(copied_example) - except (ValueError, IndexError): - # If the mutant value can't be set, still add the example to the - # mutant_example even though no change was made. This is necessary to - # allow for computation of global PD plots when not all examples have - # the same number of feature values for a feature. - mutant_examples.append(copied_example) - - return mutant_features, mutant_examples - - -def mutant_charts_for_feature(example_protos, feature_name, serving_bundles, - viz_params): - """Returns JSON formatted for rendering all charts for a feature. - - Args: - example_proto: The example protos to mutate. - feature_name: The string feature name to mutate. - serving_bundles: One `ServingBundle` object per model, that contains the - information to make the serving request. - viz_params: A `VizParams` object that contains the UI state of the request. - - Raises: - InvalidUserInputError if `viz_params.feature_index_pattern` requests out of - range indices for `feature_name` within `example_proto`. - - Returns: - A JSON-able dict for rendering a single mutant chart. parsed in - `tf-inference-dashboard.html`. - { - 'chartType': 'numeric', # oneof('numeric', 'categorical') - 'data': [A list of data] # parseable by vz-line-chart or vz-bar-chart - } - """ - - def chart_for_index(index_to_mutate): - mutant_features, mutant_examples = make_mutant_tuples( - example_protos, original_feature, index_to_mutate, viz_params) - - charts = [] - for serving_bundle in serving_bundles: - (inference_result_proto, _) = run_inference( - mutant_examples, serving_bundle) - charts.append(make_json_formatted_for_single_chart( - mutant_features, inference_result_proto, index_to_mutate)) - return charts - try: - original_feature = parse_original_feature_from_example( - example_protos[0], feature_name) - except ValueError as e: - return { - 'chartType': 'categorical', - 'data': [] - } + Args: + examples: Examples to read to get feature samples. + top_k: Max number of samples to return per feature. - indices_to_mutate = viz_params.feature_indices or range( - original_feature.length) - chart_type = ('categorical' if original_feature.feature_type == 'bytes_list' - else 'numeric') + Returns: + A dict of feature_name -> {'samples': ['Married-civ-spouse', + 'Never-married', 'Divorced']}. - try: - return { - 'chartType': chart_type, - 'data': [ - chart_for_index(index_to_mutate) - for index_to_mutate in indices_to_mutate + There is one key for each categorical feature. + + Currently, the inner dict just has one key, but this structure leaves room + for further expansion, and mirrors the structure used by + `get_numeric_features_to_observed_range`. + """ + observed_features = collections.defaultdict(list) # name -> [value, ] + for example in examples: + for feature_name in get_categorical_feature_names(example): + original_feature = parse_original_feature_from_example( + example, feature_name + ) + observed_features[feature_name].extend( + original_feature.original_value + ) + + result = {} + for feature_name, feature_values in sorted(iteritems(observed_features)): + samples = [ + word + for word, count in collections.Counter(feature_values).most_common( + top_k + ) + if count > 1 ] - } - except IndexError as e: - raise common_utils.InvalidUserInputError(e) - - -def make_json_formatted_for_single_chart(mutant_features, - inference_result_proto, - index_to_mutate): - """Returns JSON formatted for a single mutant chart. - - Args: - mutant_features: An iterable of `MutantFeatureValue`s representing the - X-axis. - inference_result_proto: A ClassificationResponse or RegressionResponse - returned by Servo, representing the Y-axis. - It contains one 'classification' or 'regression' for every Example that - was sent for inference. The length of that field should be the same length - of mutant_features. - index_to_mutate: The index of the feature being mutated for this chart. - - Returns: - A JSON-able dict for rendering a single mutant chart, parseable by - `vz-line-chart` or `vz-bar-chart`. - """ - x_label = 'step' - y_label = 'scalar' - - if isinstance(inference_result_proto, - classification_pb2.ClassificationResponse): - # classification_label -> [{x_label: y_label:}] - series = {} - - # ClassificationResponse has a separate probability for each label - for idx, classification in enumerate( - inference_result_proto.result.classifications): - # For each example to use for mutant inference, we create a copied example - # with the feature in question changed to each possible mutant value. So - # when we get the inferences back, we get num_examples*num_mutants - # results. So, modding by len(mutant_features) allows us to correctly - # lookup the mutant value for each inference. - mutant_feature = mutant_features[idx % len(mutant_features)] - for class_index, classification_class in enumerate( - classification.classes): - # Fill in class index when labels are missing - if classification_class.label == '': - classification_class.label = str(class_index) - # Special case to not include the "0" class in binary classification. - # Since that just results in a chart that is symmetric around 0.5. - if len( - classification.classes) == 2 and classification_class.label == '0': - continue - key = classification_class.label - if index_to_mutate: - key += ' (index %d)' % index_to_mutate - if not key in series: - series[key] = {} - mutant_val = ensure_not_binary(mutant_feature.mutant_value) - if not mutant_val in series[key]: - series[key][mutant_val] = [] - series[key][mutant_val].append( - classification_class.score) - - # Post-process points to have separate list for each class - return_series = collections.defaultdict(list) - for key, mutant_values in iteritems(series): - for value, y_list in iteritems(mutant_values): - return_series[key].append({ - x_label: value, - y_label: sum(y_list) / float(len(y_list)) - }) - return_series[key].sort(key=lambda p: p[x_label]) - return return_series - - elif isinstance(inference_result_proto, regression_pb2.RegressionResponse): - points = {} - - for idx, regression in enumerate(inference_result_proto.result.regressions): - # For each example to use for mutant inference, we create a copied example - # with the feature in question changed to each possible mutant value. So - # when we get the inferences back, we get num_examples*num_mutants - # results. So, modding by len(mutant_features) allows us to correctly - # lookup the mutant value for each inference. - mutant_feature = mutant_features[idx % len(mutant_features)] - mutant_val = ensure_not_binary(mutant_feature.mutant_value) - if not mutant_val in points: - points[mutant_val] = [] - points[mutant_val].append(regression.value) - key = 'value' - if (index_to_mutate != 0): - key += ' (index %d)' % index_to_mutate - list_of_points = [] - for value, y_list in iteritems(points): - list_of_points.append({ - x_label: value, - y_label: sum(y_list) / float(len(y_list)) - }) - list_of_points.sort(key=lambda p: p[x_label]) - return {key: list_of_points} - - else: - raise NotImplementedError('Only classification and regression implemented.') + if samples: + result[feature_name] = {"samples": samples} + return result + + +def make_mutant_features(original_feature, index_to_mutate, viz_params): + """Return a list of `MutantFeatureValue`s that are variants of original.""" + lower = viz_params.x_min + upper = viz_params.x_max + examples = viz_params.examples + num_mutants = viz_params.num_mutants + + if original_feature.feature_type == "float_list": + return [ + MutantFeatureValue(original_feature, index_to_mutate, value) + for value in np.linspace(lower, upper, num_mutants) + ] + elif original_feature.feature_type == "int64_list": + mutant_values = ( + np.linspace(int(lower), int(upper), num_mutants) + .astype(int) + .tolist() + ) + # Remove duplicates that can occur due to integer constraint. + mutant_values = sorted(set(mutant_values)) + return [ + MutantFeatureValue(original_feature, index_to_mutate, value) + for value in mutant_values + ] + elif original_feature.feature_type == "bytes_list": + feature_to_samples = get_categorical_features_to_sampling( + examples, num_mutants + ) + + # `mutant_values` looks like: + # [['Married-civ-spouse'], ['Never-married'], ['Divorced'], ['Separated']] + mutant_values = feature_to_samples[original_feature.feature_name][ + "samples" + ] + return [ + MutantFeatureValue(original_feature, None, value) + for value in mutant_values + ] + else: + raise ValueError( + "Malformed original feature had type of: " + + original_feature.feature_type + ) + + +def make_mutant_tuples( + example_protos, original_feature, index_to_mutate, viz_params +): + """Return a list of `MutantFeatureValue`s and a list of mutant Examples. + + Args: + example_protos: The examples to mutate. + original_feature: A `OriginalFeatureList` that encapsulates the feature to + mutate. + index_to_mutate: The index of the int64_list or float_list to mutate. + viz_params: A `VizParams` object that contains the UI state of the request. + + Returns: + A list of `MutantFeatureValue`s and a list of mutant examples. + """ + mutant_features = make_mutant_features( + original_feature, index_to_mutate, viz_params + ) + mutant_examples = [] + for example_proto in example_protos: + for mutant_feature in mutant_features: + copied_example = copy.deepcopy(example_proto) + feature_name = mutant_feature.original_feature.feature_name + + try: + feature_list = proto_value_for_feature( + copied_example, feature_name + ) + if index_to_mutate is None: + new_values = mutant_feature.mutant_value + else: + new_values = list(feature_list) + new_values[index_to_mutate] = mutant_feature.mutant_value + + del feature_list[:] + feature_list.extend(new_values) + mutant_examples.append(copied_example) + except (ValueError, IndexError): + # If the mutant value can't be set, still add the example to the + # mutant_example even though no change was made. This is necessary to + # allow for computation of global PD plots when not all examples have + # the same number of feature values for a feature. + mutant_examples.append(copied_example) + + return mutant_features, mutant_examples + + +def mutant_charts_for_feature( + example_protos, feature_name, serving_bundles, viz_params +): + """Returns JSON formatted for rendering all charts for a feature. + + Args: + example_proto: The example protos to mutate. + feature_name: The string feature name to mutate. + serving_bundles: One `ServingBundle` object per model, that contains the + information to make the serving request. + viz_params: A `VizParams` object that contains the UI state of the request. + + Raises: + InvalidUserInputError if `viz_params.feature_index_pattern` requests out of + range indices for `feature_name` within `example_proto`. + + Returns: + A JSON-able dict for rendering a single mutant chart. parsed in + `tf-inference-dashboard.html`. + { + 'chartType': 'numeric', # oneof('numeric', 'categorical') + 'data': [A list of data] # parseable by vz-line-chart or vz-bar-chart + } + """ + + def chart_for_index(index_to_mutate): + mutant_features, mutant_examples = make_mutant_tuples( + example_protos, original_feature, index_to_mutate, viz_params + ) + + charts = [] + for serving_bundle in serving_bundles: + (inference_result_proto, _) = run_inference( + mutant_examples, serving_bundle + ) + charts.append( + make_json_formatted_for_single_chart( + mutant_features, inference_result_proto, index_to_mutate + ) + ) + return charts + + try: + original_feature = parse_original_feature_from_example( + example_protos[0], feature_name + ) + except ValueError as e: + return {"chartType": "categorical", "data": []} + + indices_to_mutate = viz_params.feature_indices or range( + original_feature.length + ) + chart_type = ( + "categorical" + if original_feature.feature_type == "bytes_list" + else "numeric" + ) + + try: + return { + "chartType": chart_type, + "data": [ + chart_for_index(index_to_mutate) + for index_to_mutate in indices_to_mutate + ], + } + except IndexError as e: + raise common_utils.InvalidUserInputError(e) + + +def make_json_formatted_for_single_chart( + mutant_features, inference_result_proto, index_to_mutate +): + """Returns JSON formatted for a single mutant chart. + + Args: + mutant_features: An iterable of `MutantFeatureValue`s representing the + X-axis. + inference_result_proto: A ClassificationResponse or RegressionResponse + returned by Servo, representing the Y-axis. + It contains one 'classification' or 'regression' for every Example that + was sent for inference. The length of that field should be the same length + of mutant_features. + index_to_mutate: The index of the feature being mutated for this chart. + + Returns: + A JSON-able dict for rendering a single mutant chart, parseable by + `vz-line-chart` or `vz-bar-chart`. + """ + x_label = "step" + y_label = "scalar" + + if isinstance( + inference_result_proto, classification_pb2.ClassificationResponse + ): + # classification_label -> [{x_label: y_label:}] + series = {} + + # ClassificationResponse has a separate probability for each label + for idx, classification in enumerate( + inference_result_proto.result.classifications + ): + # For each example to use for mutant inference, we create a copied example + # with the feature in question changed to each possible mutant value. So + # when we get the inferences back, we get num_examples*num_mutants + # results. So, modding by len(mutant_features) allows us to correctly + # lookup the mutant value for each inference. + mutant_feature = mutant_features[idx % len(mutant_features)] + for class_index, classification_class in enumerate( + classification.classes + ): + # Fill in class index when labels are missing + if classification_class.label == "": + classification_class.label = str(class_index) + # Special case to not include the "0" class in binary classification. + # Since that just results in a chart that is symmetric around 0.5. + if ( + len(classification.classes) == 2 + and classification_class.label == "0" + ): + continue + key = classification_class.label + if index_to_mutate: + key += " (index %d)" % index_to_mutate + if not key in series: + series[key] = {} + mutant_val = ensure_not_binary(mutant_feature.mutant_value) + if not mutant_val in series[key]: + series[key][mutant_val] = [] + series[key][mutant_val].append(classification_class.score) + + # Post-process points to have separate list for each class + return_series = collections.defaultdict(list) + for key, mutant_values in iteritems(series): + for value, y_list in iteritems(mutant_values): + return_series[key].append( + {x_label: value, y_label: sum(y_list) / float(len(y_list))} + ) + return_series[key].sort(key=lambda p: p[x_label]) + return return_series + + elif isinstance(inference_result_proto, regression_pb2.RegressionResponse): + points = {} + + for idx, regression in enumerate( + inference_result_proto.result.regressions + ): + # For each example to use for mutant inference, we create a copied example + # with the feature in question changed to each possible mutant value. So + # when we get the inferences back, we get num_examples*num_mutants + # results. So, modding by len(mutant_features) allows us to correctly + # lookup the mutant value for each inference. + mutant_feature = mutant_features[idx % len(mutant_features)] + mutant_val = ensure_not_binary(mutant_feature.mutant_value) + if not mutant_val in points: + points[mutant_val] = [] + points[mutant_val].append(regression.value) + key = "value" + if index_to_mutate != 0: + key += " (index %d)" % index_to_mutate + list_of_points = [] + for value, y_list in iteritems(points): + list_of_points.append( + {x_label: value, y_label: sum(y_list) / float(len(y_list))} + ) + list_of_points.sort(key=lambda p: p[x_label]) + return {key: list_of_points} + + else: + raise NotImplementedError( + "Only classification and regression implemented." + ) def get_example_features(example): - """Returns the non-sequence features from the provided example.""" - return (example.features.feature if isinstance(example, tf.train.Example) - else example.context.feature) + """Returns the non-sequence features from the provided example.""" + return ( + example.features.feature + if isinstance(example, tf.train.Example) + else example.context.feature + ) + def run_inference_for_inference_results(examples, serving_bundle): - """Calls servo and wraps the inference results.""" - (inference_result_proto, extra_results) = run_inference( - examples, serving_bundle) - inferences = wrap_inference_results(inference_result_proto) - infer_json = json_format.MessageToJson( - inferences, including_default_value_fields=True) - return json.loads(infer_json), extra_results + """Calls servo and wraps the inference results.""" + (inference_result_proto, extra_results) = run_inference( + examples, serving_bundle + ) + inferences = wrap_inference_results(inference_result_proto) + infer_json = json_format.MessageToJson( + inferences, including_default_value_fields=True + ) + return json.loads(infer_json), extra_results + def get_eligible_features(examples, num_mutants): - """Returns a list of JSON objects for each feature in the examples. + """Returns a list of JSON objects for each feature in the examples. This list is used to drive partial dependence plots in the plugin. @@ -650,84 +733,88 @@ def get_eligible_features(examples, num_mutants): Numeric features are represented as {name: observedMin: observedMax:}. Categorical features are repesented as {name: samples:[]}. """ - features_dict = ( - get_numeric_features_to_observed_range( - examples)) - - features_dict.update( - get_categorical_features_to_sampling( - examples, num_mutants)) - - # Massage the features_dict into a sorted list before returning because - # Polymer dom-repeat needs a list. - features_list = [] - for k, v in sorted(features_dict.items()): - v['name'] = k - features_list.append(v) - return features_list + features_dict = get_numeric_features_to_observed_range(examples) + + features_dict.update( + get_categorical_features_to_sampling(examples, num_mutants) + ) + + # Massage the features_dict into a sorted list before returning because + # Polymer dom-repeat needs a list. + features_list = [] + for k, v in sorted(features_dict.items()): + v["name"] = k + features_list.append(v) + return features_list + def sort_eligible_features(features_list, chart_data): - """Returns a sorted list of objects representing each feature. - - The list is sorted by interestingness in terms of the resulting change in - inference values across feature values, for partial dependence plots. - - Args: - features_list: A list of eligible features in the format of the return - from the get_eligible_features function. - chart_data: A dict of feature names to chart data, formatted as the - output from the mutant_charts_for_feature function. - - Returns: - A sorted list of the inputted features_list, with the addition of - an 'interestingness' key with a calculated number for feature feature. - The list is sorted with the feature with highest interestingness first. - """ - sorted_features_list = copy.deepcopy(features_list) - for feature in sorted_features_list: - name = feature['name'] - charts = chart_data[name] - max_measure = 0 - is_numeric = charts['chartType'] == 'numeric' - for models in charts['data']: - for chart in models: - for series in chart.values(): - if is_numeric: - # For numeric features, interestingness is the total Y distance - # traveled across the line chart. - measure = 0 - for i in range(len(series) - 1): - measure += abs(series[i]['scalar'] - series[i + 1]['scalar']) - else: - # For categorical features, interestingness is the difference - # between the min and max Y values in the chart, as interestingness - # for categorical charts shouldn't depend on the order of items - # being charted. - min_y = float("inf") - max_y = float("-inf") - for i in range(len(series)): - val = series[i]['scalar'] - if val < min_y: - min_y = val - if val > max_y: - max_y = val - measure = max_y - min_y - if measure > max_measure: - max_measure = measure - feature['interestingness'] = max_measure - - return sorted( - sorted_features_list, key=lambda x: x['interestingness'], reverse=True) + """Returns a sorted list of objects representing each feature. + + The list is sorted by interestingness in terms of the resulting change in + inference values across feature values, for partial dependence plots. + + Args: + features_list: A list of eligible features in the format of the return + from the get_eligible_features function. + chart_data: A dict of feature names to chart data, formatted as the + output from the mutant_charts_for_feature function. + + Returns: + A sorted list of the inputted features_list, with the addition of + an 'interestingness' key with a calculated number for feature feature. + The list is sorted with the feature with highest interestingness first. + """ + sorted_features_list = copy.deepcopy(features_list) + for feature in sorted_features_list: + name = feature["name"] + charts = chart_data[name] + max_measure = 0 + is_numeric = charts["chartType"] == "numeric" + for models in charts["data"]: + for chart in models: + for series in chart.values(): + if is_numeric: + # For numeric features, interestingness is the total Y distance + # traveled across the line chart. + measure = 0 + for i in range(len(series) - 1): + measure += abs( + series[i]["scalar"] - series[i + 1]["scalar"] + ) + else: + # For categorical features, interestingness is the difference + # between the min and max Y values in the chart, as interestingness + # for categorical charts shouldn't depend on the order of items + # being charted. + min_y = float("inf") + max_y = float("-inf") + for i in range(len(series)): + val = series[i]["scalar"] + if val < min_y: + min_y = val + if val > max_y: + max_y = val + measure = max_y - min_y + if measure > max_measure: + max_measure = measure + feature["interestingness"] = max_measure + + return sorted( + sorted_features_list, key=lambda x: x["interestingness"], reverse=True + ) + def get_label_vocab(vocab_path): - """Returns a list of label strings loaded from the provided path.""" - if vocab_path: - try: - with tf.io.gfile.GFile(vocab_path, 'r') as f: - return [line.rstrip('\n') for line in f] - except tf.errors.NotFoundError as err: - logging.error('error reading vocab file: %s', err) - return [] + """Returns a list of label strings loaded from the provided path.""" + if vocab_path: + try: + with tf.io.gfile.GFile(vocab_path, "r") as f: + return [line.rstrip("\n") for line in f] + except tf.errors.NotFoundError as err: + logging.error("error reading vocab file: %s", err) + return [] + def create_sprite_image(examples): """Returns an encoded sprite image for use in Facets Dive. @@ -740,123 +827,141 @@ def create_sprite_image(examples): """ def generate_image_from_thubnails(thumbnails, thumbnail_dims): - """Generates a sprite atlas image from a set of thumbnails.""" - num_thumbnails = tf.shape(thumbnails)[0].eval() - images_per_row = int(math.ceil(math.sqrt(num_thumbnails))) - thumb_height = thumbnail_dims[0] - thumb_width = thumbnail_dims[1] - master_height = images_per_row * thumb_height - master_width = images_per_row * thumb_width - num_channels = 3 - master = np.zeros([master_height, master_width, num_channels]) - for idx, image in enumerate(thumbnails.eval()): - left_idx = idx % images_per_row - top_idx = int(math.floor(idx / images_per_row)) - left_start = left_idx * thumb_width - left_end = left_start + thumb_width - top_start = top_idx * thumb_height - top_end = top_start + thumb_height - master[top_start:top_end, left_start:left_end, :] = image - return tf.image.encode_png(master) - - image_feature_name = 'image/encoded' + """Generates a sprite atlas image from a set of thumbnails.""" + num_thumbnails = tf.shape(thumbnails)[0].eval() + images_per_row = int(math.ceil(math.sqrt(num_thumbnails))) + thumb_height = thumbnail_dims[0] + thumb_width = thumbnail_dims[1] + master_height = images_per_row * thumb_height + master_width = images_per_row * thumb_width + num_channels = 3 + master = np.zeros([master_height, master_width, num_channels]) + for idx, image in enumerate(thumbnails.eval()): + left_idx = idx % images_per_row + top_idx = int(math.floor(idx / images_per_row)) + left_start = left_idx * thumb_width + left_end = left_start + thumb_width + top_start = top_idx * thumb_height + top_end = top_start + thumb_height + master[top_start:top_end, left_start:left_end, :] = image + return tf.image.encode_png(master) + + image_feature_name = "image/encoded" sprite_thumbnail_dim_px = 32 with tf.compat.v1.Session(): - keys_to_features = { - image_feature_name: - tf.io.FixedLenFeature((), tf.string, default_value=''), - } - parsed = tf.io.parse_example(examples, keys_to_features) - images = tf.zeros([1, 1, 1, 1], tf.float32) - i = tf.constant(0) - thumbnail_dims = (sprite_thumbnail_dim_px, - sprite_thumbnail_dim_px) - num_examples = tf.constant(len(examples)) - encoded_images = parsed[image_feature_name] - - # Loop over all examples, decoding the image feature value, resizing - # and appending to a list of all images. - def loop_body(i, encoded_images, images): - encoded_image = encoded_images[i] - image = tf.image.decode_jpeg(encoded_image, channels=3) - resized_image = tf.image.resize(image, thumbnail_dims) - expanded_image = tf.expand_dims(resized_image, 0) - images = tf.cond( - tf.equal(i, 0), lambda: expanded_image, - lambda: tf.concat([images, expanded_image], 0)) - return i + 1, encoded_images, images - - loop_out = tf.while_loop( - lambda i, encoded_images, images: tf.less(i, num_examples), - loop_body, [i, encoded_images, images], - shape_invariants=[ - i.get_shape(), - encoded_images.get_shape(), - tf.TensorShape(None) - ]) - - # Create the single sprite atlas image from these thumbnails. - sprite = generate_image_from_thubnails(loop_out[2], thumbnail_dims) - return sprite.eval() + keys_to_features = { + image_feature_name: tf.io.FixedLenFeature( + (), tf.string, default_value="" + ), + } + parsed = tf.io.parse_example(examples, keys_to_features) + images = tf.zeros([1, 1, 1, 1], tf.float32) + i = tf.constant(0) + thumbnail_dims = (sprite_thumbnail_dim_px, sprite_thumbnail_dim_px) + num_examples = tf.constant(len(examples)) + encoded_images = parsed[image_feature_name] + + # Loop over all examples, decoding the image feature value, resizing + # and appending to a list of all images. + def loop_body(i, encoded_images, images): + encoded_image = encoded_images[i] + image = tf.image.decode_jpeg(encoded_image, channels=3) + resized_image = tf.image.resize(image, thumbnail_dims) + expanded_image = tf.expand_dims(resized_image, 0) + images = tf.cond( + tf.equal(i, 0), + lambda: expanded_image, + lambda: tf.concat([images, expanded_image], 0), + ) + return i + 1, encoded_images, images + + loop_out = tf.while_loop( + lambda i, encoded_images, images: tf.less(i, num_examples), + loop_body, + [i, encoded_images, images], + shape_invariants=[ + i.get_shape(), + encoded_images.get_shape(), + tf.TensorShape(None), + ], + ) + + # Create the single sprite atlas image from these thumbnails. + sprite = generate_image_from_thubnails(loop_out[2], thumbnail_dims) + return sprite.eval() + def run_inference(examples, serving_bundle): - """Run inference on examples given model information - - Args: - examples: A list of examples that matches the model spec. - serving_bundle: A `ServingBundle` object that contains the information to - make the inference request. - - Returns: - A tuple with the first entry being the ClassificationResponse or - RegressionResponse proto and the second entry being a dictionary of extra - data for each example, such as attributions, or None if no data exists. - """ - batch_size = 64 - if serving_bundle.estimator and serving_bundle.feature_spec: - # If provided an estimator and feature spec then run inference locally. - preds = serving_bundle.estimator.predict( - lambda: tf.data.Dataset.from_tensor_slices( - tf.io.parse_example([ex.SerializeToString() for ex in examples], - serving_bundle.feature_spec)).batch(batch_size)) - - # Use the specified key if one is provided. - key_to_use = (serving_bundle.predict_output_tensor - if serving_bundle.use_predict else None) - - values = [] - for pred in preds: - if key_to_use is None: - # If the prediction dictionary only contains one key, use it. - returned_keys = list(pred.keys()) - if len(returned_keys) == 1: - key_to_use = returned_keys[0] - # Use default keys if necessary. - elif serving_bundle.model_type == 'classification': - key_to_use = 'probabilities' + """Run inference on examples given model information. + + Args: + examples: A list of examples that matches the model spec. + serving_bundle: A `ServingBundle` object that contains the information to + make the inference request. + + Returns: + A tuple with the first entry being the ClassificationResponse or + RegressionResponse proto and the second entry being a dictionary of extra + data for each example, such as attributions, or None if no data exists. + """ + batch_size = 64 + if serving_bundle.estimator and serving_bundle.feature_spec: + # If provided an estimator and feature spec then run inference locally. + preds = serving_bundle.estimator.predict( + lambda: tf.data.Dataset.from_tensor_slices( + tf.io.parse_example( + [ex.SerializeToString() for ex in examples], + serving_bundle.feature_spec, + ) + ).batch(batch_size) + ) + + # Use the specified key if one is provided. + key_to_use = ( + serving_bundle.predict_output_tensor + if serving_bundle.use_predict + else None + ) + + values = [] + for pred in preds: + if key_to_use is None: + # If the prediction dictionary only contains one key, use it. + returned_keys = list(pred.keys()) + if len(returned_keys) == 1: + key_to_use = returned_keys[0] + # Use default keys if necessary. + elif serving_bundle.model_type == "classification": + key_to_use = "probabilities" + else: + key_to_use = "predictions" + if key_to_use not in pred: + raise KeyError( + '"%s" not found in model predictions dictionary' + % key_to_use + ) + + values.append(pred[key_to_use]) + return ( + common_utils.convert_prediction_values(values, serving_bundle), + None, + ) + elif serving_bundle.custom_predict_fn: + # If custom_predict_fn is provided, pass examples directly for local + # inference. + values = serving_bundle.custom_predict_fn(examples) + extra_results = None + # If the custom prediction function returned a dict, then parse out the + # prediction scores. If it is just a list, then the results are the + # prediction results without attributions or other data. + if isinstance(values, dict): + preds = values.pop("predictions") + extra_results = values else: - key_to_use = 'predictions' - if key_to_use not in pred: - raise KeyError( - '"%s" not found in model predictions dictionary' % key_to_use) - - values.append(pred[key_to_use]) - return (common_utils.convert_prediction_values(values, serving_bundle), - None) - elif serving_bundle.custom_predict_fn: - # If custom_predict_fn is provided, pass examples directly for local - # inference. - values = serving_bundle.custom_predict_fn(examples) - extra_results = None - # If the custom prediction function returned a dict, then parse out the - # prediction scores. If it is just a list, then the results are the - # prediction results without attributions or other data. - if isinstance(values, dict): - preds = values.pop('predictions') - extra_results = values + preds = values + return ( + common_utils.convert_prediction_values(preds, serving_bundle), + extra_results, + ) else: - preds = values - return (common_utils.convert_prediction_values(preds, serving_bundle), - extra_results) - else: - return (platform_utils.call_servo(examples, serving_bundle), None) + return (platform_utils.call_servo(examples, serving_bundle), None) diff --git a/tensorboard/plugins/interactive_inference/utils/inference_utils_test.py b/tensorboard/plugins/interactive_inference/utils/inference_utils_test.py index 6598a4940a..16abcc7bee 100644 --- a/tensorboard/plugins/interactive_inference/utils/inference_utils_test.py +++ b/tensorboard/plugins/interactive_inference/utils/inference_utils_test.py @@ -25,10 +25,10 @@ import tensorflow as tf try: - # python version >= 3.3 - from unittest import mock + # python version >= 3.3 + from unittest import mock except ImportError: - import mock # pylint: disable=unused-import + import mock # pylint: disable=unused-import from tensorflow_serving.apis import classification_pb2 from tensorflow_serving.apis import predict_pb2 @@ -41,518 +41,630 @@ class InferenceUtilsTest(tf.test.TestCase): - - def setUp(self): - self.logdir = tf.compat.v1.test.get_temp_dir() - self.examples_path = os.path.join(self.logdir, 'example.pb') - - def tearDown(self): - try: - os.remove(self.examples_path) - except EnvironmentError: - pass - - def make_and_write_fake_example(self): - """Make example and write it to self.examples_path.""" - example = test_utils.make_fake_example() - test_utils.write_out_examples([example], self.examples_path) - return example - - def test_parse_original_feature_from_example(self): - example = test_utils.make_fake_example() - original_feature = inference_utils.parse_original_feature_from_example( - example, 'repeated_float') - self.assertEqual('repeated_float', original_feature.feature_name) - self.assertEqual([1.0, 2.0, 3.0, 4.0], original_feature.original_value) - self.assertEqual('float_list', original_feature.feature_type) - self.assertEqual(4, original_feature.length) - - original_feature = inference_utils.parse_original_feature_from_example( - example, 'repeated_int') - self.assertEqual('repeated_int', original_feature.feature_name) - self.assertEqual([10, 20], original_feature.original_value) - self.assertEqual('int64_list', original_feature.feature_type) - self.assertEqual(2, original_feature.length) - - original_feature = inference_utils.parse_original_feature_from_example( - example, 'single_int') - self.assertEqual('single_int', original_feature.feature_name) - self.assertEqual([0], original_feature.original_value) - self.assertEqual('int64_list', original_feature.feature_type) - self.assertEqual(1, original_feature.length) - - def test_parse_original_feature_from_example_binary(self): - example = tf.train.Example() - example.features.feature['img'].bytes_list.value.extend([b'\xef']) - original_feature = inference_utils.parse_original_feature_from_example( - example, 'img') - self.assertEqual('img', original_feature.feature_name) - self.assertEqual([b'\xef'], original_feature.original_value) - - def test_example_protos_from_path_get_all_in_file(self): - cns_path = os.path.join(tf.compat.v1.test.get_temp_dir(), - 'dummy_example') - example = test_utils.make_fake_example() - test_utils.write_out_examples([example], cns_path) - dummy_examples = platform_utils.example_protos_from_path(cns_path) - self.assertEqual(1, len(dummy_examples)) - self.assertEqual(example, dummy_examples[0]) - - def test_example_protos_from_path_get_two(self): - cns_path = os.path.join(tf.compat.v1.test.get_temp_dir(), - 'dummy_example') - example_one = test_utils.make_fake_example(1) - example_two = test_utils.make_fake_example(2) - example_three = test_utils.make_fake_example(3) - test_utils.write_out_examples([example_one, example_two, example_three], - cns_path) - dummy_examples = platform_utils.example_protos_from_path(cns_path, 2) - self.assertEqual(2, len(dummy_examples)) - self.assertEqual(example_one, dummy_examples[0]) - self.assertEqual(example_two, dummy_examples[1]) - - def test_example_protos_from_path_use_wildcard(self): - cns_path = os.path.join(tf.compat.v1.test.get_temp_dir(), - 'wildcard_example1') - example1 = test_utils.make_fake_example(1) - test_utils.write_out_examples([example1], cns_path) - cns_path = os.path.join(tf.compat.v1.test.get_temp_dir(), - 'wildcard_example2') - example2 = test_utils.make_fake_example(2) - test_utils.write_out_examples([example2], cns_path) - - wildcard_path = os.path.join(tf.compat.v1.test.get_temp_dir(), - 'wildcard_example*') - dummy_examples = platform_utils.example_protos_from_path( - wildcard_path) - self.assertEqual(2, len(dummy_examples)) - - def test_example_proto_from_path_if_does_not_exist(self): - cns_path = os.path.join(tf.compat.v1.test.get_temp_dir(), 'does_not_exist') - with self.assertRaises(common_utils.InvalidUserInputError): - platform_utils.example_protos_from_path(cns_path) - - def test_get_numeric_features(self): - example = test_utils.make_fake_example(single_int_val=2) - data = inference_utils.get_numeric_feature_names(example) - self.assertEqual( - ['repeated_float', 'repeated_int', 'single_float', 'single_int'], data) - - def test_get_numeric_features_to_observed_range(self): - example = test_utils.make_fake_example(single_int_val=2) - - data = inference_utils.get_numeric_features_to_observed_range( - [example]) - - # Returns a sorted list by feature_name. - self.assertDictEqual({ - 'repeated_float': { - 'observedMin': 1., - 'observedMax': 4., - }, - 'repeated_int': { - 'observedMin': 10, - 'observedMax': 20, - }, - 'single_float': { - 'observedMin': 24.5, - 'observedMax': 24.5, - }, - 'single_int': { - 'observedMin': 2., - 'observedMax': 2., - }, - }, data) - - def test_get_categorical_features_to_sampling(self): - cat_example = tf.train.Example() - cat_example.features.feature['non_numeric'].bytes_list.value.extend( - [b'cat']) - - cow_example = tf.train.Example() - cow_example.features.feature['non_numeric'].bytes_list.value.extend( - [b'cow']) - - pony_example = tf.train.Example() - pony_example.features.feature['non_numeric'].bytes_list.value.extend( - [b'pony']) - - examples = [cat_example] * 4 + [cow_example] * 5 + [pony_example] * 10 - - # If we stop sampling at the first 3 examples, the only example should be - # cat example. - data = inference_utils.get_categorical_features_to_sampling( - examples[0: 3], top_k=1) - self.assertDictEqual({ - 'non_numeric': { - 'samples': ['cat'] - } - }, data) - - # If we sample more examples, the top 2 examples should be cow and pony. - data = inference_utils.get_categorical_features_to_sampling( - examples[0: 20], top_k=2) - self.assertDictEqual({ - 'non_numeric': { - 'samples': ['pony', 'cow'] - } - }, data) - - def test_wrap_inference_results_classification(self): - """Test wrapping a classification result.""" - inference_result_proto = classification_pb2.ClassificationResponse() - classification = inference_result_proto.result.classifications.add() - inference_class = classification.classes.add() - inference_class.label = 'class_b' - inference_class.score = 0.3 - inference_class = classification.classes.add() - inference_class.label = 'class_a' - inference_class.score = 0.7 - - wrapped = inference_utils.wrap_inference_results(inference_result_proto) - self.assertEqual(1, len(wrapped.classification_result.classifications)) - self.assertEqual( - 2, len(wrapped.classification_result.classifications[0].classes)) - - def test_wrap_inference_results_regression(self): - """Test wrapping a regression result.""" - inference_result_proto = regression_pb2.RegressionResponse() - regression = inference_result_proto.result.regressions.add() - regression.value = 0.45 - regression = inference_result_proto.result.regressions.add() - regression.value = 0.55 - - wrapped = inference_utils.wrap_inference_results(inference_result_proto) - self.assertEqual(2, len(wrapped.regression_result.regressions)) - - @mock.patch.object(inference_utils, 'make_json_formatted_for_single_chart') - @mock.patch.object(platform_utils, 'call_servo') - def test_mutant_charts_for_feature(self, mock_call_servo, - mock_make_json_formatted_for_single_chart): - example = self.make_and_write_fake_example() - serving_bundles = [inference_utils.ServingBundle('', '', 'classification', - '', '', False, '', '')] - num_mutants = 10 - viz_params = inference_utils.VizParams( - x_min=1, - x_max=10, - examples=[example], - num_mutants=num_mutants, - feature_index_pattern=None) - - mock_call_servo = lambda _, __: None - mock_make_json_formatted_for_single_chart = lambda _, __: {} - charts = inference_utils.mutant_charts_for_feature( - [example], 'repeated_float', serving_bundles, viz_params) - self.assertEqual('numeric', charts['chartType']) - self.assertEqual(4, len(charts['data'])) - charts = inference_utils.mutant_charts_for_feature( - [example], 'repeated_int', serving_bundles, viz_params) - self.assertEqual('numeric', charts['chartType']) - self.assertEqual(2, len(charts['data'])) - charts = inference_utils.mutant_charts_for_feature( - [example], 'single_int', serving_bundles, viz_params) - self.assertEqual('numeric', charts['chartType']) - self.assertEqual(1, len(charts['data'])) - charts = inference_utils.mutant_charts_for_feature( - [example], 'non_numeric', serving_bundles, viz_params) - self.assertEqual('categorical', charts['chartType']) - self.assertEqual(3, len(charts['data'])) - - @mock.patch.object(inference_utils, 'make_json_formatted_for_single_chart') - @mock.patch.object(platform_utils, 'call_servo') - def test_mutant_charts_for_feature_with_feature_index_pattern( - self, mock_call_servo, mock_make_json_formatted_for_single_chart): - example = self.make_and_write_fake_example() - serving_bundles = [inference_utils.ServingBundle('', '', 'classification', - '', '', False, '', '')] - num_mutants = 10 - viz_params = inference_utils.VizParams( - x_min=1, - x_max=10, - examples=[example], - num_mutants=num_mutants, - feature_index_pattern='0 , 2-3') - - mock_call_servo = lambda _, __: None - mock_make_json_formatted_for_single_chart = lambda _, __: {} - charts = inference_utils.mutant_charts_for_feature( - [example], 'repeated_float', serving_bundles, viz_params) - self.assertEqual('numeric', charts['chartType']) - self.assertEqual(3, len(charts['data'])) - - # These should return 3 charts even though all fields from the index - # pattern don't exist for the example. - charts = inference_utils.mutant_charts_for_feature( - [example], 'repeated_int', serving_bundles, viz_params) - self.assertEqual('numeric', charts['chartType']) - self.assertEqual(3, len(charts['data'])) - - charts = inference_utils.mutant_charts_for_feature( - [example], 'single_int', serving_bundles, viz_params) - self.assertEqual('numeric', charts['chartType']) - self.assertEqual(3, len(charts['data'])) - - def test_make_mutant_tuples_float_list(self): - example = self.make_and_write_fake_example() - index_to_mutate = 1 - num_mutants = 10 - viz_params = inference_utils.VizParams( - x_min=1, - x_max=10, - examples = [example], - num_mutants=num_mutants, - feature_index_pattern=None) - - original_feature = inference_utils.parse_original_feature_from_example( - example, 'repeated_float') - mutant_features, mutant_examples = inference_utils.make_mutant_tuples( - [example], - original_feature, - index_to_mutate=index_to_mutate, - viz_params=viz_params) - - # Check that values in mutant_features and mutant_examples are as expected. - expected_values = np.linspace(1, 10, num_mutants) - np.testing.assert_almost_equal( - expected_values, - [mutant_feature.mutant_value for mutant_feature in mutant_features]) - np.testing.assert_almost_equal(expected_values, [ - mutant_example.features.feature['repeated_float'] - .float_list.value[index_to_mutate] for mutant_example in mutant_examples - ]) - - # Check that the example (other than the mutant value) is the same. - for expected_value, mutant_example in zip(expected_values, mutant_examples): - mutant_values = test_utils.value_from_example(mutant_example, - 'repeated_float') - original_values = test_utils.value_from_example(example, 'repeated_float') - original_values[index_to_mutate] = expected_value - self.assertEqual(original_values, mutant_values) - - def test_make_mutant_tuples_int_list(self): - example = self.make_and_write_fake_example() - index_to_mutate = 1 - num_mutants = 10 - viz_params = inference_utils.VizParams( - x_min=1, - x_max=10, - examples = [example], - num_mutants=num_mutants, - feature_index_pattern=None) - original_feature = inference_utils.parse_original_feature_from_example( - example, 'repeated_int') - mutant_features, mutant_examples = inference_utils.make_mutant_tuples( - [example], - original_feature, - index_to_mutate=index_to_mutate, - viz_params=viz_params) - - # Check that values in mutant_features and mutant_examples are as expected. - expected_values = np.linspace(1, 10, num_mutants) - np.testing.assert_almost_equal( - expected_values, - [mutant_feature.mutant_value for mutant_feature in mutant_features]) - np.testing.assert_almost_equal(expected_values, [ - mutant_example.features.feature['repeated_int'] - .int64_list.value[index_to_mutate] for mutant_example in mutant_examples - ]) - - # Check that the example (other than the mutant value) is the same. - for expected_value, mutant_example in zip(expected_values, mutant_examples): - mutant_values = test_utils.value_from_example(mutant_example, - 'repeated_int') - original_values = test_utils.value_from_example(example, 'repeated_int') - original_values[index_to_mutate] = expected_value - self.assertEqual(original_values, mutant_values) - - def test_make_json_formatted_for_single_chart_classification(self): - """Test making a classification chart with a single point on it.""" - inference_result_proto = classification_pb2.ClassificationResponse() - classification = inference_result_proto.result.classifications.add() - inference_class = classification.classes.add() - inference_class.label = 'class_a' - inference_class.score = 0.7 - - inference_class = classification.classes.add() - inference_class.label = 'class_b' - inference_class.score = 0.3 - - original_feature = inference_utils.OriginalFeatureList( - 'feature_name', [2.], 'float_list') - mutant_feature = inference_utils.MutantFeatureValue( - original_feature, index=0, mutant_value=20) - - jsonable = inference_utils.make_json_formatted_for_single_chart( - [mutant_feature], inference_result_proto, 0) - - self.assertEqual(['class_a', 'class_b'], sorted(jsonable.keys())) - self.assertEqual(1, len(jsonable['class_a'])) - self.assertEqual(20, jsonable['class_a'][0]['step']) - self.assertAlmostEqual(0.7, jsonable['class_a'][0]['scalar']) - - self.assertEqual(1, len(jsonable['class_b'])) - self.assertEqual(20, jsonable['class_b'][0]['step']) - self.assertAlmostEqual(0.3, jsonable['class_b'][0]['scalar']) - - def test_make_json_formatted_for_single_chart_regression(self): - """Test making a regression chart with a single point on it.""" - inference_result_proto = regression_pb2.RegressionResponse() - regression = inference_result_proto.result.regressions.add() - regression.value = 0.45 - regression = inference_result_proto.result.regressions.add() - regression.value = 0.55 - - original_feature = inference_utils.OriginalFeatureList( - 'feature_name', [2.], 'float_list') - mutant_feature = inference_utils.MutantFeatureValue( - original_feature, index=0, mutant_value=20) - mutant_feature_2 = inference_utils.MutantFeatureValue( - original_feature, index=0, mutant_value=10) - - jsonable = inference_utils.make_json_formatted_for_single_chart( - [mutant_feature, mutant_feature_2], inference_result_proto, 0) - - self.assertEqual(['value'], list(jsonable.keys())) - self.assertEqual(2, len(jsonable['value'])) - self.assertEqual(10, jsonable['value'][0]['step']) - self.assertAlmostEqual(0.55, jsonable['value'][0]['scalar']) - self.assertEqual(20, jsonable['value'][1]['step']) - self.assertAlmostEqual(0.45, jsonable['value'][1]['scalar']) - - def test_convert_predict_response_regression(self): - """Test converting a PredictResponse to a RegressionResponse.""" - predict = predict_pb2.PredictResponse() - output = predict.outputs['scores'] - dim = output.tensor_shape.dim.add() - dim.size = 2 - output.float_val.extend([0.1, 0.2]) - - bundle = inference_utils.ServingBundle( - '', '', 'regression', '', '', True, '', 'scores') - converted = common_utils.convert_predict_response(predict, bundle) - - self.assertAlmostEqual(0.1, converted.result.regressions[0].value) - self.assertAlmostEqual(0.2, converted.result.regressions[1].value) - - def test_convert_predict_response_classification(self): - """Test converting a PredictResponse to a ClassificationResponse.""" - predict = predict_pb2.PredictResponse() - output = predict.outputs['probabilities'] - dim = output.tensor_shape.dim.add() - dim.size = 3 - dim = output.tensor_shape.dim.add() - dim.size = 2 - output.float_val.extend([1., 0., .9, .1, .8, .2]) - - bundle = inference_utils.ServingBundle( - '', '', 'classification', '', '', True, '', 'probabilities') - converted = common_utils.convert_predict_response(predict, bundle) - - self.assertEqual("0", converted.result.classifications[0].classes[0].label) - self.assertAlmostEqual( - 1, converted.result.classifications[0].classes[0].score) - self.assertEqual("1", converted.result.classifications[0].classes[1].label) - self.assertAlmostEqual( - 0, converted.result.classifications[0].classes[1].score) - - self.assertEqual("0", converted.result.classifications[1].classes[0].label) - self.assertAlmostEqual( - .9, converted.result.classifications[1].classes[0].score) - self.assertEqual("1", converted.result.classifications[1].classes[1].label) - self.assertAlmostEqual( - .1, converted.result.classifications[1].classes[1].score) - - self.assertEqual("0", converted.result.classifications[2].classes[0].label) - self.assertAlmostEqual( - .8, converted.result.classifications[2].classes[0].score) - self.assertEqual("1", converted.result.classifications[2].classes[1].label) - self.assertAlmostEqual( - .2, converted.result.classifications[2].classes[1].score) - - def test_vizparams_pattern_parser(self): - viz_params = inference_utils.VizParams( - x_min=1, - x_max=10, - examples=[], - num_mutants=0, - feature_index_pattern=None) - self.assertEqual([], viz_params.feature_indices) - viz_params = inference_utils.VizParams( - x_min=1, - x_max=10, - examples=[], - num_mutants=0, - feature_index_pattern='1-3') - self.assertEqual([1, 2, 3], viz_params.feature_indices) - viz_params = inference_utils.VizParams( - x_min=1, - x_max=10, - examples=[], - num_mutants=0, - feature_index_pattern='3-1') - self.assertEqual([], viz_params.feature_indices) - viz_params = inference_utils.VizParams( - x_min=1, - x_max=10, - examples=[], - num_mutants=0, - feature_index_pattern='1-1') - self.assertEqual([1], viz_params.feature_indices) - viz_params = inference_utils.VizParams( - x_min=1, - x_max=10, - examples=[], - num_mutants=0, - feature_index_pattern='3, 1') - self.assertEqual([1, 3], viz_params.feature_indices) - viz_params = inference_utils.VizParams( - x_min=1, - x_max=10, - examples=[], - num_mutants=0, - feature_index_pattern='0-') - self.assertEqual([], viz_params.feature_indices) - viz_params = inference_utils.VizParams( - x_min=1, - x_max=10, - examples=[], - num_mutants=0, - feature_index_pattern='0-error') - self.assertEqual([], viz_params.feature_indices) - viz_params = inference_utils.VizParams( - x_min=1, - x_max=10, - examples=[], - num_mutants=0, - feature_index_pattern='0-3-5') - self.assertEqual([], viz_params.feature_indices) - - def test_sort_eligible_features(self): - features_list = [{'name': 'feat1'}, {'name': 'feat2'}] - chart_data = { - 'feat1': { - 'chartType': 'numeric', - 'data': [[ - {'series1': [{'scalar': .2}, {'scalar': .1}, {'scalar': .3}]}, - {'series2': [{'scalar': .2}, {'scalar': .1}, {'scalar': .4}]}, - ]] - }, - 'feat2': { - 'chartType': 'categorical', - 'data': [[ - {'series1': [{'scalar': .2}, {'scalar': .1}, {'scalar': .3}]}, - {'series2': [{'scalar': .2}, {'scalar': .1}, {'scalar': .9}]}, - ]] + def setUp(self): + self.logdir = tf.compat.v1.test.get_temp_dir() + self.examples_path = os.path.join(self.logdir, "example.pb") + + def tearDown(self): + try: + os.remove(self.examples_path) + except EnvironmentError: + pass + + def make_and_write_fake_example(self): + """Make example and write it to self.examples_path.""" + example = test_utils.make_fake_example() + test_utils.write_out_examples([example], self.examples_path) + return example + + def test_parse_original_feature_from_example(self): + example = test_utils.make_fake_example() + original_feature = inference_utils.parse_original_feature_from_example( + example, "repeated_float" + ) + self.assertEqual("repeated_float", original_feature.feature_name) + self.assertEqual([1.0, 2.0, 3.0, 4.0], original_feature.original_value) + self.assertEqual("float_list", original_feature.feature_type) + self.assertEqual(4, original_feature.length) + + original_feature = inference_utils.parse_original_feature_from_example( + example, "repeated_int" + ) + self.assertEqual("repeated_int", original_feature.feature_name) + self.assertEqual([10, 20], original_feature.original_value) + self.assertEqual("int64_list", original_feature.feature_type) + self.assertEqual(2, original_feature.length) + + original_feature = inference_utils.parse_original_feature_from_example( + example, "single_int" + ) + self.assertEqual("single_int", original_feature.feature_name) + self.assertEqual([0], original_feature.original_value) + self.assertEqual("int64_list", original_feature.feature_type) + self.assertEqual(1, original_feature.length) + + def test_parse_original_feature_from_example_binary(self): + example = tf.train.Example() + example.features.feature["img"].bytes_list.value.extend([b"\xef"]) + original_feature = inference_utils.parse_original_feature_from_example( + example, "img" + ) + self.assertEqual("img", original_feature.feature_name) + self.assertEqual([b"\xef"], original_feature.original_value) + + def test_example_protos_from_path_get_all_in_file(self): + cns_path = os.path.join( + tf.compat.v1.test.get_temp_dir(), "dummy_example" + ) + example = test_utils.make_fake_example() + test_utils.write_out_examples([example], cns_path) + dummy_examples = platform_utils.example_protos_from_path(cns_path) + self.assertEqual(1, len(dummy_examples)) + self.assertEqual(example, dummy_examples[0]) + + def test_example_protos_from_path_get_two(self): + cns_path = os.path.join( + tf.compat.v1.test.get_temp_dir(), "dummy_example" + ) + example_one = test_utils.make_fake_example(1) + example_two = test_utils.make_fake_example(2) + example_three = test_utils.make_fake_example(3) + test_utils.write_out_examples( + [example_one, example_two, example_three], cns_path + ) + dummy_examples = platform_utils.example_protos_from_path(cns_path, 2) + self.assertEqual(2, len(dummy_examples)) + self.assertEqual(example_one, dummy_examples[0]) + self.assertEqual(example_two, dummy_examples[1]) + + def test_example_protos_from_path_use_wildcard(self): + cns_path = os.path.join( + tf.compat.v1.test.get_temp_dir(), "wildcard_example1" + ) + example1 = test_utils.make_fake_example(1) + test_utils.write_out_examples([example1], cns_path) + cns_path = os.path.join( + tf.compat.v1.test.get_temp_dir(), "wildcard_example2" + ) + example2 = test_utils.make_fake_example(2) + test_utils.write_out_examples([example2], cns_path) + + wildcard_path = os.path.join( + tf.compat.v1.test.get_temp_dir(), "wildcard_example*" + ) + dummy_examples = platform_utils.example_protos_from_path(wildcard_path) + self.assertEqual(2, len(dummy_examples)) + + def test_example_proto_from_path_if_does_not_exist(self): + cns_path = os.path.join( + tf.compat.v1.test.get_temp_dir(), "does_not_exist" + ) + with self.assertRaises(common_utils.InvalidUserInputError): + platform_utils.example_protos_from_path(cns_path) + + def test_get_numeric_features(self): + example = test_utils.make_fake_example(single_int_val=2) + data = inference_utils.get_numeric_feature_names(example) + self.assertEqual( + ["repeated_float", "repeated_int", "single_float", "single_int"], + data, + ) + + def test_get_numeric_features_to_observed_range(self): + example = test_utils.make_fake_example(single_int_val=2) + + data = inference_utils.get_numeric_features_to_observed_range([example]) + + # Returns a sorted list by feature_name. + self.assertDictEqual( + { + "repeated_float": {"observedMin": 1.0, "observedMax": 4.0,}, + "repeated_int": {"observedMin": 10, "observedMax": 20,}, + "single_float": {"observedMin": 24.5, "observedMax": 24.5,}, + "single_int": {"observedMin": 2.0, "observedMax": 2.0,}, + }, + data, + ) + + def test_get_categorical_features_to_sampling(self): + cat_example = tf.train.Example() + cat_example.features.feature["non_numeric"].bytes_list.value.extend( + [b"cat"] + ) + + cow_example = tf.train.Example() + cow_example.features.feature["non_numeric"].bytes_list.value.extend( + [b"cow"] + ) + + pony_example = tf.train.Example() + pony_example.features.feature["non_numeric"].bytes_list.value.extend( + [b"pony"] + ) + + examples = [cat_example] * 4 + [cow_example] * 5 + [pony_example] * 10 + + # If we stop sampling at the first 3 examples, the only example should be + # cat example. + data = inference_utils.get_categorical_features_to_sampling( + examples[0:3], top_k=1 + ) + self.assertDictEqual({"non_numeric": {"samples": ["cat"]}}, data) + + # If we sample more examples, the top 2 examples should be cow and pony. + data = inference_utils.get_categorical_features_to_sampling( + examples[0:20], top_k=2 + ) + self.assertDictEqual( + {"non_numeric": {"samples": ["pony", "cow"]}}, data + ) + + def test_wrap_inference_results_classification(self): + """Test wrapping a classification result.""" + inference_result_proto = classification_pb2.ClassificationResponse() + classification = inference_result_proto.result.classifications.add() + inference_class = classification.classes.add() + inference_class.label = "class_b" + inference_class.score = 0.3 + inference_class = classification.classes.add() + inference_class.label = "class_a" + inference_class.score = 0.7 + + wrapped = inference_utils.wrap_inference_results(inference_result_proto) + self.assertEqual(1, len(wrapped.classification_result.classifications)) + self.assertEqual( + 2, len(wrapped.classification_result.classifications[0].classes) + ) + + def test_wrap_inference_results_regression(self): + """Test wrapping a regression result.""" + inference_result_proto = regression_pb2.RegressionResponse() + regression = inference_result_proto.result.regressions.add() + regression.value = 0.45 + regression = inference_result_proto.result.regressions.add() + regression.value = 0.55 + + wrapped = inference_utils.wrap_inference_results(inference_result_proto) + self.assertEqual(2, len(wrapped.regression_result.regressions)) + + @mock.patch.object(inference_utils, "make_json_formatted_for_single_chart") + @mock.patch.object(platform_utils, "call_servo") + def test_mutant_charts_for_feature( + self, mock_call_servo, mock_make_json_formatted_for_single_chart + ): + example = self.make_and_write_fake_example() + serving_bundles = [ + inference_utils.ServingBundle( + "", "", "classification", "", "", False, "", "" + ) + ] + num_mutants = 10 + viz_params = inference_utils.VizParams( + x_min=1, + x_max=10, + examples=[example], + num_mutants=num_mutants, + feature_index_pattern=None, + ) + + mock_call_servo = lambda _, __: None + mock_make_json_formatted_for_single_chart = lambda _, __: {} + charts = inference_utils.mutant_charts_for_feature( + [example], "repeated_float", serving_bundles, viz_params + ) + self.assertEqual("numeric", charts["chartType"]) + self.assertEqual(4, len(charts["data"])) + charts = inference_utils.mutant_charts_for_feature( + [example], "repeated_int", serving_bundles, viz_params + ) + self.assertEqual("numeric", charts["chartType"]) + self.assertEqual(2, len(charts["data"])) + charts = inference_utils.mutant_charts_for_feature( + [example], "single_int", serving_bundles, viz_params + ) + self.assertEqual("numeric", charts["chartType"]) + self.assertEqual(1, len(charts["data"])) + charts = inference_utils.mutant_charts_for_feature( + [example], "non_numeric", serving_bundles, viz_params + ) + self.assertEqual("categorical", charts["chartType"]) + self.assertEqual(3, len(charts["data"])) + + @mock.patch.object(inference_utils, "make_json_formatted_for_single_chart") + @mock.patch.object(platform_utils, "call_servo") + def test_mutant_charts_for_feature_with_feature_index_pattern( + self, mock_call_servo, mock_make_json_formatted_for_single_chart + ): + example = self.make_and_write_fake_example() + serving_bundles = [ + inference_utils.ServingBundle( + "", "", "classification", "", "", False, "", "" + ) + ] + num_mutants = 10 + viz_params = inference_utils.VizParams( + x_min=1, + x_max=10, + examples=[example], + num_mutants=num_mutants, + feature_index_pattern="0 , 2-3", + ) + + mock_call_servo = lambda _, __: None + mock_make_json_formatted_for_single_chart = lambda _, __: {} + charts = inference_utils.mutant_charts_for_feature( + [example], "repeated_float", serving_bundles, viz_params + ) + self.assertEqual("numeric", charts["chartType"]) + self.assertEqual(3, len(charts["data"])) + + # These should return 3 charts even though all fields from the index + # pattern don't exist for the example. + charts = inference_utils.mutant_charts_for_feature( + [example], "repeated_int", serving_bundles, viz_params + ) + self.assertEqual("numeric", charts["chartType"]) + self.assertEqual(3, len(charts["data"])) + + charts = inference_utils.mutant_charts_for_feature( + [example], "single_int", serving_bundles, viz_params + ) + self.assertEqual("numeric", charts["chartType"]) + self.assertEqual(3, len(charts["data"])) + + def test_make_mutant_tuples_float_list(self): + example = self.make_and_write_fake_example() + index_to_mutate = 1 + num_mutants = 10 + viz_params = inference_utils.VizParams( + x_min=1, + x_max=10, + examples=[example], + num_mutants=num_mutants, + feature_index_pattern=None, + ) + + original_feature = inference_utils.parse_original_feature_from_example( + example, "repeated_float" + ) + mutant_features, mutant_examples = inference_utils.make_mutant_tuples( + [example], + original_feature, + index_to_mutate=index_to_mutate, + viz_params=viz_params, + ) + + # Check that values in mutant_features and mutant_examples are as expected. + expected_values = np.linspace(1, 10, num_mutants) + np.testing.assert_almost_equal( + expected_values, + [mutant_feature.mutant_value for mutant_feature in mutant_features], + ) + np.testing.assert_almost_equal( + expected_values, + [ + mutant_example.features.feature[ + "repeated_float" + ].float_list.value[index_to_mutate] + for mutant_example in mutant_examples + ], + ) + + # Check that the example (other than the mutant value) is the same. + for expected_value, mutant_example in zip( + expected_values, mutant_examples + ): + mutant_values = test_utils.value_from_example( + mutant_example, "repeated_float" + ) + original_values = test_utils.value_from_example( + example, "repeated_float" + ) + original_values[index_to_mutate] = expected_value + self.assertEqual(original_values, mutant_values) + + def test_make_mutant_tuples_int_list(self): + example = self.make_and_write_fake_example() + index_to_mutate = 1 + num_mutants = 10 + viz_params = inference_utils.VizParams( + x_min=1, + x_max=10, + examples=[example], + num_mutants=num_mutants, + feature_index_pattern=None, + ) + original_feature = inference_utils.parse_original_feature_from_example( + example, "repeated_int" + ) + mutant_features, mutant_examples = inference_utils.make_mutant_tuples( + [example], + original_feature, + index_to_mutate=index_to_mutate, + viz_params=viz_params, + ) + + # Check that values in mutant_features and mutant_examples are as expected. + expected_values = np.linspace(1, 10, num_mutants) + np.testing.assert_almost_equal( + expected_values, + [mutant_feature.mutant_value for mutant_feature in mutant_features], + ) + np.testing.assert_almost_equal( + expected_values, + [ + mutant_example.features.feature[ + "repeated_int" + ].int64_list.value[index_to_mutate] + for mutant_example in mutant_examples + ], + ) + + # Check that the example (other than the mutant value) is the same. + for expected_value, mutant_example in zip( + expected_values, mutant_examples + ): + mutant_values = test_utils.value_from_example( + mutant_example, "repeated_int" + ) + original_values = test_utils.value_from_example( + example, "repeated_int" + ) + original_values[index_to_mutate] = expected_value + self.assertEqual(original_values, mutant_values) + + def test_make_json_formatted_for_single_chart_classification(self): + """Test making a classification chart with a single point on it.""" + inference_result_proto = classification_pb2.ClassificationResponse() + classification = inference_result_proto.result.classifications.add() + inference_class = classification.classes.add() + inference_class.label = "class_a" + inference_class.score = 0.7 + + inference_class = classification.classes.add() + inference_class.label = "class_b" + inference_class.score = 0.3 + + original_feature = inference_utils.OriginalFeatureList( + "feature_name", [2.0], "float_list" + ) + mutant_feature = inference_utils.MutantFeatureValue( + original_feature, index=0, mutant_value=20 + ) + + jsonable = inference_utils.make_json_formatted_for_single_chart( + [mutant_feature], inference_result_proto, 0 + ) + + self.assertEqual(["class_a", "class_b"], sorted(jsonable.keys())) + self.assertEqual(1, len(jsonable["class_a"])) + self.assertEqual(20, jsonable["class_a"][0]["step"]) + self.assertAlmostEqual(0.7, jsonable["class_a"][0]["scalar"]) + + self.assertEqual(1, len(jsonable["class_b"])) + self.assertEqual(20, jsonable["class_b"][0]["step"]) + self.assertAlmostEqual(0.3, jsonable["class_b"][0]["scalar"]) + + def test_make_json_formatted_for_single_chart_regression(self): + """Test making a regression chart with a single point on it.""" + inference_result_proto = regression_pb2.RegressionResponse() + regression = inference_result_proto.result.regressions.add() + regression.value = 0.45 + regression = inference_result_proto.result.regressions.add() + regression.value = 0.55 + + original_feature = inference_utils.OriginalFeatureList( + "feature_name", [2.0], "float_list" + ) + mutant_feature = inference_utils.MutantFeatureValue( + original_feature, index=0, mutant_value=20 + ) + mutant_feature_2 = inference_utils.MutantFeatureValue( + original_feature, index=0, mutant_value=10 + ) + + jsonable = inference_utils.make_json_formatted_for_single_chart( + [mutant_feature, mutant_feature_2], inference_result_proto, 0 + ) + + self.assertEqual(["value"], list(jsonable.keys())) + self.assertEqual(2, len(jsonable["value"])) + self.assertEqual(10, jsonable["value"][0]["step"]) + self.assertAlmostEqual(0.55, jsonable["value"][0]["scalar"]) + self.assertEqual(20, jsonable["value"][1]["step"]) + self.assertAlmostEqual(0.45, jsonable["value"][1]["scalar"]) + + def test_convert_predict_response_regression(self): + """Test converting a PredictResponse to a RegressionResponse.""" + predict = predict_pb2.PredictResponse() + output = predict.outputs["scores"] + dim = output.tensor_shape.dim.add() + dim.size = 2 + output.float_val.extend([0.1, 0.2]) + + bundle = inference_utils.ServingBundle( + "", "", "regression", "", "", True, "", "scores" + ) + converted = common_utils.convert_predict_response(predict, bundle) + + self.assertAlmostEqual(0.1, converted.result.regressions[0].value) + self.assertAlmostEqual(0.2, converted.result.regressions[1].value) + + def test_convert_predict_response_classification(self): + """Test converting a PredictResponse to a ClassificationResponse.""" + predict = predict_pb2.PredictResponse() + output = predict.outputs["probabilities"] + dim = output.tensor_shape.dim.add() + dim.size = 3 + dim = output.tensor_shape.dim.add() + dim.size = 2 + output.float_val.extend([1.0, 0.0, 0.9, 0.1, 0.8, 0.2]) + + bundle = inference_utils.ServingBundle( + "", "", "classification", "", "", True, "", "probabilities" + ) + converted = common_utils.convert_predict_response(predict, bundle) + + self.assertEqual( + "0", converted.result.classifications[0].classes[0].label + ) + self.assertAlmostEqual( + 1, converted.result.classifications[0].classes[0].score + ) + self.assertEqual( + "1", converted.result.classifications[0].classes[1].label + ) + self.assertAlmostEqual( + 0, converted.result.classifications[0].classes[1].score + ) + + self.assertEqual( + "0", converted.result.classifications[1].classes[0].label + ) + self.assertAlmostEqual( + 0.9, converted.result.classifications[1].classes[0].score + ) + self.assertEqual( + "1", converted.result.classifications[1].classes[1].label + ) + self.assertAlmostEqual( + 0.1, converted.result.classifications[1].classes[1].score + ) + + self.assertEqual( + "0", converted.result.classifications[2].classes[0].label + ) + self.assertAlmostEqual( + 0.8, converted.result.classifications[2].classes[0].score + ) + self.assertEqual( + "1", converted.result.classifications[2].classes[1].label + ) + self.assertAlmostEqual( + 0.2, converted.result.classifications[2].classes[1].score + ) + + def test_vizparams_pattern_parser(self): + viz_params = inference_utils.VizParams( + x_min=1, + x_max=10, + examples=[], + num_mutants=0, + feature_index_pattern=None, + ) + self.assertEqual([], viz_params.feature_indices) + viz_params = inference_utils.VizParams( + x_min=1, + x_max=10, + examples=[], + num_mutants=0, + feature_index_pattern="1-3", + ) + self.assertEqual([1, 2, 3], viz_params.feature_indices) + viz_params = inference_utils.VizParams( + x_min=1, + x_max=10, + examples=[], + num_mutants=0, + feature_index_pattern="3-1", + ) + self.assertEqual([], viz_params.feature_indices) + viz_params = inference_utils.VizParams( + x_min=1, + x_max=10, + examples=[], + num_mutants=0, + feature_index_pattern="1-1", + ) + self.assertEqual([1], viz_params.feature_indices) + viz_params = inference_utils.VizParams( + x_min=1, + x_max=10, + examples=[], + num_mutants=0, + feature_index_pattern="3, 1", + ) + self.assertEqual([1, 3], viz_params.feature_indices) + viz_params = inference_utils.VizParams( + x_min=1, + x_max=10, + examples=[], + num_mutants=0, + feature_index_pattern="0-", + ) + self.assertEqual([], viz_params.feature_indices) + viz_params = inference_utils.VizParams( + x_min=1, + x_max=10, + examples=[], + num_mutants=0, + feature_index_pattern="0-error", + ) + self.assertEqual([], viz_params.feature_indices) + viz_params = inference_utils.VizParams( + x_min=1, + x_max=10, + examples=[], + num_mutants=0, + feature_index_pattern="0-3-5", + ) + self.assertEqual([], viz_params.feature_indices) + + def test_sort_eligible_features(self): + features_list = [{"name": "feat1"}, {"name": "feat2"}] + chart_data = { + "feat1": { + "chartType": "numeric", + "data": [ + [ + { + "series1": [ + {"scalar": 0.2}, + {"scalar": 0.1}, + {"scalar": 0.3}, + ] + }, + { + "series2": [ + {"scalar": 0.2}, + {"scalar": 0.1}, + {"scalar": 0.4}, + ] + }, + ] + ], + }, + "feat2": { + "chartType": "categorical", + "data": [ + [ + { + "series1": [ + {"scalar": 0.2}, + {"scalar": 0.1}, + {"scalar": 0.3}, + ] + }, + { + "series2": [ + {"scalar": 0.2}, + {"scalar": 0.1}, + {"scalar": 0.9}, + ] + }, + ] + ], + }, } - } - sorted_list = inference_utils.sort_eligible_features( - features_list, chart_data) - print(sorted_list) - self.assertEqual('feat2', sorted_list[0]['name']) - self.assertEqual(.8, sorted_list[0]['interestingness']) - self.assertEqual('feat1', sorted_list[1]['name']) - self.assertEqual(.4, sorted_list[1]['interestingness']) - -if __name__ == '__main__': - tf.test.main() + sorted_list = inference_utils.sort_eligible_features( + features_list, chart_data + ) + print(sorted_list) + self.assertEqual("feat2", sorted_list[0]["name"]) + self.assertEqual(0.8, sorted_list[0]["interestingness"]) + self.assertEqual("feat1", sorted_list[1]["name"]) + self.assertEqual(0.4, sorted_list[1]["interestingness"]) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/interactive_inference/utils/platform_utils.py b/tensorboard/plugins/interactive_inference/utils/platform_utils.py index bc87b03694..dbfd80374b 100644 --- a/tensorboard/plugins/interactive_inference/utils/platform_utils.py +++ b/tensorboard/plugins/interactive_inference/utils/platform_utils.py @@ -30,176 +30,194 @@ def filepath_to_filepath_list(file_path): - """Returns a list of files given by a filepath. + """Returns a list of files given by a filepath. - Args: - file_path: A path, possibly representing a single file, or containing a - wildcard or sharded path. + Args: + file_path: A path, possibly representing a single file, or containing a + wildcard or sharded path. - Returns: - A list of files represented by the provided path. - """ - file_path = file_path.strip() - if '*' in file_path: - return glob(file_path) - else: - return [file_path] + Returns: + A list of files represented by the provided path. + """ + file_path = file_path.strip() + if "*" in file_path: + return glob(file_path) + else: + return [file_path] def throw_if_file_access_not_allowed(file_path, logdir, has_auth_group): - """Throws an error if a file cannot be loaded for inference. - - Args: - file_path: A file path. - logdir: The path to the logdir of the TensorBoard context. - has_auth_group: True if TensorBoard was started with an authorized group, - in which case we allow access to all visible files. - - Raises: - InvalidUserInputError: If the file is not in the logdir and is not globally - readable. - """ - return - - -def example_protos_from_path(path, - num_examples=10, - start_index=0, - parse_examples=True, - sampling_odds=1, - example_class=tf.train.Example): - """Returns a number of examples from the provided path. - - Args: - path: A string path to the examples. - num_examples: The maximum number of examples to return from the path. - parse_examples: If true then parses the serialized proto from the path into - proto objects. Defaults to True. - sampling_odds: Odds of loading an example, used for sampling. When >= 1 - (the default), then all examples are loaded. - example_class: tf.train.Example or tf.train.SequenceExample class to load. - Defaults to tf.train.Example. - - Returns: - A list of Example protos or serialized proto strings at the path. - - Raises: - InvalidUserInputError: If examples cannot be procured from the path. - """ - - def append_examples_from_iterable(iterable, examples): - for value in iterable: - if sampling_odds >= 1 or random.random() < sampling_odds: - examples.append( - example_class.FromString(value) if parse_examples else value) - if len(examples) >= num_examples: - return - - examples = [] - - if path.endswith('.csv'): - def are_floats(values): - for value in values: + """Throws an error if a file cannot be loaded for inference. + + Args: + file_path: A file path. + logdir: The path to the logdir of the TensorBoard context. + has_auth_group: True if TensorBoard was started with an authorized group, + in which case we allow access to all visible files. + + Raises: + InvalidUserInputError: If the file is not in the logdir and is not globally + readable. + """ + return + + +def example_protos_from_path( + path, + num_examples=10, + start_index=0, + parse_examples=True, + sampling_odds=1, + example_class=tf.train.Example, +): + """Returns a number of examples from the provided path. + + Args: + path: A string path to the examples. + num_examples: The maximum number of examples to return from the path. + parse_examples: If true then parses the serialized proto from the path into + proto objects. Defaults to True. + sampling_odds: Odds of loading an example, used for sampling. When >= 1 + (the default), then all examples are loaded. + example_class: tf.train.Example or tf.train.SequenceExample class to load. + Defaults to tf.train.Example. + + Returns: + A list of Example protos or serialized proto strings at the path. + + Raises: + InvalidUserInputError: If examples cannot be procured from the path. + """ + + def append_examples_from_iterable(iterable, examples): + for value in iterable: + if sampling_odds >= 1 or random.random() < sampling_odds: + examples.append( + example_class.FromString(value) if parse_examples else value + ) + if len(examples) >= num_examples: + return + + examples = [] + + if path.endswith(".csv"): + + def are_floats(values): + for value in values: + try: + float(value) + except ValueError: + return False + return True + + csv.register_dialect("CsvDialect", skipinitialspace=True) + rows = csv.DictReader(open(path), dialect="CsvDialect") + for row in rows: + if sampling_odds < 1 and random.random() > sampling_odds: + continue + example = tf.train.Example() + for col in row.keys(): + # Parse out individual values from vertical-bar-delimited lists + values = [val.strip() for val in row[col].split("|")] + if are_floats(values): + example.features.feature[col].float_list.value.extend( + [float(val) for val in values] + ) + else: + example.features.feature[col].bytes_list.value.extend( + [val.encode("utf-8") for val in values] + ) + examples.append( + example if parse_examples else example.SerializeToString() + ) + if len(examples) >= num_examples: + break + return examples + + filenames = filepath_to_filepath_list(path) + compression_types = [ + "", # no compression (distinct from `None`!) + "GZIP", + "ZLIB", + ] + current_compression_idx = 0 + current_file_index = 0 + while current_file_index < len(filenames) and current_compression_idx < len( + compression_types + ): try: - float(value) - except ValueError: - return False - return True - csv.register_dialect('CsvDialect', skipinitialspace=True) - rows = csv.DictReader(open(path), dialect='CsvDialect') - for row in rows: - if sampling_odds < 1 and random.random() > sampling_odds: - continue - example = tf.train.Example() - for col in row.keys(): - # Parse out individual values from vertical-bar-delimited lists - values = [val.strip() for val in row[col].split('|')] - if are_floats(values): - example.features.feature[col].float_list.value.extend( - [float(val) for val in values]) - else: - example.features.feature[col].bytes_list.value.extend( - [val.encode('utf-8') for val in values]) - examples.append( - example if parse_examples else example.SerializeToString()) - if len(examples) >= num_examples: - break - return examples - - filenames = filepath_to_filepath_list(path) - compression_types = [ - '', # no compression (distinct from `None`!) - 'GZIP', - 'ZLIB', - ] - current_compression_idx = 0 - current_file_index = 0 - while (current_file_index < len(filenames) and - current_compression_idx < len(compression_types)): - try: - record_iterator = tf.compat.v1.python_io.tf_record_iterator( - path=filenames[current_file_index], - options=tf.io.TFRecordOptions( - compression_types[current_compression_idx])) - append_examples_from_iterable(record_iterator, examples) - current_file_index += 1 - if len(examples) >= num_examples: - break - except tf.errors.DataLossError: - current_compression_idx += 1 - except (IOError, tf.errors.NotFoundError) as e: - raise common_utils.InvalidUserInputError(e) - - if examples: - return examples - else: - raise common_utils.InvalidUserInputError( - 'No examples found at ' + path + - '. Valid formats are TFRecord files.') + record_iterator = tf.compat.v1.python_io.tf_record_iterator( + path=filenames[current_file_index], + options=tf.io.TFRecordOptions( + compression_types[current_compression_idx] + ), + ) + append_examples_from_iterable(record_iterator, examples) + current_file_index += 1 + if len(examples) >= num_examples: + break + except tf.errors.DataLossError: + current_compression_idx += 1 + except (IOError, tf.errors.NotFoundError) as e: + raise common_utils.InvalidUserInputError(e) + + if examples: + return examples + else: + raise common_utils.InvalidUserInputError( + "No examples found at " + + path + + ". Valid formats are TFRecord files." + ) + def call_servo(examples, serving_bundle): - """Send an RPC request to the Servomatic prediction service. - - Args: - examples: A list of examples that matches the model spec. - serving_bundle: A `ServingBundle` object that contains the information to - make the serving request. - - Returns: - A ClassificationResponse or RegressionResponse proto. - """ - parsed_url = urlparse('http://' + serving_bundle.inference_address) - channel = implementations.insecure_channel(parsed_url.hostname, - parsed_url.port) - stub = prediction_service_pb2.beta_create_PredictionService_stub(channel) - - if serving_bundle.use_predict: - request = predict_pb2.PredictRequest() - elif serving_bundle.model_type == 'classification': - request = classification_pb2.ClassificationRequest() - else: - request = regression_pb2.RegressionRequest() - request.model_spec.name = serving_bundle.model_name - if serving_bundle.model_version is not None: - request.model_spec.version.value = serving_bundle.model_version - if serving_bundle.signature is not None: - request.model_spec.signature_name = serving_bundle.signature - - if serving_bundle.use_predict: - # tf.compat.v1 API used here to convert tf.example into proto. This - # utility file is bundled in the witwidget pip package which has a dep - # on TensorFlow. - request.inputs[serving_bundle.predict_input_tensor].CopyFrom( - tf.compat.v1.make_tensor_proto( - values=[ex.SerializeToString() for ex in examples], - dtype=types_pb2.DT_STRING)) - else: - request.input.example_list.examples.extend(examples) - - if serving_bundle.use_predict: - return common_utils.convert_predict_response( - stub.Predict(request, 30.0), serving_bundle) # 30 secs timeout - elif serving_bundle.model_type == 'classification': - return stub.Classify(request, 30.0) # 30 secs timeout - else: - return stub.Regress(request, 30.0) # 30 secs timeout + """Send an RPC request to the Servomatic prediction service. + + Args: + examples: A list of examples that matches the model spec. + serving_bundle: A `ServingBundle` object that contains the information to + make the serving request. + + Returns: + A ClassificationResponse or RegressionResponse proto. + """ + parsed_url = urlparse("http://" + serving_bundle.inference_address) + channel = implementations.insecure_channel( + parsed_url.hostname, parsed_url.port + ) + stub = prediction_service_pb2.beta_create_PredictionService_stub(channel) + + if serving_bundle.use_predict: + request = predict_pb2.PredictRequest() + elif serving_bundle.model_type == "classification": + request = classification_pb2.ClassificationRequest() + else: + request = regression_pb2.RegressionRequest() + request.model_spec.name = serving_bundle.model_name + if serving_bundle.model_version is not None: + request.model_spec.version.value = serving_bundle.model_version + if serving_bundle.signature is not None: + request.model_spec.signature_name = serving_bundle.signature + + if serving_bundle.use_predict: + # tf.compat.v1 API used here to convert tf.example into proto. This + # utility file is bundled in the witwidget pip package which has a dep + # on TensorFlow. + request.inputs[serving_bundle.predict_input_tensor].CopyFrom( + tf.compat.v1.make_tensor_proto( + values=[ex.SerializeToString() for ex in examples], + dtype=types_pb2.DT_STRING, + ) + ) + else: + request.input.example_list.examples.extend(examples) + + if serving_bundle.use_predict: + return common_utils.convert_predict_response( + stub.Predict(request, 30.0), serving_bundle + ) # 30 secs timeout + elif serving_bundle.model_type == "classification": + return stub.Classify(request, 30.0) # 30 secs timeout + else: + return stub.Regress(request, 30.0) # 30 secs timeout diff --git a/tensorboard/plugins/interactive_inference/utils/test_utils.py b/tensorboard/plugins/interactive_inference/utils/test_utils.py index d2715427c8..f5cb660d10 100644 --- a/tensorboard/plugins/interactive_inference/utils/test_utils.py +++ b/tensorboard/plugins/interactive_inference/utils/test_utils.py @@ -22,30 +22,33 @@ def make_fake_example(single_int_val=0): - """Make a fake example with numeric and string features.""" - example = tf.train.Example() - example.features.feature['repeated_float'].float_list.value.extend( - [1.0, 2.0, 3.0, 4.0]) - example.features.feature['repeated_int'].int64_list.value.extend([10, 20]) - - example.features.feature['single_int'].int64_list.value.extend( - [single_int_val]) - example.features.feature['single_float'].float_list.value.extend([24.5]) - example.features.feature['non_numeric'].bytes_list.value.extend( - [b'cat', b'cat', b'woof']) - return example + """Make a fake example with numeric and string features.""" + example = tf.train.Example() + example.features.feature["repeated_float"].float_list.value.extend( + [1.0, 2.0, 3.0, 4.0] + ) + example.features.feature["repeated_int"].int64_list.value.extend([10, 20]) + + example.features.feature["single_int"].int64_list.value.extend( + [single_int_val] + ) + example.features.feature["single_float"].float_list.value.extend([24.5]) + example.features.feature["non_numeric"].bytes_list.value.extend( + [b"cat", b"cat", b"woof"] + ) + return example def write_out_examples(examples, path): - """Writes protos to the CNS path.""" + """Writes protos to the CNS path.""" - writer = tf.io.TFRecordWriter(path) - for example in examples: - writer.write(example.SerializeToString()) + writer = tf.io.TFRecordWriter(path) + for example in examples: + writer.write(example.SerializeToString()) def value_from_example(example, feature_name): - """Returns the feature as a Python list.""" - feature = example.features.feature[feature_name] - feature_type = feature.WhichOneof('kind') - return getattr(feature, feature_type).value[:] + """Returns the feature as a Python list.""" + feature = example.features.feature[feature_name] + feature_type = feature.WhichOneof("kind") + return getattr(feature, feature_type).value[:] diff --git a/tensorboard/plugins/interactive_inference/witwidget/__init__.py b/tensorboard/plugins/interactive_inference/witwidget/__init__.py index c809332077..eed38f9112 100644 --- a/tensorboard/plugins/interactive_inference/witwidget/__init__.py +++ b/tensorboard/plugins/interactive_inference/witwidget/__init__.py @@ -14,10 +14,13 @@ from witwidget.notebook.visualization import * + def _jupyter_nbextension_paths(): - return [{ - 'section': 'notebook', - 'src': 'static', - 'dest': 'wit-widget', - 'require': 'wit-widget/extension' - }] + return [ + { + "section": "notebook", + "src": "static", + "dest": "wit-widget", + "require": "wit-widget/extension", + } + ] diff --git a/tensorboard/plugins/interactive_inference/witwidget/notebook/base.py b/tensorboard/plugins/interactive_inference/witwidget/notebook/base.py index 17fbde4ac2..98a9ee0454 100644 --- a/tensorboard/plugins/interactive_inference/witwidget/notebook/base.py +++ b/tensorboard/plugins/interactive_inference/witwidget/notebook/base.py @@ -32,582 +32,707 @@ NUM_EXAMPLES_FOR_MUTANT_ANALYSIS = 50 # Custom user agent for tracking number of calls to Cloud AI Platform. -USER_AGENT_FOR_CAIP_TRACKING = 'WhatIfTool' +USER_AGENT_FOR_CAIP_TRACKING = "WhatIfTool" try: - POOL_SIZE = max(multiprocessing.cpu_count() - 1, 1) + POOL_SIZE = max(multiprocessing.cpu_count() - 1, 1) except Exception: - POOL_SIZE = 1 + POOL_SIZE = 1 class WitWidgetBase(object): - """WIT widget base class for common code between Jupyter and Colab.""" - - def __init__(self, config_builder): - """Constructor for WitWidgetBase. - - Args: - config_builder: WitConfigBuilder object containing settings for WIT. - """ - tf.get_logger().setLevel(logging.WARNING) - config = config_builder.build() - copied_config = dict(config) - self.estimator_and_spec = ( - dict(config.get('estimator_and_spec')) - if 'estimator_and_spec' in config else {}) - self.compare_estimator_and_spec = ( - dict(config.get('compare_estimator_and_spec')) - if 'compare_estimator_and_spec' in config else {}) - if 'estimator_and_spec' in copied_config: - del copied_config['estimator_and_spec'] - if 'compare_estimator_and_spec' in copied_config: - del copied_config['compare_estimator_and_spec'] - - self.custom_predict_fn = config.get('custom_predict_fn') - self.compare_custom_predict_fn = config.get('compare_custom_predict_fn') - self.custom_distance_fn = config.get('custom_distance_fn') - self.adjust_prediction_fn = config.get('adjust_prediction') - self.compare_adjust_prediction_fn = config.get('compare_adjust_prediction') - self.adjust_example_fn = config.get('adjust_example') - self.compare_adjust_example_fn = config.get('compare_adjust_example') - self.adjust_attribution_fn = config.get('adjust_attribution') - self.compare_adjust_attribution_fn = config.get('compare_adjust_attribution') - - if 'custom_predict_fn' in copied_config: - del copied_config['custom_predict_fn'] - if 'compare_custom_predict_fn' in copied_config: - del copied_config['compare_custom_predict_fn'] - if 'custom_distance_fn' in copied_config: - del copied_config['custom_distance_fn'] - copied_config['uses_custom_distance_fn'] = True - if 'adjust_prediction' in copied_config: - del copied_config['adjust_prediction'] - if 'compare_adjust_prediction' in copied_config: - del copied_config['compare_adjust_prediction'] - if 'adjust_example' in copied_config: - del copied_config['adjust_example'] - if 'compare_adjust_example' in copied_config: - del copied_config['compare_adjust_example'] - if 'adjust_attribution' in copied_config: - del copied_config['adjust_attribution'] - if 'compare_adjust_attribution' in copied_config: - del copied_config['compare_adjust_attribution'] - - examples = copied_config.pop('examples') - self.config = copied_config - self.set_examples(examples) - - # This tracks whether mutant inference is running in order to - # skip calling for explanations for CAIP models when inferring - # for mutant inference, for performance reasons. - self.running_mutant_infer = False - - # If using AI Platform for prediction, set the correct custom prediction - # functions. - if self.config.get('use_aip'): - self.custom_predict_fn = self._predict_aip_model - if self.config.get('compare_use_aip'): - self.compare_custom_predict_fn = self._predict_aip_compare_model - - # If using JSON input (not Example protos) and a custom predict - # function, then convert examples to JSON before sending to the - # custom predict function. - if self.config.get('uses_json_input'): - if self.custom_predict_fn is not None and not self.config.get('use_aip'): - user_predict = self.custom_predict_fn - def wrapped_custom_predict_fn(examples): - return user_predict(self._json_from_tf_examples(examples)) - self.custom_predict_fn = wrapped_custom_predict_fn - if (self.compare_custom_predict_fn is not None and - not self.config.get('compare_use_aip')): - compare_user_predict = self.compare_custom_predict_fn - def wrapped_compare_custom_predict_fn(examples): - return compare_user_predict(self._json_from_tf_examples(examples)) - self.compare_custom_predict_fn = wrapped_compare_custom_predict_fn - - def _get_element_html(self): - return """ + """WIT widget base class for common code between Jupyter and Colab.""" + + def __init__(self, config_builder): + """Constructor for WitWidgetBase. + + Args: + config_builder: WitConfigBuilder object containing settings for WIT. + """ + tf.get_logger().setLevel(logging.WARNING) + config = config_builder.build() + copied_config = dict(config) + self.estimator_and_spec = ( + dict(config.get("estimator_and_spec")) + if "estimator_and_spec" in config + else {} + ) + self.compare_estimator_and_spec = ( + dict(config.get("compare_estimator_and_spec")) + if "compare_estimator_and_spec" in config + else {} + ) + if "estimator_and_spec" in copied_config: + del copied_config["estimator_and_spec"] + if "compare_estimator_and_spec" in copied_config: + del copied_config["compare_estimator_and_spec"] + + self.custom_predict_fn = config.get("custom_predict_fn") + self.compare_custom_predict_fn = config.get("compare_custom_predict_fn") + self.custom_distance_fn = config.get("custom_distance_fn") + self.adjust_prediction_fn = config.get("adjust_prediction") + self.compare_adjust_prediction_fn = config.get( + "compare_adjust_prediction" + ) + self.adjust_example_fn = config.get("adjust_example") + self.compare_adjust_example_fn = config.get("compare_adjust_example") + self.adjust_attribution_fn = config.get("adjust_attribution") + self.compare_adjust_attribution_fn = config.get( + "compare_adjust_attribution" + ) + + if "custom_predict_fn" in copied_config: + del copied_config["custom_predict_fn"] + if "compare_custom_predict_fn" in copied_config: + del copied_config["compare_custom_predict_fn"] + if "custom_distance_fn" in copied_config: + del copied_config["custom_distance_fn"] + copied_config["uses_custom_distance_fn"] = True + if "adjust_prediction" in copied_config: + del copied_config["adjust_prediction"] + if "compare_adjust_prediction" in copied_config: + del copied_config["compare_adjust_prediction"] + if "adjust_example" in copied_config: + del copied_config["adjust_example"] + if "compare_adjust_example" in copied_config: + del copied_config["compare_adjust_example"] + if "adjust_attribution" in copied_config: + del copied_config["adjust_attribution"] + if "compare_adjust_attribution" in copied_config: + del copied_config["compare_adjust_attribution"] + + examples = copied_config.pop("examples") + self.config = copied_config + self.set_examples(examples) + + # This tracks whether mutant inference is running in order to + # skip calling for explanations for CAIP models when inferring + # for mutant inference, for performance reasons. + self.running_mutant_infer = False + + # If using AI Platform for prediction, set the correct custom prediction + # functions. + if self.config.get("use_aip"): + self.custom_predict_fn = self._predict_aip_model + if self.config.get("compare_use_aip"): + self.compare_custom_predict_fn = self._predict_aip_compare_model + + # If using JSON input (not Example protos) and a custom predict + # function, then convert examples to JSON before sending to the + # custom predict function. + if self.config.get("uses_json_input"): + if self.custom_predict_fn is not None and not self.config.get( + "use_aip" + ): + user_predict = self.custom_predict_fn + + def wrapped_custom_predict_fn(examples): + return user_predict(self._json_from_tf_examples(examples)) + + self.custom_predict_fn = wrapped_custom_predict_fn + if ( + self.compare_custom_predict_fn is not None + and not self.config.get("compare_use_aip") + ): + compare_user_predict = self.compare_custom_predict_fn + + def wrapped_compare_custom_predict_fn(examples): + return compare_user_predict( + self._json_from_tf_examples(examples) + ) + + self.compare_custom_predict_fn = ( + wrapped_compare_custom_predict_fn + ) + + def _get_element_html(self): + return """ """ - def set_examples(self, examples): - """Sets the examples shown in WIT. - - The examples are initially set by the examples specified in the config - builder during construction. This method can change which examples WIT - displays. - """ - if self.config.get('uses_json_input'): - tf_examples = self._json_to_tf_examples(examples) - self.examples = [json_format.MessageToJson(ex) for ex in tf_examples] - else: - self.examples = [json_format.MessageToJson(ex) for ex in examples] - self.updated_example_indices = set(range(len(examples))) - - def compute_custom_distance_impl(self, index, params=None): - exs_for_distance = [ - self.json_to_proto(example) for example in self.examples] - selected_ex = exs_for_distance[index] - return self.custom_distance_fn(selected_ex, exs_for_distance, params) - - def json_to_proto(self, json): - ex = (tf.train.SequenceExample() - if self.config.get('are_sequence_examples') - else tf.train.Example()) - json_format.Parse(json, ex) - return ex - - def infer_impl(self): - """Performs inference on examples that require inference.""" - indices_to_infer = sorted(self.updated_example_indices) - examples_to_infer = [ - self.json_to_proto(self.examples[index]) for index in indices_to_infer] - infer_objs = [] - extra_output_objs = [] - serving_bundle = inference_utils.ServingBundle( - self.config.get('inference_address'), - self.config.get('model_name'), - self.config.get('model_type'), - self.config.get('model_version'), - self.config.get('model_signature'), - self.config.get('uses_predict_api'), - self.config.get('predict_input_tensor'), - self.config.get('predict_output_tensor'), - self.estimator_and_spec.get('estimator'), - self.estimator_and_spec.get('feature_spec'), - self.custom_predict_fn) - (predictions, extra_output) = ( - inference_utils.run_inference_for_inference_results( - examples_to_infer, serving_bundle)) - infer_objs.append(predictions) - extra_output_objs.append(extra_output) - if ('inference_address_2' in self.config or - self.compare_estimator_and_spec.get('estimator') or - self.compare_custom_predict_fn): - serving_bundle = inference_utils.ServingBundle( - self.config.get('inference_address_2'), - self.config.get('model_name_2'), - self.config.get('model_type'), - self.config.get('model_version_2'), - self.config.get('model_signature_2'), - self.config.get('uses_predict_api'), - self.config.get('predict_input_tensor'), - self.config.get('predict_output_tensor'), - self.compare_estimator_and_spec.get('estimator'), - self.compare_estimator_and_spec.get('feature_spec'), - self.compare_custom_predict_fn) - (predictions, extra_output) = ( - inference_utils.run_inference_for_inference_results( - examples_to_infer, serving_bundle)) - infer_objs.append(predictions) - extra_output_objs.append(extra_output) - self.updated_example_indices = set() - return { - 'inferences': {'indices': indices_to_infer, 'results': infer_objs}, - 'label_vocab': self.config.get('label_vocab'), - 'extra_outputs': extra_output_objs} - - def infer_mutants_impl(self, info): - """Performs mutant inference on specified examples.""" - example_index = int(info['example_index']) - feature_name = info['feature_name'] - examples = (self.examples if example_index == -1 - else [self.examples[example_index]]) - examples = [self.json_to_proto(ex) for ex in examples] - scan_examples = [self.json_to_proto(ex) for ex in self.examples[0:50]] - serving_bundles = [] - serving_bundles.append(inference_utils.ServingBundle( - self.config.get('inference_address'), - self.config.get('model_name'), - self.config.get('model_type'), - self.config.get('model_version'), - self.config.get('model_signature'), - self.config.get('uses_predict_api'), - self.config.get('predict_input_tensor'), - self.config.get('predict_output_tensor'), - self.estimator_and_spec.get('estimator'), - self.estimator_and_spec.get('feature_spec'), - self.custom_predict_fn)) - if ('inference_address_2' in self.config or - self.compare_estimator_and_spec.get('estimator') or - self.compare_custom_predict_fn): - serving_bundles.append(inference_utils.ServingBundle( - self.config.get('inference_address_2'), - self.config.get('model_name_2'), - self.config.get('model_type'), - self.config.get('model_version_2'), - self.config.get('model_signature_2'), - self.config.get('uses_predict_api'), - self.config.get('predict_input_tensor'), - self.config.get('predict_output_tensor'), - self.compare_estimator_and_spec.get('estimator'), - self.compare_estimator_and_spec.get('feature_spec'), - self.compare_custom_predict_fn)) - viz_params = inference_utils.VizParams( - info['x_min'], info['x_max'], - scan_examples, 10, - info['feature_index_pattern']) - self.running_mutant_infer = True - charts = inference_utils.mutant_charts_for_feature( - examples, feature_name, serving_bundles, viz_params) - self.running_mutant_infer = False - return charts - - def get_eligible_features_impl(self): - """Returns information about features eligible for mutant inference.""" - examples = [self.json_to_proto(ex) for ex in self.examples[ - 0:NUM_EXAMPLES_FOR_MUTANT_ANALYSIS]] - return inference_utils.get_eligible_features( - examples, NUM_MUTANTS_TO_GENERATE) - - def sort_eligible_features_impl(self, info): - """Returns sorted list of interesting features for mutant inference.""" - features_list = info['features'] - chart_data = {} - for feat in features_list: - chart_data[feat['name']] = self.infer_mutants_impl({ - 'x_min': feat['observedMin'] if 'observedMin' in feat else 0, - 'x_max': feat['observedMax'] if 'observedMin' in feat else 0, - 'feature_index_pattern': None, - 'feature_name': feat['name'], - 'example_index': info['example_index'], - }) - return inference_utils.sort_eligible_features( - features_list, chart_data) - - def create_sprite(self): - """Returns an encoded image of thumbnails for image examples.""" - # Generate a sprite image for the examples if the examples contain the - # standard encoded image feature. - if not self.examples: - return None - example_to_check = self.json_to_proto(self.examples[0]) - feature_list = (example_to_check.context.feature - if self.config.get('are_sequence_examples') - else example_to_check.features.feature) - if 'image/encoded' in feature_list: - example_strings = [ - self.json_to_proto(ex).SerializeToString() - for ex in self.examples] - encoded = ensure_str(base64.b64encode( - inference_utils.create_sprite_image(example_strings))) - return 'data:image/png;base64,{}'.format(encoded) - else: - return None - - def _json_from_tf_examples(self, tf_examples): - json_exs = [] - feature_names = self.config.get('feature_names') - for ex in tf_examples: - # Create a JSON list or dict for each example depending on settings. - # Strip out any explicitly-labeled target feature from the example. - # This is needed because AI Platform models that accept JSON cannot handle - # when non-input features are provided as part of the object to run - # prediction on. - if self.config.get('uses_json_list'): - json_ex = [] - for feat in ex.features.feature: - if feature_names and feat in feature_names: - feat_idx = feature_names.index(feat) - else: - feat_idx = int(feat) - if (feat == self.config.get('target_feature') or - feat_idx == self.config.get('target_feature')): - continue - # Ensure the example value list is long enough to add the next feature - # from the tf.Example. - if feat_idx >= len(json_ex): - json_ex.extend([None] * (feat_idx - len(json_ex) + 1)) - if ex.features.feature[feat].HasField('int64_list'): - json_ex[feat_idx] = ex.features.feature[feat].int64_list.value[0] - elif ex.features.feature[feat].HasField('float_list'): - json_ex[feat_idx] = ex.features.feature[feat].float_list.value[0] - else: - json_ex[feat_idx] = ensure_str( - ex.features.feature[feat].bytes_list.value[0]) - else: - json_ex = {} - for feat in ex.features.feature: - if feat == self.config.get('target_feature'): - continue - if ex.features.feature[feat].HasField('int64_list'): - json_ex[feat] = ex.features.feature[feat].int64_list.value[0] - elif ex.features.feature[feat].HasField('float_list'): - json_ex[feat] = ex.features.feature[feat].float_list.value[0] - else: - json_ex[feat] = ensure_str( - ex.features.feature[feat].bytes_list.value[0]) - json_exs.append(json_ex) - return json_exs - - def _json_to_tf_examples(self, examples): - def add_single_feature(feat, value, ex): - if isinstance(value, integer_types): - ex.features.feature[feat].int64_list.value.append(value) - elif isinstance(value, Number): - ex.features.feature[feat].float_list.value.append(value) - else: - ex.features.feature[feat].bytes_list.value.append(value.encode('utf-8')) - - tf_examples = [] - for json_ex in examples: - ex = tf.train.Example() - # JSON examples can be lists of values (for xgboost models for instance), - # or dicts of key/value pairs. - if self.config.get('uses_json_list'): - feature_names = self.config.get('feature_names') - for (i, value) in enumerate(json_ex): - # If feature names have been provided, use those feature names instead - # of list indices for feature name when storing as tf.Example. - if feature_names and len(feature_names) > i: - feat = feature_names[i] - else: - feat = str(i) - add_single_feature(feat, value, ex) - tf_examples.append(ex) - else: - for feat in json_ex: - add_single_feature(feat, json_ex[feat], ex) - tf_examples.append(ex) - return tf_examples - - def _predict_aip_model(self, examples): - return self._predict_aip_impl( - examples, - self.config.get('inference_address'), - self.config.get('model_name'), - self.config.get('model_signature'), - self.config.get('force_json_input'), - self.adjust_example_fn, - self.adjust_prediction_fn, - self.adjust_attribution_fn, - self.config.get('aip_service_name'), - self.config.get('aip_service_version'), - self.config.get('get_explanations'), - self.config.get('aip_batch_size'), - self.config.get('aip_api_key')) - - def _predict_aip_compare_model(self, examples): - return self._predict_aip_impl( - examples, - self.config.get('inference_address_2'), - self.config.get('model_name_2'), - self.config.get('model_signature_2'), - self.config.get('compare_force_json_input'), - self.compare_adjust_example_fn, - self.compare_adjust_prediction_fn, - self.compare_adjust_attribution_fn, - self.config.get('compare_aip_service_name'), - self.config.get('compare_aip_service_version'), - self.config.get('compare_get_explanations'), - self.config.get('compare_aip_batch_size'), - self.config.get('compare_aip_api_key')) - - def _predict_aip_impl(self, examples, project, model, version, force_json, - adjust_example, adjust_prediction, adjust_attribution, - service_name, service_version, get_explanations, - batch_size, api_key): - """Custom prediction function for running inference through AI Platform.""" - - # Set up environment for GCP call for specified project. - os.environ['GOOGLE_CLOUD_PROJECT'] = project - - should_explain = get_explanations and not self.running_mutant_infer - - def predict(exs): - """Run prediction on a list of examples and return results.""" - # Properly package the examples to send for prediction. - discovery_url = None - error_during_prediction = False - if api_key is not None: - discovery_url = ( - ('https://%s.googleapis.com/$discovery/rest' - '?labels=GOOGLE_INTERNAL&key=%s&version=%s') - % (service_name, api_key, 'v1')) - credentials = GoogleCredentials.get_application_default() - service = googleapiclient.discovery.build( - service_name, service_version, cache_discovery=False, - developerKey=api_key, discoveryServiceUrl=discovery_url, - credentials=credentials) - else: - service = googleapiclient.discovery.build( - service_name, service_version, cache_discovery=False) - - name = 'projects/{}/models/{}'.format(project, model) - if version is not None: - name += '/versions/{}'.format(version) - - if self.config.get('uses_json_input') or force_json: - examples_for_predict = self._json_from_tf_examples(exs) - else: - examples_for_predict = [{'b64': base64.b64encode( - example.SerializeToString()).decode('utf-8') } - for example in exs] - - # If there is a user-specified input example adjustment to make, make it. - if adjust_example: - examples_for_predict = [ - adjust_example(ex) for ex in examples_for_predict] - - # Send request, including custom user-agent for tracking. - request_builder = service.projects().predict( - name=name, - body={'instances': examples_for_predict} - ) - user_agent = request_builder.headers.get('user-agent') - request_builder.headers['user-agent'] = ( - USER_AGENT_FOR_CAIP_TRACKING + - ('-' + user_agent if user_agent else '')) - try: - response = request_builder.execute() - except Exception as e: - error_during_prediction = True - response = {'error': str(e)} - - # Get the attributions and baseline score if explaination is enabled. - if should_explain and not error_during_prediction: - try: - request_builder = service.projects().explain( - name=name, - body={'instances': examples_for_predict} - ) - request_builder.headers['user-agent'] = ( - USER_AGENT_FOR_CAIP_TRACKING + - ('-' + user_agent if user_agent else '')) - explain_response = request_builder.execute() - explanations = ([explain['attributions_by_label'][0]['attributions'] - for explain in explain_response['explanations']]) - baseline_scores = [] - for i, explain in enumerate(explanations): - baseline_scores.append( - explain_response['explanations'][i][ - 'attributions_by_label'][0]['baseline_score']) - response.update( - {'explanations': explanations, 'baseline_scores': baseline_scores}) - except Exception as e: - pass - return response - - def chunks(l, n): - """Yield successive n-sized chunks from l.""" - for i in range(0, len(l), n): - yield l[i:i + n] - - # Run prediction in batches in threads. - if batch_size is None: - batch_size = len(examples) - batched_examples = list(chunks(examples, batch_size)) - - pool = multiprocessing.pool.ThreadPool(processes=POOL_SIZE) - responses = pool.map(predict, batched_examples) - pool.close() - pool.join() - - for response in responses: - if 'error' in response: - raise RuntimeError(response['error']) - - # Parse the results from the responses and return them. - all_predictions = [] - all_baseline_scores = [] - all_attributions = [] - - for response in responses: - if 'explanations' in response: - # If an attribution adjustment function was provided, use it to adjust - # the attributions. - if adjust_attribution is not None: - all_attributions.extend([ - adjust_attribution(attr) for attr in response['explanations']]) + def set_examples(self, examples): + """Sets the examples shown in WIT. + + The examples are initially set by the examples specified in the + config builder during construction. This method can change which + examples WIT displays. + """ + if self.config.get("uses_json_input"): + tf_examples = self._json_to_tf_examples(examples) + self.examples = [ + json_format.MessageToJson(ex) for ex in tf_examples + ] else: - all_attributions.extend(response['explanations']) - - if 'baseline_scores' in response: - all_baseline_scores.extend(response['baseline_scores']) - - # Use the specified key if one is provided. - key_to_use = self.config.get('predict_output_tensor') - - for pred in response['predictions']: - # If the prediction contains a key to fetch the prediction, use it. - if isinstance(pred, dict): - if key_to_use is None: - # If the dictionary only contains one key, use it. - returned_keys = list(pred.keys()) - if len(returned_keys) == 1: - key_to_use = returned_keys[0] - # Use a default key if necessary. - elif self.config.get('model_type') == 'classification': - key_to_use = 'probabilities' - else: - key_to_use = 'outputs' - - if key_to_use not in pred: - raise KeyError( - '"%s" not found in model predictions dictionary' % key_to_use) - - pred = pred[key_to_use] - - # If the model is regression and the response is a list, extract the - # score by taking the first element. - if (self.config.get('model_type') == 'regression' and - isinstance(pred, list)): - pred = pred[0] - - # If an prediction adjustment function was provided, use it to adjust - # the prediction. - if adjust_prediction: - pred = adjust_prediction(pred) - - # If the model is classification and the response is a single number, - # treat that as the positive class score for a binary classification - # and convert it into a list of those two class scores. WIT only - # accepts lists of class scores as results from classification models. - if (self.config.get('model_type') == 'classification'): - if not isinstance(pred, list): - pred = [pred] - if len(pred) == 1: - pred = [1 - pred[0], pred[0]] - - all_predictions.append(pred) - - results = {'predictions': all_predictions} - if all_attributions: - results.update({'attributions': all_attributions}) - if all_baseline_scores: - results.update({'baseline_score': all_baseline_scores}) - return results - - def create_selection_callback(self, examples, max_examples): - """Returns an example selection callback for use with TFMA. - - The returned function can be provided as an event handler for a TFMA - visualization to dynamically load examples matching a selected slice into - WIT. - - Args: - examples: A list of tf.Examples to filter and use with WIT. - max_examples: The maximum number of examples to create. - """ - def handle_selection(selected): - def extract_values(feat): - if feat.HasField('bytes_list'): - return [v.decode('utf-8') for v in feat.bytes_list.value] - elif feat.HasField('int64_list'): - return feat.int64_list.value - elif feat.HasField('float_list'): - return feat.float_list.value - return None - - filtered_examples = [] - for ex in examples: - if selected['sliceName'] == 'Overall': - filtered_examples.append(ex) + self.examples = [json_format.MessageToJson(ex) for ex in examples] + self.updated_example_indices = set(range(len(examples))) + + def compute_custom_distance_impl(self, index, params=None): + exs_for_distance = [ + self.json_to_proto(example) for example in self.examples + ] + selected_ex = exs_for_distance[index] + return self.custom_distance_fn(selected_ex, exs_for_distance, params) + + def json_to_proto(self, json): + ex = ( + tf.train.SequenceExample() + if self.config.get("are_sequence_examples") + else tf.train.Example() + ) + json_format.Parse(json, ex) + return ex + + def infer_impl(self): + """Performs inference on examples that require inference.""" + indices_to_infer = sorted(self.updated_example_indices) + examples_to_infer = [ + self.json_to_proto(self.examples[index]) + for index in indices_to_infer + ] + infer_objs = [] + extra_output_objs = [] + serving_bundle = inference_utils.ServingBundle( + self.config.get("inference_address"), + self.config.get("model_name"), + self.config.get("model_type"), + self.config.get("model_version"), + self.config.get("model_signature"), + self.config.get("uses_predict_api"), + self.config.get("predict_input_tensor"), + self.config.get("predict_output_tensor"), + self.estimator_and_spec.get("estimator"), + self.estimator_and_spec.get("feature_spec"), + self.custom_predict_fn, + ) + ( + predictions, + extra_output, + ) = inference_utils.run_inference_for_inference_results( + examples_to_infer, serving_bundle + ) + infer_objs.append(predictions) + extra_output_objs.append(extra_output) + if ( + "inference_address_2" in self.config + or self.compare_estimator_and_spec.get("estimator") + or self.compare_custom_predict_fn + ): + serving_bundle = inference_utils.ServingBundle( + self.config.get("inference_address_2"), + self.config.get("model_name_2"), + self.config.get("model_type"), + self.config.get("model_version_2"), + self.config.get("model_signature_2"), + self.config.get("uses_predict_api"), + self.config.get("predict_input_tensor"), + self.config.get("predict_output_tensor"), + self.compare_estimator_and_spec.get("estimator"), + self.compare_estimator_and_spec.get("feature_spec"), + self.compare_custom_predict_fn, + ) + ( + predictions, + extra_output, + ) = inference_utils.run_inference_for_inference_results( + examples_to_infer, serving_bundle + ) + infer_objs.append(predictions) + extra_output_objs.append(extra_output) + self.updated_example_indices = set() + return { + "inferences": {"indices": indices_to_infer, "results": infer_objs}, + "label_vocab": self.config.get("label_vocab"), + "extra_outputs": extra_output_objs, + } + + def infer_mutants_impl(self, info): + """Performs mutant inference on specified examples.""" + example_index = int(info["example_index"]) + feature_name = info["feature_name"] + examples = ( + self.examples + if example_index == -1 + else [self.examples[example_index]] + ) + examples = [self.json_to_proto(ex) for ex in examples] + scan_examples = [self.json_to_proto(ex) for ex in self.examples[0:50]] + serving_bundles = [] + serving_bundles.append( + inference_utils.ServingBundle( + self.config.get("inference_address"), + self.config.get("model_name"), + self.config.get("model_type"), + self.config.get("model_version"), + self.config.get("model_signature"), + self.config.get("uses_predict_api"), + self.config.get("predict_input_tensor"), + self.config.get("predict_output_tensor"), + self.estimator_and_spec.get("estimator"), + self.estimator_and_spec.get("feature_spec"), + self.custom_predict_fn, + ) + ) + if ( + "inference_address_2" in self.config + or self.compare_estimator_and_spec.get("estimator") + or self.compare_custom_predict_fn + ): + serving_bundles.append( + inference_utils.ServingBundle( + self.config.get("inference_address_2"), + self.config.get("model_name_2"), + self.config.get("model_type"), + self.config.get("model_version_2"), + self.config.get("model_signature_2"), + self.config.get("uses_predict_api"), + self.config.get("predict_input_tensor"), + self.config.get("predict_output_tensor"), + self.compare_estimator_and_spec.get("estimator"), + self.compare_estimator_and_spec.get("feature_spec"), + self.compare_custom_predict_fn, + ) + ) + viz_params = inference_utils.VizParams( + info["x_min"], + info["x_max"], + scan_examples, + 10, + info["feature_index_pattern"], + ) + self.running_mutant_infer = True + charts = inference_utils.mutant_charts_for_feature( + examples, feature_name, serving_bundles, viz_params + ) + self.running_mutant_infer = False + return charts + + def get_eligible_features_impl(self): + """Returns information about features eligible for mutant inference.""" + examples = [ + self.json_to_proto(ex) + for ex in self.examples[0:NUM_EXAMPLES_FOR_MUTANT_ANALYSIS] + ] + return inference_utils.get_eligible_features( + examples, NUM_MUTANTS_TO_GENERATE + ) + + def sort_eligible_features_impl(self, info): + """Returns sorted list of interesting features for mutant inference.""" + features_list = info["features"] + chart_data = {} + for feat in features_list: + chart_data[feat["name"]] = self.infer_mutants_impl( + { + "x_min": feat["observedMin"] + if "observedMin" in feat + else 0, + "x_max": feat["observedMax"] + if "observedMin" in feat + else 0, + "feature_index_pattern": None, + "feature_name": feat["name"], + "example_index": info["example_index"], + } + ) + return inference_utils.sort_eligible_features(features_list, chart_data) + + def create_sprite(self): + """Returns an encoded image of thumbnails for image examples.""" + # Generate a sprite image for the examples if the examples contain the + # standard encoded image feature. + if not self.examples: + return None + example_to_check = self.json_to_proto(self.examples[0]) + feature_list = ( + example_to_check.context.feature + if self.config.get("are_sequence_examples") + else example_to_check.features.feature + ) + if "image/encoded" in feature_list: + example_strings = [ + self.json_to_proto(ex).SerializeToString() + for ex in self.examples + ] + encoded = ensure_str( + base64.b64encode( + inference_utils.create_sprite_image(example_strings) + ) + ) + return "data:image/png;base64,{}".format(encoded) else: - values = extract_values(ex.features.feature[selected['sliceName']]) - if selected['sliceValue'] in values: - filtered_examples.append(ex) - if len(filtered_examples) == max_examples: - break - - self.set_examples(filtered_examples) - return handle_selection + return None + + def _json_from_tf_examples(self, tf_examples): + json_exs = [] + feature_names = self.config.get("feature_names") + for ex in tf_examples: + # Create a JSON list or dict for each example depending on settings. + # Strip out any explicitly-labeled target feature from the example. + # This is needed because AI Platform models that accept JSON cannot handle + # when non-input features are provided as part of the object to run + # prediction on. + if self.config.get("uses_json_list"): + json_ex = [] + for feat in ex.features.feature: + if feature_names and feat in feature_names: + feat_idx = feature_names.index(feat) + else: + feat_idx = int(feat) + if feat == self.config.get( + "target_feature" + ) or feat_idx == self.config.get("target_feature"): + continue + # Ensure the example value list is long enough to add the next feature + # from the tf.Example. + if feat_idx >= len(json_ex): + json_ex.extend([None] * (feat_idx - len(json_ex) + 1)) + if ex.features.feature[feat].HasField("int64_list"): + json_ex[feat_idx] = ex.features.feature[ + feat + ].int64_list.value[0] + elif ex.features.feature[feat].HasField("float_list"): + json_ex[feat_idx] = ex.features.feature[ + feat + ].float_list.value[0] + else: + json_ex[feat_idx] = ensure_str( + ex.features.feature[feat].bytes_list.value[0] + ) + else: + json_ex = {} + for feat in ex.features.feature: + if feat == self.config.get("target_feature"): + continue + if ex.features.feature[feat].HasField("int64_list"): + json_ex[feat] = ex.features.feature[ + feat + ].int64_list.value[0] + elif ex.features.feature[feat].HasField("float_list"): + json_ex[feat] = ex.features.feature[ + feat + ].float_list.value[0] + else: + json_ex[feat] = ensure_str( + ex.features.feature[feat].bytes_list.value[0] + ) + json_exs.append(json_ex) + return json_exs + + def _json_to_tf_examples(self, examples): + def add_single_feature(feat, value, ex): + if isinstance(value, integer_types): + ex.features.feature[feat].int64_list.value.append(value) + elif isinstance(value, Number): + ex.features.feature[feat].float_list.value.append(value) + else: + ex.features.feature[feat].bytes_list.value.append( + value.encode("utf-8") + ) + + tf_examples = [] + for json_ex in examples: + ex = tf.train.Example() + # JSON examples can be lists of values (for xgboost models for instance), + # or dicts of key/value pairs. + if self.config.get("uses_json_list"): + feature_names = self.config.get("feature_names") + for (i, value) in enumerate(json_ex): + # If feature names have been provided, use those feature names instead + # of list indices for feature name when storing as tf.Example. + if feature_names and len(feature_names) > i: + feat = feature_names[i] + else: + feat = str(i) + add_single_feature(feat, value, ex) + tf_examples.append(ex) + else: + for feat in json_ex: + add_single_feature(feat, json_ex[feat], ex) + tf_examples.append(ex) + return tf_examples + + def _predict_aip_model(self, examples): + return self._predict_aip_impl( + examples, + self.config.get("inference_address"), + self.config.get("model_name"), + self.config.get("model_signature"), + self.config.get("force_json_input"), + self.adjust_example_fn, + self.adjust_prediction_fn, + self.adjust_attribution_fn, + self.config.get("aip_service_name"), + self.config.get("aip_service_version"), + self.config.get("get_explanations"), + self.config.get("aip_batch_size"), + self.config.get("aip_api_key"), + ) + + def _predict_aip_compare_model(self, examples): + return self._predict_aip_impl( + examples, + self.config.get("inference_address_2"), + self.config.get("model_name_2"), + self.config.get("model_signature_2"), + self.config.get("compare_force_json_input"), + self.compare_adjust_example_fn, + self.compare_adjust_prediction_fn, + self.compare_adjust_attribution_fn, + self.config.get("compare_aip_service_name"), + self.config.get("compare_aip_service_version"), + self.config.get("compare_get_explanations"), + self.config.get("compare_aip_batch_size"), + self.config.get("compare_aip_api_key"), + ) + + def _predict_aip_impl( + self, + examples, + project, + model, + version, + force_json, + adjust_example, + adjust_prediction, + adjust_attribution, + service_name, + service_version, + get_explanations, + batch_size, + api_key, + ): + """Custom prediction function for running inference through AI + Platform.""" + + # Set up environment for GCP call for specified project. + os.environ["GOOGLE_CLOUD_PROJECT"] = project + + should_explain = get_explanations and not self.running_mutant_infer + + def predict(exs): + """Run prediction on a list of examples and return results.""" + # Properly package the examples to send for prediction. + discovery_url = None + error_during_prediction = False + if api_key is not None: + discovery_url = ( + "https://%s.googleapis.com/$discovery/rest" + "?labels=GOOGLE_INTERNAL&key=%s&version=%s" + ) % (service_name, api_key, "v1") + credentials = GoogleCredentials.get_application_default() + service = googleapiclient.discovery.build( + service_name, + service_version, + cache_discovery=False, + developerKey=api_key, + discoveryServiceUrl=discovery_url, + credentials=credentials, + ) + else: + service = googleapiclient.discovery.build( + service_name, service_version, cache_discovery=False + ) + + name = "projects/{}/models/{}".format(project, model) + if version is not None: + name += "/versions/{}".format(version) + + if self.config.get("uses_json_input") or force_json: + examples_for_predict = self._json_from_tf_examples(exs) + else: + examples_for_predict = [ + { + "b64": base64.b64encode( + example.SerializeToString() + ).decode("utf-8") + } + for example in exs + ] + + # If there is a user-specified input example adjustment to make, make it. + if adjust_example: + examples_for_predict = [ + adjust_example(ex) for ex in examples_for_predict + ] + + # Send request, including custom user-agent for tracking. + request_builder = service.projects().predict( + name=name, body={"instances": examples_for_predict} + ) + user_agent = request_builder.headers.get("user-agent") + request_builder.headers["user-agent"] = ( + USER_AGENT_FOR_CAIP_TRACKING + + ("-" + user_agent if user_agent else "") + ) + try: + response = request_builder.execute() + except Exception as e: + error_during_prediction = True + response = {"error": str(e)} + + # Get the attributions and baseline score if explaination is enabled. + if should_explain and not error_during_prediction: + try: + request_builder = service.projects().explain( + name=name, body={"instances": examples_for_predict} + ) + request_builder.headers["user-agent"] = ( + USER_AGENT_FOR_CAIP_TRACKING + + ("-" + user_agent if user_agent else "") + ) + explain_response = request_builder.execute() + explanations = [ + explain["attributions_by_label"][0]["attributions"] + for explain in explain_response["explanations"] + ] + baseline_scores = [] + for i, explain in enumerate(explanations): + baseline_scores.append( + explain_response["explanations"][i][ + "attributions_by_label" + ][0]["baseline_score"] + ) + response.update( + { + "explanations": explanations, + "baseline_scores": baseline_scores, + } + ) + except Exception as e: + pass + return response + + def chunks(l, n): + """Yield successive n-sized chunks from l.""" + for i in range(0, len(l), n): + yield l[i : i + n] + + # Run prediction in batches in threads. + if batch_size is None: + batch_size = len(examples) + batched_examples = list(chunks(examples, batch_size)) + + pool = multiprocessing.pool.ThreadPool(processes=POOL_SIZE) + responses = pool.map(predict, batched_examples) + pool.close() + pool.join() + + for response in responses: + if "error" in response: + raise RuntimeError(response["error"]) + + # Parse the results from the responses and return them. + all_predictions = [] + all_baseline_scores = [] + all_attributions = [] + + for response in responses: + if "explanations" in response: + # If an attribution adjustment function was provided, use it to adjust + # the attributions. + if adjust_attribution is not None: + all_attributions.extend( + [ + adjust_attribution(attr) + for attr in response["explanations"] + ] + ) + else: + all_attributions.extend(response["explanations"]) + + if "baseline_scores" in response: + all_baseline_scores.extend(response["baseline_scores"]) + + # Use the specified key if one is provided. + key_to_use = self.config.get("predict_output_tensor") + + for pred in response["predictions"]: + # If the prediction contains a key to fetch the prediction, use it. + if isinstance(pred, dict): + if key_to_use is None: + # If the dictionary only contains one key, use it. + returned_keys = list(pred.keys()) + if len(returned_keys) == 1: + key_to_use = returned_keys[0] + # Use a default key if necessary. + elif self.config.get("model_type") == "classification": + key_to_use = "probabilities" + else: + key_to_use = "outputs" + + if key_to_use not in pred: + raise KeyError( + '"%s" not found in model predictions dictionary' + % key_to_use + ) + + pred = pred[key_to_use] + + # If the model is regression and the response is a list, extract the + # score by taking the first element. + if self.config.get("model_type") == "regression" and isinstance( + pred, list + ): + pred = pred[0] + + # If an prediction adjustment function was provided, use it to adjust + # the prediction. + if adjust_prediction: + pred = adjust_prediction(pred) + + # If the model is classification and the response is a single number, + # treat that as the positive class score for a binary classification + # and convert it into a list of those two class scores. WIT only + # accepts lists of class scores as results from classification models. + if self.config.get("model_type") == "classification": + if not isinstance(pred, list): + pred = [pred] + if len(pred) == 1: + pred = [1 - pred[0], pred[0]] + + all_predictions.append(pred) + + results = {"predictions": all_predictions} + if all_attributions: + results.update({"attributions": all_attributions}) + if all_baseline_scores: + results.update({"baseline_score": all_baseline_scores}) + return results + + def create_selection_callback(self, examples, max_examples): + """Returns an example selection callback for use with TFMA. + + The returned function can be provided as an event handler for a TFMA + visualization to dynamically load examples matching a selected slice into + WIT. + + Args: + examples: A list of tf.Examples to filter and use with WIT. + max_examples: The maximum number of examples to create. + """ + + def handle_selection(selected): + def extract_values(feat): + if feat.HasField("bytes_list"): + return [v.decode("utf-8") for v in feat.bytes_list.value] + elif feat.HasField("int64_list"): + return feat.int64_list.value + elif feat.HasField("float_list"): + return feat.float_list.value + return None + + filtered_examples = [] + for ex in examples: + if selected["sliceName"] == "Overall": + filtered_examples.append(ex) + else: + values = extract_values( + ex.features.feature[selected["sliceName"]] + ) + if selected["sliceValue"] in values: + filtered_examples.append(ex) + if len(filtered_examples) == max_examples: + break + + self.set_examples(filtered_examples) + + return handle_selection diff --git a/tensorboard/plugins/interactive_inference/witwidget/notebook/colab/wit.py b/tensorboard/plugins/interactive_inference/witwidget/notebook/colab/wit.py index ae2eb7b468..8a1e0c6243 100644 --- a/tensorboard/plugins/interactive_inference/witwidget/notebook/colab/wit.py +++ b/tensorboard/plugins/interactive_inference/witwidget/notebook/colab/wit.py @@ -21,45 +21,65 @@ # Python functions for requests from javascript. def infer_examples(wit_id): - WitWidget.widgets[wit_id].infer() -output.register_callback('notebook.InferExamples', infer_examples) + WitWidget.widgets[wit_id].infer() + + +output.register_callback("notebook.InferExamples", infer_examples) def delete_example(wit_id, index): - WitWidget.widgets[wit_id].delete_example(index) -output.register_callback('notebook.DeleteExample', delete_example) + WitWidget.widgets[wit_id].delete_example(index) + + +output.register_callback("notebook.DeleteExample", delete_example) def duplicate_example(wit_id, index): - WitWidget.widgets[wit_id].duplicate_example(index) -output.register_callback('notebook.DuplicateExample', duplicate_example) + WitWidget.widgets[wit_id].duplicate_example(index) + + +output.register_callback("notebook.DuplicateExample", duplicate_example) def update_example(wit_id, index, example): - WitWidget.widgets[wit_id].update_example(index, example) -output.register_callback('notebook.UpdateExample', update_example) + WitWidget.widgets[wit_id].update_example(index, example) + + +output.register_callback("notebook.UpdateExample", update_example) def get_eligible_features(wit_id): - WitWidget.widgets[wit_id].get_eligible_features() -output.register_callback('notebook.GetEligibleFeatures', get_eligible_features) + WitWidget.widgets[wit_id].get_eligible_features() + + +output.register_callback("notebook.GetEligibleFeatures", get_eligible_features) def sort_eligible_features(wit_id, details): - WitWidget.widgets[wit_id].sort_eligible_features(details) -output.register_callback('notebook.SortEligibleFeatures', sort_eligible_features) + WitWidget.widgets[wit_id].sort_eligible_features(details) + + +output.register_callback( + "notebook.SortEligibleFeatures", sort_eligible_features +) def infer_mutants(wit_id, details): - WitWidget.widgets[wit_id].infer_mutants(details) -output.register_callback('notebook.InferMutants', infer_mutants) + WitWidget.widgets[wit_id].infer_mutants(details) + + +output.register_callback("notebook.InferMutants", infer_mutants) def compute_custom_distance(wit_id, index, callback_name, params): - WitWidget.widgets[wit_id].compute_custom_distance(index, callback_name, - params) -output.register_callback('notebook.ComputeCustomDistance', - compute_custom_distance) + WitWidget.widgets[wit_id].compute_custom_distance( + index, callback_name, params + ) + + +output.register_callback( + "notebook.ComputeCustomDistance", compute_custom_distance +) # HTML/javascript for the WIT frontend. @@ -211,129 +231,172 @@ def compute_custom_distance(wit_id, index, callback_name, params): class WitWidget(base.WitWidgetBase): - """WIT widget for colab.""" - - # Static instance list of constructed WitWidgets so python global functions - # can call into instances of this object - widgets = [] - - # Static instance index to keep track of ID number of each constructed - # WitWidget. - index = 0 - - def __init__(self, config_builder, height=1000): - """Constructor for colab notebook WitWidget. - - Args: - config_builder: WitConfigBuilder object containing settings for WIT. - height: Optional height in pixels for WIT to occupy. Defaults to 1000. - """ - self._ctor_complete = False - self.id = WitWidget.index - base.WitWidgetBase.__init__(self, config_builder) - # Add this instance to the static instance list. - WitWidget.widgets.append(self) - - # Display WIT Polymer element. - display.display(display.HTML(self._get_element_html())) - display.display(display.HTML( - WIT_HTML.format(height=height, id=self.id))) - - # Increment the static instance WitWidget index counter - WitWidget.index += 1 - - # Send the provided config and examples to JS - output.eval_js("""configCallback({config})""".format( - config=json.dumps(self.config))) - output.eval_js("""updateExamplesCallback({examples})""".format( - examples=json.dumps(self.examples))) - self._generate_sprite() - self._ctor_complete = True - - def _get_element_html(self): - return """ + """WIT widget for colab.""" + + # Static instance list of constructed WitWidgets so python global functions + # can call into instances of this object + widgets = [] + + # Static instance index to keep track of ID number of each constructed + # WitWidget. + index = 0 + + def __init__(self, config_builder, height=1000): + """Constructor for colab notebook WitWidget. + + Args: + config_builder: WitConfigBuilder object containing settings for WIT. + height: Optional height in pixels for WIT to occupy. Defaults to 1000. + """ + self._ctor_complete = False + self.id = WitWidget.index + base.WitWidgetBase.__init__(self, config_builder) + # Add this instance to the static instance list. + WitWidget.widgets.append(self) + + # Display WIT Polymer element. + display.display(display.HTML(self._get_element_html())) + display.display( + display.HTML(WIT_HTML.format(height=height, id=self.id)) + ) + + # Increment the static instance WitWidget index counter + WitWidget.index += 1 + + # Send the provided config and examples to JS + output.eval_js( + """configCallback({config})""".format( + config=json.dumps(self.config) + ) + ) + output.eval_js( + """updateExamplesCallback({examples})""".format( + examples=json.dumps(self.examples) + ) + ) + self._generate_sprite() + self._ctor_complete = True + + def _get_element_html(self): + return """ """ - def set_examples(self, examples): - base.WitWidgetBase.set_examples(self, examples) - # If this is called outside of the ctor, use a BroadcastChannel to send - # the updated examples to the visualization. Inside of the ctor, no action - # is necessary as the ctor handles all communication. - if self._ctor_complete: - # Use BroadcastChannel to allow this call to be made in a separate colab - # cell from the cell that displays WIT. - channel_name = 'updateExamples{}'.format(self.id) - output.eval_js("""(new BroadcastChannel('{channel_name}')).postMessage( + def set_examples(self, examples): + base.WitWidgetBase.set_examples(self, examples) + # If this is called outside of the ctor, use a BroadcastChannel to send + # the updated examples to the visualization. Inside of the ctor, no action + # is necessary as the ctor handles all communication. + if self._ctor_complete: + # Use BroadcastChannel to allow this call to be made in a separate colab + # cell from the cell that displays WIT. + channel_name = "updateExamples{}".format(self.id) + output.eval_js( + """(new BroadcastChannel('{channel_name}')).postMessage( {examples})""".format( - examples=json.dumps(self.examples), channel_name=channel_name)) - self._generate_sprite() - - def infer(self): - try: - inferences = base.WitWidgetBase.infer_impl(self) - output.eval_js("""inferenceCallback({inferences})""".format( - inferences=json.dumps(inferences))) - except Exception as e: - output.eval_js("""backendError({error})""".format( - error=json.dumps({'msg': repr(e)}))) - - def delete_example(self, index): - self.examples.pop(index) - self.updated_example_indices = set([ - i if i < index else i - 1 for i in self.updated_example_indices]) - self._generate_sprite() - - def update_example(self, index, example): - self.updated_example_indices.add(index) - self.examples[index] = example - self._generate_sprite() - - def duplicate_example(self, index): - self.examples.append(self.examples[index]) - self.updated_example_indices.add(len(self.examples) - 1) - self._generate_sprite() - - def compute_custom_distance(self, index, callback_fn, params): - try: - distances = base.WitWidgetBase.compute_custom_distance_impl( - self, index, params['distanceParams']) - callback_dict = { - 'distances': distances, - 'exInd': index, - 'funId': callback_fn, - 'params': params['callbackParams'] - } - output.eval_js("""distanceCallback({callback_dict})""".format( - callback_dict=json.dumps(callback_dict))) - except Exception as e: - output.eval_js( - """backendError({error})""".format( - error=json.dumps({'msg': repr(e)}))) - - def get_eligible_features(self): - features_list = base.WitWidgetBase.get_eligible_features_impl(self) - output.eval_js("""eligibleFeaturesCallback({features_list})""".format( - features_list=json.dumps(features_list))) - - def infer_mutants(self, info): - try: - json_mapping = base.WitWidgetBase.infer_mutants_impl(self, info) - output.eval_js("""inferMutantsCallback({json_mapping})""".format( - json_mapping=json.dumps(json_mapping))) - except Exception as e: - output.eval_js("""backendError({error})""".format( - error=json.dumps({'msg': repr(e)}))) - - def sort_eligible_features(self, info): - try: - features_list = base.WitWidgetBase.sort_eligible_features_impl(self, info) - output.eval_js("""sortEligibleFeaturesCallback({features_list})""".format( - features_list=json.dumps(features_list))) - except Exception as e: - output.eval_js("""backendError({error})""".format( - error=json.dumps({'msg': repr(e)}))) - - def _generate_sprite(self): - sprite = base.WitWidgetBase.create_sprite(self) - if sprite is not None: - output.eval_js("""spriteCallback('{sprite}')""".format(sprite=sprite)) + examples=json.dumps(self.examples), + channel_name=channel_name, + ) + ) + self._generate_sprite() + + def infer(self): + try: + inferences = base.WitWidgetBase.infer_impl(self) + output.eval_js( + """inferenceCallback({inferences})""".format( + inferences=json.dumps(inferences) + ) + ) + except Exception as e: + output.eval_js( + """backendError({error})""".format( + error=json.dumps({"msg": repr(e)}) + ) + ) + + def delete_example(self, index): + self.examples.pop(index) + self.updated_example_indices = set( + [i if i < index else i - 1 for i in self.updated_example_indices] + ) + self._generate_sprite() + + def update_example(self, index, example): + self.updated_example_indices.add(index) + self.examples[index] = example + self._generate_sprite() + + def duplicate_example(self, index): + self.examples.append(self.examples[index]) + self.updated_example_indices.add(len(self.examples) - 1) + self._generate_sprite() + + def compute_custom_distance(self, index, callback_fn, params): + try: + distances = base.WitWidgetBase.compute_custom_distance_impl( + self, index, params["distanceParams"] + ) + callback_dict = { + "distances": distances, + "exInd": index, + "funId": callback_fn, + "params": params["callbackParams"], + } + output.eval_js( + """distanceCallback({callback_dict})""".format( + callback_dict=json.dumps(callback_dict) + ) + ) + except Exception as e: + output.eval_js( + """backendError({error})""".format( + error=json.dumps({"msg": repr(e)}) + ) + ) + + def get_eligible_features(self): + features_list = base.WitWidgetBase.get_eligible_features_impl(self) + output.eval_js( + """eligibleFeaturesCallback({features_list})""".format( + features_list=json.dumps(features_list) + ) + ) + + def infer_mutants(self, info): + try: + json_mapping = base.WitWidgetBase.infer_mutants_impl(self, info) + output.eval_js( + """inferMutantsCallback({json_mapping})""".format( + json_mapping=json.dumps(json_mapping) + ) + ) + except Exception as e: + output.eval_js( + """backendError({error})""".format( + error=json.dumps({"msg": repr(e)}) + ) + ) + + def sort_eligible_features(self, info): + try: + features_list = base.WitWidgetBase.sort_eligible_features_impl( + self, info + ) + output.eval_js( + """sortEligibleFeaturesCallback({features_list})""".format( + features_list=json.dumps(features_list) + ) + ) + except Exception as e: + output.eval_js( + """backendError({error})""".format( + error=json.dumps({"msg": repr(e)}) + ) + ) + + def _generate_sprite(self): + sprite = base.WitWidgetBase.create_sprite(self) + if sprite is not None: + output.eval_js( + """spriteCallback('{sprite}')""".format(sprite=sprite) + ) diff --git a/tensorboard/plugins/interactive_inference/witwidget/notebook/jupyter/wit.py b/tensorboard/plugins/interactive_inference/witwidget/notebook/jupyter/wit.py index 19b1d92532..aef9ddbfd7 100644 --- a/tensorboard/plugins/interactive_inference/witwidget/notebook/jupyter/wit.py +++ b/tensorboard/plugins/interactive_inference/witwidget/notebook/jupyter/wit.py @@ -27,124 +27,128 @@ @widgets.register class WitWidget(widgets.DOMWidget, base.WitWidgetBase): - """WIT widget for Jupyter.""" - _view_name = Unicode('WITView').tag(sync=True) - _view_module = Unicode('wit-widget').tag(sync=True) - _view_module_version = Unicode('^0.1.0').tag(sync=True) - - # Traitlets for communicating between python and javascript. - config = Dict(dict()).tag(sync=True) - examples = List([]).tag(sync=True) - inferences = Dict(dict()).tag(sync=True) - infer = Int(0).tag(sync=True) - update_example = Dict(dict()).tag(sync=True) - delete_example = Dict(dict()).tag(sync=True) - duplicate_example = Dict(dict()).tag(sync=True) - updated_example_indices = Set(set()) - get_eligible_features = Int(0).tag(sync=True) - sort_eligible_features = Dict(dict()).tag(sync=True) - eligible_features = List([]).tag(sync=True) - infer_mutants = Dict(dict()).tag(sync=True) - mutant_charts = Dict([]).tag(sync=True) - mutant_charts_counter = Int(0) - sprite = Unicode('').tag(sync=True) - error = Dict(dict()).tag(sync=True) - compute_custom_distance = Dict(dict()).tag(sync=True) - custom_distance_dict = Dict(dict()).tag(sync=True) - - def __init__(self, config_builder, height=1000): - """Constructor for Jupyter notebook WitWidget. - - Args: - config_builder: WitConfigBuilder object containing settings for WIT. - height: Optional height in pixels for WIT to occupy. Defaults to 1000. - """ - widgets.DOMWidget.__init__(self, layout=Layout(height='%ipx' % height)) - base.WitWidgetBase.__init__(self, config_builder) - self.error_counter = 0 - - # Ensure the visualization takes all available width. - display(HTML("")) - - def set_examples(self, examples): - base.WitWidgetBase.set_examples(self, examples) - self._generate_sprite() - - def _report_error(self, err): - self.error = { - 'msg': repr(err), - 'counter': self.error_counter - } - self.error_counter += 1 - - @observe('infer') - def _infer(self, change): - try: - self.inferences = base.WitWidgetBase.infer_impl(self) - except Exception as e: - self._report_error(e) - - # Observer callbacks for changes from javascript. - @observe('get_eligible_features') - def _get_eligible_features(self, change): - features_list = base.WitWidgetBase.get_eligible_features_impl(self) - self.eligible_features = features_list - - @observe('sort_eligible_features') - def _sort_eligible_features(self, change): - info = self.sort_eligible_features - features_list = base.WitWidgetBase.sort_eligible_features_impl(self, info) - self.eligible_features = features_list - - @observe('infer_mutants') - def _infer_mutants(self, change): - info = self.infer_mutants - try: - json_mapping = base.WitWidgetBase.infer_mutants_impl(self, info) - json_mapping['counter'] = self.mutant_charts_counter - self.mutant_charts_counter += 1 - self.mutant_charts = json_mapping - except Exception as e: - self._report_error(e) - - @observe('update_example') - def _update_example(self, change): - index = self.update_example['index'] - self.updated_example_indices.add(index) - self.examples[index] = self.update_example['example'] - self._generate_sprite() - - @observe('duplicate_example') - def _duplicate_example(self, change): - self.examples.append(self.examples[self.duplicate_example['index']]) - self.updated_example_indices.add(len(self.examples) - 1) - self._generate_sprite() - - @observe('delete_example') - def _delete_example(self, change): - index = self.delete_example['index'] - self.examples.pop(index) - self.updated_example_indices = set([ - i if i < index else i - 1 for i in self.updated_example_indices]) - self._generate_sprite() - - @observe('compute_custom_distance') - def _compute_custom_distance(self, change): - info = self.compute_custom_distance - index = info['index'] - params = info['params'] - callback_fn = info['callback'] - try: - distances = base.WitWidgetBase.compute_custom_distance_impl(self, index, - params['distanceParams']) - self.custom_distance_dict = {'distances': distances, - 'exInd': index, - 'funId': callback_fn, - 'params': params['callbackParams']} - except Exception as e: - self._report_error(e) - - def _generate_sprite(self): - sprite = base.WitWidgetBase.create_sprite(self) - if sprite is not None: - self.sprite = sprite + """WIT widget for Jupyter.""" + + _view_name = Unicode("WITView").tag(sync=True) + _view_module = Unicode("wit-widget").tag(sync=True) + _view_module_version = Unicode("^0.1.0").tag(sync=True) + + # Traitlets for communicating between python and javascript. + config = Dict(dict()).tag(sync=True) + examples = List([]).tag(sync=True) + inferences = Dict(dict()).tag(sync=True) + infer = Int(0).tag(sync=True) + update_example = Dict(dict()).tag(sync=True) + delete_example = Dict(dict()).tag(sync=True) + duplicate_example = Dict(dict()).tag(sync=True) + updated_example_indices = Set(set()) + get_eligible_features = Int(0).tag(sync=True) + sort_eligible_features = Dict(dict()).tag(sync=True) + eligible_features = List([]).tag(sync=True) + infer_mutants = Dict(dict()).tag(sync=True) + mutant_charts = Dict([]).tag(sync=True) + mutant_charts_counter = Int(0) + sprite = Unicode("").tag(sync=True) + error = Dict(dict()).tag(sync=True) + compute_custom_distance = Dict(dict()).tag(sync=True) + custom_distance_dict = Dict(dict()).tag(sync=True) + + def __init__(self, config_builder, height=1000): + """Constructor for Jupyter notebook WitWidget. + + Args: + config_builder: WitConfigBuilder object containing settings for WIT. + height: Optional height in pixels for WIT to occupy. Defaults to 1000. + """ + widgets.DOMWidget.__init__(self, layout=Layout(height="%ipx" % height)) + base.WitWidgetBase.__init__(self, config_builder) + self.error_counter = 0 + + # Ensure the visualization takes all available width. + display(HTML("")) + + def set_examples(self, examples): + base.WitWidgetBase.set_examples(self, examples) + self._generate_sprite() + + def _report_error(self, err): + self.error = {"msg": repr(err), "counter": self.error_counter} + self.error_counter += 1 + + @observe("infer") + def _infer(self, change): + try: + self.inferences = base.WitWidgetBase.infer_impl(self) + except Exception as e: + self._report_error(e) + + # Observer callbacks for changes from javascript. + @observe("get_eligible_features") + def _get_eligible_features(self, change): + features_list = base.WitWidgetBase.get_eligible_features_impl(self) + self.eligible_features = features_list + + @observe("sort_eligible_features") + def _sort_eligible_features(self, change): + info = self.sort_eligible_features + features_list = base.WitWidgetBase.sort_eligible_features_impl( + self, info + ) + self.eligible_features = features_list + + @observe("infer_mutants") + def _infer_mutants(self, change): + info = self.infer_mutants + try: + json_mapping = base.WitWidgetBase.infer_mutants_impl(self, info) + json_mapping["counter"] = self.mutant_charts_counter + self.mutant_charts_counter += 1 + self.mutant_charts = json_mapping + except Exception as e: + self._report_error(e) + + @observe("update_example") + def _update_example(self, change): + index = self.update_example["index"] + self.updated_example_indices.add(index) + self.examples[index] = self.update_example["example"] + self._generate_sprite() + + @observe("duplicate_example") + def _duplicate_example(self, change): + self.examples.append(self.examples[self.duplicate_example["index"]]) + self.updated_example_indices.add(len(self.examples) - 1) + self._generate_sprite() + + @observe("delete_example") + def _delete_example(self, change): + index = self.delete_example["index"] + self.examples.pop(index) + self.updated_example_indices = set( + [i if i < index else i - 1 for i in self.updated_example_indices] + ) + self._generate_sprite() + + @observe("compute_custom_distance") + def _compute_custom_distance(self, change): + info = self.compute_custom_distance + index = info["index"] + params = info["params"] + callback_fn = info["callback"] + try: + distances = base.WitWidgetBase.compute_custom_distance_impl( + self, index, params["distanceParams"] + ) + self.custom_distance_dict = { + "distances": distances, + "exInd": index, + "funId": callback_fn, + "params": params["callbackParams"], + } + except Exception as e: + self._report_error(e) + + def _generate_sprite(self): + sprite = base.WitWidgetBase.create_sprite(self) + if sprite is not None: + self.sprite = sprite diff --git a/tensorboard/plugins/interactive_inference/witwidget/notebook/visualization.py b/tensorboard/plugins/interactive_inference/witwidget/notebook/visualization.py index 98ca6c9de1..e077a3a242 100644 --- a/tensorboard/plugins/interactive_inference/witwidget/notebook/visualization.py +++ b/tensorboard/plugins/interactive_inference/witwidget/notebook/visualization.py @@ -18,673 +18,706 @@ def _is_colab(): - return "google.colab" in sys.modules + return "google.colab" in sys.modules if _is_colab(): - from witwidget.notebook.colab.wit import * # pylint: disable=wildcard-import + from witwidget.notebook.colab.wit import * # pylint: disable=wildcard-import else: - from witwidget.notebook.jupyter.wit import * # pylint: disable=wildcard-import + from witwidget.notebook.jupyter.wit import * # pylint: disable=wildcard-import class WitConfigBuilder(object): - """Configuration builder for WitWidget settings.""" - - def __init__(self, examples, feature_names=None): - """Constructs the WitConfigBuilder object. - - Args: - examples: A list of tf.Example or tf.SequenceExample proto objects, or - raw JSON objects. JSON is allowed only for AI Platform-hosted models (see - 'set_ai_platform_model' and 'set_compare_ai_platform_model methods). - These are the examples that will be displayed in WIT. If no model to - infer these examples with is specified through the methods on this class, - then WIT will display the examples for exploration, but no model inference - will be performed by the tool. - feature_names: Optional, defaults to None. If examples are provided as - JSON lists of numbers (not as feature dictionaries), then this array - maps indices in the feature value lists to human-readable names of those - features, used for display purposes. - """ - self.config = {} - self.set_model_type('classification') - self.set_label_vocab([]) - self.set_examples(examples, feature_names) - - def build(self): - """Returns the configuration set through use of this builder object. - - Used by WitWidget to set the settings on an instance of the What-If Tool. - """ - return self.config - - def store(self, key, value): - self.config[key] = value - - def delete(self, key): - if key in self.config: - del self.config[key] - - def set_examples(self, examples, feature_names=None): - """Sets the examples to be displayed in WIT. - - Args: - examples: List of example protos or JSON objects. - feature_names: Optional, defaults to None. If examples are provided as - JSON lists of numbers (not as feature dictionaries), then this array - maps indices in the feature value lists to human-readable names of those - features, used just for display purposes. - - Returns: - self, in order to enabled method chaining. - """ - self.store('examples', examples) - if feature_names: - self.store('feature_names', feature_names) - if len(examples) > 0 and not ( - isinstance(examples[0], tf.train.Example) or - isinstance(examples[0], tf.train.SequenceExample)): - self._set_uses_json_input(True) - if isinstance(examples[0], list): - self._set_uses_json_list(True) - elif len(examples) > 0: - self.store('are_sequence_examples', - isinstance(examples[0], tf.train.SequenceExample)) - return self - - def set_model_type(self, model): - """Sets the type of the model being used for inference. - - Args: - model: The model type, such as "classification" or "regression". - The model type defaults to "classification". - - Returns: - self, in order to enabled method chaining. - """ - self.store('model_type', model) - return self - - def set_inference_address(self, address): - """Sets the inference address for model inference through TF Serving. - - Args: - address: The address of the served model, including port, such as - "localhost:8888". - - Returns: - self, in order to enabled method chaining. - """ - self.store('inference_address', address) - return self - - def set_model_name(self, name): - """Sets the model name for model inference through TF Serving. - - Setting a model name is required if inferring through a model hosted by - TF Serving. - - Args: - name: The name of the model to be queried through TF Serving at the - address provided by set_inference_address. - - Returns: - self, in order to enabled method chaining. - """ - self.store('model_name', name) - return self - - def has_model_name(self): - return 'model_name' in self.config - - def set_model_version(self, version): - """Sets the optional model version for model inference through TF Serving. - - Args: - version: The string version number of the model to be queried through TF - Serving. This is optional, as TF Serving will use the latest model version - if none is provided. - - Returns: - self, in order to enabled method chaining. - """ - self.store('model_version', version) - return self - - def set_model_signature(self, signature): - """Sets the optional model signature for model inference through TF Serving. - - Args: - signature: The string signature of the model to be queried through TF - Serving. This is optional, as TF Serving will use the default model - signature if none is provided. - - Returns: - self, in order to enabled method chaining. - """ - self.store('model_signature', signature) - return self - - def set_compare_inference_address(self, address): - """Sets the inference address for model inference for a second model hosted - by TF Serving. - - If you wish to compare the results of two models in WIT, use this method - to setup the details of the second model. - - Args: - address: The address of the served model, including port, such as - "localhost:8888". - - Returns: - self, in order to enabled method chaining. - """ - self.store('inference_address_2', address) - return self - - def set_compare_model_name(self, name): - """Sets the model name for a second model hosted by TF Serving. - - If you wish to compare the results of two models in WIT, use this method - to setup the details of the second model. - - Setting a model name is required if inferring through a model hosted by - TF Serving. - - Args: - name: The name of the model to be queried through TF Serving at the - address provided by set_compare_inference_address. - - Returns: - self, in order to enabled method chaining. - """ - self.store('model_name_2', name) - return self - - def has_compare_model_name(self): - return 'model_name_2' in self.config - - def set_compare_model_version(self, version): - """Sets the optional model version for a second model hosted by TF Serving. - - If you wish to compare the results of two models in WIT, use this method - to setup the details of the second model. - - Args: - version: The string version number of the model to be queried through TF - Serving. This is optional, as TF Serving will use the latest model version - if none is provided. - - Returns: - self, in order to enabled method chaining. - """ - self.store('model_version_2', version) - return self - - def set_compare_model_signature(self, signature): - """Sets the optional model signature for a second model hosted by TF - Serving. - - If you wish to compare the results of two models in WIT, use this method - to setup the details of the second model. - - Args: - signature: The string signature of the model to be queried through TF - Serving. This is optional, as TF Serving will use the default model - signature if none is provided. - - Returns: - self, in order to enabled method chaining. - """ - self.store('model_signature_2', signature) - return self - - def set_uses_predict_api(self, predict): - """Indicates that the model uses the Predict API, as opposed to the - Classification or Regression API. - - If the model doesn't use the standard Classification or Regression APIs - provided through TF Serving, but instead uses the more flexible Predict API, - then use this method to indicate that. If this is true, then use the - set_predict_input_tensor and set_predict_output_tensor methods to indicate - the names of the tensors that are used as the input and output for the - models provided in order to perform the appropriate inference request. - - Args: - predict: True if the model or models use the Predict API. - - Returns: - self, in order to enabled method chaining. - """ - self.store('uses_predict_api', predict) - return self - - def set_max_classes_to_display(self, max_classes): - """Sets the maximum number of class results to display for multiclass - classification models. - - When using WIT with a multiclass model with a large number of possible - classes, it can be helpful to restrict WIT to only display some smaller - number of the highest-scoring classes as inference results for any given - example. This method sets that limit. - - Args: - max_classes: The maximum number of classes to display for inference - results for multiclass classification models. - - Returns: - self, in order to enabled method chaining. - """ - self.store('max_classes', max_classes) - return self - - def set_multi_class(self, multiclass): - """Sets if the model(s) to query are mutliclass classification models. - - Args: - multiclass: True if the model or models are multiclass classififcation - models. Defaults to false. - - Returns: - self, in order to enabled method chaining. - """ - self.store('multiclass', multiclass) - return self - - def set_predict_input_tensor(self, tensor): - """Sets the name of the input tensor for models that use the Predict API. - - If using WIT with set_uses_predict_api(True), then call this to specify - the name of the input tensor of the model or models that accepts the - example proto for inference. - - Args: - tensor: The name of the input tensor. - - Returns: - self, in order to enabled method chaining. - """ - self.store('predict_input_tensor', tensor) - return self - - def set_predict_output_tensor(self, tensor): - """Sets the name of the output tensor for models that need output parsing. - - If using WIT with set_uses_predict_api(True), then call this to specify - the name of the output tensor of the model or models that returns the - inference results to be explored by WIT. - - If using an AI Platform model which returns multiple prediction - results in a dictionary, this method specifies the key corresponding to - the inference results to be explored by WIT. - - Args: - tensor: The name of the output tensor. - - Returns: - self, in order to enabled method chaining. - """ - self.store('predict_output_tensor', tensor) - return self - - def set_label_vocab(self, vocab): - """Sets the string value of numeric labels for classification models. - - For classification models, the model returns scores for each class ID - number (classes 0 and 1 for binary classification models). In order for - WIT to visually display the results in a more-readable way, you can specify - string labels for each class ID. - - Args: - vocab: A list of strings, where the string at each index corresponds to - the label for that class ID. For example ['<=50K', '>50K'] for the UCI - census binary classification task. - - Returns: - self, in order to enabled method chaining. - """ - self.store('label_vocab', vocab) - return self - - def set_estimator_and_feature_spec(self, estimator, feature_spec): - """Sets the model for inference as a TF Estimator. - - Instead of using TF Serving to host a model for WIT to query, WIT can - directly use a TF Estimator object as the model to query. In order to - accomplish this, a feature_spec must also be provided to parse the - example protos for input into the estimator. - - Args: - estimator: The TF Estimator which will be used for model inference. - feature_spec: The feature_spec object which will be used for example - parsing. - - Returns: - self, in order to enabled method chaining. - """ - # If custom function is set, remove it before setting estimator - self.delete('custom_predict_fn') - - self.store('estimator_and_spec', { - 'estimator': estimator, 'feature_spec': feature_spec}) - self.set_inference_address('estimator') - # If no model name has been set, give a default - if not self.has_model_name(): - self.set_model_name('1') - return self - - def set_compare_estimator_and_feature_spec(self, estimator, feature_spec): - """Sets a second model for inference as a TF Estimator. - - If you wish to compare the results of two models in WIT, use this method - to setup the details of the second model. - - Instead of using TF Serving to host a model for WIT to query, WIT can - directly use a TF Estimator object as the model to query. In order to - accomplish this, a feature_spec must also be provided to parse the - example protos for input into the estimator. - - Args: - estimator: The TF Estimator which will be used for model inference. - feature_spec: The feature_spec object which will be used for example - parsing. - - Returns: - self, in order to enabled method chaining. - """ - # If custom function is set, remove it before setting estimator - self.delete('compare_custom_predict_fn') - - self.store('compare_estimator_and_spec', { - 'estimator': estimator, 'feature_spec': feature_spec}) - self.set_compare_inference_address('estimator') - # If no model name has been set, give a default - if not self.has_compare_model_name(): - self.set_compare_model_name('2') - return self - - def set_custom_predict_fn(self, predict_fn): - """Sets a custom function for inference. - - Instead of using TF Serving to host a model for WIT to query, WIT can - directly use a custom function as the model to query. In this case, the - provided function should accept example protos and return: - - For classification: A 2D list of numbers. The first dimension is for - each example being predicted. The second dimension are the probabilities - for each class ID in the prediction. - - For regression: A 1D list of numbers, with a regression score for each - example being predicted. - - Optionally, if attributions or other prediction-time information - can be returned by the model with each prediction, then this method - can return a dict with the key 'predictions' containing the predictions - result list described above, and with the key 'attributions' containing - a list of attributions for each example that was predicted. - - For each example, the attributions list should contain a dict mapping - input feature names to attribution values for that feature on that example. - The attribution value can be one of these things: - - A single number representing the attribution for the entire feature - - A list of numbers representing the attribution to each value in the - feature for multivalent features - such as attributions to individual - pixels in an image or numbers in a list of numbers. - - A 2D list for sparse feature attribution. Index 0 contains a list of - feature values that there are attribution scores for. Index 1 contains - a list of attribution values for the corresponding feature values in - the first list. - - This dict can contain any other keys, with their values being a list of - prediction-time strings or numbers for each example being predicted. These - values will be displayed in WIT as extra information for each example, - usable in the same ways by WIT as normal input features (such as for - creating plots and slicing performance data). - - Args: - predict_fn: The custom python function which will be used for model - inference. - - Returns: - self, in order to enabled method chaining. - """ - # If estimator is set, remove it before setting predict_fn - self.delete('estimator_and_spec') - - self.store('custom_predict_fn', predict_fn) - self.set_inference_address('custom_predict_fn') - # If no model name has been set, give a default - if not self.has_model_name(): - self.set_model_name('1') - return self - - def set_compare_custom_predict_fn(self, predict_fn): - """Sets a second custom function for inference. - - If you wish to compare the results of two models in WIT, use this method - to setup the details of the second model. - - Instead of using TF Serving to host a model for WIT to query, WIT can - directly use a custom function as the model to query. In this case, the - provided function should accept example protos and return: - - For classification: A 2D list of numbers. The first dimension is for - each example being predicted. The second dimension are the probabilities - for each class ID in the prediction. - - For regression: A 1D list of numbers, with a regression score for each - example being predicted. - - Optionally, if attributions or other prediction-time information - can be returned by the model with each prediction, then this method - can return a dict with the key 'predictions' containing the predictions - result list described above, and with the key 'attributions' containing - a list of attributions for each example that was predicted. - - For each example, the attributions list should contain a dict mapping - input feature names to attribution values for that feature on that example. - The attribution value can be one of these things: - - A single number representing the attribution for the entire feature - - A list of numbers representing the attribution to each value in the - feature for multivalent features - such as attributions to individual - pixels in an image or numbers in a list of numbers. - - A 2D list for sparse feature attribution. Index 0 contains a list of - feature values that there are attribution scores for. Index 1 contains - a list of attribution values for the corresponding feature values in - the first list. - - This dict can contain any other keys, with their values being a list of - prediction-time strings or numbers for each example being predicted. These - values will be displayed in WIT as extra information for each example, - usable in the same ways by WIT as normal input features (such as for - creating plots and slicing performance data). - - Args: - predict_fn: The custom python function which will be used for model - inference. - - Returns: - self, in order to enabled method chaining. - """ - # If estimator is set, remove it before setting predict_fn - self.delete('compare_estimator_and_spec') - - self.store('compare_custom_predict_fn', predict_fn) - self.set_compare_inference_address('custom_predict_fn') - # If no model name has been set, give a default - if not self.has_compare_model_name(): - self.set_compare_model_name('2') - return self - - def set_custom_distance_fn(self, distance_fn): - """Sets a custom function for distance computation. - - WIT can directly use a custom function for all distance computations within - the tool. In this case, the provided function should accept a query example - proto and a list of example protos to compute the distance against and - return a 1D list of numbers containing the distances. - - Args: - distance_fn: The python function which will be used for distance - computation. - - Returns: - self, in order to enabled method chaining. - """ - if distance_fn is None: - self.delete('custom_distance_fn') - else: - self.store('custom_distance_fn', distance_fn) - return self - - def set_ai_platform_model( - self, project, model, version=None, force_json_input=None, - adjust_prediction=None, adjust_example=None, adjust_attribution=None, - service_name='ml', service_version='v1', get_explanations=True, - batch_size=500, api_key=None): - """Sets the model information for a model served by AI Platform. - - AI Platform Prediction a Google Cloud serving platform. - - Args: - project: The name of the AI Platform Prediction project. - model: The name of the AI Platform Prediction model. - version: Optional, the version of the AI Platform Prediction model. - force_json_input: Optional. If True and examples are provided as - tf.Example protos, convert them to raw JSON objects before sending them - for inference to this model. - adjust_prediction: Optional. If not None then this function takes the - prediction output from the model for a single example and converts it to - the appopriate format - a regression score or a list of class scores. Only - necessary if the model doesn't already abide by this format. - adjust_example: Optional. If not None then this function takes an example - to run prediction on and converts it to the format expected by the model. - Necessary for example if the served model expects a single data value to - run inference on instead of a list or dict of values. - adjust_attribution: Optional. If not None and the model returns attribution - information, then this function takes the attribution information for an - example and converts it to the format expected by the tool, which is a - dictionary of input feature names to attribution scores. Usually necessary - if making use of adjust_example and the model returns attribution results. - service_name: Optional. Name of the AI Platform Prediction service. Defaults - to 'ml'. - service_version: Optional. Version of the AI Platform Prediction service. Defaults - to 'v1'. - get_explanations: Optional. If a model is deployed with explanations, - then this specifies if explainations will be calculated and displayed. - Defaults to True. - batch_size: Optional. Sets the individual batch size to send for - prediction. Defaults to 500. - api_key. Optional. A generated API key to send with the requests to AI - Platform. - - Returns: - self, in order to enabled method chaining. - """ - self.set_inference_address(project) - self.set_model_name(model) - self.store('use_aip', True) - self.store('aip_service_name', service_name) - self.store('aip_service_version', service_version) - self.store('aip_batch_size', batch_size) - self.store('get_explanations', get_explanations) - if version is not None: - self.set_model_signature(version) - if force_json_input: - self.store('force_json_input', True) - if adjust_prediction: - self.store('adjust_prediction', adjust_prediction) - if adjust_example: - self.store('adjust_example', adjust_example) - if adjust_attribution: - self.store('adjust_attribution', adjust_attribution) - if api_key: - self.store('aip_api_key', api_key) - return self - - def set_compare_ai_platform_model( - self, project, model, version=None, force_json_input=None, - adjust_prediction=None, adjust_example=None, adjust_attribution=None, - service_name='ml', service_version='v1', get_explanations=True, - batch_size=500, api_key=None): - """Sets the model information for a second model served by AI Platform. - - AI Platform Prediction a Google Cloud serving platform. - - Args: - project: The name of the AI Platform Prediction project. - model: The name of the AI Platform Prediction model. - version: Optional, the version of the AI Platform Prediction model. - force_json_input: Optional. If True and examples are provided as - tf.Example protos, convert them to raw JSON objects before sending them - for inference to this model. - adjust_prediction: Optional. If not None then this function takes the - prediction output from the model for a single example and converts it to - the appopriate format - a regression score or a list of class scores. Only - necessary if the model doesn't already abide by this format. - adjust_example: Optional. If not None then this function takes an example - to run prediction on and converts it to the format expected by the model. - Necessary for example if the served model expects a single data value to - run inference on instead of a list or dict of values. - adjust_attribution: Optional. If not None and the model returns attribution - information, then this function takes the attribution information for an - example and converts it to the format expected by the tool, which is a - dictionary of input feature names to attribution scores. Usually necessary - if making use of adjust_example and the model returns attribution results. - service_name: Optional. Name of the AI Platform Prediction service. Defaults - to 'ml'. - service_version: Optional. Version of the AI Platform Prediction service. Defaults - to 'v1'. - get_explanations: Optional. If a model is deployed with explanations, - then this specifies if explainations will be calculated and displayed. - Defaults to True. - batch_size: Optional. Sets the individual batch size to send for - prediction. Defaults to 500. - api_key. Optional. A generated API key to send with the requests to AI - Platform. - - Returns: - self, in order to enabled method chaining. - """ - self.set_compare_inference_address(project) - self.set_compare_model_name(model) - self.store('compare_use_aip', True) - self.store('compare_aip_service_name', service_name) - self.store('compare_aip_service_version', service_version) - self.store('compare_aip_batch_size', batch_size) - self.store('compare_get_explanations', get_explanations) - if version is not None: - self.set_compare_model_signature(version) - if force_json_input: - self.store('compare_force_json_input', True) - if adjust_prediction: - self.store('compare_adjust_prediction', adjust_prediction) - if adjust_example: - self.store('compare_adjust_example', adjust_example) - if adjust_attribution: - self.store('compare_adjust_attribution', adjust_attribution) - if api_key: - self.store('compare_aip_api_key', api_key) - return self - - def set_target_feature(self, target): - """Sets the name of the target feature in the provided examples. - - If the provided examples contain a feature that represents the target - that the model is trying to predict, it can be specified by this method. - This is necessary for AI Platform models so that the target feature isn't - sent to the model for prediction, which can cause model inference errors. - - Args: - target: The name of the feature in the examples that represents the value - that the model is trying to predict. - - Returns: - self, in order to enabled method chaining. - """ - self.store('target_feature', target) - return self - - def _set_uses_json_input(self, is_json): - self.store('uses_json_input', is_json) - return self - - def _set_uses_json_list(self, is_list): - self.store('uses_json_list', is_list) - return self + """Configuration builder for WitWidget settings.""" + + def __init__(self, examples, feature_names=None): + """Constructs the WitConfigBuilder object. + + Args: + examples: A list of tf.Example or tf.SequenceExample proto objects, or + raw JSON objects. JSON is allowed only for AI Platform-hosted models (see + 'set_ai_platform_model' and 'set_compare_ai_platform_model methods). + These are the examples that will be displayed in WIT. If no model to + infer these examples with is specified through the methods on this class, + then WIT will display the examples for exploration, but no model inference + will be performed by the tool. + feature_names: Optional, defaults to None. If examples are provided as + JSON lists of numbers (not as feature dictionaries), then this array + maps indices in the feature value lists to human-readable names of those + features, used for display purposes. + """ + self.config = {} + self.set_model_type("classification") + self.set_label_vocab([]) + self.set_examples(examples, feature_names) + + def build(self): + """Returns the configuration set through use of this builder object. + + Used by WitWidget to set the settings on an instance of the + What-If Tool. + """ + return self.config + + def store(self, key, value): + self.config[key] = value + + def delete(self, key): + if key in self.config: + del self.config[key] + + def set_examples(self, examples, feature_names=None): + """Sets the examples to be displayed in WIT. + + Args: + examples: List of example protos or JSON objects. + feature_names: Optional, defaults to None. If examples are provided as + JSON lists of numbers (not as feature dictionaries), then this array + maps indices in the feature value lists to human-readable names of those + features, used just for display purposes. + + Returns: + self, in order to enabled method chaining. + """ + self.store("examples", examples) + if feature_names: + self.store("feature_names", feature_names) + if len(examples) > 0 and not ( + isinstance(examples[0], tf.train.Example) + or isinstance(examples[0], tf.train.SequenceExample) + ): + self._set_uses_json_input(True) + if isinstance(examples[0], list): + self._set_uses_json_list(True) + elif len(examples) > 0: + self.store( + "are_sequence_examples", + isinstance(examples[0], tf.train.SequenceExample), + ) + return self + + def set_model_type(self, model): + """Sets the type of the model being used for inference. + + Args: + model: The model type, such as "classification" or "regression". + The model type defaults to "classification". + + Returns: + self, in order to enabled method chaining. + """ + self.store("model_type", model) + return self + + def set_inference_address(self, address): + """Sets the inference address for model inference through TF Serving. + + Args: + address: The address of the served model, including port, such as + "localhost:8888". + + Returns: + self, in order to enabled method chaining. + """ + self.store("inference_address", address) + return self + + def set_model_name(self, name): + """Sets the model name for model inference through TF Serving. + + Setting a model name is required if inferring through a model hosted by + TF Serving. + + Args: + name: The name of the model to be queried through TF Serving at the + address provided by set_inference_address. + + Returns: + self, in order to enabled method chaining. + """ + self.store("model_name", name) + return self + + def has_model_name(self): + return "model_name" in self.config + + def set_model_version(self, version): + """Sets the optional model version for model inference through TF + Serving. + + Args: + version: The string version number of the model to be queried through TF + Serving. This is optional, as TF Serving will use the latest model version + if none is provided. + + Returns: + self, in order to enabled method chaining. + """ + self.store("model_version", version) + return self + + def set_model_signature(self, signature): + """Sets the optional model signature for model inference through TF + Serving. + + Args: + signature: The string signature of the model to be queried through TF + Serving. This is optional, as TF Serving will use the default model + signature if none is provided. + + Returns: + self, in order to enabled method chaining. + """ + self.store("model_signature", signature) + return self + + def set_compare_inference_address(self, address): + """Sets the inference address for model inference for a second model + hosted by TF Serving. + + If you wish to compare the results of two models in WIT, use this method + to setup the details of the second model. + + Args: + address: The address of the served model, including port, such as + "localhost:8888". + + Returns: + self, in order to enabled method chaining. + """ + self.store("inference_address_2", address) + return self + + def set_compare_model_name(self, name): + """Sets the model name for a second model hosted by TF Serving. + + If you wish to compare the results of two models in WIT, use this method + to setup the details of the second model. + + Setting a model name is required if inferring through a model hosted by + TF Serving. + + Args: + name: The name of the model to be queried through TF Serving at the + address provided by set_compare_inference_address. + + Returns: + self, in order to enabled method chaining. + """ + self.store("model_name_2", name) + return self + + def has_compare_model_name(self): + return "model_name_2" in self.config + + def set_compare_model_version(self, version): + """Sets the optional model version for a second model hosted by TF + Serving. + + If you wish to compare the results of two models in WIT, use this method + to setup the details of the second model. + + Args: + version: The string version number of the model to be queried through TF + Serving. This is optional, as TF Serving will use the latest model version + if none is provided. + + Returns: + self, in order to enabled method chaining. + """ + self.store("model_version_2", version) + return self + + def set_compare_model_signature(self, signature): + """Sets the optional model signature for a second model hosted by TF + Serving. + + If you wish to compare the results of two models in WIT, use this method + to setup the details of the second model. + + Args: + signature: The string signature of the model to be queried through TF + Serving. This is optional, as TF Serving will use the default model + signature if none is provided. + + Returns: + self, in order to enabled method chaining. + """ + self.store("model_signature_2", signature) + return self + + def set_uses_predict_api(self, predict): + """Indicates that the model uses the Predict API, as opposed to the + Classification or Regression API. + + If the model doesn't use the standard Classification or Regression APIs + provided through TF Serving, but instead uses the more flexible Predict API, + then use this method to indicate that. If this is true, then use the + set_predict_input_tensor and set_predict_output_tensor methods to indicate + the names of the tensors that are used as the input and output for the + models provided in order to perform the appropriate inference request. + + Args: + predict: True if the model or models use the Predict API. + + Returns: + self, in order to enabled method chaining. + """ + self.store("uses_predict_api", predict) + return self + + def set_max_classes_to_display(self, max_classes): + """Sets the maximum number of class results to display for multiclass + classification models. + + When using WIT with a multiclass model with a large number of possible + classes, it can be helpful to restrict WIT to only display some smaller + number of the highest-scoring classes as inference results for any given + example. This method sets that limit. + + Args: + max_classes: The maximum number of classes to display for inference + results for multiclass classification models. + + Returns: + self, in order to enabled method chaining. + """ + self.store("max_classes", max_classes) + return self + + def set_multi_class(self, multiclass): + """Sets if the model(s) to query are mutliclass classification models. + + Args: + multiclass: True if the model or models are multiclass classififcation + models. Defaults to false. + + Returns: + self, in order to enabled method chaining. + """ + self.store("multiclass", multiclass) + return self + + def set_predict_input_tensor(self, tensor): + """Sets the name of the input tensor for models that use the Predict + API. + + If using WIT with set_uses_predict_api(True), then call this to specify + the name of the input tensor of the model or models that accepts the + example proto for inference. + + Args: + tensor: The name of the input tensor. + + Returns: + self, in order to enabled method chaining. + """ + self.store("predict_input_tensor", tensor) + return self + + def set_predict_output_tensor(self, tensor): + """Sets the name of the output tensor for models that need output + parsing. + + If using WIT with set_uses_predict_api(True), then call this to specify + the name of the output tensor of the model or models that returns the + inference results to be explored by WIT. + + If using an AI Platform model which returns multiple prediction + results in a dictionary, this method specifies the key corresponding to + the inference results to be explored by WIT. + + Args: + tensor: The name of the output tensor. + + Returns: + self, in order to enabled method chaining. + """ + self.store("predict_output_tensor", tensor) + return self + + def set_label_vocab(self, vocab): + """Sets the string value of numeric labels for classification models. + + For classification models, the model returns scores for each class ID + number (classes 0 and 1 for binary classification models). In order for + WIT to visually display the results in a more-readable way, you can specify + string labels for each class ID. + + Args: + vocab: A list of strings, where the string at each index corresponds to + the label for that class ID. For example ['<=50K', '>50K'] for the UCI + census binary classification task. + + Returns: + self, in order to enabled method chaining. + """ + self.store("label_vocab", vocab) + return self + + def set_estimator_and_feature_spec(self, estimator, feature_spec): + """Sets the model for inference as a TF Estimator. + + Instead of using TF Serving to host a model for WIT to query, WIT can + directly use a TF Estimator object as the model to query. In order to + accomplish this, a feature_spec must also be provided to parse the + example protos for input into the estimator. + + Args: + estimator: The TF Estimator which will be used for model inference. + feature_spec: The feature_spec object which will be used for example + parsing. + + Returns: + self, in order to enabled method chaining. + """ + # If custom function is set, remove it before setting estimator + self.delete("custom_predict_fn") + + self.store( + "estimator_and_spec", + {"estimator": estimator, "feature_spec": feature_spec}, + ) + self.set_inference_address("estimator") + # If no model name has been set, give a default + if not self.has_model_name(): + self.set_model_name("1") + return self + + def set_compare_estimator_and_feature_spec(self, estimator, feature_spec): + """Sets a second model for inference as a TF Estimator. + + If you wish to compare the results of two models in WIT, use this method + to setup the details of the second model. + + Instead of using TF Serving to host a model for WIT to query, WIT can + directly use a TF Estimator object as the model to query. In order to + accomplish this, a feature_spec must also be provided to parse the + example protos for input into the estimator. + + Args: + estimator: The TF Estimator which will be used for model inference. + feature_spec: The feature_spec object which will be used for example + parsing. + + Returns: + self, in order to enabled method chaining. + """ + # If custom function is set, remove it before setting estimator + self.delete("compare_custom_predict_fn") + + self.store( + "compare_estimator_and_spec", + {"estimator": estimator, "feature_spec": feature_spec}, + ) + self.set_compare_inference_address("estimator") + # If no model name has been set, give a default + if not self.has_compare_model_name(): + self.set_compare_model_name("2") + return self + + def set_custom_predict_fn(self, predict_fn): + """Sets a custom function for inference. + + Instead of using TF Serving to host a model for WIT to query, WIT can + directly use a custom function as the model to query. In this case, the + provided function should accept example protos and return: + - For classification: A 2D list of numbers. The first dimension is for + each example being predicted. The second dimension are the probabilities + for each class ID in the prediction. + - For regression: A 1D list of numbers, with a regression score for each + example being predicted. + + Optionally, if attributions or other prediction-time information + can be returned by the model with each prediction, then this method + can return a dict with the key 'predictions' containing the predictions + result list described above, and with the key 'attributions' containing + a list of attributions for each example that was predicted. + + For each example, the attributions list should contain a dict mapping + input feature names to attribution values for that feature on that example. + The attribution value can be one of these things: + - A single number representing the attribution for the entire feature + - A list of numbers representing the attribution to each value in the + feature for multivalent features - such as attributions to individual + pixels in an image or numbers in a list of numbers. + - A 2D list for sparse feature attribution. Index 0 contains a list of + feature values that there are attribution scores for. Index 1 contains + a list of attribution values for the corresponding feature values in + the first list. + + This dict can contain any other keys, with their values being a list of + prediction-time strings or numbers for each example being predicted. These + values will be displayed in WIT as extra information for each example, + usable in the same ways by WIT as normal input features (such as for + creating plots and slicing performance data). + + Args: + predict_fn: The custom python function which will be used for model + inference. + + Returns: + self, in order to enabled method chaining. + """ + # If estimator is set, remove it before setting predict_fn + self.delete("estimator_and_spec") + + self.store("custom_predict_fn", predict_fn) + self.set_inference_address("custom_predict_fn") + # If no model name has been set, give a default + if not self.has_model_name(): + self.set_model_name("1") + return self + + def set_compare_custom_predict_fn(self, predict_fn): + """Sets a second custom function for inference. + + If you wish to compare the results of two models in WIT, use this method + to setup the details of the second model. + + Instead of using TF Serving to host a model for WIT to query, WIT can + directly use a custom function as the model to query. In this case, the + provided function should accept example protos and return: + - For classification: A 2D list of numbers. The first dimension is for + each example being predicted. The second dimension are the probabilities + for each class ID in the prediction. + - For regression: A 1D list of numbers, with a regression score for each + example being predicted. + + Optionally, if attributions or other prediction-time information + can be returned by the model with each prediction, then this method + can return a dict with the key 'predictions' containing the predictions + result list described above, and with the key 'attributions' containing + a list of attributions for each example that was predicted. + + For each example, the attributions list should contain a dict mapping + input feature names to attribution values for that feature on that example. + The attribution value can be one of these things: + - A single number representing the attribution for the entire feature + - A list of numbers representing the attribution to each value in the + feature for multivalent features - such as attributions to individual + pixels in an image or numbers in a list of numbers. + - A 2D list for sparse feature attribution. Index 0 contains a list of + feature values that there are attribution scores for. Index 1 contains + a list of attribution values for the corresponding feature values in + the first list. + + This dict can contain any other keys, with their values being a list of + prediction-time strings or numbers for each example being predicted. These + values will be displayed in WIT as extra information for each example, + usable in the same ways by WIT as normal input features (such as for + creating plots and slicing performance data). + + Args: + predict_fn: The custom python function which will be used for model + inference. + + Returns: + self, in order to enabled method chaining. + """ + # If estimator is set, remove it before setting predict_fn + self.delete("compare_estimator_and_spec") + + self.store("compare_custom_predict_fn", predict_fn) + self.set_compare_inference_address("custom_predict_fn") + # If no model name has been set, give a default + if not self.has_compare_model_name(): + self.set_compare_model_name("2") + return self + + def set_custom_distance_fn(self, distance_fn): + """Sets a custom function for distance computation. + + WIT can directly use a custom function for all distance computations within + the tool. In this case, the provided function should accept a query example + proto and a list of example protos to compute the distance against and + return a 1D list of numbers containing the distances. + + Args: + distance_fn: The python function which will be used for distance + computation. + + Returns: + self, in order to enabled method chaining. + """ + if distance_fn is None: + self.delete("custom_distance_fn") + else: + self.store("custom_distance_fn", distance_fn) + return self + + def set_ai_platform_model( + self, + project, + model, + version=None, + force_json_input=None, + adjust_prediction=None, + adjust_example=None, + adjust_attribution=None, + service_name="ml", + service_version="v1", + get_explanations=True, + batch_size=500, + api_key=None, + ): + """Sets the model information for a model served by AI Platform. + + AI Platform Prediction a Google Cloud serving platform. + + Args: + project: The name of the AI Platform Prediction project. + model: The name of the AI Platform Prediction model. + version: Optional, the version of the AI Platform Prediction model. + force_json_input: Optional. If True and examples are provided as + tf.Example protos, convert them to raw JSON objects before sending them + for inference to this model. + adjust_prediction: Optional. If not None then this function takes the + prediction output from the model for a single example and converts it to + the appopriate format - a regression score or a list of class scores. Only + necessary if the model doesn't already abide by this format. + adjust_example: Optional. If not None then this function takes an example + to run prediction on and converts it to the format expected by the model. + Necessary for example if the served model expects a single data value to + run inference on instead of a list or dict of values. + adjust_attribution: Optional. If not None and the model returns attribution + information, then this function takes the attribution information for an + example and converts it to the format expected by the tool, which is a + dictionary of input feature names to attribution scores. Usually necessary + if making use of adjust_example and the model returns attribution results. + service_name: Optional. Name of the AI Platform Prediction service. Defaults + to 'ml'. + service_version: Optional. Version of the AI Platform Prediction service. Defaults + to 'v1'. + get_explanations: Optional. If a model is deployed with explanations, + then this specifies if explainations will be calculated and displayed. + Defaults to True. + batch_size: Optional. Sets the individual batch size to send for + prediction. Defaults to 500. + api_key. Optional. A generated API key to send with the requests to AI + Platform. + + Returns: + self, in order to enabled method chaining. + """ + self.set_inference_address(project) + self.set_model_name(model) + self.store("use_aip", True) + self.store("aip_service_name", service_name) + self.store("aip_service_version", service_version) + self.store("aip_batch_size", batch_size) + self.store("get_explanations", get_explanations) + if version is not None: + self.set_model_signature(version) + if force_json_input: + self.store("force_json_input", True) + if adjust_prediction: + self.store("adjust_prediction", adjust_prediction) + if adjust_example: + self.store("adjust_example", adjust_example) + if adjust_attribution: + self.store("adjust_attribution", adjust_attribution) + if api_key: + self.store("aip_api_key", api_key) + return self + + def set_compare_ai_platform_model( + self, + project, + model, + version=None, + force_json_input=None, + adjust_prediction=None, + adjust_example=None, + adjust_attribution=None, + service_name="ml", + service_version="v1", + get_explanations=True, + batch_size=500, + api_key=None, + ): + """Sets the model information for a second model served by AI Platform. + + AI Platform Prediction a Google Cloud serving platform. + + Args: + project: The name of the AI Platform Prediction project. + model: The name of the AI Platform Prediction model. + version: Optional, the version of the AI Platform Prediction model. + force_json_input: Optional. If True and examples are provided as + tf.Example protos, convert them to raw JSON objects before sending them + for inference to this model. + adjust_prediction: Optional. If not None then this function takes the + prediction output from the model for a single example and converts it to + the appopriate format - a regression score or a list of class scores. Only + necessary if the model doesn't already abide by this format. + adjust_example: Optional. If not None then this function takes an example + to run prediction on and converts it to the format expected by the model. + Necessary for example if the served model expects a single data value to + run inference on instead of a list or dict of values. + adjust_attribution: Optional. If not None and the model returns attribution + information, then this function takes the attribution information for an + example and converts it to the format expected by the tool, which is a + dictionary of input feature names to attribution scores. Usually necessary + if making use of adjust_example and the model returns attribution results. + service_name: Optional. Name of the AI Platform Prediction service. Defaults + to 'ml'. + service_version: Optional. Version of the AI Platform Prediction service. Defaults + to 'v1'. + get_explanations: Optional. If a model is deployed with explanations, + then this specifies if explainations will be calculated and displayed. + Defaults to True. + batch_size: Optional. Sets the individual batch size to send for + prediction. Defaults to 500. + api_key. Optional. A generated API key to send with the requests to AI + Platform. + + Returns: + self, in order to enabled method chaining. + """ + self.set_compare_inference_address(project) + self.set_compare_model_name(model) + self.store("compare_use_aip", True) + self.store("compare_aip_service_name", service_name) + self.store("compare_aip_service_version", service_version) + self.store("compare_aip_batch_size", batch_size) + self.store("compare_get_explanations", get_explanations) + if version is not None: + self.set_compare_model_signature(version) + if force_json_input: + self.store("compare_force_json_input", True) + if adjust_prediction: + self.store("compare_adjust_prediction", adjust_prediction) + if adjust_example: + self.store("compare_adjust_example", adjust_example) + if adjust_attribution: + self.store("compare_adjust_attribution", adjust_attribution) + if api_key: + self.store("compare_aip_api_key", api_key) + return self + + def set_target_feature(self, target): + """Sets the name of the target feature in the provided examples. + + If the provided examples contain a feature that represents the target + that the model is trying to predict, it can be specified by this method. + This is necessary for AI Platform models so that the target feature isn't + sent to the model for prediction, which can cause model inference errors. + + Args: + target: The name of the feature in the examples that represents the value + that the model is trying to predict. + + Returns: + self, in order to enabled method chaining. + """ + self.store("target_feature", target) + return self + + def _set_uses_json_input(self, is_json): + self.store("uses_json_input", is_json) + return self + + def _set_uses_json_list(self, is_list): + self.store("uses_json_list", is_list) + return self diff --git a/tensorboard/plugins/interactive_inference/witwidget/pip_package/setup.py b/tensorboard/plugins/interactive_inference/witwidget/pip_package/setup.py index 6a2f176b61..6304c6dd91 100644 --- a/tensorboard/plugins/interactive_inference/witwidget/pip_package/setup.py +++ b/tensorboard/plugins/interactive_inference/witwidget/pip_package/setup.py @@ -21,85 +21,81 @@ import sys from setuptools import find_packages, setup -project_name = 'witwidget' +project_name = "witwidget" # Set when building the pip package -if '--project_name' in sys.argv: - project_name_idx = sys.argv.index('--project_name') - project_name = sys.argv[project_name_idx + 1] - sys.argv.remove('--project_name') - sys.argv.pop(project_name_idx) +if "--project_name" in sys.argv: + project_name_idx = sys.argv.index("--project_name") + project_name = sys.argv[project_name_idx + 1] + sys.argv.remove("--project_name") + sys.argv.pop(project_name_idx) -_TF_REQ = [ - 'tensorflow>=1.12.0', - 'tensorflow-serving-api>=1.12.0' -] +_TF_REQ = ["tensorflow>=1.12.0", "tensorflow-serving-api>=1.12.0"] # GPU build (note: the only difference is we depend on tensorflow-gpu and # tensorflow-serving-api-gpu so pip doesn't overwrite them with the CPU builds) -if 'witwidget-gpu' in project_name: - _TF_REQ = [ - 'tensorflow-gpu>=1.12.0', - 'tensorflow-serving-api-gpu>=1.12.0' - ] +if "witwidget-gpu" in project_name: + _TF_REQ = ["tensorflow-gpu>=1.12.0", "tensorflow-serving-api-gpu>=1.12.0"] REQUIRED_PACKAGES = [ - 'absl-py >= 0.4', - 'google-api-python-client>=1.7.8', - 'ipywidgets>=7.0.0', - 'jupyter>=1.0,<2', - 'oauth2client>=4.1.3', - 'six>=1.12.0', + "absl-py >= 0.4", + "google-api-python-client>=1.7.8", + "ipywidgets>=7.0.0", + "jupyter>=1.0,<2", + "oauth2client>=4.1.3", + "six>=1.12.0", ] + _TF_REQ + def get_readme(): - with open('README.rst') as f: - return f.read() + with open("README.rst") as f: + return f.read() + def get_version(): - version_ns = {} - with open(os.path.join('witwidget', 'version.py')) as f: - exec(f.read(), {}, version_ns) - return version_ns['VERSION'].replace('-', '') + version_ns = {} + with open(os.path.join("witwidget", "version.py")) as f: + exec(f.read(), {}, version_ns) + return version_ns["VERSION"].replace("-", "") + setup( - name=project_name, - version=get_version(), - description='What-If Tool jupyter widget', - long_description=get_readme(), - author='Google Inc.', - author_email='packages@tensorflow.org', - url='https://github.com/tensorflow/tensorboard/tree/master/tensorboard/plugins/interactive_inference', - include_package_data=True, - data_files=[ - ('share/jupyter/nbextensions/wit-widget', [ - 'witwidget/static/extension.js', - 'witwidget/static/index.js', - 'witwidget/static/index.js.map', - 'witwidget/static/wit_jupyter.html', - ],), - ('etc/jupyter/nbconfig/notebook.d/', ['wit-widget.json']) - ], - packages=find_packages(), - zip_safe=False, - install_requires=REQUIRED_PACKAGES, - keywords=[ - 'ipython', - 'jupyter', - 'widgets', - ], - license='Apache 2.0', - classifiers=[ - 'Development Status :: 4 - Beta', - 'Framework :: IPython', - 'Intended Audience :: Developers', - 'Intended Audience :: Science/Research', - 'Topic :: Multimedia :: Graphics', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - ] + name=project_name, + version=get_version(), + description="What-If Tool jupyter widget", + long_description=get_readme(), + author="Google Inc.", + author_email="packages@tensorflow.org", + url="https://github.com/tensorflow/tensorboard/tree/master/tensorboard/plugins/interactive_inference", + include_package_data=True, + data_files=[ + ( + "share/jupyter/nbextensions/wit-widget", + [ + "witwidget/static/extension.js", + "witwidget/static/index.js", + "witwidget/static/index.js.map", + "witwidget/static/wit_jupyter.html", + ], + ), + ("etc/jupyter/nbconfig/notebook.d/", ["wit-widget.json"]), + ], + packages=find_packages(), + zip_safe=False, + install_requires=REQUIRED_PACKAGES, + keywords=["ipython", "jupyter", "widgets",], + license="Apache 2.0", + classifiers=[ + "Development Status :: 4 - Beta", + "Framework :: IPython", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Topic :: Multimedia :: Graphics", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.3", + "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + ], ) diff --git a/tensorboard/plugins/interactive_inference/witwidget/version.py b/tensorboard/plugins/interactive_inference/witwidget/version.py index 0b665b55ed..8a32721f2e 100644 --- a/tensorboard/plugins/interactive_inference/witwidget/version.py +++ b/tensorboard/plugins/interactive_inference/witwidget/version.py @@ -14,4 +14,4 @@ """Contains the version string.""" -VERSION = '1.5.0' +VERSION = "1.5.0" diff --git a/tensorboard/plugins/mesh/demo_utils.py b/tensorboard/plugins/mesh/demo_utils.py index e165745984..1b442f04c6 100644 --- a/tensorboard/plugins/mesh/demo_utils.py +++ b/tensorboard/plugins/mesh/demo_utils.py @@ -23,64 +23,69 @@ def _parse_vertex(vertex_row): - """Parses a line in a PLY file which encodes a vertex coordinates. + """Parses a line in a PLY file which encodes a vertex coordinates. - Args: - vertex_row: string with vertex coordinates and color. + Args: + vertex_row: string with vertex coordinates and color. - Returns: - 2-tuple containing a length-3 array of vertex coordinates (as - floats) and a length-3 array of RGB color values (as ints between 0 - and 255, inclusive). - """ - vertex = vertex_row.strip().split() - # The row must contain coordinates with RGB/RGBA color in addition to that. - if len(vertex) >= 6: - # Supports only RGB colors now, alpha channel will be ignored. - # TODO(b/129298103): add support of RGBA in .ply files. - return ([float(coord) for coord in vertex[:3]], - [int(channel) for channel in vertex[3:6]]) - raise ValueError('PLY file must contain vertices with colors.') + Returns: + 2-tuple containing a length-3 array of vertex coordinates (as + floats) and a length-3 array of RGB color values (as ints between 0 + and 255, inclusive). + """ + vertex = vertex_row.strip().split() + # The row must contain coordinates with RGB/RGBA color in addition to that. + if len(vertex) >= 6: + # Supports only RGB colors now, alpha channel will be ignored. + # TODO(b/129298103): add support of RGBA in .ply files. + return ( + [float(coord) for coord in vertex[:3]], + [int(channel) for channel in vertex[3:6]], + ) + raise ValueError("PLY file must contain vertices with colors.") def _parse_face(face_row): - """Parses a line in a PLY file which encodes a face of the mesh.""" - face = [int(index) for index in face_row.strip().split()] - # Assert that number of vertices in a face is 3, i.e. it is a triangle - if len(face) != 4 or face[0] != 3: - raise ValueError( - 'Only supports face representation as a string with 4 numbers.') + """Parses a line in a PLY file which encodes a face of the mesh.""" + face = [int(index) for index in face_row.strip().split()] + # Assert that number of vertices in a face is 3, i.e. it is a triangle + if len(face) != 4 or face[0] != 3: + raise ValueError( + "Only supports face representation as a string with 4 numbers." + ) - return face[1:] + return face[1:] def read_ascii_ply(filename): - """Reads a PLY file encoded in ASCII format. + """Reads a PLY file encoded in ASCII format. - NOTE: this util method is not intended to be comprehensive PLY reader - and serves as part of demo application. + NOTE: this util method is not intended to be comprehensive PLY reader + and serves as part of demo application. - Args: - filename: path to a PLY file to read. + Args: + filename: path to a PLY file to read. - Returns: - numpy `[dim_1, 3]` array of vertices, `[dim_1, 3]` array of colors and - `[dim_1, 3]` array of faces of the mesh. - """ - with tf.io.gfile.GFile(filename) as ply_file: - for line in ply_file: - if line.startswith('end_header'): - break - elif line.startswith('element vertex'): - vert_count = int(line.split()[-1]) - elif line.startswith('element face'): - face_count = int(line.split()[-1]) - # Read vertices and their colors. - vertex_data = [_parse_vertex(next(ply_file)) for _ in range(vert_count)] - vertices = [datum[0] for datum in vertex_data] - colors = [datum[1] for datum in vertex_data] - # Read faces. - faces = [_parse_face(next(ply_file)) for _ in range(face_count)] - return (np.array(vertices).astype(np.float32), + Returns: + numpy `[dim_1, 3]` array of vertices, `[dim_1, 3]` array of colors and + `[dim_1, 3]` array of faces of the mesh. + """ + with tf.io.gfile.GFile(filename) as ply_file: + for line in ply_file: + if line.startswith("end_header"): + break + elif line.startswith("element vertex"): + vert_count = int(line.split()[-1]) + elif line.startswith("element face"): + face_count = int(line.split()[-1]) + # Read vertices and their colors. + vertex_data = [_parse_vertex(next(ply_file)) for _ in range(vert_count)] + vertices = [datum[0] for datum in vertex_data] + colors = [datum[1] for datum in vertex_data] + # Read faces. + faces = [_parse_face(next(ply_file)) for _ in range(face_count)] + return ( + np.array(vertices).astype(np.float32), np.array(colors).astype(np.uint8), - np.array(faces).astype(np.int32)) + np.array(faces).astype(np.int32), + ) diff --git a/tensorboard/plugins/mesh/demo_utils_test.py b/tensorboard/plugins/mesh/demo_utils_test.py index 47ad137dd6..437c7eaade 100644 --- a/tensorboard/plugins/mesh/demo_utils_test.py +++ b/tensorboard/plugins/mesh/demo_utils_test.py @@ -27,36 +27,41 @@ class TestPLYReader(tf.test.TestCase): - def test_parse_vertex(self): - """Tests vertex coordinate and color parsing.""" - # Vertex 3D coordinates with RGBA color. - vertex_data = [-0.249245, 1.119303, 0.3095566, 60, 253, 32, 255] - coords, colors = demo_utils._parse_vertex(' '.join(map(str, vertex_data))) - self.assertListEqual(coords, vertex_data[:3]) - self.assertListEqual(colors, vertex_data[3:6]) - - def test_prase_vertex_expects_colors(self): - """Tests that method will throw error if color is not poresent.""" - with self.assertRaisesRegexp(ValueError, - 'PLY file must contain vertices with colors'): - demo_utils._parse_vertex('1 2 3') - - def test_parse_face(self): - """Tests face line parsing.""" - face_data = [3, 10, 20, 30] - parsed_face = demo_utils._parse_face(' '.join(map(str, face_data))) - self.assertListEqual(parsed_face, face_data[1:]) - - def test_read_ascii_ply(self): - """Tests end-to-end PLY file reading and parsing.""" - test_ply = os.path.join( - os.path.dirname(os.environ['TEST_BINARY']), - 'test_data', 'icosphere.ply') - vertices, colors, faces = demo_utils.read_ascii_ply(test_ply) - self.assertEqual(len(vertices), 82) - self.assertEqual(len(vertices), len(colors)) - self.assertEqual(len(faces), 80) - - -if __name__ == '__main__': - tf.test.main() + def test_parse_vertex(self): + """Tests vertex coordinate and color parsing.""" + # Vertex 3D coordinates with RGBA color. + vertex_data = [-0.249245, 1.119303, 0.3095566, 60, 253, 32, 255] + coords, colors = demo_utils._parse_vertex( + " ".join(map(str, vertex_data)) + ) + self.assertListEqual(coords, vertex_data[:3]) + self.assertListEqual(colors, vertex_data[3:6]) + + def test_prase_vertex_expects_colors(self): + """Tests that method will throw error if color is not poresent.""" + with self.assertRaisesRegexp( + ValueError, "PLY file must contain vertices with colors" + ): + demo_utils._parse_vertex("1 2 3") + + def test_parse_face(self): + """Tests face line parsing.""" + face_data = [3, 10, 20, 30] + parsed_face = demo_utils._parse_face(" ".join(map(str, face_data))) + self.assertListEqual(parsed_face, face_data[1:]) + + def test_read_ascii_ply(self): + """Tests end-to-end PLY file reading and parsing.""" + test_ply = os.path.join( + os.path.dirname(os.environ["TEST_BINARY"]), + "test_data", + "icosphere.ply", + ) + vertices, colors, faces = demo_utils.read_ascii_ply(test_ply) + self.assertEqual(len(vertices), 82) + self.assertEqual(len(vertices), len(colors)) + self.assertEqual(len(faces), 80) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/mesh/mesh_demo.py b/tensorboard/plugins/mesh/mesh_demo.py index db034ac79e..aab8b7c616 100644 --- a/tensorboard/plugins/mesh/mesh_demo.py +++ b/tensorboard/plugins/mesh/mesh_demo.py @@ -27,9 +27,10 @@ from tensorboard.plugins.mesh import summary as mesh_summary from tensorboard.plugins.mesh import demo_utils -flags.DEFINE_string('logdir', '/tmp/mesh_demo', - 'Directory to write event logs to.') -flags.DEFINE_string('mesh_path', None, 'Path to PLY file to visualize.') +flags.DEFINE_string( + "logdir", "/tmp/mesh_demo", "Directory to write event logs to." +) +flags.DEFINE_string("mesh_path", None, "Path to PLY file to visualize.") FLAGS = flags.FLAGS @@ -40,63 +41,66 @@ def run(): - """Runs session with a mesh summary.""" - # Mesh summaries only work on TensorFlow 1.x. - if int(tf.__version__.split('.')[0]) > 1: - raise ImportError('TensorFlow 1.x is required to run this demo.') - # Flag mesh_path is required. - if FLAGS.mesh_path is None: - raise ValueError( - 'Flag --mesh_path is required and must contain path to PLY file.') - # Camera and scene configuration. - config_dict = { - 'camera': {'cls': 'PerspectiveCamera', 'fov': 75} - } - - # Read sample PLY file. - vertices, colors, faces = demo_utils.read_ascii_ply(FLAGS.mesh_path) - - # Add batch dimension. - vertices = np.expand_dims(vertices, 0) - faces = np.expand_dims(faces, 0) - colors = np.expand_dims(colors, 0) - - # Create placeholders for tensors representing the mesh. - step = tf.placeholder(tf.int32, ()) - vertices_tensor = tf.placeholder( - tf.float32, vertices.shape) - faces_tensor = tf.placeholder( - tf.int32, faces.shape) - colors_tensor = tf.placeholder( - tf.int32, colors.shape) - - # Change colors over time. - t = tf.cast(step, tf.float32) / _MAX_STEPS - transformed_colors = t * (255 - colors) + (1 - t) * colors - - meshes_summary = mesh_summary.op( - 'mesh_color_tensor', vertices=vertices_tensor, faces=faces_tensor, - colors=transformed_colors, config_dict=config_dict) - - # Create summary writer and session. - writer = tf.summary.FileWriter(FLAGS.logdir) - sess = tf.Session() - - for i in range(_MAX_STEPS): - summary = sess.run(meshes_summary, feed_dict={ - vertices_tensor: vertices, - faces_tensor: faces, - colors_tensor: colors, - step: i, - }) - writer.add_summary(summary, global_step=i) + """Runs session with a mesh summary.""" + # Mesh summaries only work on TensorFlow 1.x. + if int(tf.__version__.split(".")[0]) > 1: + raise ImportError("TensorFlow 1.x is required to run this demo.") + # Flag mesh_path is required. + if FLAGS.mesh_path is None: + raise ValueError( + "Flag --mesh_path is required and must contain path to PLY file." + ) + # Camera and scene configuration. + config_dict = {"camera": {"cls": "PerspectiveCamera", "fov": 75}} + + # Read sample PLY file. + vertices, colors, faces = demo_utils.read_ascii_ply(FLAGS.mesh_path) + + # Add batch dimension. + vertices = np.expand_dims(vertices, 0) + faces = np.expand_dims(faces, 0) + colors = np.expand_dims(colors, 0) + + # Create placeholders for tensors representing the mesh. + step = tf.placeholder(tf.int32, ()) + vertices_tensor = tf.placeholder(tf.float32, vertices.shape) + faces_tensor = tf.placeholder(tf.int32, faces.shape) + colors_tensor = tf.placeholder(tf.int32, colors.shape) + + # Change colors over time. + t = tf.cast(step, tf.float32) / _MAX_STEPS + transformed_colors = t * (255 - colors) + (1 - t) * colors + + meshes_summary = mesh_summary.op( + "mesh_color_tensor", + vertices=vertices_tensor, + faces=faces_tensor, + colors=transformed_colors, + config_dict=config_dict, + ) + + # Create summary writer and session. + writer = tf.summary.FileWriter(FLAGS.logdir) + sess = tf.Session() + + for i in range(_MAX_STEPS): + summary = sess.run( + meshes_summary, + feed_dict={ + vertices_tensor: vertices, + faces_tensor: faces, + colors_tensor: colors, + step: i, + }, + ) + writer.add_summary(summary, global_step=i) def main(unused_argv): - print('Saving output to %s.' % FLAGS.logdir) - run() - print('Done. Output saved to %s.' % FLAGS.logdir) + print("Saving output to %s." % FLAGS.logdir) + run() + print("Done. Output saved to %s." % FLAGS.logdir) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/tensorboard/plugins/mesh/mesh_demo_v2.py b/tensorboard/plugins/mesh/mesh_demo_v2.py index ce732dd04b..9e4a5a85da 100644 --- a/tensorboard/plugins/mesh/mesh_demo_v2.py +++ b/tensorboard/plugins/mesh/mesh_demo_v2.py @@ -28,9 +28,10 @@ from tensorboard.plugins.mesh import demo_utils -flags.DEFINE_string('logdir', '/tmp/mesh_demo', - 'Directory to write event logs to.') -flags.DEFINE_string('mesh_path', None, 'Path to PLY file to visualize.') +flags.DEFINE_string( + "logdir", "/tmp/mesh_demo", "Directory to write event logs to." +) +flags.DEFINE_string("mesh_path", None, "Path to PLY file to visualize.") FLAGS = flags.FLAGS @@ -41,50 +42,54 @@ def train_step(vertices, faces, colors, config_dict, step): - """Executes summary as a train step.""" - # Change colors over time. - t = float(step) / _MAX_STEPS - transformed_colors = t * (255 - colors) + (1 - t) * colors - mesh_summary.mesh( - 'mesh_color_tensor', vertices=vertices, faces=faces, - colors=transformed_colors, config_dict=config_dict, step=step) + """Executes summary as a train step.""" + # Change colors over time. + t = float(step) / _MAX_STEPS + transformed_colors = t * (255 - colors) + (1 - t) * colors + mesh_summary.mesh( + "mesh_color_tensor", + vertices=vertices, + faces=faces, + colors=transformed_colors, + config_dict=config_dict, + step=step, + ) def run(): - """Runs training steps with a mesh summary.""" - # Mesh summaries only work on TensorFlow 2.x. - if int(tf.__version__.split('.')[0]) < 1: - raise ImportError('TensorFlow 2.x is required to run this demo.') - # Flag mesh_path is required. - if FLAGS.mesh_path is None: - raise ValueError( - 'Flag --mesh_path is required and must contain path to PLY file.') - # Camera and scene configuration. - config_dict = { - 'camera': {'cls': 'PerspectiveCamera', 'fov': 75} - } - - # Read sample PLY file. - vertices, colors, faces = demo_utils.read_ascii_ply(FLAGS.mesh_path) - - # Add batch dimension. - vertices = np.expand_dims(vertices, 0) - faces = np.expand_dims(faces, 0) - colors = np.expand_dims(colors, 0) - - # Create summary writer. - writer = tf.summary.create_file_writer(FLAGS.logdir) - - with writer.as_default(): - for step in range(_MAX_STEPS): - train_step(vertices, faces, colors, config_dict, step) + """Runs training steps with a mesh summary.""" + # Mesh summaries only work on TensorFlow 2.x. + if int(tf.__version__.split(".")[0]) < 1: + raise ImportError("TensorFlow 2.x is required to run this demo.") + # Flag mesh_path is required. + if FLAGS.mesh_path is None: + raise ValueError( + "Flag --mesh_path is required and must contain path to PLY file." + ) + # Camera and scene configuration. + config_dict = {"camera": {"cls": "PerspectiveCamera", "fov": 75}} + + # Read sample PLY file. + vertices, colors, faces = demo_utils.read_ascii_ply(FLAGS.mesh_path) + + # Add batch dimension. + vertices = np.expand_dims(vertices, 0) + faces = np.expand_dims(faces, 0) + colors = np.expand_dims(colors, 0) + + # Create summary writer. + writer = tf.summary.create_file_writer(FLAGS.logdir) + + with writer.as_default(): + for step in range(_MAX_STEPS): + train_step(vertices, faces, colors, config_dict, step) def main(unused_argv): - print('Saving output to %s.' % FLAGS.logdir) - run() - print('Done. Output saved to %s.' % FLAGS.logdir) + print("Saving output to %s." % FLAGS.logdir) + run() + print("Done. Output saved to %s." % FLAGS.logdir) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/tensorboard/plugins/mesh/mesh_plugin.py b/tensorboard/plugins/mesh/mesh_plugin.py index d1b93debe1..efe4e7330d 100644 --- a/tensorboard/plugins/mesh/mesh_plugin.py +++ b/tensorboard/plugins/mesh/mesh_plugin.py @@ -35,244 +35,258 @@ class MeshPlugin(base_plugin.TBPlugin): - """A plugin that serves 3D visualization of meshes.""" - - plugin_name = metadata.PLUGIN_NAME - - def __init__(self, context): - """Instantiates a MeshPlugin via TensorBoard core. - - Args: - context: A base_plugin.TBContext instance. A magic container that - TensorBoard uses to make objects available to the plugin. - """ - # Retrieve the multiplexer from the context and store a reference to it. - self._multiplexer = context.multiplexer - self._tag_to_instance_tags = collections.defaultdict(list) - self._instance_tag_to_tag = dict() - self._instance_tag_to_metadata = dict() - self.prepare_metadata() - - def prepare_metadata(self): - """Processes all tags and caches metadata for each.""" - if self._tag_to_instance_tags: - return - # This is a dictionary mapping from run to (tag to string content). - # To be clear, the values of the dictionary are dictionaries. - all_runs = self._multiplexer.PluginRunToTagToContent(MeshPlugin.plugin_name) - - # tagToContent is itself a dictionary mapping tag name to string - # SummaryMetadata.plugin_data.content. Retrieve the keys of that dictionary - # to obtain a list of tags associated with each run. For each tag, estimate - # the number of samples. - self._tag_to_instance_tags = collections.defaultdict(list) - self._instance_tag_to_metadata = dict() - for run, tag_to_content in six.iteritems(all_runs): - for tag, content in six.iteritems(tag_to_content): - meta = metadata.parse_plugin_metadata(content) - self._instance_tag_to_metadata[(run, tag)] = meta - # Remember instance_name (instance_tag) for future reference. - self._tag_to_instance_tags[(run, meta.name)].append(tag) - self._instance_tag_to_tag[(run, tag)] = meta.name - - @wrappers.Request.application - def _serve_tags(self, request): - """A route (HTTP handler) that returns a response with tags. - - Args: - request: The werkzeug.Request object. - - Returns: - A response that contains a JSON object. The keys of the object - are all the runs. Each run is mapped to a (potentially empty) - list of all tags that are relevant to this plugin. - """ - # This is a dictionary mapping from run to (tag to string content). - # To be clear, the values of the dictionary are dictionaries. - all_runs = self._multiplexer.PluginRunToTagToContent( - MeshPlugin.plugin_name) - - # Make sure we populate tags mapping structures. - self.prepare_metadata() - - # tagToContent is itself a dictionary mapping tag name to string - # SummaryMetadata.plugin_data.content. Retrieve the keys of that dictionary - # to obtain a list of tags associated with each run. For each tag estimate - # number of samples. - response = dict() - for run, tag_to_content in six.iteritems(all_runs): - response[run] = dict() - for instance_tag, _ in six.iteritems(tag_to_content): - # Make sure we only operate on user-defined tags here. - tag = self._instance_tag_to_tag[(run, instance_tag)] - meta = self._instance_tag_to_metadata[(run, instance_tag)] - # Batch size must be defined, otherwise we don't know how many - # samples were there. - response[run][tag] = {'samples': meta.shape[0]} - return http_util.Respond(request, response, 'application/json') - - def get_plugin_apps(self): - """Gets all routes offered by the plugin. - - This method is called by TensorBoard when retrieving all the - routes offered by the plugin. - - Returns: - A dictionary mapping URL path to route that handles it. - """ - # Note that the methods handling routes are decorated with - # @wrappers.Request.application. - return { - '/tags': self._serve_tags, - '/meshes': self._serve_mesh_metadata, - '/data': self._serve_mesh_data, - } - - def is_active(self): - """Determines whether this plugin is active. - - This plugin is only active if TensorBoard sampled any summaries - relevant to the mesh plugin. - - Returns: - Whether this plugin is active. - """ - all_runs = self._multiplexer.PluginRunToTagToContent( - MeshPlugin.plugin_name) - - # The plugin is active if any of the runs has a tag relevant - # to the plugin. - return bool(self._multiplexer and any(six.itervalues(all_runs))) - - def frontend_metadata(self): - return base_plugin.FrontendMetadata(element_name='mesh-dashboard') - - def _get_sample(self, tensor_event, sample): - """Returns a single sample from a batch of samples.""" - data = tensor_util.make_ndarray(tensor_event.tensor_proto) - return data[sample].tolist() - - def _get_tensor_metadata( - self, event, content_type, components, data_shape, config): - """Converts a TensorEvent into a JSON-compatible response. - - Args: - event: TensorEvent object containing data in proto format. - content_type: enum plugin_data_pb2.MeshPluginData.ContentType value, - representing content type in TensorEvent. - components: Bitmask representing all parts (vertices, colors, etc.) that - belong to the summary. - data_shape: list of dimensions sizes of the tensor. - config: rendering scene configuration as dictionary. - - Returns: - Dictionary of transformed metadata. - """ - return { - 'wall_time': event.wall_time, - 'step': event.step, - 'content_type': content_type, - 'components': components, - 'config': config, - 'data_shape': list(data_shape), - } - - def _get_tensor_data(self, event, sample): - """Convert a TensorEvent into a JSON-compatible response.""" - data = self._get_sample(event, sample) - return data - - def _collect_tensor_events(self, request, step=None): - """Collects list of tensor events based on request.""" - run = request.args.get('run') - tag = request.args.get('tag') - - # TODO(b/128995556): investigate why this additional metadata mapping is - # necessary, it must have something todo with the lifecycle of the request. - # Make sure we populate tags mapping structures. - self.prepare_metadata() - - tensor_events = [] # List of tuples (meta, tensor) that contain tag. - for instance_tag in self._tag_to_instance_tags[(run, tag)]: - tensors = self._multiplexer.Tensors(run, instance_tag) - meta = self._instance_tag_to_metadata[(run, instance_tag)] - tensor_events += [(meta, tensor) for tensor in tensors] - - if step is not None: - tensor_events = [ - event for event in tensor_events if event[1].step == step] - else: - # Make sure tensors sorted by step in ascending order. - tensor_events = sorted( - tensor_events, key=lambda tensor_data: tensor_data[1].step) - - return tensor_events - - @wrappers.Request.application - def _serve_mesh_data(self, request): - """A route that returns data for particular summary of specified type. - - Data can represent vertices coordinates, vertices indices in faces, - vertices colors and so on. Each mesh may have different combination of - abovementioned data and each type/part of mesh summary must be served as - separate roundtrip to the server. - - Args: - request: werkzeug.Request containing content_type as a name of enum - plugin_data_pb2.MeshPluginData.ContentType. - - Returns: - werkzeug.Response either float32 or int32 data in binary format. - """ - step = float(request.args.get('step', 0.0)) - tensor_events = self._collect_tensor_events(request, step) - content_type = request.args.get('content_type') - try: - content_type = plugin_data_pb2.MeshPluginData.ContentType.Value( - content_type) - except ValueError: - return http_util.Respond(request, 'Bad content_type', 'text/plain', 400) - sample = int(request.args.get('sample', 0)) - - response = [ - self._get_tensor_data(tensor, sample) - for meta, tensor in tensor_events - if meta.content_type == content_type - ] - - np_type = { - plugin_data_pb2.MeshPluginData.VERTEX: np.float32, - plugin_data_pb2.MeshPluginData.FACE: np.int32, - plugin_data_pb2.MeshPluginData.COLOR: np.uint8, - }[content_type] - - response = np.array(response, dtype=np_type) - # Looks like reshape can take around 160ms, so why not store it reshaped. - response = response.reshape(-1).tobytes() - - return http_util.Respond(request, response, 'arraybuffer') - - @wrappers.Request.application - def _serve_mesh_metadata(self, request): - """A route that returns the mesh metadata associated with a tag. - - Metadata consists of wall time, type of elements in tensor, scene - configuration and so on. - - Args: - request: The werkzeug.Request object. - - Returns: - A JSON list of mesh data associated with the run and tag - combination. - """ - tensor_events = self._collect_tensor_events(request) - - # We convert the tensor data to text. - response = [ - self._get_tensor_metadata( - tensor, meta.content_type, meta.components, meta.shape, - meta.json_config) - for meta, tensor in tensor_events - ] - return http_util.Respond(request, response, 'application/json') + """A plugin that serves 3D visualization of meshes.""" + + plugin_name = metadata.PLUGIN_NAME + + def __init__(self, context): + """Instantiates a MeshPlugin via TensorBoard core. + + Args: + context: A base_plugin.TBContext instance. A magic container that + TensorBoard uses to make objects available to the plugin. + """ + # Retrieve the multiplexer from the context and store a reference to it. + self._multiplexer = context.multiplexer + self._tag_to_instance_tags = collections.defaultdict(list) + self._instance_tag_to_tag = dict() + self._instance_tag_to_metadata = dict() + self.prepare_metadata() + + def prepare_metadata(self): + """Processes all tags and caches metadata for each.""" + if self._tag_to_instance_tags: + return + # This is a dictionary mapping from run to (tag to string content). + # To be clear, the values of the dictionary are dictionaries. + all_runs = self._multiplexer.PluginRunToTagToContent( + MeshPlugin.plugin_name + ) + + # tagToContent is itself a dictionary mapping tag name to string + # SummaryMetadata.plugin_data.content. Retrieve the keys of that dictionary + # to obtain a list of tags associated with each run. For each tag, estimate + # the number of samples. + self._tag_to_instance_tags = collections.defaultdict(list) + self._instance_tag_to_metadata = dict() + for run, tag_to_content in six.iteritems(all_runs): + for tag, content in six.iteritems(tag_to_content): + meta = metadata.parse_plugin_metadata(content) + self._instance_tag_to_metadata[(run, tag)] = meta + # Remember instance_name (instance_tag) for future reference. + self._tag_to_instance_tags[(run, meta.name)].append(tag) + self._instance_tag_to_tag[(run, tag)] = meta.name + + @wrappers.Request.application + def _serve_tags(self, request): + """A route (HTTP handler) that returns a response with tags. + + Args: + request: The werkzeug.Request object. + + Returns: + A response that contains a JSON object. The keys of the object + are all the runs. Each run is mapped to a (potentially empty) + list of all tags that are relevant to this plugin. + """ + # This is a dictionary mapping from run to (tag to string content). + # To be clear, the values of the dictionary are dictionaries. + all_runs = self._multiplexer.PluginRunToTagToContent( + MeshPlugin.plugin_name + ) + + # Make sure we populate tags mapping structures. + self.prepare_metadata() + + # tagToContent is itself a dictionary mapping tag name to string + # SummaryMetadata.plugin_data.content. Retrieve the keys of that dictionary + # to obtain a list of tags associated with each run. For each tag estimate + # number of samples. + response = dict() + for run, tag_to_content in six.iteritems(all_runs): + response[run] = dict() + for instance_tag, _ in six.iteritems(tag_to_content): + # Make sure we only operate on user-defined tags here. + tag = self._instance_tag_to_tag[(run, instance_tag)] + meta = self._instance_tag_to_metadata[(run, instance_tag)] + # Batch size must be defined, otherwise we don't know how many + # samples were there. + response[run][tag] = {"samples": meta.shape[0]} + return http_util.Respond(request, response, "application/json") + + def get_plugin_apps(self): + """Gets all routes offered by the plugin. + + This method is called by TensorBoard when retrieving all the + routes offered by the plugin. + + Returns: + A dictionary mapping URL path to route that handles it. + """ + # Note that the methods handling routes are decorated with + # @wrappers.Request.application. + return { + "/tags": self._serve_tags, + "/meshes": self._serve_mesh_metadata, + "/data": self._serve_mesh_data, + } + + def is_active(self): + """Determines whether this plugin is active. + + This plugin is only active if TensorBoard sampled any summaries + relevant to the mesh plugin. + + Returns: + Whether this plugin is active. + """ + all_runs = self._multiplexer.PluginRunToTagToContent( + MeshPlugin.plugin_name + ) + + # The plugin is active if any of the runs has a tag relevant + # to the plugin. + return bool(self._multiplexer and any(six.itervalues(all_runs))) + + def frontend_metadata(self): + return base_plugin.FrontendMetadata(element_name="mesh-dashboard") + + def _get_sample(self, tensor_event, sample): + """Returns a single sample from a batch of samples.""" + data = tensor_util.make_ndarray(tensor_event.tensor_proto) + return data[sample].tolist() + + def _get_tensor_metadata( + self, event, content_type, components, data_shape, config + ): + """Converts a TensorEvent into a JSON-compatible response. + + Args: + event: TensorEvent object containing data in proto format. + content_type: enum plugin_data_pb2.MeshPluginData.ContentType value, + representing content type in TensorEvent. + components: Bitmask representing all parts (vertices, colors, etc.) that + belong to the summary. + data_shape: list of dimensions sizes of the tensor. + config: rendering scene configuration as dictionary. + + Returns: + Dictionary of transformed metadata. + """ + return { + "wall_time": event.wall_time, + "step": event.step, + "content_type": content_type, + "components": components, + "config": config, + "data_shape": list(data_shape), + } + + def _get_tensor_data(self, event, sample): + """Convert a TensorEvent into a JSON-compatible response.""" + data = self._get_sample(event, sample) + return data + + def _collect_tensor_events(self, request, step=None): + """Collects list of tensor events based on request.""" + run = request.args.get("run") + tag = request.args.get("tag") + + # TODO(b/128995556): investigate why this additional metadata mapping is + # necessary, it must have something todo with the lifecycle of the request. + # Make sure we populate tags mapping structures. + self.prepare_metadata() + + tensor_events = [] # List of tuples (meta, tensor) that contain tag. + for instance_tag in self._tag_to_instance_tags[(run, tag)]: + tensors = self._multiplexer.Tensors(run, instance_tag) + meta = self._instance_tag_to_metadata[(run, instance_tag)] + tensor_events += [(meta, tensor) for tensor in tensors] + + if step is not None: + tensor_events = [ + event for event in tensor_events if event[1].step == step + ] + else: + # Make sure tensors sorted by step in ascending order. + tensor_events = sorted( + tensor_events, key=lambda tensor_data: tensor_data[1].step + ) + + return tensor_events + + @wrappers.Request.application + def _serve_mesh_data(self, request): + """A route that returns data for particular summary of specified type. + + Data can represent vertices coordinates, vertices indices in faces, + vertices colors and so on. Each mesh may have different combination of + abovementioned data and each type/part of mesh summary must be served as + separate roundtrip to the server. + + Args: + request: werkzeug.Request containing content_type as a name of enum + plugin_data_pb2.MeshPluginData.ContentType. + + Returns: + werkzeug.Response either float32 or int32 data in binary format. + """ + step = float(request.args.get("step", 0.0)) + tensor_events = self._collect_tensor_events(request, step) + content_type = request.args.get("content_type") + try: + content_type = plugin_data_pb2.MeshPluginData.ContentType.Value( + content_type + ) + except ValueError: + return http_util.Respond( + request, "Bad content_type", "text/plain", 400 + ) + sample = int(request.args.get("sample", 0)) + + response = [ + self._get_tensor_data(tensor, sample) + for meta, tensor in tensor_events + if meta.content_type == content_type + ] + + np_type = { + plugin_data_pb2.MeshPluginData.VERTEX: np.float32, + plugin_data_pb2.MeshPluginData.FACE: np.int32, + plugin_data_pb2.MeshPluginData.COLOR: np.uint8, + }[content_type] + + response = np.array(response, dtype=np_type) + # Looks like reshape can take around 160ms, so why not store it reshaped. + response = response.reshape(-1).tobytes() + + return http_util.Respond(request, response, "arraybuffer") + + @wrappers.Request.application + def _serve_mesh_metadata(self, request): + """A route that returns the mesh metadata associated with a tag. + + Metadata consists of wall time, type of elements in tensor, scene + configuration and so on. + + Args: + request: The werkzeug.Request object. + + Returns: + A JSON list of mesh data associated with the run and tag + combination. + """ + tensor_events = self._collect_tensor_events(request) + + # We convert the tensor data to text. + response = [ + self._get_tensor_metadata( + tensor, + meta.content_type, + meta.components, + meta.shape, + meta.json_config, + ) + for meta, tensor in tensor_events + ] + return http_util.Respond(request, response, "application/json") diff --git a/tensorboard/plugins/mesh/mesh_plugin_test.py b/tensorboard/plugins/mesh/mesh_plugin_test.py index 4c602d42f3..21452be7b9 100644 --- a/tensorboard/plugins/mesh/mesh_plugin_test.py +++ b/tensorboard/plugins/mesh/mesh_plugin_test.py @@ -27,7 +27,9 @@ from werkzeug import test as werkzeug_test from werkzeug import wrappers from tensorboard.backend import application -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from tensorboard.backend.event_processing import ( + plugin_event_multiplexer as event_multiplexer, +) from tensorboard.plugins import base_plugin from tensorboard.plugins.mesh import mesh_plugin from tensorboard.plugins.mesh import summary @@ -38,203 +40,250 @@ from mock import patch try: - # python version >= 3.3 - from unittest import mock + # python version >= 3.3 + from unittest import mock except ImportError: - import mock # pylint: disable=unused-import + import mock # pylint: disable=unused-import class MeshPluginTest(tf.test.TestCase): - """Tests for mesh plugin server.""" - - def setUp(self): - # We use numpy.random to generate meshes. We seed to avoid non-determinism - # in this test. - np.random.seed(17) - - # Log dir to save temp events into. - self.log_dir = self.get_temp_dir() - - # Create mesh summary. - with tf.compat.v1.Graph().as_default(): - tf_placeholder = tf.compat.v1.placeholder - sess = tf.compat.v1.Session() - point_cloud = test_utils.get_random_mesh(1000) - point_cloud_vertices = tf_placeholder( - tf.float32, point_cloud.vertices.shape - ) - - mesh_no_color = test_utils.get_random_mesh(2000, add_faces=True) - mesh_no_color_extended = test_utils.get_random_mesh(2500, add_faces=True) - mesh_no_color_vertices = tf_placeholder(tf.float32, [1, None, 3]) - mesh_no_color_faces = tf_placeholder(tf.int32, [1, None, 3]) - - mesh_color = test_utils.get_random_mesh( - 3000, add_faces=True, add_colors=True) - mesh_color_vertices = tf_placeholder(tf.float32, mesh_color.vertices.shape) - mesh_color_faces = tf_placeholder(tf.int32, mesh_color.faces.shape) - mesh_color_colors = tf_placeholder(tf.uint8, mesh_color.colors.shape) - - self.data = [ - point_cloud, mesh_no_color, mesh_no_color_extended, mesh_color] - - # In case when name is present and display_name is not, we will reuse name - # as display_name. Summaries below intended to test both cases. - self.names = ["point_cloud", "mesh_no_color", "mesh_color"] - summary.op( - self.names[0], - point_cloud_vertices, - description="just point cloud") - summary.op( - self.names[1], - mesh_no_color_vertices, - faces=mesh_no_color_faces, - display_name="name_to_display_in_ui", - description="beautiful mesh in grayscale") - summary.op( - self.names[2], - mesh_color_vertices, - faces=mesh_color_faces, - colors=mesh_color_colors, - description="mesh with random colors") - - merged_summary_op = tf.compat.v1.summary.merge_all() - self.runs = ["bar"] - self.steps = 20 - bar_directory = os.path.join(self.log_dir, self.runs[0]) - with tensorboard_test_util.FileWriterCache.get(bar_directory) as writer: - writer.add_graph(sess.graph) - for step in range(self.steps): - # Alternate between two random meshes with different number of - # vertices. - no_color = mesh_no_color if step % 2 == 0 else mesh_no_color_extended - with patch.object(time, 'time', return_value=step): - writer.add_summary( - sess.run( - merged_summary_op, - feed_dict={ - point_cloud_vertices: point_cloud.vertices, - mesh_no_color_vertices: no_color.vertices, - mesh_no_color_faces: no_color.faces, - mesh_color_vertices: mesh_color.vertices, - mesh_color_faces: mesh_color.faces, - mesh_color_colors: mesh_color.colors, - }), - global_step=step) - - # Start a server that will receive requests. - self.multiplexer = event_multiplexer.EventMultiplexer({ - "bar": bar_directory, - }) - self.context = base_plugin.TBContext( - logdir=self.log_dir, multiplexer=self.multiplexer) - self.plugin = mesh_plugin.MeshPlugin(self.context) - # Wait until after plugin construction to reload the multiplexer because the - # plugin caches data from the multiplexer upon construction and this affects - # logic tested later down. - # TODO(https://github.com/tensorflow/tensorboard/issues/2579): Eliminate the - # caching of data at construction time and move this Reload() up to just - # after the multiplexer is created. - self.multiplexer.Reload() - wsgi_app = application.TensorBoardWSGI([self.plugin]) - self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse) - self.routes = self.plugin.get_plugin_apps() - - def tearDown(self): - shutil.rmtree(self.log_dir, ignore_errors=True) - - def testRoutes(self): - """Tests that the /tags route offers the correct run to tag mapping.""" - self.assertIsInstance(self.routes["/tags"], collections.Callable) - self.assertIsInstance(self.routes["/meshes"], collections.Callable) - self.assertIsInstance(self.routes["/data"], collections.Callable) - - def testTagsRoute(self): - """Tests that the /tags route offers the correct run to tag mapping.""" - response = self.server.get("/data/plugin/mesh/tags") - self.assertEqual(200, response.status_code) - tags = test_utils.deserialize_json_response(response.get_data()) - self.assertIn(self.runs[0], tags) - for name in self.names: - self.assertIn(name, tags[self.runs[0]]) - - def validate_data_response( - self, run, tag, sample, content_type, dtype, ground_truth_data, - step=0): - """Makes request and checks that response has expected data.""" - response = self.server.get( - "/data/plugin/mesh/data?run=%s&tag=%s&sample=%d&content_type=" - "%s&step=%d" % - (run, tag, sample, content_type, step)) - self.assertEqual(200, response.status_code) - data = test_utils.deserialize_array_buffer_response( - next(response.response), dtype) - self.assertEqual(ground_truth_data.reshape(-1).tolist(), data.tolist()) - - def testDataRoute(self): - """Tests that the /data route returns correct data for meshes.""" - self.validate_data_response( - self.runs[0], self.names[0], 0, "VERTEX", np.float32, - self.data[0].vertices) - - self.validate_data_response( - self.runs[0], self.names[1], 0, "FACE", np.int32, self.data[1].faces) - - # Validate that the same summary has mesh with different number of faces at - # different step=1. - self.validate_data_response( - self.runs[0], self.names[1], 0, "FACE", np.int32, self.data[2].faces, - step=1) - - self.validate_data_response( - self.runs[0], self.names[2], 0, "COLOR", np.uint8, self.data[3].colors) - - def testMetadataRoute(self): - """Tests that the /meshes route returns correct metadata for meshes.""" - response = self.server.get( - "/data/plugin/mesh/meshes?run=%s&tag=%s&sample=%d" % - (self.runs[0], self.names[0], 0)) - self.assertEqual(200, response.status_code) - metadata = test_utils.deserialize_json_response(response.get_data()) - self.assertEqual(len(metadata), self.steps) - self.assertAllEqual(metadata[0]["content_type"], - plugin_data_pb2.MeshPluginData.VERTEX) - self.assertAllEqual(metadata[0]["data_shape"], self.data[0].vertices.shape) - - def testsEventsAlwaysSortedByStep(self): - """Tests that events always sorted by step.""" - response = self.server.get( - "/data/plugin/mesh/meshes?run=%s&tag=%s&sample=%d" % - (self.runs[0], self.names[1], 0)) - self.assertEqual(200, response.status_code) - metadata = test_utils.deserialize_json_response(response.get_data()) - for i in range(1, self.steps): - # Step will be equal when two tensors of different content type - # belong to the same mesh. - self.assertLessEqual(metadata[i - 1]["step"], - metadata[i]["step"]) - - @mock.patch.object( - event_multiplexer.EventMultiplexer, - "PluginRunToTagToContent", - return_value={"bar": {"foo": "".encode("utf-8")}}, - ) - def testMetadataComputedOnce(self, run_to_tag_mock): - """Tests that metadata mapping computed once.""" - self.plugin.prepare_metadata() - self.plugin.prepare_metadata() - self.assertEqual(1, run_to_tag_mock.call_count) - - def testIsActive(self): - self.assertTrue(self.plugin.is_active()) - - @mock.patch.object( - event_multiplexer.EventMultiplexer, - "PluginRunToTagToContent", - return_value={}) - def testIsInactive(self, get_random_mesh_stub): - self.assertFalse(self.plugin.is_active()) + """Tests for mesh plugin server.""" + + def setUp(self): + # We use numpy.random to generate meshes. We seed to avoid non-determinism + # in this test. + np.random.seed(17) + + # Log dir to save temp events into. + self.log_dir = self.get_temp_dir() + + # Create mesh summary. + with tf.compat.v1.Graph().as_default(): + tf_placeholder = tf.compat.v1.placeholder + sess = tf.compat.v1.Session() + point_cloud = test_utils.get_random_mesh(1000) + point_cloud_vertices = tf_placeholder( + tf.float32, point_cloud.vertices.shape + ) + + mesh_no_color = test_utils.get_random_mesh(2000, add_faces=True) + mesh_no_color_extended = test_utils.get_random_mesh( + 2500, add_faces=True + ) + mesh_no_color_vertices = tf_placeholder(tf.float32, [1, None, 3]) + mesh_no_color_faces = tf_placeholder(tf.int32, [1, None, 3]) + + mesh_color = test_utils.get_random_mesh( + 3000, add_faces=True, add_colors=True + ) + mesh_color_vertices = tf_placeholder( + tf.float32, mesh_color.vertices.shape + ) + mesh_color_faces = tf_placeholder(tf.int32, mesh_color.faces.shape) + mesh_color_colors = tf_placeholder( + tf.uint8, mesh_color.colors.shape + ) + + self.data = [ + point_cloud, + mesh_no_color, + mesh_no_color_extended, + mesh_color, + ] + + # In case when name is present and display_name is not, we will reuse name + # as display_name. Summaries below intended to test both cases. + self.names = ["point_cloud", "mesh_no_color", "mesh_color"] + summary.op( + self.names[0], + point_cloud_vertices, + description="just point cloud", + ) + summary.op( + self.names[1], + mesh_no_color_vertices, + faces=mesh_no_color_faces, + display_name="name_to_display_in_ui", + description="beautiful mesh in grayscale", + ) + summary.op( + self.names[2], + mesh_color_vertices, + faces=mesh_color_faces, + colors=mesh_color_colors, + description="mesh with random colors", + ) + + merged_summary_op = tf.compat.v1.summary.merge_all() + self.runs = ["bar"] + self.steps = 20 + bar_directory = os.path.join(self.log_dir, self.runs[0]) + with tensorboard_test_util.FileWriterCache.get( + bar_directory + ) as writer: + writer.add_graph(sess.graph) + for step in range(self.steps): + # Alternate between two random meshes with different number of + # vertices. + no_color = ( + mesh_no_color + if step % 2 == 0 + else mesh_no_color_extended + ) + with patch.object(time, "time", return_value=step): + writer.add_summary( + sess.run( + merged_summary_op, + feed_dict={ + point_cloud_vertices: point_cloud.vertices, + mesh_no_color_vertices: no_color.vertices, + mesh_no_color_faces: no_color.faces, + mesh_color_vertices: mesh_color.vertices, + mesh_color_faces: mesh_color.faces, + mesh_color_colors: mesh_color.colors, + }, + ), + global_step=step, + ) + + # Start a server that will receive requests. + self.multiplexer = event_multiplexer.EventMultiplexer( + {"bar": bar_directory,} + ) + self.context = base_plugin.TBContext( + logdir=self.log_dir, multiplexer=self.multiplexer + ) + self.plugin = mesh_plugin.MeshPlugin(self.context) + # Wait until after plugin construction to reload the multiplexer because the + # plugin caches data from the multiplexer upon construction and this affects + # logic tested later down. + # TODO(https://github.com/tensorflow/tensorboard/issues/2579): Eliminate the + # caching of data at construction time and move this Reload() up to just + # after the multiplexer is created. + self.multiplexer.Reload() + wsgi_app = application.TensorBoardWSGI([self.plugin]) + self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse) + self.routes = self.plugin.get_plugin_apps() + + def tearDown(self): + shutil.rmtree(self.log_dir, ignore_errors=True) + + def testRoutes(self): + """Tests that the /tags route offers the correct run to tag mapping.""" + self.assertIsInstance(self.routes["/tags"], collections.Callable) + self.assertIsInstance(self.routes["/meshes"], collections.Callable) + self.assertIsInstance(self.routes["/data"], collections.Callable) + + def testTagsRoute(self): + """Tests that the /tags route offers the correct run to tag mapping.""" + response = self.server.get("/data/plugin/mesh/tags") + self.assertEqual(200, response.status_code) + tags = test_utils.deserialize_json_response(response.get_data()) + self.assertIn(self.runs[0], tags) + for name in self.names: + self.assertIn(name, tags[self.runs[0]]) + + def validate_data_response( + self, run, tag, sample, content_type, dtype, ground_truth_data, step=0 + ): + """Makes request and checks that response has expected data.""" + response = self.server.get( + "/data/plugin/mesh/data?run=%s&tag=%s&sample=%d&content_type=" + "%s&step=%d" % (run, tag, sample, content_type, step) + ) + self.assertEqual(200, response.status_code) + data = test_utils.deserialize_array_buffer_response( + next(response.response), dtype + ) + self.assertEqual(ground_truth_data.reshape(-1).tolist(), data.tolist()) + + def testDataRoute(self): + """Tests that the /data route returns correct data for meshes.""" + self.validate_data_response( + self.runs[0], + self.names[0], + 0, + "VERTEX", + np.float32, + self.data[0].vertices, + ) + + self.validate_data_response( + self.runs[0], self.names[1], 0, "FACE", np.int32, self.data[1].faces + ) + + # Validate that the same summary has mesh with different number of faces at + # different step=1. + self.validate_data_response( + self.runs[0], + self.names[1], + 0, + "FACE", + np.int32, + self.data[2].faces, + step=1, + ) + + self.validate_data_response( + self.runs[0], + self.names[2], + 0, + "COLOR", + np.uint8, + self.data[3].colors, + ) + + def testMetadataRoute(self): + """Tests that the /meshes route returns correct metadata for meshes.""" + response = self.server.get( + "/data/plugin/mesh/meshes?run=%s&tag=%s&sample=%d" + % (self.runs[0], self.names[0], 0) + ) + self.assertEqual(200, response.status_code) + metadata = test_utils.deserialize_json_response(response.get_data()) + self.assertEqual(len(metadata), self.steps) + self.assertAllEqual( + metadata[0]["content_type"], plugin_data_pb2.MeshPluginData.VERTEX + ) + self.assertAllEqual( + metadata[0]["data_shape"], self.data[0].vertices.shape + ) + + def testsEventsAlwaysSortedByStep(self): + """Tests that events always sorted by step.""" + response = self.server.get( + "/data/plugin/mesh/meshes?run=%s&tag=%s&sample=%d" + % (self.runs[0], self.names[1], 0) + ) + self.assertEqual(200, response.status_code) + metadata = test_utils.deserialize_json_response(response.get_data()) + for i in range(1, self.steps): + # Step will be equal when two tensors of different content type + # belong to the same mesh. + self.assertLessEqual(metadata[i - 1]["step"], metadata[i]["step"]) + + @mock.patch.object( + event_multiplexer.EventMultiplexer, + "PluginRunToTagToContent", + return_value={"bar": {"foo": "".encode("utf-8")}}, + ) + def testMetadataComputedOnce(self, run_to_tag_mock): + """Tests that metadata mapping computed once.""" + self.plugin.prepare_metadata() + self.plugin.prepare_metadata() + self.assertEqual(1, run_to_tag_mock.call_count) + + def testIsActive(self): + self.assertTrue(self.plugin.is_active()) + + @mock.patch.object( + event_multiplexer.EventMultiplexer, + "PluginRunToTagToContent", + return_value={}, + ) + def testIsInactive(self, get_random_mesh_stub): + self.assertFalse(self.plugin.is_active()) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/tensorboard/plugins/mesh/metadata.py b/tensorboard/plugins/mesh/metadata.py index 3ea386c682..b9702f5bc3 100644 --- a/tensorboard/plugins/mesh/metadata.py +++ b/tensorboard/plugins/mesh/metadata.py @@ -24,8 +24,9 @@ MeshTensor = collections.namedtuple( - 'MeshTensor', ('data', 'content_type', 'data_type')) -PLUGIN_NAME = 'mesh' + "MeshTensor", ("data", "content_type", "data_type") +) +PLUGIN_NAME = "mesh" # The most recent value for the `version` field of the # `MeshPluginData` proto. @@ -33,100 +34,111 @@ def get_components_bitmask(content_types): - """Creates bitmask for all existing components of the summary. + """Creates bitmask for all existing components of the summary. - Args: - content_type: list of plugin_data_pb2.MeshPluginData.ContentType, - representing all components related to the summary. - Returns: bitmask based on passed tensors. - """ - components = 0 - for content_type in content_types: - if content_type == plugin_data_pb2.MeshPluginData.UNDEFINED: - raise ValueError('Cannot include UNDEFINED content type in mask.') - components = components | (1 << content_type) - return components + Args: + content_type: list of plugin_data_pb2.MeshPluginData.ContentType, + representing all components related to the summary. + Returns: bitmask based on passed tensors. + """ + components = 0 + for content_type in content_types: + if content_type == plugin_data_pb2.MeshPluginData.UNDEFINED: + raise ValueError("Cannot include UNDEFINED content type in mask.") + components = components | (1 << content_type) + return components def get_current_version(): - """Returns current verions of the proto.""" - return _PROTO_VERSION + """Returns current verions of the proto.""" + return _PROTO_VERSION def get_instance_name(name, content_type): - """Returns a unique instance name for a given summary related to the mesh.""" - return '%s_%s' % ( - name, - plugin_data_pb2.MeshPluginData.ContentType.Name(content_type)) - - -def create_summary_metadata(name, - display_name, - content_type, - components, - shape, - description=None, - json_config=None): - """Creates summary metadata which defined at MeshPluginData proto. - - Arguments: - name: Original merged (summaries of different types) summary name. - display_name: The display name used in TensorBoard. - content_type: Value from MeshPluginData.ContentType enum describing data. - components: Bitmask representing present parts (vertices, colors, etc.) that - belong to the summary. - shape: list of dimensions sizes of the tensor. - description: The description to show in TensorBoard. - json_config: A string, JSON-serialized dictionary of ThreeJS classes - configuration. - - Returns: - A `summary_pb2.SummaryMetadata` protobuf object. - """ - # Shape should be at least BxNx3 where B represents the batch dimensions - # and N - the number of points, each with x,y,z coordinates. - if len(shape) != 3: - raise ValueError( - 'Tensor shape should be of shape BxNx3, but got %s.' % str(shape)) - mesh_plugin_data = plugin_data_pb2.MeshPluginData( - version=get_current_version(), - name=name, - content_type=content_type, - components=components, - shape=shape, - json_config=json_config) - content = mesh_plugin_data.SerializeToString() - return summary_pb2.SummaryMetadata( - display_name=display_name, # Will not be used in TensorBoard UI. - summary_description=description, - plugin_data=summary_pb2.SummaryMetadata.PluginData( - plugin_name=PLUGIN_NAME, - content=content)) + """Returns a unique instance name for a given summary related to the + mesh.""" + return "%s_%s" % ( + name, + plugin_data_pb2.MeshPluginData.ContentType.Name(content_type), + ) + + +def create_summary_metadata( + name, + display_name, + content_type, + components, + shape, + description=None, + json_config=None, +): + """Creates summary metadata which defined at MeshPluginData proto. + + Arguments: + name: Original merged (summaries of different types) summary name. + display_name: The display name used in TensorBoard. + content_type: Value from MeshPluginData.ContentType enum describing data. + components: Bitmask representing present parts (vertices, colors, etc.) that + belong to the summary. + shape: list of dimensions sizes of the tensor. + description: The description to show in TensorBoard. + json_config: A string, JSON-serialized dictionary of ThreeJS classes + configuration. + + Returns: + A `summary_pb2.SummaryMetadata` protobuf object. + """ + # Shape should be at least BxNx3 where B represents the batch dimensions + # and N - the number of points, each with x,y,z coordinates. + if len(shape) != 3: + raise ValueError( + "Tensor shape should be of shape BxNx3, but got %s." % str(shape) + ) + mesh_plugin_data = plugin_data_pb2.MeshPluginData( + version=get_current_version(), + name=name, + content_type=content_type, + components=components, + shape=shape, + json_config=json_config, + ) + content = mesh_plugin_data.SerializeToString() + return summary_pb2.SummaryMetadata( + display_name=display_name, # Will not be used in TensorBoard UI. + summary_description=description, + plugin_data=summary_pb2.SummaryMetadata.PluginData( + plugin_name=PLUGIN_NAME, content=content + ), + ) def parse_plugin_metadata(content): - """Parse summary metadata to a Python object. - - Arguments: - content: The `content` field of a `SummaryMetadata` proto - corresponding to the mesh plugin. - - Returns: - A `MeshPluginData` protobuf object. - Raises: Error if the version of the plugin is not supported. - """ - if not isinstance(content, bytes): - raise TypeError('Content type must be bytes.') - result = plugin_data_pb2.MeshPluginData.FromString(content) - if not 0 <= result.version <= get_current_version(): - raise ValueError('Unknown metadata version: %s. The latest version known to ' - 'this build of TensorBoard is %s; perhaps a newer build is ' - 'available?' % (result.version, get_current_version())) - # Add components field to older version of the proto. - if result.components == 0: - result.components = get_components_bitmask([ - plugin_data_pb2.MeshPluginData.VERTEX, - plugin_data_pb2.MeshPluginData.FACE, - plugin_data_pb2.MeshPluginData.COLOR, - ]) - return result + """Parse summary metadata to a Python object. + + Arguments: + content: The `content` field of a `SummaryMetadata` proto + corresponding to the mesh plugin. + + Returns: + A `MeshPluginData` protobuf object. + Raises: Error if the version of the plugin is not supported. + """ + if not isinstance(content, bytes): + raise TypeError("Content type must be bytes.") + result = plugin_data_pb2.MeshPluginData.FromString(content) + if not 0 <= result.version <= get_current_version(): + raise ValueError( + "Unknown metadata version: %s. The latest version known to " + "this build of TensorBoard is %s; perhaps a newer build is " + "available?" % (result.version, get_current_version()) + ) + # Add components field to older version of the proto. + if result.components == 0: + result.components = get_components_bitmask( + [ + plugin_data_pb2.MeshPluginData.VERTEX, + plugin_data_pb2.MeshPluginData.FACE, + plugin_data_pb2.MeshPluginData.COLOR, + ] + ) + return result diff --git a/tensorboard/plugins/mesh/metadata_test.py b/tensorboard/plugins/mesh/metadata_test.py index a498caf9bc..6bbd0d3681 100644 --- a/tensorboard/plugins/mesh/metadata_test.py +++ b/tensorboard/plugins/mesh/metadata_test.py @@ -27,84 +27,94 @@ class MetadataTest(tf.test.TestCase): + def _create_metadata(self, shape=None): + """Creates metadata with dummy data.""" + self.name = "unique_name" + self.display_name = "my mesh" + self.json_config = "{}" + if shape is None: + shape = [1, 100, 3] + self.shape = shape + self.components = 14 + self.summary_metadata = metadata.create_summary_metadata( + self.name, + self.display_name, + plugin_data_pb2.MeshPluginData.ContentType.Value("VERTEX"), + self.components, + self.shape, + json_config=self.json_config, + ) - def _create_metadata(self, shape=None): - """Creates metadata with dummy data.""" - self.name = 'unique_name' - self.display_name = 'my mesh' - self.json_config = '{}' - if shape is None: - shape = [1, 100, 3] - self.shape = shape - self.components = 14 - self.summary_metadata = metadata.create_summary_metadata( - self.name, - self.display_name, - plugin_data_pb2.MeshPluginData.ContentType.Value('VERTEX'), - self.components, - self.shape, - json_config=self.json_config) + def test_get_instance_name(self): + """Tests proper creation of instance name based on display_name.""" + display_name = "my_mesh" + instance_name = metadata.get_instance_name( + display_name, + plugin_data_pb2.MeshPluginData.ContentType.Value("VERTEX"), + ) + self.assertEqual("%s_VERTEX" % display_name, instance_name) - def test_get_instance_name(self): - """Tests proper creation of instance name based on display_name.""" - display_name = 'my_mesh' - instance_name = metadata.get_instance_name( - display_name, - plugin_data_pb2.MeshPluginData.ContentType.Value('VERTEX')) - self.assertEqual('%s_VERTEX' % display_name, instance_name) + def test_create_summary_metadata(self): + """Tests MeshPlugin metadata creation.""" + self._create_metadata() + self.assertEqual(self.display_name, self.summary_metadata.display_name) + self.assertEqual( + metadata.PLUGIN_NAME, self.summary_metadata.plugin_data.plugin_name + ) - def test_create_summary_metadata(self): - """Tests MeshPlugin metadata creation.""" - self._create_metadata() - self.assertEqual(self.display_name, - self.summary_metadata.display_name) - self.assertEqual(metadata.PLUGIN_NAME, - self.summary_metadata.plugin_data.plugin_name) + def test_parse_plugin_metadata(self): + """Tests parsing of saved plugin metadata.""" + self._create_metadata() + parsed_metadata = metadata.parse_plugin_metadata( + self.summary_metadata.plugin_data.content + ) + self.assertEqual(self.name, parsed_metadata.name) + self.assertEqual( + plugin_data_pb2.MeshPluginData.ContentType.Value("VERTEX"), + parsed_metadata.content_type, + ) + self.assertEqual(self.shape, parsed_metadata.shape) + self.assertEqual(self.json_config, parsed_metadata.json_config) + self.assertEqual(self.components, parsed_metadata.components) - def test_parse_plugin_metadata(self): - """Tests parsing of saved plugin metadata.""" - self._create_metadata() - parsed_metadata = metadata.parse_plugin_metadata( - self.summary_metadata.plugin_data.content) - self.assertEqual(self.name, parsed_metadata.name) - self.assertEqual(plugin_data_pb2.MeshPluginData.ContentType.Value('VERTEX'), - parsed_metadata.content_type) - self.assertEqual(self.shape, parsed_metadata.shape) - self.assertEqual(self.json_config, parsed_metadata.json_config) - self.assertEqual(self.components, parsed_metadata.components) + def test_metadata_version(self): + """Tests that only the latest version of metadata is supported.""" + with patch.object(metadata, "get_current_version", return_value=100): + self._create_metadata() + # Change the version. + with patch.object(metadata, "get_current_version", return_value=1): + # Try to parse metadata from a prior version. + with self.assertRaises(ValueError): + metadata.parse_plugin_metadata( + self.summary_metadata.plugin_data.content + ) - def test_metadata_version(self): - """Tests that only the latest version of metadata is supported.""" - with patch.object(metadata, 'get_current_version', return_value=100): - self._create_metadata() - # Change the version. - with patch.object(metadata, 'get_current_version', return_value=1): - # Try to parse metadata from a prior version. - with self.assertRaises(ValueError): - metadata.parse_plugin_metadata( - self.summary_metadata.plugin_data.content) + def test_tensor_shape(self): + """Tests that target tensor should be of particular shape.""" + with six.assertRaisesRegex( + self, ValueError, r"Tensor shape should be of shape BxNx3.*" + ): + self._create_metadata([1]) - def test_tensor_shape(self): - """Tests that target tensor should be of particular shape.""" - with six.assertRaisesRegex( - self, ValueError, r'Tensor shape should be of shape BxNx3.*'): - self._create_metadata([1]) + def test_metadata_format(self): + """Tests that metadata content must be passed as a serialized + string.""" + with six.assertRaisesRegex( + self, TypeError, r"Content type must be bytes." + ): + metadata.parse_plugin_metadata(123) - def test_metadata_format(self): - """Tests that metadata content must be passed as a serialized string.""" - with six.assertRaisesRegex(self, TypeError, r'Content type must be bytes.'): - metadata.parse_plugin_metadata(123) + def test_default_components(self): + """Tests that defult components are added when necessary.""" + self._create_metadata() + stored_metadata = plugin_data_pb2.MeshPluginData( + version=metadata.get_current_version(), components=0 + ) + parsed_metadata = metadata.parse_plugin_metadata( + stored_metadata.SerializeToString() + ) + self.assertGreater(parsed_metadata.components, 0) - def test_default_components(self): - """Tests that defult components are added when necessary.""" - self._create_metadata() - stored_metadata = plugin_data_pb2.MeshPluginData( - version=metadata.get_current_version(), - components=0) - parsed_metadata = metadata.parse_plugin_metadata( - stored_metadata.SerializeToString()) - self.assertGreater(parsed_metadata.components, 0) - -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/mesh/summary.py b/tensorboard/plugins/mesh/summary.py index 1755ad4d4d..b8ea5bd2ca 100644 --- a/tensorboard/plugins/mesh/summary.py +++ b/tensorboard/plugins/mesh/summary.py @@ -30,176 +30,218 @@ def _get_tensor_summary( - name, display_name, description, tensor, content_type, components, - json_config, collections): - """Creates a tensor summary with summary metadata. - - Args: - name: Uniquely identifiable name of the summary op. Could be replaced by - combination of name and type to make it unique even outside of this - summary. - display_name: Will be used as the display name in TensorBoard. - Defaults to `tag`. - description: A longform readable description of the summary data. Markdown - is supported. - tensor: Tensor to display in summary. - content_type: Type of content inside the Tensor. - components: Bitmask representing present parts (vertices, colors, etc.) that - belong to the summary. - json_config: A string, JSON-serialized dictionary of ThreeJS classes - configuration. - collections: List of collections to add this summary to. - - Returns: - Tensor summary with metadata. - """ - tensor = tf.convert_to_tensor(value=tensor) - shape = tensor.shape.as_list() - shape = [dim if dim is not None else -1 for dim in shape] - tensor_metadata = metadata.create_summary_metadata( - name, - display_name, - content_type, - components, - shape, - description, - json_config=json_config) - tensor_summary = tf.compat.v1.summary.tensor_summary( - metadata.get_instance_name(name, content_type), - tensor, - summary_metadata=tensor_metadata, - collections=collections) - return tensor_summary - - -def _get_display_name(name, display_name): - """Returns display_name from display_name and name.""" - if display_name is None: - return name - return display_name - - -def _get_json_config(config_dict): - """Parses and returns JSON string from python dictionary.""" - json_config = '{}' - if config_dict is not None: - json_config = json.dumps(config_dict, sort_keys=True) - return json_config - - -def op(name, vertices, faces=None, colors=None, display_name=None, - description=None, collections=None, config_dict=None): - """Creates a TensorFlow summary op for mesh rendering. - - Args: - name: A name for this summary operation. - vertices: Tensor of shape `[dim_1, ..., dim_n, 3]` representing the 3D - coordinates of vertices. - faces: Tensor of shape `[dim_1, ..., dim_n, 3]` containing indices of - vertices within each triangle. - colors: Tensor of shape `[dim_1, ..., dim_n, 3]` containing colors for each - vertex. - display_name: If set, will be used as the display name in TensorBoard. - Defaults to `name`. - description: A longform readable description of the summary data. Markdown - is supported. - collections: Which TensorFlow graph collections to add the summary op to. - Defaults to `['summaries']`. Can usually be ignored. - config_dict: Dictionary with ThreeJS classes names and configuration. - - Returns: - Merged summary for mesh/point cloud representation. - """ - display_name = _get_display_name(name, display_name) - json_config = _get_json_config(config_dict) - - # All tensors representing a single mesh will be represented as separate - # summaries internally. Those summaries will be regrouped on the client before - # rendering. - summaries = [] - tensors = [ - metadata.MeshTensor( - vertices, plugin_data_pb2.MeshPluginData.VERTEX, tf.float32), - metadata.MeshTensor(faces, plugin_data_pb2.MeshPluginData.FACE, tf.int32), - metadata.MeshTensor( - colors, plugin_data_pb2.MeshPluginData.COLOR, tf.uint8) - ] - tensors = [tensor for tensor in tensors if tensor.data is not None] - - components = metadata.get_components_bitmask([ - tensor.content_type for tensor in tensors]) - - for tensor in tensors: - summaries.append( - _get_tensor_summary(name, display_name, description, tensor.data, - tensor.content_type, components, json_config, - collections)) - - all_summaries = tf.compat.v1.summary.merge( - summaries, collections=collections, name=name) - return all_summaries - - -def pb(name, - vertices, - faces=None, - colors=None, - display_name=None, - description=None, - config_dict=None): - """Create a mesh summary to save in pb format. - - Args: - name: A name for this summary operation. - vertices: numpy array of shape `[dim_1, ..., dim_n, 3]` representing the 3D - coordinates of vertices. - faces: numpy array of shape `[dim_1, ..., dim_n, 3]` containing indices of - vertices within each triangle. - colors: numpy array of shape `[dim_1, ..., dim_n, 3]` containing colors for - each vertex. - display_name: If set, will be used as the display name in TensorBoard. - Defaults to `name`. - description: A longform readable description of the summary data. Markdown - is supported. - config_dict: Dictionary with ThreeJS classes names and configuration. - - Returns: - Instance of tf.Summary class. - """ - display_name = _get_display_name(name, display_name) - json_config = _get_json_config(config_dict) - - summaries = [] - tensors = [ - metadata.MeshTensor( - vertices, plugin_data_pb2.MeshPluginData.VERTEX, tf.float32), - metadata.MeshTensor(faces, plugin_data_pb2.MeshPluginData.FACE, tf.int32), - metadata.MeshTensor( - colors, plugin_data_pb2.MeshPluginData.COLOR, tf.uint8) - ] - tensors = [tensor for tensor in tensors if tensor.data is not None] - components = metadata.get_components_bitmask([ - tensor.content_type for tensor in tensors]) - for tensor in tensors: - shape = tensor.data.shape + name, + display_name, + description, + tensor, + content_type, + components, + json_config, + collections, +): + """Creates a tensor summary with summary metadata. + + Args: + name: Uniquely identifiable name of the summary op. Could be replaced by + combination of name and type to make it unique even outside of this + summary. + display_name: Will be used as the display name in TensorBoard. + Defaults to `tag`. + description: A longform readable description of the summary data. Markdown + is supported. + tensor: Tensor to display in summary. + content_type: Type of content inside the Tensor. + components: Bitmask representing present parts (vertices, colors, etc.) that + belong to the summary. + json_config: A string, JSON-serialized dictionary of ThreeJS classes + configuration. + collections: List of collections to add this summary to. + + Returns: + Tensor summary with metadata. + """ + tensor = tf.convert_to_tensor(value=tensor) + shape = tensor.shape.as_list() shape = [dim if dim is not None else -1 for dim in shape] - tensor_proto = tf.compat.v1.make_tensor_proto( - tensor.data, dtype=tensor.data_type) - summary_metadata = metadata.create_summary_metadata( + tensor_metadata = metadata.create_summary_metadata( name, display_name, - tensor.content_type, + content_type, components, shape, description, - json_config=json_config) - tag = metadata.get_instance_name(name, tensor.content_type) - summaries.append((tag, summary_metadata, tensor_proto)) - - summary = tf.compat.v1.Summary() - for tag, summary_metadata, tensor_proto in summaries: - tf_summary_metadata = tf.compat.v1.SummaryMetadata.FromString( - summary_metadata.SerializeToString()) - summary.value.add( - tag=tag, metadata=tf_summary_metadata, tensor=tensor_proto) - return summary + json_config=json_config, + ) + tensor_summary = tf.compat.v1.summary.tensor_summary( + metadata.get_instance_name(name, content_type), + tensor, + summary_metadata=tensor_metadata, + collections=collections, + ) + return tensor_summary + + +def _get_display_name(name, display_name): + """Returns display_name from display_name and name.""" + if display_name is None: + return name + return display_name + + +def _get_json_config(config_dict): + """Parses and returns JSON string from python dictionary.""" + json_config = "{}" + if config_dict is not None: + json_config = json.dumps(config_dict, sort_keys=True) + return json_config + + +def op( + name, + vertices, + faces=None, + colors=None, + display_name=None, + description=None, + collections=None, + config_dict=None, +): + """Creates a TensorFlow summary op for mesh rendering. + + Args: + name: A name for this summary operation. + vertices: Tensor of shape `[dim_1, ..., dim_n, 3]` representing the 3D + coordinates of vertices. + faces: Tensor of shape `[dim_1, ..., dim_n, 3]` containing indices of + vertices within each triangle. + colors: Tensor of shape `[dim_1, ..., dim_n, 3]` containing colors for each + vertex. + display_name: If set, will be used as the display name in TensorBoard. + Defaults to `name`. + description: A longform readable description of the summary data. Markdown + is supported. + collections: Which TensorFlow graph collections to add the summary op to. + Defaults to `['summaries']`. Can usually be ignored. + config_dict: Dictionary with ThreeJS classes names and configuration. + + Returns: + Merged summary for mesh/point cloud representation. + """ + display_name = _get_display_name(name, display_name) + json_config = _get_json_config(config_dict) + + # All tensors representing a single mesh will be represented as separate + # summaries internally. Those summaries will be regrouped on the client before + # rendering. + summaries = [] + tensors = [ + metadata.MeshTensor( + vertices, plugin_data_pb2.MeshPluginData.VERTEX, tf.float32 + ), + metadata.MeshTensor( + faces, plugin_data_pb2.MeshPluginData.FACE, tf.int32 + ), + metadata.MeshTensor( + colors, plugin_data_pb2.MeshPluginData.COLOR, tf.uint8 + ), + ] + tensors = [tensor for tensor in tensors if tensor.data is not None] + + components = metadata.get_components_bitmask( + [tensor.content_type for tensor in tensors] + ) + + for tensor in tensors: + summaries.append( + _get_tensor_summary( + name, + display_name, + description, + tensor.data, + tensor.content_type, + components, + json_config, + collections, + ) + ) + + all_summaries = tf.compat.v1.summary.merge( + summaries, collections=collections, name=name + ) + return all_summaries + + +def pb( + name, + vertices, + faces=None, + colors=None, + display_name=None, + description=None, + config_dict=None, +): + """Create a mesh summary to save in pb format. + + Args: + name: A name for this summary operation. + vertices: numpy array of shape `[dim_1, ..., dim_n, 3]` representing the 3D + coordinates of vertices. + faces: numpy array of shape `[dim_1, ..., dim_n, 3]` containing indices of + vertices within each triangle. + colors: numpy array of shape `[dim_1, ..., dim_n, 3]` containing colors for + each vertex. + display_name: If set, will be used as the display name in TensorBoard. + Defaults to `name`. + description: A longform readable description of the summary data. Markdown + is supported. + config_dict: Dictionary with ThreeJS classes names and configuration. + + Returns: + Instance of tf.Summary class. + """ + display_name = _get_display_name(name, display_name) + json_config = _get_json_config(config_dict) + + summaries = [] + tensors = [ + metadata.MeshTensor( + vertices, plugin_data_pb2.MeshPluginData.VERTEX, tf.float32 + ), + metadata.MeshTensor( + faces, plugin_data_pb2.MeshPluginData.FACE, tf.int32 + ), + metadata.MeshTensor( + colors, plugin_data_pb2.MeshPluginData.COLOR, tf.uint8 + ), + ] + tensors = [tensor for tensor in tensors if tensor.data is not None] + components = metadata.get_components_bitmask( + [tensor.content_type for tensor in tensors] + ) + for tensor in tensors: + shape = tensor.data.shape + shape = [dim if dim is not None else -1 for dim in shape] + tensor_proto = tf.compat.v1.make_tensor_proto( + tensor.data, dtype=tensor.data_type + ) + summary_metadata = metadata.create_summary_metadata( + name, + display_name, + tensor.content_type, + components, + shape, + description, + json_config=json_config, + ) + tag = metadata.get_instance_name(name, tensor.content_type) + summaries.append((tag, summary_metadata, tensor_proto)) + + summary = tf.compat.v1.Summary() + for tag, summary_metadata, tensor_proto in summaries: + tf_summary_metadata = tf.compat.v1.SummaryMetadata.FromString( + summary_metadata.SerializeToString() + ) + summary.value.add( + tag=tag, metadata=tf_summary_metadata, tensor=tensor_proto + ) + return summary diff --git a/tensorboard/plugins/mesh/summary_test.py b/tensorboard/plugins/mesh/summary_test.py index c044e296b1..20ff4d5688 100644 --- a/tensorboard/plugins/mesh/summary_test.py +++ b/tensorboard/plugins/mesh/summary_test.py @@ -29,86 +29,107 @@ class MeshSummaryTest(tf.test.TestCase): + def pb_via_op(self, summary_op): + """Parses pb proto.""" + actual_pbtxt = summary_op.eval() + actual_proto = summary_pb2.Summary() + actual_proto.ParseFromString(actual_pbtxt) + return actual_proto - def pb_via_op(self, summary_op): - """Parses pb proto.""" - actual_pbtxt = summary_op.eval() - actual_proto = summary_pb2.Summary() - actual_proto.ParseFromString(actual_pbtxt) - return actual_proto + def get_components(self, proto): + return metadata.parse_plugin_metadata( + proto.metadata.plugin_data.content + ).components - def get_components(self, proto): - return metadata.parse_plugin_metadata( - proto.metadata.plugin_data.content).components + def verify_proto(self, proto, name): + """Validates proto.""" + self.assertEqual(3, len(proto.value)) + self.assertEqual("%s_VERTEX" % name, proto.value[0].tag) + self.assertEqual("%s_FACE" % name, proto.value[1].tag) + self.assertEqual("%s_COLOR" % name, proto.value[2].tag) - def verify_proto(self, proto, name): - """Validates proto.""" - self.assertEqual(3, len(proto.value)) - self.assertEqual("%s_VERTEX" % name, proto.value[0].tag) - self.assertEqual("%s_FACE" % name, proto.value[1].tag) - self.assertEqual("%s_COLOR" % name, proto.value[2].tag) + self.assertEqual(14, self.get_components(proto.value[0])) + self.assertEqual(14, self.get_components(proto.value[1])) + self.assertEqual(14, self.get_components(proto.value[2])) - self.assertEqual(14, self.get_components(proto.value[0])) - self.assertEqual(14, self.get_components(proto.value[1])) - self.assertEqual(14, self.get_components(proto.value[2])) + def test_get_tensor_summary(self): + """Tests proper creation of tensor summary with mesh plugin + metadata.""" + name = "my_mesh" + display_name = "my_display_name" + description = "my mesh is the best of meshes" + tensor_data = test_utils.get_random_mesh(100) + components = 14 + with tf.compat.v1.Graph().as_default(): + tensor_summary = summary._get_tensor_summary( + name, + display_name, + description, + tensor_data.vertices, + plugin_data_pb2.MeshPluginData.VERTEX, + components, + "", + None, + ) + with self.test_session(): + proto = self.pb_via_op(tensor_summary) + self.assertEqual("%s_VERTEX" % name, proto.value[0].tag) + self.assertEqual( + metadata.PLUGIN_NAME, + proto.value[0].metadata.plugin_data.plugin_name, + ) + self.assertEqual( + components, self.get_components(proto.value[0]) + ) - def test_get_tensor_summary(self): - """Tests proper creation of tensor summary with mesh plugin metadata.""" - name = "my_mesh" - display_name = "my_display_name" - description = "my mesh is the best of meshes" - tensor_data = test_utils.get_random_mesh(100) - components = 14 - with tf.compat.v1.Graph().as_default(): - tensor_summary = summary._get_tensor_summary( - name, display_name, description, tensor_data.vertices, - plugin_data_pb2.MeshPluginData.VERTEX, components, "", None) - with self.test_session(): - proto = self.pb_via_op(tensor_summary) - self.assertEqual("%s_VERTEX" % name, proto.value[0].tag) - self.assertEqual(metadata.PLUGIN_NAME, - proto.value[0].metadata.plugin_data.plugin_name) - self.assertEqual(components, self.get_components(proto.value[0])) + def test_op(self): + """Tests merged summary with different types of data.""" + name = "my_mesh" + tensor_data = test_utils.get_random_mesh( + 100, add_faces=True, add_colors=True + ) + config_dict = {"foo": 1} + with tf.compat.v1.Graph().as_default(): + tensor_summary = summary.op( + name, + tensor_data.vertices, + faces=tensor_data.faces, + colors=tensor_data.colors, + config_dict=config_dict, + ) + with self.test_session() as sess: + proto = self.pb_via_op(tensor_summary) + self.verify_proto(proto, name) + plugin_metadata = metadata.parse_plugin_metadata( + proto.value[0].metadata.plugin_data.content + ) + self.assertEqual( + json.dumps(config_dict, sort_keys=True), + plugin_metadata.json_config, + ) - def test_op(self): - """Tests merged summary with different types of data.""" - name = "my_mesh" - tensor_data = test_utils.get_random_mesh( - 100, add_faces=True, add_colors=True) - config_dict = {"foo": 1} - with tf.compat.v1.Graph().as_default(): - tensor_summary = summary.op( - name, - tensor_data.vertices, - faces=tensor_data.faces, - colors=tensor_data.colors, - config_dict=config_dict) - with self.test_session() as sess: - proto = self.pb_via_op(tensor_summary) + def test_pb(self): + """Tests merged summary protobuf with different types of data.""" + name = "my_mesh" + tensor_data = test_utils.get_random_mesh( + 100, add_faces=True, add_colors=True + ) + config_dict = {"foo": 1} + proto = summary.pb( + name, + tensor_data.vertices, + faces=tensor_data.faces, + colors=tensor_data.colors, + config_dict=config_dict, + ) self.verify_proto(proto, name) plugin_metadata = metadata.parse_plugin_metadata( - proto.value[0].metadata.plugin_data.content) + proto.value[0].metadata.plugin_data.content + ) self.assertEqual( - json.dumps(config_dict, sort_keys=True), plugin_metadata.json_config) - - def test_pb(self): - """Tests merged summary protobuf with different types of data.""" - name = "my_mesh" - tensor_data = test_utils.get_random_mesh( - 100, add_faces=True, add_colors=True) - config_dict = {"foo": 1} - proto = summary.pb( - name, - tensor_data.vertices, - faces=tensor_data.faces, - colors=tensor_data.colors, - config_dict=config_dict) - self.verify_proto(proto, name) - plugin_metadata = metadata.parse_plugin_metadata( - proto.value[0].metadata.plugin_data.content) - self.assertEqual( - json.dumps(config_dict, sort_keys=True), plugin_metadata.json_config) + json.dumps(config_dict, sort_keys=True), plugin_metadata.json_config + ) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/tensorboard/plugins/mesh/summary_v2.py b/tensorboard/plugins/mesh/summary_v2.py index b641464b95..2b428d439f 100644 --- a/tensorboard/plugins/mesh/summary_v2.py +++ b/tensorboard/plugins/mesh/summary_v2.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Mesh summaries and TensorFlow operations to create them. V2 versions""" +"""Mesh summaries and TensorFlow operations to create them. + +V2 versions +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -27,158 +30,188 @@ def _write_summary( - name, description, tensor, content_type, components, - json_config, step): - """Creates a tensor summary with summary metadata. - - Args: - name: A name for this summary. The summary tag used for TensorBoard will - be this name prefixed by any active name scopes. - description: Optional long-form description for this summary, as a - constant `str`. Markdown is supported. Defaults to empty. - tensor: Tensor to display in summary. - content_type: Type of content inside the Tensor. - components: Bitmask representing present parts (vertices, colors, etc.) that - belong to the summary. - json_config: A string, JSON-serialized dictionary of ThreeJS classes - configuration. - step: Explicit `int64`-castable monotonic step value for this summary. If - omitted, this defaults to `tf.summary.experimental.get_step()`, which must - not be None. - - Returns: - A boolean indicating if summary was saved successfully or not. - """ - tensor = tf.convert_to_tensor(value=tensor) - shape = tensor.shape.as_list() - shape = [dim if dim is not None else -1 for dim in shape] - tensor_metadata = metadata.create_summary_metadata( - name, - None, # display_name - content_type, - components, - shape, - description, - json_config=json_config) - return tf.summary.write( + name, description, tensor, content_type, components, json_config, step +): + """Creates a tensor summary with summary metadata. + + Args: + name: A name for this summary. The summary tag used for TensorBoard will + be this name prefixed by any active name scopes. + description: Optional long-form description for this summary, as a + constant `str`. Markdown is supported. Defaults to empty. + tensor: Tensor to display in summary. + content_type: Type of content inside the Tensor. + components: Bitmask representing present parts (vertices, colors, etc.) that + belong to the summary. + json_config: A string, JSON-serialized dictionary of ThreeJS classes + configuration. + step: Explicit `int64`-castable monotonic step value for this summary. If + omitted, this defaults to `tf.summary.experimental.get_step()`, which must + not be None. + + Returns: + A boolean indicating if summary was saved successfully or not. + """ + tensor = tf.convert_to_tensor(value=tensor) + shape = tensor.shape.as_list() + shape = [dim if dim is not None else -1 for dim in shape] + tensor_metadata = metadata.create_summary_metadata( + name, + None, # display_name + content_type, + components, + shape, + description, + json_config=json_config, + ) + return tf.summary.write( tag=metadata.get_instance_name(name, content_type), tensor=tensor, step=step, - metadata=tensor_metadata) + metadata=tensor_metadata, + ) def _get_json_config(config_dict): - """Parses and returns JSON string from python dictionary.""" - json_config = '{}' - if config_dict is not None: - json_config = json.dumps(config_dict, sort_keys=True) - return json_config - - -def mesh(name, vertices, faces=None, colors=None, config_dict=None, step=None, - description=None): - """Writes a TensorFlow mesh summary. - - Args: - name: A name for this summary. The summary tag used for TensorBoard will - be this name prefixed by any active name scopes. - vertices: Tensor of shape `[dim_1, ..., dim_n, 3]` representing the 3D - coordinates of vertices. - faces: Tensor of shape `[dim_1, ..., dim_n, 3]` containing indices of - vertices within each triangle. - colors: Tensor of shape `[dim_1, ..., dim_n, 3]` containing colors for each - vertex. - config_dict: Dictionary with ThreeJS classes names and configuration. - step: Explicit `int64`-castable monotonic step value for this summary. If - omitted, this defaults to `tf.summary.experimental.get_step()`, which must - not be None. - description: Optional long-form description for this summary, as a - constant `str`. Markdown is supported. Defaults to empty. - - Returns: - True if all components of the mesh were saved successfully and False - otherwise. - """ - json_config = _get_json_config(config_dict) - - # All tensors representing a single mesh will be represented as separate - # summaries internally. Those summaries will be regrouped on the client before - # rendering. - tensors = [ - metadata.MeshTensor( - vertices, plugin_data_pb2.MeshPluginData.VERTEX, tf.float32), - metadata.MeshTensor(faces, plugin_data_pb2.MeshPluginData.FACE, tf.int32), - metadata.MeshTensor( - colors, plugin_data_pb2.MeshPluginData.COLOR, tf.uint8) - ] - tensors = [tensor for tensor in tensors if tensor.data is not None] - - components = metadata.get_components_bitmask([ - tensor.content_type for tensor in tensors]) - - summary_scope = ( - getattr(tf.summary.experimental, 'summary_scope', None) or - tf.summary.summary_scope) - all_success = True - with summary_scope(name, 'mesh_summary', values=tensors): + """Parses and returns JSON string from python dictionary.""" + json_config = "{}" + if config_dict is not None: + json_config = json.dumps(config_dict, sort_keys=True) + return json_config + + +def mesh( + name, + vertices, + faces=None, + colors=None, + config_dict=None, + step=None, + description=None, +): + """Writes a TensorFlow mesh summary. + + Args: + name: A name for this summary. The summary tag used for TensorBoard will + be this name prefixed by any active name scopes. + vertices: Tensor of shape `[dim_1, ..., dim_n, 3]` representing the 3D + coordinates of vertices. + faces: Tensor of shape `[dim_1, ..., dim_n, 3]` containing indices of + vertices within each triangle. + colors: Tensor of shape `[dim_1, ..., dim_n, 3]` containing colors for each + vertex. + config_dict: Dictionary with ThreeJS classes names and configuration. + step: Explicit `int64`-castable monotonic step value for this summary. If + omitted, this defaults to `tf.summary.experimental.get_step()`, which must + not be None. + description: Optional long-form description for this summary, as a + constant `str`. Markdown is supported. Defaults to empty. + + Returns: + True if all components of the mesh were saved successfully and False + otherwise. + """ + json_config = _get_json_config(config_dict) + + # All tensors representing a single mesh will be represented as separate + # summaries internally. Those summaries will be regrouped on the client before + # rendering. + tensors = [ + metadata.MeshTensor( + vertices, plugin_data_pb2.MeshPluginData.VERTEX, tf.float32 + ), + metadata.MeshTensor( + faces, plugin_data_pb2.MeshPluginData.FACE, tf.int32 + ), + metadata.MeshTensor( + colors, plugin_data_pb2.MeshPluginData.COLOR, tf.uint8 + ), + ] + tensors = [tensor for tensor in tensors if tensor.data is not None] + + components = metadata.get_components_bitmask( + [tensor.content_type for tensor in tensors] + ) + + summary_scope = ( + getattr(tf.summary.experimental, "summary_scope", None) + or tf.summary.summary_scope + ) + all_success = True + with summary_scope(name, "mesh_summary", values=tensors): + for tensor in tensors: + all_success = all_success and _write_summary( + name, + description, + tensor.data, + tensor.content_type, + components, + json_config, + step, + ) + + return all_success + + +def mesh_pb( + tag, vertices, faces=None, colors=None, config_dict=None, description=None +): + """Create a mesh summary to save in pb format. + + Args: + tag: String tag for the summary. + vertices: numpy array of shape `[dim_1, ..., dim_n, 3]` representing the 3D + coordinates of vertices. + faces: numpy array of shape `[dim_1, ..., dim_n, 3]` containing indices of + vertices within each triangle. + colors: numpy array of shape `[dim_1, ..., dim_n, 3]` containing colors for + each vertex. + config_dict: Dictionary with ThreeJS classes names and configuration. + description: Optional long-form description for this summary, as a + constant `str`. Markdown is supported. Defaults to empty. + + Returns: + Instance of tf.Summary class. + """ + json_config = _get_json_config(config_dict) + + summaries = [] + tensors = [ + metadata.MeshTensor( + vertices, plugin_data_pb2.MeshPluginData.VERTEX, tf.float32 + ), + metadata.MeshTensor( + faces, plugin_data_pb2.MeshPluginData.FACE, tf.int32 + ), + metadata.MeshTensor( + colors, plugin_data_pb2.MeshPluginData.COLOR, tf.uint8 + ), + ] + tensors = [tensor for tensor in tensors if tensor.data is not None] + components = metadata.get_components_bitmask( + [tensor.content_type for tensor in tensors] + ) for tensor in tensors: - all_success = all_success and _write_summary( - name, description, tensor.data, tensor.content_type, - components, json_config, step) - - return all_success - - -def mesh_pb(tag, vertices, faces=None, colors=None, config_dict=None, - description=None): - """Create a mesh summary to save in pb format. - - Args: - tag: String tag for the summary. - vertices: numpy array of shape `[dim_1, ..., dim_n, 3]` representing the 3D - coordinates of vertices. - faces: numpy array of shape `[dim_1, ..., dim_n, 3]` containing indices of - vertices within each triangle. - colors: numpy array of shape `[dim_1, ..., dim_n, 3]` containing colors for - each vertex. - config_dict: Dictionary with ThreeJS classes names and configuration. - description: Optional long-form description for this summary, as a - constant `str`. Markdown is supported. Defaults to empty. - - Returns: - Instance of tf.Summary class. - """ - json_config = _get_json_config(config_dict) - - summaries = [] - tensors = [ - metadata.MeshTensor( - vertices, plugin_data_pb2.MeshPluginData.VERTEX, tf.float32), - metadata.MeshTensor(faces, plugin_data_pb2.MeshPluginData.FACE, tf.int32), - metadata.MeshTensor( - colors, plugin_data_pb2.MeshPluginData.COLOR, tf.uint8) - ] - tensors = [tensor for tensor in tensors if tensor.data is not None] - components = metadata.get_components_bitmask([ - tensor.content_type for tensor in tensors]) - for tensor in tensors: - shape = tensor.data.shape - shape = [dim if dim is not None else -1 for dim in shape] - tensor_proto = tensor_util.make_tensor_proto( - tensor.data, dtype=tensor.data_type) - summary_metadata = metadata.create_summary_metadata( - tag, - None, # display_name - tensor.content_type, - components, - shape, - description, - json_config=json_config) - instance_tag = metadata.get_instance_name(tag, tensor.content_type) - summaries.append((instance_tag, summary_metadata, tensor_proto)) - - summary = summary_pb2.Summary() - for instance_tag, summary_metadata, tensor_proto in summaries: - summary.value.add( - tag=instance_tag, metadata=summary_metadata, tensor=tensor_proto) - return summary + shape = tensor.data.shape + shape = [dim if dim is not None else -1 for dim in shape] + tensor_proto = tensor_util.make_tensor_proto( + tensor.data, dtype=tensor.data_type + ) + summary_metadata = metadata.create_summary_metadata( + tag, + None, # display_name + tensor.content_type, + components, + shape, + description, + json_config=json_config, + ) + instance_tag = metadata.get_instance_name(tag, tensor.content_type) + summaries.append((instance_tag, summary_metadata, tensor_proto)) + + summary = summary_pb2.Summary() + for instance_tag, summary_metadata, tensor_proto in summaries: + summary.value.add( + tag=instance_tag, metadata=summary_metadata, tensor=tensor_proto + ) + return summary diff --git a/tensorboard/plugins/mesh/summary_v2_test.py b/tensorboard/plugins/mesh/summary_v2_test.py index b97faf2fc1..e0299e4b9c 100644 --- a/tensorboard/plugins/mesh/summary_v2_test.py +++ b/tensorboard/plugins/mesh/summary_v2_test.py @@ -33,125 +33,144 @@ try: - tf2.__version__ # Force lazy import to resolve + tf2.__version__ # Force lazy import to resolve except ImportError: - tf2 = None + tf2 = None try: - tf.compat.v1.enable_eager_execution() + tf.compat.v1.enable_eager_execution() except AttributeError: - # TF 2.0 doesn't have this symbol because eager is the default. - pass + # TF 2.0 doesn't have this symbol because eager is the default. + pass class MeshSummaryV2Test(tf.test.TestCase): - def setUp(self): - super(MeshSummaryV2Test, self).setUp() - if tf2 is None: - self.skipTest('v2 summary API not available') - - def mesh_events(self, *args, **kwargs): - self.write_mesh_event(*args, **kwargs) - event_files = sorted(glob.glob(os.path.join(self.get_temp_dir(), '*'))) - self.assertEqual(len(event_files), 1) - events = list(tf.compat.v1.train.summary_iterator(event_files[0])) - # Expect a boilerplate event for the file_version, then the vertices - # summary one. - num_events = 2 - # All additional tensors (i.e. colors or faces) will be stored as separate - # events, so account for them as well. - num_events += len(frozenset(["colors", "faces"]).intersection(kwargs)) - self.assertEqual(len(events), num_events) - # Delete the event file to reset to an empty directory for later calls. - os.remove(event_files[0]) - return events[1:] - - def write_mesh_event(self, *args, **kwargs): - kwargs.setdefault('step', 1) - writer = tf2.summary.create_file_writer(self.get_temp_dir()) - with writer.as_default(): - summary.mesh(*args, **kwargs) - writer.close() - - def get_metadata(self, event): - return metadata.parse_plugin_metadata( - event.summary.value[0].metadata.plugin_data.content) - - def test_step(self): - """Tests that different components of mesh summary share the same step.""" - tensor_data = test_utils.get_random_mesh( - 100, add_faces=True, add_colors=True) - config_dict = {"foo": 1} - events = self.mesh_events( - 'a', - tensor_data.vertices, - faces=tensor_data.faces, - colors=tensor_data.colors, - config_dict=config_dict, - step=333) - self.assertEqual(333, events[0].step) - self.assertEqual(333, events[1].step) - self.assertEqual(333, events[2].step) - - def test_tags(self): - """Tests proper tags for each event/tensor.""" - tensor_data = test_utils.get_random_mesh( - 100, add_faces=True, add_colors=True) - config_dict = {"foo": 1} - name = 'foo' - events = self.mesh_events( - name, - tensor_data.vertices, - faces=tensor_data.faces, - colors=tensor_data.colors, - config_dict=config_dict, - step=333) - expected_names_set = frozenset( - name_tpl % name for name_tpl in ["%s_VERTEX", "%s_FACE", "%s_COLOR"]) - actual_names_set = frozenset([event.summary.value[0].tag for event in events]) - self.assertEqual(expected_names_set, actual_names_set) - expected_bitmask = metadata.get_components_bitmask([ - plugin_data_pb2.MeshPluginData.VERTEX, - plugin_data_pb2.MeshPluginData.FACE, - plugin_data_pb2.MeshPluginData.COLOR, - ]) - for event in events: - self.assertEqual(expected_bitmask, self.get_metadata(event).components) - - def test_pb(self): - """Tests ProtoBuf interface.""" - name = "my_mesh" - tensor_data = test_utils.get_random_mesh( - 100, add_faces=True, add_colors=True) - config_dict = {"foo": 1} - proto = summary.mesh_pb( - name, - tensor_data.vertices, - faces=tensor_data.faces, - colors=tensor_data.colors, - config_dict=config_dict) - plugin_metadata = metadata.parse_plugin_metadata( - proto.value[0].metadata.plugin_data.content) - self.assertEqual( - json.dumps(config_dict, sort_keys=True), plugin_metadata.json_config) + def setUp(self): + super(MeshSummaryV2Test, self).setUp() + if tf2 is None: + self.skipTest("v2 summary API not available") + + def mesh_events(self, *args, **kwargs): + self.write_mesh_event(*args, **kwargs) + event_files = sorted(glob.glob(os.path.join(self.get_temp_dir(), "*"))) + self.assertEqual(len(event_files), 1) + events = list(tf.compat.v1.train.summary_iterator(event_files[0])) + # Expect a boilerplate event for the file_version, then the vertices + # summary one. + num_events = 2 + # All additional tensors (i.e. colors or faces) will be stored as separate + # events, so account for them as well. + num_events += len(frozenset(["colors", "faces"]).intersection(kwargs)) + self.assertEqual(len(events), num_events) + # Delete the event file to reset to an empty directory for later calls. + os.remove(event_files[0]) + return events[1:] + + def write_mesh_event(self, *args, **kwargs): + kwargs.setdefault("step", 1) + writer = tf2.summary.create_file_writer(self.get_temp_dir()) + with writer.as_default(): + summary.mesh(*args, **kwargs) + writer.close() + + def get_metadata(self, event): + return metadata.parse_plugin_metadata( + event.summary.value[0].metadata.plugin_data.content + ) + + def test_step(self): + """Tests that different components of mesh summary share the same + step.""" + tensor_data = test_utils.get_random_mesh( + 100, add_faces=True, add_colors=True + ) + config_dict = {"foo": 1} + events = self.mesh_events( + "a", + tensor_data.vertices, + faces=tensor_data.faces, + colors=tensor_data.colors, + config_dict=config_dict, + step=333, + ) + self.assertEqual(333, events[0].step) + self.assertEqual(333, events[1].step) + self.assertEqual(333, events[2].step) + + def test_tags(self): + """Tests proper tags for each event/tensor.""" + tensor_data = test_utils.get_random_mesh( + 100, add_faces=True, add_colors=True + ) + config_dict = {"foo": 1} + name = "foo" + events = self.mesh_events( + name, + tensor_data.vertices, + faces=tensor_data.faces, + colors=tensor_data.colors, + config_dict=config_dict, + step=333, + ) + expected_names_set = frozenset( + name_tpl % name for name_tpl in ["%s_VERTEX", "%s_FACE", "%s_COLOR"] + ) + actual_names_set = frozenset( + [event.summary.value[0].tag for event in events] + ) + self.assertEqual(expected_names_set, actual_names_set) + expected_bitmask = metadata.get_components_bitmask( + [ + plugin_data_pb2.MeshPluginData.VERTEX, + plugin_data_pb2.MeshPluginData.FACE, + plugin_data_pb2.MeshPluginData.COLOR, + ] + ) + for event in events: + self.assertEqual( + expected_bitmask, self.get_metadata(event).components + ) + + def test_pb(self): + """Tests ProtoBuf interface.""" + name = "my_mesh" + tensor_data = test_utils.get_random_mesh( + 100, add_faces=True, add_colors=True + ) + config_dict = {"foo": 1} + proto = summary.mesh_pb( + name, + tensor_data.vertices, + faces=tensor_data.faces, + colors=tensor_data.colors, + config_dict=config_dict, + ) + plugin_metadata = metadata.parse_plugin_metadata( + proto.value[0].metadata.plugin_data.content + ) + self.assertEqual( + json.dumps(config_dict, sort_keys=True), plugin_metadata.json_config + ) class MeshSummaryV2GraphTest(MeshSummaryV2Test, tf.test.TestCase): - def write_mesh_event(self, *args, **kwargs): - kwargs.setdefault('step', 1) - # Hack to extract current scope since there's no direct API for it. - with tf.name_scope('_') as temp_scope: - scope = temp_scope.rstrip('/_') - @tf2.function - def graph_fn(): - # Recreate the active scope inside the defun since it won't propagate. - with tf.name_scope(scope): - summary.mesh(*args, **kwargs) - writer = tf2.summary.create_file_writer(self.get_temp_dir()) - with writer.as_default(): - graph_fn() - writer.close() - - -if __name__ == '__main__': - tf.test.main() + def write_mesh_event(self, *args, **kwargs): + kwargs.setdefault("step", 1) + # Hack to extract current scope since there's no direct API for it. + with tf.name_scope("_") as temp_scope: + scope = temp_scope.rstrip("/_") + + @tf2.function + def graph_fn(): + # Recreate the active scope inside the defun since it won't propagate. + with tf.name_scope(scope): + summary.mesh(*args, **kwargs) + + writer = tf2.summary.create_file_writer(self.get_temp_dir()) + with writer.as_default(): + graph_fn() + writer.close() + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/mesh/test_utils.py b/tensorboard/plugins/mesh/test_utils.py index 572dab95e2..fef0e222d0 100644 --- a/tensorboard/plugins/mesh/test_utils.py +++ b/tensorboard/plugins/mesh/test_utils.py @@ -30,67 +30,69 @@ from tensorboard.compat.proto import summary_pb2 from tensorboard.util import tb_logging -Mesh = collections.namedtuple('Mesh', ('vertices', 'faces', 'colors')) +Mesh = collections.namedtuple("Mesh", ("vertices", "faces", "colors")) logger = tb_logging.get_logger() -def get_random_mesh(num_vertices, - add_faces=False, - add_colors=False, - batch_size=1): - """Returns a random point cloud, optionally with random disconnected faces. - - Args: - num_vertices: Number of vertices in the point cloud or mesh. - add_faces: Random faces will be generated and added to the mesh when True. - add_colors: Random colors will be assigned to each vertex when True. Each - color will be in a range of [0, 255]. - batch_size: Size of batch dimension in output array. - - Returns: - Mesh namedtuple with vertices and optionally with faces and/or colors. - """ - vertices = np.random.random([num_vertices, 3]) * 1000 - # Add batch dimension. - vertices = np.tile(vertices, [batch_size, 1, 1]) - faces = None - colors = None - if add_faces: - arranged_vertices = np.random.permutation(num_vertices) - faces = [] - for i in range(num_vertices - 2): - faces.append([ - arranged_vertices[i], arranged_vertices[i + 1], - arranged_vertices[i + 2] - ]) - faces = np.array(faces) - faces = np.tile(faces, [batch_size, 1, 1]).astype(np.int32) - if add_colors: - colors = np.random.randint(low=0, high=255, size=[num_vertices, 3]) - colors = np.tile(colors, [batch_size, 1, 1]).astype(np.uint8) - return Mesh(vertices.astype(np.float32), faces, colors) +def get_random_mesh( + num_vertices, add_faces=False, add_colors=False, batch_size=1 +): + """Returns a random point cloud, optionally with random disconnected faces. + + Args: + num_vertices: Number of vertices in the point cloud or mesh. + add_faces: Random faces will be generated and added to the mesh when True. + add_colors: Random colors will be assigned to each vertex when True. Each + color will be in a range of [0, 255]. + batch_size: Size of batch dimension in output array. + + Returns: + Mesh namedtuple with vertices and optionally with faces and/or colors. + """ + vertices = np.random.random([num_vertices, 3]) * 1000 + # Add batch dimension. + vertices = np.tile(vertices, [batch_size, 1, 1]) + faces = None + colors = None + if add_faces: + arranged_vertices = np.random.permutation(num_vertices) + faces = [] + for i in range(num_vertices - 2): + faces.append( + [ + arranged_vertices[i], + arranged_vertices[i + 1], + arranged_vertices[i + 2], + ] + ) + faces = np.array(faces) + faces = np.tile(faces, [batch_size, 1, 1]).astype(np.int32) + if add_colors: + colors = np.random.randint(low=0, high=255, size=[num_vertices, 3]) + colors = np.tile(colors, [batch_size, 1, 1]).astype(np.uint8) + return Mesh(vertices.astype(np.float32), faces, colors) def deserialize_json_response(byte_content): - """Deserializes byte content that is a JSON encoding. + """Deserializes byte content that is a JSON encoding. - Args: - byte_content: The byte content of a response. + Args: + byte_content: The byte content of a response. - Returns: - The deserialized python object decoded from JSON. - """ - return json.loads(byte_content.decode('utf-8')) + Returns: + The deserialized python object decoded from JSON. + """ + return json.loads(byte_content.decode("utf-8")) def deserialize_array_buffer_response(byte_content, data_type): - """Deserializes arraybuffer response and optionally tiles the array. + """Deserializes arraybuffer response and optionally tiles the array. - Args: - byte_content: The byte content of a response. - data_type: Numpy type to parse data with. + Args: + byte_content: The byte content of a response. + data_type: Numpy type to parse data with. - Returns: - Flat numpy array with the data. - """ - return np.frombuffer(byte_content, dtype=data_type) + Returns: + Flat numpy array with the data. + """ + return np.frombuffer(byte_content, dtype=data_type) diff --git a/tensorboard/plugins/pr_curve/metadata.py b/tensorboard/plugins/pr_curve/metadata.py index fa1a55b4b7..4ca7961d4f 100644 --- a/tensorboard/plugins/pr_curve/metadata.py +++ b/tensorboard/plugins/pr_curve/metadata.py @@ -24,7 +24,7 @@ logger = tb_logging.get_logger() -PLUGIN_NAME = 'pr_curves' +PLUGIN_NAME = "pr_curves" # Indices for obtaining various values from the tensor stored in a summary. TRUE_POSITIVES_INDEX = 0 @@ -38,46 +38,52 @@ # `PrCurvePluginData` proto. PROTO_VERSION = 0 + def create_summary_metadata(display_name, description, num_thresholds): - """Create a `summary_pb2.SummaryMetadata` proto for pr_curves plugin data. + """Create a `summary_pb2.SummaryMetadata` proto for pr_curves plugin data. - Arguments: - display_name: The display name used in TensorBoard. - description: The description to show in TensorBoard. - num_thresholds: The number of thresholds to use for PR curves. + Arguments: + display_name: The display name used in TensorBoard. + description: The description to show in TensorBoard. + num_thresholds: The number of thresholds to use for PR curves. - Returns: - A `summary_pb2.SummaryMetadata` protobuf object. - """ - pr_curve_plugin_data = plugin_data_pb2.PrCurvePluginData( - version=PROTO_VERSION, num_thresholds=num_thresholds) - content = pr_curve_plugin_data.SerializeToString() - return summary_pb2.SummaryMetadata( - display_name=display_name, - summary_description=description, - plugin_data=summary_pb2.SummaryMetadata.PluginData( - plugin_name=PLUGIN_NAME, - content=content)) + Returns: + A `summary_pb2.SummaryMetadata` protobuf object. + """ + pr_curve_plugin_data = plugin_data_pb2.PrCurvePluginData( + version=PROTO_VERSION, num_thresholds=num_thresholds + ) + content = pr_curve_plugin_data.SerializeToString() + return summary_pb2.SummaryMetadata( + display_name=display_name, + summary_description=description, + plugin_data=summary_pb2.SummaryMetadata.PluginData( + plugin_name=PLUGIN_NAME, content=content + ), + ) def parse_plugin_metadata(content): - """Parse summary metadata to a Python object. + """Parse summary metadata to a Python object. - Arguments: - content: The `content` field of a `SummaryMetadata` proto - corresponding to the pr_curves plugin. + Arguments: + content: The `content` field of a `SummaryMetadata` proto + corresponding to the pr_curves plugin. - Returns: - A `PrCurvesPlugin` protobuf object. - """ - if not isinstance(content, bytes): - raise TypeError('Content type must be bytes') - result = plugin_data_pb2.PrCurvePluginData.FromString(content) - if result.version == 0: - return result - else: - logger.warn( - 'Unknown metadata version: %s. The latest version known to ' - 'this build of TensorBoard is %s; perhaps a newer build is ' - 'available?', result.version, PROTO_VERSION) - return result + Returns: + A `PrCurvesPlugin` protobuf object. + """ + if not isinstance(content, bytes): + raise TypeError("Content type must be bytes") + result = plugin_data_pb2.PrCurvePluginData.FromString(content) + if result.version == 0: + return result + else: + logger.warn( + "Unknown metadata version: %s. The latest version known to " + "this build of TensorBoard is %s; perhaps a newer build is " + "available?", + result.version, + PROTO_VERSION, + ) + return result diff --git a/tensorboard/plugins/pr_curve/pr_curve_demo.py b/tensorboard/plugins/pr_curve/pr_curve_demo.py index 0cf00c3dc3..97643c12ee 100644 --- a/tensorboard/plugins/pr_curve/pr_curve_demo.py +++ b/tensorboard/plugins/pr_curve/pr_curve_demo.py @@ -42,195 +42,238 @@ tf.compat.v1.disable_v2_behavior() FLAGS = flags.FLAGS -flags.DEFINE_string('logdir', '/tmp/pr_curve_demo', - 'Directory into which to write TensorBoard data.') +flags.DEFINE_string( + "logdir", + "/tmp/pr_curve_demo", + "Directory into which to write TensorBoard data.", +) + +flags.DEFINE_integer( + "steps", 10, "Number of steps to generate for each PR curve." +) -flags.DEFINE_integer('steps', 10, - 'Number of steps to generate for each PR curve.') def start_runs( - logdir, - steps, - run_name, - thresholds, - mask_every_other_prediction=False): - """Generate a PR curve with precision and recall evenly weighted. - - Arguments: - logdir: The directory into which to store all the runs' data. - steps: The number of steps to run for. - run_name: The name of the run. - thresholds: The number of thresholds to use for PR curves. - mask_every_other_prediction: Whether to mask every other prediction by - alternating weights between 0 and 1. - """ - tf.compat.v1.reset_default_graph() - tf.compat.v1.set_random_seed(42) - - # Create a normal distribution layer used to generate true color labels. - distribution = tf.compat.v1.distributions.Normal(loc=0., scale=142.) - - # Sample the distribution to generate colors. Lets generate different numbers - # of each color. The first dimension is the count of examples. - - # The calls to sample() are given fixed random seed values that are "magic" - # in that they correspond to the default seeds for those ops when the PR - # curve test (which depends on this code) was written. We've pinned these - # instead of continuing to use the defaults since the defaults are based on - # node IDs from the sequence of nodes added to the graph, which can silently - # change when this code or any TF op implementations it uses are modified. - - # TODO(nickfelt): redo the PR curve test to avoid reliance on random seeds. - - # Generate reds. - number_of_reds = 100 - true_reds = tf.clip_by_value( - tf.concat([ - 255 - tf.abs(distribution.sample([number_of_reds, 1], seed=11)), - tf.abs(distribution.sample([number_of_reds, 2], seed=34)) - ], axis=1), - 0, 255) - - # Generate greens. - number_of_greens = 200 - true_greens = tf.clip_by_value( - tf.concat([ - tf.abs(distribution.sample([number_of_greens, 1], seed=61)), - 255 - tf.abs(distribution.sample([number_of_greens, 1], seed=82)), - tf.abs(distribution.sample([number_of_greens, 1], seed=105)) - ], axis=1), - 0, 255) - - # Generate blues. - number_of_blues = 150 - true_blues = tf.clip_by_value( - tf.concat([ - tf.abs(distribution.sample([number_of_blues, 2], seed=132)), - 255 - tf.abs(distribution.sample([number_of_blues, 1], seed=153)) - ], axis=1), - 0, 255) - - # Assign each color a vector of 3 booleans based on its true label. - labels = tf.concat([ - tf.tile(tf.constant([[True, False, False]]), (number_of_reds, 1)), - tf.tile(tf.constant([[False, True, False]]), (number_of_greens, 1)), - tf.tile(tf.constant([[False, False, True]]), (number_of_blues, 1)), - ], axis=0) - - # We introduce 3 normal distributions. They are used to predict whether a - # color falls under a certain class (based on distances from corners of the - # color triangle). The distributions vary per color. We have the distributions - # narrow over time. - initial_standard_deviations = [v + FLAGS.steps for v in (158, 200, 242)] - iteration = tf.compat.v1.placeholder(tf.int32, shape=[]) - red_predictor = tf.compat.v1.distributions.Normal( - loc=0., - scale=tf.cast( - initial_standard_deviations[0] - iteration, - dtype=tf.float32)) - green_predictor = tf.compat.v1.distributions.Normal( - loc=0., - scale=tf.cast( - initial_standard_deviations[1] - iteration, - dtype=tf.float32)) - blue_predictor = tf.compat.v1.distributions.Normal( - loc=0., - scale=tf.cast( - initial_standard_deviations[2] - iteration, - dtype=tf.float32)) - - # Make predictions (assign 3 probabilities to each color based on each color's - # distance to each of the 3 corners). We seek double the area in the right - # tail of the normal distribution. - examples = tf.concat([true_reds, true_greens, true_blues], axis=0) - probabilities_colors_are_red = (1 - red_predictor.cdf( - tf.norm(tensor=examples - tf.constant([255., 0, 0]), axis=1))) * 2 - probabilities_colors_are_green = (1 - green_predictor.cdf( - tf.norm(tensor=examples - tf.constant([0, 255., 0]), axis=1))) * 2 - probabilities_colors_are_blue = (1 - blue_predictor.cdf( - tf.norm(tensor=examples - tf.constant([0, 0, 255.]), axis=1))) * 2 - - predictions = ( - probabilities_colors_are_red, - probabilities_colors_are_green, - probabilities_colors_are_blue - ) - - # This is the crucial piece. We write data required for generating PR curves. - # We create 1 summary per class because we create 1 PR curve per class. - for i, color in enumerate(('red', 'green', 'blue')): - description = ('The probabilities used to create this PR curve are ' - 'generated from a normal distribution. Its standard ' - 'deviation is initially %0.0f and decreases over time.' % - initial_standard_deviations[i]) - - weights = None - if mask_every_other_prediction: - # Assign a weight of 0 to every even-indexed prediction. Odd-indexed - # predictions are assigned a default weight of 1. - consecutive_indices = tf.reshape( - tf.range(tf.size(input=predictions[i])), tf.shape(input=predictions[i])) - weights = tf.cast(consecutive_indices % 2, dtype=tf.float32) - - summary.op( - name=color, - labels=labels[:, i], - predictions=predictions[i], - num_thresholds=thresholds, - weights=weights, - display_name='classifying %s' % color, - description=description) - merged_summary_op = tf.compat.v1.summary.merge_all() - events_directory = os.path.join(logdir, run_name) - sess = tf.compat.v1.Session() - writer = tf.compat.v1.summary.FileWriter(events_directory, sess.graph) - - for step in xrange(steps): - feed_dict = { - iteration: step, - } - merged_summary = sess.run(merged_summary_op, feed_dict=feed_dict) - writer.add_summary(merged_summary, step) - - writer.close() + logdir, steps, run_name, thresholds, mask_every_other_prediction=False +): + """Generate a PR curve with precision and recall evenly weighted. + + Arguments: + logdir: The directory into which to store all the runs' data. + steps: The number of steps to run for. + run_name: The name of the run. + thresholds: The number of thresholds to use for PR curves. + mask_every_other_prediction: Whether to mask every other prediction by + alternating weights between 0 and 1. + """ + tf.compat.v1.reset_default_graph() + tf.compat.v1.set_random_seed(42) + + # Create a normal distribution layer used to generate true color labels. + distribution = tf.compat.v1.distributions.Normal(loc=0.0, scale=142.0) + + # Sample the distribution to generate colors. Lets generate different numbers + # of each color. The first dimension is the count of examples. + + # The calls to sample() are given fixed random seed values that are "magic" + # in that they correspond to the default seeds for those ops when the PR + # curve test (which depends on this code) was written. We've pinned these + # instead of continuing to use the defaults since the defaults are based on + # node IDs from the sequence of nodes added to the graph, which can silently + # change when this code or any TF op implementations it uses are modified. + + # TODO(nickfelt): redo the PR curve test to avoid reliance on random seeds. + + # Generate reds. + number_of_reds = 100 + true_reds = tf.clip_by_value( + tf.concat( + [ + 255 - tf.abs(distribution.sample([number_of_reds, 1], seed=11)), + tf.abs(distribution.sample([number_of_reds, 2], seed=34)), + ], + axis=1, + ), + 0, + 255, + ) + + # Generate greens. + number_of_greens = 200 + true_greens = tf.clip_by_value( + tf.concat( + [ + tf.abs(distribution.sample([number_of_greens, 1], seed=61)), + 255 + - tf.abs(distribution.sample([number_of_greens, 1], seed=82)), + tf.abs(distribution.sample([number_of_greens, 1], seed=105)), + ], + axis=1, + ), + 0, + 255, + ) + + # Generate blues. + number_of_blues = 150 + true_blues = tf.clip_by_value( + tf.concat( + [ + tf.abs(distribution.sample([number_of_blues, 2], seed=132)), + 255 + - tf.abs(distribution.sample([number_of_blues, 1], seed=153)), + ], + axis=1, + ), + 0, + 255, + ) + + # Assign each color a vector of 3 booleans based on its true label. + labels = tf.concat( + [ + tf.tile(tf.constant([[True, False, False]]), (number_of_reds, 1)), + tf.tile(tf.constant([[False, True, False]]), (number_of_greens, 1)), + tf.tile(tf.constant([[False, False, True]]), (number_of_blues, 1)), + ], + axis=0, + ) + + # We introduce 3 normal distributions. They are used to predict whether a + # color falls under a certain class (based on distances from corners of the + # color triangle). The distributions vary per color. We have the distributions + # narrow over time. + initial_standard_deviations = [v + FLAGS.steps for v in (158, 200, 242)] + iteration = tf.compat.v1.placeholder(tf.int32, shape=[]) + red_predictor = tf.compat.v1.distributions.Normal( + loc=0.0, + scale=tf.cast( + initial_standard_deviations[0] - iteration, dtype=tf.float32 + ), + ) + green_predictor = tf.compat.v1.distributions.Normal( + loc=0.0, + scale=tf.cast( + initial_standard_deviations[1] - iteration, dtype=tf.float32 + ), + ) + blue_predictor = tf.compat.v1.distributions.Normal( + loc=0.0, + scale=tf.cast( + initial_standard_deviations[2] - iteration, dtype=tf.float32 + ), + ) + + # Make predictions (assign 3 probabilities to each color based on each color's + # distance to each of the 3 corners). We seek double the area in the right + # tail of the normal distribution. + examples = tf.concat([true_reds, true_greens, true_blues], axis=0) + probabilities_colors_are_red = ( + 1 + - red_predictor.cdf( + tf.norm(tensor=examples - tf.constant([255.0, 0, 0]), axis=1) + ) + ) * 2 + probabilities_colors_are_green = ( + 1 + - green_predictor.cdf( + tf.norm(tensor=examples - tf.constant([0, 255.0, 0]), axis=1) + ) + ) * 2 + probabilities_colors_are_blue = ( + 1 + - blue_predictor.cdf( + tf.norm(tensor=examples - tf.constant([0, 0, 255.0]), axis=1) + ) + ) * 2 + + predictions = ( + probabilities_colors_are_red, + probabilities_colors_are_green, + probabilities_colors_are_blue, + ) + + # This is the crucial piece. We write data required for generating PR curves. + # We create 1 summary per class because we create 1 PR curve per class. + for i, color in enumerate(("red", "green", "blue")): + description = ( + "The probabilities used to create this PR curve are " + "generated from a normal distribution. Its standard " + "deviation is initially %0.0f and decreases over time." + % initial_standard_deviations[i] + ) + + weights = None + if mask_every_other_prediction: + # Assign a weight of 0 to every even-indexed prediction. Odd-indexed + # predictions are assigned a default weight of 1. + consecutive_indices = tf.reshape( + tf.range(tf.size(input=predictions[i])), + tf.shape(input=predictions[i]), + ) + weights = tf.cast(consecutive_indices % 2, dtype=tf.float32) + + summary.op( + name=color, + labels=labels[:, i], + predictions=predictions[i], + num_thresholds=thresholds, + weights=weights, + display_name="classifying %s" % color, + description=description, + ) + merged_summary_op = tf.compat.v1.summary.merge_all() + events_directory = os.path.join(logdir, run_name) + sess = tf.compat.v1.Session() + writer = tf.compat.v1.summary.FileWriter(events_directory, sess.graph) + + for step in xrange(steps): + feed_dict = { + iteration: step, + } + merged_summary = sess.run(merged_summary_op, feed_dict=feed_dict) + writer.add_summary(merged_summary, step) + + writer.close() + def run_all(logdir, steps, thresholds, verbose=False): - """Generate PR curve summaries. - - Arguments: - logdir: The directory into which to store all the runs' data. - steps: The number of steps to run for. - verbose: Whether to print the names of runs into stdout during execution. - thresholds: The number of thresholds to use for PR curves. - """ - # First, we generate data for a PR curve that assigns even weights for - # predictions of all classes. - run_name = 'colors' - if verbose: - print('--- Running: %s' % run_name) - start_runs( - logdir=logdir, - steps=steps, - run_name=run_name, - thresholds=thresholds) - - # Next, we generate data for a PR curve that assigns arbitrary weights to - # predictions. - run_name = 'mask_every_other_prediction' - if verbose: - print('--- Running: %s' % run_name) - start_runs( - logdir=logdir, - steps=steps, - run_name=run_name, - thresholds=thresholds, - mask_every_other_prediction=True) + """Generate PR curve summaries. + + Arguments: + logdir: The directory into which to store all the runs' data. + steps: The number of steps to run for. + verbose: Whether to print the names of runs into stdout during execution. + thresholds: The number of thresholds to use for PR curves. + """ + # First, we generate data for a PR curve that assigns even weights for + # predictions of all classes. + run_name = "colors" + if verbose: + print("--- Running: %s" % run_name) + start_runs( + logdir=logdir, steps=steps, run_name=run_name, thresholds=thresholds + ) + + # Next, we generate data for a PR curve that assigns arbitrary weights to + # predictions. + run_name = "mask_every_other_prediction" + if verbose: + print("--- Running: %s" % run_name) + start_runs( + logdir=logdir, + steps=steps, + run_name=run_name, + thresholds=thresholds, + mask_every_other_prediction=True, + ) + def main(unused_argv): - print('Saving output to %s.' % FLAGS.logdir) - run_all(FLAGS.logdir, FLAGS.steps, 50, verbose=True) - print('Done. Output saved to %s.' % FLAGS.logdir) + print("Saving output to %s." % FLAGS.logdir) + run_all(FLAGS.logdir, FLAGS.steps, 50, verbose=True) + print("Done. Output saved to %s." % FLAGS.logdir) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/tensorboard/plugins/pr_curve/pr_curves_plugin.py b/tensorboard/plugins/pr_curve/pr_curves_plugin.py index 589196af73..9168ce9f05 100644 --- a/tensorboard/plugins/pr_curve/pr_curves_plugin.py +++ b/tensorboard/plugins/pr_curve/pr_curves_plugin.py @@ -30,67 +30,73 @@ class PrCurvesPlugin(base_plugin.TBPlugin): - """A plugin that serves PR curves for individual classes.""" - - plugin_name = metadata.PLUGIN_NAME - - def __init__(self, context): - """Instantiates a PrCurvesPlugin. - Args: - context: A base_plugin.TBContext instance. A magic container that - TensorBoard uses to make objects available to the plugin. - """ - self._db_connection_provider = context.db_connection_provider - self._multiplexer = context.multiplexer - - @wrappers.Request.application - def pr_curves_route(self, request): - """A route that returns a JSON mapping between runs and PR curve data. - - Returns: - Given a tag and a comma-separated list of runs (both stored within GET - parameters), fetches a JSON object that maps between run name and objects - containing data required for PR curves for that run. Runs that either - cannot be found or that lack tags will be excluded from the response. - """ - runs = request.args.getlist('run') - if not runs: - return http_util.Respond( - request, 'No runs provided when fetching PR curve data', 400) - - tag = request.args.get('tag') - if not tag: - return http_util.Respond( - request, 'No tag provided when fetching PR curve data', 400) - - try: - response = http_util.Respond( - request, self.pr_curves_impl(runs, tag), 'application/json') - except ValueError as e: - return http_util.Respond(request, str(e), 'text/plain', 400) - - return response - - def pr_curves_impl(self, runs, tag): - """Creates the JSON object for the PR curves response for a run-tag combo. - - Arguments: - runs: A list of runs to fetch the curves for. - tag: The tag to fetch the curves for. - - Raises: - ValueError: If no PR curves could be fetched for a run and tag. - - Returns: - The JSON object for the PR curves route response. - """ - if self._db_connection_provider: - # Serve data from the database. - db = self._db_connection_provider() - - # We select for steps greater than -1 because the writer inserts - # placeholder rows en masse. The check for step filters out those rows. - cursor = db.execute(''' + """A plugin that serves PR curves for individual classes.""" + + plugin_name = metadata.PLUGIN_NAME + + def __init__(self, context): + """Instantiates a PrCurvesPlugin. + + Args: + context: A base_plugin.TBContext instance. A magic container that + TensorBoard uses to make objects available to the plugin. + """ + self._db_connection_provider = context.db_connection_provider + self._multiplexer = context.multiplexer + + @wrappers.Request.application + def pr_curves_route(self, request): + """A route that returns a JSON mapping between runs and PR curve data. + + Returns: + Given a tag and a comma-separated list of runs (both stored within GET + parameters), fetches a JSON object that maps between run name and objects + containing data required for PR curves for that run. Runs that either + cannot be found or that lack tags will be excluded from the response. + """ + runs = request.args.getlist("run") + if not runs: + return http_util.Respond( + request, "No runs provided when fetching PR curve data", 400 + ) + + tag = request.args.get("tag") + if not tag: + return http_util.Respond( + request, "No tag provided when fetching PR curve data", 400 + ) + + try: + response = http_util.Respond( + request, self.pr_curves_impl(runs, tag), "application/json" + ) + except ValueError as e: + return http_util.Respond(request, str(e), "text/plain", 400) + + return response + + def pr_curves_impl(self, runs, tag): + """Creates the JSON object for the PR curves response for a run-tag + combo. + + Arguments: + runs: A list of runs to fetch the curves for. + tag: The tag to fetch the curves for. + + Raises: + ValueError: If no PR curves could be fetched for a run and tag. + + Returns: + The JSON object for the PR curves route response. + """ + if self._db_connection_provider: + # Serve data from the database. + db = self._db_connection_provider() + + # We select for steps greater than -1 because the writer inserts + # placeholder rows en masse. The check for step filters out those rows. + cursor = db.execute( + """ SELECT Runs.run_name, Tensors.step, @@ -110,78 +116,102 @@ def pr_curves_impl(self, runs, tag): AND Tags.plugin_name = ? AND Tensors.step > -1 ORDER BY Tensors.step - ''' % ','.join(['?'] * len(runs)), runs + [tag, metadata.PLUGIN_NAME]) - response_mapping = {} - for (run, step, wall_time, data, dtype, shape, plugin_data) in cursor: - if run not in response_mapping: - response_mapping[run] = [] - buf = np.frombuffer(data, dtype=tf.DType(dtype).as_numpy_dtype) - data_array = buf.reshape([int(i) for i in shape.split(',')]) - plugin_data_proto = plugin_data_pb2.PrCurvePluginData() - string_buffer = np.frombuffer(plugin_data, dtype=np.dtype('b')) - plugin_data_proto.ParseFromString(tf.compat.as_bytes( - string_buffer.tostring())) - thresholds = self._compute_thresholds(plugin_data_proto.num_thresholds) - entry = self._make_pr_entry(step, wall_time, data_array, thresholds) - response_mapping[run].append(entry) - else: - # Serve data from events files. - response_mapping = {} - for run in runs: - try: - tensor_events = self._multiplexer.Tensors(run, tag) - except KeyError: - raise ValueError( - 'No PR curves could be found for run %r and tag %r' % (run, tag)) - - content = self._multiplexer.SummaryMetadata( - run, tag).plugin_data.content - pr_curve_data = metadata.parse_plugin_metadata(content) - thresholds = self._compute_thresholds(pr_curve_data.num_thresholds) - response_mapping[run] = [ - self._process_tensor_event(e, thresholds) for e in tensor_events] - return response_mapping - - def _compute_thresholds(self, num_thresholds): - """Computes a list of specific thresholds from the number of thresholds. - - Args: - num_thresholds: The number of thresholds. - - Returns: - A list of specific thresholds (floats). - """ - return [float(v) / num_thresholds for v in range(1, num_thresholds + 1)] - - @wrappers.Request.application - def tags_route(self, request): - """A route (HTTP handler) that returns a response with tags. - - Returns: - A response that contains a JSON object. The keys of the object - are all the runs. Each run is mapped to a (potentially empty) dictionary - whose keys are tags associated with run and whose values are metadata - (dictionaries). - - The metadata dictionaries contain 2 keys: - - displayName: For the display name used atop visualizations in - TensorBoard. - - description: The description that appears near visualizations upon the - user hovering over a certain icon. - """ - return http_util.Respond( - request, self.tags_impl(), 'application/json') - - def tags_impl(self): - """Creates the JSON object for the tags route response. - - Returns: - The JSON object for the tags route response. - """ - if self._db_connection_provider: - # Read tags from the database. - db = self._db_connection_provider() - cursor = db.execute(''' + """ + % ",".join(["?"] * len(runs)), + runs + [tag, metadata.PLUGIN_NAME], + ) + response_mapping = {} + for ( + run, + step, + wall_time, + data, + dtype, + shape, + plugin_data, + ) in cursor: + if run not in response_mapping: + response_mapping[run] = [] + buf = np.frombuffer(data, dtype=tf.DType(dtype).as_numpy_dtype) + data_array = buf.reshape([int(i) for i in shape.split(",")]) + plugin_data_proto = plugin_data_pb2.PrCurvePluginData() + string_buffer = np.frombuffer(plugin_data, dtype=np.dtype("b")) + plugin_data_proto.ParseFromString( + tf.compat.as_bytes(string_buffer.tostring()) + ) + thresholds = self._compute_thresholds( + plugin_data_proto.num_thresholds + ) + entry = self._make_pr_entry( + step, wall_time, data_array, thresholds + ) + response_mapping[run].append(entry) + else: + # Serve data from events files. + response_mapping = {} + for run in runs: + try: + tensor_events = self._multiplexer.Tensors(run, tag) + except KeyError: + raise ValueError( + "No PR curves could be found for run %r and tag %r" + % (run, tag) + ) + + content = self._multiplexer.SummaryMetadata( + run, tag + ).plugin_data.content + pr_curve_data = metadata.parse_plugin_metadata(content) + thresholds = self._compute_thresholds( + pr_curve_data.num_thresholds + ) + response_mapping[run] = [ + self._process_tensor_event(e, thresholds) + for e in tensor_events + ] + return response_mapping + + def _compute_thresholds(self, num_thresholds): + """Computes a list of specific thresholds from the number of + thresholds. + + Args: + num_thresholds: The number of thresholds. + + Returns: + A list of specific thresholds (floats). + """ + return [float(v) / num_thresholds for v in range(1, num_thresholds + 1)] + + @wrappers.Request.application + def tags_route(self, request): + """A route (HTTP handler) that returns a response with tags. + + Returns: + A response that contains a JSON object. The keys of the object + are all the runs. Each run is mapped to a (potentially empty) dictionary + whose keys are tags associated with run and whose values are metadata + (dictionaries). + + The metadata dictionaries contain 2 keys: + - displayName: For the display name used atop visualizations in + TensorBoard. + - description: The description that appears near visualizations upon the + user hovering over a certain icon. + """ + return http_util.Respond(request, self.tags_impl(), "application/json") + + def tags_impl(self): + """Creates the JSON object for the tags route response. + + Returns: + The JSON object for the tags route response. + """ + if self._db_connection_provider: + # Read tags from the database. + db = self._db_connection_provider() + cursor = db.execute( + """ SELECT Tags.tag_name, Tags.display_name, @@ -191,55 +221,67 @@ def tags_impl(self): ON Tags.run_id = Runs.run_id WHERE Tags.plugin_name = ? - ''', (metadata.PLUGIN_NAME,)) - result = {} - for (tag_name, display_name, run_name) in cursor: - if run_name not in result: - result[run_name] = {} - result[run_name][tag_name] = { - 'displayName': display_name, - # TODO(chihuahua): Populate the description. Currently, the tags - # table does not link with the description table. - 'description': '', - } - else: - # Read tags from events files. - runs = self._multiplexer.Runs() - result = {run: {} for run in runs} - - mapping = self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME) - for (run, tag_to_content) in six.iteritems(mapping): - for (tag, _) in six.iteritems(tag_to_content): - summary_metadata = self._multiplexer.SummaryMetadata(run, tag) - result[run][tag] = {'displayName': summary_metadata.display_name, - 'description': plugin_util.markdown_to_safe_html( - summary_metadata.summary_description)} - - return result - - @wrappers.Request.application - def available_time_entries_route(self, request): - """Gets a dict mapping run to a list of time entries. - Returns: - A dict with string keys (all runs with PR curve data). The values of the - dict are lists of time entries (consisting of the fields below) to be - used in populating values within time sliders. - """ - return http_util.Respond( - request, self.available_time_entries_impl(), 'application/json') - - def available_time_entries_impl(self): - """Creates the JSON object for the available time entries route response. - - Returns: - The JSON object for the available time entries route response. - """ - result = {} - if self._db_connection_provider: - db = self._db_connection_provider() - # For each run, pick a tag. - cursor = db.execute( - ''' + """, + (metadata.PLUGIN_NAME,), + ) + result = {} + for (tag_name, display_name, run_name) in cursor: + if run_name not in result: + result[run_name] = {} + result[run_name][tag_name] = { + "displayName": display_name, + # TODO(chihuahua): Populate the description. Currently, the tags + # table does not link with the description table. + "description": "", + } + else: + # Read tags from events files. + runs = self._multiplexer.Runs() + result = {run: {} for run in runs} + + mapping = self._multiplexer.PluginRunToTagToContent( + metadata.PLUGIN_NAME + ) + for (run, tag_to_content) in six.iteritems(mapping): + for (tag, _) in six.iteritems(tag_to_content): + summary_metadata = self._multiplexer.SummaryMetadata( + run, tag + ) + result[run][tag] = { + "displayName": summary_metadata.display_name, + "description": plugin_util.markdown_to_safe_html( + summary_metadata.summary_description + ), + } + + return result + + @wrappers.Request.application + def available_time_entries_route(self, request): + """Gets a dict mapping run to a list of time entries. + + Returns: + A dict with string keys (all runs with PR curve data). The values of the + dict are lists of time entries (consisting of the fields below) to be + used in populating values within time sliders. + """ + return http_util.Respond( + request, self.available_time_entries_impl(), "application/json" + ) + + def available_time_entries_impl(self): + """Creates the JSON object for the available time entries route + response. + + Returns: + The JSON object for the available time entries route response. + """ + result = {} + if self._db_connection_provider: + db = self._db_connection_provider() + # For each run, pick a tag. + cursor = db.execute( + """ SELECT TagPickingTable.run_name, Tensors.step, @@ -259,147 +301,165 @@ def available_time_entries_impl(self): ON Tensors.series = TagPickingTable.tag_id WHERE Tensors.step IS NOT NULL ORDER BY Tensors.step - ''', (metadata.PLUGIN_NAME,)) - for (run, step, wall_time) in cursor: - if run not in result: - result[run] = [] - result[run].append(self._create_time_entry(step, wall_time)) - else: - # Read data from disk. - all_runs = self._multiplexer.PluginRunToTagToContent( - metadata.PLUGIN_NAME) - for run, tag_to_content in all_runs.items(): - if not tag_to_content: - # This run lacks data for this plugin. - continue - # Just use the list of tensor events for any of the tags to determine - # the steps to list for the run. The steps are often the same across - # tags for each run, albeit the user may elect to sample certain tags - # differently within the same run. If the latter occurs, TensorBoard - # will show the actual step of each tag atop the card for the tag. - tensor_events = self._multiplexer.Tensors( - run, min(six.iterkeys(tag_to_content))) - result[run] = [self._create_time_entry(e.step, e.wall_time) - for e in tensor_events] - return result - - def _create_time_entry(self, step, wall_time): - """Creates a time entry given a tensor event. - - Arguments: - step: The step for the time entry. - wall_time: The wall time for the time entry. - - Returns: - A JSON-able time entry to be passed to the frontend in order to construct - the slider. - """ - return { - 'step': step, - 'wall_time': wall_time, - } - - def get_plugin_apps(self): - """Gets all routes offered by the plugin. - - Returns: - A dictionary mapping URL path to route that handles it. - """ - return { - '/tags': self.tags_route, - '/pr_curves': self.pr_curves_route, - '/available_time_entries': self.available_time_entries_route, - } - - def is_active(self): - """Determines whether this plugin is active. - - This plugin is active only if PR curve summary data is read by TensorBoard. - - Returns: - Whether this plugin is active. - """ - if self._db_connection_provider: - # The plugin is active if one relevant tag can be found in the database. - db = self._db_connection_provider() - cursor = db.execute( - ''' + """, + (metadata.PLUGIN_NAME,), + ) + for (run, step, wall_time) in cursor: + if run not in result: + result[run] = [] + result[run].append(self._create_time_entry(step, wall_time)) + else: + # Read data from disk. + all_runs = self._multiplexer.PluginRunToTagToContent( + metadata.PLUGIN_NAME + ) + for run, tag_to_content in all_runs.items(): + if not tag_to_content: + # This run lacks data for this plugin. + continue + # Just use the list of tensor events for any of the tags to determine + # the steps to list for the run. The steps are often the same across + # tags for each run, albeit the user may elect to sample certain tags + # differently within the same run. If the latter occurs, TensorBoard + # will show the actual step of each tag atop the card for the tag. + tensor_events = self._multiplexer.Tensors( + run, min(six.iterkeys(tag_to_content)) + ) + result[run] = [ + self._create_time_entry(e.step, e.wall_time) + for e in tensor_events + ] + return result + + def _create_time_entry(self, step, wall_time): + """Creates a time entry given a tensor event. + + Arguments: + step: The step for the time entry. + wall_time: The wall time for the time entry. + + Returns: + A JSON-able time entry to be passed to the frontend in order to construct + the slider. + """ + return { + "step": step, + "wall_time": wall_time, + } + + def get_plugin_apps(self): + """Gets all routes offered by the plugin. + + Returns: + A dictionary mapping URL path to route that handles it. + """ + return { + "/tags": self.tags_route, + "/pr_curves": self.pr_curves_route, + "/available_time_entries": self.available_time_entries_route, + } + + def is_active(self): + """Determines whether this plugin is active. + + This plugin is active only if PR curve summary data is read by TensorBoard. + + Returns: + Whether this plugin is active. + """ + if self._db_connection_provider: + # The plugin is active if one relevant tag can be found in the database. + db = self._db_connection_provider() + cursor = db.execute( + """ SELECT 1 FROM Tags WHERE Tags.plugin_name = ? LIMIT 1 - ''', - (metadata.PLUGIN_NAME,)) - return bool(list(cursor)) - - if not self._multiplexer: - return False - - all_runs = self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME) - - # The plugin is active if any of the runs has a tag relevant to the plugin. - return any(six.itervalues(all_runs)) - - def frontend_metadata(self): - return base_plugin.FrontendMetadata( - element_name='tf-pr-curve-dashboard', - tab_name='PR Curves', - ) - - def _process_tensor_event(self, event, thresholds): - """Converts a TensorEvent into a dict that encapsulates information on it. - - Args: - event: The TensorEvent to convert. - thresholds: An array of floats that ranges from 0 to 1 (in that - direction and inclusive of 0 and 1). - - Returns: - A JSON-able dictionary of PR curve data for 1 step. - """ - return self._make_pr_entry( - event.step, - event.wall_time, - tensor_util.make_ndarray(event.tensor_proto), - thresholds) - - def _make_pr_entry(self, step, wall_time, data_array, thresholds): - """Creates an entry for PR curve data. Each entry corresponds to 1 step. - - Args: - step: The step. - wall_time: The wall time. - data_array: A numpy array of PR curve data stored in the summary format. - thresholds: An array of floating point thresholds. - - Returns: - A PR curve entry. - """ - # Trim entries for which TP + FP = 0 (precision is undefined) at the tail of - # the data. - true_positives = [int(v) for v in data_array[metadata.TRUE_POSITIVES_INDEX]] - false_positives = [ - int(v) for v in data_array[metadata.FALSE_POSITIVES_INDEX]] - tp_index = metadata.TRUE_POSITIVES_INDEX - fp_index = metadata.FALSE_POSITIVES_INDEX - positives = data_array[[tp_index, fp_index], :].astype(int).sum(axis=0) - end_index_inclusive = len(positives) - 1 - while end_index_inclusive > 0 and positives[end_index_inclusive] == 0: - end_index_inclusive -= 1 - end_index = end_index_inclusive + 1 - - return { - 'wall_time': wall_time, - 'step': step, - 'precision': data_array[metadata.PRECISION_INDEX, :end_index].tolist(), - 'recall': data_array[metadata.RECALL_INDEX, :end_index].tolist(), - 'true_positives': true_positives[:end_index], - 'false_positives': false_positives[:end_index], - 'true_negatives': - [int(v) for v in - data_array[metadata.TRUE_NEGATIVES_INDEX][:end_index]], - 'false_negatives': - [int(v) for v in - data_array[metadata.FALSE_NEGATIVES_INDEX][:end_index]], - 'thresholds': thresholds[:end_index], - } + """, + (metadata.PLUGIN_NAME,), + ) + return bool(list(cursor)) + + if not self._multiplexer: + return False + + all_runs = self._multiplexer.PluginRunToTagToContent( + metadata.PLUGIN_NAME + ) + + # The plugin is active if any of the runs has a tag relevant to the plugin. + return any(six.itervalues(all_runs)) + + def frontend_metadata(self): + return base_plugin.FrontendMetadata( + element_name="tf-pr-curve-dashboard", tab_name="PR Curves", + ) + + def _process_tensor_event(self, event, thresholds): + """Converts a TensorEvent into a dict that encapsulates information on + it. + + Args: + event: The TensorEvent to convert. + thresholds: An array of floats that ranges from 0 to 1 (in that + direction and inclusive of 0 and 1). + + Returns: + A JSON-able dictionary of PR curve data for 1 step. + """ + return self._make_pr_entry( + event.step, + event.wall_time, + tensor_util.make_ndarray(event.tensor_proto), + thresholds, + ) + + def _make_pr_entry(self, step, wall_time, data_array, thresholds): + """Creates an entry for PR curve data. Each entry corresponds to 1 + step. + + Args: + step: The step. + wall_time: The wall time. + data_array: A numpy array of PR curve data stored in the summary format. + thresholds: An array of floating point thresholds. + + Returns: + A PR curve entry. + """ + # Trim entries for which TP + FP = 0 (precision is undefined) at the tail of + # the data. + true_positives = [ + int(v) for v in data_array[metadata.TRUE_POSITIVES_INDEX] + ] + false_positives = [ + int(v) for v in data_array[metadata.FALSE_POSITIVES_INDEX] + ] + tp_index = metadata.TRUE_POSITIVES_INDEX + fp_index = metadata.FALSE_POSITIVES_INDEX + positives = data_array[[tp_index, fp_index], :].astype(int).sum(axis=0) + end_index_inclusive = len(positives) - 1 + while end_index_inclusive > 0 and positives[end_index_inclusive] == 0: + end_index_inclusive -= 1 + end_index = end_index_inclusive + 1 + + return { + "wall_time": wall_time, + "step": step, + "precision": data_array[ + metadata.PRECISION_INDEX, :end_index + ].tolist(), + "recall": data_array[metadata.RECALL_INDEX, :end_index].tolist(), + "true_positives": true_positives[:end_index], + "false_positives": false_positives[:end_index], + "true_negatives": [ + int(v) + for v in data_array[metadata.TRUE_NEGATIVES_INDEX][:end_index] + ], + "false_negatives": [ + int(v) + for v in data_array[metadata.FALSE_NEGATIVES_INDEX][:end_index] + ], + "thresholds": thresholds[:end_index], + } diff --git a/tensorboard/plugins/pr_curve/pr_curves_plugin_test.py b/tensorboard/plugins/pr_curve/pr_curves_plugin_test.py index afe1c9a893..077bc0c4cb 100644 --- a/tensorboard/plugins/pr_curve/pr_curves_plugin_test.py +++ b/tensorboard/plugins/pr_curve/pr_curves_plugin_test.py @@ -26,7 +26,9 @@ import six import tensorflow as tf -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from tensorboard.backend.event_processing import ( + plugin_event_multiplexer as event_multiplexer, +) from tensorboard.plugins import base_plugin from tensorboard.plugins.pr_curve import pr_curve_demo from tensorboard.plugins.pr_curve import pr_curves_plugin @@ -35,272 +37,314 @@ # are small. The default relative error (rtol) of 1e-7 yields many undesired # test failures. assert_allclose = functools.partial( - np.testing.assert_allclose, rtol=0, atol=1e-7) + np.testing.assert_allclose, rtol=0, atol=1e-7 +) class PrCurvesPluginTest(tf.test.TestCase): - - def setUp(self): - super(PrCurvesPluginTest, self).setUp() - logdir = os.path.join(self.get_temp_dir(), 'logdir') - - # Generate data. - pr_curve_demo.run_all( - logdir=logdir, - steps=3, - thresholds=5, - verbose=False) - - # Create a multiplexer for reading the data we just wrote. - multiplexer = event_multiplexer.EventMultiplexer() - multiplexer.AddRunsFromDirectory(logdir) - multiplexer.Reload() - - context = base_plugin.TBContext(logdir=logdir, multiplexer=multiplexer) - self.plugin = pr_curves_plugin.PrCurvesPlugin(context) - - def validatePrCurveEntry( - self, - expected_step, - expected_precision, - expected_recall, - expected_true_positives, - expected_false_positives, - expected_true_negatives, - expected_false_negatives, - expected_thresholds, - pr_curve_entry): - """Checks that the values stored within a tensor are correct. - - Args: - expected_step: The expected step. - expected_precision: A list of float values. - expected_recall: A list of float values. - expected_true_positives: A list of int values. - expected_false_positives: A list of int values. - expected_true_negatives: A list of int values. - expected_false_negatives: A list of int values. - expected_thresholds: A list of floats ranging from 0 to 1. - pr_curve_entry: The PR curve entry to evaluate. - """ - self.assertEqual(expected_step, pr_curve_entry['step']) - assert_allclose(expected_precision, pr_curve_entry['precision']) - assert_allclose(expected_recall, pr_curve_entry['recall']) - self.assertListEqual( - expected_true_positives, pr_curve_entry['true_positives']) - self.assertListEqual( - expected_false_positives, pr_curve_entry['false_positives']) - self.assertListEqual( - expected_true_negatives, pr_curve_entry['true_negatives']) - self.assertListEqual( - expected_false_negatives, pr_curve_entry['false_negatives']) - assert_allclose(expected_thresholds, pr_curve_entry['thresholds']) - - def computeCorrectDescription(self, standard_deviation): - """Generates a correct description. - - Arguments: - standard_deviation: An integer standard deviation value. - - Returns: - The correct description given a standard deviation value. - """ - description = ('

The probabilities used to create this PR curve are ' - 'generated from a normal distribution. Its standard ' - 'deviation is initially %d and decreases' - ' over time.

') % standard_deviation - return description - - def testRoutesProvided(self): - """Tests that the plugin offers the correct routes.""" - routes = self.plugin.get_plugin_apps() - self.assertIsInstance(routes['/tags'], collections.Callable) - self.assertIsInstance(routes['/pr_curves'], collections.Callable) - self.assertIsInstance( - routes['/available_time_entries'], collections.Callable) - - def testTagsProvided(self): - """Tests that tags are provided.""" - tags_response = self.plugin.tags_impl() - - # Assert that the runs are right. - self.assertItemsEqual( - ['colors', 'mask_every_other_prediction'], list(tags_response.keys())) - - # Assert that the tags for each run are correct. - self.assertItemsEqual( - ['red/pr_curves', 'green/pr_curves', 'blue/pr_curves'], - list(tags_response['colors'].keys())) - self.assertItemsEqual( - ['red/pr_curves', 'green/pr_curves', 'blue/pr_curves'], - list(tags_response['mask_every_other_prediction'].keys())) - - # Verify the data for each run-tag combination. - self.assertDictEqual({ - 'displayName': 'classifying red', - 'description': self.computeCorrectDescription(168), - }, tags_response['colors']['red/pr_curves']) - self.assertDictEqual({ - 'displayName': 'classifying green', - 'description': self.computeCorrectDescription(210), - }, tags_response['colors']['green/pr_curves']) - self.assertDictEqual({ - 'displayName': 'classifying blue', - 'description': self.computeCorrectDescription(252), - }, tags_response['colors']['blue/pr_curves']) - self.assertDictEqual({ - 'displayName': 'classifying red', - 'description': self.computeCorrectDescription(168), - }, tags_response['mask_every_other_prediction']['red/pr_curves']) - self.assertDictEqual({ - 'displayName': 'classifying green', - 'description': self.computeCorrectDescription(210), - }, tags_response['mask_every_other_prediction']['green/pr_curves']) - self.assertDictEqual({ - 'displayName': 'classifying blue', - 'description': self.computeCorrectDescription(252), - }, tags_response['mask_every_other_prediction']['blue/pr_curves']) - - def testAvailableSteps(self): - """Tests that runs are mapped to correct available steps.""" - # Test that all runs are within the keys of the mapping. - response = self.plugin.available_time_entries_impl() - self.assertItemsEqual( - ['colors', 'mask_every_other_prediction'], list(response.keys())) - - # TODO(chizeng): Find a means of testing the wall time and relative time. - # The wall time written to disk is computed within TensorFlow C++. - entries = response['colors'] - entry = entries[0] - self.assertEqual(0, entry['step']) - self.assertIn('wall_time', entry) - entry = entries[1] - self.assertEqual(1, entry['step']) - self.assertIn('wall_time', entry) - entry = entries[2] - self.assertEqual(2, entry['step']) - self.assertIn('wall_time', entry) - - entries = response['mask_every_other_prediction'] - entry = entries[0] - self.assertEqual(0, entry['step']) - self.assertIn('wall_time', entry) - entry = entries[1] - self.assertEqual(1, entry['step']) - self.assertIn('wall_time', entry) - entry = entries[2] - self.assertEqual(2, entry['step']) - self.assertIn('wall_time', entry) - - def testPrCurvesDataCorrect(self): - """Tests that responses for PR curves for run-tag combos are correct.""" - pr_curves_response = self.plugin.pr_curves_impl( - ['colors', 'mask_every_other_prediction'], 'blue/pr_curves') - - # Assert that the runs are correct. - self.assertItemsEqual( - ['colors', 'mask_every_other_prediction'], - list(pr_curves_response.keys())) - - # Assert that PR curve data is correct for the colors run. - entries = pr_curves_response['colors'] - self.assertEqual(3, len(entries)) - self.validatePrCurveEntry( - expected_step=0, - expected_precision=[0.3333333, 0.3853211, 0.5421687, 0.75], - expected_recall=[1.0, 0.84, 0.3, 0.04], - expected_true_positives=[150, 126, 45, 6], - expected_false_positives=[300, 201, 38, 2], - expected_true_negatives=[0, 99, 262, 298], - expected_false_negatives=[0, 24, 105, 144], - expected_thresholds=[0.2, 0.4, 0.6, 0.8], - pr_curve_entry=entries[0]) - self.validatePrCurveEntry( - expected_step=1, - expected_precision=[0.3333333, 0.3855422, 0.5357143, 0.4], - expected_recall=[1.0, 0.8533334, 0.3, 0.0266667], - expected_true_positives=[150, 128, 45, 4], - expected_false_positives=[300, 204, 39, 6], - expected_true_negatives=[0, 96, 261, 294], - expected_false_negatives=[0, 22, 105, 146], - expected_thresholds=[0.2, 0.4, 0.6, 0.8], - pr_curve_entry=entries[1]) - self.validatePrCurveEntry( - expected_step=2, - expected_precision=[0.3333333, 0.3934426, 0.5064935, 0.6666667], - expected_recall=[1.0, 0.8, 0.26, 0.0266667], - expected_true_positives=[150, 120, 39, 4], - expected_false_positives=[300, 185, 38, 2], - expected_true_negatives=[0, 115, 262, 298], - expected_false_negatives=[0, 30, 111, 146], - expected_thresholds=[0.2, 0.4, 0.6, 0.8], - pr_curve_entry=entries[2]) - - # Assert that PR curve data is correct for the mask_every_other_prediction - # run. - entries = pr_curves_response['mask_every_other_prediction'] - self.assertEqual(3, len(entries)) - self.validatePrCurveEntry( - expected_step=0, - expected_precision=[0.3333333, 0.3786982, 0.5384616, 1.0], - expected_recall=[1.0, 0.8533334, 0.28, 0.0666667], - expected_true_positives=[75, 64, 21, 5], - expected_false_positives=[150, 105, 18, 0], - expected_true_negatives=[0, 45, 132, 150], - expected_false_negatives=[0, 11, 54, 70], - expected_thresholds=[0.2, 0.4, 0.6, 0.8], - pr_curve_entry=entries[0]) - self.validatePrCurveEntry( - expected_step=1, - expected_precision=[0.3333333, 0.3850932, 0.5, 0.25], - expected_recall=[1.0, 0.8266667, 0.28, 0.0133333], - expected_true_positives=[75, 62, 21, 1], - expected_false_positives=[150, 99, 21, 3], - expected_true_negatives=[0, 51, 129, 147], - expected_false_negatives=[0, 13, 54, 74], - expected_thresholds=[0.2, 0.4, 0.6, 0.8], - pr_curve_entry=entries[1]) - self.validatePrCurveEntry( - expected_step=2, - expected_precision=[0.3333333, 0.3986928, 0.4444444, 0.6666667], - expected_recall=[1.0, 0.8133333, 0.2133333, 0.0266667], - expected_true_positives=[75, 61, 16, 2], - expected_false_positives=[150, 92, 20, 1], - expected_true_negatives=[0, 58, 130, 149], - expected_false_negatives=[0, 14, 59, 73], - expected_thresholds=[0.2, 0.4, 0.6, 0.8], - pr_curve_entry=entries[2]) - - def testPrCurvesRaisesValueErrorWhenNoData(self): - """Tests that the method for obtaining PR curve data raises a ValueError. - - The handler should raise a ValueError when no PR curve data can be found - for a certain run-tag combination. - """ - with six.assertRaisesRegex( - self, ValueError, r'No PR curves could be found'): - self.plugin.pr_curves_impl(['colors'], 'non_existent_tag') - - with six.assertRaisesRegex( - self, ValueError, r'No PR curves could be found'): - self.plugin.pr_curves_impl(['non_existent_run'], 'blue/pr_curves') - - def testPluginIsNotActive(self): - """Tests that the plugin is inactive when no relevant data exists.""" - empty_logdir = os.path.join(self.get_temp_dir(), 'empty_logdir') - multiplexer = event_multiplexer.EventMultiplexer() - multiplexer.AddRunsFromDirectory(empty_logdir) - multiplexer.Reload() - context = base_plugin.TBContext( - logdir=empty_logdir, multiplexer=multiplexer) - plugin = pr_curves_plugin.PrCurvesPlugin(context) - self.assertFalse(plugin.is_active()) - - def testPluginIsActive(self): - """Tests that the plugin is active when relevant data exists.""" - # The set up for this test generates relevant data. - self.assertTrue(self.plugin.is_active()) - - -if __name__ == '__main__': - tf.test.main() + def setUp(self): + super(PrCurvesPluginTest, self).setUp() + logdir = os.path.join(self.get_temp_dir(), "logdir") + + # Generate data. + pr_curve_demo.run_all( + logdir=logdir, steps=3, thresholds=5, verbose=False + ) + + # Create a multiplexer for reading the data we just wrote. + multiplexer = event_multiplexer.EventMultiplexer() + multiplexer.AddRunsFromDirectory(logdir) + multiplexer.Reload() + + context = base_plugin.TBContext(logdir=logdir, multiplexer=multiplexer) + self.plugin = pr_curves_plugin.PrCurvesPlugin(context) + + def validatePrCurveEntry( + self, + expected_step, + expected_precision, + expected_recall, + expected_true_positives, + expected_false_positives, + expected_true_negatives, + expected_false_negatives, + expected_thresholds, + pr_curve_entry, + ): + """Checks that the values stored within a tensor are correct. + + Args: + expected_step: The expected step. + expected_precision: A list of float values. + expected_recall: A list of float values. + expected_true_positives: A list of int values. + expected_false_positives: A list of int values. + expected_true_negatives: A list of int values. + expected_false_negatives: A list of int values. + expected_thresholds: A list of floats ranging from 0 to 1. + pr_curve_entry: The PR curve entry to evaluate. + """ + self.assertEqual(expected_step, pr_curve_entry["step"]) + assert_allclose(expected_precision, pr_curve_entry["precision"]) + assert_allclose(expected_recall, pr_curve_entry["recall"]) + self.assertListEqual( + expected_true_positives, pr_curve_entry["true_positives"] + ) + self.assertListEqual( + expected_false_positives, pr_curve_entry["false_positives"] + ) + self.assertListEqual( + expected_true_negatives, pr_curve_entry["true_negatives"] + ) + self.assertListEqual( + expected_false_negatives, pr_curve_entry["false_negatives"] + ) + assert_allclose(expected_thresholds, pr_curve_entry["thresholds"]) + + def computeCorrectDescription(self, standard_deviation): + """Generates a correct description. + + Arguments: + standard_deviation: An integer standard deviation value. + + Returns: + The correct description given a standard deviation value. + """ + description = ( + "

The probabilities used to create this PR curve are " + "generated from a normal distribution. Its standard " + "deviation is initially %d and decreases" + " over time.

" + ) % standard_deviation + return description + + def testRoutesProvided(self): + """Tests that the plugin offers the correct routes.""" + routes = self.plugin.get_plugin_apps() + self.assertIsInstance(routes["/tags"], collections.Callable) + self.assertIsInstance(routes["/pr_curves"], collections.Callable) + self.assertIsInstance( + routes["/available_time_entries"], collections.Callable + ) + + def testTagsProvided(self): + """Tests that tags are provided.""" + tags_response = self.plugin.tags_impl() + + # Assert that the runs are right. + self.assertItemsEqual( + ["colors", "mask_every_other_prediction"], + list(tags_response.keys()), + ) + + # Assert that the tags for each run are correct. + self.assertItemsEqual( + ["red/pr_curves", "green/pr_curves", "blue/pr_curves"], + list(tags_response["colors"].keys()), + ) + self.assertItemsEqual( + ["red/pr_curves", "green/pr_curves", "blue/pr_curves"], + list(tags_response["mask_every_other_prediction"].keys()), + ) + + # Verify the data for each run-tag combination. + self.assertDictEqual( + { + "displayName": "classifying red", + "description": self.computeCorrectDescription(168), + }, + tags_response["colors"]["red/pr_curves"], + ) + self.assertDictEqual( + { + "displayName": "classifying green", + "description": self.computeCorrectDescription(210), + }, + tags_response["colors"]["green/pr_curves"], + ) + self.assertDictEqual( + { + "displayName": "classifying blue", + "description": self.computeCorrectDescription(252), + }, + tags_response["colors"]["blue/pr_curves"], + ) + self.assertDictEqual( + { + "displayName": "classifying red", + "description": self.computeCorrectDescription(168), + }, + tags_response["mask_every_other_prediction"]["red/pr_curves"], + ) + self.assertDictEqual( + { + "displayName": "classifying green", + "description": self.computeCorrectDescription(210), + }, + tags_response["mask_every_other_prediction"]["green/pr_curves"], + ) + self.assertDictEqual( + { + "displayName": "classifying blue", + "description": self.computeCorrectDescription(252), + }, + tags_response["mask_every_other_prediction"]["blue/pr_curves"], + ) + + def testAvailableSteps(self): + """Tests that runs are mapped to correct available steps.""" + # Test that all runs are within the keys of the mapping. + response = self.plugin.available_time_entries_impl() + self.assertItemsEqual( + ["colors", "mask_every_other_prediction"], list(response.keys()) + ) + + # TODO(chizeng): Find a means of testing the wall time and relative time. + # The wall time written to disk is computed within TensorFlow C++. + entries = response["colors"] + entry = entries[0] + self.assertEqual(0, entry["step"]) + self.assertIn("wall_time", entry) + entry = entries[1] + self.assertEqual(1, entry["step"]) + self.assertIn("wall_time", entry) + entry = entries[2] + self.assertEqual(2, entry["step"]) + self.assertIn("wall_time", entry) + + entries = response["mask_every_other_prediction"] + entry = entries[0] + self.assertEqual(0, entry["step"]) + self.assertIn("wall_time", entry) + entry = entries[1] + self.assertEqual(1, entry["step"]) + self.assertIn("wall_time", entry) + entry = entries[2] + self.assertEqual(2, entry["step"]) + self.assertIn("wall_time", entry) + + def testPrCurvesDataCorrect(self): + """Tests that responses for PR curves for run-tag combos are + correct.""" + pr_curves_response = self.plugin.pr_curves_impl( + ["colors", "mask_every_other_prediction"], "blue/pr_curves" + ) + + # Assert that the runs are correct. + self.assertItemsEqual( + ["colors", "mask_every_other_prediction"], + list(pr_curves_response.keys()), + ) + + # Assert that PR curve data is correct for the colors run. + entries = pr_curves_response["colors"] + self.assertEqual(3, len(entries)) + self.validatePrCurveEntry( + expected_step=0, + expected_precision=[0.3333333, 0.3853211, 0.5421687, 0.75], + expected_recall=[1.0, 0.84, 0.3, 0.04], + expected_true_positives=[150, 126, 45, 6], + expected_false_positives=[300, 201, 38, 2], + expected_true_negatives=[0, 99, 262, 298], + expected_false_negatives=[0, 24, 105, 144], + expected_thresholds=[0.2, 0.4, 0.6, 0.8], + pr_curve_entry=entries[0], + ) + self.validatePrCurveEntry( + expected_step=1, + expected_precision=[0.3333333, 0.3855422, 0.5357143, 0.4], + expected_recall=[1.0, 0.8533334, 0.3, 0.0266667], + expected_true_positives=[150, 128, 45, 4], + expected_false_positives=[300, 204, 39, 6], + expected_true_negatives=[0, 96, 261, 294], + expected_false_negatives=[0, 22, 105, 146], + expected_thresholds=[0.2, 0.4, 0.6, 0.8], + pr_curve_entry=entries[1], + ) + self.validatePrCurveEntry( + expected_step=2, + expected_precision=[0.3333333, 0.3934426, 0.5064935, 0.6666667], + expected_recall=[1.0, 0.8, 0.26, 0.0266667], + expected_true_positives=[150, 120, 39, 4], + expected_false_positives=[300, 185, 38, 2], + expected_true_negatives=[0, 115, 262, 298], + expected_false_negatives=[0, 30, 111, 146], + expected_thresholds=[0.2, 0.4, 0.6, 0.8], + pr_curve_entry=entries[2], + ) + + # Assert that PR curve data is correct for the mask_every_other_prediction + # run. + entries = pr_curves_response["mask_every_other_prediction"] + self.assertEqual(3, len(entries)) + self.validatePrCurveEntry( + expected_step=0, + expected_precision=[0.3333333, 0.3786982, 0.5384616, 1.0], + expected_recall=[1.0, 0.8533334, 0.28, 0.0666667], + expected_true_positives=[75, 64, 21, 5], + expected_false_positives=[150, 105, 18, 0], + expected_true_negatives=[0, 45, 132, 150], + expected_false_negatives=[0, 11, 54, 70], + expected_thresholds=[0.2, 0.4, 0.6, 0.8], + pr_curve_entry=entries[0], + ) + self.validatePrCurveEntry( + expected_step=1, + expected_precision=[0.3333333, 0.3850932, 0.5, 0.25], + expected_recall=[1.0, 0.8266667, 0.28, 0.0133333], + expected_true_positives=[75, 62, 21, 1], + expected_false_positives=[150, 99, 21, 3], + expected_true_negatives=[0, 51, 129, 147], + expected_false_negatives=[0, 13, 54, 74], + expected_thresholds=[0.2, 0.4, 0.6, 0.8], + pr_curve_entry=entries[1], + ) + self.validatePrCurveEntry( + expected_step=2, + expected_precision=[0.3333333, 0.3986928, 0.4444444, 0.6666667], + expected_recall=[1.0, 0.8133333, 0.2133333, 0.0266667], + expected_true_positives=[75, 61, 16, 2], + expected_false_positives=[150, 92, 20, 1], + expected_true_negatives=[0, 58, 130, 149], + expected_false_negatives=[0, 14, 59, 73], + expected_thresholds=[0.2, 0.4, 0.6, 0.8], + pr_curve_entry=entries[2], + ) + + def testPrCurvesRaisesValueErrorWhenNoData(self): + """Tests that the method for obtaining PR curve data raises a + ValueError. + + The handler should raise a ValueError when no PR curve data can + be found for a certain run-tag combination. + """ + with six.assertRaisesRegex( + self, ValueError, r"No PR curves could be found" + ): + self.plugin.pr_curves_impl(["colors"], "non_existent_tag") + + with six.assertRaisesRegex( + self, ValueError, r"No PR curves could be found" + ): + self.plugin.pr_curves_impl(["non_existent_run"], "blue/pr_curves") + + def testPluginIsNotActive(self): + """Tests that the plugin is inactive when no relevant data exists.""" + empty_logdir = os.path.join(self.get_temp_dir(), "empty_logdir") + multiplexer = event_multiplexer.EventMultiplexer() + multiplexer.AddRunsFromDirectory(empty_logdir) + multiplexer.Reload() + context = base_plugin.TBContext( + logdir=empty_logdir, multiplexer=multiplexer + ) + plugin = pr_curves_plugin.PrCurvesPlugin(context) + self.assertFalse(plugin.is_active()) + + def testPluginIsActive(self): + """Tests that the plugin is active when relevant data exists.""" + # The set up for this test generates relevant data. + self.assertTrue(self.plugin.is_active()) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/pr_curve/summary.py b/tensorboard/plugins/pr_curve/summary.py index 76c113f26e..be929bb6e2 100644 --- a/tensorboard/plugins/pr_curve/summary.py +++ b/tensorboard/plugins/pr_curve/summary.py @@ -34,6 +34,7 @@ # The default number of thresholds. _DEFAULT_NUM_THRESHOLDS = 201 + def op( name, labels, @@ -42,307 +43,328 @@ def op( weights=None, display_name=None, description=None, - collections=None): - """Create a PR curve summary op for a single binary classifier. - - Computes true/false positive/negative values for the given `predictions` - against the ground truth `labels`, against a list of evenly distributed - threshold values in `[0, 1]` of length `num_thresholds`. - - Each number in `predictions`, a float in `[0, 1]`, is compared with its - corresponding boolean label in `labels`, and counts as a single tp/fp/tn/fn - value at each threshold. This is then multiplied with `weights` which can be - used to reweight certain values, or more commonly used for masking values. - - Args: - name: A tag attached to the summary. Used by TensorBoard for organization. - labels: The ground truth values. A Tensor of `bool` values with arbitrary - shape. - predictions: A float32 `Tensor` whose values are in the range `[0, 1]`. - Dimensions must match those of `labels`. - num_thresholds: Number of thresholds, evenly distributed in `[0, 1]`, to - compute PR metrics for. Should be `>= 2`. This value should be a - constant integer value, not a Tensor that stores an integer. - weights: Optional float32 `Tensor`. Individual counts are multiplied by this - value. This tensor must be either the same shape as or broadcastable to - the `labels` tensor. - display_name: Optional name for this summary in TensorBoard, as a - constant `str`. Defaults to `name`. - description: Optional long-form description for this summary, as a - constant `str`. Markdown is supported. Defaults to empty. - collections: Optional list of graph collections keys. The new - summary op is added to these collections. Defaults to - `[Graph Keys.SUMMARIES]`. - - Returns: - A summary operation for use in a TensorFlow graph. The float32 tensor - produced by the summary operation is of dimension (6, num_thresholds). The - first dimension (of length 6) is of the order: true positives, - false positives, true negatives, false negatives, precision, recall. - - """ - # TODO(nickfelt): remove on-demand imports once dep situation is fixed. - import tensorflow.compat.v1 as tf - - if num_thresholds is None: - num_thresholds = _DEFAULT_NUM_THRESHOLDS - - if weights is None: - weights = 1.0 - - dtype = predictions.dtype - - with tf.name_scope(name, values=[labels, predictions, weights]): - tf.assert_type(labels, tf.bool) - # We cast to float to ensure we have 0.0 or 1.0. - f_labels = tf.cast(labels, dtype) - # Ensure predictions are all in range [0.0, 1.0]. - predictions = tf.minimum(1.0, tf.maximum(0.0, predictions)) - # Get weighted true/false labels. - true_labels = f_labels * weights - false_labels = (1.0 - f_labels) * weights - - # Before we begin, flatten predictions. - predictions = tf.reshape(predictions, [-1]) - - # Shape the labels so they are broadcast-able for later multiplication. - true_labels = tf.reshape(true_labels, [-1, 1]) - false_labels = tf.reshape(false_labels, [-1, 1]) - - # To compute TP/FP/TN/FN, we are measuring a binary classifier - # C(t) = (predictions >= t) - # at each threshold 't'. So we have - # TP(t) = sum( C(t) * true_labels ) - # FP(t) = sum( C(t) * false_labels ) - # - # But, computing C(t) requires computation for each t. To make it fast, - # observe that C(t) is a cumulative integral, and so if we have - # thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1} - # where n = num_thresholds, and if we can compute the bucket function - # B(i) = Sum( (predictions == t), t_i <= t < t{i+1} ) - # then we get - # C(t_i) = sum( B(j), j >= i ) - # which is the reversed cumulative sum in tf.cumsum(). - # - # We can compute B(i) efficiently by taking advantage of the fact that - # our thresholds are evenly distributed, in that - # width = 1.0 / (num_thresholds - 1) - # thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0] - # Given a prediction value p, we can map it to its bucket by - # bucket_index(p) = floor( p * (num_thresholds - 1) ) - # so we can use tf.scatter_add() to update the buckets in one pass. - - # Compute the bucket indices for each prediction value. - bucket_indices = tf.cast( - tf.floor(predictions * (num_thresholds - 1)), tf.int32) - - # Bucket predictions. - tp_buckets = tf.reduce_sum( - input_tensor=tf.one_hot(bucket_indices, depth=num_thresholds) * true_labels, - axis=0) - fp_buckets = tf.reduce_sum( - input_tensor=tf.one_hot(bucket_indices, depth=num_thresholds) * false_labels, - axis=0) - - # Set up the cumulative sums to compute the actual metrics. - tp = tf.cumsum(tp_buckets, reverse=True, name='tp') - fp = tf.cumsum(fp_buckets, reverse=True, name='fp') - # fn = sum(true_labels) - tp - # = sum(tp_buckets) - tp - # = tp[0] - tp - # Similarly, - # tn = fp[0] - fp + collections=None, +): + """Create a PR curve summary op for a single binary classifier. + + Computes true/false positive/negative values for the given `predictions` + against the ground truth `labels`, against a list of evenly distributed + threshold values in `[0, 1]` of length `num_thresholds`. + + Each number in `predictions`, a float in `[0, 1]`, is compared with its + corresponding boolean label in `labels`, and counts as a single tp/fp/tn/fn + value at each threshold. This is then multiplied with `weights` which can be + used to reweight certain values, or more commonly used for masking values. + + Args: + name: A tag attached to the summary. Used by TensorBoard for organization. + labels: The ground truth values. A Tensor of `bool` values with arbitrary + shape. + predictions: A float32 `Tensor` whose values are in the range `[0, 1]`. + Dimensions must match those of `labels`. + num_thresholds: Number of thresholds, evenly distributed in `[0, 1]`, to + compute PR metrics for. Should be `>= 2`. This value should be a + constant integer value, not a Tensor that stores an integer. + weights: Optional float32 `Tensor`. Individual counts are multiplied by this + value. This tensor must be either the same shape as or broadcastable to + the `labels` tensor. + display_name: Optional name for this summary in TensorBoard, as a + constant `str`. Defaults to `name`. + description: Optional long-form description for this summary, as a + constant `str`. Markdown is supported. Defaults to empty. + collections: Optional list of graph collections keys. The new + summary op is added to these collections. Defaults to + `[Graph Keys.SUMMARIES]`. + + Returns: + A summary operation for use in a TensorFlow graph. The float32 tensor + produced by the summary operation is of dimension (6, num_thresholds). The + first dimension (of length 6) is of the order: true positives, + false positives, true negatives, false negatives, precision, recall. + """ + # TODO(nickfelt): remove on-demand imports once dep situation is fixed. + import tensorflow.compat.v1 as tf + + if num_thresholds is None: + num_thresholds = _DEFAULT_NUM_THRESHOLDS + + if weights is None: + weights = 1.0 + + dtype = predictions.dtype + + with tf.name_scope(name, values=[labels, predictions, weights]): + tf.assert_type(labels, tf.bool) + # We cast to float to ensure we have 0.0 or 1.0. + f_labels = tf.cast(labels, dtype) + # Ensure predictions are all in range [0.0, 1.0]. + predictions = tf.minimum(1.0, tf.maximum(0.0, predictions)) + # Get weighted true/false labels. + true_labels = f_labels * weights + false_labels = (1.0 - f_labels) * weights + + # Before we begin, flatten predictions. + predictions = tf.reshape(predictions, [-1]) + + # Shape the labels so they are broadcast-able for later multiplication. + true_labels = tf.reshape(true_labels, [-1, 1]) + false_labels = tf.reshape(false_labels, [-1, 1]) + + # To compute TP/FP/TN/FN, we are measuring a binary classifier + # C(t) = (predictions >= t) + # at each threshold 't'. So we have + # TP(t) = sum( C(t) * true_labels ) + # FP(t) = sum( C(t) * false_labels ) + # + # But, computing C(t) requires computation for each t. To make it fast, + # observe that C(t) is a cumulative integral, and so if we have + # thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1} + # where n = num_thresholds, and if we can compute the bucket function + # B(i) = Sum( (predictions == t), t_i <= t < t{i+1} ) + # then we get + # C(t_i) = sum( B(j), j >= i ) + # which is the reversed cumulative sum in tf.cumsum(). + # + # We can compute B(i) efficiently by taking advantage of the fact that + # our thresholds are evenly distributed, in that + # width = 1.0 / (num_thresholds - 1) + # thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0] + # Given a prediction value p, we can map it to its bucket by + # bucket_index(p) = floor( p * (num_thresholds - 1) ) + # so we can use tf.scatter_add() to update the buckets in one pass. + + # Compute the bucket indices for each prediction value. + bucket_indices = tf.cast( + tf.floor(predictions * (num_thresholds - 1)), tf.int32 + ) + + # Bucket predictions. + tp_buckets = tf.reduce_sum( + input_tensor=tf.one_hot(bucket_indices, depth=num_thresholds) + * true_labels, + axis=0, + ) + fp_buckets = tf.reduce_sum( + input_tensor=tf.one_hot(bucket_indices, depth=num_thresholds) + * false_labels, + axis=0, + ) + + # Set up the cumulative sums to compute the actual metrics. + tp = tf.cumsum(tp_buckets, reverse=True, name="tp") + fp = tf.cumsum(fp_buckets, reverse=True, name="fp") + # fn = sum(true_labels) - tp + # = sum(tp_buckets) - tp + # = tp[0] - tp + # Similarly, + # tn = fp[0] - fp + tn = fp[0] - fp + fn = tp[0] - tp + + precision = tp / tf.maximum(_MINIMUM_COUNT, tp + fp) + recall = tp / tf.maximum(_MINIMUM_COUNT, tp + fn) + + return _create_tensor_summary( + name, + tp, + fp, + tn, + fn, + precision, + recall, + num_thresholds, + display_name, + description, + collections, + ) + + +def pb( + name, + labels, + predictions, + num_thresholds=None, + weights=None, + display_name=None, + description=None, +): + """Create a PR curves summary protobuf. + + Arguments: + name: A name for the generated node. Will also serve as a series name in + TensorBoard. + labels: The ground truth values. A bool numpy array. + predictions: A float32 numpy array whose values are in the range `[0, 1]`. + Dimensions must match those of `labels`. + num_thresholds: Optional number of thresholds, evenly distributed in + `[0, 1]`, to compute PR metrics for. When provided, should be an int of + value at least 2. Defaults to 201. + weights: Optional float or float32 numpy array. Individual counts are + multiplied by this value. This tensor must be either the same shape as + or broadcastable to the `labels` numpy array. + display_name: Optional name for this summary in TensorBoard, as a `str`. + Defaults to `name`. + description: Optional long-form description for this summary, as a `str`. + Markdown is supported. Defaults to empty. + """ + # TODO(nickfelt): remove on-demand imports once dep situation is fixed. + import tensorflow.compat.v1 as tf + + if num_thresholds is None: + num_thresholds = _DEFAULT_NUM_THRESHOLDS + + if weights is None: + weights = 1.0 + + # Compute bins of true positives and false positives. + bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1))) + float_labels = labels.astype(np.float) + histogram_range = (0, num_thresholds - 1) + tp_buckets, _ = np.histogram( + bucket_indices, + bins=num_thresholds, + range=histogram_range, + weights=float_labels * weights, + ) + fp_buckets, _ = np.histogram( + bucket_indices, + bins=num_thresholds, + range=histogram_range, + weights=(1.0 - float_labels) * weights, + ) + + # Obtain the reverse cumulative sum. + tp = np.cumsum(tp_buckets[::-1])[::-1] + fp = np.cumsum(fp_buckets[::-1])[::-1] tn = fp[0] - fp fn = tp[0] - tp + precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp) + recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn) - precision = tp / tf.maximum(_MINIMUM_COUNT, tp + fp) - recall = tp / tf.maximum(_MINIMUM_COUNT, tp + fn) - - return _create_tensor_summary( + return raw_data_pb( name, - tp, - fp, - tn, - fn, - precision, - recall, - num_thresholds, - display_name, - description, - collections) - -def pb(name, - labels, - predictions, - num_thresholds=None, - weights=None, - display_name=None, - description=None): - """Create a PR curves summary protobuf. - - Arguments: - name: A name for the generated node. Will also serve as a series name in - TensorBoard. - labels: The ground truth values. A bool numpy array. - predictions: A float32 numpy array whose values are in the range `[0, 1]`. - Dimensions must match those of `labels`. - num_thresholds: Optional number of thresholds, evenly distributed in - `[0, 1]`, to compute PR metrics for. When provided, should be an int of - value at least 2. Defaults to 201. - weights: Optional float or float32 numpy array. Individual counts are - multiplied by this value. This tensor must be either the same shape as - or broadcastable to the `labels` numpy array. - display_name: Optional name for this summary in TensorBoard, as a `str`. - Defaults to `name`. - description: Optional long-form description for this summary, as a `str`. - Markdown is supported. Defaults to empty. - """ - # TODO(nickfelt): remove on-demand imports once dep situation is fixed. - import tensorflow.compat.v1 as tf - - if num_thresholds is None: - num_thresholds = _DEFAULT_NUM_THRESHOLDS - - if weights is None: - weights = 1.0 - - # Compute bins of true positives and false positives. - bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1))) - float_labels = labels.astype(np.float) - histogram_range = (0, num_thresholds - 1) - tp_buckets, _ = np.histogram( - bucket_indices, - bins=num_thresholds, - range=histogram_range, - weights=float_labels * weights) - fp_buckets, _ = np.histogram( - bucket_indices, - bins=num_thresholds, - range=histogram_range, - weights=(1.0 - float_labels) * weights) - - # Obtain the reverse cumulative sum. - tp = np.cumsum(tp_buckets[::-1])[::-1] - fp = np.cumsum(fp_buckets[::-1])[::-1] - tn = fp[0] - fp - fn = tp[0] - tp - precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp) - recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn) - - return raw_data_pb(name, - true_positive_counts=tp, - false_positive_counts=fp, - true_negative_counts=tn, - false_negative_counts=fn, - precision=precision, - recall=recall, - num_thresholds=num_thresholds, - display_name=display_name, - description=description) - -def streaming_op(name, - labels, - predictions, - num_thresholds=None, - weights=None, - metrics_collections=None, - updates_collections=None, - display_name=None, - description=None): - """Computes a precision-recall curve summary across batches of data. - - This function is similar to op() above, but can be used to compute the PR - curve across multiple batches of labels and predictions, in the same style - as the metrics found in tf.metrics. - - This function creates multiple local variables for storing true positives, - true negative, etc. accumulated over each batch of data, and uses these local - variables for computing the final PR curve summary. These variables can be - updated with the returned update_op. - - Args: - name: A tag attached to the summary. Used by TensorBoard for organization. - labels: The ground truth values, a `Tensor` whose dimensions must match - `predictions`. Will be cast to `bool`. - predictions: A floating point `Tensor` of arbitrary shape and whose values - are in the range `[0, 1]`. - num_thresholds: The number of evenly spaced thresholds to generate for - computing the PR curve. Defaults to 201. - weights: Optional `Tensor` whose rank is either 0, or the same rank as - `labels`, and must be broadcastable to `labels` (i.e., all dimensions must - be either `1`, or the same as the corresponding `labels` dimension). - metrics_collections: An optional list of collections that `auc` should be - added to. - updates_collections: An optional list of collections that `update_op` should - be added to. - display_name: Optional name for this summary in TensorBoard, as a - constant `str`. Defaults to `name`. - description: Optional long-form description for this summary, as a - constant `str`. Markdown is supported. Defaults to empty. - - Returns: - pr_curve: A string `Tensor` containing a single value: the - serialized PR curve Tensor summary. The summary contains a - float32 `Tensor` of dimension (6, num_thresholds). The first - dimension (of length 6) is of the order: true positives, false - positives, true negatives, false negatives, precision, recall. - update_op: An operation that updates the summary with the latest data. - """ - # TODO(nickfelt): remove on-demand imports once dep situation is fixed. - import tensorflow.compat.v1 as tf - - if num_thresholds is None: - num_thresholds = _DEFAULT_NUM_THRESHOLDS - - thresholds = [i / float(num_thresholds - 1) - for i in range(num_thresholds)] - - with tf.name_scope(name, values=[labels, predictions, weights]): - tp, update_tp = tf.metrics.true_positives_at_thresholds( - labels=labels, - predictions=predictions, - thresholds=thresholds, - weights=weights) - fp, update_fp = tf.metrics.false_positives_at_thresholds( - labels=labels, - predictions=predictions, - thresholds=thresholds, - weights=weights) - tn, update_tn = tf.metrics.true_negatives_at_thresholds( - labels=labels, - predictions=predictions, - thresholds=thresholds, - weights=weights) - fn, update_fn = tf.metrics.false_negatives_at_thresholds( - labels=labels, - predictions=predictions, - thresholds=thresholds, - weights=weights) - - def compute_summary(tp, fp, tn, fn, collections): - precision = tp / tf.maximum(_MINIMUM_COUNT, tp + fp) - recall = tp / tf.maximum(_MINIMUM_COUNT, tp + fn) - - return _create_tensor_summary( - name, - tp, - fp, - tn, - fn, - precision, - recall, - num_thresholds, - display_name, - description, - collections) - - pr_curve = compute_summary(tp, fp, tn, fn, metrics_collections) - update_op = tf.group(update_tp, update_fp, update_tn, update_fn) - if updates_collections: - for collection in updates_collections: - tf.add_to_collection(collection, update_op) - - return pr_curve, update_op + true_positive_counts=tp, + false_positive_counts=fp, + true_negative_counts=tn, + false_negative_counts=fn, + precision=precision, + recall=recall, + num_thresholds=num_thresholds, + display_name=display_name, + description=description, + ) + + +def streaming_op( + name, + labels, + predictions, + num_thresholds=None, + weights=None, + metrics_collections=None, + updates_collections=None, + display_name=None, + description=None, +): + """Computes a precision-recall curve summary across batches of data. + + This function is similar to op() above, but can be used to compute the PR + curve across multiple batches of labels and predictions, in the same style + as the metrics found in tf.metrics. + + This function creates multiple local variables for storing true positives, + true negative, etc. accumulated over each batch of data, and uses these local + variables for computing the final PR curve summary. These variables can be + updated with the returned update_op. + + Args: + name: A tag attached to the summary. Used by TensorBoard for organization. + labels: The ground truth values, a `Tensor` whose dimensions must match + `predictions`. Will be cast to `bool`. + predictions: A floating point `Tensor` of arbitrary shape and whose values + are in the range `[0, 1]`. + num_thresholds: The number of evenly spaced thresholds to generate for + computing the PR curve. Defaults to 201. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that `auc` should be + added to. + updates_collections: An optional list of collections that `update_op` should + be added to. + display_name: Optional name for this summary in TensorBoard, as a + constant `str`. Defaults to `name`. + description: Optional long-form description for this summary, as a + constant `str`. Markdown is supported. Defaults to empty. + + Returns: + pr_curve: A string `Tensor` containing a single value: the + serialized PR curve Tensor summary. The summary contains a + float32 `Tensor` of dimension (6, num_thresholds). The first + dimension (of length 6) is of the order: true positives, false + positives, true negatives, false negatives, precision, recall. + update_op: An operation that updates the summary with the latest data. + """ + # TODO(nickfelt): remove on-demand imports once dep situation is fixed. + import tensorflow.compat.v1 as tf + + if num_thresholds is None: + num_thresholds = _DEFAULT_NUM_THRESHOLDS + + thresholds = [i / float(num_thresholds - 1) for i in range(num_thresholds)] + + with tf.name_scope(name, values=[labels, predictions, weights]): + tp, update_tp = tf.metrics.true_positives_at_thresholds( + labels=labels, + predictions=predictions, + thresholds=thresholds, + weights=weights, + ) + fp, update_fp = tf.metrics.false_positives_at_thresholds( + labels=labels, + predictions=predictions, + thresholds=thresholds, + weights=weights, + ) + tn, update_tn = tf.metrics.true_negatives_at_thresholds( + labels=labels, + predictions=predictions, + thresholds=thresholds, + weights=weights, + ) + fn, update_fn = tf.metrics.false_negatives_at_thresholds( + labels=labels, + predictions=predictions, + thresholds=thresholds, + weights=weights, + ) + + def compute_summary(tp, fp, tn, fn, collections): + precision = tp / tf.maximum(_MINIMUM_COUNT, tp + fp) + recall = tp / tf.maximum(_MINIMUM_COUNT, tp + fn) + + return _create_tensor_summary( + name, + tp, + fp, + tn, + fn, + precision, + recall, + num_thresholds, + display_name, + description, + collections, + ) + + pr_curve = compute_summary(tp, fp, tn, fn, metrics_collections) + update_op = tf.group(update_tp, update_fp, update_tn, update_fn) + if updates_collections: + for collection in updates_collections: + tf.add_to_collection(collection, update_op) + + return pr_curve, update_op + def raw_data_op( name, @@ -355,75 +377,81 @@ def raw_data_op( num_thresholds=None, display_name=None, description=None, - collections=None): - """Create an op that collects data for visualizing PR curves. - - Unlike the op above, this one avoids computing precision, recall, and the - intermediate counts. Instead, it accepts those tensors as arguments and - relies on the caller to ensure that the calculations are correct (and the - counts yield the provided precision and recall values). - - This op is useful when a caller seeks to compute precision and recall - differently but still use the PR curves plugin. - - Args: - name: A tag attached to the summary. Used by TensorBoard for organization. - true_positive_counts: A rank-1 tensor of true positive counts. Must contain - `num_thresholds` elements and be castable to float32. Values correspond - to thresholds that increase from left to right (from 0 to 1). - false_positive_counts: A rank-1 tensor of false positive counts. Must - contain `num_thresholds` elements and be castable to float32. Values - correspond to thresholds that increase from left to right (from 0 to 1). - true_negative_counts: A rank-1 tensor of true negative counts. Must contain - `num_thresholds` elements and be castable to float32. Values - correspond to thresholds that increase from left to right (from 0 to 1). - false_negative_counts: A rank-1 tensor of false negative counts. Must - contain `num_thresholds` elements and be castable to float32. Values - correspond to thresholds that increase from left to right (from 0 to 1). - precision: A rank-1 tensor of precision values. Must contain - `num_thresholds` elements and be castable to float32. Values correspond - to thresholds that increase from left to right (from 0 to 1). - recall: A rank-1 tensor of recall values. Must contain `num_thresholds` - elements and be castable to float32. Values correspond to thresholds - that increase from left to right (from 0 to 1). - num_thresholds: Number of thresholds, evenly distributed in `[0, 1]`, to - compute PR metrics for. Should be `>= 2`. This value should be a - constant integer value, not a Tensor that stores an integer. - display_name: Optional name for this summary in TensorBoard, as a - constant `str`. Defaults to `name`. - description: Optional long-form description for this summary, as a - constant `str`. Markdown is supported. Defaults to empty. - collections: Optional list of graph collections keys. The new - summary op is added to these collections. Defaults to - `[Graph Keys.SUMMARIES]`. - - Returns: - A summary operation for use in a TensorFlow graph. See docs for the `op` - method for details on the float32 tensor produced by this summary. - """ - # TODO(nickfelt): remove on-demand imports once dep situation is fixed. - import tensorflow.compat.v1 as tf - - with tf.name_scope(name, values=[ - true_positive_counts, - false_positive_counts, - true_negative_counts, - false_negative_counts, - precision, - recall, - ]): - return _create_tensor_summary( + collections=None, +): + """Create an op that collects data for visualizing PR curves. + + Unlike the op above, this one avoids computing precision, recall, and the + intermediate counts. Instead, it accepts those tensors as arguments and + relies on the caller to ensure that the calculations are correct (and the + counts yield the provided precision and recall values). + + This op is useful when a caller seeks to compute precision and recall + differently but still use the PR curves plugin. + + Args: + name: A tag attached to the summary. Used by TensorBoard for organization. + true_positive_counts: A rank-1 tensor of true positive counts. Must contain + `num_thresholds` elements and be castable to float32. Values correspond + to thresholds that increase from left to right (from 0 to 1). + false_positive_counts: A rank-1 tensor of false positive counts. Must + contain `num_thresholds` elements and be castable to float32. Values + correspond to thresholds that increase from left to right (from 0 to 1). + true_negative_counts: A rank-1 tensor of true negative counts. Must contain + `num_thresholds` elements and be castable to float32. Values + correspond to thresholds that increase from left to right (from 0 to 1). + false_negative_counts: A rank-1 tensor of false negative counts. Must + contain `num_thresholds` elements and be castable to float32. Values + correspond to thresholds that increase from left to right (from 0 to 1). + precision: A rank-1 tensor of precision values. Must contain + `num_thresholds` elements and be castable to float32. Values correspond + to thresholds that increase from left to right (from 0 to 1). + recall: A rank-1 tensor of recall values. Must contain `num_thresholds` + elements and be castable to float32. Values correspond to thresholds + that increase from left to right (from 0 to 1). + num_thresholds: Number of thresholds, evenly distributed in `[0, 1]`, to + compute PR metrics for. Should be `>= 2`. This value should be a + constant integer value, not a Tensor that stores an integer. + display_name: Optional name for this summary in TensorBoard, as a + constant `str`. Defaults to `name`. + description: Optional long-form description for this summary, as a + constant `str`. Markdown is supported. Defaults to empty. + collections: Optional list of graph collections keys. The new + summary op is added to these collections. Defaults to + `[Graph Keys.SUMMARIES]`. + + Returns: + A summary operation for use in a TensorFlow graph. See docs for the `op` + method for details on the float32 tensor produced by this summary. + """ + # TODO(nickfelt): remove on-demand imports once dep situation is fixed. + import tensorflow.compat.v1 as tf + + with tf.name_scope( name, - true_positive_counts, - false_positive_counts, - true_negative_counts, - false_negative_counts, - precision, - recall, - num_thresholds, - display_name, - description, - collections) + values=[ + true_positive_counts, + false_positive_counts, + true_negative_counts, + false_negative_counts, + precision, + recall, + ], + ): + return _create_tensor_summary( + name, + true_positive_counts, + false_positive_counts, + true_negative_counts, + false_negative_counts, + precision, + recall, + num_thresholds, + display_name, + description, + collections, + ) + def raw_data_pb( name, @@ -435,58 +463,65 @@ def raw_data_pb( recall, num_thresholds=None, display_name=None, - description=None): - """Create a PR curves summary protobuf from raw data values. - - Args: - name: A tag attached to the summary. Used by TensorBoard for organization. - true_positive_counts: A rank-1 numpy array of true positive counts. Must - contain `num_thresholds` elements and be castable to float32. - false_positive_counts: A rank-1 numpy array of false positive counts. Must - contain `num_thresholds` elements and be castable to float32. - true_negative_counts: A rank-1 numpy array of true negative counts. Must - contain `num_thresholds` elements and be castable to float32. - false_negative_counts: A rank-1 numpy array of false negative counts. Must - contain `num_thresholds` elements and be castable to float32. - precision: A rank-1 numpy array of precision values. Must contain - `num_thresholds` elements and be castable to float32. - recall: A rank-1 numpy array of recall values. Must contain `num_thresholds` - elements and be castable to float32. - num_thresholds: Number of thresholds, evenly distributed in `[0, 1]`, to - compute PR metrics for. Should be an int `>= 2`. - display_name: Optional name for this summary in TensorBoard, as a `str`. - Defaults to `name`. - description: Optional long-form description for this summary, as a `str`. - Markdown is supported. Defaults to empty. - - Returns: - A summary operation for use in a TensorFlow graph. See docs for the `op` - method for details on the float32 tensor produced by this summary. - """ - # TODO(nickfelt): remove on-demand imports once dep situation is fixed. - import tensorflow.compat.v1 as tf - - if display_name is None: - display_name = name - summary_metadata = metadata.create_summary_metadata( - display_name=display_name if display_name is not None else name, - description=description or '', - num_thresholds=num_thresholds) - tf_summary_metadata = tf.SummaryMetadata.FromString( - summary_metadata.SerializeToString()) - summary = tf.Summary() - data = np.stack( - (true_positive_counts, - false_positive_counts, - true_negative_counts, - false_negative_counts, - precision, - recall)) - tensor = tf.make_tensor_proto(np.float32(data), dtype=tf.float32) - summary.value.add(tag='%s/pr_curves' % name, - metadata=tf_summary_metadata, - tensor=tensor) - return summary + description=None, +): + """Create a PR curves summary protobuf from raw data values. + + Args: + name: A tag attached to the summary. Used by TensorBoard for organization. + true_positive_counts: A rank-1 numpy array of true positive counts. Must + contain `num_thresholds` elements and be castable to float32. + false_positive_counts: A rank-1 numpy array of false positive counts. Must + contain `num_thresholds` elements and be castable to float32. + true_negative_counts: A rank-1 numpy array of true negative counts. Must + contain `num_thresholds` elements and be castable to float32. + false_negative_counts: A rank-1 numpy array of false negative counts. Must + contain `num_thresholds` elements and be castable to float32. + precision: A rank-1 numpy array of precision values. Must contain + `num_thresholds` elements and be castable to float32. + recall: A rank-1 numpy array of recall values. Must contain `num_thresholds` + elements and be castable to float32. + num_thresholds: Number of thresholds, evenly distributed in `[0, 1]`, to + compute PR metrics for. Should be an int `>= 2`. + display_name: Optional name for this summary in TensorBoard, as a `str`. + Defaults to `name`. + description: Optional long-form description for this summary, as a `str`. + Markdown is supported. Defaults to empty. + + Returns: + A summary operation for use in a TensorFlow graph. See docs for the `op` + method for details on the float32 tensor produced by this summary. + """ + # TODO(nickfelt): remove on-demand imports once dep situation is fixed. + import tensorflow.compat.v1 as tf + + if display_name is None: + display_name = name + summary_metadata = metadata.create_summary_metadata( + display_name=display_name if display_name is not None else name, + description=description or "", + num_thresholds=num_thresholds, + ) + tf_summary_metadata = tf.SummaryMetadata.FromString( + summary_metadata.SerializeToString() + ) + summary = tf.Summary() + data = np.stack( + ( + true_positive_counts, + false_positive_counts, + true_negative_counts, + false_negative_counts, + precision, + recall, + ) + ) + tensor = tf.make_tensor_proto(np.float32(data), dtype=tf.float32) + summary.value.add( + tag="%s/pr_curves" % name, metadata=tf_summary_metadata, tensor=tensor + ) + return summary + def _create_tensor_summary( name, @@ -499,40 +534,46 @@ def _create_tensor_summary( num_thresholds=None, display_name=None, description=None, - collections=None): - """A private helper method for generating a tensor summary. - - We use a helper method instead of having `op` directly call `raw_data_op` - to prevent the scope of `raw_data_op` from being embedded within `op`. - - Arguments are the same as for raw_data_op. - - Returns: - A tensor summary that collects data for PR curves. - """ - # TODO(nickfelt): remove on-demand imports once dep situation is fixed. - import tensorflow.compat.v1 as tf - - # Store the number of thresholds within the summary metadata because - # that value is constant for all pr curve summaries with the same tag. - summary_metadata = metadata.create_summary_metadata( - display_name=display_name if display_name is not None else name, - description=description or '', - num_thresholds=num_thresholds) - - # Store values within a tensor. We store them in the order: - # true positives, false positives, true negatives, false - # negatives, precision, and recall. - combined_data = tf.stack([ - tf.cast(true_positive_counts, tf.float32), - tf.cast(false_positive_counts, tf.float32), - tf.cast(true_negative_counts, tf.float32), - tf.cast(false_negative_counts, tf.float32), - tf.cast(precision, tf.float32), - tf.cast(recall, tf.float32)]) - - return tf.summary.tensor_summary( - name='pr_curves', - tensor=combined_data, - collections=collections, - summary_metadata=summary_metadata) + collections=None, +): + """A private helper method for generating a tensor summary. + + We use a helper method instead of having `op` directly call `raw_data_op` + to prevent the scope of `raw_data_op` from being embedded within `op`. + + Arguments are the same as for raw_data_op. + + Returns: + A tensor summary that collects data for PR curves. + """ + # TODO(nickfelt): remove on-demand imports once dep situation is fixed. + import tensorflow.compat.v1 as tf + + # Store the number of thresholds within the summary metadata because + # that value is constant for all pr curve summaries with the same tag. + summary_metadata = metadata.create_summary_metadata( + display_name=display_name if display_name is not None else name, + description=description or "", + num_thresholds=num_thresholds, + ) + + # Store values within a tensor. We store them in the order: + # true positives, false positives, true negatives, false + # negatives, precision, and recall. + combined_data = tf.stack( + [ + tf.cast(true_positive_counts, tf.float32), + tf.cast(false_positive_counts, tf.float32), + tf.cast(true_negative_counts, tf.float32), + tf.cast(false_negative_counts, tf.float32), + tf.cast(precision, tf.float32), + tf.cast(recall, tf.float32), + ] + ) + + return tf.summary.tensor_summary( + name="pr_curves", + tensor=combined_data, + collections=collections, + summary_metadata=summary_metadata, + ) diff --git a/tensorboard/plugins/pr_curve/summary_test.py b/tensorboard/plugins/pr_curve/summary_test.py index 8620ece2c2..ff9c311591 100644 --- a/tensorboard/plugins/pr_curve/summary_test.py +++ b/tensorboard/plugins/pr_curve/summary_test.py @@ -32,408 +32,451 @@ class PrCurveTest(tf.test.TestCase): - - def setUp(self): - super(PrCurveTest, self).setUp() - tf.compat.v1.reset_default_graph() - np.random.seed(42) - - def pb_via_op(self, summary_op, feed_dict=None): - with tf.compat.v1.Session() as sess: - actual_pbtxt = sess.run(summary_op, feed_dict=feed_dict or {}) - actual_proto = summary_pb2.Summary() - actual_proto.ParseFromString(actual_pbtxt) - return actual_proto - - def normalize_summary_pb(self, pb): - """Pass `pb`'s `TensorProto` through a marshalling roundtrip. - `TensorProto`s can be equal in value even if they are not identical - in representation, because data can be stored in either the - `tensor_content` field or the `${dtype}_value` field. This - normalization ensures a canonical form, and should be used before - comparing two `Summary`s for equality. - """ - result = summary_pb2.Summary() - if not isinstance(pb, summary_pb2.Summary): - # pb can come from `pb_via_op` which creates a TB Summary. - pb = test_util.ensure_tb_summary_proto(pb) - result.MergeFrom(pb) - for value in result.value: - if value.HasField('tensor'): - new_tensor = tensor_util.make_tensor_proto( - tensor_util.make_ndarray(value.tensor)) - value.ClearField('tensor') - value.tensor.MergeFrom(new_tensor) - return result - - def compute_and_check_summary_pb(self, - name, - labels, - predictions, - num_thresholds, - weights=None, - display_name=None, - description=None, - feed_dict=None): - """Use both `op` and `pb` to get a summary, asserting equality. - Returns: - a `Summary` protocol buffer - """ - labels_tensor = tf.constant(labels) - predictions_tensor = tf.constant(predictions) - weights_tensor = None if weights is None else tf.constant(weights) - op = summary.op( - name=name, - labels=labels_tensor, - predictions=predictions_tensor, - num_thresholds=num_thresholds, - weights=weights_tensor, - display_name=display_name, - description=description) - pb = self.normalize_summary_pb(summary.pb( - name=name, - labels=labels, - predictions=predictions, - num_thresholds=num_thresholds, - weights=weights, - display_name=display_name, - description=description)) - pb_via_op = self.normalize_summary_pb( - self.pb_via_op(op, feed_dict=feed_dict)) - self.assertProtoEquals(pb, pb_via_op) - return pb - - def verify_float_arrays_are_equal(self, expected, actual): - # We use an absolute error instead of a relative one because the expected - # values are small. The default relative error (trol) of 1e-7 yields many - # undesired test failures. - np.testing.assert_allclose( - expected, actual, rtol=0, atol=1e-7) - - def test_metadata(self): - pb = self.compute_and_check_summary_pb( - name='foo', - labels=np.array([True]), - predictions=np.float32([0.42]), - num_thresholds=3) - summary_metadata = pb.value[0].metadata - plugin_data = summary_metadata.plugin_data - self.assertEqual('foo', summary_metadata.display_name) - self.assertEqual('', summary_metadata.summary_description) - self.assertEqual(metadata.PLUGIN_NAME, plugin_data.plugin_name) - plugin_data = metadata.parse_plugin_metadata( - summary_metadata.plugin_data.content) - self.assertEqual(3, plugin_data.num_thresholds) - - def test_all_true_positives(self): - pb = self.compute_and_check_summary_pb( - name='foo', - labels=np.array([True]), - predictions=np.float32([1]), - num_thresholds=3) - expected = [ - [1.0, 1.0, 1.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [1.0, 1.0, 1.0], - [1.0, 1.0, 1.0], - ] - values = tensor_util.make_ndarray(pb.value[0].tensor) - self.verify_float_arrays_are_equal(expected, values) - - def test_all_true_negatives(self): - pb = self.compute_and_check_summary_pb( - name='foo', - labels=np.array([False]), - predictions=np.float32([0]), - num_thresholds=3) - expected = [ - [0.0, 0.0, 0.0], - [1.0, 0.0, 0.0], - [0.0, 1.0, 1.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - ] - values = tensor_util.make_ndarray(pb.value[0].tensor) - self.verify_float_arrays_are_equal(expected, values) - - def test_all_false_positives(self): - pb = self.compute_and_check_summary_pb( - name='foo', - labels=np.array([False]), - predictions=np.float32([1]), - num_thresholds=3) - expected = [ - [0.0, 0.0, 0.0], - [1.0, 1.0, 1.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - ] - values = tensor_util.make_ndarray(pb.value[0].tensor) - self.verify_float_arrays_are_equal(expected, values) - - def test_all_false_negatives(self): - pb = self.compute_and_check_summary_pb( - name='foo', - labels=np.array([True]), - predictions=np.float32([0]), - num_thresholds=3) - expected = [ - [1.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 1.0, 1.0], - [1.0, 0.0, 0.0], - [1.0, 0.0, 0.0], - ] - values = tensor_util.make_ndarray(pb.value[0].tensor) - self.verify_float_arrays_are_equal(expected, values) - - def test_many_values(self): - pb = self.compute_and_check_summary_pb( - name='foo', - labels=np.array([True, False, False, True, True, True]), - predictions=np.float32([0.2, 0.3, 0.4, 0.6, 0.7, 0.8]), - num_thresholds=3) - expected = [ - [4.0, 3.0, 0.0], - [2.0, 0.0, 0.0], - [0.0, 2.0, 2.0], - [0.0, 1.0, 4.0], - [2.0 / 3.0, 1.0, 0.0], - [1.0, 0.75, 0.0], - ] - values = tensor_util.make_ndarray(pb.value[0].tensor) - self.verify_float_arrays_are_equal(expected, values) - - def test_many_values_with_weights(self): - pb = self.compute_and_check_summary_pb( - name='foo', - labels=np.array([True, False, False, True, True, True]), - predictions=np.float32([0.2, 0.3, 0.4, 0.6, 0.7, 0.8]), - num_thresholds=3, - weights=np.float32([0.0, 0.5, 2.0, 0.0, 0.5, 1.0])) - expected = [ - [1.5, 1.5, 0.0], - [2.5, 0.0, 0.0], - [0.0, 2.5, 2.5], - [0.0, 0.0, 1.5], - [0.375, 1.0, 0.0], - [1.0, 1.0, 0.0] - ] - values = tensor_util.make_ndarray(pb.value[0].tensor) - self.verify_float_arrays_are_equal(expected, values) - - def test_exhaustive_random_values(self): - # Most other tests use small and crafted predictions and labels. - # This test exhaustively generates many data points. - data_points = 420 - pb = self.compute_and_check_summary_pb( - name='foo', - labels=np.random.uniform(size=(data_points,)) > 0.5, - predictions=np.float32(np.random.uniform(size=(data_points,))), - num_thresholds=5) - expected = [ - [218.0, 162.0, 111.0, 55.0, 0.0], - [202.0, 148.0, 98.0, 51.0, 0.0], - [0.0, 54.0, 104.0, 151.0, 202.0], - [0.0, 56.0, 107.0, 163.0, 218.0], - [0.5190476, 0.5225806, 0.5311005, 0.5188679, 0.0], - [1.0, 0.7431192, 0.5091743, 0.2522936, 0.0] - ] - values = tensor_util.make_ndarray(pb.value[0].tensor) - self.verify_float_arrays_are_equal(expected, values) - - def test_counts_below_1(self): - """Tests support for counts below 1. - - Certain weights cause TP, FP, TN, FN counts to be below 1. - """ - pb = self.compute_and_check_summary_pb( - name='foo', - labels=np.array([True, False, False, True, True, True]), - predictions=np.float32([0.2, 0.3, 0.4, 0.6, 0.7, 0.8]), - num_thresholds=3, - weights=np.float32([0.0, 0.1, 0.2, 0.1, 0.1, 0.0])) - expected = [ - [0.2, 0.2, 0.0], - [0.3, 0.0, 0.0], - [0.0, 0.3, 0.3], - [0.0, 0.0, 0.2], - [0.4, 1.0, 0.0], - [1.0, 1.0, 0.0] - ] - values = tensor_util.make_ndarray(pb.value[0].tensor) - self.verify_float_arrays_are_equal(expected, values) - - def test_raw_data(self): - # We pass these raw counts and precision/recall values. - name = 'foo' - true_positive_counts = [75, 64, 21, 5, 0] - false_positive_counts = [150, 105, 18, 0, 0] - true_negative_counts = [0, 45, 132, 150, 150] - false_negative_counts = [0, 11, 54, 70, 75] - precision = [0.3333333, 0.3786982, 0.5384616, 1.0, 0.0] - recall = [1.0, 0.8533334, 0.28, 0.0666667, 0.0] - num_thresholds = 5 - display_name = 'some_raw_values' - description = 'We passed raw values into a summary op.' - - op = summary.raw_data_op( - name=name, - true_positive_counts=tf.constant(true_positive_counts), - false_positive_counts=tf.constant(false_positive_counts), - true_negative_counts=tf.constant(true_negative_counts), - false_negative_counts=tf.constant(false_negative_counts), - precision=tf.constant(precision), - recall=tf.constant(recall), - num_thresholds=num_thresholds, - display_name=display_name, - description=description) - pb_via_op = self.normalize_summary_pb(self.pb_via_op(op)) - - # Call the corresponding method that is decoupled from TensorFlow. - pb = self.normalize_summary_pb(summary.raw_data_pb( - name=name, - true_positive_counts=true_positive_counts, - false_positive_counts=false_positive_counts, - true_negative_counts=true_negative_counts, - false_negative_counts=false_negative_counts, - precision=precision, - recall=recall, - num_thresholds=num_thresholds, - display_name=display_name, - description=description)) - - # The 2 methods above should write summaries with the same data. - self.assertProtoEquals(pb, pb_via_op) - - # Test the metadata. - summary_metadata = pb.value[0].metadata - self.assertEqual('some_raw_values', summary_metadata.display_name) - self.assertEqual( - 'We passed raw values into a summary op.', - summary_metadata.summary_description) - self.assertEqual( - metadata.PLUGIN_NAME, summary_metadata.plugin_data.plugin_name) - - plugin_data = metadata.parse_plugin_metadata( - summary_metadata.plugin_data.content) - self.assertEqual(5, plugin_data.num_thresholds) - - # Test the summary contents. - values = tensor_util.make_ndarray(pb.value[0].tensor) - self.verify_float_arrays_are_equal([ - [75.0, 64.0, 21.0, 5.0, 0.0], # True positives. - [150.0, 105.0, 18.0, 0.0, 0.0], # False positives. - [0.0, 45.0, 132.0, 150.0, 150.0], # True negatives. - [0.0, 11.0, 54.0, 70.0, 75.0], # False negatives. - [0.3333333, 0.3786982, 0.5384616, 1.0, 0.0], # Precision. - [1.0, 0.8533334, 0.28, 0.0666667, 0.0], # Recall. - ], values) + def setUp(self): + super(PrCurveTest, self).setUp() + tf.compat.v1.reset_default_graph() + np.random.seed(42) + + def pb_via_op(self, summary_op, feed_dict=None): + with tf.compat.v1.Session() as sess: + actual_pbtxt = sess.run(summary_op, feed_dict=feed_dict or {}) + actual_proto = summary_pb2.Summary() + actual_proto.ParseFromString(actual_pbtxt) + return actual_proto + + def normalize_summary_pb(self, pb): + """Pass `pb`'s `TensorProto` through a marshalling roundtrip. + + `TensorProto`s can be equal in value even if they are not + identical in representation, because data can be stored in + either the `tensor_content` field or the `${dtype}_value` field. + This normalization ensures a canonical form, and should be used + before comparing two `Summary`s for equality. + """ + result = summary_pb2.Summary() + if not isinstance(pb, summary_pb2.Summary): + # pb can come from `pb_via_op` which creates a TB Summary. + pb = test_util.ensure_tb_summary_proto(pb) + result.MergeFrom(pb) + for value in result.value: + if value.HasField("tensor"): + new_tensor = tensor_util.make_tensor_proto( + tensor_util.make_ndarray(value.tensor) + ) + value.ClearField("tensor") + value.tensor.MergeFrom(new_tensor) + return result + + def compute_and_check_summary_pb( + self, + name, + labels, + predictions, + num_thresholds, + weights=None, + display_name=None, + description=None, + feed_dict=None, + ): + """Use both `op` and `pb` to get a summary, asserting equality. + + Returns: + a `Summary` protocol buffer + """ + labels_tensor = tf.constant(labels) + predictions_tensor = tf.constant(predictions) + weights_tensor = None if weights is None else tf.constant(weights) + op = summary.op( + name=name, + labels=labels_tensor, + predictions=predictions_tensor, + num_thresholds=num_thresholds, + weights=weights_tensor, + display_name=display_name, + description=description, + ) + pb = self.normalize_summary_pb( + summary.pb( + name=name, + labels=labels, + predictions=predictions, + num_thresholds=num_thresholds, + weights=weights, + display_name=display_name, + description=description, + ) + ) + pb_via_op = self.normalize_summary_pb( + self.pb_via_op(op, feed_dict=feed_dict) + ) + self.assertProtoEquals(pb, pb_via_op) + return pb + + def verify_float_arrays_are_equal(self, expected, actual): + # We use an absolute error instead of a relative one because the expected + # values are small. The default relative error (trol) of 1e-7 yields many + # undesired test failures. + np.testing.assert_allclose(expected, actual, rtol=0, atol=1e-7) + + def test_metadata(self): + pb = self.compute_and_check_summary_pb( + name="foo", + labels=np.array([True]), + predictions=np.float32([0.42]), + num_thresholds=3, + ) + summary_metadata = pb.value[0].metadata + plugin_data = summary_metadata.plugin_data + self.assertEqual("foo", summary_metadata.display_name) + self.assertEqual("", summary_metadata.summary_description) + self.assertEqual(metadata.PLUGIN_NAME, plugin_data.plugin_name) + plugin_data = metadata.parse_plugin_metadata( + summary_metadata.plugin_data.content + ) + self.assertEqual(3, plugin_data.num_thresholds) + + def test_all_true_positives(self): + pb = self.compute_and_check_summary_pb( + name="foo", + labels=np.array([True]), + predictions=np.float32([1]), + num_thresholds=3, + ) + expected = [ + [1.0, 1.0, 1.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + ] + values = tensor_util.make_ndarray(pb.value[0].tensor) + self.verify_float_arrays_are_equal(expected, values) + + def test_all_true_negatives(self): + pb = self.compute_and_check_summary_pb( + name="foo", + labels=np.array([False]), + predictions=np.float32([0]), + num_thresholds=3, + ) + expected = [ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 1.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ] + values = tensor_util.make_ndarray(pb.value[0].tensor) + self.verify_float_arrays_are_equal(expected, values) + + def test_all_false_positives(self): + pb = self.compute_and_check_summary_pb( + name="foo", + labels=np.array([False]), + predictions=np.float32([1]), + num_thresholds=3, + ) + expected = [ + [0.0, 0.0, 0.0], + [1.0, 1.0, 1.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ] + values = tensor_util.make_ndarray(pb.value[0].tensor) + self.verify_float_arrays_are_equal(expected, values) + + def test_all_false_negatives(self): + pb = self.compute_and_check_summary_pb( + name="foo", + labels=np.array([True]), + predictions=np.float32([0]), + num_thresholds=3, + ) + expected = [ + [1.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 1.0, 1.0], + [1.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + ] + values = tensor_util.make_ndarray(pb.value[0].tensor) + self.verify_float_arrays_are_equal(expected, values) + + def test_many_values(self): + pb = self.compute_and_check_summary_pb( + name="foo", + labels=np.array([True, False, False, True, True, True]), + predictions=np.float32([0.2, 0.3, 0.4, 0.6, 0.7, 0.8]), + num_thresholds=3, + ) + expected = [ + [4.0, 3.0, 0.0], + [2.0, 0.0, 0.0], + [0.0, 2.0, 2.0], + [0.0, 1.0, 4.0], + [2.0 / 3.0, 1.0, 0.0], + [1.0, 0.75, 0.0], + ] + values = tensor_util.make_ndarray(pb.value[0].tensor) + self.verify_float_arrays_are_equal(expected, values) + + def test_many_values_with_weights(self): + pb = self.compute_and_check_summary_pb( + name="foo", + labels=np.array([True, False, False, True, True, True]), + predictions=np.float32([0.2, 0.3, 0.4, 0.6, 0.7, 0.8]), + num_thresholds=3, + weights=np.float32([0.0, 0.5, 2.0, 0.0, 0.5, 1.0]), + ) + expected = [ + [1.5, 1.5, 0.0], + [2.5, 0.0, 0.0], + [0.0, 2.5, 2.5], + [0.0, 0.0, 1.5], + [0.375, 1.0, 0.0], + [1.0, 1.0, 0.0], + ] + values = tensor_util.make_ndarray(pb.value[0].tensor) + self.verify_float_arrays_are_equal(expected, values) + + def test_exhaustive_random_values(self): + # Most other tests use small and crafted predictions and labels. + # This test exhaustively generates many data points. + data_points = 420 + pb = self.compute_and_check_summary_pb( + name="foo", + labels=np.random.uniform(size=(data_points,)) > 0.5, + predictions=np.float32(np.random.uniform(size=(data_points,))), + num_thresholds=5, + ) + expected = [ + [218.0, 162.0, 111.0, 55.0, 0.0], + [202.0, 148.0, 98.0, 51.0, 0.0], + [0.0, 54.0, 104.0, 151.0, 202.0], + [0.0, 56.0, 107.0, 163.0, 218.0], + [0.5190476, 0.5225806, 0.5311005, 0.5188679, 0.0], + [1.0, 0.7431192, 0.5091743, 0.2522936, 0.0], + ] + values = tensor_util.make_ndarray(pb.value[0].tensor) + self.verify_float_arrays_are_equal(expected, values) + + def test_counts_below_1(self): + """Tests support for counts below 1. + + Certain weights cause TP, FP, TN, FN counts to be below 1. + """ + pb = self.compute_and_check_summary_pb( + name="foo", + labels=np.array([True, False, False, True, True, True]), + predictions=np.float32([0.2, 0.3, 0.4, 0.6, 0.7, 0.8]), + num_thresholds=3, + weights=np.float32([0.0, 0.1, 0.2, 0.1, 0.1, 0.0]), + ) + expected = [ + [0.2, 0.2, 0.0], + [0.3, 0.0, 0.0], + [0.0, 0.3, 0.3], + [0.0, 0.0, 0.2], + [0.4, 1.0, 0.0], + [1.0, 1.0, 0.0], + ] + values = tensor_util.make_ndarray(pb.value[0].tensor) + self.verify_float_arrays_are_equal(expected, values) + + def test_raw_data(self): + # We pass these raw counts and precision/recall values. + name = "foo" + true_positive_counts = [75, 64, 21, 5, 0] + false_positive_counts = [150, 105, 18, 0, 0] + true_negative_counts = [0, 45, 132, 150, 150] + false_negative_counts = [0, 11, 54, 70, 75] + precision = [0.3333333, 0.3786982, 0.5384616, 1.0, 0.0] + recall = [1.0, 0.8533334, 0.28, 0.0666667, 0.0] + num_thresholds = 5 + display_name = "some_raw_values" + description = "We passed raw values into a summary op." + + op = summary.raw_data_op( + name=name, + true_positive_counts=tf.constant(true_positive_counts), + false_positive_counts=tf.constant(false_positive_counts), + true_negative_counts=tf.constant(true_negative_counts), + false_negative_counts=tf.constant(false_negative_counts), + precision=tf.constant(precision), + recall=tf.constant(recall), + num_thresholds=num_thresholds, + display_name=display_name, + description=description, + ) + pb_via_op = self.normalize_summary_pb(self.pb_via_op(op)) + + # Call the corresponding method that is decoupled from TensorFlow. + pb = self.normalize_summary_pb( + summary.raw_data_pb( + name=name, + true_positive_counts=true_positive_counts, + false_positive_counts=false_positive_counts, + true_negative_counts=true_negative_counts, + false_negative_counts=false_negative_counts, + precision=precision, + recall=recall, + num_thresholds=num_thresholds, + display_name=display_name, + description=description, + ) + ) + + # The 2 methods above should write summaries with the same data. + self.assertProtoEquals(pb, pb_via_op) + + # Test the metadata. + summary_metadata = pb.value[0].metadata + self.assertEqual("some_raw_values", summary_metadata.display_name) + self.assertEqual( + "We passed raw values into a summary op.", + summary_metadata.summary_description, + ) + self.assertEqual( + metadata.PLUGIN_NAME, summary_metadata.plugin_data.plugin_name + ) + + plugin_data = metadata.parse_plugin_metadata( + summary_metadata.plugin_data.content + ) + self.assertEqual(5, plugin_data.num_thresholds) + + # Test the summary contents. + values = tensor_util.make_ndarray(pb.value[0].tensor) + self.verify_float_arrays_are_equal( + [ + [75.0, 64.0, 21.0, 5.0, 0.0], # True positives. + [150.0, 105.0, 18.0, 0.0, 0.0], # False positives. + [0.0, 45.0, 132.0, 150.0, 150.0], # True negatives. + [0.0, 11.0, 54.0, 70.0, 75.0], # False negatives. + [0.3333333, 0.3786982, 0.5384616, 1.0, 0.0], # Precision. + [1.0, 0.8533334, 0.28, 0.0666667, 0.0], # Recall. + ], + values, + ) class StreamingOpTest(tf.test.TestCase): - - def setUp(self): - super(StreamingOpTest, self).setUp() - tf.compat.v1.reset_default_graph() - np.random.seed(1) - - def pb_via_op(self, summary_op): - actual_pbtxt = summary_op.eval() - actual_proto = summary_pb2.Summary() - actual_proto.ParseFromString(actual_pbtxt) - return actual_proto - - def tensor_via_op(self, summary_op): - actual_pbtxt = summary_op.eval() - actual_proto = summary_pb2.Summary() - actual_proto.ParseFromString(actual_pbtxt) - return actual_proto - - def test_matches_op(self): - predictions = tf.constant([0.2, 0.4, 0.5, 0.6, 0.8], dtype=tf.float32) - labels = tf.constant([False, True, True, False, True], dtype=tf.bool) - - pr_curve, update_op = summary.streaming_op(name='pr_curve', - predictions=predictions, - labels=labels, - num_thresholds=10) - expected_pr_curve = summary.op(name='pr_curve', - predictions=predictions, - labels=labels, - num_thresholds=10) - with self.test_session() as sess: - sess.run(tf.compat.v1.local_variables_initializer()) - sess.run([update_op]) - - proto = self.pb_via_op(pr_curve) - expected_proto = self.pb_via_op(expected_pr_curve) - - # Need to detect and fix the automatic _1 appended to second namespace. - self.assertEqual(proto.value[0].tag, 'pr_curve/pr_curves') - self.assertEqual(expected_proto.value[0].tag, 'pr_curve_1/pr_curves') - expected_proto.value[0].tag = 'pr_curve/pr_curves' - - self.assertProtoEquals(expected_proto, proto) - - def test_matches_op_with_updates(self): - predictions = tf.constant([0.2, 0.4, 0.5, 0.6, 0.8], dtype=tf.float32) - labels = tf.constant([False, True, True, False, True], dtype=tf.bool) - pr_curve, update_op = summary.streaming_op(name='pr_curve', - predictions=predictions, - labels=labels, - num_thresholds=10) - - complete_predictions = tf.tile(predictions, [3]) - complete_labels = tf.tile(labels, [3]) - expected_pr_curve = summary.op(name='pr_curve', - predictions=complete_predictions, - labels=complete_labels, - num_thresholds=10) - with self.test_session() as sess: - sess.run(tf.compat.v1.local_variables_initializer()) - sess.run([update_op]) - sess.run([update_op]) - sess.run([update_op]) - - proto = self.pb_via_op(pr_curve) - expected_proto = self.pb_via_op(expected_pr_curve) - - # Need to detect and fix the automatic _1 appended to second namespace. - self.assertEqual(proto.value[0].tag, 'pr_curve/pr_curves') - self.assertEqual(expected_proto.value[0].tag, 'pr_curve_1/pr_curves') - expected_proto.value[0].tag = 'pr_curve/pr_curves' - - self.assertProtoEquals(expected_proto, proto) - - def test_only_1_summary_generated(self): - """Tests that the streaming op only generates 1 summary for PR curves. - - This test was made in response to a bug in which calling the streaming op - actually introduced 2 tags. - """ - predictions = tf.constant([0.2, 0.4, 0.5, 0.6, 0.8], dtype=tf.float32) - labels = tf.constant([False, True, True, False, True], dtype=tf.bool) - _, update_op = summary.streaming_op(name='pr_curve', - predictions=predictions, - labels=labels, - num_thresholds=10) - with self.test_session() as sess: - sess.run(tf.compat.v1.local_variables_initializer()) - sess.run(update_op) - summary_proto = summary_pb2.Summary() - summary_proto.ParseFromString(sess.run(tf.compat.v1.summary.merge_all())) - - tags = [v.tag for v in summary_proto.value] - # Only 1 tag should have been introduced. - self.assertEqual(['pr_curve/pr_curves'], tags) + def setUp(self): + super(StreamingOpTest, self).setUp() + tf.compat.v1.reset_default_graph() + np.random.seed(1) + + def pb_via_op(self, summary_op): + actual_pbtxt = summary_op.eval() + actual_proto = summary_pb2.Summary() + actual_proto.ParseFromString(actual_pbtxt) + return actual_proto + + def tensor_via_op(self, summary_op): + actual_pbtxt = summary_op.eval() + actual_proto = summary_pb2.Summary() + actual_proto.ParseFromString(actual_pbtxt) + return actual_proto + + def test_matches_op(self): + predictions = tf.constant([0.2, 0.4, 0.5, 0.6, 0.8], dtype=tf.float32) + labels = tf.constant([False, True, True, False, True], dtype=tf.bool) + + pr_curve, update_op = summary.streaming_op( + name="pr_curve", + predictions=predictions, + labels=labels, + num_thresholds=10, + ) + expected_pr_curve = summary.op( + name="pr_curve", + predictions=predictions, + labels=labels, + num_thresholds=10, + ) + with self.test_session() as sess: + sess.run(tf.compat.v1.local_variables_initializer()) + sess.run([update_op]) + + proto = self.pb_via_op(pr_curve) + expected_proto = self.pb_via_op(expected_pr_curve) + + # Need to detect and fix the automatic _1 appended to second namespace. + self.assertEqual(proto.value[0].tag, "pr_curve/pr_curves") + self.assertEqual( + expected_proto.value[0].tag, "pr_curve_1/pr_curves" + ) + expected_proto.value[0].tag = "pr_curve/pr_curves" + + self.assertProtoEquals(expected_proto, proto) + + def test_matches_op_with_updates(self): + predictions = tf.constant([0.2, 0.4, 0.5, 0.6, 0.8], dtype=tf.float32) + labels = tf.constant([False, True, True, False, True], dtype=tf.bool) + pr_curve, update_op = summary.streaming_op( + name="pr_curve", + predictions=predictions, + labels=labels, + num_thresholds=10, + ) + + complete_predictions = tf.tile(predictions, [3]) + complete_labels = tf.tile(labels, [3]) + expected_pr_curve = summary.op( + name="pr_curve", + predictions=complete_predictions, + labels=complete_labels, + num_thresholds=10, + ) + with self.test_session() as sess: + sess.run(tf.compat.v1.local_variables_initializer()) + sess.run([update_op]) + sess.run([update_op]) + sess.run([update_op]) + + proto = self.pb_via_op(pr_curve) + expected_proto = self.pb_via_op(expected_pr_curve) + + # Need to detect and fix the automatic _1 appended to second namespace. + self.assertEqual(proto.value[0].tag, "pr_curve/pr_curves") + self.assertEqual( + expected_proto.value[0].tag, "pr_curve_1/pr_curves" + ) + expected_proto.value[0].tag = "pr_curve/pr_curves" + + self.assertProtoEquals(expected_proto, proto) + + def test_only_1_summary_generated(self): + """Tests that the streaming op only generates 1 summary for PR curves. + + This test was made in response to a bug in which calling the + streaming op actually introduced 2 tags. + """ + predictions = tf.constant([0.2, 0.4, 0.5, 0.6, 0.8], dtype=tf.float32) + labels = tf.constant([False, True, True, False, True], dtype=tf.bool) + _, update_op = summary.streaming_op( + name="pr_curve", + predictions=predictions, + labels=labels, + num_thresholds=10, + ) + with self.test_session() as sess: + sess.run(tf.compat.v1.local_variables_initializer()) + sess.run(update_op) + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString( + sess.run(tf.compat.v1.summary.merge_all()) + ) + + tags = [v.tag for v in summary_proto.value] + # Only 1 tag should have been introduced. + self.assertEqual(["pr_curve/pr_curves"], tags) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/tensorboard/plugins/profile/profile_demo.py b/tensorboard/plugins/profile/profile_demo.py index 99b13df537..2efcf5c969 100644 --- a/tensorboard/plugins/profile/profile_demo.py +++ b/tensorboard/plugins/profile/profile_demo.py @@ -42,71 +42,78 @@ # Directory into which to write tensorboard data. -LOGDIR = '/tmp/profile_demo' +LOGDIR = "/tmp/profile_demo" # Suffix for the empty eventfile to write. Should be kept in sync with TF # profiler kProfileEmptySuffix constant defined in: # tensorflow/core/profiler/rpc/client/capture_profile.cc. -EVENT_FILE_SUFFIX = '.profile-empty' +EVENT_FILE_SUFFIX = ".profile-empty" def _maybe_create_directory(directory): - try: - os.makedirs(directory) - except OSError: - print('Directory %s already exists.' %directory) + try: + os.makedirs(directory) + except OSError: + print("Directory %s already exists." % directory) def write_empty_event_file(logdir): - w = tf.compat.v2.summary.create_file_writer( - logdir, filename_suffix=EVENT_FILE_SUFFIX) - w.close() + w = tf.compat.v2.summary.create_file_writer( + logdir, filename_suffix=EVENT_FILE_SUFFIX + ) + w.close() def dump_data(logdir): - """Dumps plugin data to the log directory.""" - # Create a tfevents file in the logdir so it is detected as a run. - write_empty_event_file(logdir) - - plugin_logdir = plugin_asset_util.PluginDirectory( - logdir, profile_plugin.ProfilePlugin.plugin_name) - _maybe_create_directory(plugin_logdir) - - for run in profile_demo_data.RUNS: - run_dir = os.path.join(plugin_logdir, run) + """Dumps plugin data to the log directory.""" + # Create a tfevents file in the logdir so it is detected as a run. + write_empty_event_file(logdir) + + plugin_logdir = plugin_asset_util.PluginDirectory( + logdir, profile_plugin.ProfilePlugin.plugin_name + ) + _maybe_create_directory(plugin_logdir) + + for run in profile_demo_data.RUNS: + run_dir = os.path.join(plugin_logdir, run) + _maybe_create_directory(run_dir) + if run in profile_demo_data.TRACES: + with open(os.path.join(run_dir, "trace"), "w") as f: + proto = trace_events_pb2.Trace() + text_format.Merge(profile_demo_data.TRACES[run], proto) + f.write(proto.SerializeToString()) + + if run not in profile_demo_data.TRACE_ONLY: + shutil.copyfile( + "tensorboard/plugins/profile/profile_demo.op_profile.json", + os.path.join(run_dir, "op_profile.json"), + ) + shutil.copyfile( + "tensorboard/plugins/profile/profile_demo.memory_viewer.json", + os.path.join(run_dir, "memory_viewer.json"), + ) + shutil.copyfile( + "tensorboard/plugins/profile/profile_demo.pod_viewer.json", + os.path.join(run_dir, "pod_viewer.json"), + ) + shutil.copyfile( + "tensorboard/plugins/profile/profile_demo.google_chart_demo.json", + os.path.join(run_dir, "google_chart_demo.json"), + ) + + # Unsupported tool data should not be displayed. + run_dir = os.path.join(plugin_logdir, "empty") _maybe_create_directory(run_dir) - if run in profile_demo_data.TRACES: - with open(os.path.join(run_dir, 'trace'), 'w') as f: - proto = trace_events_pb2.Trace() - text_format.Merge(profile_demo_data.TRACES[run], proto) - f.write(proto.SerializeToString()) - - if run not in profile_demo_data.TRACE_ONLY: - shutil.copyfile('tensorboard/plugins/profile/profile_demo.op_profile.json', - os.path.join(run_dir, 'op_profile.json')) - shutil.copyfile( - 'tensorboard/plugins/profile/profile_demo.memory_viewer.json', - os.path.join(run_dir, 'memory_viewer.json')) - shutil.copyfile( - 'tensorboard/plugins/profile/profile_demo.pod_viewer.json', - os.path.join(run_dir, 'pod_viewer.json')) - shutil.copyfile( - 'tensorboard/plugins/profile/profile_demo.google_chart_demo.json', - os.path.join(run_dir, 'google_chart_demo.json')) - - # Unsupported tool data should not be displayed. - run_dir = os.path.join(plugin_logdir, 'empty') - _maybe_create_directory(run_dir) - with open(os.path.join(run_dir, 'unsupported'), 'w') as f: - f.write('unsupported data') + with open(os.path.join(run_dir, "unsupported"), "w") as f: + f.write("unsupported data") def main(unused_argv): - print('Saving output to %s.' % LOGDIR) - dump_data(LOGDIR) - print('Done. Output saved to %s.' % LOGDIR) + print("Saving output to %s." % LOGDIR) + dump_data(LOGDIR) + print("Done. Output saved to %s." % LOGDIR) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/tensorboard/plugins/profile/profile_demo_data.py b/tensorboard/plugins/profile/profile_demo_data.py index 2d8c93adde..d428d8da1a 100644 --- a/tensorboard/plugins/profile/profile_demo_data.py +++ b/tensorboard/plugins/profile/profile_demo_data.py @@ -12,21 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Demo data for the profile dashboard""" +"""Demo data for the profile dashboard.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -RUNS = ['foo', 'bar'] +RUNS = ["foo", "bar"] -TRACE_ONLY = ['foo'] +TRACE_ONLY = ["foo"] TRACES = {} -TRACES['foo'] = """ +TRACES[ + "foo" +] = """ devices { key: 2 value { name: 'Foo2' device_id: 2 @@ -61,7 +63,9 @@ """ -TRACES['bar'] = """ +TRACES[ + "bar" +] = """ devices { key: 2 value { name: 'Bar2' device_id: 2 diff --git a/tensorboard/plugins/profile/profile_plugin.py b/tensorboard/plugins/profile/profile_plugin.py index cf7bab56eb..6258e5eac8 100644 --- a/tensorboard/plugins/profile/profile_plugin.py +++ b/tensorboard/plugins/profile/profile_plugin.py @@ -38,443 +38,504 @@ logger = tb_logging.get_logger() # The prefix of routes provided by this plugin. -PLUGIN_NAME = 'profile' +PLUGIN_NAME = "profile" # HTTP routes -DATA_ROUTE = '/data' -TOOLS_ROUTE = '/tools' -HOSTS_ROUTE = '/hosts' -CAPTURE_ROUTE = '/capture_profile' +DATA_ROUTE = "/data" +TOOLS_ROUTE = "/tools" +HOSTS_ROUTE = "/hosts" +CAPTURE_ROUTE = "/capture_profile" # Available profiling tools -> file name of the tool data. -_FILE_NAME = 'TOOL_FILE_NAME' +_FILE_NAME = "TOOL_FILE_NAME" TOOLS = { - 'trace_viewer': 'trace', - 'trace_viewer@': 'tracetable', #streaming traceviewer - 'op_profile': 'op_profile.json', - 'input_pipeline_analyzer': 'input_pipeline.json', - 'overview_page': 'overview_page.json', - 'memory_viewer': 'memory_viewer.json', - 'pod_viewer': 'pod_viewer.json', - 'google_chart_demo': 'google_chart_demo.json', + "trace_viewer": "trace", + "trace_viewer@": "tracetable", # streaming traceviewer + "op_profile": "op_profile.json", + "input_pipeline_analyzer": "input_pipeline.json", + "overview_page": "overview_page.json", + "memory_viewer": "memory_viewer.json", + "pod_viewer": "pod_viewer.json", + "google_chart_demo": "google_chart_demo.json", } # Tools that consume raw data. -_RAW_DATA_TOOLS = frozenset(['input_pipeline_analyzer', - 'op_profile', - 'overview_page', - 'memory_viewer', - 'pod_viewer', - 'google_chart_demo',]) +_RAW_DATA_TOOLS = frozenset( + [ + "input_pipeline_analyzer", + "op_profile", + "overview_page", + "memory_viewer", + "pod_viewer", + "google_chart_demo", + ] +) + def process_raw_trace(raw_trace): - """Processes raw trace data and returns the UI data.""" - trace = trace_events_pb2.Trace() - trace.ParseFromString(raw_trace) - return ''.join(trace_events_json.TraceEventsJsonStream(trace)) + """Processes raw trace data and returns the UI data.""" + trace = trace_events_pb2.Trace() + trace.ParseFromString(raw_trace) + return "".join(trace_events_json.TraceEventsJsonStream(trace)) + def get_worker_list(cluster_resolver): - """Parses TPU workers list from the cluster resolver.""" - cluster_spec = cluster_resolver.cluster_spec() - task_indices = cluster_spec.task_indices('worker') - worker_list = [ - cluster_spec.task_address('worker', i).split(':')[0] for i in task_indices - ] - return ','.join(worker_list) + """Parses TPU workers list from the cluster resolver.""" + cluster_spec = cluster_resolver.cluster_spec() + task_indices = cluster_spec.task_indices("worker") + worker_list = [ + cluster_spec.task_address("worker", i).split(":")[0] + for i in task_indices + ] + return ",".join(worker_list) + class ProfilePlugin(base_plugin.TBPlugin): - """Profile Plugin for TensorBoard.""" - - plugin_name = PLUGIN_NAME - - def __init__(self, context): - """Constructs a profiler plugin for TensorBoard. - - This plugin adds handlers for performance-related frontends. - - Args: - context: A base_plugin.TBContext instance. - """ - self.logdir = context.logdir - self.multiplexer = context.multiplexer - self.plugin_logdir = plugin_asset_util.PluginDirectory( - self.logdir, PLUGIN_NAME) - self.stub = None - self.master_tpu_unsecure_channel = context.flags.master_tpu_unsecure_channel - - # Whether the plugin is active. This is an expensive computation, so we - # compute this asynchronously and cache positive results indefinitely. - self._is_active = False - # Lock to ensure at most one thread computes _is_active at a time. - self._is_active_lock = threading.Lock() - - def is_active(self): - """Whether this plugin is active and has any profile data to show. - - Detecting profile data is expensive, so this process runs asynchronously - and the value reported by this method is the cached value and may be stale. - - Returns: - Whether any run has profile data. - """ - # If we are already active, we remain active and don't recompute this. - # Otherwise, try to acquire the lock without blocking; if we get it and - # we're still not active, launch a thread to check if we're active and - # release the lock once the computation is finished. Either way, this - # thread returns the current cached value to avoid blocking. - if not self._is_active and self._is_active_lock.acquire(False): - if self._is_active: - self._is_active_lock.release() - else: - def compute_is_active(): - self._is_active = any(self.generate_run_to_tools()) - self._is_active_lock.release() - new_thread = threading.Thread( - target=compute_is_active, - name='ProfilePluginIsActiveThread') - new_thread.start() - return self._is_active - - def frontend_metadata(self): - # TODO(#2338): Keep this in sync with the `registerDashboard` call - # on the frontend until that call is removed. - return base_plugin.FrontendMetadata( - element_name='tf-profile-dashboard', - disable_reload=True, - ) - - def start_grpc_stub_if_necessary(self): - # We will enable streaming trace viewer on two conditions: - # 1. user specify the flags master_tpu_unsecure_channel to the ip address of - # as "master" TPU. grpc will be used to fetch streaming trace data. - # 2. the logdir is on google cloud storage. - if self.master_tpu_unsecure_channel and self.logdir.startswith('gs://'): - if self.stub is None: - import grpc - from tensorflow.python.tpu.profiler import profiler_analysis_pb2_grpc - # Workaround the grpc's 4MB message limitation. - gigabyte = 1024 * 1024 * 1024 - options = [('grpc.max_message_length', gigabyte), - ('grpc.max_send_message_length', gigabyte), - ('grpc.max_receive_message_length', gigabyte)] - tpu_profiler_port = self.master_tpu_unsecure_channel + ':8466' - channel = grpc.insecure_channel(tpu_profiler_port, options) - self.stub = profiler_analysis_pb2_grpc.ProfileAnalysisStub(channel) - - def _run_dir(self, run): - """Helper that maps a frontend run name to a profile "run" directory. - - The frontend run name consists of the TensorBoard run name (aka the relative - path from the logdir root to the directory containing the data) path-joined - to the Profile plugin's "run" concept (which is a subdirectory of the - plugins/profile directory representing an individual run of the tool), with - the special case that TensorBoard run is the logdir root (which is the run - named '.') then only the Profile plugin "run" name is used, for backwards - compatibility. - - To convert back to the actual run directory, we apply the following - transformation: - - If the run name doesn't contain '/', prepend './' - - Split on the rightmost instance of '/' - - Assume the left side is a TensorBoard run name and map it to a directory - path using EventMultiplexer.RunPaths(), then map that to the profile - plugin directory via PluginDirectory() - - Assume the right side is a Profile plugin "run" and path-join it to - the preceding path to get the final directory - - Args: - run: the frontend run name, as described above, e.g. train/run1. - - Returns: - The resolved directory path, e.g. /logdir/train/plugins/profile/run1. - """ - run = run.rstrip('/') - if '/' not in run: - run = './' + run - tb_run_name, _, profile_run_name = run.rpartition('/') - tb_run_directory = self.multiplexer.RunPaths().get(tb_run_name) - if tb_run_directory is None: - # Check if logdir is a directory to handle case where it's actually a - # multipart directory spec, which this plugin does not support. - if tb_run_name == '.' and tf.io.gfile.isdir(self.logdir): - tb_run_directory = self.logdir - else: - raise RuntimeError("No matching run directory for run %s" % run) - plugin_directory = plugin_asset_util.PluginDirectory( - tb_run_directory, PLUGIN_NAME) - return os.path.join(plugin_directory, profile_run_name) - - def generate_run_to_tools(self): - """Generator for pairs of "run name" and a list of tools for that run. - - The "run name" here is a "frontend run name" - see _run_dir() for the - definition of a "frontend run name" and how it maps to a directory of - profile data for a specific profile "run". The profile plugin concept of - "run" is different from the normal TensorBoard run; each run in this case - represents a single instance of profile data collection, more similar to a - "step" of data in typical TensorBoard semantics. These runs reside in - subdirectories of the plugins/profile directory within any regular - TensorBoard run directory (defined as a subdirectory of the logdir that - contains at least one tfevents file) or within the logdir root directory - itself (even if it contains no tfevents file and would thus not be - considered a normal TensorBoard run, for backwards compatibility). - - Within those "profile run directories", there are files in the directory - that correspond to different profiling tools. The file that contains profile - for a specific tool "x" will have a suffix name TOOLS["x"]. - - Example: - logs/ - plugins/ - profile/ + """Profile Plugin for TensorBoard.""" + + plugin_name = PLUGIN_NAME + + def __init__(self, context): + """Constructs a profiler plugin for TensorBoard. + + This plugin adds handlers for performance-related frontends. + + Args: + context: A base_plugin.TBContext instance. + """ + self.logdir = context.logdir + self.multiplexer = context.multiplexer + self.plugin_logdir = plugin_asset_util.PluginDirectory( + self.logdir, PLUGIN_NAME + ) + self.stub = None + self.master_tpu_unsecure_channel = ( + context.flags.master_tpu_unsecure_channel + ) + + # Whether the plugin is active. This is an expensive computation, so we + # compute this asynchronously and cache positive results indefinitely. + self._is_active = False + # Lock to ensure at most one thread computes _is_active at a time. + self._is_active_lock = threading.Lock() + + def is_active(self): + """Whether this plugin is active and has any profile data to show. + + Detecting profile data is expensive, so this process runs asynchronously + and the value reported by this method is the cached value and may be stale. + + Returns: + Whether any run has profile data. + """ + # If we are already active, we remain active and don't recompute this. + # Otherwise, try to acquire the lock without blocking; if we get it and + # we're still not active, launch a thread to check if we're active and + # release the lock once the computation is finished. Either way, this + # thread returns the current cached value to avoid blocking. + if not self._is_active and self._is_active_lock.acquire(False): + if self._is_active: + self._is_active_lock.release() + else: + + def compute_is_active(): + self._is_active = any(self.generate_run_to_tools()) + self._is_active_lock.release() + + new_thread = threading.Thread( + target=compute_is_active, name="ProfilePluginIsActiveThread" + ) + new_thread.start() + return self._is_active + + def frontend_metadata(self): + # TODO(#2338): Keep this in sync with the `registerDashboard` call + # on the frontend until that call is removed. + return base_plugin.FrontendMetadata( + element_name="tf-profile-dashboard", disable_reload=True, + ) + + def start_grpc_stub_if_necessary(self): + # We will enable streaming trace viewer on two conditions: + # 1. user specify the flags master_tpu_unsecure_channel to the ip address of + # as "master" TPU. grpc will be used to fetch streaming trace data. + # 2. the logdir is on google cloud storage. + if self.master_tpu_unsecure_channel and self.logdir.startswith("gs://"): + if self.stub is None: + import grpc + from tensorflow.python.tpu.profiler import ( + profiler_analysis_pb2_grpc, + ) + + # Workaround the grpc's 4MB message limitation. + gigabyte = 1024 * 1024 * 1024 + options = [ + ("grpc.max_message_length", gigabyte), + ("grpc.max_send_message_length", gigabyte), + ("grpc.max_receive_message_length", gigabyte), + ] + tpu_profiler_port = self.master_tpu_unsecure_channel + ":8466" + channel = grpc.insecure_channel(tpu_profiler_port, options) + self.stub = profiler_analysis_pb2_grpc.ProfileAnalysisStub( + channel + ) + + def _run_dir(self, run): + """Helper that maps a frontend run name to a profile "run" directory. + + The frontend run name consists of the TensorBoard run name (aka the relative + path from the logdir root to the directory containing the data) path-joined + to the Profile plugin's "run" concept (which is a subdirectory of the + plugins/profile directory representing an individual run of the tool), with + the special case that TensorBoard run is the logdir root (which is the run + named '.') then only the Profile plugin "run" name is used, for backwards + compatibility. + + To convert back to the actual run directory, we apply the following + transformation: + - If the run name doesn't contain '/', prepend './' + - Split on the rightmost instance of '/' + - Assume the left side is a TensorBoard run name and map it to a directory + path using EventMultiplexer.RunPaths(), then map that to the profile + plugin directory via PluginDirectory() + - Assume the right side is a Profile plugin "run" and path-join it to + the preceding path to get the final directory + + Args: + run: the frontend run name, as described above, e.g. train/run1. + + Returns: + The resolved directory path, e.g. /logdir/train/plugins/profile/run1. + """ + run = run.rstrip("/") + if "/" not in run: + run = "./" + run + tb_run_name, _, profile_run_name = run.rpartition("/") + tb_run_directory = self.multiplexer.RunPaths().get(tb_run_name) + if tb_run_directory is None: + # Check if logdir is a directory to handle case where it's actually a + # multipart directory spec, which this plugin does not support. + if tb_run_name == "." and tf.io.gfile.isdir(self.logdir): + tb_run_directory = self.logdir + else: + raise RuntimeError("No matching run directory for run %s" % run) + plugin_directory = plugin_asset_util.PluginDirectory( + tb_run_directory, PLUGIN_NAME + ) + return os.path.join(plugin_directory, profile_run_name) + + def generate_run_to_tools(self): + """Generator for pairs of "run name" and a list of tools for that run. + + The "run name" here is a "frontend run name" - see _run_dir() for the + definition of a "frontend run name" and how it maps to a directory of + profile data for a specific profile "run". The profile plugin concept of + "run" is different from the normal TensorBoard run; each run in this case + represents a single instance of profile data collection, more similar to a + "step" of data in typical TensorBoard semantics. These runs reside in + subdirectories of the plugins/profile directory within any regular + TensorBoard run directory (defined as a subdirectory of the logdir that + contains at least one tfevents file) or within the logdir root directory + itself (even if it contains no tfevents file and would thus not be + considered a normal TensorBoard run, for backwards compatibility). + + Within those "profile run directories", there are files in the directory + that correspond to different profiling tools. The file that contains profile + for a specific tool "x" will have a suffix name TOOLS["x"]. + + Example: + logs/ + plugins/ + profile/ + run1/ + hostA.trace + train/ + events.out.tfevents.foo + plugins/ + profile/ + run1/ + hostA.trace + hostB.trace + run2/ + hostA.trace + validation/ + events.out.tfevents.foo + plugins/ + profile/ + run1/ + hostA.trace + + Yields: + A sequence of tuples mapping "frontend run names" to lists of tool names + available for those runs. For the above example, this would be: + + ("run1", ["trace_viewer"]) + ("train/run1", ["trace_viewer"]) + ("train/run2", ["trace_viewer"]) + ("validation/run1", ["trace_viewer"]) + """ + self.start_grpc_stub_if_necessary() + + plugin_assets = self.multiplexer.PluginAssets(PLUGIN_NAME) + tb_run_names_to_dirs = self.multiplexer.RunPaths() + + # Ensure that we also check the root logdir, even if it isn't a recognized + # TensorBoard run (i.e. has no tfevents file directly under it), to remain + # backwards compatible with previously profile plugin behavior. Note that we + # check if logdir is a directory to handle case where it's actually a + # multipart directory spec, which this plugin does not support. + if "." not in plugin_assets and tf.io.gfile.isdir(self.logdir): + tb_run_names_to_dirs["."] = self.logdir + plugin_assets["."] = plugin_asset_util.ListAssets( + self.logdir, PLUGIN_NAME + ) + + for tb_run_name, profile_runs in six.iteritems(plugin_assets): + tb_run_dir = tb_run_names_to_dirs[tb_run_name] + tb_plugin_dir = plugin_asset_util.PluginDirectory( + tb_run_dir, PLUGIN_NAME + ) + for profile_run in profile_runs: + # Remove trailing slash; some filesystem implementations emit this. + profile_run = profile_run.rstrip("/") + if tb_run_name == ".": + frontend_run = profile_run + else: + frontend_run = "/".join([tb_run_name, profile_run]) + profile_run_dir = os.path.join(tb_plugin_dir, profile_run) + if tf.io.gfile.isdir(profile_run_dir): + yield frontend_run, self._get_active_tools(profile_run_dir) + + def _get_active_tools(self, profile_run_dir): + tools = [] + for tool in TOOLS: + tool_pattern = "*" + TOOLS[tool] + path = os.path.join(profile_run_dir, tool_pattern) + try: + files = tf.io.gfile.glob(path) + if len(files) >= 1: + tools.append(tool) + except tf.errors.OpError as e: + logger.warn( + "Cannot read asset directory: %s, OpError %s", + profile_run_dir, + e, + ) + if "trace_viewer@" in tools: + # streaming trace viewer always override normal trace viewer. + # the trailing '@' is to inform tf-profile-dashboard.html and + # tf-trace-viewer.html that stream trace viewer should be used. + removed_tool = ( + "trace_viewer@" if self.stub is None else "trace_viewer" + ) + if removed_tool in tools: + tools.remove(removed_tool) + tools.sort() + op = "overview_page" + if op in tools: + # keep overview page at the top of the list + tools.remove(op) + tools.insert(0, op) + return tools + + @wrappers.Request.application + def tools_route(self, request): + run_to_tools = dict(self.generate_run_to_tools()) + return http_util.Respond(request, run_to_tools, "application/json") + + def host_impl(self, run, tool): + """Returns available hosts for the run and tool in the log directory. + + In the plugin log directory, each directory contains profile data for a + single run (identified by the directory name), and files in the run + directory contains data for different tools and hosts. The file that + contains profile for a specific tool "x" will have a prefix name TOOLS["x"]. + + Example: + log/ run1/ - hostA.trace - train/ - events.out.tfevents.foo - plugins/ - profile/ - run1/ - hostA.trace - hostB.trace - run2/ - hostA.trace - validation/ - events.out.tfevents.foo - plugins/ - profile/ - run1/ - hostA.trace - - Yields: - A sequence of tuples mapping "frontend run names" to lists of tool names - available for those runs. For the above example, this would be: - - ("run1", ["trace_viewer"]) - ("train/run1", ["trace_viewer"]) - ("train/run2", ["trace_viewer"]) - ("validation/run1", ["trace_viewer"]) - """ - self.start_grpc_stub_if_necessary() - - plugin_assets = self.multiplexer.PluginAssets(PLUGIN_NAME) - tb_run_names_to_dirs = self.multiplexer.RunPaths() - - # Ensure that we also check the root logdir, even if it isn't a recognized - # TensorBoard run (i.e. has no tfevents file directly under it), to remain - # backwards compatible with previously profile plugin behavior. Note that we - # check if logdir is a directory to handle case where it's actually a - # multipart directory spec, which this plugin does not support. - if '.' not in plugin_assets and tf.io.gfile.isdir(self.logdir): - tb_run_names_to_dirs['.'] = self.logdir - plugin_assets['.'] = plugin_asset_util.ListAssets( - self.logdir, PLUGIN_NAME) - - for tb_run_name, profile_runs in six.iteritems(plugin_assets): - tb_run_dir = tb_run_names_to_dirs[tb_run_name] - tb_plugin_dir = plugin_asset_util.PluginDirectory( - tb_run_dir, PLUGIN_NAME) - for profile_run in profile_runs: - # Remove trailing slash; some filesystem implementations emit this. - profile_run = profile_run.rstrip('/') - if tb_run_name == '.': - frontend_run = profile_run - else: - frontend_run = '/'.join([tb_run_name, profile_run]) - profile_run_dir = os.path.join(tb_plugin_dir, profile_run) - if tf.io.gfile.isdir(profile_run_dir): - yield frontend_run, self._get_active_tools(profile_run_dir) - - def _get_active_tools(self, profile_run_dir): - tools = [] - for tool in TOOLS: - tool_pattern = '*' + TOOLS[tool] - path = os.path.join(profile_run_dir, tool_pattern) - try: - files = tf.io.gfile.glob(path) - if len(files) >= 1: - tools.append(tool) - except tf.errors.OpError as e: - logger.warn("Cannot read asset directory: %s, OpError %s", - profile_run_dir, e) - if 'trace_viewer@' in tools: - # streaming trace viewer always override normal trace viewer. - # the trailing '@' is to inform tf-profile-dashboard.html and - # tf-trace-viewer.html that stream trace viewer should be used. - removed_tool = 'trace_viewer@' if self.stub is None else 'trace_viewer' - if removed_tool in tools: - tools.remove(removed_tool) - tools.sort() - op = 'overview_page' - if op in tools: - # keep overview page at the top of the list - tools.remove(op) - tools.insert(0, op) - return tools - - @wrappers.Request.application - def tools_route(self, request): - run_to_tools = dict(self.generate_run_to_tools()) - return http_util.Respond(request, run_to_tools, 'application/json') - - def host_impl(self, run, tool): - """Returns available hosts for the run and tool in the log directory. - - In the plugin log directory, each directory contains profile data for a - single run (identified by the directory name), and files in the run - directory contains data for different tools and hosts. The file that - contains profile for a specific tool "x" will have a prefix name TOOLS["x"]. - - Example: - log/ - run1/ - plugins/ - profile/ - host1.trace - host2.trace - run2/ - plugins/ - profile/ - host1.trace - host2.trace - - Returns: - A list of host names e.g. - {"host1", "host2", "host3"} for the example. - """ - hosts = {} - run_dir = self._run_dir(run) - if not run_dir: - logger.warn("Cannot find asset directory for: %s", run) - return hosts - tool_pattern = '*' + TOOLS[tool] - try: - files = tf.io.gfile.glob(os.path.join(run_dir, tool_pattern)) - hosts = [os.path.basename(f).replace(TOOLS[tool], '') for f in files] - except tf.errors.OpError as e: - logger.warn("Cannot read asset directory: %s, OpError %s", - run_dir, e) - return hosts - - - @wrappers.Request.application - def hosts_route(self, request): - run = request.args.get('run') - tool = request.args.get('tag') - hosts = self.host_impl(run, tool) - return http_util.Respond(request, hosts, 'application/json') - - def data_impl(self, request): - """Retrieves and processes the tool data for a run and a host. - - Args: - request: XMLHttpRequest - - Returns: - A string that can be served to the frontend tool or None if tool, - run or host is invalid. - """ - run = request.args.get('run') - tool = request.args.get('tag') - host = request.args.get('host') - run_dir = self._run_dir(run) - # Profile plugin "run" is the last component of run dir. - profile_run = os.path.basename(run_dir) - - if tool not in TOOLS: - return None - - self.start_grpc_stub_if_necessary() - if tool == 'trace_viewer@' and self.stub is not None: - from tensorflow.core.profiler import profiler_analysis_pb2 - grpc_request = profiler_analysis_pb2.ProfileSessionDataRequest() - grpc_request.repository_root = os.path.dirname(run_dir) - grpc_request.session_id = profile_run - grpc_request.tool_name = 'trace_viewer' - # Remove the trailing dot if present - grpc_request.host_name = host.rstrip('.') - - grpc_request.parameters['resolution'] = request.args.get('resolution') - if request.args.get('start_time_ms') is not None: - grpc_request.parameters['start_time_ms'] = request.args.get( - 'start_time_ms') - if request.args.get('end_time_ms') is not None: - grpc_request.parameters['end_time_ms'] = request.args.get('end_time_ms') - grpc_response = self.stub.GetSessionToolData(grpc_request) - return grpc_response.output - - if tool not in TOOLS: - return None - tool_name = str(host) + TOOLS[tool] - asset_path = os.path.join(run_dir, tool_name) - raw_data = None - try: - with tf.io.gfile.GFile(asset_path, 'rb') as f: - raw_data = f.read() - except tf.errors.NotFoundError: - logger.warn('Asset path %s not found', asset_path) - except tf.errors.OpError as e: - logger.warn("Couldn't read asset path: %s, OpError %s", asset_path, e) - - if raw_data is None: - return None - if tool == 'trace_viewer': - return process_raw_trace(raw_data) - if tool in _RAW_DATA_TOOLS: - return raw_data - return None - - @wrappers.Request.application - def data_route(self, request): - # params - # request: XMLHTTPRequest. - data = self.data_impl(request) - if data is None: - return http_util.Respond(request, '404 Not Found', 'text/plain', code=404) - return http_util.Respond(request, data, 'application/json') - - @wrappers.Request.application - def capture_route(self, request): - service_addr = request.args.get('service_addr') - duration = int(request.args.get('duration', '1000')) - is_tpu_name = request.args.get('is_tpu_name') == 'true' - worker_list = request.args.get('worker_list') - include_dataset_ops = request.args.get('include_dataset_ops') == 'true' - num_tracing_attempts = int(request.args.get('num_retry', '0')) + 1 - - if is_tpu_name: - try: - tpu_cluster_resolver = ( - tf.distribute.cluster_resolver.TPUClusterResolver(service_addr)) - master_grpc_addr = tpu_cluster_resolver.get_master() - except (ImportError, RuntimeError) as err: - return http_util.Respond(request, {'error': err.message}, - 'application/json', code=200) - except (ValueError, TypeError): - return http_util.Respond(request, - {'error': 'no TPUs with the specified names exist.'}, - 'application/json', code=200) - if not worker_list: - worker_list = get_worker_list(tpu_cluster_resolver) - # TPU cluster resolver always returns port 8470. Replace it with 8466 - # on which profiler service is running. - master_ip = master_grpc_addr.replace('grpc://', '').replace(':8470', '') - service_addr = master_ip + ':8466' - # Set the master TPU for streaming trace viewer. - self.master_tpu_unsecure_channel = master_ip - try: - profiler_client.start_tracing(service_addr, self.logdir, duration, - worker_list, include_dataset_ops, num_tracing_attempts) - return http_util.Respond( - request, {'result': 'Capture profile successfully. Please refresh.'}, - 'application/json') - except tf.errors.UnavailableError: - return http_util.Respond(request, {'error': 'empty trace result.'}, - 'application/json', code=200) - - def get_plugin_apps(self): - return { - TOOLS_ROUTE: self.tools_route, - HOSTS_ROUTE: self.hosts_route, - DATA_ROUTE: self.data_route, - CAPTURE_ROUTE: self.capture_route, - } + plugins/ + profile/ + host1.trace + host2.trace + run2/ + plugins/ + profile/ + host1.trace + host2.trace + + Returns: + A list of host names e.g. + {"host1", "host2", "host3"} for the example. + """ + hosts = {} + run_dir = self._run_dir(run) + if not run_dir: + logger.warn("Cannot find asset directory for: %s", run) + return hosts + tool_pattern = "*" + TOOLS[tool] + try: + files = tf.io.gfile.glob(os.path.join(run_dir, tool_pattern)) + hosts = [ + os.path.basename(f).replace(TOOLS[tool], "") for f in files + ] + except tf.errors.OpError as e: + logger.warn( + "Cannot read asset directory: %s, OpError %s", run_dir, e + ) + return hosts + + @wrappers.Request.application + def hosts_route(self, request): + run = request.args.get("run") + tool = request.args.get("tag") + hosts = self.host_impl(run, tool) + return http_util.Respond(request, hosts, "application/json") + + def data_impl(self, request): + """Retrieves and processes the tool data for a run and a host. + + Args: + request: XMLHttpRequest + + Returns: + A string that can be served to the frontend tool or None if tool, + run or host is invalid. + """ + run = request.args.get("run") + tool = request.args.get("tag") + host = request.args.get("host") + run_dir = self._run_dir(run) + # Profile plugin "run" is the last component of run dir. + profile_run = os.path.basename(run_dir) + + if tool not in TOOLS: + return None + + self.start_grpc_stub_if_necessary() + if tool == "trace_viewer@" and self.stub is not None: + from tensorflow.core.profiler import profiler_analysis_pb2 + + grpc_request = profiler_analysis_pb2.ProfileSessionDataRequest() + grpc_request.repository_root = os.path.dirname(run_dir) + grpc_request.session_id = profile_run + grpc_request.tool_name = "trace_viewer" + # Remove the trailing dot if present + grpc_request.host_name = host.rstrip(".") + + grpc_request.parameters["resolution"] = request.args.get( + "resolution" + ) + if request.args.get("start_time_ms") is not None: + grpc_request.parameters["start_time_ms"] = request.args.get( + "start_time_ms" + ) + if request.args.get("end_time_ms") is not None: + grpc_request.parameters["end_time_ms"] = request.args.get( + "end_time_ms" + ) + grpc_response = self.stub.GetSessionToolData(grpc_request) + return grpc_response.output + + if tool not in TOOLS: + return None + tool_name = str(host) + TOOLS[tool] + asset_path = os.path.join(run_dir, tool_name) + raw_data = None + try: + with tf.io.gfile.GFile(asset_path, "rb") as f: + raw_data = f.read() + except tf.errors.NotFoundError: + logger.warn("Asset path %s not found", asset_path) + except tf.errors.OpError as e: + logger.warn( + "Couldn't read asset path: %s, OpError %s", asset_path, e + ) + + if raw_data is None: + return None + if tool == "trace_viewer": + return process_raw_trace(raw_data) + if tool in _RAW_DATA_TOOLS: + return raw_data + return None + + @wrappers.Request.application + def data_route(self, request): + # params + # request: XMLHTTPRequest. + data = self.data_impl(request) + if data is None: + return http_util.Respond( + request, "404 Not Found", "text/plain", code=404 + ) + return http_util.Respond(request, data, "application/json") + + @wrappers.Request.application + def capture_route(self, request): + service_addr = request.args.get("service_addr") + duration = int(request.args.get("duration", "1000")) + is_tpu_name = request.args.get("is_tpu_name") == "true" + worker_list = request.args.get("worker_list") + include_dataset_ops = request.args.get("include_dataset_ops") == "true" + num_tracing_attempts = int(request.args.get("num_retry", "0")) + 1 + + if is_tpu_name: + try: + tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( + service_addr + ) + master_grpc_addr = tpu_cluster_resolver.get_master() + except (ImportError, RuntimeError) as err: + return http_util.Respond( + request, + {"error": err.message}, + "application/json", + code=200, + ) + except (ValueError, TypeError): + return http_util.Respond( + request, + {"error": "no TPUs with the specified names exist."}, + "application/json", + code=200, + ) + if not worker_list: + worker_list = get_worker_list(tpu_cluster_resolver) + # TPU cluster resolver always returns port 8470. Replace it with 8466 + # on which profiler service is running. + master_ip = master_grpc_addr.replace("grpc://", "").replace( + ":8470", "" + ) + service_addr = master_ip + ":8466" + # Set the master TPU for streaming trace viewer. + self.master_tpu_unsecure_channel = master_ip + try: + profiler_client.start_tracing( + service_addr, + self.logdir, + duration, + worker_list, + include_dataset_ops, + num_tracing_attempts, + ) + return http_util.Respond( + request, + {"result": "Capture profile successfully. Please refresh."}, + "application/json", + ) + except tf.errors.UnavailableError: + return http_util.Respond( + request, + {"error": "empty trace result."}, + "application/json", + code=200, + ) + + def get_plugin_apps(self): + return { + TOOLS_ROUTE: self.tools_route, + HOSTS_ROUTE: self.hosts_route, + DATA_ROUTE: self.data_route, + CAPTURE_ROUTE: self.capture_route, + } diff --git a/tensorboard/plugins/profile/profile_plugin_loader.py b/tensorboard/plugins/profile/profile_plugin_loader.py index cdd4ee856f..6dbeec6625 100644 --- a/tensorboard/plugins/profile/profile_plugin_loader.py +++ b/tensorboard/plugins/profile/profile_plugin_loader.py @@ -22,40 +22,43 @@ class ProfilePluginLoader(base_plugin.TBLoader): - """ProfilePlugin factory. - - This class checks for `tensorflow` install and dependency. - """ - - def define_flags(self, parser): - group = parser.add_argument_group('profile plugin') - group.add_argument( - '--master_tpu_unsecure_channel', - metavar='ADDR', - type=str, - default='', - help='''\ + """ProfilePlugin factory. + + This class checks for `tensorflow` install and dependency. + """ + + def define_flags(self, parser): + group = parser.add_argument_group("profile plugin") + group.add_argument( + "--master_tpu_unsecure_channel", + metavar="ADDR", + type=str, + default="", + help="""\ IP address of "master tpu", used for getting streaming trace data through tpu profiler analysis grpc. The grpc channel is not secured.\ -''') +""", + ) - def load(self, context): - """Returns the plugin, if possible. + def load(self, context): + """Returns the plugin, if possible. - Args: - context: The TBContext flags. + Args: + context: The TBContext flags. - Returns: - A ProfilePlugin instance or None if it couldn't be loaded. - """ - try: - # pylint: disable=unused-import - import tensorflow - # Available in TensorFlow 1.14 or later, so do import check - # pylint: disable=unused-import - from tensorflow.python.eager import profiler_client - except ImportError: - return - - from tensorboard.plugins.profile.profile_plugin import ProfilePlugin - return ProfilePlugin(context) + Returns: + A ProfilePlugin instance or None if it couldn't be loaded. + """ + try: + # pylint: disable=unused-import + import tensorflow + + # Available in TensorFlow 1.14 or later, so do import check + # pylint: disable=unused-import + from tensorflow.python.eager import profiler_client + except ImportError: + return + + from tensorboard.plugins.profile.profile_plugin import ProfilePlugin + + return ProfilePlugin(context) diff --git a/tensorboard/plugins/profile/profile_plugin_test.py b/tensorboard/plugins/profile/profile_plugin_test.py index ce82b58df2..f69951d4a8 100644 --- a/tensorboard/plugins/profile/profile_plugin_test.py +++ b/tensorboard/plugins/profile/profile_plugin_test.py @@ -25,7 +25,9 @@ from werkzeug import Request from tensorboard.backend.event_processing import plugin_asset_util -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from tensorboard.backend.event_processing import ( + plugin_event_multiplexer as event_multiplexer, +) from tensorboard.plugins import base_plugin from tensorboard.plugins.profile import profile_plugin from tensorboard.plugins.profile import trace_events_pb2 @@ -35,44 +37,33 @@ class FakeFlags(object): - def __init__( - self, - logdir, - master_tpu_unsecure_channel=''): - self.logdir = logdir - self.master_tpu_unsecure_channel = master_tpu_unsecure_channel + def __init__(self, logdir, master_tpu_unsecure_channel=""): + self.logdir = logdir + self.master_tpu_unsecure_channel = master_tpu_unsecure_channel RUN_TO_TOOLS = { - 'foo': ['trace_viewer'], - 'bar': ['unsupported'], - 'baz': ['trace_viewer'], - 'empty': [], + "foo": ["trace_viewer"], + "bar": ["unsupported"], + "baz": ["trace_viewer"], + "empty": [], } RUN_TO_HOSTS = { - 'foo': ['host0', 'host1'], - 'bar': ['host1'], - 'baz': ['host2'], - 'empty': [], + "foo": ["host0", "host1"], + "bar": ["host1"], + "baz": ["host2"], + "empty": [], } EXPECTED_TRACE_DATA = dict( - displayTimeUnit='ns', - metadata={'highres-ticks': True}, + displayTimeUnit="ns", + metadata={"highres-ticks": True}, traceEvents=[ - dict( - ph='M', - pid=0, - name='process_name', - args=dict(name='foo')), - dict( - ph='M', - pid=0, - name='process_sort_index', - args=dict(sort_index=0)), + dict(ph="M", pid=0, name="process_name", args=dict(name="foo")), + dict(ph="M", pid=0, name="process_sort_index", args=dict(sort_index=0)), dict(), ], ) @@ -81,173 +72,201 @@ def __init__( # Suffix for the empty eventfile to write. Should be kept in sync with TF # profiler kProfileEmptySuffix constant defined in: # tensorflow/core/profiler/rpc/client/capture_profile.cc. -EVENT_FILE_SUFFIX = '.profile-empty' +EVENT_FILE_SUFFIX = ".profile-empty" def generate_testdata(logdir): - plugin_logdir = plugin_asset_util.PluginDirectory( - logdir, profile_plugin.ProfilePlugin.plugin_name) - os.makedirs(plugin_logdir) - for run in RUN_TO_TOOLS: - run_dir = os.path.join(plugin_logdir, run) - os.mkdir(run_dir) - for tool in RUN_TO_TOOLS[run]: - if tool not in profile_plugin.TOOLS: - continue - for host in RUN_TO_HOSTS[run]: - file_name = host + profile_plugin.TOOLS[tool] - tool_file = os.path.join(run_dir, file_name) - if tool == 'trace_viewer': - trace = trace_events_pb2.Trace() - trace.devices[0].name = run - data = trace.SerializeToString() - else: - data = tool - with open(tool_file, 'wb') as f: - f.write(data) - with open(os.path.join(plugin_logdir, 'noise'), 'w') as f: - f.write('Not a dir, not a run.') + plugin_logdir = plugin_asset_util.PluginDirectory( + logdir, profile_plugin.ProfilePlugin.plugin_name + ) + os.makedirs(plugin_logdir) + for run in RUN_TO_TOOLS: + run_dir = os.path.join(plugin_logdir, run) + os.mkdir(run_dir) + for tool in RUN_TO_TOOLS[run]: + if tool not in profile_plugin.TOOLS: + continue + for host in RUN_TO_HOSTS[run]: + file_name = host + profile_plugin.TOOLS[tool] + tool_file = os.path.join(run_dir, file_name) + if tool == "trace_viewer": + trace = trace_events_pb2.Trace() + trace.devices[0].name = run + data = trace.SerializeToString() + else: + data = tool + with open(tool_file, "wb") as f: + f.write(data) + with open(os.path.join(plugin_logdir, "noise"), "w") as f: + f.write("Not a dir, not a run.") def write_empty_event_file(logdir): - w = tf.compat.v2.summary.create_file_writer( - logdir, filename_suffix=EVENT_FILE_SUFFIX) - w.close() + w = tf.compat.v2.summary.create_file_writer( + logdir, filename_suffix=EVENT_FILE_SUFFIX + ) + w.close() class ProfilePluginTest(tf.test.TestCase): - - def setUp(self): - self.logdir = self.get_temp_dir() - self.multiplexer = event_multiplexer.EventMultiplexer() - self.multiplexer.AddRunsFromDirectory(self.logdir) - context = base_plugin.TBContext( - logdir=self.logdir, - multiplexer=self.multiplexer, - flags=FakeFlags(self.logdir)) - self.plugin = profile_plugin.ProfilePlugin(context) - self.apps = self.plugin.get_plugin_apps() - - def testRuns_logdirWithoutEventFile(self): - generate_testdata(self.logdir) - self.multiplexer.Reload() - runs = dict(self.plugin.generate_run_to_tools()) - self.assertItemsEqual(runs.keys(), RUN_TO_TOOLS.keys()) - self.assertItemsEqual(runs['foo'], RUN_TO_TOOLS['foo']) - self.assertItemsEqual(runs['bar'], []) - self.assertItemsEqual(runs['empty'], []) - - def testRuns_logdirWithEventFIle(self): - write_empty_event_file(self.logdir) - generate_testdata(self.logdir) - self.multiplexer.Reload() - runs = dict(self.plugin.generate_run_to_tools()) - self.assertItemsEqual(runs.keys(), RUN_TO_TOOLS.keys()) - - def testRuns_withSubdirectories(self): - subdir_a = os.path.join(self.logdir, 'a') - subdir_b = os.path.join(self.logdir, 'b') - subdir_b_c = os.path.join(subdir_b, 'c') - generate_testdata(self.logdir) - generate_testdata(subdir_a) - generate_testdata(subdir_b) - generate_testdata(subdir_b_c) - write_empty_event_file(self.logdir) - write_empty_event_file(subdir_a) - # Skip writing an event file for subdir_b - write_empty_event_file(subdir_b_c) - self.multiplexer.AddRunsFromDirectory(self.logdir) - self.multiplexer.Reload() - runs = dict(self.plugin.generate_run_to_tools()) - # Expect runs for the logdir root, 'a', and 'b/c' but not for 'b' - # because it doesn't contain a tfevents file. - expected = list(RUN_TO_TOOLS.keys()) - expected.extend('a/' + run for run in RUN_TO_TOOLS.keys()) - expected.extend('b/c/' + run for run in RUN_TO_TOOLS.keys()) - self.assertItemsEqual(runs.keys(), expected) - - def makeRequest(self, run, tag, host): - req = Request({}) - req.args = {'run': run, 'tag': tag, 'host': host,} - return req - - def testHosts(self): - generate_testdata(self.logdir) - subdir_a = os.path.join(self.logdir, 'a') - generate_testdata(subdir_a) - write_empty_event_file(subdir_a) - self.multiplexer.AddRunsFromDirectory(self.logdir) - self.multiplexer.Reload() - hosts = self.plugin.host_impl('foo', 'trace_viewer') - self.assertItemsEqual(['host0', 'host1'], sorted(hosts)) - hosts_a = self.plugin.host_impl('a/foo', 'trace_viewer') - self.assertItemsEqual(['host0', 'host1'], sorted(hosts_a)) - - def testData(self): - generate_testdata(self.logdir) - subdir_a = os.path.join(self.logdir, 'a') - generate_testdata(subdir_a) - write_empty_event_file(subdir_a) - self.multiplexer.AddRunsFromDirectory(self.logdir) - self.multiplexer.Reload() - trace = json.loads(self.plugin.data_impl( - self.makeRequest('foo', 'trace_viewer', 'host0'))) - self.assertEqual(trace, EXPECTED_TRACE_DATA) - trace_a = json.loads(self.plugin.data_impl( - self.makeRequest('a/foo', 'trace_viewer', 'host0'))) - self.assertEqual(trace_a, EXPECTED_TRACE_DATA) - - # Invalid tool/run. - self.assertEqual(None, self.plugin.data_impl( - self.makeRequest('foo', 'nonono', 'host0'))) - self.assertEqual(None, self.plugin.data_impl( - self.makeRequest('foo', 'trace_viewer', ''))) - self.assertEqual(None, self.plugin.data_impl( - self.makeRequest('bar', 'unsupported', 'host1'))) - self.assertEqual(None, self.plugin.data_impl( - self.makeRequest('empty', 'trace_viewer', ''))) - self.assertEqual(None, self.plugin.data_impl( - self.makeRequest('a', 'trace_viewer', ''))) - - def testActive(self): - def wait_for_thread(): - with self.plugin._is_active_lock: - pass - # Launch thread to check if active. - self.plugin.is_active() - wait_for_thread() - # Should be false since there's no data yet. - self.assertFalse(self.plugin.is_active()) - wait_for_thread() - generate_testdata(self.logdir) - self.multiplexer.Reload() - # Launch a new thread to check if active. - self.plugin.is_active() - wait_for_thread() - # Now that there's data, this should be active. - self.assertTrue(self.plugin.is_active()) - - def testActive_subdirectoryOnly(self): - def wait_for_thread(): - with self.plugin._is_active_lock: - pass - # Launch thread to check if active. - self.plugin.is_active() - wait_for_thread() - # Should be false since there's no data yet. - self.assertFalse(self.plugin.is_active()) - wait_for_thread() - subdir_a = os.path.join(self.logdir, 'a') - generate_testdata(subdir_a) - write_empty_event_file(subdir_a) - self.multiplexer.AddRunsFromDirectory(self.logdir) - self.multiplexer.Reload() - # Launch a new thread to check if active. - self.plugin.is_active() - wait_for_thread() - # Now that there's data, this should be active. - self.assertTrue(self.plugin.is_active()) - - -if __name__ == '__main__': - tf.test.main() + def setUp(self): + self.logdir = self.get_temp_dir() + self.multiplexer = event_multiplexer.EventMultiplexer() + self.multiplexer.AddRunsFromDirectory(self.logdir) + context = base_plugin.TBContext( + logdir=self.logdir, + multiplexer=self.multiplexer, + flags=FakeFlags(self.logdir), + ) + self.plugin = profile_plugin.ProfilePlugin(context) + self.apps = self.plugin.get_plugin_apps() + + def testRuns_logdirWithoutEventFile(self): + generate_testdata(self.logdir) + self.multiplexer.Reload() + runs = dict(self.plugin.generate_run_to_tools()) + self.assertItemsEqual(runs.keys(), RUN_TO_TOOLS.keys()) + self.assertItemsEqual(runs["foo"], RUN_TO_TOOLS["foo"]) + self.assertItemsEqual(runs["bar"], []) + self.assertItemsEqual(runs["empty"], []) + + def testRuns_logdirWithEventFIle(self): + write_empty_event_file(self.logdir) + generate_testdata(self.logdir) + self.multiplexer.Reload() + runs = dict(self.plugin.generate_run_to_tools()) + self.assertItemsEqual(runs.keys(), RUN_TO_TOOLS.keys()) + + def testRuns_withSubdirectories(self): + subdir_a = os.path.join(self.logdir, "a") + subdir_b = os.path.join(self.logdir, "b") + subdir_b_c = os.path.join(subdir_b, "c") + generate_testdata(self.logdir) + generate_testdata(subdir_a) + generate_testdata(subdir_b) + generate_testdata(subdir_b_c) + write_empty_event_file(self.logdir) + write_empty_event_file(subdir_a) + # Skip writing an event file for subdir_b + write_empty_event_file(subdir_b_c) + self.multiplexer.AddRunsFromDirectory(self.logdir) + self.multiplexer.Reload() + runs = dict(self.plugin.generate_run_to_tools()) + # Expect runs for the logdir root, 'a', and 'b/c' but not for 'b' + # because it doesn't contain a tfevents file. + expected = list(RUN_TO_TOOLS.keys()) + expected.extend("a/" + run for run in RUN_TO_TOOLS.keys()) + expected.extend("b/c/" + run for run in RUN_TO_TOOLS.keys()) + self.assertItemsEqual(runs.keys(), expected) + + def makeRequest(self, run, tag, host): + req = Request({}) + req.args = { + "run": run, + "tag": tag, + "host": host, + } + return req + + def testHosts(self): + generate_testdata(self.logdir) + subdir_a = os.path.join(self.logdir, "a") + generate_testdata(subdir_a) + write_empty_event_file(subdir_a) + self.multiplexer.AddRunsFromDirectory(self.logdir) + self.multiplexer.Reload() + hosts = self.plugin.host_impl("foo", "trace_viewer") + self.assertItemsEqual(["host0", "host1"], sorted(hosts)) + hosts_a = self.plugin.host_impl("a/foo", "trace_viewer") + self.assertItemsEqual(["host0", "host1"], sorted(hosts_a)) + + def testData(self): + generate_testdata(self.logdir) + subdir_a = os.path.join(self.logdir, "a") + generate_testdata(subdir_a) + write_empty_event_file(subdir_a) + self.multiplexer.AddRunsFromDirectory(self.logdir) + self.multiplexer.Reload() + trace = json.loads( + self.plugin.data_impl( + self.makeRequest("foo", "trace_viewer", "host0") + ) + ) + self.assertEqual(trace, EXPECTED_TRACE_DATA) + trace_a = json.loads( + self.plugin.data_impl( + self.makeRequest("a/foo", "trace_viewer", "host0") + ) + ) + self.assertEqual(trace_a, EXPECTED_TRACE_DATA) + + # Invalid tool/run. + self.assertEqual( + None, + self.plugin.data_impl(self.makeRequest("foo", "nonono", "host0")), + ) + self.assertEqual( + None, + self.plugin.data_impl(self.makeRequest("foo", "trace_viewer", "")), + ) + self.assertEqual( + None, + self.plugin.data_impl( + self.makeRequest("bar", "unsupported", "host1") + ), + ) + self.assertEqual( + None, + self.plugin.data_impl( + self.makeRequest("empty", "trace_viewer", "") + ), + ) + self.assertEqual( + None, + self.plugin.data_impl(self.makeRequest("a", "trace_viewer", "")), + ) + + def testActive(self): + def wait_for_thread(): + with self.plugin._is_active_lock: + pass + + # Launch thread to check if active. + self.plugin.is_active() + wait_for_thread() + # Should be false since there's no data yet. + self.assertFalse(self.plugin.is_active()) + wait_for_thread() + generate_testdata(self.logdir) + self.multiplexer.Reload() + # Launch a new thread to check if active. + self.plugin.is_active() + wait_for_thread() + # Now that there's data, this should be active. + self.assertTrue(self.plugin.is_active()) + + def testActive_subdirectoryOnly(self): + def wait_for_thread(): + with self.plugin._is_active_lock: + pass + + # Launch thread to check if active. + self.plugin.is_active() + wait_for_thread() + # Should be false since there's no data yet. + self.assertFalse(self.plugin.is_active()) + wait_for_thread() + subdir_a = os.path.join(self.logdir, "a") + generate_testdata(subdir_a) + write_empty_event_file(subdir_a) + self.multiplexer.AddRunsFromDirectory(self.logdir) + self.multiplexer.Reload() + # Launch a new thread to check if active. + self.plugin.is_active() + wait_for_thread() + # Now that there's data, this should be active. + self.assertTrue(self.plugin.is_active()) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/profile/trace_events_json.py b/tensorboard/plugins/profile/trace_events_json.py index 3305da2b84..2ec89fa4b9 100644 --- a/tensorboard/plugins/profile/trace_events_json.py +++ b/tensorboard/plugins/profile/trace_events_json.py @@ -23,84 +23,90 @@ import six # Values for type (ph) and s (scope) parameters in catapult trace format. -_TYPE_METADATA = 'M' -_TYPE_COMPLETE = 'X' -_TYPE_INSTANT = 'i' -_SCOPE_THREAD = 't' +_TYPE_METADATA = "M" +_TYPE_COMPLETE = "X" +_TYPE_INSTANT = "i" +_SCOPE_THREAD = "t" class TraceEventsJsonStream(object): - """A streaming trace file in the format expected by catapult trace viewer. + """A streaming trace file in the format expected by catapult trace viewer. - Iterating over this yields a sequence of string chunks, so it is suitable for - returning in a werkzeug Response. - """ + Iterating over this yields a sequence of string chunks, so it is + suitable for returning in a werkzeug Response. + """ - def __init__(self, proto): - """Create an iterable JSON stream over the supplied Trace. + def __init__(self, proto): + """Create an iterable JSON stream over the supplied Trace. - Args: - proto: a tensorboard.profile.Trace protobuf - """ - self._proto = proto + Args: + proto: a tensorboard.profile.Trace protobuf + """ + self._proto = proto - def _events(self): - """Iterator over all catapult trace events, as python values.""" - for did, device in sorted(six.iteritems(self._proto.devices)): - if device.name: - yield dict( - ph=_TYPE_METADATA, - pid=did, - name='process_name', - args=dict(name=device.name)) - yield dict( - ph=_TYPE_METADATA, - pid=did, - name='process_sort_index', - args=dict(sort_index=did)) - for rid, resource in sorted(six.iteritems(device.resources)): - if resource.name: - yield dict( - ph=_TYPE_METADATA, - pid=did, - tid=rid, - name='thread_name', - args=dict(name=resource.name)) - yield dict( - ph=_TYPE_METADATA, - pid=did, - tid=rid, - name='thread_sort_index', - args=dict(sort_index=rid)) - # TODO(sammccall): filtering and downsampling? - for event in self._proto.trace_events: - yield self._event(event) + def _events(self): + """Iterator over all catapult trace events, as python values.""" + for did, device in sorted(six.iteritems(self._proto.devices)): + if device.name: + yield dict( + ph=_TYPE_METADATA, + pid=did, + name="process_name", + args=dict(name=device.name), + ) + yield dict( + ph=_TYPE_METADATA, + pid=did, + name="process_sort_index", + args=dict(sort_index=did), + ) + for rid, resource in sorted(six.iteritems(device.resources)): + if resource.name: + yield dict( + ph=_TYPE_METADATA, + pid=did, + tid=rid, + name="thread_name", + args=dict(name=resource.name), + ) + yield dict( + ph=_TYPE_METADATA, + pid=did, + tid=rid, + name="thread_sort_index", + args=dict(sort_index=rid), + ) + # TODO(sammccall): filtering and downsampling? + for event in self._proto.trace_events: + yield self._event(event) - def _event(self, event): - """Converts a TraceEvent proto into a catapult trace event python value.""" - result = dict( - pid=event.device_id, - tid=event.resource_id, - name=event.name, - ts=event.timestamp_ps / 1000000.0) - if event.duration_ps: - result['ph'] = _TYPE_COMPLETE - result['dur'] = event.duration_ps / 1000000.0 - else: - result['ph'] = _TYPE_INSTANT - result['s'] = _SCOPE_THREAD - for key in dict(event.args): - if 'args' not in result: - result['args'] = {} - result['args'][key] = event.args[key] - return result + def _event(self, event): + """Converts a TraceEvent proto into a catapult trace event python + value.""" + result = dict( + pid=event.device_id, + tid=event.resource_id, + name=event.name, + ts=event.timestamp_ps / 1000000.0, + ) + if event.duration_ps: + result["ph"] = _TYPE_COMPLETE + result["dur"] = event.duration_ps / 1000000.0 + else: + result["ph"] = _TYPE_INSTANT + result["s"] = _SCOPE_THREAD + for key in dict(event.args): + if "args" not in result: + result["args"] = {} + result["args"][key] = event.args[key] + return result - def __iter__(self): - """Returns an iterator of string chunks of a complete JSON document.""" - yield '{"displayTimeUnit":"ns","metadata":{"highres-ticks":true},\n' - yield '"traceEvents":[\n' - for event in self._events(): - yield json.dumps(event) - yield ',\n' - # Add one fake event to avoid dealing with no-trailing-comma rule. - yield '{}]}\n' + def __iter__(self): + """Returns an iterator of string chunks of a complete JSON document.""" + yield '{"displayTimeUnit":"ns","metadata":{"highres-ticks":true},\n' + yield '"traceEvents":[\n' + for event in self._events(): + yield json.dumps(event) + yield ",\n" + # Add one fake event to avoid dealing with no-trailing-comma rule. + yield "{}]}\n" diff --git a/tensorboard/plugins/profile/trace_events_json_test.py b/tensorboard/plugins/profile/trace_events_json_test.py index ca47e26261..1157a5c3c7 100644 --- a/tensorboard/plugins/profile/trace_events_json_test.py +++ b/tensorboard/plugins/profile/trace_events_json_test.py @@ -28,15 +28,17 @@ class TraceEventsJsonStreamTest(tf.test.TestCase): + def convert(self, proto_text): + proto = trace_events_pb2.Trace() + text_format.Merge(proto_text, proto) + return json.loads( + "".join(trace_events_json.TraceEventsJsonStream(proto)) + ) - def convert(self, proto_text): - proto = trace_events_pb2.Trace() - text_format.Merge(proto_text, proto) - return json.loads(''.join(trace_events_json.TraceEventsJsonStream(proto))) - - def testJsonConversion(self): - self.assertEqual( - self.convert(""" + def testJsonConversion(self): + self.assertEqual( + self.convert( + """ devices { key: 2 value { name: 'D2' device_id: 2 @@ -69,59 +71,73 @@ def testJsonConversion(self): name: "E2.2.1" timestamp_ps: 105000 } - """), - dict( - displayTimeUnit='ns', - metadata={'highres-ticks': True}, - traceEvents=[ - dict(ph='M', pid=1, name='process_name', args=dict(name='D1')), - dict( - ph='M', - pid=1, - name='process_sort_index', - args=dict(sort_index=1)), - dict( - ph='M', - pid=1, - tid=2, - name='thread_name', - args=dict(name='R1.2')), - dict( - ph='M', - pid=1, - tid=2, - name='thread_sort_index', - args=dict(sort_index=2)), - dict(ph='M', pid=2, name='process_name', args=dict(name='D2')), - dict( - ph='M', - pid=2, - name='process_sort_index', - args=dict(sort_index=2)), - dict( - ph='M', - pid=2, - tid=2, - name='thread_name', - args=dict(name='R2.2')), - dict( - ph='M', - pid=2, - tid=2, - name='thread_sort_index', - args=dict(sort_index=2)), - dict( - ph='X', - pid=1, - tid=2, - name='E1.2.1', - ts=0.1, - dur=0.01, - args=dict(label='E1.2.1', extra='extra info')), - dict(ph='i', pid=2, tid=2, name='E2.2.1', ts=0.105, s='t'), - {}, - ])) + """ + ), + dict( + displayTimeUnit="ns", + metadata={"highres-ticks": True}, + traceEvents=[ + dict( + ph="M", pid=1, name="process_name", args=dict(name="D1") + ), + dict( + ph="M", + pid=1, + name="process_sort_index", + args=dict(sort_index=1), + ), + dict( + ph="M", + pid=1, + tid=2, + name="thread_name", + args=dict(name="R1.2"), + ), + dict( + ph="M", + pid=1, + tid=2, + name="thread_sort_index", + args=dict(sort_index=2), + ), + dict( + ph="M", pid=2, name="process_name", args=dict(name="D2") + ), + dict( + ph="M", + pid=2, + name="process_sort_index", + args=dict(sort_index=2), + ), + dict( + ph="M", + pid=2, + tid=2, + name="thread_name", + args=dict(name="R2.2"), + ), + dict( + ph="M", + pid=2, + tid=2, + name="thread_sort_index", + args=dict(sort_index=2), + ), + dict( + ph="X", + pid=1, + tid=2, + name="E1.2.1", + ts=0.1, + dur=0.01, + args=dict(label="E1.2.1", extra="extra info"), + ), + dict(ph="i", pid=2, tid=2, name="E2.2.1", ts=0.105, s="t"), + {}, + ], + ), + ) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/projector/__init__.py b/tensorboard/plugins/projector/__init__.py index 47e0008f56..af0c4e92b0 100644 --- a/tensorboard/plugins/projector/__init__.py +++ b/tensorboard/plugins/projector/__init__.py @@ -36,30 +36,30 @@ def visualize_embeddings(logdir, config): - """Stores a config file used by the embedding projector. + """Stores a config file used by the embedding projector. - Args: - logdir: Directory into which to store the config file, as a `str`. - For compatibility, can also be a `tf.compat.v1.summary.FileWriter` - object open at the desired logdir. - config: `tf.contrib.tensorboard.plugins.projector.ProjectorConfig` - proto that holds the configuration for the projector such as paths to - checkpoint files and metadata files for the embeddings. If - `config.model_checkpoint_path` is none, it defaults to the - `logdir` used by the summary_writer. + Args: + logdir: Directory into which to store the config file, as a `str`. + For compatibility, can also be a `tf.compat.v1.summary.FileWriter` + object open at the desired logdir. + config: `tf.contrib.tensorboard.plugins.projector.ProjectorConfig` + proto that holds the configuration for the projector such as paths to + checkpoint files and metadata files for the embeddings. If + `config.model_checkpoint_path` is none, it defaults to the + `logdir` used by the summary_writer. - Raises: - ValueError: If the summary writer does not have a `logdir`. - """ - # Convert from `tf.compat.v1.summary.FileWriter` if necessary. - logdir = getattr(logdir, 'get_logdir', lambda: logdir)() + Raises: + ValueError: If the summary writer does not have a `logdir`. + """ + # Convert from `tf.compat.v1.summary.FileWriter` if necessary. + logdir = getattr(logdir, "get_logdir", lambda: logdir)() - # Sanity checks. - if logdir is None: - raise ValueError('Expected logdir to be a path, but got None') + # Sanity checks. + if logdir is None: + raise ValueError("Expected logdir to be a path, but got None") - # Saving the config file in the logdir. - config_pbtxt = _text_format.MessageToString(config) - path = os.path.join(logdir, _projector_plugin.PROJECTOR_FILENAME) - with tf.io.gfile.GFile(path, 'w') as f: - f.write(config_pbtxt) + # Saving the config file in the logdir. + config_pbtxt = _text_format.MessageToString(config) + path = os.path.join(logdir, _projector_plugin.PROJECTOR_FILENAME) + with tf.io.gfile.GFile(path, "w") as f: + f.write(config_pbtxt) diff --git a/tensorboard/plugins/projector/projector_api_test.py b/tensorboard/plugins/projector/projector_api_test.py index 417c5839a0..e05f99b915 100644 --- a/tensorboard/plugins/projector/projector_api_test.py +++ b/tensorboard/plugins/projector/projector_api_test.py @@ -28,55 +28,58 @@ from tensorboard.plugins import projector from tensorboard.util import test_util -def create_dummy_config(): - return projector.ProjectorConfig( - model_checkpoint_path='test', - embeddings = [ - projector.EmbeddingInfo( - tensor_name='tensor1', - metadata_path='metadata1', - ), - ], - ) - -class ProjectorApiTest(tf.test.TestCase): - - def test_visualize_embeddings_with_logdir(self): - logdir = self.get_temp_dir() - config = create_dummy_config() - projector.visualize_embeddings(logdir, config) - - # Read the configurations from disk and make sure it matches the original. - with tf.io.gfile.GFile(os.path.join(logdir, 'projector_config.pbtxt')) as f: - config2 = projector.ProjectorConfig() - text_format.Parse(f.read(), config2) - - self.assertEqual(config, config2) - def test_visualize_embeddings_with_file_writer(self): - if tf.__version__ == "stub": - self.skipTest("Requires TensorFlow for FileWriter") - logdir = self.get_temp_dir() - config = create_dummy_config() - - with tf.compat.v1.Graph().as_default(): - with test_util.FileWriterCache.get(logdir) as writer: - projector.visualize_embeddings(writer, config) - - # Read the configurations from disk and make sure it matches the original. - with tf.io.gfile.GFile(os.path.join(logdir, 'projector_config.pbtxt')) as f: - config2 = projector.ProjectorConfig() - text_format.Parse(f.read(), config2) - - self.assertEqual(config, config2) - - def test_visualize_embeddings_no_logdir(self): - with six.assertRaisesRegex( - self, - ValueError, - "Expected logdir to be a path, but got None"): - projector.visualize_embeddings(None, create_dummy_config()) +def create_dummy_config(): + return projector.ProjectorConfig( + model_checkpoint_path="test", + embeddings=[ + projector.EmbeddingInfo( + tensor_name="tensor1", metadata_path="metadata1", + ), + ], + ) -if __name__ == '__main__': - tf.test.main() +class ProjectorApiTest(tf.test.TestCase): + def test_visualize_embeddings_with_logdir(self): + logdir = self.get_temp_dir() + config = create_dummy_config() + projector.visualize_embeddings(logdir, config) + + # Read the configurations from disk and make sure it matches the original. + with tf.io.gfile.GFile( + os.path.join(logdir, "projector_config.pbtxt") + ) as f: + config2 = projector.ProjectorConfig() + text_format.Parse(f.read(), config2) + + self.assertEqual(config, config2) + + def test_visualize_embeddings_with_file_writer(self): + if tf.__version__ == "stub": + self.skipTest("Requires TensorFlow for FileWriter") + logdir = self.get_temp_dir() + config = create_dummy_config() + + with tf.compat.v1.Graph().as_default(): + with test_util.FileWriterCache.get(logdir) as writer: + projector.visualize_embeddings(writer, config) + + # Read the configurations from disk and make sure it matches the original. + with tf.io.gfile.GFile( + os.path.join(logdir, "projector_config.pbtxt") + ) as f: + config2 = projector.ProjectorConfig() + text_format.Parse(f.read(), config2) + + self.assertEqual(config, config2) + + def test_visualize_embeddings_no_logdir(self): + with six.assertRaisesRegex( + self, ValueError, "Expected logdir to be a path, but got None" + ): + projector.visualize_embeddings(None, create_dummy_config()) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/projector/projector_plugin.py b/tensorboard/plugins/projector/projector_plugin.py index 11c36a4b64..a53a538fa4 100644 --- a/tensorboard/plugins/projector/projector_plugin.py +++ b/tensorboard/plugins/projector/projector_plugin.py @@ -42,653 +42,745 @@ logger = tb_logging.get_logger() # The prefix of routes provided by this plugin. -_PLUGIN_PREFIX_ROUTE = 'projector' +_PLUGIN_PREFIX_ROUTE = "projector" # FYI - the PROJECTOR_FILENAME is hardcoded in the visualize_embeddings # method in tf.contrib.tensorboard.plugins.projector module. # TODO(@decentralion): Fix duplication when we find a permanent home for the # projector module. -PROJECTOR_FILENAME = 'projector_config.pbtxt' -_PLUGIN_NAME = 'org_tensorflow_tensorboard_projector' -_PLUGINS_DIR = 'plugins' +PROJECTOR_FILENAME = "projector_config.pbtxt" +_PLUGIN_NAME = "org_tensorflow_tensorboard_projector" +_PLUGINS_DIR = "plugins" # Number of tensors in the LRU cache. _TENSOR_CACHE_CAPACITY = 1 # HTTP routes. -CONFIG_ROUTE = '/info' -TENSOR_ROUTE = '/tensor' -METADATA_ROUTE = '/metadata' -RUNS_ROUTE = '/runs' -BOOKMARKS_ROUTE = '/bookmarks' -SPRITE_IMAGE_ROUTE = '/sprite_image' +CONFIG_ROUTE = "/info" +TENSOR_ROUTE = "/tensor" +METADATA_ROUTE = "/metadata" +RUNS_ROUTE = "/runs" +BOOKMARKS_ROUTE = "/bookmarks" +SPRITE_IMAGE_ROUTE = "/sprite_image" _IMGHDR_TO_MIMETYPE = { - 'bmp': 'image/bmp', - 'gif': 'image/gif', - 'jpeg': 'image/jpeg', - 'png': 'image/png' + "bmp": "image/bmp", + "gif": "image/gif", + "jpeg": "image/jpeg", + "png": "image/png", } -_DEFAULT_IMAGE_MIMETYPE = 'application/octet-stream' +_DEFAULT_IMAGE_MIMETYPE = "application/octet-stream" class LRUCache(object): - """LRU cache. Used for storing the last used tensor.""" + """LRU cache. - def __init__(self, size): - if size < 1: - raise ValueError('The cache size must be >=1') - self._size = size - self._dict = collections.OrderedDict() - - def get(self, key): - try: - value = self._dict.pop(key) - self._dict[key] = value - return value - except KeyError: - return None - - def set(self, key, value): - if value is None: - raise ValueError('value must be != None') - try: - self._dict.pop(key) - except KeyError: - if len(self._dict) >= self._size: - self._dict.popitem(last=False) - self._dict[key] = value - - -class EmbeddingMetadata(object): - """Metadata container for an embedding. - - The metadata holds different columns with values used for visualization - (color by, label by) in the "Embeddings" tab in TensorBoard. - """ + Used for storing the last used tensor. + """ - def __init__(self, num_points): - """Constructs a metadata for an embedding of the specified size. + def __init__(self, size): + if size < 1: + raise ValueError("The cache size must be >=1") + self._size = size + self._dict = collections.OrderedDict() - Args: - num_points: Number of points in the embedding. - """ - self.num_points = num_points - self.column_names = [] - self.name_to_values = {} + def get(self, key): + try: + value = self._dict.pop(key) + self._dict[key] = value + return value + except KeyError: + return None + + def set(self, key, value): + if value is None: + raise ValueError("value must be != None") + try: + self._dict.pop(key) + except KeyError: + if len(self._dict) >= self._size: + self._dict.popitem(last=False) + self._dict[key] = value - def add_column(self, column_name, column_values): - """Adds a named column of metadata values. - Args: - column_name: Name of the column. - column_values: 1D array/list/iterable holding the column values. Must be - of length `num_points`. The i-th value corresponds to the i-th point. +class EmbeddingMetadata(object): + """Metadata container for an embedding. - Raises: - ValueError: If `column_values` is not 1D array, or of length `num_points`, - or the `name` is already used. + The metadata holds different columns with values used for + visualization (color by, label by) in the "Embeddings" tab in + TensorBoard. """ - # Sanity checks. - if isinstance(column_values, list) and isinstance(column_values[0], list): - raise ValueError('"column_values" must be a flat list, but we detected ' - 'that its first entry is a list') - - if isinstance(column_values, np.ndarray) and column_values.ndim != 1: - raise ValueError('"column_values" should be of rank 1, ' - 'but is of rank %d' % column_values.ndim) - if len(column_values) != self.num_points: - raise ValueError('"column_values" should be of length %d, but is of ' - 'length %d' % (self.num_points, len(column_values))) - if column_name in self.name_to_values: - raise ValueError('The column name "%s" is already used' % column_name) - self.column_names.append(column_name) - self.name_to_values[column_name] = column_values + def __init__(self, num_points): + """Constructs a metadata for an embedding of the specified size. + + Args: + num_points: Number of points in the embedding. + """ + self.num_points = num_points + self.column_names = [] + self.name_to_values = {} + + def add_column(self, column_name, column_values): + """Adds a named column of metadata values. + + Args: + column_name: Name of the column. + column_values: 1D array/list/iterable holding the column values. Must be + of length `num_points`. The i-th value corresponds to the i-th point. + + Raises: + ValueError: If `column_values` is not 1D array, or of length `num_points`, + or the `name` is already used. + """ + # Sanity checks. + if isinstance(column_values, list) and isinstance( + column_values[0], list + ): + raise ValueError( + '"column_values" must be a flat list, but we detected ' + "that its first entry is a list" + ) + + if isinstance(column_values, np.ndarray) and column_values.ndim != 1: + raise ValueError( + '"column_values" should be of rank 1, ' + "but is of rank %d" % column_values.ndim + ) + if len(column_values) != self.num_points: + raise ValueError( + '"column_values" should be of length %d, but is of ' + "length %d" % (self.num_points, len(column_values)) + ) + if column_name in self.name_to_values: + raise ValueError( + 'The column name "%s" is already used' % column_name + ) + + self.column_names.append(column_name) + self.name_to_values[column_name] = column_values def _read_tensor_tsv_file(fpath): - with tf.io.gfile.GFile(fpath, 'r') as f: - tensor = [] - for line in f: - line = line.rstrip('\n') - if line: - tensor.append(list(map(float, line.split('\t')))) - return np.array(tensor, dtype='float32') + with tf.io.gfile.GFile(fpath, "r") as f: + tensor = [] + for line in f: + line = line.rstrip("\n") + if line: + tensor.append(list(map(float, line.split("\t")))) + return np.array(tensor, dtype="float32") def _assets_dir_to_logdir(assets_dir): - sub_path = os.path.sep + _PLUGINS_DIR + os.path.sep - if sub_path in assets_dir: - two_parents_up = os.pardir + os.path.sep + os.pardir - return os.path.abspath(os.path.join(assets_dir, two_parents_up)) - return assets_dir + sub_path = os.path.sep + _PLUGINS_DIR + os.path.sep + if sub_path in assets_dir: + two_parents_up = os.pardir + os.path.sep + os.pardir + return os.path.abspath(os.path.join(assets_dir, two_parents_up)) + return assets_dir def _latest_checkpoints_changed(configs, run_path_pairs): - """Returns true if the latest checkpoint has changed in any of the runs.""" - for run_name, assets_dir in run_path_pairs: - if run_name not in configs: - config = ProjectorConfig() - config_fpath = os.path.join(assets_dir, PROJECTOR_FILENAME) - if tf.io.gfile.exists(config_fpath): - with tf.io.gfile.GFile(config_fpath, 'r') as f: - file_content = f.read() - text_format.Merge(file_content, config) - else: - config = configs[run_name] - - # See if you can find a checkpoint file in the logdir. - logdir = _assets_dir_to_logdir(assets_dir) - ckpt_path = _find_latest_checkpoint(logdir) - if not ckpt_path: - continue - if config.model_checkpoint_path != ckpt_path: - return True - return False + """Returns true if the latest checkpoint has changed in any of the runs.""" + for run_name, assets_dir in run_path_pairs: + if run_name not in configs: + config = ProjectorConfig() + config_fpath = os.path.join(assets_dir, PROJECTOR_FILENAME) + if tf.io.gfile.exists(config_fpath): + with tf.io.gfile.GFile(config_fpath, "r") as f: + file_content = f.read() + text_format.Merge(file_content, config) + else: + config = configs[run_name] + + # See if you can find a checkpoint file in the logdir. + logdir = _assets_dir_to_logdir(assets_dir) + ckpt_path = _find_latest_checkpoint(logdir) + if not ckpt_path: + continue + if config.model_checkpoint_path != ckpt_path: + return True + return False def _parse_positive_int_param(request, param_name): - """Parses and asserts a positive (>0) integer query parameter. - - Args: - request: The Werkzeug Request object - param_name: Name of the parameter. - - Returns: - Param, or None, or -1 if parameter is not a positive integer. - """ - param = request.args.get(param_name) - if not param: - return None - try: - param = int(param) - if param <= 0: - raise ValueError() - return param - except ValueError: - return -1 + """Parses and asserts a positive (>0) integer query parameter. + + Args: + request: The Werkzeug Request object + param_name: Name of the parameter. + + Returns: + Param, or None, or -1 if parameter is not a positive integer. + """ + param = request.args.get(param_name) + if not param: + return None + try: + param = int(param) + if param <= 0: + raise ValueError() + return param + except ValueError: + return -1 def _rel_to_abs_asset_path(fpath, config_fpath): - fpath = os.path.expanduser(fpath) - if not os.path.isabs(fpath): - return os.path.join(os.path.dirname(config_fpath), fpath) - return fpath + fpath = os.path.expanduser(fpath) + if not os.path.isabs(fpath): + return os.path.join(os.path.dirname(config_fpath), fpath) + return fpath def _using_tf(): - """Return true if we're not using the fake TF API stub implementation.""" - return tf.__version__ != 'stub' + """Return true if we're not using the fake TF API stub implementation.""" + return tf.__version__ != "stub" class ProjectorPlugin(base_plugin.TBPlugin): - """Embedding projector.""" - - plugin_name = _PLUGIN_PREFIX_ROUTE - - def __init__(self, context): - """Instantiates ProjectorPlugin via TensorBoard core. - - Args: - context: A base_plugin.TBContext instance. - """ - self.multiplexer = context.multiplexer - self.logdir = context.logdir - self._handlers = None - self.readers = {} - self.run_paths = None - self._configs = {} - self.old_num_run_paths = None - self.config_fpaths = None - self.tensor_cache = LRUCache(_TENSOR_CACHE_CAPACITY) - - # Whether the plugin is active (has meaningful data to process and serve). - # Once the plugin is deemed active, we no longer re-compute the value - # because doing so is potentially expensive. - self._is_active = False - - # The running thread that is currently determining whether the plugin is - # active. If such a thread exists, do not start a duplicate thread. - self._thread_for_determining_is_active = None - - if self.multiplexer: - self.run_paths = self.multiplexer.RunPaths() - - def get_plugin_apps(self): - self._handlers = { - RUNS_ROUTE: self._serve_runs, - CONFIG_ROUTE: self._serve_config, - TENSOR_ROUTE: self._serve_tensor, - METADATA_ROUTE: self._serve_metadata, - BOOKMARKS_ROUTE: self._serve_bookmarks, - SPRITE_IMAGE_ROUTE: self._serve_sprite_image, - '/index.js': - functools.partial( + """Embedding projector.""" + + plugin_name = _PLUGIN_PREFIX_ROUTE + + def __init__(self, context): + """Instantiates ProjectorPlugin via TensorBoard core. + + Args: + context: A base_plugin.TBContext instance. + """ + self.multiplexer = context.multiplexer + self.logdir = context.logdir + self._handlers = None + self.readers = {} + self.run_paths = None + self._configs = {} + self.old_num_run_paths = None + self.config_fpaths = None + self.tensor_cache = LRUCache(_TENSOR_CACHE_CAPACITY) + + # Whether the plugin is active (has meaningful data to process and serve). + # Once the plugin is deemed active, we no longer re-compute the value + # because doing so is potentially expensive. + self._is_active = False + + # The running thread that is currently determining whether the plugin is + # active. If such a thread exists, do not start a duplicate thread. + self._thread_for_determining_is_active = None + + if self.multiplexer: + self.run_paths = self.multiplexer.RunPaths() + + def get_plugin_apps(self): + self._handlers = { + RUNS_ROUTE: self._serve_runs, + CONFIG_ROUTE: self._serve_config, + TENSOR_ROUTE: self._serve_tensor, + METADATA_ROUTE: self._serve_metadata, + BOOKMARKS_ROUTE: self._serve_bookmarks, + SPRITE_IMAGE_ROUTE: self._serve_sprite_image, + "/index.js": functools.partial( self._serve_file, - os.path.join('tf_projector_plugin', 'index.js')), - '/projector_binary.html': - functools.partial( + os.path.join("tf_projector_plugin", "index.js"), + ), + "/projector_binary.html": functools.partial( self._serve_file, - os.path.join('tf_projector_plugin', 'projector_binary.html')), - '/projector_binary.js': - functools.partial( + os.path.join("tf_projector_plugin", "projector_binary.html"), + ), + "/projector_binary.js": functools.partial( self._serve_file, - os.path.join('tf_projector_plugin', 'projector_binary.js')), - } - return self._handlers - - def is_active(self): - """Determines whether this plugin is active. - - This plugin is only active if any run has an embedding. - - Returns: - Whether any run has embedding data to show in the projector. - """ - if not self.multiplexer: - return False - - if self._is_active: - # We have already determined that the projector plugin should be active. - # Do not re-compute that. We have no reason to later set this plugin to be - # inactive. - return True - - if self._thread_for_determining_is_active: - # We are currently determining whether the plugin is active. Do not start - # a separate thread. - return self._is_active - - # The plugin is currently not active. The frontend might check again later. - # For now, spin off a separate thread to determine whether the plugin is - # active. - new_thread = threading.Thread( - target=self._determine_is_active, - name='ProjectorPluginIsActiveThread') - self._thread_for_determining_is_active = new_thread - new_thread.start() - return False - - def frontend_metadata(self): - return base_plugin.FrontendMetadata( - es_module_path='/index.js', - disable_reload=True, - ) - - def _determine_is_active(self): - """Determines whether the plugin is active. - - This method is run in a separate thread so that the plugin can offer an - immediate response to whether it is active and determine whether it should - be active in a separate thread. - """ - if self.configs: - self._is_active = True - self._thread_for_determining_is_active = None - - @property - def configs(self): - """Returns a map of run paths to `ProjectorConfig` protos.""" - run_path_pairs = list(self.run_paths.items()) - self._append_plugin_asset_directories(run_path_pairs) - # If there are no summary event files, the projector should still work, - # treating the `logdir` as the model checkpoint directory. - if not run_path_pairs: - run_path_pairs.append(('.', self.logdir)) - if (self._run_paths_changed() or - _latest_checkpoints_changed(self._configs, run_path_pairs)): - self.readers = {} - self._configs, self.config_fpaths = self._read_latest_config_files( - run_path_pairs) - self._augment_configs_with_checkpoint_info() - return self._configs - - def _run_paths_changed(self): - num_run_paths = len(list(self.run_paths.keys())) - if num_run_paths != self.old_num_run_paths: - self.old_num_run_paths = num_run_paths - return True - return False - - def _augment_configs_with_checkpoint_info(self): - for run, config in self._configs.items(): - for embedding in config.embeddings: - # Normalize the name of the embeddings. - if embedding.tensor_name.endswith(':0'): - embedding.tensor_name = embedding.tensor_name[:-2] - # Find the size of embeddings associated with a tensors file. - if embedding.tensor_path and not embedding.tensor_shape: - fpath = _rel_to_abs_asset_path(embedding.tensor_path, - self.config_fpaths[run]) - tensor = self.tensor_cache.get((run, embedding.tensor_name)) - if tensor is None: - tensor = _read_tensor_tsv_file(fpath) - self.tensor_cache.set((run, embedding.tensor_name), tensor) - embedding.tensor_shape.extend([len(tensor), len(tensor[0])]) - - reader = self._get_reader_for_run(run) - if not reader: - continue - # Augment the configuration with the tensors in the checkpoint file. - special_embedding = None - if config.embeddings and not config.embeddings[0].tensor_name: - special_embedding = config.embeddings[0] - config.embeddings.remove(special_embedding) - var_map = reader.get_variable_to_shape_map() - for tensor_name, tensor_shape in var_map.items(): - if len(tensor_shape) != 2: - continue - embedding = self._get_embedding(tensor_name, config) - if not embedding: - embedding = config.embeddings.add() - embedding.tensor_name = tensor_name - if special_embedding: - embedding.metadata_path = special_embedding.metadata_path - embedding.bookmarks_path = special_embedding.bookmarks_path - if not embedding.tensor_shape: - embedding.tensor_shape.extend(tensor_shape) - - # Remove configs that do not have any valid (2D) tensors. - runs_to_remove = [] - for run, config in self._configs.items(): - if not config.embeddings: - runs_to_remove.append(run) - for run in runs_to_remove: - del self._configs[run] - del self.config_fpaths[run] - - def _read_latest_config_files(self, run_path_pairs): - """Reads and returns the projector config files in every run directory.""" - configs = {} - config_fpaths = {} - for run_name, assets_dir in run_path_pairs: - config = ProjectorConfig() - config_fpath = os.path.join(assets_dir, PROJECTOR_FILENAME) - if tf.io.gfile.exists(config_fpath): - with tf.io.gfile.GFile(config_fpath, 'r') as f: - file_content = f.read() - text_format.Merge(file_content, config) - has_tensor_files = False - for embedding in config.embeddings: - if embedding.tensor_path: - if not embedding.tensor_name: - embedding.tensor_name = os.path.basename(embedding.tensor_path) - has_tensor_files = True - break - - if not config.model_checkpoint_path: - # See if you can find a checkpoint file in the logdir. - logdir = _assets_dir_to_logdir(assets_dir) - ckpt_path = _find_latest_checkpoint(logdir) - if not ckpt_path and not has_tensor_files: - continue - if ckpt_path: - config.model_checkpoint_path = ckpt_path - - # Sanity check for the checkpoint file existing. - if (config.model_checkpoint_path and _using_tf() and - not tf.io.gfile.glob(config.model_checkpoint_path + '*')): - logger.warn('Checkpoint file "%s" not found', - config.model_checkpoint_path) - continue - configs[run_name] = config - config_fpaths[run_name] = config_fpath - return configs, config_fpaths - - def _get_reader_for_run(self, run): - if run in self.readers: - return self.readers[run] - - config = self._configs[run] - reader = None - if config.model_checkpoint_path and _using_tf(): - try: - reader = tf.train.load_checkpoint(config.model_checkpoint_path) - except Exception: # pylint: disable=broad-except - logger.warn('Failed reading "%s"', config.model_checkpoint_path) - self.readers[run] = reader - return reader - - def _get_metadata_file_for_tensor(self, tensor_name, config): - embedding_info = self._get_embedding(tensor_name, config) - if embedding_info: - return embedding_info.metadata_path - return None - - def _get_bookmarks_file_for_tensor(self, tensor_name, config): - embedding_info = self._get_embedding(tensor_name, config) - if embedding_info: - return embedding_info.bookmarks_path - return None - - def _canonical_tensor_name(self, tensor_name): - if ':' not in tensor_name: - return tensor_name + ':0' - else: - return tensor_name - - def _get_embedding(self, tensor_name, config): - if not config.embeddings: - return None - for info in config.embeddings: - if (self._canonical_tensor_name(info.tensor_name) == - self._canonical_tensor_name(tensor_name)): - return info - return None - - def _append_plugin_asset_directories(self, run_path_pairs): - for run, assets in self.multiplexer.PluginAssets(_PLUGIN_NAME).items(): - if PROJECTOR_FILENAME not in assets: - continue - assets_dir = os.path.join(self.run_paths[run], _PLUGINS_DIR, _PLUGIN_NAME) - assets_path_pair = (run, os.path.abspath(assets_dir)) - run_path_pairs.append(assets_path_pair) - - @wrappers.Request.application - def _serve_file(self, file_path, request): - """Returns a resource file.""" - res_path = os.path.join(os.path.dirname(__file__), file_path) - with open(res_path, 'rb') as read_file: - mimetype = mimetypes.guess_type(file_path)[0] - return Respond(request, read_file.read(), content_type=mimetype) - - @wrappers.Request.application - def _serve_runs(self, request): - """Returns a list of runs that have embeddings.""" - return Respond(request, list(self.configs.keys()), 'application/json') - - @wrappers.Request.application - def _serve_config(self, request): - run = request.args.get('run') - if run is None: - return Respond(request, 'query parameter "run" is required', 'text/plain', - 400) - if run not in self.configs: - return Respond(request, 'Unknown run: "%s"' % run, 'text/plain', 400) - - config = self.configs[run] - return Respond(request, - json_format.MessageToJson(config), 'application/json') - - @wrappers.Request.application - def _serve_metadata(self, request): - run = request.args.get('run') - if run is None: - return Respond(request, 'query parameter "run" is required', 'text/plain', - 400) - - name = request.args.get('name') - if name is None: - return Respond(request, 'query parameter "name" is required', - 'text/plain', 400) - - num_rows = _parse_positive_int_param(request, 'num_rows') - if num_rows == -1: - return Respond(request, 'query parameter num_rows must be integer > 0', - 'text/plain', 400) - - if run not in self.configs: - return Respond(request, 'Unknown run: "%s"' % run, 'text/plain', 400) - - config = self.configs[run] - fpath = self._get_metadata_file_for_tensor(name, config) - if not fpath: - return Respond( - request, - 'No metadata file found for tensor "%s" in the config file "%s"' % - (name, self.config_fpaths[run]), 'text/plain', 400) - fpath = _rel_to_abs_asset_path(fpath, self.config_fpaths[run]) - if not tf.io.gfile.exists(fpath) or tf.io.gfile.isdir(fpath): - return Respond(request, '"%s" not found, or is not a file' % fpath, - 'text/plain', 400) - - num_header_rows = 0 - with tf.io.gfile.GFile(fpath, 'r') as f: - lines = [] - # Stream reading the file with early break in case the file doesn't fit in - # memory. - for line in f: - lines.append(line) - if len(lines) == 1 and '\t' in lines[0]: - num_header_rows = 1 - if num_rows and len(lines) >= num_rows + num_header_rows: - break - return Respond(request, ''.join(lines), 'text/plain') - - @wrappers.Request.application - def _serve_tensor(self, request): - run = request.args.get('run') - if run is None: - return Respond(request, 'query parameter "run" is required', 'text/plain', - 400) - - name = request.args.get('name') - if name is None: - return Respond(request, 'query parameter "name" is required', - 'text/plain', 400) - - num_rows = _parse_positive_int_param(request, 'num_rows') - if num_rows == -1: - return Respond(request, 'query parameter num_rows must be integer > 0', - 'text/plain', 400) - - if run not in self.configs: - return Respond(request, 'Unknown run: "%s"' % run, 'text/plain', 400) - - config = self.configs[run] - - tensor = self.tensor_cache.get((run, name)) - if tensor is None: - # See if there is a tensor file in the config. - embedding = self._get_embedding(name, config) - - if embedding and embedding.tensor_path: - fpath = _rel_to_abs_asset_path(embedding.tensor_path, - self.config_fpaths[run]) - if not tf.io.gfile.exists(fpath): - return Respond(request, - 'Tensor file "%s" does not exist' % fpath, - 'text/plain', 400) - tensor = _read_tensor_tsv_file(fpath) - else: - reader = self._get_reader_for_run(run) - if not reader or not reader.has_tensor(name): - return Respond(request, - 'Tensor "%s" not found in checkpoint dir "%s"' % - (name, config.model_checkpoint_path), 'text/plain', - 400) - try: - tensor = reader.get_tensor(name) - except tf.errors.InvalidArgumentError as e: - return Respond(request, str(e), 'text/plain', 400) - - self.tensor_cache.set((run, name), tensor) - - if num_rows: - tensor = tensor[:num_rows] - if tensor.dtype != 'float32': - tensor = tensor.astype(dtype='float32', copy=False) - data_bytes = tensor.tobytes() - return Respond(request, data_bytes, 'application/octet-stream') - - @wrappers.Request.application - def _serve_bookmarks(self, request): - run = request.args.get('run') - if not run: - return Respond(request, 'query parameter "run" is required', 'text/plain', - 400) - - name = request.args.get('name') - if name is None: - return Respond(request, 'query parameter "name" is required', - 'text/plain', 400) - - if run not in self.configs: - return Respond(request, 'Unknown run: "%s"' % run, 'text/plain', 400) - - config = self.configs[run] - fpath = self._get_bookmarks_file_for_tensor(name, config) - if not fpath: - return Respond( - request, - 'No bookmarks file found for tensor "%s" in the config file "%s"' % - (name, self.config_fpaths[run]), 'text/plain', 400) - fpath = _rel_to_abs_asset_path(fpath, self.config_fpaths[run]) - if not tf.io.gfile.exists(fpath) or tf.io.gfile.isdir(fpath): - return Respond(request, '"%s" not found, or is not a file' % fpath, - 'text/plain', 400) - - bookmarks_json = None - with tf.io.gfile.GFile(fpath, 'rb') as f: - bookmarks_json = f.read() - return Respond(request, bookmarks_json, 'application/json') - - @wrappers.Request.application - def _serve_sprite_image(self, request): - run = request.args.get('run') - if not run: - return Respond(request, 'query parameter "run" is required', 'text/plain', - 400) - - name = request.args.get('name') - if name is None: - return Respond(request, 'query parameter "name" is required', - 'text/plain', 400) - - if run not in self.configs: - return Respond(request, 'Unknown run: "%s"' % run, 'text/plain', 400) - - config = self.configs[run] - embedding_info = self._get_embedding(name, config) - - if not embedding_info or not embedding_info.sprite.image_path: - return Respond( - request, - 'No sprite image file found for tensor "%s" in the config file "%s"' % - (name, self.config_fpaths[run]), 'text/plain', 400) - - fpath = os.path.expanduser(embedding_info.sprite.image_path) - fpath = _rel_to_abs_asset_path(fpath, self.config_fpaths[run]) - if not tf.io.gfile.exists(fpath) or tf.io.gfile.isdir(fpath): - return Respond(request, '"%s" does not exist or is directory' % fpath, - 'text/plain', 400) - f = tf.io.gfile.GFile(fpath, 'rb') - encoded_image_string = f.read() - f.close() - image_type = imghdr.what(None, encoded_image_string) - mime_type = _IMGHDR_TO_MIMETYPE.get(image_type, _DEFAULT_IMAGE_MIMETYPE) - return Respond(request, encoded_image_string, mime_type) + os.path.join("tf_projector_plugin", "projector_binary.js"), + ), + } + return self._handlers + + def is_active(self): + """Determines whether this plugin is active. + + This plugin is only active if any run has an embedding. + + Returns: + Whether any run has embedding data to show in the projector. + """ + if not self.multiplexer: + return False + + if self._is_active: + # We have already determined that the projector plugin should be active. + # Do not re-compute that. We have no reason to later set this plugin to be + # inactive. + return True + + if self._thread_for_determining_is_active: + # We are currently determining whether the plugin is active. Do not start + # a separate thread. + return self._is_active + + # The plugin is currently not active. The frontend might check again later. + # For now, spin off a separate thread to determine whether the plugin is + # active. + new_thread = threading.Thread( + target=self._determine_is_active, + name="ProjectorPluginIsActiveThread", + ) + self._thread_for_determining_is_active = new_thread + new_thread.start() + return False + + def frontend_metadata(self): + return base_plugin.FrontendMetadata( + es_module_path="/index.js", disable_reload=True, + ) + + def _determine_is_active(self): + """Determines whether the plugin is active. + + This method is run in a separate thread so that the plugin can + offer an immediate response to whether it is active and + determine whether it should be active in a separate thread. + """ + if self.configs: + self._is_active = True + self._thread_for_determining_is_active = None + + @property + def configs(self): + """Returns a map of run paths to `ProjectorConfig` protos.""" + run_path_pairs = list(self.run_paths.items()) + self._append_plugin_asset_directories(run_path_pairs) + # If there are no summary event files, the projector should still work, + # treating the `logdir` as the model checkpoint directory. + if not run_path_pairs: + run_path_pairs.append((".", self.logdir)) + if self._run_paths_changed() or _latest_checkpoints_changed( + self._configs, run_path_pairs + ): + self.readers = {} + self._configs, self.config_fpaths = self._read_latest_config_files( + run_path_pairs + ) + self._augment_configs_with_checkpoint_info() + return self._configs + + def _run_paths_changed(self): + num_run_paths = len(list(self.run_paths.keys())) + if num_run_paths != self.old_num_run_paths: + self.old_num_run_paths = num_run_paths + return True + return False + + def _augment_configs_with_checkpoint_info(self): + for run, config in self._configs.items(): + for embedding in config.embeddings: + # Normalize the name of the embeddings. + if embedding.tensor_name.endswith(":0"): + embedding.tensor_name = embedding.tensor_name[:-2] + # Find the size of embeddings associated with a tensors file. + if embedding.tensor_path and not embedding.tensor_shape: + fpath = _rel_to_abs_asset_path( + embedding.tensor_path, self.config_fpaths[run] + ) + tensor = self.tensor_cache.get((run, embedding.tensor_name)) + if tensor is None: + tensor = _read_tensor_tsv_file(fpath) + self.tensor_cache.set( + (run, embedding.tensor_name), tensor + ) + embedding.tensor_shape.extend([len(tensor), len(tensor[0])]) + + reader = self._get_reader_for_run(run) + if not reader: + continue + # Augment the configuration with the tensors in the checkpoint file. + special_embedding = None + if config.embeddings and not config.embeddings[0].tensor_name: + special_embedding = config.embeddings[0] + config.embeddings.remove(special_embedding) + var_map = reader.get_variable_to_shape_map() + for tensor_name, tensor_shape in var_map.items(): + if len(tensor_shape) != 2: + continue + embedding = self._get_embedding(tensor_name, config) + if not embedding: + embedding = config.embeddings.add() + embedding.tensor_name = tensor_name + if special_embedding: + embedding.metadata_path = ( + special_embedding.metadata_path + ) + embedding.bookmarks_path = ( + special_embedding.bookmarks_path + ) + if not embedding.tensor_shape: + embedding.tensor_shape.extend(tensor_shape) + + # Remove configs that do not have any valid (2D) tensors. + runs_to_remove = [] + for run, config in self._configs.items(): + if not config.embeddings: + runs_to_remove.append(run) + for run in runs_to_remove: + del self._configs[run] + del self.config_fpaths[run] + + def _read_latest_config_files(self, run_path_pairs): + """Reads and returns the projector config files in every run + directory.""" + configs = {} + config_fpaths = {} + for run_name, assets_dir in run_path_pairs: + config = ProjectorConfig() + config_fpath = os.path.join(assets_dir, PROJECTOR_FILENAME) + if tf.io.gfile.exists(config_fpath): + with tf.io.gfile.GFile(config_fpath, "r") as f: + file_content = f.read() + text_format.Merge(file_content, config) + has_tensor_files = False + for embedding in config.embeddings: + if embedding.tensor_path: + if not embedding.tensor_name: + embedding.tensor_name = os.path.basename( + embedding.tensor_path + ) + has_tensor_files = True + break + + if not config.model_checkpoint_path: + # See if you can find a checkpoint file in the logdir. + logdir = _assets_dir_to_logdir(assets_dir) + ckpt_path = _find_latest_checkpoint(logdir) + if not ckpt_path and not has_tensor_files: + continue + if ckpt_path: + config.model_checkpoint_path = ckpt_path + + # Sanity check for the checkpoint file existing. + if ( + config.model_checkpoint_path + and _using_tf() + and not tf.io.gfile.glob(config.model_checkpoint_path + "*") + ): + logger.warn( + 'Checkpoint file "%s" not found', + config.model_checkpoint_path, + ) + continue + configs[run_name] = config + config_fpaths[run_name] = config_fpath + return configs, config_fpaths + + def _get_reader_for_run(self, run): + if run in self.readers: + return self.readers[run] + + config = self._configs[run] + reader = None + if config.model_checkpoint_path and _using_tf(): + try: + reader = tf.train.load_checkpoint(config.model_checkpoint_path) + except Exception: # pylint: disable=broad-except + logger.warn('Failed reading "%s"', config.model_checkpoint_path) + self.readers[run] = reader + return reader + + def _get_metadata_file_for_tensor(self, tensor_name, config): + embedding_info = self._get_embedding(tensor_name, config) + if embedding_info: + return embedding_info.metadata_path + return None + + def _get_bookmarks_file_for_tensor(self, tensor_name, config): + embedding_info = self._get_embedding(tensor_name, config) + if embedding_info: + return embedding_info.bookmarks_path + return None + + def _canonical_tensor_name(self, tensor_name): + if ":" not in tensor_name: + return tensor_name + ":0" + else: + return tensor_name + + def _get_embedding(self, tensor_name, config): + if not config.embeddings: + return None + for info in config.embeddings: + if self._canonical_tensor_name( + info.tensor_name + ) == self._canonical_tensor_name(tensor_name): + return info + return None + + def _append_plugin_asset_directories(self, run_path_pairs): + for run, assets in self.multiplexer.PluginAssets(_PLUGIN_NAME).items(): + if PROJECTOR_FILENAME not in assets: + continue + assets_dir = os.path.join( + self.run_paths[run], _PLUGINS_DIR, _PLUGIN_NAME + ) + assets_path_pair = (run, os.path.abspath(assets_dir)) + run_path_pairs.append(assets_path_pair) + + @wrappers.Request.application + def _serve_file(self, file_path, request): + """Returns a resource file.""" + res_path = os.path.join(os.path.dirname(__file__), file_path) + with open(res_path, "rb") as read_file: + mimetype = mimetypes.guess_type(file_path)[0] + return Respond(request, read_file.read(), content_type=mimetype) + + @wrappers.Request.application + def _serve_runs(self, request): + """Returns a list of runs that have embeddings.""" + return Respond(request, list(self.configs.keys()), "application/json") + + @wrappers.Request.application + def _serve_config(self, request): + run = request.args.get("run") + if run is None: + return Respond( + request, 'query parameter "run" is required', "text/plain", 400 + ) + if run not in self.configs: + return Respond( + request, 'Unknown run: "%s"' % run, "text/plain", 400 + ) + + config = self.configs[run] + return Respond( + request, json_format.MessageToJson(config), "application/json" + ) + + @wrappers.Request.application + def _serve_metadata(self, request): + run = request.args.get("run") + if run is None: + return Respond( + request, 'query parameter "run" is required', "text/plain", 400 + ) + + name = request.args.get("name") + if name is None: + return Respond( + request, 'query parameter "name" is required', "text/plain", 400 + ) + + num_rows = _parse_positive_int_param(request, "num_rows") + if num_rows == -1: + return Respond( + request, + "query parameter num_rows must be integer > 0", + "text/plain", + 400, + ) + + if run not in self.configs: + return Respond( + request, 'Unknown run: "%s"' % run, "text/plain", 400 + ) + + config = self.configs[run] + fpath = self._get_metadata_file_for_tensor(name, config) + if not fpath: + return Respond( + request, + 'No metadata file found for tensor "%s" in the config file "%s"' + % (name, self.config_fpaths[run]), + "text/plain", + 400, + ) + fpath = _rel_to_abs_asset_path(fpath, self.config_fpaths[run]) + if not tf.io.gfile.exists(fpath) or tf.io.gfile.isdir(fpath): + return Respond( + request, + '"%s" not found, or is not a file' % fpath, + "text/plain", + 400, + ) + + num_header_rows = 0 + with tf.io.gfile.GFile(fpath, "r") as f: + lines = [] + # Stream reading the file with early break in case the file doesn't fit in + # memory. + for line in f: + lines.append(line) + if len(lines) == 1 and "\t" in lines[0]: + num_header_rows = 1 + if num_rows and len(lines) >= num_rows + num_header_rows: + break + return Respond(request, "".join(lines), "text/plain") + + @wrappers.Request.application + def _serve_tensor(self, request): + run = request.args.get("run") + if run is None: + return Respond( + request, 'query parameter "run" is required', "text/plain", 400 + ) + + name = request.args.get("name") + if name is None: + return Respond( + request, 'query parameter "name" is required', "text/plain", 400 + ) + + num_rows = _parse_positive_int_param(request, "num_rows") + if num_rows == -1: + return Respond( + request, + "query parameter num_rows must be integer > 0", + "text/plain", + 400, + ) + + if run not in self.configs: + return Respond( + request, 'Unknown run: "%s"' % run, "text/plain", 400 + ) + + config = self.configs[run] + + tensor = self.tensor_cache.get((run, name)) + if tensor is None: + # See if there is a tensor file in the config. + embedding = self._get_embedding(name, config) + + if embedding and embedding.tensor_path: + fpath = _rel_to_abs_asset_path( + embedding.tensor_path, self.config_fpaths[run] + ) + if not tf.io.gfile.exists(fpath): + return Respond( + request, + 'Tensor file "%s" does not exist' % fpath, + "text/plain", + 400, + ) + tensor = _read_tensor_tsv_file(fpath) + else: + reader = self._get_reader_for_run(run) + if not reader or not reader.has_tensor(name): + return Respond( + request, + 'Tensor "%s" not found in checkpoint dir "%s"' + % (name, config.model_checkpoint_path), + "text/plain", + 400, + ) + try: + tensor = reader.get_tensor(name) + except tf.errors.InvalidArgumentError as e: + return Respond(request, str(e), "text/plain", 400) + + self.tensor_cache.set((run, name), tensor) + + if num_rows: + tensor = tensor[:num_rows] + if tensor.dtype != "float32": + tensor = tensor.astype(dtype="float32", copy=False) + data_bytes = tensor.tobytes() + return Respond(request, data_bytes, "application/octet-stream") + + @wrappers.Request.application + def _serve_bookmarks(self, request): + run = request.args.get("run") + if not run: + return Respond( + request, 'query parameter "run" is required', "text/plain", 400 + ) + + name = request.args.get("name") + if name is None: + return Respond( + request, 'query parameter "name" is required', "text/plain", 400 + ) + + if run not in self.configs: + return Respond( + request, 'Unknown run: "%s"' % run, "text/plain", 400 + ) + + config = self.configs[run] + fpath = self._get_bookmarks_file_for_tensor(name, config) + if not fpath: + return Respond( + request, + 'No bookmarks file found for tensor "%s" in the config file "%s"' + % (name, self.config_fpaths[run]), + "text/plain", + 400, + ) + fpath = _rel_to_abs_asset_path(fpath, self.config_fpaths[run]) + if not tf.io.gfile.exists(fpath) or tf.io.gfile.isdir(fpath): + return Respond( + request, + '"%s" not found, or is not a file' % fpath, + "text/plain", + 400, + ) + + bookmarks_json = None + with tf.io.gfile.GFile(fpath, "rb") as f: + bookmarks_json = f.read() + return Respond(request, bookmarks_json, "application/json") + + @wrappers.Request.application + def _serve_sprite_image(self, request): + run = request.args.get("run") + if not run: + return Respond( + request, 'query parameter "run" is required', "text/plain", 400 + ) + + name = request.args.get("name") + if name is None: + return Respond( + request, 'query parameter "name" is required', "text/plain", 400 + ) + + if run not in self.configs: + return Respond( + request, 'Unknown run: "%s"' % run, "text/plain", 400 + ) + + config = self.configs[run] + embedding_info = self._get_embedding(name, config) + + if not embedding_info or not embedding_info.sprite.image_path: + return Respond( + request, + 'No sprite image file found for tensor "%s" in the config file "%s"' + % (name, self.config_fpaths[run]), + "text/plain", + 400, + ) + + fpath = os.path.expanduser(embedding_info.sprite.image_path) + fpath = _rel_to_abs_asset_path(fpath, self.config_fpaths[run]) + if not tf.io.gfile.exists(fpath) or tf.io.gfile.isdir(fpath): + return Respond( + request, + '"%s" does not exist or is directory' % fpath, + "text/plain", + 400, + ) + f = tf.io.gfile.GFile(fpath, "rb") + encoded_image_string = f.read() + f.close() + image_type = imghdr.what(None, encoded_image_string) + mime_type = _IMGHDR_TO_MIMETYPE.get(image_type, _DEFAULT_IMAGE_MIMETYPE) + return Respond(request, encoded_image_string, mime_type) def _find_latest_checkpoint(dir_path): - if not _using_tf(): - return None - try: - ckpt_path = tf.train.latest_checkpoint(dir_path) - if not ckpt_path: - # Check the parent directory. - ckpt_path = tf.train.latest_checkpoint(os.path.join(dir_path, os.pardir)) - return ckpt_path - except tf.errors.NotFoundError: - return None + if not _using_tf(): + return None + try: + ckpt_path = tf.train.latest_checkpoint(dir_path) + if not ckpt_path: + # Check the parent directory. + ckpt_path = tf.train.latest_checkpoint( + os.path.join(dir_path, os.pardir) + ) + return ckpt_path + except tf.errors.NotFoundError: + return None diff --git a/tensorboard/plugins/projector/projector_plugin_test.py b/tensorboard/plugins/projector/projector_plugin_test.py index 4907bd6144..a6b6932353 100644 --- a/tensorboard/plugins/projector/projector_plugin_test.py +++ b/tensorboard/plugins/projector/projector_plugin_test.py @@ -33,7 +33,9 @@ from google.protobuf import text_format from tensorboard.backend import application -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from tensorboard.backend.event_processing import ( + plugin_event_multiplexer as event_multiplexer, +) from tensorboard.compat.proto import event_pb2 from tensorboard.compat.proto import summary_pb2 from tensorboard.compat import tf as tf_compat @@ -44,378 +46,393 @@ tf.compat.v1.disable_v2_behavior() -USING_REAL_TF = tf_compat.__version__ != 'stub' +USING_REAL_TF = tf_compat.__version__ != "stub" class ProjectorAppTest(tf.test.TestCase): - - def __init__(self, *args, **kwargs): - super(ProjectorAppTest, self).__init__(*args, **kwargs) - self.logdir = None - self.plugin = None - self.server = None - - def setUp(self): - self.log_dir = self.get_temp_dir() - - def testRunsWithValidCheckpoint(self): - self._GenerateProjectorTestData() - self._SetupWSGIApp() - run_json = self._GetJson('/data/plugin/projector/runs') - if USING_REAL_TF: + def __init__(self, *args, **kwargs): + super(ProjectorAppTest, self).__init__(*args, **kwargs) + self.logdir = None + self.plugin = None + self.server = None + + def setUp(self): + self.log_dir = self.get_temp_dir() + + def testRunsWithValidCheckpoint(self): + self._GenerateProjectorTestData() + self._SetupWSGIApp() + run_json = self._GetJson("/data/plugin/projector/runs") + if USING_REAL_TF: + self.assertTrue(run_json) + else: + self.assertFalse(run_json) + + def testRunsWithNoCheckpoint(self): + self._SetupWSGIApp() + run_json = self._GetJson("/data/plugin/projector/runs") + self.assertEqual(run_json, []) + + def testRunsWithInvalidModelCheckpointPath(self): + checkpoint_file = os.path.join(self.log_dir, "checkpoint") + f = open(checkpoint_file, "w") + f.write('model_checkpoint_path: "does_not_exist"\n') + f.write('all_model_checkpoint_paths: "does_not_exist"\n') + f.close() + self._SetupWSGIApp() + + run_json = self._GetJson("/data/plugin/projector/runs") + self.assertEqual(run_json, []) + + # TODO(#2007): Cleanly separate out projector tests that require real TF + @unittest.skipUnless(USING_REAL_TF, "Test only passes when using real TF") + def testRunsWithInvalidModelCheckpointPathInConfig(self): + config_path = os.path.join(self.log_dir, "projector_config.pbtxt") + config = projector_config_pb2.ProjectorConfig() + config.model_checkpoint_path = "does_not_exist" + embedding = config.embeddings.add() + embedding.tensor_name = "var1" + with tf.io.gfile.GFile(config_path, "w") as f: + f.write(text_format.MessageToString(config)) + self._SetupWSGIApp() + + run_json = self._GetJson("/data/plugin/projector/runs") + self.assertEqual(run_json, []) + + # TODO(#2007): Cleanly separate out projector tests that require real TF + @unittest.skipUnless(USING_REAL_TF, "Test only passes when using real TF") + def testInfoWithValidCheckpointNoEventsData(self): + self._GenerateProjectorTestData() + self._SetupWSGIApp() + + info_json = self._GetJson("/data/plugin/projector/info?run=.") + self.assertItemsEqual( + info_json["embeddings"], + [ + { + "tensorShape": [1, 2], + "tensorName": "var1", + "bookmarksPath": "bookmarks.json", + }, + {"tensorShape": [10, 10], "tensorName": "var2"}, + {"tensorShape": [100, 100], "tensorName": "var3"}, + ], + ) + + # TODO(#2007): Cleanly separate out projector tests that require real TF + @unittest.skipUnless(USING_REAL_TF, "Test only passes when using real TF") + def testInfoWithValidCheckpointAndEventsData(self): + self._GenerateProjectorTestData() + self._GenerateEventsData() + self._SetupWSGIApp() + + run_json = self._GetJson("/data/plugin/projector/runs") self.assertTrue(run_json) - else: - self.assertFalse(run_json) - - def testRunsWithNoCheckpoint(self): - self._SetupWSGIApp() - run_json = self._GetJson('/data/plugin/projector/runs') - self.assertEqual(run_json, []) - - def testRunsWithInvalidModelCheckpointPath(self): - checkpoint_file = os.path.join(self.log_dir, 'checkpoint') - f = open(checkpoint_file, 'w') - f.write('model_checkpoint_path: "does_not_exist"\n') - f.write('all_model_checkpoint_paths: "does_not_exist"\n') - f.close() - self._SetupWSGIApp() - - run_json = self._GetJson('/data/plugin/projector/runs') - self.assertEqual(run_json, []) - - # TODO(#2007): Cleanly separate out projector tests that require real TF - @unittest.skipUnless(USING_REAL_TF, 'Test only passes when using real TF') - def testRunsWithInvalidModelCheckpointPathInConfig(self): - config_path = os.path.join(self.log_dir, 'projector_config.pbtxt') - config = projector_config_pb2.ProjectorConfig() - config.model_checkpoint_path = 'does_not_exist' - embedding = config.embeddings.add() - embedding.tensor_name = 'var1' - with tf.io.gfile.GFile(config_path, 'w') as f: - f.write(text_format.MessageToString(config)) - self._SetupWSGIApp() - - run_json = self._GetJson('/data/plugin/projector/runs') - self.assertEqual(run_json, []) - - # TODO(#2007): Cleanly separate out projector tests that require real TF - @unittest.skipUnless(USING_REAL_TF, 'Test only passes when using real TF') - def testInfoWithValidCheckpointNoEventsData(self): - self._GenerateProjectorTestData() - self._SetupWSGIApp() - - info_json = self._GetJson('/data/plugin/projector/info?run=.') - self.assertItemsEqual(info_json['embeddings'], [{ - 'tensorShape': [1, 2], - 'tensorName': 'var1', - 'bookmarksPath': 'bookmarks.json' - }, { - 'tensorShape': [10, 10], - 'tensorName': 'var2' - }, { - 'tensorShape': [100, 100], - 'tensorName': 'var3' - }]) - - # TODO(#2007): Cleanly separate out projector tests that require real TF - @unittest.skipUnless(USING_REAL_TF, 'Test only passes when using real TF') - def testInfoWithValidCheckpointAndEventsData(self): - self._GenerateProjectorTestData() - self._GenerateEventsData() - self._SetupWSGIApp() - - run_json = self._GetJson('/data/plugin/projector/runs') - self.assertTrue(run_json) - run = run_json[0] - info_json = self._GetJson('/data/plugin/projector/info?run=%s' % run) - self.assertItemsEqual(info_json['embeddings'], [{ - 'tensorShape': [1, 2], - 'tensorName': 'var1', - 'bookmarksPath': 'bookmarks.json' - }, { - 'tensorShape': [10, 10], - 'tensorName': 'var2' - }, { - 'tensorShape': [100, 100], - 'tensorName': 'var3' - }]) - - # TODO(#2007): Cleanly separate out projector tests that require real TF - @unittest.skipUnless(USING_REAL_TF, 'Test only passes when using real TF') - def testTensorWithValidCheckpoint(self): - self._GenerateProjectorTestData() - self._SetupWSGIApp() - - url = '/data/plugin/projector/tensor?run=.&name=var1' - tensor_bytes = self._Get(url).data - expected_tensor = np.array([[6, 6]], dtype=np.float32) - self._AssertTensorResponse(tensor_bytes, expected_tensor) - - def testBookmarksRequestMissingRunAndName(self): - self._GenerateProjectorTestData() - self._SetupWSGIApp() - - url = '/data/plugin/projector/bookmarks' - self.assertEqual(self._Get(url).status_code, 400) - - def testBookmarksRequestMissingName(self): - self._GenerateProjectorTestData() - self._SetupWSGIApp() - - url = '/data/plugin/projector/bookmarks?run=.' - self.assertEqual(self._Get(url).status_code, 400) - - def testBookmarksRequestMissingRun(self): - self._GenerateProjectorTestData() - self._SetupWSGIApp() - - url = '/data/plugin/projector/bookmarks?name=var1' - self.assertEqual(self._Get(url).status_code, 400) - - def testBookmarksUnknownRun(self): - self._GenerateProjectorTestData() - self._SetupWSGIApp() - - url = '/data/plugin/projector/bookmarks?run=unknown&name=var1' - self.assertEqual(self._Get(url).status_code, 400) - - def testBookmarksUnknownName(self): - self._GenerateProjectorTestData() - self._SetupWSGIApp() - - url = '/data/plugin/projector/bookmarks?run=.&name=unknown' - self.assertEqual(self._Get(url).status_code, 400) - - # TODO(#2007): Cleanly separate out projector tests that require real TF - @unittest.skipUnless(USING_REAL_TF, 'Test only passes when using real TF') - def testBookmarks(self): - self._GenerateProjectorTestData() - self._SetupWSGIApp() - - url = '/data/plugin/projector/bookmarks?run=.&name=var1' - bookmark = self._GetJson(url) - self.assertEqual(bookmark, {'a': 'b'}) - - def testEndpointsNoAssets(self): - g = tf.Graph() - - with test_util.FileWriterCache.get(self.log_dir) as writer: - writer.add_graph(g) - - self._SetupWSGIApp() - run_json = self._GetJson('/data/plugin/projector/runs') - self.assertEqual(run_json, []) - - def _AssertTensorResponse(self, tensor_bytes, expected_tensor): - tensor = np.reshape(np.fromstring(tensor_bytes, dtype=np.float32), - expected_tensor.shape) - self.assertTrue(np.array_equal(tensor, expected_tensor)) - - # TODO(#2007): Cleanly separate out projector tests that require real TF - @unittest.skipUnless(USING_REAL_TF, 'Test only passes when using real TF') - def testPluginIsActive(self): - self._GenerateProjectorTestData() - self._SetupWSGIApp() - - patcher = tf.compat.v1.test.mock.patch('threading.Thread.start', autospec=True) - mock = patcher.start() - self.addCleanup(patcher.stop) - - # The projector plugin has not yet determined whether it is active, but it - # should now start a thread to determine that. - self.assertFalse(self.plugin.is_active()) - thread = self.plugin._thread_for_determining_is_active - mock.assert_called_once_with(thread) - - # The logic has not finished running yet, so the plugin should still not - # have deemed itself to be active. - self.assertFalse(self.plugin.is_active()) - mock.assert_called_once_with(thread) - - self.plugin._thread_for_determining_is_active.run() - - # The plugin later finds that embedding data is available. - self.assertTrue(self.plugin.is_active()) - - # Subsequent calls to is_active should not start a new thread. The mock - # should only have been called once throughout this test. - self.assertTrue(self.plugin.is_active()) - mock.assert_called_once_with(thread) - - def testPluginIsNotActive(self): - self._SetupWSGIApp() - - # The is_active method makes use of a separate thread, so we mock threading - # behavior to make this test deterministic. - patcher = tf.compat.v1.test.mock.patch('threading.Thread.start', autospec=True) - mock = patcher.start() - self.addCleanup(patcher.stop) - - # The projector plugin has not yet determined whether it is active, but it - # should now start a thread to determine that. - self.assertFalse(self.plugin.is_active()) - mock.assert_called_once_with(self.plugin._thread_for_determining_is_active) - - self.plugin._thread_for_determining_is_active.run() - - # The plugin later finds that embedding data is not available. - self.assertFalse(self.plugin.is_active()) - - # Furthermore, the plugin should have spawned a new thread to check whether - # it is active (because it might now be active even though it had not been - # beforehand), so the mock should now be called twice. - self.assertEqual(2, mock.call_count) - - def _SetupWSGIApp(self): - multiplexer = event_multiplexer.EventMultiplexer( - size_guidance=application.DEFAULT_SIZE_GUIDANCE, - purge_orphaned_data=True) - context = base_plugin.TBContext( - logdir=self.log_dir, multiplexer=multiplexer) - self.plugin = projector_plugin.ProjectorPlugin(context) - wsgi_app = application.TensorBoardWSGI([self.plugin]) - self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse) - - def _Get(self, path): - return self.server.get(path) - - def _GetJson(self, path): - response = self.server.get(path) - data = response.data - if response.headers.get('Content-Encoding') == 'gzip': - data = gzip.GzipFile('', 'rb', 9, io.BytesIO(data)).read() - return json.loads(data.decode('utf-8')) - - def _GenerateEventsData(self): - with test_util.FileWriterCache.get(self.log_dir) as fw: - event = event_pb2.Event( - wall_time=1, - step=1, - summary=summary_pb2.Summary( - value=[summary_pb2.Summary.Value(tag='s1', simple_value=0)])) - fw.add_event(event) - - def _GenerateProjectorTestData(self): - config_path = os.path.join(self.log_dir, 'projector_config.pbtxt') - config = projector_config_pb2.ProjectorConfig() - embedding = config.embeddings.add() - # Add an embedding by its canonical tensor name. - embedding.tensor_name = 'var1:0' - - with tf.io.gfile.GFile(os.path.join(self.log_dir, 'bookmarks.json'), 'w') as f: - f.write('{"a": "b"}') - embedding.bookmarks_path = 'bookmarks.json' - - config_pbtxt = text_format.MessageToString(config) - with tf.io.gfile.GFile(config_path, 'w') as f: - f.write(config_pbtxt) - - # Write a checkpoint with some dummy variables. - with tf.Graph().as_default(): - sess = tf.compat.v1.Session() - checkpoint_path = os.path.join(self.log_dir, 'model') - tf.compat.v1.get_variable('var1', - initializer=tf.constant(np.full([1, 2], 6.0))) - tf.compat.v1.get_variable('var2', [10, 10]) - tf.compat.v1.get_variable('var3', [100, 100]) - sess.run(tf.compat.v1.global_variables_initializer()) - saver = tf.compat.v1.train.Saver(write_version=tf.compat.v1.train.SaverDef.V1) - saver.save(sess, checkpoint_path) + run = run_json[0] + info_json = self._GetJson("/data/plugin/projector/info?run=%s" % run) + self.assertItemsEqual( + info_json["embeddings"], + [ + { + "tensorShape": [1, 2], + "tensorName": "var1", + "bookmarksPath": "bookmarks.json", + }, + {"tensorShape": [10, 10], "tensorName": "var2"}, + {"tensorShape": [100, 100], "tensorName": "var3"}, + ], + ) + + # TODO(#2007): Cleanly separate out projector tests that require real TF + @unittest.skipUnless(USING_REAL_TF, "Test only passes when using real TF") + def testTensorWithValidCheckpoint(self): + self._GenerateProjectorTestData() + self._SetupWSGIApp() + + url = "/data/plugin/projector/tensor?run=.&name=var1" + tensor_bytes = self._Get(url).data + expected_tensor = np.array([[6, 6]], dtype=np.float32) + self._AssertTensorResponse(tensor_bytes, expected_tensor) + + def testBookmarksRequestMissingRunAndName(self): + self._GenerateProjectorTestData() + self._SetupWSGIApp() + + url = "/data/plugin/projector/bookmarks" + self.assertEqual(self._Get(url).status_code, 400) + + def testBookmarksRequestMissingName(self): + self._GenerateProjectorTestData() + self._SetupWSGIApp() + + url = "/data/plugin/projector/bookmarks?run=." + self.assertEqual(self._Get(url).status_code, 400) + + def testBookmarksRequestMissingRun(self): + self._GenerateProjectorTestData() + self._SetupWSGIApp() + + url = "/data/plugin/projector/bookmarks?name=var1" + self.assertEqual(self._Get(url).status_code, 400) + + def testBookmarksUnknownRun(self): + self._GenerateProjectorTestData() + self._SetupWSGIApp() + + url = "/data/plugin/projector/bookmarks?run=unknown&name=var1" + self.assertEqual(self._Get(url).status_code, 400) + + def testBookmarksUnknownName(self): + self._GenerateProjectorTestData() + self._SetupWSGIApp() + + url = "/data/plugin/projector/bookmarks?run=.&name=unknown" + self.assertEqual(self._Get(url).status_code, 400) + + # TODO(#2007): Cleanly separate out projector tests that require real TF + @unittest.skipUnless(USING_REAL_TF, "Test only passes when using real TF") + def testBookmarks(self): + self._GenerateProjectorTestData() + self._SetupWSGIApp() + + url = "/data/plugin/projector/bookmarks?run=.&name=var1" + bookmark = self._GetJson(url) + self.assertEqual(bookmark, {"a": "b"}) + + def testEndpointsNoAssets(self): + g = tf.Graph() + + with test_util.FileWriterCache.get(self.log_dir) as writer: + writer.add_graph(g) + + self._SetupWSGIApp() + run_json = self._GetJson("/data/plugin/projector/runs") + self.assertEqual(run_json, []) + + def _AssertTensorResponse(self, tensor_bytes, expected_tensor): + tensor = np.reshape( + np.fromstring(tensor_bytes, dtype=np.float32), expected_tensor.shape + ) + self.assertTrue(np.array_equal(tensor, expected_tensor)) + + # TODO(#2007): Cleanly separate out projector tests that require real TF + @unittest.skipUnless(USING_REAL_TF, "Test only passes when using real TF") + def testPluginIsActive(self): + self._GenerateProjectorTestData() + self._SetupWSGIApp() + + patcher = tf.compat.v1.test.mock.patch( + "threading.Thread.start", autospec=True + ) + mock = patcher.start() + self.addCleanup(patcher.stop) + + # The projector plugin has not yet determined whether it is active, but it + # should now start a thread to determine that. + self.assertFalse(self.plugin.is_active()) + thread = self.plugin._thread_for_determining_is_active + mock.assert_called_once_with(thread) + + # The logic has not finished running yet, so the plugin should still not + # have deemed itself to be active. + self.assertFalse(self.plugin.is_active()) + mock.assert_called_once_with(thread) + + self.plugin._thread_for_determining_is_active.run() + + # The plugin later finds that embedding data is available. + self.assertTrue(self.plugin.is_active()) + + # Subsequent calls to is_active should not start a new thread. The mock + # should only have been called once throughout this test. + self.assertTrue(self.plugin.is_active()) + mock.assert_called_once_with(thread) + + def testPluginIsNotActive(self): + self._SetupWSGIApp() + + # The is_active method makes use of a separate thread, so we mock threading + # behavior to make this test deterministic. + patcher = tf.compat.v1.test.mock.patch( + "threading.Thread.start", autospec=True + ) + mock = patcher.start() + self.addCleanup(patcher.stop) + + # The projector plugin has not yet determined whether it is active, but it + # should now start a thread to determine that. + self.assertFalse(self.plugin.is_active()) + mock.assert_called_once_with( + self.plugin._thread_for_determining_is_active + ) + + self.plugin._thread_for_determining_is_active.run() + + # The plugin later finds that embedding data is not available. + self.assertFalse(self.plugin.is_active()) + + # Furthermore, the plugin should have spawned a new thread to check whether + # it is active (because it might now be active even though it had not been + # beforehand), so the mock should now be called twice. + self.assertEqual(2, mock.call_count) + + def _SetupWSGIApp(self): + multiplexer = event_multiplexer.EventMultiplexer( + size_guidance=application.DEFAULT_SIZE_GUIDANCE, + purge_orphaned_data=True, + ) + context = base_plugin.TBContext( + logdir=self.log_dir, multiplexer=multiplexer + ) + self.plugin = projector_plugin.ProjectorPlugin(context) + wsgi_app = application.TensorBoardWSGI([self.plugin]) + self.server = werkzeug_test.Client(wsgi_app, wrappers.BaseResponse) + + def _Get(self, path): + return self.server.get(path) + + def _GetJson(self, path): + response = self.server.get(path) + data = response.data + if response.headers.get("Content-Encoding") == "gzip": + data = gzip.GzipFile("", "rb", 9, io.BytesIO(data)).read() + return json.loads(data.decode("utf-8")) + + def _GenerateEventsData(self): + with test_util.FileWriterCache.get(self.log_dir) as fw: + event = event_pb2.Event( + wall_time=1, + step=1, + summary=summary_pb2.Summary( + value=[summary_pb2.Summary.Value(tag="s1", simple_value=0)] + ), + ) + fw.add_event(event) + + def _GenerateProjectorTestData(self): + config_path = os.path.join(self.log_dir, "projector_config.pbtxt") + config = projector_config_pb2.ProjectorConfig() + embedding = config.embeddings.add() + # Add an embedding by its canonical tensor name. + embedding.tensor_name = "var1:0" + + with tf.io.gfile.GFile( + os.path.join(self.log_dir, "bookmarks.json"), "w" + ) as f: + f.write('{"a": "b"}') + embedding.bookmarks_path = "bookmarks.json" + + config_pbtxt = text_format.MessageToString(config) + with tf.io.gfile.GFile(config_path, "w") as f: + f.write(config_pbtxt) + + # Write a checkpoint with some dummy variables. + with tf.Graph().as_default(): + sess = tf.compat.v1.Session() + checkpoint_path = os.path.join(self.log_dir, "model") + tf.compat.v1.get_variable( + "var1", initializer=tf.constant(np.full([1, 2], 6.0)) + ) + tf.compat.v1.get_variable("var2", [10, 10]) + tf.compat.v1.get_variable("var3", [100, 100]) + sess.run(tf.compat.v1.global_variables_initializer()) + saver = tf.compat.v1.train.Saver( + write_version=tf.compat.v1.train.SaverDef.V1 + ) + saver.save(sess, checkpoint_path) class MetadataColumnsTest(tf.test.TestCase): - - def testLengthDoesNotMatch(self): - metadata = projector_plugin.EmbeddingMetadata(10) - - with self.assertRaises(ValueError): - metadata.add_column('Labels', [''] * 11) - - def testValuesNot1D(self): - metadata = projector_plugin.EmbeddingMetadata(3) - values = np.array([[1, 2, 3]]) - - with self.assertRaises(ValueError): - metadata.add_column('Labels', values) - - def testMultipleColumnsRetrieval(self): - metadata = projector_plugin.EmbeddingMetadata(3) - metadata.add_column('Sizes', [1, 2, 3]) - metadata.add_column('Labels', ['a', 'b', 'c']) - self.assertEqual(metadata.column_names, ['Sizes', 'Labels']) - self.assertEqual(metadata.name_to_values['Labels'], ['a', 'b', 'c']) - self.assertEqual(metadata.name_to_values['Sizes'], [1, 2, 3]) - - def testValuesAreListofLists(self): - metadata = projector_plugin.EmbeddingMetadata(3) - values = [[1, 2, 3], [4, 5, 6]] - with self.assertRaises(ValueError): - metadata.add_column('Labels', values) - - def testStringListRetrieval(self): - metadata = projector_plugin.EmbeddingMetadata(3) - metadata.add_column('Labels', ['a', 'B', 'c']) - self.assertEqual(metadata.name_to_values['Labels'], ['a', 'B', 'c']) - self.assertEqual(metadata.column_names, ['Labels']) - - def testNumericListRetrieval(self): - metadata = projector_plugin.EmbeddingMetadata(3) - metadata.add_column('Labels', [1, 2, 3]) - self.assertEqual(metadata.name_to_values['Labels'], [1, 2, 3]) - - def testNumericNdArrayRetrieval(self): - metadata = projector_plugin.EmbeddingMetadata(3) - metadata.add_column('Labels', np.array([1, 2, 3])) - self.assertEqual(metadata.name_to_values['Labels'].tolist(), [1, 2, 3]) - - def testStringNdArrayRetrieval(self): - metadata = projector_plugin.EmbeddingMetadata(2) - metadata.add_column('Labels', np.array(['a', 'b'])) - self.assertEqual(metadata.name_to_values['Labels'].tolist(), ['a', 'b']) - - def testDuplicateColumnName(self): - metadata = projector_plugin.EmbeddingMetadata(2) - metadata.add_column('Labels', np.array(['a', 'b'])) - with self.assertRaises(ValueError): - metadata.add_column('Labels', np.array(['a', 'b'])) + def testLengthDoesNotMatch(self): + metadata = projector_plugin.EmbeddingMetadata(10) + + with self.assertRaises(ValueError): + metadata.add_column("Labels", [""] * 11) + + def testValuesNot1D(self): + metadata = projector_plugin.EmbeddingMetadata(3) + values = np.array([[1, 2, 3]]) + + with self.assertRaises(ValueError): + metadata.add_column("Labels", values) + + def testMultipleColumnsRetrieval(self): + metadata = projector_plugin.EmbeddingMetadata(3) + metadata.add_column("Sizes", [1, 2, 3]) + metadata.add_column("Labels", ["a", "b", "c"]) + self.assertEqual(metadata.column_names, ["Sizes", "Labels"]) + self.assertEqual(metadata.name_to_values["Labels"], ["a", "b", "c"]) + self.assertEqual(metadata.name_to_values["Sizes"], [1, 2, 3]) + + def testValuesAreListofLists(self): + metadata = projector_plugin.EmbeddingMetadata(3) + values = [[1, 2, 3], [4, 5, 6]] + with self.assertRaises(ValueError): + metadata.add_column("Labels", values) + + def testStringListRetrieval(self): + metadata = projector_plugin.EmbeddingMetadata(3) + metadata.add_column("Labels", ["a", "B", "c"]) + self.assertEqual(metadata.name_to_values["Labels"], ["a", "B", "c"]) + self.assertEqual(metadata.column_names, ["Labels"]) + + def testNumericListRetrieval(self): + metadata = projector_plugin.EmbeddingMetadata(3) + metadata.add_column("Labels", [1, 2, 3]) + self.assertEqual(metadata.name_to_values["Labels"], [1, 2, 3]) + + def testNumericNdArrayRetrieval(self): + metadata = projector_plugin.EmbeddingMetadata(3) + metadata.add_column("Labels", np.array([1, 2, 3])) + self.assertEqual(metadata.name_to_values["Labels"].tolist(), [1, 2, 3]) + + def testStringNdArrayRetrieval(self): + metadata = projector_plugin.EmbeddingMetadata(2) + metadata.add_column("Labels", np.array(["a", "b"])) + self.assertEqual(metadata.name_to_values["Labels"].tolist(), ["a", "b"]) + + def testDuplicateColumnName(self): + metadata = projector_plugin.EmbeddingMetadata(2) + metadata.add_column("Labels", np.array(["a", "b"])) + with self.assertRaises(ValueError): + metadata.add_column("Labels", np.array(["a", "b"])) class LRUCacheTest(tf.test.TestCase): - - def testInvalidSize(self): - with self.assertRaises(ValueError): - projector_plugin.LRUCache(0) - - def testSimpleGetAndSet(self): - cache = projector_plugin.LRUCache(1) - value = cache.get('a') - self.assertIsNone(value) - cache.set('a', 10) - self.assertEqual(cache.get('a'), 10) - - def testErrorsWhenSettingNoneAsValue(self): - cache = projector_plugin.LRUCache(1) - with self.assertRaises(ValueError): - cache.set('a', None) - - def testLRUReplacementPolicy(self): - cache = projector_plugin.LRUCache(2) - cache.set('a', 1) - cache.set('b', 2) - cache.set('c', 3) - self.assertIsNone(cache.get('a')) - self.assertEqual(cache.get('b'), 2) - self.assertEqual(cache.get('c'), 3) - - # Make 'b' the most recently used. - cache.get('b') - cache.set('d', 4) - - # Make sure 'c' got replaced with 'd'. - self.assertIsNone(cache.get('c')) - self.assertEqual(cache.get('b'), 2) - self.assertEqual(cache.get('d'), 4) - - -if __name__ == '__main__': - tf.test.main() + def testInvalidSize(self): + with self.assertRaises(ValueError): + projector_plugin.LRUCache(0) + + def testSimpleGetAndSet(self): + cache = projector_plugin.LRUCache(1) + value = cache.get("a") + self.assertIsNone(value) + cache.set("a", 10) + self.assertEqual(cache.get("a"), 10) + + def testErrorsWhenSettingNoneAsValue(self): + cache = projector_plugin.LRUCache(1) + with self.assertRaises(ValueError): + cache.set("a", None) + + def testLRUReplacementPolicy(self): + cache = projector_plugin.LRUCache(2) + cache.set("a", 1) + cache.set("b", 2) + cache.set("c", 3) + self.assertIsNone(cache.get("a")) + self.assertEqual(cache.get("b"), 2) + self.assertEqual(cache.get("c"), 3) + + # Make 'b' the most recently used. + cache.get("b") + cache.set("d", 4) + + # Make sure 'c' got replaced with 'd'. + self.assertIsNone(cache.get("c")) + self.assertEqual(cache.get("b"), 2) + self.assertEqual(cache.get("d"), 4) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/scalar/metadata.py b/tensorboard/plugins/scalar/metadata.py index 2d1e273568..1b3903807d 100644 --- a/tensorboard/plugins/scalar/metadata.py +++ b/tensorboard/plugins/scalar/metadata.py @@ -24,7 +24,7 @@ logger = tb_logging.get_logger() -PLUGIN_NAME = 'scalars' +PLUGIN_NAME = "scalars" # The most recent value for the `version` field of the # `ScalarPluginData` proto. @@ -32,39 +32,43 @@ def create_summary_metadata(display_name, description): - """Create a `summary_pb2.SummaryMetadata` proto for scalar plugin data. + """Create a `summary_pb2.SummaryMetadata` proto for scalar plugin data. - Returns: - A `summary_pb2.SummaryMetadata` protobuf object. - """ - content = plugin_data_pb2.ScalarPluginData(version=PROTO_VERSION) - metadata = summary_pb2.SummaryMetadata( - display_name=display_name, - summary_description=description, - plugin_data=summary_pb2.SummaryMetadata.PluginData( - plugin_name=PLUGIN_NAME, - content=content.SerializeToString())) - return metadata + Returns: + A `summary_pb2.SummaryMetadata` protobuf object. + """ + content = plugin_data_pb2.ScalarPluginData(version=PROTO_VERSION) + metadata = summary_pb2.SummaryMetadata( + display_name=display_name, + summary_description=description, + plugin_data=summary_pb2.SummaryMetadata.PluginData( + plugin_name=PLUGIN_NAME, content=content.SerializeToString() + ), + ) + return metadata def parse_plugin_metadata(content): - """Parse summary metadata to a Python object. + """Parse summary metadata to a Python object. - Arguments: - content: The `content` field of a `SummaryMetadata` proto - corresponding to the scalar plugin. + Arguments: + content: The `content` field of a `SummaryMetadata` proto + corresponding to the scalar plugin. - Returns: - A `ScalarPluginData` protobuf object. - """ - if not isinstance(content, bytes): - raise TypeError('Content type must be bytes') - result = plugin_data_pb2.ScalarPluginData.FromString(content) - if result.version == 0: - return result - else: - logger.warn( - 'Unknown metadata version: %s. The latest version known to ' - 'this build of TensorBoard is %s; perhaps a newer build is ' - 'available?', result.version, PROTO_VERSION) - return result + Returns: + A `ScalarPluginData` protobuf object. + """ + if not isinstance(content, bytes): + raise TypeError("Content type must be bytes") + result = plugin_data_pb2.ScalarPluginData.FromString(content) + if result.version == 0: + return result + else: + logger.warn( + "Unknown metadata version: %s. The latest version known to " + "this build of TensorBoard is %s; perhaps a newer build is " + "available?", + result.version, + PROTO_VERSION, + ) + return result diff --git a/tensorboard/plugins/scalar/scalars_demo.py b/tensorboard/plugins/scalar/scalars_demo.py index e16316075f..b561f1e8a3 100644 --- a/tensorboard/plugins/scalar/scalars_demo.py +++ b/tensorboard/plugins/scalar/scalars_demo.py @@ -26,119 +26,138 @@ from tensorboard.plugins.scalar import summary # Directory into which to write tensorboard data. -LOGDIR = '/tmp/scalars_demo' +LOGDIR = "/tmp/scalars_demo" # Duration of the simulation. STEPS = 1000 -def run(logdir, run_name, - initial_temperature, ambient_temperature, heat_coefficient): - """Run a temperature simulation. - - This will simulate an object at temperature `initial_temperature` - sitting at rest in a large room at temperature `ambient_temperature`. - The object has some intrinsic `heat_coefficient`, which indicates - how much thermal conductivity it has: for instance, metals have high - thermal conductivity, while the thermal conductivity of water is low. - - Over time, the object's temperature will adjust to match the - temperature of its environment. We'll track the object's temperature, - how far it is from the room's temperature, and how much it changes at - each time step. - - Arguments: - logdir: the top-level directory into which to write summary data - run_name: the name of this run; will be created as a subdirectory - under logdir - initial_temperature: float; the object's initial temperature - ambient_temperature: float; the temperature of the enclosing room - heat_coefficient: float; a measure of the object's thermal - conductivity - """ - tf.compat.v1.reset_default_graph() - tf.compat.v1.set_random_seed(0) - - with tf.name_scope('temperature'): - # Create a mutable variable to hold the object's temperature, and - # create a scalar summary to track its value over time. The name of - # the summary will appear as "temperature/current" due to the - # name-scope above. - temperature = tf.Variable(tf.constant(initial_temperature), - name='temperature') - summary.op('current', temperature, - display_name='Temperature', - description='The temperature of the object under ' - 'simulation, in Kelvins.') - - # Compute how much the object's temperature differs from that of its - # environment, and track this, too: likewise, as - # "temperature/difference_to_ambient". - ambient_difference = temperature - ambient_temperature - summary.op('difference_to_ambient', ambient_difference, - display_name='Difference to ambient temperature', - description='The difference between the ambient ' - 'temperature and the temperature of the ' - 'object under simulation, in Kelvins.') - - # Newton suggested that the rate of change of the temperature of an - # object is directly proportional to this `ambient_difference` above, - # where the proportionality constant is what we called the heat - # coefficient. But in real life, not everything is quite so clean, so - # we'll add in some noise. (The value of 50 is arbitrary, chosen to - # make the data look somewhat interesting. :-) ) - noise = 50 * tf.random.normal([]) - delta = -heat_coefficient * (ambient_difference + noise) - summary.op('delta', delta, - description='The change in temperature from the previous ' - 'step, in Kelvins.') - - # Collect all the scalars that we want to keep track of. - summ = tf.compat.v1.summary.merge_all() - - # Now, augment the current temperature by this delta that we computed, - # blocking the assignment on summary collection to avoid race conditions - # and ensure that the summary always reports the pre-update value. - with tf.control_dependencies([summ]): - update_step = temperature.assign_add(delta) - - sess = tf.compat.v1.Session() - writer = tf.summary.FileWriter(os.path.join(logdir, run_name)) - writer.add_graph(sess.graph) - sess.run(tf.compat.v1.global_variables_initializer()) - for step in xrange(STEPS): - # By asking TensorFlow to compute the update step, we force it to - # change the value of the temperature variable. We don't actually - # care about this value, so we discard it; instead, we grab the - # summary data computed along the way. - (s, _) = sess.run([summ, update_step]) - writer.add_summary(s, global_step=step) - writer.close() +def run( + logdir, run_name, initial_temperature, ambient_temperature, heat_coefficient +): + """Run a temperature simulation. + + This will simulate an object at temperature `initial_temperature` + sitting at rest in a large room at temperature `ambient_temperature`. + The object has some intrinsic `heat_coefficient`, which indicates + how much thermal conductivity it has: for instance, metals have high + thermal conductivity, while the thermal conductivity of water is low. + + Over time, the object's temperature will adjust to match the + temperature of its environment. We'll track the object's temperature, + how far it is from the room's temperature, and how much it changes at + each time step. + + Arguments: + logdir: the top-level directory into which to write summary data + run_name: the name of this run; will be created as a subdirectory + under logdir + initial_temperature: float; the object's initial temperature + ambient_temperature: float; the temperature of the enclosing room + heat_coefficient: float; a measure of the object's thermal + conductivity + """ + tf.compat.v1.reset_default_graph() + tf.compat.v1.set_random_seed(0) + + with tf.name_scope("temperature"): + # Create a mutable variable to hold the object's temperature, and + # create a scalar summary to track its value over time. The name of + # the summary will appear as "temperature/current" due to the + # name-scope above. + temperature = tf.Variable( + tf.constant(initial_temperature), name="temperature" + ) + summary.op( + "current", + temperature, + display_name="Temperature", + description="The temperature of the object under " + "simulation, in Kelvins.", + ) + + # Compute how much the object's temperature differs from that of its + # environment, and track this, too: likewise, as + # "temperature/difference_to_ambient". + ambient_difference = temperature - ambient_temperature + summary.op( + "difference_to_ambient", + ambient_difference, + display_name="Difference to ambient temperature", + description="The difference between the ambient " + "temperature and the temperature of the " + "object under simulation, in Kelvins.", + ) + + # Newton suggested that the rate of change of the temperature of an + # object is directly proportional to this `ambient_difference` above, + # where the proportionality constant is what we called the heat + # coefficient. But in real life, not everything is quite so clean, so + # we'll add in some noise. (The value of 50 is arbitrary, chosen to + # make the data look somewhat interesting. :-) ) + noise = 50 * tf.random.normal([]) + delta = -heat_coefficient * (ambient_difference + noise) + summary.op( + "delta", + delta, + description="The change in temperature from the previous " + "step, in Kelvins.", + ) + + # Collect all the scalars that we want to keep track of. + summ = tf.compat.v1.summary.merge_all() + + # Now, augment the current temperature by this delta that we computed, + # blocking the assignment on summary collection to avoid race conditions + # and ensure that the summary always reports the pre-update value. + with tf.control_dependencies([summ]): + update_step = temperature.assign_add(delta) + + sess = tf.compat.v1.Session() + writer = tf.summary.FileWriter(os.path.join(logdir, run_name)) + writer.add_graph(sess.graph) + sess.run(tf.compat.v1.global_variables_initializer()) + for step in xrange(STEPS): + # By asking TensorFlow to compute the update step, we force it to + # change the value of the temperature variable. We don't actually + # care about this value, so we discard it; instead, we grab the + # summary data computed along the way. + (s, _) = sess.run([summ, update_step]) + writer.add_summary(s, global_step=step) + writer.close() def run_all(logdir, verbose=False): - """Run simulations on a reasonable set of parameters. - - Arguments: - logdir: the directory into which to store all the runs' data - verbose: if true, print out each run's name as it begins - """ - for initial_temperature in [270.0, 310.0, 350.0]: - for final_temperature in [270.0, 310.0, 350.0]: - for heat_coefficient in [0.001, 0.005]: - run_name = 'temperature:t0=%g,tA=%g,kH=%g' % ( - initial_temperature, final_temperature, heat_coefficient) - if verbose: - print('--- Running: %s' % run_name) - run(logdir, run_name, - initial_temperature, final_temperature, heat_coefficient) + """Run simulations on a reasonable set of parameters. + + Arguments: + logdir: the directory into which to store all the runs' data + verbose: if true, print out each run's name as it begins + """ + for initial_temperature in [270.0, 310.0, 350.0]: + for final_temperature in [270.0, 310.0, 350.0]: + for heat_coefficient in [0.001, 0.005]: + run_name = "temperature:t0=%g,tA=%g,kH=%g" % ( + initial_temperature, + final_temperature, + heat_coefficient, + ) + if verbose: + print("--- Running: %s" % run_name) + run( + logdir, + run_name, + initial_temperature, + final_temperature, + heat_coefficient, + ) def main(unused_argv): - print('Saving output to %s.' % LOGDIR) - run_all(LOGDIR, verbose=True) - print('Done. Output saved to %s.' % LOGDIR) + print("Saving output to %s." % LOGDIR) + run_all(LOGDIR, verbose=True) + print("Done. Output saved to %s." % LOGDIR) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/tensorboard/plugins/scalar/scalars_plugin.py b/tensorboard/plugins/scalar/scalars_plugin.py index 4314b65b61..4004abbaed 100644 --- a/tensorboard/plugins/scalar/scalars_plugin.py +++ b/tensorboard/plugins/scalar/scalars_plugin.py @@ -41,83 +41,93 @@ class OutputFormat(object): - """An enum used to list the valid output formats for API calls.""" - JSON = 'json' - CSV = 'csv' + """An enum used to list the valid output formats for API calls.""" + + JSON = "json" + CSV = "csv" class ScalarsPlugin(base_plugin.TBPlugin): - """Scalars Plugin for TensorBoard.""" - - plugin_name = metadata.PLUGIN_NAME - - def __init__(self, context): - """Instantiates ScalarsPlugin via TensorBoard core. - - Args: - context: A base_plugin.TBContext instance. - """ - self._multiplexer = context.multiplexer - self._db_connection_provider = context.db_connection_provider - if context.flags and context.flags.generic_data == 'true': - self._data_provider = context.data_provider - else: - self._data_provider = None - - def get_plugin_apps(self): - return { - '/scalars': self.scalars_route, - '/tags': self.tags_route, - } - - def is_active(self): - """The scalars plugin is active iff any run has at least one scalar tag.""" - if self._data_provider: - # We don't have an experiment ID, and modifying the backend core - # to provide one would break backward compatibility. Hack for now. - return True - - if self._db_connection_provider: - # The plugin is active if one relevant tag can be found in the database. - db = self._db_connection_provider() - cursor = db.execute(''' + """Scalars Plugin for TensorBoard.""" + + plugin_name = metadata.PLUGIN_NAME + + def __init__(self, context): + """Instantiates ScalarsPlugin via TensorBoard core. + + Args: + context: A base_plugin.TBContext instance. + """ + self._multiplexer = context.multiplexer + self._db_connection_provider = context.db_connection_provider + if context.flags and context.flags.generic_data == "true": + self._data_provider = context.data_provider + else: + self._data_provider = None + + def get_plugin_apps(self): + return { + "/scalars": self.scalars_route, + "/tags": self.tags_route, + } + + def is_active(self): + """The scalars plugin is active iff any run has at least one scalar + tag.""" + if self._data_provider: + # We don't have an experiment ID, and modifying the backend core + # to provide one would break backward compatibility. Hack for now. + return True + + if self._db_connection_provider: + # The plugin is active if one relevant tag can be found in the database. + db = self._db_connection_provider() + cursor = db.execute( + """ SELECT 1 FROM Tags WHERE Tags.plugin_name = ? LIMIT 1 - ''', (metadata.PLUGIN_NAME,)) - return bool(list(cursor)) - - if not self._multiplexer: - return False - - return bool(self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME)) - - def frontend_metadata(self): - return base_plugin.FrontendMetadata(element_name='tf-scalar-dashboard') - - def index_impl(self, experiment=None): - """Return {runName: {tagName: {displayName: ..., description: ...}}}.""" - if self._data_provider: - mapping = self._data_provider.list_scalars( - experiment_id=experiment, - plugin_name=metadata.PLUGIN_NAME, - ) - result = {run: {} for run in mapping} - for (run, tag_to_content) in six.iteritems(mapping): - for (tag, metadatum) in six.iteritems(tag_to_content): - description = plugin_util.markdown_to_safe_html(metadatum.description) - result[run][tag] = { - 'displayName': metadatum.display_name, - 'description': description, - } - return result - - if self._db_connection_provider: - # Read tags from the database. - db = self._db_connection_provider() - cursor = db.execute(''' + """, + (metadata.PLUGIN_NAME,), + ) + return bool(list(cursor)) + + if not self._multiplexer: + return False + + return bool( + self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME) + ) + + def frontend_metadata(self): + return base_plugin.FrontendMetadata(element_name="tf-scalar-dashboard") + + def index_impl(self, experiment=None): + """Return {runName: {tagName: {displayName: ..., description: + ...}}}.""" + if self._data_provider: + mapping = self._data_provider.list_scalars( + experiment_id=experiment, plugin_name=metadata.PLUGIN_NAME, + ) + result = {run: {} for run in mapping} + for (run, tag_to_content) in six.iteritems(mapping): + for (tag, metadatum) in six.iteritems(tag_to_content): + description = plugin_util.markdown_to_safe_html( + metadatum.description + ) + result[run][tag] = { + "displayName": metadatum.display_name, + "description": description, + } + return result + + if self._db_connection_provider: + # Read tags from the database. + db = self._db_connection_provider() + cursor = db.execute( + """ SELECT Tags.tag_name, Tags.display_name, @@ -127,54 +137,62 @@ def index_impl(self, experiment=None): ON Tags.run_id = Runs.run_id WHERE Tags.plugin_name = ? - ''', (metadata.PLUGIN_NAME,)) - result = collections.defaultdict(dict) - for row in cursor: - tag_name, display_name, run_name = row - result[run_name][tag_name] = { - 'displayName': display_name, - # TODO(chihuahua): Populate the description. Currently, the tags - # table does not link with the description table. - 'description': '', - } - return result - - result = collections.defaultdict(lambda: {}) - mapping = self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME) - for (run, tag_to_content) in six.iteritems(mapping): - for (tag, content) in six.iteritems(tag_to_content): - content = metadata.parse_plugin_metadata(content) - summary_metadata = self._multiplexer.SummaryMetadata(run, tag) - result[run][tag] = {'displayName': summary_metadata.display_name, - 'description': plugin_util.markdown_to_safe_html( - summary_metadata.summary_description)} - - return result - - def scalars_impl(self, tag, run, experiment, output_format): - """Result of the form `(body, mime_type)`.""" - if self._data_provider: - # Downsample reads to 1000 scalars per time series, which is the - # default size guidance for scalars under the multiplexer loading - # logic. - SAMPLE_COUNT = 1000 - all_scalars = self._data_provider.read_scalars( - experiment_id=experiment, - plugin_name=metadata.PLUGIN_NAME, - downsample=SAMPLE_COUNT, - run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]), - ) - scalars = all_scalars.get(run, {}).get(tag, None) - if scalars is None: - raise errors.NotFoundError( - 'No scalar data for run=%r, tag=%r' % (run, tag) + """, + (metadata.PLUGIN_NAME,), + ) + result = collections.defaultdict(dict) + for row in cursor: + tag_name, display_name, run_name = row + result[run_name][tag_name] = { + "displayName": display_name, + # TODO(chihuahua): Populate the description. Currently, the tags + # table does not link with the description table. + "description": "", + } + return result + + result = collections.defaultdict(lambda: {}) + mapping = self._multiplexer.PluginRunToTagToContent( + metadata.PLUGIN_NAME ) - values = [(x.wall_time, x.step, x.value) for x in scalars] - elif self._db_connection_provider: - db = self._db_connection_provider() - # We select for steps greater than -1 because the writer inserts - # placeholder rows en masse. The check for step filters out those rows. - cursor = db.execute(''' + for (run, tag_to_content) in six.iteritems(mapping): + for (tag, content) in six.iteritems(tag_to_content): + content = metadata.parse_plugin_metadata(content) + summary_metadata = self._multiplexer.SummaryMetadata(run, tag) + result[run][tag] = { + "displayName": summary_metadata.display_name, + "description": plugin_util.markdown_to_safe_html( + summary_metadata.summary_description + ), + } + + return result + + def scalars_impl(self, tag, run, experiment, output_format): + """Result of the form `(body, mime_type)`.""" + if self._data_provider: + # Downsample reads to 1000 scalars per time series, which is the + # default size guidance for scalars under the multiplexer loading + # logic. + SAMPLE_COUNT = 1000 + all_scalars = self._data_provider.read_scalars( + experiment_id=experiment, + plugin_name=metadata.PLUGIN_NAME, + downsample=SAMPLE_COUNT, + run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]), + ) + scalars = all_scalars.get(run, {}).get(tag, None) + if scalars is None: + raise errors.NotFoundError( + "No scalar data for run=%r, tag=%r" % (run, tag) + ) + values = [(x.wall_time, x.step, x.value) for x in scalars] + elif self._db_connection_provider: + db = self._db_connection_provider() + # We select for steps greater than -1 because the writer inserts + # placeholder rows en masse. The check for step filters out those rows. + cursor = db.execute( + """ SELECT Tensors.step, Tensors.computed_time, @@ -195,56 +213,73 @@ def scalars_impl(self, tag, run, experiment, output_format): AND Tensors.shape = '' AND Tensors.step > -1 ORDER BY Tensors.step - ''', dict(exp=experiment, run=run, tag=tag, plugin=metadata.PLUGIN_NAME)) - values = [(wall_time, step, self._get_value(data, dtype_enum)) - for (step, wall_time, data, dtype_enum) in cursor] - else: - try: - tensor_events = self._multiplexer.Tensors(run, tag) - except KeyError: - raise errors.NotFoundError( - 'No scalar data for run=%r, tag=%r' % (run, tag) + """, + dict( + exp=experiment, + run=run, + tag=tag, + plugin=metadata.PLUGIN_NAME, + ), + ) + values = [ + (wall_time, step, self._get_value(data, dtype_enum)) + for (step, wall_time, data, dtype_enum) in cursor + ] + else: + try: + tensor_events = self._multiplexer.Tensors(run, tag) + except KeyError: + raise errors.NotFoundError( + "No scalar data for run=%r, tag=%r" % (run, tag) + ) + values = [ + ( + tensor_event.wall_time, + tensor_event.step, + tensor_util.make_ndarray(tensor_event.tensor_proto).item(), + ) + for tensor_event in tensor_events + ] + + if output_format == OutputFormat.CSV: + string_io = StringIO() + writer = csv.writer(string_io) + writer.writerow(["Wall time", "Step", "Value"]) + writer.writerows(values) + return (string_io.getvalue(), "text/csv") + else: + return (values, "application/json") + + def _get_value(self, scalar_data_blob, dtype_enum): + """Obtains value for scalar event given blob and dtype enum. + + Args: + scalar_data_blob: The blob obtained from the database. + dtype_enum: The enum representing the dtype. + + Returns: + The scalar value. + """ + tensorflow_dtype = tf.DType(dtype_enum) + buf = np.frombuffer( + scalar_data_blob, dtype=tensorflow_dtype.as_numpy_dtype + ) + return np.asscalar(buf) + + @wrappers.Request.application + def tags_route(self, request): + experiment = plugin_util.experiment_id(request.environ) + index = self.index_impl(experiment=experiment) + return http_util.Respond(request, index, "application/json") + + @wrappers.Request.application + def scalars_route(self, request): + """Given a tag and single run, return array of ScalarEvents.""" + tag = request.args.get("tag") + run = request.args.get("run") + experiment = plugin_util.experiment_id(request.environ) + output_format = request.args.get("format") + (body, mime_type) = self.scalars_impl( + tag, run, experiment, output_format ) - values = [(tensor_event.wall_time, - tensor_event.step, - tensor_util.make_ndarray(tensor_event.tensor_proto).item()) - for tensor_event in tensor_events] - - if output_format == OutputFormat.CSV: - string_io = StringIO() - writer = csv.writer(string_io) - writer.writerow(['Wall time', 'Step', 'Value']) - writer.writerows(values) - return (string_io.getvalue(), 'text/csv') - else: - return (values, 'application/json') - - def _get_value(self, scalar_data_blob, dtype_enum): - """Obtains value for scalar event given blob and dtype enum. - - Args: - scalar_data_blob: The blob obtained from the database. - dtype_enum: The enum representing the dtype. - - Returns: - The scalar value. - """ - tensorflow_dtype = tf.DType(dtype_enum) - buf = np.frombuffer(scalar_data_blob, dtype=tensorflow_dtype.as_numpy_dtype) - return np.asscalar(buf) - - @wrappers.Request.application - def tags_route(self, request): - experiment = plugin_util.experiment_id(request.environ) - index = self.index_impl(experiment=experiment) - return http_util.Respond(request, index, 'application/json') - - @wrappers.Request.application - def scalars_route(self, request): - """Given a tag and single run, return array of ScalarEvents.""" - tag = request.args.get('tag') - run = request.args.get('run') - experiment = plugin_util.experiment_id(request.environ) - output_format = request.args.get('format') - (body, mime_type) = self.scalars_impl(tag, run, experiment, output_format) - return http_util.Respond(request, body, mime_type) + return http_util.Respond(request, body, mime_type) diff --git a/tensorboard/plugins/scalar/scalars_plugin_test.py b/tensorboard/plugins/scalar/scalars_plugin_test.py index 60840c2b11..30bd7f1574 100644 --- a/tensorboard/plugins/scalar/scalars_plugin_test.py +++ b/tensorboard/plugins/scalar/scalars_plugin_test.py @@ -32,8 +32,12 @@ from tensorboard import errors from tensorboard.backend import application from tensorboard.backend.event_processing import data_provider -from tensorboard.backend.event_processing import plugin_event_accumulator as event_accumulator -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from tensorboard.backend.event_processing import ( + plugin_event_accumulator as event_accumulator, +) +from tensorboard.backend.event_processing import ( + plugin_event_multiplexer as event_multiplexer, +) from tensorboard.plugins import base_plugin from tensorboard.plugins.core import core_plugin from tensorboard.plugins.scalar import metadata @@ -47,255 +51,304 @@ class ScalarsPluginTest(tf.test.TestCase): - _STEPS = 9 - - _LEGACY_SCALAR_TAG = 'ancient-values' - _SCALAR_TAG = 'simple-values' - _HISTOGRAM_TAG = 'complicated-values' - - _DISPLAY_NAME = 'Walrus population' - _DESCRIPTION = 'the *most* valuable statistic' - _HTML_DESCRIPTION = '

the most valuable statistic

' - - _RUN_WITH_LEGACY_SCALARS = '_RUN_WITH_LEGACY_SCALARS' - _RUN_WITH_SCALARS = '_RUN_WITH_SCALARS' - _RUN_WITH_HISTOGRAM = '_RUN_WITH_HISTOGRAM' - - def __init__(self, *args, **kwargs): - super(ScalarsPluginTest, self).__init__(*args, **kwargs) - self.plugin = None # used by DB tests only - - def load_runs(self, run_names): - logdir = self.get_temp_dir() - for run_name in run_names: - self.generate_run(logdir, run_name) - multiplexer = event_multiplexer.EventMultiplexer(size_guidance={ - # don't truncate my test data, please - event_accumulator.TENSORS: self._STEPS, - }) - multiplexer.AddRunsFromDirectory(logdir) - multiplexer.Reload() - return (logdir, multiplexer) - - def set_up_db(self): - self.db_path = os.path.join(self.get_temp_dir(), 'db.db') - self.db_uri = 'sqlite:' + self.db_path - db_connection_provider = application.create_sqlite_connection_provider( - self.db_uri) - context = base_plugin.TBContext( - db_connection_provider=db_connection_provider, - db_uri=self.db_uri) - self.core_plugin = core_plugin.CorePlugin(context) - self.plugin = scalars_plugin.ScalarsPlugin(context) - - def generate_run_to_db(self, experiment_name, run_name): - # This method uses `tf.contrib.summary`, and so must only be invoked - # when TensorFlow 1.x is installed. - tf.compat.v1.reset_default_graph() - with tf.compat.v1.Graph().as_default(): - global_step = tf.compat.v1.placeholder(tf.int64) - db_writer = tf.contrib.summary.create_db_writer( - db_uri=self.db_path, - experiment_name=experiment_name, - run_name=run_name, - user_name='user') - with db_writer.as_default(), tf.contrib.summary.always_record_summaries(): - tf.contrib.summary.scalar(self._SCALAR_TAG, 42, step=global_step) - flush_op = tf.contrib.summary.flush(db_writer._resource) - with tf.compat.v1.Session() as sess: - sess.run(tf.contrib.summary.summary_writer_initializer_op()) - summaries = tf.contrib.summary.all_summary_ops() - for step in xrange(self._STEPS): - feed_dict = {global_step: step} - sess.run(summaries, feed_dict=feed_dict) - sess.run(flush_op) - - def with_runs(run_names): - """Run a test with a bare multiplexer and with a `data_provider`. - - The decorated function will receive an initialized `ScalarsPlugin` - object as its first positional argument. - """ - def decorator(fn): - @functools.wraps(fn) - def wrapper(self, *args, **kwargs): - (logdir, multiplexer) = self.load_runs(run_names) - with self.subTest('bare multiplexer'): - ctx = base_plugin.TBContext(logdir=logdir, multiplexer=multiplexer) - fn(self, scalars_plugin.ScalarsPlugin(ctx), *args, **kwargs) - with self.subTest('generic data provider'): - flags = argparse.Namespace(generic_data='true') - provider = data_provider.MultiplexerDataProvider(multiplexer, logdir) - ctx = base_plugin.TBContext( - flags=flags, - logdir=logdir, - multiplexer=multiplexer, - data_provider=provider, - ) - fn(self, scalars_plugin.ScalarsPlugin(ctx), *args, **kwargs) - return wrapper - return decorator - - def generate_run(self, logdir, run_name): - subdir = os.path.join(logdir, run_name) - with test_util.FileWriterCache.get(subdir) as writer: - for step in xrange(self._STEPS): - data = [1 + step, 2 + step, 3 + step] - if run_name == self._RUN_WITH_LEGACY_SCALARS: - summ = tf.compat.v1.summary.scalar( - self._LEGACY_SCALAR_TAG, tf.reduce_mean(data), - ).numpy() - elif run_name == self._RUN_WITH_SCALARS: - summ = summary.op( - self._SCALAR_TAG, - tf.reduce_sum(data), - display_name=self._DISPLAY_NAME, - description=self._DESCRIPTION, - ).numpy() - elif run_name == self._RUN_WITH_HISTOGRAM: - summ = tf.compat.v1.summary.histogram( - self._HISTOGRAM_TAG, data - ).numpy() - else: - assert False, 'Invalid run name: %r' % run_name - writer.add_summary(summ, global_step=step) - - @with_runs([]) - def testRoutesProvided(self, plugin): - """Tests that the plugin offers the correct routes.""" - routes = plugin.get_plugin_apps() - self.assertIsInstance(routes['/scalars'], collections.Callable) - self.assertIsInstance(routes['/tags'], collections.Callable) - - @with_runs([_RUN_WITH_LEGACY_SCALARS, _RUN_WITH_SCALARS, _RUN_WITH_HISTOGRAM]) - def test_index(self, plugin): - self.assertEqual({ - self._RUN_WITH_LEGACY_SCALARS: { - self._LEGACY_SCALAR_TAG: { - 'displayName': self._LEGACY_SCALAR_TAG, - 'description': '', - }, - }, - self._RUN_WITH_SCALARS: { - '%s/scalar_summary' % self._SCALAR_TAG: { - 'displayName': self._DISPLAY_NAME, - 'description': self._HTML_DESCRIPTION, + _STEPS = 9 + + _LEGACY_SCALAR_TAG = "ancient-values" + _SCALAR_TAG = "simple-values" + _HISTOGRAM_TAG = "complicated-values" + + _DISPLAY_NAME = "Walrus population" + _DESCRIPTION = "the *most* valuable statistic" + _HTML_DESCRIPTION = "

the most valuable statistic

" + + _RUN_WITH_LEGACY_SCALARS = "_RUN_WITH_LEGACY_SCALARS" + _RUN_WITH_SCALARS = "_RUN_WITH_SCALARS" + _RUN_WITH_HISTOGRAM = "_RUN_WITH_HISTOGRAM" + + def __init__(self, *args, **kwargs): + super(ScalarsPluginTest, self).__init__(*args, **kwargs) + self.plugin = None # used by DB tests only + + def load_runs(self, run_names): + logdir = self.get_temp_dir() + for run_name in run_names: + self.generate_run(logdir, run_name) + multiplexer = event_multiplexer.EventMultiplexer( + size_guidance={ + # don't truncate my test data, please + event_accumulator.TENSORS: self._STEPS, + } + ) + multiplexer.AddRunsFromDirectory(logdir) + multiplexer.Reload() + return (logdir, multiplexer) + + def set_up_db(self): + self.db_path = os.path.join(self.get_temp_dir(), "db.db") + self.db_uri = "sqlite:" + self.db_path + db_connection_provider = application.create_sqlite_connection_provider( + self.db_uri + ) + context = base_plugin.TBContext( + db_connection_provider=db_connection_provider, db_uri=self.db_uri + ) + self.core_plugin = core_plugin.CorePlugin(context) + self.plugin = scalars_plugin.ScalarsPlugin(context) + + def generate_run_to_db(self, experiment_name, run_name): + # This method uses `tf.contrib.summary`, and so must only be invoked + # when TensorFlow 1.x is installed. + tf.compat.v1.reset_default_graph() + with tf.compat.v1.Graph().as_default(): + global_step = tf.compat.v1.placeholder(tf.int64) + db_writer = tf.contrib.summary.create_db_writer( + db_uri=self.db_path, + experiment_name=experiment_name, + run_name=run_name, + user_name="user", + ) + with db_writer.as_default(), tf.contrib.summary.always_record_summaries(): + tf.contrib.summary.scalar( + self._SCALAR_TAG, 42, step=global_step + ) + flush_op = tf.contrib.summary.flush(db_writer._resource) + with tf.compat.v1.Session() as sess: + sess.run(tf.contrib.summary.summary_writer_initializer_op()) + summaries = tf.contrib.summary.all_summary_ops() + for step in xrange(self._STEPS): + feed_dict = {global_step: step} + sess.run(summaries, feed_dict=feed_dict) + sess.run(flush_op) + + def with_runs(run_names): + """Run a test with a bare multiplexer and with a `data_provider`. + + The decorated function will receive an initialized + `ScalarsPlugin` object as its first positional argument. + """ + + def decorator(fn): + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + (logdir, multiplexer) = self.load_runs(run_names) + with self.subTest("bare multiplexer"): + ctx = base_plugin.TBContext( + logdir=logdir, multiplexer=multiplexer + ) + fn(self, scalars_plugin.ScalarsPlugin(ctx), *args, **kwargs) + with self.subTest("generic data provider"): + flags = argparse.Namespace(generic_data="true") + provider = data_provider.MultiplexerDataProvider( + multiplexer, logdir + ) + ctx = base_plugin.TBContext( + flags=flags, + logdir=logdir, + multiplexer=multiplexer, + data_provider=provider, + ) + fn(self, scalars_plugin.ScalarsPlugin(ctx), *args, **kwargs) + + return wrapper + + return decorator + + def generate_run(self, logdir, run_name): + subdir = os.path.join(logdir, run_name) + with test_util.FileWriterCache.get(subdir) as writer: + for step in xrange(self._STEPS): + data = [1 + step, 2 + step, 3 + step] + if run_name == self._RUN_WITH_LEGACY_SCALARS: + summ = tf.compat.v1.summary.scalar( + self._LEGACY_SCALAR_TAG, tf.reduce_mean(data), + ).numpy() + elif run_name == self._RUN_WITH_SCALARS: + summ = summary.op( + self._SCALAR_TAG, + tf.reduce_sum(data), + display_name=self._DISPLAY_NAME, + description=self._DESCRIPTION, + ).numpy() + elif run_name == self._RUN_WITH_HISTOGRAM: + summ = tf.compat.v1.summary.histogram( + self._HISTOGRAM_TAG, data + ).numpy() + else: + assert False, "Invalid run name: %r" % run_name + writer.add_summary(summ, global_step=step) + + @with_runs([]) + def testRoutesProvided(self, plugin): + """Tests that the plugin offers the correct routes.""" + routes = plugin.get_plugin_apps() + self.assertIsInstance(routes["/scalars"], collections.Callable) + self.assertIsInstance(routes["/tags"], collections.Callable) + + @with_runs( + [_RUN_WITH_LEGACY_SCALARS, _RUN_WITH_SCALARS, _RUN_WITH_HISTOGRAM] + ) + def test_index(self, plugin): + self.assertEqual( + { + self._RUN_WITH_LEGACY_SCALARS: { + self._LEGACY_SCALAR_TAG: { + "displayName": self._LEGACY_SCALAR_TAG, + "description": "", + }, + }, + self._RUN_WITH_SCALARS: { + "%s/scalar_summary" + % self._SCALAR_TAG: { + "displayName": self._DISPLAY_NAME, + "description": self._HTML_DESCRIPTION, + }, + }, + # _RUN_WITH_HISTOGRAM omitted: No scalar data. }, - }, - # _RUN_WITH_HISTOGRAM omitted: No scalar data. - }, plugin.index_impl('eid')) - - @with_runs([_RUN_WITH_LEGACY_SCALARS, _RUN_WITH_SCALARS, _RUN_WITH_HISTOGRAM]) - def _test_scalars_json(self, plugin, run_name, tag_name, should_work=True): - if should_work: - (data, mime_type) = plugin.scalars_impl( - tag_name, run_name, 'eid', scalars_plugin.OutputFormat.JSON) - self.assertEqual('application/json', mime_type) - self.assertEqual(len(data), self._STEPS) - else: - with self.assertRaises(errors.NotFoundError): - plugin.scalars_impl( - self._SCALAR_TAG, run_name, 'eid', scalars_plugin.OutputFormat.JSON + plugin.index_impl("eid"), + ) + + @with_runs( + [_RUN_WITH_LEGACY_SCALARS, _RUN_WITH_SCALARS, _RUN_WITH_HISTOGRAM] + ) + def _test_scalars_json(self, plugin, run_name, tag_name, should_work=True): + if should_work: + (data, mime_type) = plugin.scalars_impl( + tag_name, run_name, "eid", scalars_plugin.OutputFormat.JSON + ) + self.assertEqual("application/json", mime_type) + self.assertEqual(len(data), self._STEPS) + else: + with self.assertRaises(errors.NotFoundError): + plugin.scalars_impl( + self._SCALAR_TAG, + run_name, + "eid", + scalars_plugin.OutputFormat.JSON, + ) + + @with_runs( + [_RUN_WITH_LEGACY_SCALARS, _RUN_WITH_SCALARS, _RUN_WITH_HISTOGRAM] + ) + def _test_scalars_csv(self, plugin, run_name, tag_name, should_work=True): + if should_work: + (data, mime_type) = plugin.scalars_impl( + tag_name, run_name, "eid", scalars_plugin.OutputFormat.CSV + ) + self.assertEqual("text/csv", mime_type) + s = StringIO(data) + reader = csv.reader(s) + self.assertEqual(["Wall time", "Step", "Value"], next(reader)) + self.assertEqual(len(list(reader)), self._STEPS) + else: + with self.assertRaises(errors.NotFoundError): + plugin.scalars_impl( + self._SCALAR_TAG, + run_name, + "eid", + scalars_plugin.OutputFormat.CSV, + ) + + def test_scalars_json_with_legacy_scalars(self): + self._test_scalars_json( + self._RUN_WITH_LEGACY_SCALARS, self._LEGACY_SCALAR_TAG + ) + + def test_scalars_json_with_scalars(self): + self._test_scalars_json( + self._RUN_WITH_SCALARS, "%s/scalar_summary" % self._SCALAR_TAG + ) + + def test_scalars_json_with_histogram(self): + self._test_scalars_json( + self._RUN_WITH_HISTOGRAM, self._HISTOGRAM_TAG, should_work=False ) - @with_runs([_RUN_WITH_LEGACY_SCALARS, _RUN_WITH_SCALARS, _RUN_WITH_HISTOGRAM]) - def _test_scalars_csv(self, plugin, run_name, tag_name, should_work=True): - if should_work: - (data, mime_type) = plugin.scalars_impl( - tag_name, run_name, 'eid', scalars_plugin.OutputFormat.CSV) - self.assertEqual('text/csv', mime_type) - s = StringIO(data) - reader = csv.reader(s) - self.assertEqual(['Wall time', 'Step', 'Value'], next(reader)) - self.assertEqual(len(list(reader)), self._STEPS) - else: - with self.assertRaises(errors.NotFoundError): - plugin.scalars_impl( - self._SCALAR_TAG, run_name, 'eid', scalars_plugin.OutputFormat.CSV + def test_scalars_csv_with_legacy_scalars(self): + self._test_scalars_csv( + self._RUN_WITH_LEGACY_SCALARS, self._LEGACY_SCALAR_TAG ) - def test_scalars_json_with_legacy_scalars(self): - self._test_scalars_json(self._RUN_WITH_LEGACY_SCALARS, - self._LEGACY_SCALAR_TAG) - - def test_scalars_json_with_scalars(self): - self._test_scalars_json(self._RUN_WITH_SCALARS, - '%s/scalar_summary' % self._SCALAR_TAG) - - def test_scalars_json_with_histogram(self): - self._test_scalars_json(self._RUN_WITH_HISTOGRAM, self._HISTOGRAM_TAG, - should_work=False) - - def test_scalars_csv_with_legacy_scalars(self): - self._test_scalars_csv(self._RUN_WITH_LEGACY_SCALARS, - self._LEGACY_SCALAR_TAG) - - def test_scalars_csv_with_scalars(self): - self._test_scalars_csv(self._RUN_WITH_SCALARS, - '%s/scalar_summary' % self._SCALAR_TAG) - - def test_scalars_csv_with_histogram(self): - self._test_scalars_csv(self._RUN_WITH_HISTOGRAM, self._HISTOGRAM_TAG, - should_work=False) - - @with_runs([_RUN_WITH_LEGACY_SCALARS]) - def test_active_with_legacy_scalars(self, plugin): - self.assertTrue(plugin.is_active()) - - @with_runs([_RUN_WITH_SCALARS]) - def test_active_with_scalars(self, plugin): - self.assertTrue(plugin.is_active()) - - @with_runs([_RUN_WITH_HISTOGRAM]) - def test_active_with_histogram(self, plugin): - if plugin._data_provider: - # Hack, for now. - self.assertTrue(plugin.is_active()) - else: - self.assertFalse(plugin.is_active()) - - @with_runs([_RUN_WITH_LEGACY_SCALARS, _RUN_WITH_SCALARS, _RUN_WITH_HISTOGRAM]) - def test_active_with_all(self, plugin): - self.assertTrue(plugin.is_active()) - - @test_util.run_v1_only('Requires contrib for db writer') - def test_scalars_db_without_exp(self): - self.set_up_db() - self.generate_run_to_db('exp1', self._RUN_WITH_SCALARS) - - (data, mime_type) = self.plugin.scalars_impl( - self._SCALAR_TAG, self._RUN_WITH_SCALARS, 'eid', - scalars_plugin.OutputFormat.JSON) - self.assertEqual('application/json', mime_type) - # When querying DB-based backend without an experiment id, it returns all - # scalars without an experiment id. Such scalar can only be generated using - # raw SQL queries though. - self.assertEqual(len(data), 0) - - @test_util.run_v1_only('Requires contrib for db writer') - def test_scalars_db_filter_by_experiment(self): - self.set_up_db() - self.generate_run_to_db('exp1', self._RUN_WITH_SCALARS) - all_exps = self.core_plugin.list_experiments_impl() - exp1 = next((x for x in all_exps if x.get('name') == 'exp1'), {}) - - (data, mime_type) = self.plugin.scalars_impl( - self._SCALAR_TAG, self._RUN_WITH_SCALARS, exp1.get('id'), - scalars_plugin.OutputFormat.JSON) - self.assertEqual('application/json', mime_type) - self.assertEqual(len(data), self._STEPS) - - @test_util.run_v1_only('Requires contrib for db writer') - def test_scalars_db_no_match(self): - self.set_up_db() - self.generate_run_to_db('exp1', self._RUN_WITH_SCALARS) - - # experiment_id is a number but we passed a string here. - (data, mime_type) = self.plugin.scalars_impl( - self._SCALAR_TAG, self._RUN_WITH_SCALARS, 'random_exp_id', - scalars_plugin.OutputFormat.JSON) - self.assertEqual('application/json', mime_type) - self.assertEqual(len(data), 0) - -if __name__ == '__main__': - tf.test.main() + def test_scalars_csv_with_scalars(self): + self._test_scalars_csv( + self._RUN_WITH_SCALARS, "%s/scalar_summary" % self._SCALAR_TAG + ) + + def test_scalars_csv_with_histogram(self): + self._test_scalars_csv( + self._RUN_WITH_HISTOGRAM, self._HISTOGRAM_TAG, should_work=False + ) + + @with_runs([_RUN_WITH_LEGACY_SCALARS]) + def test_active_with_legacy_scalars(self, plugin): + self.assertTrue(plugin.is_active()) + + @with_runs([_RUN_WITH_SCALARS]) + def test_active_with_scalars(self, plugin): + self.assertTrue(plugin.is_active()) + + @with_runs([_RUN_WITH_HISTOGRAM]) + def test_active_with_histogram(self, plugin): + if plugin._data_provider: + # Hack, for now. + self.assertTrue(plugin.is_active()) + else: + self.assertFalse(plugin.is_active()) + + @with_runs( + [_RUN_WITH_LEGACY_SCALARS, _RUN_WITH_SCALARS, _RUN_WITH_HISTOGRAM] + ) + def test_active_with_all(self, plugin): + self.assertTrue(plugin.is_active()) + + @test_util.run_v1_only("Requires contrib for db writer") + def test_scalars_db_without_exp(self): + self.set_up_db() + self.generate_run_to_db("exp1", self._RUN_WITH_SCALARS) + + (data, mime_type) = self.plugin.scalars_impl( + self._SCALAR_TAG, + self._RUN_WITH_SCALARS, + "eid", + scalars_plugin.OutputFormat.JSON, + ) + self.assertEqual("application/json", mime_type) + # When querying DB-based backend without an experiment id, it returns all + # scalars without an experiment id. Such scalar can only be generated using + # raw SQL queries though. + self.assertEqual(len(data), 0) + + @test_util.run_v1_only("Requires contrib for db writer") + def test_scalars_db_filter_by_experiment(self): + self.set_up_db() + self.generate_run_to_db("exp1", self._RUN_WITH_SCALARS) + all_exps = self.core_plugin.list_experiments_impl() + exp1 = next((x for x in all_exps if x.get("name") == "exp1"), {}) + + (data, mime_type) = self.plugin.scalars_impl( + self._SCALAR_TAG, + self._RUN_WITH_SCALARS, + exp1.get("id"), + scalars_plugin.OutputFormat.JSON, + ) + self.assertEqual("application/json", mime_type) + self.assertEqual(len(data), self._STEPS) + + @test_util.run_v1_only("Requires contrib for db writer") + def test_scalars_db_no_match(self): + self.set_up_db() + self.generate_run_to_db("exp1", self._RUN_WITH_SCALARS) + + # experiment_id is a number but we passed a string here. + (data, mime_type) = self.plugin.scalars_impl( + self._SCALAR_TAG, + self._RUN_WITH_SCALARS, + "random_exp_id", + scalars_plugin.OutputFormat.JSON, + ) + self.assertEqual("application/json", mime_type) + self.assertEqual(len(data), 0) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/scalar/summary.py b/tensorboard/plugins/scalar/summary.py index e391971896..b1de180302 100644 --- a/tensorboard/plugins/scalar/summary.py +++ b/tensorboard/plugins/scalar/summary.py @@ -14,7 +14,8 @@ # ============================================================================== """Scalar summaries and TensorFlow operations to create them. -A scalar summary stores a single floating-point value, as a rank-0 tensor. +A scalar summary stores a single floating-point value, as a rank-0 +tensor. """ from __future__ import absolute_import @@ -32,78 +33,82 @@ scalar_pb = summary_v2.scalar_pb -def op(name, - data, - display_name=None, - description=None, - collections=None): - """Create a legacy scalar summary op. - - Arguments: - name: A unique name for the generated summary node. - data: A real numeric rank-0 `Tensor`. Must have `dtype` castable - to `float32`. - display_name: Optional name for this summary in TensorBoard, as a - constant `str`. Defaults to `name`. - description: Optional long-form description for this summary, as a - constant `str`. Markdown is supported. Defaults to empty. - collections: Optional list of graph collections keys. The new - summary op is added to these collections. Defaults to - `[Graph Keys.SUMMARIES]`. - - Returns: - A TensorFlow summary op. - """ - # TODO(nickfelt): remove on-demand imports once dep situation is fixed. - import tensorflow.compat.v1 as tf - - if display_name is None: - display_name = name - summary_metadata = metadata.create_summary_metadata( - display_name=display_name, description=description) - with tf.name_scope(name): - with tf.control_dependencies([tf.assert_scalar(data)]): - return tf.summary.tensor_summary(name='scalar_summary', - tensor=tf.cast(data, tf.float32), - collections=collections, - summary_metadata=summary_metadata) +def op(name, data, display_name=None, description=None, collections=None): + """Create a legacy scalar summary op. + + Arguments: + name: A unique name for the generated summary node. + data: A real numeric rank-0 `Tensor`. Must have `dtype` castable + to `float32`. + display_name: Optional name for this summary in TensorBoard, as a + constant `str`. Defaults to `name`. + description: Optional long-form description for this summary, as a + constant `str`. Markdown is supported. Defaults to empty. + collections: Optional list of graph collections keys. The new + summary op is added to these collections. Defaults to + `[Graph Keys.SUMMARIES]`. + + Returns: + A TensorFlow summary op. + """ + # TODO(nickfelt): remove on-demand imports once dep situation is fixed. + import tensorflow.compat.v1 as tf + + if display_name is None: + display_name = name + summary_metadata = metadata.create_summary_metadata( + display_name=display_name, description=description + ) + with tf.name_scope(name): + with tf.control_dependencies([tf.assert_scalar(data)]): + return tf.summary.tensor_summary( + name="scalar_summary", + tensor=tf.cast(data, tf.float32), + collections=collections, + summary_metadata=summary_metadata, + ) def pb(name, data, display_name=None, description=None): - """Create a legacy scalar summary protobuf. - - Arguments: - name: A unique name for the generated summary, including any desired - name scopes. - data: A rank-0 `np.array` or array-like form (so raw `int`s and - `float`s are fine, too). - display_name: Optional name for this summary in TensorBoard, as a - `str`. Defaults to `name`. - description: Optional long-form description for this summary, as a - `str`. Markdown is supported. Defaults to empty. - - Returns: - A `tf.Summary` protobuf object. - """ - # TODO(nickfelt): remove on-demand imports once dep situation is fixed. - import tensorflow.compat.v1 as tf - - data = np.array(data) - if data.shape != (): - raise ValueError('Expected scalar shape for data, saw shape: %s.' - % data.shape) - if data.dtype.kind not in ('b', 'i', 'u', 'f'): # bool, int, uint, float - raise ValueError('Cast %s to float is not supported' % data.dtype.name) - tensor = tf.make_tensor_proto(data.astype(np.float32)) - - if display_name is None: - display_name = name - summary_metadata = metadata.create_summary_metadata( - display_name=display_name, description=description) - tf_summary_metadata = tf.SummaryMetadata.FromString( - summary_metadata.SerializeToString()) - summary = tf.Summary() - summary.value.add(tag='%s/scalar_summary' % name, - metadata=tf_summary_metadata, - tensor=tensor) - return summary + """Create a legacy scalar summary protobuf. + + Arguments: + name: A unique name for the generated summary, including any desired + name scopes. + data: A rank-0 `np.array` or array-like form (so raw `int`s and + `float`s are fine, too). + display_name: Optional name for this summary in TensorBoard, as a + `str`. Defaults to `name`. + description: Optional long-form description for this summary, as a + `str`. Markdown is supported. Defaults to empty. + + Returns: + A `tf.Summary` protobuf object. + """ + # TODO(nickfelt): remove on-demand imports once dep situation is fixed. + import tensorflow.compat.v1 as tf + + data = np.array(data) + if data.shape != (): + raise ValueError( + "Expected scalar shape for data, saw shape: %s." % data.shape + ) + if data.dtype.kind not in ("b", "i", "u", "f"): # bool, int, uint, float + raise ValueError("Cast %s to float is not supported" % data.dtype.name) + tensor = tf.make_tensor_proto(data.astype(np.float32)) + + if display_name is None: + display_name = name + summary_metadata = metadata.create_summary_metadata( + display_name=display_name, description=description + ) + tf_summary_metadata = tf.SummaryMetadata.FromString( + summary_metadata.SerializeToString() + ) + summary = tf.Summary() + summary.value.add( + tag="%s/scalar_summary" % name, + metadata=tf_summary_metadata, + tensor=tensor, + ) + return summary diff --git a/tensorboard/plugins/scalar/summary_test.py b/tensorboard/plugins/scalar/summary_test.py index 3be3a4190f..754591209e 100644 --- a/tensorboard/plugins/scalar/summary_test.py +++ b/tensorboard/plugins/scalar/summary_test.py @@ -34,171 +34,179 @@ from tensorboard.util import tensor_util try: - tf2.__version__ # Force lazy import to resolve + tf2.__version__ # Force lazy import to resolve except ImportError: - tf2 = None + tf2 = None try: - tf.compat.v1.enable_eager_execution() + tf.compat.v1.enable_eager_execution() except AttributeError: - # TF 2.0 doesn't have this symbol because eager is the default. - pass + # TF 2.0 doesn't have this symbol because eager is the default. + pass class SummaryBaseTest(object): - - def scalar(self, *args, **kwargs): - raise NotImplementedError() - - def test_tag(self): - self.assertEqual('a', self.scalar('a', 1).value[0].tag) - self.assertEqual('a/b', self.scalar('a/b', 1).value[0].tag) - - def test_metadata(self): - pb = self.scalar('a', 1.13) - summary_metadata = pb.value[0].metadata - plugin_data = summary_metadata.plugin_data - self.assertEqual(summary_metadata.summary_description, '') - self.assertEqual(plugin_data.plugin_name, metadata.PLUGIN_NAME) - content = summary_metadata.plugin_data.content - # There's no content, so successfully parsing is fine. - metadata.parse_plugin_metadata(content) - - def test_explicit_description(self): - description = 'The first letter of the alphabet.' - pb = self.scalar('a', 1.13, description=description) - summary_metadata = pb.value[0].metadata - self.assertEqual(summary_metadata.summary_description, description) - plugin_data = summary_metadata.plugin_data - self.assertEqual(plugin_data.plugin_name, metadata.PLUGIN_NAME) - content = summary_metadata.plugin_data.content - # There's no content, so successfully parsing is fine. - metadata.parse_plugin_metadata(content) - - def test_float_value(self): - pb = self.scalar('a', 1.13) - value = tensor_util.make_ndarray(pb.value[0].tensor).item() - self.assertEqual(float, type(value)) - self.assertNear(1.13, value, 1e-6) - - def test_int_value(self): - # ints should be valid, but converted to floats. - pb = self.scalar('a', 113) - value = tensor_util.make_ndarray(pb.value[0].tensor).item() - self.assertEqual(float, type(value)) - self.assertNear(113.0, value, 1e-6) - - def test_bool_value(self): - # bools should be valid, but converted to floats. - pb = self.scalar('a', True) - value = tensor_util.make_ndarray(pb.value[0].tensor).item() - self.assertEqual(float, type(value)) - self.assertEqual(1.0, value) - - def test_string_value(self): - # Use str.* in regex because PY3 numpy refers to string arrays using - # length-dependent type names in the format "str%d" % (32 * len(str)). - with six.assertRaisesRegex(self, (ValueError, tf.errors.UnimplementedError), - r'Cast str.*float'): - self.scalar('a', np.array('113')) - - def test_requires_rank_0(self): - with six.assertRaisesRegex(self, ValueError, r'Expected scalar shape'): - self.scalar('a', np.array([1, 1, 3])) + def scalar(self, *args, **kwargs): + raise NotImplementedError() + + def test_tag(self): + self.assertEqual("a", self.scalar("a", 1).value[0].tag) + self.assertEqual("a/b", self.scalar("a/b", 1).value[0].tag) + + def test_metadata(self): + pb = self.scalar("a", 1.13) + summary_metadata = pb.value[0].metadata + plugin_data = summary_metadata.plugin_data + self.assertEqual(summary_metadata.summary_description, "") + self.assertEqual(plugin_data.plugin_name, metadata.PLUGIN_NAME) + content = summary_metadata.plugin_data.content + # There's no content, so successfully parsing is fine. + metadata.parse_plugin_metadata(content) + + def test_explicit_description(self): + description = "The first letter of the alphabet." + pb = self.scalar("a", 1.13, description=description) + summary_metadata = pb.value[0].metadata + self.assertEqual(summary_metadata.summary_description, description) + plugin_data = summary_metadata.plugin_data + self.assertEqual(plugin_data.plugin_name, metadata.PLUGIN_NAME) + content = summary_metadata.plugin_data.content + # There's no content, so successfully parsing is fine. + metadata.parse_plugin_metadata(content) + + def test_float_value(self): + pb = self.scalar("a", 1.13) + value = tensor_util.make_ndarray(pb.value[0].tensor).item() + self.assertEqual(float, type(value)) + self.assertNear(1.13, value, 1e-6) + + def test_int_value(self): + # ints should be valid, but converted to floats. + pb = self.scalar("a", 113) + value = tensor_util.make_ndarray(pb.value[0].tensor).item() + self.assertEqual(float, type(value)) + self.assertNear(113.0, value, 1e-6) + + def test_bool_value(self): + # bools should be valid, but converted to floats. + pb = self.scalar("a", True) + value = tensor_util.make_ndarray(pb.value[0].tensor).item() + self.assertEqual(float, type(value)) + self.assertEqual(1.0, value) + + def test_string_value(self): + # Use str.* in regex because PY3 numpy refers to string arrays using + # length-dependent type names in the format "str%d" % (32 * len(str)). + with six.assertRaisesRegex( + self, (ValueError, tf.errors.UnimplementedError), r"Cast str.*float" + ): + self.scalar("a", np.array("113")) + + def test_requires_rank_0(self): + with six.assertRaisesRegex(self, ValueError, r"Expected scalar shape"): + self.scalar("a", np.array([1, 1, 3])) class SummaryV1PbTest(SummaryBaseTest, tf.test.TestCase): - def scalar(self, *args, **kwargs): - return summary.pb(*args, **kwargs) + def scalar(self, *args, **kwargs): + return summary.pb(*args, **kwargs) - def test_tag(self): - self.assertEqual('a/scalar_summary', self.scalar('a', 1).value[0].tag) - self.assertEqual('a/b/scalar_summary', self.scalar('a/b', 1).value[0].tag) + def test_tag(self): + self.assertEqual("a/scalar_summary", self.scalar("a", 1).value[0].tag) + self.assertEqual( + "a/b/scalar_summary", self.scalar("a/b", 1).value[0].tag + ) class SummaryV1OpTest(SummaryBaseTest, tf.test.TestCase): - def scalar(self, *args, **kwargs): - return summary_pb2.Summary.FromString(summary.op(*args, **kwargs).numpy()) + def scalar(self, *args, **kwargs): + return summary_pb2.Summary.FromString( + summary.op(*args, **kwargs).numpy() + ) - def test_tag(self): - self.assertEqual('a/scalar_summary', self.scalar('a', 1).value[0].tag) - self.assertEqual('a/b/scalar_summary', self.scalar('a/b', 1).value[0].tag) + def test_tag(self): + self.assertEqual("a/scalar_summary", self.scalar("a", 1).value[0].tag) + self.assertEqual( + "a/b/scalar_summary", self.scalar("a/b", 1).value[0].tag + ) - def test_scoped_tag(self): - with tf.name_scope('scope'): - self.assertEqual('scope/a/scalar_summary', - self.scalar('a', 1).value[0].tag) + def test_scoped_tag(self): + with tf.name_scope("scope"): + self.assertEqual( + "scope/a/scalar_summary", self.scalar("a", 1).value[0].tag + ) class SummaryV2PbTest(SummaryBaseTest, tf.test.TestCase): - def scalar(self, *args, **kwargs): - return summary.scalar_pb(*args, **kwargs) + def scalar(self, *args, **kwargs): + return summary.scalar_pb(*args, **kwargs) class SummaryV2OpTest(SummaryBaseTest, tf.test.TestCase): - - def setUp(self): - super(SummaryV2OpTest, self).setUp() - if tf2 is None: - self.skipTest('v2 summary API not available') - - def scalar(self, *args, **kwargs): - return self.scalar_event(*args, **kwargs).summary - - def scalar_event(self, *args, **kwargs): - self.write_scalar_event(*args, **kwargs) - event_files = sorted(glob.glob(os.path.join(self.get_temp_dir(), '*'))) - self.assertEqual(len(event_files), 1) - events = list(tf.compat.v1.train.summary_iterator(event_files[0])) - # Expect a boilerplate event for the file_version, then the summary one. - self.assertEqual(len(events), 2) - # Delete the event file to reset to an empty directory for later calls. - # TODO(nickfelt): use a unique subdirectory per writer instead. - os.remove(event_files[0]) - return events[1] - - def write_scalar_event(self, *args, **kwargs): - kwargs.setdefault('step', 1) - writer = tf2.summary.create_file_writer(self.get_temp_dir()) - with writer.as_default(): - summary.scalar(*args, **kwargs) - writer.close() - - def test_scoped_tag(self): - with tf.name_scope('scope'): - self.assertEqual('scope/a', self.scalar('a', 1).value[0].tag) - - def test_step(self): - event = self.scalar_event('a', 1.0, step=333) - self.assertEqual(333, event.step) - - def test_default_step(self): - try: - tf2.summary.experimental.set_step(333) - # TODO(nickfelt): change test logic so we can just omit `step` entirely. - event = self.scalar_event('a', 1.0, step=None) - self.assertEqual(333, event.step) - finally: - # Reset to default state for other tests. - tf2.summary.experimental.set_step(None) + def setUp(self): + super(SummaryV2OpTest, self).setUp() + if tf2 is None: + self.skipTest("v2 summary API not available") + + def scalar(self, *args, **kwargs): + return self.scalar_event(*args, **kwargs).summary + + def scalar_event(self, *args, **kwargs): + self.write_scalar_event(*args, **kwargs) + event_files = sorted(glob.glob(os.path.join(self.get_temp_dir(), "*"))) + self.assertEqual(len(event_files), 1) + events = list(tf.compat.v1.train.summary_iterator(event_files[0])) + # Expect a boilerplate event for the file_version, then the summary one. + self.assertEqual(len(events), 2) + # Delete the event file to reset to an empty directory for later calls. + # TODO(nickfelt): use a unique subdirectory per writer instead. + os.remove(event_files[0]) + return events[1] + + def write_scalar_event(self, *args, **kwargs): + kwargs.setdefault("step", 1) + writer = tf2.summary.create_file_writer(self.get_temp_dir()) + with writer.as_default(): + summary.scalar(*args, **kwargs) + writer.close() + + def test_scoped_tag(self): + with tf.name_scope("scope"): + self.assertEqual("scope/a", self.scalar("a", 1).value[0].tag) + + def test_step(self): + event = self.scalar_event("a", 1.0, step=333) + self.assertEqual(333, event.step) + + def test_default_step(self): + try: + tf2.summary.experimental.set_step(333) + # TODO(nickfelt): change test logic so we can just omit `step` entirely. + event = self.scalar_event("a", 1.0, step=None) + self.assertEqual(333, event.step) + finally: + # Reset to default state for other tests. + tf2.summary.experimental.set_step(None) class SummaryV2OpGraphTest(SummaryV2OpTest, tf.test.TestCase): - def write_scalar_event(self, *args, **kwargs): - kwargs.setdefault('step', 1) - # Hack to extract current scope since there's no direct API for it. - with tf.name_scope('_') as temp_scope: - scope = temp_scope.rstrip('/_') - @tf2.function - def graph_fn(): - # Recreate the active scope inside the defun since it won't propagate. - with tf.name_scope(scope): - summary.scalar(*args, **kwargs) - writer = tf2.summary.create_file_writer(self.get_temp_dir()) - with writer.as_default(): - graph_fn() - writer.close() - - -if __name__ == '__main__': - tf.test.main() + def write_scalar_event(self, *args, **kwargs): + kwargs.setdefault("step", 1) + # Hack to extract current scope since there's no direct API for it. + with tf.name_scope("_") as temp_scope: + scope = temp_scope.rstrip("/_") + + @tf2.function + def graph_fn(): + # Recreate the active scope inside the defun since it won't propagate. + with tf.name_scope(scope): + summary.scalar(*args, **kwargs) + + writer = tf2.summary.create_file_writer(self.get_temp_dir()) + with writer.as_default(): + graph_fn() + writer.close() + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/scalar/summary_v2.py b/tensorboard/plugins/scalar/summary_v2.py index be5fc3b214..17aae4543a 100644 --- a/tensorboard/plugins/scalar/summary_v2.py +++ b/tensorboard/plugins/scalar/summary_v2.py @@ -14,7 +14,8 @@ # ============================================================================== """Scalar summaries and TensorFlow operations to create them, V2 versions. -A scalar summary stores a single floating-point value, as a rank-0 tensor. +A scalar summary stores a single floating-point value, as a rank-0 +tensor. """ from __future__ import absolute_import @@ -30,67 +31,70 @@ def scalar(name, data, step=None, description=None): - """Write a scalar summary. + """Write a scalar summary. - Arguments: - name: A name for this summary. The summary tag used for TensorBoard will - be this name prefixed by any active name scopes. - data: A real numeric scalar value, convertible to a `float32` Tensor. - step: Explicit `int64`-castable monotonic step value for this summary. If - omitted, this defaults to `tf.summary.experimental.get_step()`, which must - not be None. - description: Optional long-form description for this summary, as a - constant `str`. Markdown is supported. Defaults to empty. + Arguments: + name: A name for this summary. The summary tag used for TensorBoard will + be this name prefixed by any active name scopes. + data: A real numeric scalar value, convertible to a `float32` Tensor. + step: Explicit `int64`-castable monotonic step value for this summary. If + omitted, this defaults to `tf.summary.experimental.get_step()`, which must + not be None. + description: Optional long-form description for this summary, as a + constant `str`. Markdown is supported. Defaults to empty. - Returns: - True on success, or false if no summary was written because no default - summary writer was available. + Returns: + True on success, or false if no summary was written because no default + summary writer was available. - Raises: - ValueError: if a default writer exists, but no step was provided and - `tf.summary.experimental.get_step()` is None. - """ - summary_metadata = metadata.create_summary_metadata( - display_name=None, description=description) - # TODO(https://github.com/tensorflow/tensorboard/issues/2109): remove fallback - summary_scope = ( - getattr(tf.summary.experimental, 'summary_scope', None) or - tf.summary.summary_scope) - with summary_scope( - name, 'scalar_summary', values=[data, step]) as (tag, _): - tf.debugging.assert_scalar(data) - return tf.summary.write(tag=tag, - tensor=tf.cast(data, tf.float32), - step=step, - metadata=summary_metadata) + Raises: + ValueError: if a default writer exists, but no step was provided and + `tf.summary.experimental.get_step()` is None. + """ + summary_metadata = metadata.create_summary_metadata( + display_name=None, description=description + ) + # TODO(https://github.com/tensorflow/tensorboard/issues/2109): remove fallback + summary_scope = ( + getattr(tf.summary.experimental, "summary_scope", None) + or tf.summary.summary_scope + ) + with summary_scope(name, "scalar_summary", values=[data, step]) as (tag, _): + tf.debugging.assert_scalar(data) + return tf.summary.write( + tag=tag, + tensor=tf.cast(data, tf.float32), + step=step, + metadata=summary_metadata, + ) def scalar_pb(tag, data, description=None): - """Create a scalar summary_pb2.Summary protobuf. + """Create a scalar summary_pb2.Summary protobuf. - Arguments: - tag: String tag for the summary. - data: A 0-dimensional `np.array` or a compatible python number type. - description: Optional long-form description for this summary, as a - `str`. Markdown is supported. Defaults to empty. + Arguments: + tag: String tag for the summary. + data: A 0-dimensional `np.array` or a compatible python number type. + description: Optional long-form description for this summary, as a + `str`. Markdown is supported. Defaults to empty. - Raises: - ValueError: If the type or shape of the data is unsupported. + Raises: + ValueError: If the type or shape of the data is unsupported. - Returns: - A `summary_pb2.Summary` protobuf object. - """ - arr = np.array(data) - if arr.shape != (): - raise ValueError('Expected scalar shape for tensor, got shape: %s.' - % arr.shape) - if arr.dtype.kind not in ('b', 'i', 'u', 'f'): # bool, int, uint, float - raise ValueError('Cast %s to float is not supported' % arr.dtype.name) - tensor_proto = tensor_util.make_tensor_proto(arr.astype(np.float32)) - summary_metadata = metadata.create_summary_metadata( - display_name=None, description=description) - summary = summary_pb2.Summary() - summary.value.add(tag=tag, - metadata=summary_metadata, - tensor=tensor_proto) - return summary + Returns: + A `summary_pb2.Summary` protobuf object. + """ + arr = np.array(data) + if arr.shape != (): + raise ValueError( + "Expected scalar shape for tensor, got shape: %s." % arr.shape + ) + if arr.dtype.kind not in ("b", "i", "u", "f"): # bool, int, uint, float + raise ValueError("Cast %s to float is not supported" % arr.dtype.name) + tensor_proto = tensor_util.make_tensor_proto(arr.astype(np.float32)) + summary_metadata = metadata.create_summary_metadata( + display_name=None, description=description + ) + summary = summary_pb2.Summary() + summary.value.add(tag=tag, metadata=summary_metadata, tensor=tensor_proto) + return summary diff --git a/tensorboard/plugins/text/metadata.py b/tensorboard/plugins/text/metadata.py index 7d7225890d..4236fb07e2 100644 --- a/tensorboard/plugins/text/metadata.py +++ b/tensorboard/plugins/text/metadata.py @@ -24,7 +24,7 @@ logger = tb_logging.get_logger() -PLUGIN_NAME = 'text' +PLUGIN_NAME = "text" # The most recent value for the `version` field of the # `TextPluginData` proto. @@ -32,36 +32,42 @@ def create_summary_metadata(display_name, description): - """Create a `summary_pb2.SummaryMetadata` proto for text plugin data. - Returns: - A `summary_pb2.SummaryMetadata` protobuf object. - """ - content = plugin_data_pb2.TextPluginData(version=PROTO_VERSION) - metadata = summary_pb2.SummaryMetadata( - display_name=display_name, - summary_description=description, - plugin_data=summary_pb2.SummaryMetadata.PluginData( - plugin_name=PLUGIN_NAME, - content=content.SerializeToString())) - return metadata + """Create a `summary_pb2.SummaryMetadata` proto for text plugin data. + + Returns: + A `summary_pb2.SummaryMetadata` protobuf object. + """ + content = plugin_data_pb2.TextPluginData(version=PROTO_VERSION) + metadata = summary_pb2.SummaryMetadata( + display_name=display_name, + summary_description=description, + plugin_data=summary_pb2.SummaryMetadata.PluginData( + plugin_name=PLUGIN_NAME, content=content.SerializeToString() + ), + ) + return metadata def parse_plugin_metadata(content): - """Parse summary metadata to a Python object. - Arguments: - content: The `content` field of a `SummaryMetadata` proto corresponding to - the text plugin. - Returns: - A `TextPluginData` protobuf object. - """ - if not isinstance(content, bytes): - raise TypeError('Content type must be bytes') - result = plugin_data_pb2.TextPluginData.FromString(content) - if result.version == 0: - return result - else: - logger.warn( - 'Unknown metadata version: %s. The latest version known to ' - 'this build of TensorBoard is %s; perhaps a newer build is ' - 'available?', result.version, PROTO_VERSION) - return result + """Parse summary metadata to a Python object. + + Arguments: + content: The `content` field of a `SummaryMetadata` proto corresponding to + the text plugin. + Returns: + A `TextPluginData` protobuf object. + """ + if not isinstance(content, bytes): + raise TypeError("Content type must be bytes") + result = plugin_data_pb2.TextPluginData.FromString(content) + if result.version == 0: + return result + else: + logger.warn( + "Unknown metadata version: %s. The latest version known to " + "this build of TensorBoard is %s; perhaps a newer build is " + "available?", + result.version, + PROTO_VERSION, + ) + return result diff --git a/tensorboard/plugins/text/summary.py b/tensorboard/plugins/text/summary.py index 4248c8fcb4..76f78be157 100644 --- a/tensorboard/plugins/text/summary.py +++ b/tensorboard/plugins/text/summary.py @@ -27,90 +27,93 @@ text_pb = summary_v2.text_pb -def op(name, - data, - display_name=None, - description=None, - collections=None): - """Create a legacy text summary op. - - Text data summarized via this plugin will be visible in the Text Dashboard - in TensorBoard. The standard TensorBoard Text Dashboard will render markdown - in the strings, and will automatically organize 1D and 2D tensors into tables. - If a tensor with more than 2 dimensions is provided, a 2D subarray will be - displayed along with a warning message. (Note that this behavior is not - intrinsic to the text summary API, but rather to the default TensorBoard text - plugin.) - - Args: - name: A name for the generated node. Will also serve as a series name in - TensorBoard. - data: A string-type Tensor to summarize. The text must be encoded in UTF-8. - display_name: Optional name for this summary in TensorBoard, as a - constant `str`. Defaults to `name`. - description: Optional long-form description for this summary, as a - constant `str`. Markdown is supported. Defaults to empty. - collections: Optional list of ops.GraphKeys. The collections to which to add - the summary. Defaults to [Graph Keys.SUMMARIES]. - - Returns: - A TensorSummary op that is configured so that TensorBoard will recognize - that it contains textual data. The TensorSummary is a scalar `Tensor` of - type `string` which contains `Summary` protobufs. - - Raises: - ValueError: If tensor has the wrong type. - """ - # TODO(nickfelt): remove on-demand imports once dep situation is fixed. - import tensorflow.compat.v1 as tf - - if display_name is None: - display_name = name - summary_metadata = metadata.create_summary_metadata( - display_name=display_name, description=description) - with tf.name_scope(name): - with tf.control_dependencies([tf.assert_type(data, tf.string)]): - return tf.summary.tensor_summary(name='text_summary', - tensor=data, - collections=collections, - summary_metadata=summary_metadata) +def op(name, data, display_name=None, description=None, collections=None): + """Create a legacy text summary op. + + Text data summarized via this plugin will be visible in the Text Dashboard + in TensorBoard. The standard TensorBoard Text Dashboard will render markdown + in the strings, and will automatically organize 1D and 2D tensors into tables. + If a tensor with more than 2 dimensions is provided, a 2D subarray will be + displayed along with a warning message. (Note that this behavior is not + intrinsic to the text summary API, but rather to the default TensorBoard text + plugin.) + + Args: + name: A name for the generated node. Will also serve as a series name in + TensorBoard. + data: A string-type Tensor to summarize. The text must be encoded in UTF-8. + display_name: Optional name for this summary in TensorBoard, as a + constant `str`. Defaults to `name`. + description: Optional long-form description for this summary, as a + constant `str`. Markdown is supported. Defaults to empty. + collections: Optional list of ops.GraphKeys. The collections to which to add + the summary. Defaults to [Graph Keys.SUMMARIES]. + + Returns: + A TensorSummary op that is configured so that TensorBoard will recognize + that it contains textual data. The TensorSummary is a scalar `Tensor` of + type `string` which contains `Summary` protobufs. + + Raises: + ValueError: If tensor has the wrong type. + """ + # TODO(nickfelt): remove on-demand imports once dep situation is fixed. + import tensorflow.compat.v1 as tf + + if display_name is None: + display_name = name + summary_metadata = metadata.create_summary_metadata( + display_name=display_name, description=description + ) + with tf.name_scope(name): + with tf.control_dependencies([tf.assert_type(data, tf.string)]): + return tf.summary.tensor_summary( + name="text_summary", + tensor=data, + collections=collections, + summary_metadata=summary_metadata, + ) def pb(name, data, display_name=None, description=None): - """Create a legacy text summary protobuf. - - Arguments: - name: A name for the generated node. Will also serve as a series name in - TensorBoard. - data: A Python bytestring (of type bytes), or Unicode string. Or a numpy - data array of those types. - display_name: Optional name for this summary in TensorBoard, as a - `str`. Defaults to `name`. - description: Optional long-form description for this summary, as a - `str`. Markdown is supported. Defaults to empty. - - Raises: - ValueError: If the type of the data is unsupported. - - Returns: - A `tf.Summary` protobuf object. - """ - # TODO(nickfelt): remove on-demand imports once dep situation is fixed. - import tensorflow.compat.v1 as tf - - try: - tensor = tf.make_tensor_proto(data, dtype=tf.string) - except TypeError as e: - raise ValueError(e) - - if display_name is None: - display_name = name - summary_metadata = metadata.create_summary_metadata( - display_name=display_name, description=description) - tf_summary_metadata = tf.SummaryMetadata.FromString( - summary_metadata.SerializeToString()) - summary = tf.Summary() - summary.value.add(tag='%s/text_summary' % name, - metadata=tf_summary_metadata, - tensor=tensor) - return summary + """Create a legacy text summary protobuf. + + Arguments: + name: A name for the generated node. Will also serve as a series name in + TensorBoard. + data: A Python bytestring (of type bytes), or Unicode string. Or a numpy + data array of those types. + display_name: Optional name for this summary in TensorBoard, as a + `str`. Defaults to `name`. + description: Optional long-form description for this summary, as a + `str`. Markdown is supported. Defaults to empty. + + Raises: + ValueError: If the type of the data is unsupported. + + Returns: + A `tf.Summary` protobuf object. + """ + # TODO(nickfelt): remove on-demand imports once dep situation is fixed. + import tensorflow.compat.v1 as tf + + try: + tensor = tf.make_tensor_proto(data, dtype=tf.string) + except TypeError as e: + raise ValueError(e) + + if display_name is None: + display_name = name + summary_metadata = metadata.create_summary_metadata( + display_name=display_name, description=description + ) + tf_summary_metadata = tf.SummaryMetadata.FromString( + summary_metadata.SerializeToString() + ) + summary = tf.Summary() + summary.value.add( + tag="%s/text_summary" % name, + metadata=tf_summary_metadata, + tensor=tensor, + ) + return summary diff --git a/tensorboard/plugins/text/summary_test.py b/tensorboard/plugins/text/summary_test.py index 2de1b572df..ee9d5d4eb4 100644 --- a/tensorboard/plugins/text/summary_test.py +++ b/tensorboard/plugins/text/summary_test.py @@ -34,188 +34,203 @@ from tensorboard.util import tensor_util try: - tf2.__version__ # Force lazy import to resolve + tf2.__version__ # Force lazy import to resolve except ImportError: - tf2 = None + tf2 = None try: - tf.compat.v1.enable_eager_execution() + tf.compat.v1.enable_eager_execution() except AttributeError: - # TF 2.0 doesn't have this symbol because eager is the default. - pass + # TF 2.0 doesn't have this symbol because eager is the default. + pass class SummaryBaseTest(object): - - def text(self, *args, **kwargs): - raise NotImplementedError() - - def test_tag(self): - self.assertEqual('a', self.text('a', 'foo').value[0].tag) - self.assertEqual('a/b', self.text('a/b', 'foo').value[0].tag) - - def test_metadata(self): - pb = self.text('do', 'A deer. A female deer.') - summary_metadata = pb.value[0].metadata - plugin_data = summary_metadata.plugin_data - self.assertEqual(summary_metadata.summary_description, '') - self.assertEqual(plugin_data.plugin_name, metadata.PLUGIN_NAME) - content = summary_metadata.plugin_data.content - # There's no content, so successfully parsing is fine. - metadata.parse_plugin_metadata(content) - - def test_explicit_description(self): - description = 'A whole step above do.' - pb = self.text('re', 'A drop of golden sun.', description=description) - summary_metadata = pb.value[0].metadata - self.assertEqual(summary_metadata.summary_description, description) - plugin_data = summary_metadata.plugin_data - self.assertEqual(plugin_data.plugin_name, metadata.PLUGIN_NAME) - content = summary_metadata.plugin_data.content - # There's no content, so successfully parsing is fine. - metadata.parse_plugin_metadata(content) - - def test_bytes_value(self): - pb = self.text('mi', b'A name\xe2\x80\xa6I call myself') - value = tensor_util.make_ndarray(pb.value[0].tensor).item() - self.assertIsInstance(value, six.binary_type) - self.assertEqual(b'A name\xe2\x80\xa6I call myself', value) - - def test_unicode_value(self): - pb = self.text('mi', u'A name\u2026I call myself') - value = tensor_util.make_ndarray(pb.value[0].tensor).item() - self.assertIsInstance(value, six.binary_type) - self.assertEqual(b'A name\xe2\x80\xa6I call myself', value) - - def test_np_array_bytes_value(self): - pb = self.text( - 'fa', - np.array( - [[b'A', b'long', b'long'], [b'way', b'to', b'run \xe2\x80\xbc']])) - values = tensor_util.make_ndarray(pb.value[0].tensor).tolist() - self.assertEqual( - [[b'A', b'long', b'long'], [b'way', b'to', b'run \xe2\x80\xbc']], - values) - # Check that all entries are byte strings. - for vectors in values: - for value in vectors: + def text(self, *args, **kwargs): + raise NotImplementedError() + + def test_tag(self): + self.assertEqual("a", self.text("a", "foo").value[0].tag) + self.assertEqual("a/b", self.text("a/b", "foo").value[0].tag) + + def test_metadata(self): + pb = self.text("do", "A deer. A female deer.") + summary_metadata = pb.value[0].metadata + plugin_data = summary_metadata.plugin_data + self.assertEqual(summary_metadata.summary_description, "") + self.assertEqual(plugin_data.plugin_name, metadata.PLUGIN_NAME) + content = summary_metadata.plugin_data.content + # There's no content, so successfully parsing is fine. + metadata.parse_plugin_metadata(content) + + def test_explicit_description(self): + description = "A whole step above do." + pb = self.text("re", "A drop of golden sun.", description=description) + summary_metadata = pb.value[0].metadata + self.assertEqual(summary_metadata.summary_description, description) + plugin_data = summary_metadata.plugin_data + self.assertEqual(plugin_data.plugin_name, metadata.PLUGIN_NAME) + content = summary_metadata.plugin_data.content + # There's no content, so successfully parsing is fine. + metadata.parse_plugin_metadata(content) + + def test_bytes_value(self): + pb = self.text("mi", b"A name\xe2\x80\xa6I call myself") + value = tensor_util.make_ndarray(pb.value[0].tensor).item() self.assertIsInstance(value, six.binary_type) + self.assertEqual(b"A name\xe2\x80\xa6I call myself", value) - def test_np_array_unicode_value(self): - pb = self.text( - 'fa', - np.array( - [[u'A', u'long', u'long'], [u'way', u'to', u'run \u203C']])) - values = tensor_util.make_ndarray (pb.value[0].tensor).tolist() - self.assertEqual( - [[b'A', b'long', b'long'], [b'way', b'to', b'run \xe2\x80\xbc']], - values) - # Check that all entries are byte strings. - for vectors in values: - for value in vectors: + def test_unicode_value(self): + pb = self.text("mi", u"A name\u2026I call myself") + value = tensor_util.make_ndarray(pb.value[0].tensor).item() self.assertIsInstance(value, six.binary_type) - - def test_non_string_value(self): - with six.assertRaisesRegex(self, TypeError, r'must be of type.*string'): - self.text('la', np.array(range(42))) + self.assertEqual(b"A name\xe2\x80\xa6I call myself", value) + + def test_np_array_bytes_value(self): + pb = self.text( + "fa", + np.array( + [[b"A", b"long", b"long"], [b"way", b"to", b"run \xe2\x80\xbc"]] + ), + ) + values = tensor_util.make_ndarray(pb.value[0].tensor).tolist() + self.assertEqual( + [[b"A", b"long", b"long"], [b"way", b"to", b"run \xe2\x80\xbc"]], + values, + ) + # Check that all entries are byte strings. + for vectors in values: + for value in vectors: + self.assertIsInstance(value, six.binary_type) + + def test_np_array_unicode_value(self): + pb = self.text( + "fa", + np.array( + [[u"A", u"long", u"long"], [u"way", u"to", u"run \u203C"]] + ), + ) + values = tensor_util.make_ndarray(pb.value[0].tensor).tolist() + self.assertEqual( + [[b"A", b"long", b"long"], [b"way", b"to", b"run \xe2\x80\xbc"]], + values, + ) + # Check that all entries are byte strings. + for vectors in values: + for value in vectors: + self.assertIsInstance(value, six.binary_type) + + def test_non_string_value(self): + with six.assertRaisesRegex(self, TypeError, r"must be of type.*string"): + self.text("la", np.array(range(42))) class SummaryV1PbTest(SummaryBaseTest, tf.test.TestCase): - def text(self, *args, **kwargs): - return summary.pb(*args, **kwargs) + def text(self, *args, **kwargs): + return summary.pb(*args, **kwargs) - def test_tag(self): - self.assertEqual('a/text_summary', self.text('a', 'foo').value[0].tag) - self.assertEqual('a/b/text_summary', self.text('a/b', 'foo').value[0].tag) + def test_tag(self): + self.assertEqual("a/text_summary", self.text("a", "foo").value[0].tag) + self.assertEqual( + "a/b/text_summary", self.text("a/b", "foo").value[0].tag + ) - def test_non_string_value(self): - with six.assertRaisesRegex(self, ValueError, - r'Expected binary or unicode string, got 0'): - self.text('la', np.array(range(42))) + def test_non_string_value(self): + with six.assertRaisesRegex( + self, ValueError, r"Expected binary or unicode string, got 0" + ): + self.text("la", np.array(range(42))) class SummaryV1OpTest(SummaryBaseTest, tf.test.TestCase): - def text(self, *args, **kwargs): - return summary_pb2.Summary.FromString(summary.op(*args, **kwargs).numpy()) + def text(self, *args, **kwargs): + return summary_pb2.Summary.FromString( + summary.op(*args, **kwargs).numpy() + ) - def test_tag(self): - self.assertEqual('a/text_summary', self.text('a', 'foo').value[0].tag) - self.assertEqual('a/b/text_summary', self.text('a/b', 'foo').value[0].tag) + def test_tag(self): + self.assertEqual("a/text_summary", self.text("a", "foo").value[0].tag) + self.assertEqual( + "a/b/text_summary", self.text("a/b", "foo").value[0].tag + ) - def test_scoped_tag(self): - with tf.name_scope('scope'): - self.assertEqual('scope/a/text_summary', - self.text('a', 'foo').value[0].tag) + def test_scoped_tag(self): + with tf.name_scope("scope"): + self.assertEqual( + "scope/a/text_summary", self.text("a", "foo").value[0].tag + ) class SummaryV2PbTest(SummaryBaseTest, tf.test.TestCase): - def text(self, *args, **kwargs): - return summary.text_pb(*args, **kwargs) + def text(self, *args, **kwargs): + return summary.text_pb(*args, **kwargs) class SummaryV2OpTest(SummaryBaseTest, tf.test.TestCase): - def setUp(self): - super(SummaryV2OpTest, self).setUp() - if tf2 is None: - self.skipTest('TF v2 summary API not available') - - def text(self, *args, **kwargs): - return self.text_event(*args, **kwargs).summary - - def text_event(self, *args, **kwargs): - self.write_text_event(*args, **kwargs) - event_files = sorted(glob.glob(os.path.join(self.get_temp_dir(), '*'))) - self.assertEqual(len(event_files), 1) - events = list(tf.compat.v1.train.summary_iterator(event_files[0])) - # Expect a boilerplate event for the file_version, then the summary one. - self.assertEqual(len(events), 2) - # Delete the event file to reset to an empty directory for later calls. - # TODO(nickfelt): use a unique subdirectory per writer instead. - os.remove(event_files[0]) - return events[1] - - def write_text_event(self, *args, **kwargs): - kwargs.setdefault('step', 1) - writer = tf2.summary.create_file_writer(self.get_temp_dir()) - with writer.as_default(): - summary.text(*args, **kwargs) - writer.close() - - def test_scoped_tag(self): - with tf.name_scope('scope'): - self.assertEqual('scope/a', self.text('a', 'foo').value[0].tag) - - def test_step(self): - event = self.text_event('a', 'foo', step=333) - self.assertEqual(333, event.step) - - def test_default_step(self): - try: - tf2.summary.experimental.set_step(333) - # TODO(nickfelt): change test logic so we can just omit `step` entirely. - event = self.text_event('a', 'foo', step=None) - self.assertEqual(333, event.step) - finally: - # Reset to default state for other tests. - tf2.summary.experimental.set_step(None) + def setUp(self): + super(SummaryV2OpTest, self).setUp() + if tf2 is None: + self.skipTest("TF v2 summary API not available") + + def text(self, *args, **kwargs): + return self.text_event(*args, **kwargs).summary + + def text_event(self, *args, **kwargs): + self.write_text_event(*args, **kwargs) + event_files = sorted(glob.glob(os.path.join(self.get_temp_dir(), "*"))) + self.assertEqual(len(event_files), 1) + events = list(tf.compat.v1.train.summary_iterator(event_files[0])) + # Expect a boilerplate event for the file_version, then the summary one. + self.assertEqual(len(events), 2) + # Delete the event file to reset to an empty directory for later calls. + # TODO(nickfelt): use a unique subdirectory per writer instead. + os.remove(event_files[0]) + return events[1] + + def write_text_event(self, *args, **kwargs): + kwargs.setdefault("step", 1) + writer = tf2.summary.create_file_writer(self.get_temp_dir()) + with writer.as_default(): + summary.text(*args, **kwargs) + writer.close() + + def test_scoped_tag(self): + with tf.name_scope("scope"): + self.assertEqual("scope/a", self.text("a", "foo").value[0].tag) + + def test_step(self): + event = self.text_event("a", "foo", step=333) + self.assertEqual(333, event.step) + + def test_default_step(self): + try: + tf2.summary.experimental.set_step(333) + # TODO(nickfelt): change test logic so we can just omit `step` entirely. + event = self.text_event("a", "foo", step=None) + self.assertEqual(333, event.step) + finally: + # Reset to default state for other tests. + tf2.summary.experimental.set_step(None) class SummaryV2OpGraphTest(SummaryV2OpTest, tf.test.TestCase): - def write_text_event(self, *args, **kwargs): - kwargs.setdefault('step', 1) - # Hack to extract current scope since there's no direct API for it. - with tf.name_scope('_') as temp_scope: - scope = temp_scope.rstrip('/_') - @tf2.function - def graph_fn(): - # Recreate the active scope inside the defun since it won't propagate. - with tf.name_scope(scope): - summary.text(*args, **kwargs) - writer = tf2.summary.create_file_writer(self.get_temp_dir()) - with writer.as_default(): - graph_fn() - writer.close() - - -if __name__ == '__main__': - tf.test.main() + def write_text_event(self, *args, **kwargs): + kwargs.setdefault("step", 1) + # Hack to extract current scope since there's no direct API for it. + with tf.name_scope("_") as temp_scope: + scope = temp_scope.rstrip("/_") + + @tf2.function + def graph_fn(): + # Recreate the active scope inside the defun since it won't propagate. + with tf.name_scope(scope): + summary.text(*args, **kwargs) + + writer = tf2.summary.create_file_writer(self.get_temp_dir()) + with writer.as_default(): + graph_fn() + writer.close() + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/plugins/text/summary_v2.py b/tensorboard/plugins/text/summary_v2.py index a5b076196c..5d14c84b24 100644 --- a/tensorboard/plugins/text/summary_v2.py +++ b/tensorboard/plugins/text/summary_v2.py @@ -27,63 +27,64 @@ def text(name, data, step=None, description=None): - """Write a text summary. + """Write a text summary. - Arguments: - name: A name for this summary. The summary tag used for TensorBoard will - be this name prefixed by any active name scopes. - data: A UTF-8 string tensor value. - step: Explicit `int64`-castable monotonic step value for this summary. If - omitted, this defaults to `tf.summary.experimental.get_step()`, which must - not be None. - description: Optional long-form description for this summary, as a - constant `str`. Markdown is supported. Defaults to empty. + Arguments: + name: A name for this summary. The summary tag used for TensorBoard will + be this name prefixed by any active name scopes. + data: A UTF-8 string tensor value. + step: Explicit `int64`-castable monotonic step value for this summary. If + omitted, this defaults to `tf.summary.experimental.get_step()`, which must + not be None. + description: Optional long-form description for this summary, as a + constant `str`. Markdown is supported. Defaults to empty. - Returns: - True on success, or false if no summary was emitted because no default - summary writer was available. + Returns: + True on success, or false if no summary was emitted because no default + summary writer was available. - Raises: - ValueError: if a default writer exists, but no step was provided and - `tf.summary.experimental.get_step()` is None. - """ - summary_metadata = metadata.create_summary_metadata( - display_name=None, description=description) - # TODO(https://github.com/tensorflow/tensorboard/issues/2109): remove fallback - summary_scope = ( - getattr(tf.summary.experimental, 'summary_scope', None) or - tf.summary.summary_scope) - with summary_scope( - name, 'text_summary', values=[data, step]) as (tag, _): - tf.debugging.assert_type(data, tf.string) - return tf.summary.write( - tag=tag, tensor=data, step=step, metadata=summary_metadata) + Raises: + ValueError: if a default writer exists, but no step was provided and + `tf.summary.experimental.get_step()` is None. + """ + summary_metadata = metadata.create_summary_metadata( + display_name=None, description=description + ) + # TODO(https://github.com/tensorflow/tensorboard/issues/2109): remove fallback + summary_scope = ( + getattr(tf.summary.experimental, "summary_scope", None) + or tf.summary.summary_scope + ) + with summary_scope(name, "text_summary", values=[data, step]) as (tag, _): + tf.debugging.assert_type(data, tf.string) + return tf.summary.write( + tag=tag, tensor=data, step=step, metadata=summary_metadata + ) def text_pb(tag, data, description=None): - """Create a text tf.Summary protobuf. + """Create a text tf.Summary protobuf. - Arguments: - tag: String tag for the summary. - data: A Python bytestring (of type bytes), a Unicode string, or a numpy data - array of those types. - description: Optional long-form description for this summary, as a `str`. - Markdown is supported. Defaults to empty. + Arguments: + tag: String tag for the summary. + data: A Python bytestring (of type bytes), a Unicode string, or a numpy data + array of those types. + description: Optional long-form description for this summary, as a `str`. + Markdown is supported. Defaults to empty. - Raises: - TypeError: If the type of the data is unsupported. + Raises: + TypeError: If the type of the data is unsupported. - Returns: - A `tf.Summary` protobuf object. - """ - try: - tensor = tensor_util.make_tensor_proto(data, dtype=np.object) - except TypeError as e: - raise TypeError('tensor must be of type string', e) - summary_metadata = metadata.create_summary_metadata( - display_name=None, description=description) - summary = summary_pb2.Summary() - summary.value.add(tag=tag, - metadata=summary_metadata, - tensor=tensor) - return summary + Returns: + A `tf.Summary` protobuf object. + """ + try: + tensor = tensor_util.make_tensor_proto(data, dtype=np.object) + except TypeError as e: + raise TypeError("tensor must be of type string", e) + summary_metadata = metadata.create_summary_metadata( + display_name=None, description=description + ) + summary = summary_pb2.Summary() + summary.value.add(tag=tag, metadata=summary_metadata, tensor=tensor) + return summary diff --git a/tensorboard/plugins/text/text_demo.py b/tensorboard/plugins/text/text_demo.py index c2e47f1641..eea016c7eb 100644 --- a/tensorboard/plugins/text/text_demo.py +++ b/tensorboard/plugins/text/text_demo.py @@ -28,108 +28,108 @@ logger = tb_logging.get_logger() # Directory into which to write tensorboard data. -LOGDIR = '/tmp/text_demo' +LOGDIR = "/tmp/text_demo" # Number of steps for which to write data. STEPS = 16 def simple_example(step): - # Text summaries log arbitrary text. This can be encoded with ASCII or - # UTF-8. Here's a simple example, wherein we greet the user on each - # step: - step_string = tf.as_string(step) - greeting = tf.strings.join(['Hello from step ', step_string, '!']) - tf.compat.v1.summary.text('greeting', greeting) + # Text summaries log arbitrary text. This can be encoded with ASCII or + # UTF-8. Here's a simple example, wherein we greet the user on each + # step: + step_string = tf.as_string(step) + greeting = tf.strings.join(["Hello from step ", step_string, "!"]) + tf.compat.v1.summary.text("greeting", greeting) def markdown_table(step): - # The text summary can also contain Markdown, including Markdown - # tables. Markdown tables look like this: - # - # | hello | there | - # |-------|-------| - # | this | is | - # | a | table | - # - # The leading and trailing pipes in each row are optional, and the text - # doesn't actually have to be neatly aligned, so we can create these - # pretty easily. Let's do so. - header_row = 'Pounds of chocolate | Happiness' - chocolate = tf.range(step) - happiness = tf.square(chocolate + 1) - chocolate_column = tf.as_string(chocolate) - happiness_column = tf.as_string(happiness) - table_rows = tf.strings.join([chocolate_column, " | ", happiness_column]) - table_body = tf.strings.reduce_join(inputs=table_rows, separator='\n') - table = tf.strings.join([header_row, "---|---", table_body], separator='\n') - preamble = 'We conducted an experiment and found the following data:\n\n' - result = tf.strings.join([preamble, table]) - tf.compat.v1.summary.text('chocolate_study', result) + # The text summary can also contain Markdown, including Markdown + # tables. Markdown tables look like this: + # + # | hello | there | + # |-------|-------| + # | this | is | + # | a | table | + # + # The leading and trailing pipes in each row are optional, and the text + # doesn't actually have to be neatly aligned, so we can create these + # pretty easily. Let's do so. + header_row = "Pounds of chocolate | Happiness" + chocolate = tf.range(step) + happiness = tf.square(chocolate + 1) + chocolate_column = tf.as_string(chocolate) + happiness_column = tf.as_string(happiness) + table_rows = tf.strings.join([chocolate_column, " | ", happiness_column]) + table_body = tf.strings.reduce_join(inputs=table_rows, separator="\n") + table = tf.strings.join([header_row, "---|---", table_body], separator="\n") + preamble = "We conducted an experiment and found the following data:\n\n" + result = tf.strings.join([preamble, table]) + tf.compat.v1.summary.text("chocolate_study", result) def higher_order_tensors(step): - # We're not limited to passing scalar tensors to the summary - # operation. If we pass a rank-1 or rank-2 tensor, it'll be visualized - # as a table in TensorBoard. (For higher-ranked tensors, you'll see - # just a 2D slice of the data.) - # - # To demonstrate this, let's create a multiplication table. - - # First, we'll create the table body, a `step`-by-`step` array of - # strings. - numbers = tf.range(step) - numbers_row = tf.expand_dims(numbers, 0) # shape: [1, step] - numbers_column = tf.expand_dims(numbers, 1) # shape: [step, 1] - products = tf.matmul(numbers_column, numbers_row) # shape: [step, step] - table_body = tf.as_string(products) - - # Next, we'll create a header row and column, and a little - # multiplication sign to put in the corner. - bold_numbers = tf.strings.join(['**', tf.as_string(numbers), '**']) - bold_row = tf.expand_dims(bold_numbers, 0) - bold_column = tf.expand_dims(bold_numbers, 1) - corner_cell = tf.constant(u'\u00d7'.encode('utf-8')) # MULTIPLICATION SIGN - - # Now, we have to put the pieces together. Using `axis=0` stacks - # vertically; using `axis=1` juxtaposes horizontally. - table_body_and_top_row = tf.concat([bold_row, table_body], axis=0) - table_left_column = tf.concat([[[corner_cell]], bold_column], axis=0) - table_full = tf.concat([table_left_column, table_body_and_top_row], axis=1) - - # The result, `table_full`, is a rank-2 string tensor of shape - # `[step + 1, step + 1]`. We can pass it directly to the summary, and - # we'll get a nicely formatted table in TensorBoard. - tf.compat.v1.summary.text('multiplication_table', table_full) + # We're not limited to passing scalar tensors to the summary + # operation. If we pass a rank-1 or rank-2 tensor, it'll be visualized + # as a table in TensorBoard. (For higher-ranked tensors, you'll see + # just a 2D slice of the data.) + # + # To demonstrate this, let's create a multiplication table. + + # First, we'll create the table body, a `step`-by-`step` array of + # strings. + numbers = tf.range(step) + numbers_row = tf.expand_dims(numbers, 0) # shape: [1, step] + numbers_column = tf.expand_dims(numbers, 1) # shape: [step, 1] + products = tf.matmul(numbers_column, numbers_row) # shape: [step, step] + table_body = tf.as_string(products) + + # Next, we'll create a header row and column, and a little + # multiplication sign to put in the corner. + bold_numbers = tf.strings.join(["**", tf.as_string(numbers), "**"]) + bold_row = tf.expand_dims(bold_numbers, 0) + bold_column = tf.expand_dims(bold_numbers, 1) + corner_cell = tf.constant(u"\u00d7".encode("utf-8")) # MULTIPLICATION SIGN + + # Now, we have to put the pieces together. Using `axis=0` stacks + # vertically; using `axis=1` juxtaposes horizontally. + table_body_and_top_row = tf.concat([bold_row, table_body], axis=0) + table_left_column = tf.concat([[[corner_cell]], bold_column], axis=0) + table_full = tf.concat([table_left_column, table_body_and_top_row], axis=1) + + # The result, `table_full`, is a rank-2 string tensor of shape + # `[step + 1, step + 1]`. We can pass it directly to the summary, and + # we'll get a nicely formatted table in TensorBoard. + tf.compat.v1.summary.text("multiplication_table", table_full) def run_all(logdir): - tf.compat.v1.reset_default_graph() - step_placeholder = tf.compat.v1.placeholder(tf.int32) - - with tf.name_scope('simple_example'): - simple_example(step_placeholder) - with tf.name_scope('markdown_table'): - markdown_table(step_placeholder) - with tf.name_scope('higher_order_tensors'): - higher_order_tensors(step_placeholder) - all_summaries = tf.compat.v1.summary.merge_all() - - with tf.compat.v1.Session() as sess: - writer = tf.summary.FileWriter(logdir) - writer.add_graph(sess.graph) - for step in xrange(STEPS): - s = sess.run(all_summaries, feed_dict={step_placeholder: step}) - writer.add_summary(s, global_step=step) - writer.close() + tf.compat.v1.reset_default_graph() + step_placeholder = tf.compat.v1.placeholder(tf.int32) + + with tf.name_scope("simple_example"): + simple_example(step_placeholder) + with tf.name_scope("markdown_table"): + markdown_table(step_placeholder) + with tf.name_scope("higher_order_tensors"): + higher_order_tensors(step_placeholder) + all_summaries = tf.compat.v1.summary.merge_all() + + with tf.compat.v1.Session() as sess: + writer = tf.summary.FileWriter(logdir) + writer.add_graph(sess.graph) + for step in xrange(STEPS): + s = sess.run(all_summaries, feed_dict={step_placeholder: step}) + writer.add_summary(s, global_step=step) + writer.close() def main(unused_argv): - logging.set_verbosity(logging.INFO) - logger.info('Saving output to %s.' % LOGDIR) - run_all(LOGDIR) - logger.info('Done. Output saved to %s.' % LOGDIR) + logging.set_verbosity(logging.INFO) + logger.info("Saving output to %s." % LOGDIR) + run_all(LOGDIR) + logger.info("Done. Output saved to %s." % LOGDIR) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/tensorboard/plugins/text/text_plugin.py b/tensorboard/plugins/text/text_plugin.py index b97a5196c9..bdf4b8ad8d 100644 --- a/tensorboard/plugins/text/text_plugin.py +++ b/tensorboard/plugins/text/text_plugin.py @@ -24,6 +24,7 @@ # pylint: disable=g-bad-import-order # Necessary for an internal test with special behavior for numpy. import numpy as np + # pylint: enable=g-bad-import-order import six @@ -36,217 +37,230 @@ from tensorboard.util import tensor_util # HTTP routes -TAGS_ROUTE = '/tags' -TEXT_ROUTE = '/text' +TAGS_ROUTE = "/tags" +TEXT_ROUTE = "/text" -WARNING_TEMPLATE = textwrap.dedent("""\ +WARNING_TEMPLATE = textwrap.dedent( + """\ **Warning:** This text summary contained data of dimensionality %d, but only \ - 2d tables are supported. Showing a 2d slice of the data instead.""") - + 2d tables are supported. Showing a 2d slice of the data instead.""" +) -def make_table_row(contents, tag='td'): - """Given an iterable of string contents, make a table row. - Args: - contents: An iterable yielding strings. - tag: The tag to place contents in. Defaults to 'td', you might want 'th'. +def make_table_row(contents, tag="td"): + """Given an iterable of string contents, make a table row. - Returns: - A string containing the content strings, organized into a table row. + Args: + contents: An iterable yielding strings. + tag: The tag to place contents in. Defaults to 'td', you might want 'th'. - Example: make_table_row(['one', 'two', 'three']) == ''' - - one - two - three - ''' - """ - columns = ('<%s>%s\n' % (tag, s, tag) for s in contents) - return '\n' + ''.join(columns) + '\n' + Returns: + A string containing the content strings, organized into a table row. + + Example: make_table_row(['one', 'two', 'three']) == ''' + + one + two + three + ''' + """ + columns = ("<%s>%s\n" % (tag, s, tag) for s in contents) + return "\n" + "".join(columns) + "\n" def make_table(contents, headers=None): - """Given a numpy ndarray of strings, concatenate them into a html table. - - Args: - contents: A np.ndarray of strings. May be 1d or 2d. In the 1d case, the - table is laid out vertically (i.e. row-major). - headers: A np.ndarray or list of string header names for the table. - - Returns: - A string containing all of the content strings, organized into a table. - - Raises: - ValueError: If contents is not a np.ndarray. - ValueError: If contents is not 1d or 2d. - ValueError: If contents is empty. - ValueError: If headers is present and not a list, tuple, or ndarray. - ValueError: If headers is not 1d. - ValueError: If number of elements in headers does not correspond to number - of columns in contents. - """ - if not isinstance(contents, np.ndarray): - raise ValueError('make_table contents must be a numpy ndarray') - - if contents.ndim not in [1, 2]: - raise ValueError('make_table requires a 1d or 2d numpy array, was %dd' % - contents.ndim) - - if headers: - if isinstance(headers, (list, tuple)): - headers = np.array(headers) - if not isinstance(headers, np.ndarray): - raise ValueError('Could not convert headers %s into np.ndarray' % headers) - if headers.ndim != 1: - raise ValueError('Headers must be 1d, is %dd' % headers.ndim) - expected_n_columns = contents.shape[1] if contents.ndim == 2 else 1 - if headers.shape[0] != expected_n_columns: - raise ValueError('Number of headers %d must match number of columns %d' % - (headers.shape[0], expected_n_columns)) - header = '\n%s\n' % make_table_row(headers, tag='th') - else: - header = '' - - n_rows = contents.shape[0] - if contents.ndim == 1: - # If it's a vector, we need to wrap each element in a new list, otherwise - # we would turn the string itself into a row (see test code) - rows = (make_table_row([contents[i]]) for i in range(n_rows)) - else: - rows = (make_table_row(contents[i, :]) for i in range(n_rows)) - - return '\n%s\n%s\n
' % (header, ''.join(rows)) - - -def reduce_to_2d(arr): - """Given a np.npdarray with nDims > 2, reduce it to 2d. - - It does this by selecting the zeroth coordinate for every dimension greater - than two. + """Given a numpy ndarray of strings, concatenate them into a html table. - Args: - arr: a numpy ndarray of dimension at least 2. - - Returns: - A two-dimensional subarray from the input array. - - Raises: - ValueError: If the argument is not a numpy ndarray, or the dimensionality - is too low. - """ - if not isinstance(arr, np.ndarray): - raise ValueError('reduce_to_2d requires a numpy.ndarray') - - ndims = len(arr.shape) - if ndims < 2: - raise ValueError('reduce_to_2d requires an array of dimensionality >=2') - # slice(None) is equivalent to `:`, so we take arr[0,0,...0,:,:] - slices = ([0] * (ndims - 2)) + [slice(None), slice(None)] - return arr[slices] - - -def text_array_to_html(text_arr): - """Take a numpy.ndarray containing strings, and convert it into html. + Args: + contents: A np.ndarray of strings. May be 1d or 2d. In the 1d case, the + table is laid out vertically (i.e. row-major). + headers: A np.ndarray or list of string header names for the table. - If the ndarray contains a single scalar string, that string is converted to - html via our sanitized markdown parser. If it contains an array of strings, - the strings are individually converted to html and then composed into a table - using make_table. If the array contains dimensionality greater than 2, - all but two of the dimensions are removed, and a warning message is prefixed - to the table. + Returns: + A string containing all of the content strings, organized into a table. + + Raises: + ValueError: If contents is not a np.ndarray. + ValueError: If contents is not 1d or 2d. + ValueError: If contents is empty. + ValueError: If headers is present and not a list, tuple, or ndarray. + ValueError: If headers is not 1d. + ValueError: If number of elements in headers does not correspond to number + of columns in contents. + """ + if not isinstance(contents, np.ndarray): + raise ValueError("make_table contents must be a numpy ndarray") + + if contents.ndim not in [1, 2]: + raise ValueError( + "make_table requires a 1d or 2d numpy array, was %dd" + % contents.ndim + ) + + if headers: + if isinstance(headers, (list, tuple)): + headers = np.array(headers) + if not isinstance(headers, np.ndarray): + raise ValueError( + "Could not convert headers %s into np.ndarray" % headers + ) + if headers.ndim != 1: + raise ValueError("Headers must be 1d, is %dd" % headers.ndim) + expected_n_columns = contents.shape[1] if contents.ndim == 2 else 1 + if headers.shape[0] != expected_n_columns: + raise ValueError( + "Number of headers %d must match number of columns %d" + % (headers.shape[0], expected_n_columns) + ) + header = "\n%s\n" % make_table_row(headers, tag="th") + else: + header = "" + + n_rows = contents.shape[0] + if contents.ndim == 1: + # If it's a vector, we need to wrap each element in a new list, otherwise + # we would turn the string itself into a row (see test code) + rows = (make_table_row([contents[i]]) for i in range(n_rows)) + else: + rows = (make_table_row(contents[i, :]) for i in range(n_rows)) + + return "\n%s\n%s\n
" % (header, "".join(rows)) - Args: - text_arr: A numpy.ndarray containing strings. - Returns: - The array converted to html. - """ - if not text_arr.shape: - # It is a scalar. No need to put it in a table, just apply markdown - return plugin_util.markdown_to_safe_html(np.asscalar(text_arr)) - warning = '' - if len(text_arr.shape) > 2: - warning = plugin_util.markdown_to_safe_html(WARNING_TEMPLATE - % len(text_arr.shape)) - text_arr = reduce_to_2d(text_arr) +def reduce_to_2d(arr): + """Given a np.npdarray with nDims > 2, reduce it to 2d. - html_arr = [plugin_util.markdown_to_safe_html(x) - for x in text_arr.reshape(-1)] - html_arr = np.array(html_arr).reshape(text_arr.shape) + It does this by selecting the zeroth coordinate for every dimension greater + than two. - return warning + make_table(html_arr) + Args: + arr: a numpy ndarray of dimension at least 2. + Returns: + A two-dimensional subarray from the input array. -def process_string_tensor_event(event): - """Convert a TensorEvent into a JSON-compatible response.""" - string_arr = tensor_util.make_ndarray(event.tensor_proto) - html = text_array_to_html(string_arr) - return { - 'wall_time': event.wall_time, - 'step': event.step, - 'text': html, - } + Raises: + ValueError: If the argument is not a numpy ndarray, or the dimensionality + is too low. + """ + if not isinstance(arr, np.ndarray): + raise ValueError("reduce_to_2d requires a numpy.ndarray") + ndims = len(arr.shape) + if ndims < 2: + raise ValueError("reduce_to_2d requires an array of dimensionality >=2") + # slice(None) is equivalent to `:`, so we take arr[0,0,...0,:,:] + slices = ([0] * (ndims - 2)) + [slice(None), slice(None)] + return arr[slices] -class TextPlugin(base_plugin.TBPlugin): - """Text Plugin for TensorBoard.""" - plugin_name = metadata.PLUGIN_NAME +def text_array_to_html(text_arr): + """Take a numpy.ndarray containing strings, and convert it into html. - def __init__(self, context): - """Instantiates TextPlugin via TensorBoard core. + If the ndarray contains a single scalar string, that string is converted to + html via our sanitized markdown parser. If it contains an array of strings, + the strings are individually converted to html and then composed into a table + using make_table. If the array contains dimensionality greater than 2, + all but two of the dimensions are removed, and a warning message is prefixed + to the table. Args: - context: A base_plugin.TBContext instance. - """ - self._multiplexer = context.multiplexer - - def is_active(self): - """Determines whether this plugin is active. - - This plugin is only active if TensorBoard sampled any text summaries. + text_arr: A numpy.ndarray containing strings. Returns: - Whether this plugin is active. + The array converted to html. """ - if not self._multiplexer: - return False - return bool(self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME)) + if not text_arr.shape: + # It is a scalar. No need to put it in a table, just apply markdown + return plugin_util.markdown_to_safe_html(np.asscalar(text_arr)) + warning = "" + if len(text_arr.shape) > 2: + warning = plugin_util.markdown_to_safe_html( + WARNING_TEMPLATE % len(text_arr.shape) + ) + text_arr = reduce_to_2d(text_arr) - def frontend_metadata(self): - return base_plugin.FrontendMetadata(element_name='tf-text-dashboard') + html_arr = [ + plugin_util.markdown_to_safe_html(x) for x in text_arr.reshape(-1) + ] + html_arr = np.array(html_arr).reshape(text_arr.shape) + + return warning + make_table(html_arr) - def index_impl(self): - mapping = self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME) - return { - run: list(tag_to_content) - for (run, tag_to_content) - in six.iteritems(mapping) - } - @wrappers.Request.application - def tags_route(self, request): - index = self.index_impl() - return http_util.Respond(request, index, 'application/json') - - def text_impl(self, run, tag): - try: - text_events = self._multiplexer.Tensors(run, tag) - except KeyError: - text_events = [] - responses = [process_string_tensor_event(ev) for ev in text_events] - return responses - - @wrappers.Request.application - def text_route(self, request): - run = request.args.get('run') - tag = request.args.get('tag') - response = self.text_impl(run, tag) - return http_util.Respond(request, response, 'application/json') - - def get_plugin_apps(self): +def process_string_tensor_event(event): + """Convert a TensorEvent into a JSON-compatible response.""" + string_arr = tensor_util.make_ndarray(event.tensor_proto) + html = text_array_to_html(string_arr) return { - TAGS_ROUTE: self.tags_route, - TEXT_ROUTE: self.text_route, + "wall_time": event.wall_time, + "step": event.step, + "text": html, } + + +class TextPlugin(base_plugin.TBPlugin): + """Text Plugin for TensorBoard.""" + + plugin_name = metadata.PLUGIN_NAME + + def __init__(self, context): + """Instantiates TextPlugin via TensorBoard core. + + Args: + context: A base_plugin.TBContext instance. + """ + self._multiplexer = context.multiplexer + + def is_active(self): + """Determines whether this plugin is active. + + This plugin is only active if TensorBoard sampled any text summaries. + + Returns: + Whether this plugin is active. + """ + if not self._multiplexer: + return False + return bool( + self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME) + ) + + def frontend_metadata(self): + return base_plugin.FrontendMetadata(element_name="tf-text-dashboard") + + def index_impl(self): + mapping = self._multiplexer.PluginRunToTagToContent( + metadata.PLUGIN_NAME + ) + return { + run: list(tag_to_content) + for (run, tag_to_content) in six.iteritems(mapping) + } + + @wrappers.Request.application + def tags_route(self, request): + index = self.index_impl() + return http_util.Respond(request, index, "application/json") + + def text_impl(self, run, tag): + try: + text_events = self._multiplexer.Tensors(run, tag) + except KeyError: + text_events = [] + responses = [process_string_tensor_event(ev) for ev in text_events] + return responses + + @wrappers.Request.application + def text_route(self, request): + run = request.args.get("run") + tag = request.args.get("tag") + response = self.text_impl(run, tag) + return http_util.Respond(request, response, "application/json") + + def get_plugin_apps(self): + return { + TAGS_ROUTE: self.tags_route, + TEXT_ROUTE: self.text_route, + } diff --git a/tensorboard/plugins/text/text_plugin_test.py b/tensorboard/plugins/text/text_plugin_test.py index 2638618d74..52c744ca71 100644 --- a/tensorboard/plugins/text/text_plugin_test.py +++ b/tensorboard/plugins/text/text_plugin_test.py @@ -26,85 +26,90 @@ import tensorflow as tf from tensorboard import plugin_util -from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer +from tensorboard.backend.event_processing import ( + plugin_event_multiplexer as event_multiplexer, +) from tensorboard.plugins import base_plugin from tensorboard.plugins.text import text_plugin from tensorboard.util import test_util tf.compat.v1.disable_v2_behavior() -GEMS = ['garnet', 'amethyst', 'pearl', 'steven'] +GEMS = ["garnet", "amethyst", "pearl", "steven"] class TextPluginTest(tf.test.TestCase): - - def setUp(self): - self.logdir = self.get_temp_dir() - self.generate_testdata() - multiplexer = event_multiplexer.EventMultiplexer() - multiplexer.AddRunsFromDirectory(self.logdir) - multiplexer.Reload() - context = base_plugin.TBContext(logdir=self.logdir, multiplexer=multiplexer) - self.plugin = text_plugin.TextPlugin(context) - - def testRoutesProvided(self): - routes = self.plugin.get_plugin_apps() - self.assertIsInstance(routes['/tags'], collections.Callable) - self.assertIsInstance(routes['/text'], collections.Callable) - - def generate_testdata(self, include_text=True, logdir=None): - tf.compat.v1.reset_default_graph() - sess = tf.compat.v1.Session() - placeholder = tf.compat.v1.placeholder(tf.string) - summary_tensor = tf.compat.v1.summary.text('message', placeholder) - vector_summary = tf.compat.v1.summary.text('vector', placeholder) - scalar_summary = tf.compat.v1.summary.scalar('twelve', tf.constant(12)) - - run_names = ['fry', 'leela'] - for run_name in run_names: - subdir = os.path.join(logdir or self.logdir, run_name) - with test_util.FileWriterCache.get(subdir) as writer: - writer.add_graph(sess.graph) - - step = 0 - for gem in GEMS: - message = run_name + ' *loves* ' + gem - feed_dict = { - placeholder: message, - } - if include_text: - summ = sess.run(summary_tensor, feed_dict=feed_dict) - writer.add_summary(summ, global_step=step) - step += 1 - - vector_message = ['one', 'two', 'three', 'four'] - if include_text: - summ = sess.run(vector_summary, - feed_dict={placeholder: vector_message}) - writer.add_summary(summ) - - summ = sess.run(scalar_summary, feed_dict={placeholder: []}) - writer.add_summary(summ) - - - def testIndex(self): - index = self.plugin.index_impl() - self.assertItemsEqual(['fry', 'leela'], index.keys()) - self.assertItemsEqual(['message', 'vector'], index['fry']) - self.assertItemsEqual(['message', 'vector'], index['leela']) - - def testText(self): - fry = self.plugin.text_impl('fry', 'message') - leela = self.plugin.text_impl('leela', 'message') - self.assertEqual(len(fry), 4) - self.assertEqual(len(leela), 4) - for i in range(4): - self.assertEqual(fry[i]['step'], i) - self.assertEqual(leela[i]['step'], i) - - table = self.plugin.text_impl('fry', 'vector')[0]['text'] - self.assertEqual(table, - textwrap.dedent("""\ + def setUp(self): + self.logdir = self.get_temp_dir() + self.generate_testdata() + multiplexer = event_multiplexer.EventMultiplexer() + multiplexer.AddRunsFromDirectory(self.logdir) + multiplexer.Reload() + context = base_plugin.TBContext( + logdir=self.logdir, multiplexer=multiplexer + ) + self.plugin = text_plugin.TextPlugin(context) + + def testRoutesProvided(self): + routes = self.plugin.get_plugin_apps() + self.assertIsInstance(routes["/tags"], collections.Callable) + self.assertIsInstance(routes["/text"], collections.Callable) + + def generate_testdata(self, include_text=True, logdir=None): + tf.compat.v1.reset_default_graph() + sess = tf.compat.v1.Session() + placeholder = tf.compat.v1.placeholder(tf.string) + summary_tensor = tf.compat.v1.summary.text("message", placeholder) + vector_summary = tf.compat.v1.summary.text("vector", placeholder) + scalar_summary = tf.compat.v1.summary.scalar("twelve", tf.constant(12)) + + run_names = ["fry", "leela"] + for run_name in run_names: + subdir = os.path.join(logdir or self.logdir, run_name) + with test_util.FileWriterCache.get(subdir) as writer: + writer.add_graph(sess.graph) + + step = 0 + for gem in GEMS: + message = run_name + " *loves* " + gem + feed_dict = { + placeholder: message, + } + if include_text: + summ = sess.run(summary_tensor, feed_dict=feed_dict) + writer.add_summary(summ, global_step=step) + step += 1 + + vector_message = ["one", "two", "three", "four"] + if include_text: + summ = sess.run( + vector_summary, feed_dict={placeholder: vector_message} + ) + writer.add_summary(summ) + + summ = sess.run(scalar_summary, feed_dict={placeholder: []}) + writer.add_summary(summ) + + def testIndex(self): + index = self.plugin.index_impl() + self.assertItemsEqual(["fry", "leela"], index.keys()) + self.assertItemsEqual(["message", "vector"], index["fry"]) + self.assertItemsEqual(["message", "vector"], index["leela"]) + + def testText(self): + fry = self.plugin.text_impl("fry", "message") + leela = self.plugin.text_impl("leela", "message") + self.assertEqual(len(fry), 4) + self.assertEqual(len(leela), 4) + for i in range(4): + self.assertEqual(fry[i]["step"], i) + self.assertEqual(leela[i]["step"], i) + + table = self.plugin.text_impl("fry", "vector")[0]["text"] + self.assertEqual( + table, + textwrap.dedent( + """\ @@ -120,11 +125,14 @@ def testText(self): -

four

""")) - - def testTableGeneration(self): - array2d = np.array([['one', 'two'], ['three', 'four']]) - expected_table = textwrap.dedent("""\ + """ + ), + ) + + def testTableGeneration(self): + array2d = np.array([["one", "two"], ["three", "four"]]) + expected_table = textwrap.dedent( + """\ @@ -136,10 +144,12 @@ def testTableGeneration(self): -
four
""") - self.assertEqual(text_plugin.make_table(array2d), expected_table) + """ + ) + self.assertEqual(text_plugin.make_table(array2d), expected_table) - expected_table_with_headers = textwrap.dedent("""\ + expected_table_with_headers = textwrap.dedent( + """\ @@ -157,13 +167,17 @@ def testTableGeneration(self): -
four
""") + """ + ) - actual_with_headers = text_plugin.make_table(array2d, headers=['c1', 'c2']) - self.assertEqual(actual_with_headers, expected_table_with_headers) + actual_with_headers = text_plugin.make_table( + array2d, headers=["c1", "c2"] + ) + self.assertEqual(actual_with_headers, expected_table_with_headers) - array_1d = np.array(['one', 'two', 'three', 'four', 'five']) - expected_1d = textwrap.dedent("""\ + array_1d = np.array(["one", "two", "three", "four", "five"]) + expected_1d = textwrap.dedent( + """\ @@ -182,10 +196,12 @@ def testTableGeneration(self): -
five
""") - self.assertEqual(text_plugin.make_table(array_1d), expected_1d) + """ + ) + self.assertEqual(text_plugin.make_table(array_1d), expected_1d) - expected_1d_with_headers = textwrap.dedent("""\ + expected_1d_with_headers = textwrap.dedent( + """\ @@ -209,77 +225,82 @@ def testTableGeneration(self): -
five
""") - actual_1d_with_headers = text_plugin.make_table(array_1d, headers=['X']) - self.assertEqual(actual_1d_with_headers, expected_1d_with_headers) - - def testMakeTableExceptions(self): - # Verify that contents is being type-checked and shape-checked. - with self.assertRaises(ValueError): - text_plugin.make_table([]) - - with self.assertRaises(ValueError): - text_plugin.make_table('foo') - - with self.assertRaises(ValueError): - invalid_shape = np.full((3, 3, 3), 'nope', dtype=np.dtype('S3')) - text_plugin.make_table(invalid_shape) - - # Test headers exceptions in 2d array case. - test_array = np.full((3, 3), 'foo', dtype=np.dtype('S3')) - with self.assertRaises(ValueError): - # Headers is wrong type. - text_plugin.make_table(test_array, headers='foo') - with self.assertRaises(ValueError): - # Too many headers. - text_plugin.make_table(test_array, headers=['foo', 'bar', 'zod', 'zoink']) - with self.assertRaises(ValueError): - # headers is 2d - text_plugin.make_table(test_array, headers=test_array) - - # Also make sure the column counting logic works in the 1d array case. - test_array = np.array(['foo', 'bar', 'zod']) - with self.assertRaises(ValueError): - # Too many headers. - text_plugin.make_table(test_array, headers=test_array) - - def test_reduce_to_2d(self): - - def make_range_array(dim): - """Produce an incrementally increasing multidimensional array. - - Args: - dim: the number of dimensions for the array - - Returns: - An array of increasing integer elements, with dim dimensions and size - two in each dimension. - - Example: rangeArray(2) results in [[0,1],[2,3]]. - """ - return np.array(range(2**dim)).reshape([2] * dim) - - for i in range(2, 5): - actual = text_plugin.reduce_to_2d(make_range_array(i)) - expected = make_range_array(2) - np.testing.assert_array_equal(actual, expected) - - def test_text_array_to_html(self): - convert = text_plugin.text_array_to_html - scalar = np.array('foo') - scalar_expected = '

foo

' - self.assertEqual(convert(scalar), scalar_expected) - - # Check that underscores are preserved correctly; this detects erroneous - # use of UTF-16 or UTF-32 encoding when calling markdown_to_safe_html(), - # which would introduce spurious null bytes and cause undesired tags - # around the underscores. - scalar_underscores = np.array('word_with_underscores') - scalar_underscores_expected = '

word_with_underscores

' - self.assertEqual(convert(scalar_underscores), scalar_underscores_expected) - - vector = np.array(['foo', 'bar']) - vector_expected = textwrap.dedent("""\ + """ + ) + actual_1d_with_headers = text_plugin.make_table(array_1d, headers=["X"]) + self.assertEqual(actual_1d_with_headers, expected_1d_with_headers) + + def testMakeTableExceptions(self): + # Verify that contents is being type-checked and shape-checked. + with self.assertRaises(ValueError): + text_plugin.make_table([]) + + with self.assertRaises(ValueError): + text_plugin.make_table("foo") + + with self.assertRaises(ValueError): + invalid_shape = np.full((3, 3, 3), "nope", dtype=np.dtype("S3")) + text_plugin.make_table(invalid_shape) + + # Test headers exceptions in 2d array case. + test_array = np.full((3, 3), "foo", dtype=np.dtype("S3")) + with self.assertRaises(ValueError): + # Headers is wrong type. + text_plugin.make_table(test_array, headers="foo") + with self.assertRaises(ValueError): + # Too many headers. + text_plugin.make_table( + test_array, headers=["foo", "bar", "zod", "zoink"] + ) + with self.assertRaises(ValueError): + # headers is 2d + text_plugin.make_table(test_array, headers=test_array) + + # Also make sure the column counting logic works in the 1d array case. + test_array = np.array(["foo", "bar", "zod"]) + with self.assertRaises(ValueError): + # Too many headers. + text_plugin.make_table(test_array, headers=test_array) + + def test_reduce_to_2d(self): + def make_range_array(dim): + """Produce an incrementally increasing multidimensional array. + + Args: + dim: the number of dimensions for the array + + Returns: + An array of increasing integer elements, with dim dimensions and size + two in each dimension. + + Example: rangeArray(2) results in [[0,1],[2,3]]. + """ + return np.array(range(2 ** dim)).reshape([2] * dim) + + for i in range(2, 5): + actual = text_plugin.reduce_to_2d(make_range_array(i)) + expected = make_range_array(2) + np.testing.assert_array_equal(actual, expected) + + def test_text_array_to_html(self): + convert = text_plugin.text_array_to_html + scalar = np.array("foo") + scalar_expected = "

foo

" + self.assertEqual(convert(scalar), scalar_expected) + + # Check that underscores are preserved correctly; this detects erroneous + # use of UTF-16 or UTF-32 encoding when calling markdown_to_safe_html(), + # which would introduce spurious null bytes and cause undesired tags + # around the underscores. + scalar_underscores = np.array("word_with_underscores") + scalar_underscores_expected = "

word_with_underscores

" + self.assertEqual( + convert(scalar_underscores), scalar_underscores_expected + ) + + vector = np.array(["foo", "bar"]) + vector_expected = textwrap.dedent( + """\ @@ -289,11 +310,13 @@ def test_text_array_to_html(self): -

bar

""") - self.assertEqual(convert(vector), vector_expected) + """ + ) + self.assertEqual(convert(vector), vector_expected) - d2 = np.array([['foo', 'bar'], ['zoink', 'zod']]) - d2_expected = textwrap.dedent("""\ + d2 = np.array([["foo", "bar"], ["zoink", "zod"]]) + d2_expected = textwrap.dedent( + """\ @@ -305,15 +328,22 @@ def test_text_array_to_html(self): -

zod

""") - self.assertEqual(convert(d2), d2_expected) - - d3 = np.array([[['foo', 'bar'], ['zoink', 'zod']], [['FOO', 'BAR'], - ['ZOINK', 'ZOD']]]) - - warning = plugin_util.markdown_to_safe_html( - text_plugin.WARNING_TEMPLATE % 3) - d3_expected = warning + textwrap.dedent("""\ + """ + ) + self.assertEqual(convert(d2), d2_expected) + + d3 = np.array( + [ + [["foo", "bar"], ["zoink", "zod"]], + [["FOO", "BAR"], ["ZOINK", "ZOD"]], + ] + ) + + warning = plugin_util.markdown_to_safe_html( + text_plugin.WARNING_TEMPLATE % 3 + ) + d3_expected = warning + textwrap.dedent( + """\ @@ -325,42 +355,48 @@ def test_text_array_to_html(self): -

zod

""") - self.assertEqual(convert(d3), d3_expected) - - def testPluginIsActiveWhenNoRuns(self): - """The plugin should be inactive when there are no runs.""" - multiplexer = event_multiplexer.EventMultiplexer() - context = base_plugin.TBContext(logdir=self.logdir, multiplexer=multiplexer) - plugin = text_plugin.TextPlugin(context) - self.assertFalse(plugin.is_active()) - - def testPluginIsActiveWhenTextRuns(self): - """The plugin should be active when there are runs with text.""" - multiplexer = event_multiplexer.EventMultiplexer() - context = base_plugin.TBContext(logdir=self.logdir, multiplexer=multiplexer) - plugin = text_plugin.TextPlugin(context) - multiplexer.AddRunsFromDirectory(self.logdir) - multiplexer.Reload() - self.assertTrue(plugin.is_active()) - - def testPluginIsActiveWhenRunsButNoText(self): - """The plugin should be inactive when there are runs but none has text.""" - logdir = os.path.join(self.get_temp_dir(), 'runs_with_no_text') - multiplexer = event_multiplexer.EventMultiplexer() - context = base_plugin.TBContext(logdir=logdir, multiplexer=multiplexer) - plugin = text_plugin.TextPlugin(context) - self.generate_testdata(include_text=False, logdir=logdir) - multiplexer.AddRunsFromDirectory(logdir) - multiplexer.Reload() - self.assertFalse(plugin.is_active()) - - def testPluginIndexImpl(self): - run_to_tags = self.plugin.index_impl() - self.assertItemsEqual(['fry', 'leela'], run_to_tags.keys()) - self.assertItemsEqual(['message', 'vector'], run_to_tags['fry']) - self.assertItemsEqual(['message', 'vector'], run_to_tags['leela']) - - -if __name__ == '__main__': - tf.test.main() + """ + ) + self.assertEqual(convert(d3), d3_expected) + + def testPluginIsActiveWhenNoRuns(self): + """The plugin should be inactive when there are no runs.""" + multiplexer = event_multiplexer.EventMultiplexer() + context = base_plugin.TBContext( + logdir=self.logdir, multiplexer=multiplexer + ) + plugin = text_plugin.TextPlugin(context) + self.assertFalse(plugin.is_active()) + + def testPluginIsActiveWhenTextRuns(self): + """The plugin should be active when there are runs with text.""" + multiplexer = event_multiplexer.EventMultiplexer() + context = base_plugin.TBContext( + logdir=self.logdir, multiplexer=multiplexer + ) + plugin = text_plugin.TextPlugin(context) + multiplexer.AddRunsFromDirectory(self.logdir) + multiplexer.Reload() + self.assertTrue(plugin.is_active()) + + def testPluginIsActiveWhenRunsButNoText(self): + """The plugin should be inactive when there are runs but none has + text.""" + logdir = os.path.join(self.get_temp_dir(), "runs_with_no_text") + multiplexer = event_multiplexer.EventMultiplexer() + context = base_plugin.TBContext(logdir=logdir, multiplexer=multiplexer) + plugin = text_plugin.TextPlugin(context) + self.generate_testdata(include_text=False, logdir=logdir) + multiplexer.AddRunsFromDirectory(logdir) + multiplexer.Reload() + self.assertFalse(plugin.is_active()) + + def testPluginIndexImpl(self): + run_to_tags = self.plugin.index_impl() + self.assertItemsEqual(["fry", "leela"], run_to_tags.keys()) + self.assertItemsEqual(["message", "vector"], run_to_tags["fry"]) + self.assertItemsEqual(["message", "vector"], run_to_tags["leela"]) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/program.py b/tensorboard/program.py index d90df1beeb..6858bdf696 100644 --- a/tensorboard/program.py +++ b/tensorboard/program.py @@ -65,612 +65,687 @@ logger = tb_logging.get_logger() # Default subcommand name. This is a user-facing CLI and should not change. -_SERVE_SUBCOMMAND_NAME = 'serve' +_SERVE_SUBCOMMAND_NAME = "serve" # Internal flag name used to store which subcommand was invoked. -_SUBCOMMAND_FLAG = '__tensorboard_subcommand' +_SUBCOMMAND_FLAG = "__tensorboard_subcommand" def setup_environment(): - """Makes recommended modifications to the environment. + """Makes recommended modifications to the environment. - This functions changes global state in the Python process. Calling - this function is a good idea, but it can't appropriately be called - from library routines. - """ - absl.logging.set_verbosity(absl.logging.WARNING) + This functions changes global state in the Python process. Calling + this function is a good idea, but it can't appropriately be called + from library routines. + """ + absl.logging.set_verbosity(absl.logging.WARNING) - # The default is HTTP/1.0 for some strange reason. If we don't use - # HTTP/1.1 then a new TCP socket and Python thread is created for - # each HTTP request. The tradeoff is we must always specify the - # Content-Length header, or do chunked encoding for streaming. - serving.WSGIRequestHandler.protocol_version = 'HTTP/1.1' + # The default is HTTP/1.0 for some strange reason. If we don't use + # HTTP/1.1 then a new TCP socket and Python thread is created for + # each HTTP request. The tradeoff is we must always specify the + # Content-Length header, or do chunked encoding for streaming. + serving.WSGIRequestHandler.protocol_version = "HTTP/1.1" -def get_default_assets_zip_provider(): - """Opens stock TensorBoard web assets collection. - - Returns: - Returns function that returns a newly opened file handle to zip file - containing static assets for stock TensorBoard, or None if webfiles.zip - could not be found. The value the callback returns must be closed. The - paths inside the zip file are considered absolute paths on the web server. - """ - path = os.path.join(os.path.dirname(inspect.getfile(sys._getframe(1))), - 'webfiles.zip') - if not os.path.exists(path): - logger.warning('webfiles.zip static assets not found: %s', path) - return None - return lambda: open(path, 'rb') -class TensorBoard(object): - """Class for running TensorBoard. - - Fields: - plugin_loaders: Set from plugins passed to constructor. - assets_zip_provider: Set by constructor. - server_class: Set by constructor. - flags: An argparse.Namespace set by the configure() method. - cache_key: As `manager.cache_key`; set by the configure() method. - """ - - def __init__( - self, - plugins=None, - assets_zip_provider=None, - server_class=None, - subcommands=None, - ): - """Creates new instance. - - Args: - plugins: A list of TensorBoard plugins to load, as TBPlugin classes or - TBLoader instances or classes. If not specified, defaults to first-party - plugins. - assets_zip_provider: Delegates to TBContext or uses default if None. - server_class: An optional factory for a `TensorBoardServer` to use - for serving the TensorBoard WSGI app. If provided, its callable - signature should match that of `TensorBoardServer.__init__`. - - :type plugins: - list[ - base_plugin.TBLoader | Type[base_plugin.TBLoader] | - Type[base_plugin.TBPlugin] - ] - """ - if plugins is None: - from tensorboard import default - plugins = default.get_plugins() - if assets_zip_provider is None: - assets_zip_provider = get_default_assets_zip_provider() - if server_class is None: - server_class = create_port_scanning_werkzeug_server - if subcommands is None: - subcommands = [] - self.plugin_loaders = [application.make_plugin_loader(p) for p in plugins] - self.assets_zip_provider = assets_zip_provider - self.server_class = server_class - self.subcommands = {} - for subcommand in subcommands: - name = subcommand.name() - if name in self.subcommands or name == _SERVE_SUBCOMMAND_NAME: - raise ValueError("Duplicate subcommand name: %r" % name) - self.subcommands[name] = subcommand - self.flags = None - - def configure(self, argv=('',), **kwargs): - """Configures TensorBoard behavior via flags. - - This method will populate the "flags" property with an argparse.Namespace - representing flag values parsed from the provided argv list, overridden by - explicit flags from remaining keyword arguments. - - Args: - argv: Can be set to CLI args equivalent to sys.argv; the first arg is - taken to be the name of the path being executed. - kwargs: Additional arguments will override what was parsed from - argv. They must be passed as Python data structures, e.g. - `foo=1` rather than `foo="1"`. +def get_default_assets_zip_provider(): + """Opens stock TensorBoard web assets collection. Returns: - Either argv[:1] if argv was non-empty, or [''] otherwise, as a mechanism - for absl.app.run() compatibility. - - Raises: - ValueError: If flag values are invalid. + Returns function that returns a newly opened file handle to zip file + containing static assets for stock TensorBoard, or None if webfiles.zip + could not be found. The value the callback returns must be closed. The + paths inside the zip file are considered absolute paths on the web server. """ - - base_parser = argparse_flags.ArgumentParser( - prog='tensorboard', - description=('TensorBoard is a suite of web applications for ' - 'inspecting and understanding your TensorFlow runs ' - 'and graphs. https://github.com/tensorflow/tensorboard ')) - subparsers = base_parser.add_subparsers( - help="TensorBoard subcommand (defaults to %r)" % _SERVE_SUBCOMMAND_NAME) - - serve_subparser = subparsers.add_parser( - _SERVE_SUBCOMMAND_NAME, - help='start local TensorBoard server (default subcommand)') - serve_subparser.set_defaults(**{_SUBCOMMAND_FLAG: _SERVE_SUBCOMMAND_NAME}) - - if len(argv) < 2 or argv[1].startswith('-'): - # This invocation, if valid, must not use any subcommands: we - # don't permit flags before the subcommand name. - serve_parser = base_parser - else: - # This invocation, if valid, must use a subcommand: we don't take - # any positional arguments to `serve`. - serve_parser = serve_subparser - - for (name, subcommand) in six.iteritems(self.subcommands): - subparser = subparsers.add_parser( - name, help=subcommand.help(), description=subcommand.description()) - subparser.set_defaults(**{_SUBCOMMAND_FLAG: name}) - subcommand.define_flags(subparser) - - for loader in self.plugin_loaders: - loader.define_flags(serve_parser) - - arg0 = argv[0] if argv else '' - - with argparse_util.allow_missing_subcommand(): - flags = base_parser.parse_args(argv[1:]) # Strip binary name from argv. - if getattr(flags, _SUBCOMMAND_FLAG, None) is None: - # Manually assign default value rather than using `set_defaults` - # on the base parser to work around Python bug #9351 on old - # versions of `argparse`: - setattr(flags, _SUBCOMMAND_FLAG, _SERVE_SUBCOMMAND_NAME) - - self.cache_key = manager.cache_key( - working_directory=os.getcwd(), - arguments=argv[1:], - configure_kwargs=kwargs, + path = os.path.join( + os.path.dirname(inspect.getfile(sys._getframe(1))), "webfiles.zip" ) - if arg0: - # Only expose main module Abseil flags as TensorBoard native flags. - # This is the same logic Abseil's ArgumentParser uses for determining - # which Abseil flags to include in the short helpstring. - for flag in set(absl_flags.FLAGS.get_key_flags_for_module(arg0)): - if hasattr(flags, flag.name): - raise ValueError('Conflicting Abseil flag: %s' % flag.name) - setattr(flags, flag.name, flag.value) - for k, v in kwargs.items(): - if not hasattr(flags, k): - raise ValueError('Unknown TensorBoard flag: %s' % k) - setattr(flags, k, v) - if getattr(flags, _SUBCOMMAND_FLAG) == _SERVE_SUBCOMMAND_NAME: - for loader in self.plugin_loaders: - loader.fix_flags(flags) - self.flags = flags - return [arg0] - - def main(self, ignored_argv=('',)): - """Blocking main function for TensorBoard. - - This method is called by `tensorboard.main.run_main`, which is the - standard entrypoint for the tensorboard command line program. The - configure() method must be called first. - - Args: - ignored_argv: Do not pass. Required for Abseil compatibility. + if not os.path.exists(path): + logger.warning("webfiles.zip static assets not found: %s", path) + return None + return lambda: open(path, "rb") - Returns: - Process exit code, i.e. 0 if successful or non-zero on failure. In - practice, an exception will most likely be raised instead of - returning non-zero. - :rtype: int +class TensorBoard(object): + """Class for running TensorBoard. + + Fields: + plugin_loaders: Set from plugins passed to constructor. + assets_zip_provider: Set by constructor. + server_class: Set by constructor. + flags: An argparse.Namespace set by the configure() method. + cache_key: As `manager.cache_key`; set by the configure() method. """ - self._install_signal_handler(signal.SIGTERM, "SIGTERM") - subcommand_name = getattr(self.flags, _SUBCOMMAND_FLAG) - if subcommand_name == _SERVE_SUBCOMMAND_NAME: - runner = self._run_serve_subcommand - else: - runner = self.subcommands[subcommand_name].run - return runner(self.flags) or 0 - - def _run_serve_subcommand(self, flags): - # TODO(#2801): Make `--version` a flag on only the base parser, not `serve`. - if flags.version_tb: - print(version.VERSION) - return 0 - if flags.inspect: - # TODO(@wchargin): Convert `inspect` to a normal subcommand? - logger.info('Not bringing up TensorBoard, but inspecting event files.') - event_file = os.path.expanduser(flags.event_file) - efi.inspect(flags.logdir, event_file, flags.tag) - return 0 - try: - server = self._make_server() - server.print_serving_message() - self._register_info(server) - server.serve_forever() - return 0 - except TensorBoardServerException as e: - logger.error(e.msg) - sys.stderr.write('ERROR: %s\n' % e.msg) - sys.stderr.flush() - return -1 - - def launch(self): - """Python API for launching TensorBoard. - - This method is the same as main() except it launches TensorBoard in - a separate permanent thread. The configure() method must be called - first. - Returns: - The URL of the TensorBoard web server. + def __init__( + self, + plugins=None, + assets_zip_provider=None, + server_class=None, + subcommands=None, + ): + """Creates new instance. + + Args: + plugins: A list of TensorBoard plugins to load, as TBPlugin classes or + TBLoader instances or classes. If not specified, defaults to first-party + plugins. + assets_zip_provider: Delegates to TBContext or uses default if None. + server_class: An optional factory for a `TensorBoardServer` to use + for serving the TensorBoard WSGI app. If provided, its callable + signature should match that of `TensorBoardServer.__init__`. + + :type plugins: + list[ + base_plugin.TBLoader | Type[base_plugin.TBLoader] | + Type[base_plugin.TBPlugin] + ] + """ + if plugins is None: + from tensorboard import default + + plugins = default.get_plugins() + if assets_zip_provider is None: + assets_zip_provider = get_default_assets_zip_provider() + if server_class is None: + server_class = create_port_scanning_werkzeug_server + if subcommands is None: + subcommands = [] + self.plugin_loaders = [ + application.make_plugin_loader(p) for p in plugins + ] + self.assets_zip_provider = assets_zip_provider + self.server_class = server_class + self.subcommands = {} + for subcommand in subcommands: + name = subcommand.name() + if name in self.subcommands or name == _SERVE_SUBCOMMAND_NAME: + raise ValueError("Duplicate subcommand name: %r" % name) + self.subcommands[name] = subcommand + self.flags = None + + def configure(self, argv=("",), **kwargs): + """Configures TensorBoard behavior via flags. + + This method will populate the "flags" property with an argparse.Namespace + representing flag values parsed from the provided argv list, overridden by + explicit flags from remaining keyword arguments. + + Args: + argv: Can be set to CLI args equivalent to sys.argv; the first arg is + taken to be the name of the path being executed. + kwargs: Additional arguments will override what was parsed from + argv. They must be passed as Python data structures, e.g. + `foo=1` rather than `foo="1"`. + + Returns: + Either argv[:1] if argv was non-empty, or [''] otherwise, as a mechanism + for absl.app.run() compatibility. + + Raises: + ValueError: If flag values are invalid. + """ + + base_parser = argparse_flags.ArgumentParser( + prog="tensorboard", + description=( + "TensorBoard is a suite of web applications for " + "inspecting and understanding your TensorFlow runs " + "and graphs. https://github.com/tensorflow/tensorboard " + ), + ) + subparsers = base_parser.add_subparsers( + help="TensorBoard subcommand (defaults to %r)" + % _SERVE_SUBCOMMAND_NAME + ) + + serve_subparser = subparsers.add_parser( + _SERVE_SUBCOMMAND_NAME, + help="start local TensorBoard server (default subcommand)", + ) + serve_subparser.set_defaults( + **{_SUBCOMMAND_FLAG: _SERVE_SUBCOMMAND_NAME} + ) + + if len(argv) < 2 or argv[1].startswith("-"): + # This invocation, if valid, must not use any subcommands: we + # don't permit flags before the subcommand name. + serve_parser = base_parser + else: + # This invocation, if valid, must use a subcommand: we don't take + # any positional arguments to `serve`. + serve_parser = serve_subparser + + for (name, subcommand) in six.iteritems(self.subcommands): + subparser = subparsers.add_parser( + name, + help=subcommand.help(), + description=subcommand.description(), + ) + subparser.set_defaults(**{_SUBCOMMAND_FLAG: name}) + subcommand.define_flags(subparser) + + for loader in self.plugin_loaders: + loader.define_flags(serve_parser) + + arg0 = argv[0] if argv else "" + + with argparse_util.allow_missing_subcommand(): + flags = base_parser.parse_args( + argv[1:] + ) # Strip binary name from argv. + if getattr(flags, _SUBCOMMAND_FLAG, None) is None: + # Manually assign default value rather than using `set_defaults` + # on the base parser to work around Python bug #9351 on old + # versions of `argparse`: + setattr(flags, _SUBCOMMAND_FLAG, _SERVE_SUBCOMMAND_NAME) + + self.cache_key = manager.cache_key( + working_directory=os.getcwd(), + arguments=argv[1:], + configure_kwargs=kwargs, + ) + if arg0: + # Only expose main module Abseil flags as TensorBoard native flags. + # This is the same logic Abseil's ArgumentParser uses for determining + # which Abseil flags to include in the short helpstring. + for flag in set(absl_flags.FLAGS.get_key_flags_for_module(arg0)): + if hasattr(flags, flag.name): + raise ValueError("Conflicting Abseil flag: %s" % flag.name) + setattr(flags, flag.name, flag.value) + for k, v in kwargs.items(): + if not hasattr(flags, k): + raise ValueError("Unknown TensorBoard flag: %s" % k) + setattr(flags, k, v) + if getattr(flags, _SUBCOMMAND_FLAG) == _SERVE_SUBCOMMAND_NAME: + for loader in self.plugin_loaders: + loader.fix_flags(flags) + self.flags = flags + return [arg0] + + def main(self, ignored_argv=("",)): + """Blocking main function for TensorBoard. + + This method is called by `tensorboard.main.run_main`, which is the + standard entrypoint for the tensorboard command line program. The + configure() method must be called first. + + Args: + ignored_argv: Do not pass. Required for Abseil compatibility. + + Returns: + Process exit code, i.e. 0 if successful or non-zero on failure. In + practice, an exception will most likely be raised instead of + returning non-zero. + + :rtype: int + """ + self._install_signal_handler(signal.SIGTERM, "SIGTERM") + subcommand_name = getattr(self.flags, _SUBCOMMAND_FLAG) + if subcommand_name == _SERVE_SUBCOMMAND_NAME: + runner = self._run_serve_subcommand + else: + runner = self.subcommands[subcommand_name].run + return runner(self.flags) or 0 + + def _run_serve_subcommand(self, flags): + # TODO(#2801): Make `--version` a flag on only the base parser, not `serve`. + if flags.version_tb: + print(version.VERSION) + return 0 + if flags.inspect: + # TODO(@wchargin): Convert `inspect` to a normal subcommand? + logger.info( + "Not bringing up TensorBoard, but inspecting event files." + ) + event_file = os.path.expanduser(flags.event_file) + efi.inspect(flags.logdir, event_file, flags.tag) + return 0 + try: + server = self._make_server() + server.print_serving_message() + self._register_info(server) + server.serve_forever() + return 0 + except TensorBoardServerException as e: + logger.error(e.msg) + sys.stderr.write("ERROR: %s\n" % e.msg) + sys.stderr.flush() + return -1 + + def launch(self): + """Python API for launching TensorBoard. + + This method is the same as main() except it launches TensorBoard in + a separate permanent thread. The configure() method must be called + first. + + Returns: + The URL of the TensorBoard web server. + + :rtype: str + """ + # Make it easy to run TensorBoard inside other programs, e.g. Colab. + server = self._make_server() + thread = threading.Thread( + target=server.serve_forever, name="TensorBoard" + ) + thread.daemon = True + thread.start() + return server.get_url() + + def _register_info(self, server): + """Write a TensorBoardInfo file and arrange for its cleanup. + + Args: + server: The result of `self._make_server()`. + """ + server_url = urllib.parse.urlparse(server.get_url()) + info = manager.TensorBoardInfo( + version=version.VERSION, + start_time=int(time.time()), + port=server_url.port, + pid=os.getpid(), + path_prefix=self.flags.path_prefix, + logdir=self.flags.logdir or self.flags.logdir_spec, + db=self.flags.db, + cache_key=self.cache_key, + ) + atexit.register(manager.remove_info_file) + manager.write_info_file(info) + + def _install_signal_handler(self, signal_number, signal_name): + """Set a signal handler to gracefully exit on the given signal. + + When this process receives the given signal, it will run `atexit` + handlers and then exit with `0`. + + Args: + signal_number: The numeric code for the signal to handle, like + `signal.SIGTERM`. + signal_name: The human-readable signal name. + """ + old_signal_handler = None # set below + + def handler(handled_signal_number, frame): + # In case we catch this signal again while running atexit + # handlers, take the hint and actually die. + signal.signal(signal_number, signal.SIG_DFL) + sys.stderr.write( + "TensorBoard caught %s; exiting...\n" % signal_name + ) + # The main thread is the only non-daemon thread, so it suffices to + # exit hence. + if old_signal_handler not in (signal.SIG_IGN, signal.SIG_DFL): + old_signal_handler(handled_signal_number, frame) + sys.exit(0) + + old_signal_handler = signal.signal(signal_number, handler) + + def _make_server(self): + """Constructs the TensorBoard WSGI app and instantiates the server.""" + app = application.standard_tensorboard_wsgi( + self.flags, self.plugin_loaders, self.assets_zip_provider + ) + return self.server_class(app, self.flags) - :rtype: str - """ - # Make it easy to run TensorBoard inside other programs, e.g. Colab. - server = self._make_server() - thread = threading.Thread(target=server.serve_forever, name='TensorBoard') - thread.daemon = True - thread.start() - return server.get_url() - def _register_info(self, server): - """Write a TensorBoardInfo file and arrange for its cleanup. +@six.add_metaclass(ABCMeta) +class TensorBoardSubcommand(object): + """Experimental private API for defining subcommands to tensorboard(1).""" - Args: - server: The result of `self._make_server()`. - """ - server_url = urllib.parse.urlparse(server.get_url()) - info = manager.TensorBoardInfo( - version=version.VERSION, - start_time=int(time.time()), - port=server_url.port, - pid=os.getpid(), - path_prefix=self.flags.path_prefix, - logdir=self.flags.logdir or self.flags.logdir_spec, - db=self.flags.db, - cache_key=self.cache_key, - ) - atexit.register(manager.remove_info_file) - manager.write_info_file(info) + @abstractmethod + def name(self): + """Name of this subcommand, as specified on the command line. - def _install_signal_handler(self, signal_number, signal_name): - """Set a signal handler to gracefully exit on the given signal. + This must be unique across all subcommands. - When this process receives the given signal, it will run `atexit` - handlers and then exit with `0`. + Returns: + A string. + """ + pass - Args: - signal_number: The numeric code for the signal to handle, like - `signal.SIGTERM`. - signal_name: The human-readable signal name. - """ - old_signal_handler = None # set below - def handler(handled_signal_number, frame): - # In case we catch this signal again while running atexit - # handlers, take the hint and actually die. - signal.signal(signal_number, signal.SIG_DFL) - sys.stderr.write("TensorBoard caught %s; exiting...\n" % signal_name) - # The main thread is the only non-daemon thread, so it suffices to - # exit hence. - if old_signal_handler not in (signal.SIG_IGN, signal.SIG_DFL): - old_signal_handler(handled_signal_number, frame) - sys.exit(0) - old_signal_handler = signal.signal(signal_number, handler) - - - def _make_server(self): - """Constructs the TensorBoard WSGI app and instantiates the server.""" - app = application.standard_tensorboard_wsgi(self.flags, - self.plugin_loaders, - self.assets_zip_provider) - return self.server_class(app, self.flags) + @abstractmethod + def define_flags(self, parser): + """Configure an argument parser for this subcommand. + Flags whose names start with two underscores (e.g., `__foo`) are + reserved for use by the runtime and must not be defined by + subcommands. -@six.add_metaclass(ABCMeta) -class TensorBoardSubcommand(object): - """Experimental private API for defining subcommands to tensorboard(1).""" + Args: + parser: An `argparse.ArgumentParser` scoped to this subcommand, + which this function should mutate. + """ + pass - @abstractmethod - def name(self): - """Name of this subcommand, as specified on the command line. + @abstractmethod + def run(self, flags): + """Execute this subcommand with user-provided flags. - This must be unique across all subcommands. + Args: + flags: An `argparse.Namespace` object with all defined flags. - Returns: - A string. - """ - pass + Returns: + An `int` exit code, or `None` as an alias for `0`. + """ + pass - @abstractmethod - def define_flags(self, parser): - """Configure an argument parser for this subcommand. + def help(self): + """Short, one-line help text to display on `tensorboard --help`.""" + return None - Flags whose names start with two underscores (e.g., `__foo`) are - reserved for use by the runtime and must not be defined by - subcommands. + def description(self): + """Description to display on `tensorboard SUBCOMMAND --help`.""" + return None - Args: - parser: An `argparse.ArgumentParser` scoped to this subcommand, - which this function should mutate. - """ - pass - @abstractmethod - def run(self, flags): - """Execute this subcommand with user-provided flags. +@six.add_metaclass(ABCMeta) +class TensorBoardServer(object): + """Class for customizing TensorBoard WSGI app serving.""" - Args: - flags: An `argparse.Namespace` object with all defined flags. + @abstractmethod + def __init__(self, wsgi_app, flags): + """Create a flag-configured HTTP server for TensorBoard's WSGI app. - Returns: - An `int` exit code, or `None` as an alias for `0`. - """ - pass + Args: + wsgi_app: The TensorBoard WSGI application to create a server for. + flags: argparse.Namespace instance of TensorBoard flags. + """ + raise NotImplementedError() - def help(self): - """Short, one-line help text to display on `tensorboard --help`.""" - return None + @abstractmethod + def serve_forever(self): + """Blocking call to start serving the TensorBoard server.""" + raise NotImplementedError() - def description(self): - """Description to display on `tensorboard SUBCOMMAND --help`.""" - return None + @abstractmethod + def get_url(self): + """Returns a URL at which this server should be reachable.""" + raise NotImplementedError() + def print_serving_message(self): + """Prints a user-friendly message prior to server start. -@six.add_metaclass(ABCMeta) -class TensorBoardServer(object): - """Class for customizing TensorBoard WSGI app serving.""" + This will be called just before `serve_forever`. + """ + sys.stderr.write( + "TensorBoard %s at %s (Press CTRL+C to quit)\n" + % (version.VERSION, self.get_url()) + ) + sys.stderr.flush() - @abstractmethod - def __init__(self, wsgi_app, flags): - """Create a flag-configured HTTP server for TensorBoard's WSGI app. - Args: - wsgi_app: The TensorBoard WSGI application to create a server for. - flags: argparse.Namespace instance of TensorBoard flags. +class TensorBoardServerException(Exception): + """Exception raised by TensorBoardServer for user-friendly errors. + + Subclasses of TensorBoardServer can raise this exception in order to + generate a clean error message for the user rather than a + stacktrace. """ - raise NotImplementedError() - @abstractmethod - def serve_forever(self): - """Blocking call to start serving the TensorBoard server.""" - raise NotImplementedError() + def __init__(self, msg): + self.msg = msg - @abstractmethod - def get_url(self): - """Returns a URL at which this server should be reachable.""" - raise NotImplementedError() - def print_serving_message(self): - """Prints a user-friendly message prior to server start. +class TensorBoardPortInUseError(TensorBoardServerException): + """Error raised when attempting to bind to a port that is in use. - This will be called just before `serve_forever`. + This should be raised when it is expected that binding to another + similar port would succeed. It is used as a signal to indicate that + automatic port searching should continue rather than abort. """ - sys.stderr.write( - 'TensorBoard %s at %s (Press CTRL+C to quit)\n' - % (version.VERSION, self.get_url()) - ) - sys.stderr.flush() + pass -class TensorBoardServerException(Exception): - """Exception raised by TensorBoardServer for user-friendly errors. - Subclasses of TensorBoardServer can raise this exception in order to - generate a clean error message for the user rather than a stacktrace. - """ - def __init__(self, msg): - self.msg = msg +def with_port_scanning(cls): + """Create a server factory that performs port scanning. + This function returns a callable whose signature matches the + specification of `TensorBoardServer.__init__`, using `cls` as an + underlying implementation. It passes through `flags` unchanged except + in the case that `flags.port is None`, in which case it repeatedly + instantiates the underlying server with new port suggestions. -class TensorBoardPortInUseError(TensorBoardServerException): - """Error raised when attempting to bind to a port that is in use. + Args: + cls: A valid implementation of `TensorBoardServer`. This class's + initializer should raise a `TensorBoardPortInUseError` upon + failing to bind to a port when it is expected that binding to + another nearby port might succeed. - This should be raised when it is expected that binding to another - similar port would succeed. It is used as a signal to indicate that - automatic port searching should continue rather than abort. - """ - pass + The initializer for `cls` will only ever be invoked with `flags` + such that `flags.port is not None`. + Returns: + A function that implements the `__init__` contract of + `TensorBoardServer`. + """ -def with_port_scanning(cls): - """Create a server factory that performs port scanning. - - This function returns a callable whose signature matches the - specification of `TensorBoardServer.__init__`, using `cls` as an - underlying implementation. It passes through `flags` unchanged except - in the case that `flags.port is None`, in which case it repeatedly - instantiates the underlying server with new port suggestions. - - Args: - cls: A valid implementation of `TensorBoardServer`. This class's - initializer should raise a `TensorBoardPortInUseError` upon - failing to bind to a port when it is expected that binding to - another nearby port might succeed. - - The initializer for `cls` will only ever be invoked with `flags` - such that `flags.port is not None`. - - Returns: - A function that implements the `__init__` contract of - `TensorBoardServer`. - """ - - def init(wsgi_app, flags): - # base_port: what's the first port to which we should try to bind? - # should_scan: if that fails, shall we try additional ports? - # max_attempts: how many ports shall we try? - should_scan = flags.port is None - base_port = core_plugin.DEFAULT_PORT if flags.port is None else flags.port - max_attempts = 10 if should_scan else 1 - - if base_port > 0xFFFF: - raise TensorBoardServerException( - 'TensorBoard cannot bind to port %d > %d' % (base_port, 0xFFFF) - ) - max_attempts = 10 if should_scan else 1 - base_port = min(base_port + max_attempts, 0x10000) - max_attempts - - for port in xrange(base_port, base_port + max_attempts): - subflags = argparse.Namespace(**vars(flags)) - subflags.port = port - try: - return cls(wsgi_app=wsgi_app, flags=subflags) - except TensorBoardPortInUseError: - if not should_scan: - raise - # All attempts failed to bind. - raise TensorBoardServerException( - 'TensorBoard could not bind to any port around %s ' - '(tried %d times)' - % (base_port, max_attempts)) - - return init + def init(wsgi_app, flags): + # base_port: what's the first port to which we should try to bind? + # should_scan: if that fails, shall we try additional ports? + # max_attempts: how many ports shall we try? + should_scan = flags.port is None + base_port = ( + core_plugin.DEFAULT_PORT if flags.port is None else flags.port + ) + max_attempts = 10 if should_scan else 1 + + if base_port > 0xFFFF: + raise TensorBoardServerException( + "TensorBoard cannot bind to port %d > %d" % (base_port, 0xFFFF) + ) + max_attempts = 10 if should_scan else 1 + base_port = min(base_port + max_attempts, 0x10000) - max_attempts + + for port in xrange(base_port, base_port + max_attempts): + subflags = argparse.Namespace(**vars(flags)) + subflags.port = port + try: + return cls(wsgi_app=wsgi_app, flags=subflags) + except TensorBoardPortInUseError: + if not should_scan: + raise + # All attempts failed to bind. + raise TensorBoardServerException( + "TensorBoard could not bind to any port around %s " + "(tried %d times)" % (base_port, max_attempts) + ) + + return init class WerkzeugServer(serving.ThreadedWSGIServer, TensorBoardServer): - """Implementation of TensorBoardServer using the Werkzeug dev server.""" - - # ThreadedWSGIServer handles this in werkzeug 0.12+ but we allow 0.11.x. - daemon_threads = True - - def __init__(self, wsgi_app, flags): - self._flags = flags - host = flags.host - port = flags.port - - self._auto_wildcard = flags.bind_all - if self._auto_wildcard: - # Serve on all interfaces, and attempt to serve both IPv4 and IPv6 - # traffic through one socket. - host = self._get_wildcard_address(port) - elif host is None: - host = 'localhost' - - self._host = host - - self._fix_werkzeug_logging() - try: - super(WerkzeugServer, self).__init__(host, port, wsgi_app) - except socket.error as e: - if hasattr(errno, 'EACCES') and e.errno == errno.EACCES: - raise TensorBoardServerException( - 'TensorBoard must be run as superuser to bind to port %d' % - port) - elif hasattr(errno, 'EADDRINUSE') and e.errno == errno.EADDRINUSE: - if port == 0: - raise TensorBoardServerException( - 'TensorBoard unable to find any open port') - else: - raise TensorBoardPortInUseError( - 'TensorBoard could not bind to port %d, it was already in use' % - port) - elif hasattr(errno, 'EADDRNOTAVAIL') and e.errno == errno.EADDRNOTAVAIL: - raise TensorBoardServerException( - 'TensorBoard could not bind to unavailable address %s' % host) - elif hasattr(errno, 'EAFNOSUPPORT') and e.errno == errno.EAFNOSUPPORT: - raise TensorBoardServerException( - 'Tensorboard could not bind to unsupported address family %s' % - host) - # Raise the raw exception if it wasn't identifiable as a user error. - raise - - def _get_wildcard_address(self, port): - """Returns a wildcard address for the port in question. - - This will attempt to follow the best practice of calling getaddrinfo() with - a null host and AI_PASSIVE to request a server-side socket wildcard address. - If that succeeds, this returns the first IPv6 address found, or if none, - then returns the first IPv4 address. If that fails, then this returns the - hardcoded address "::" if socket.has_ipv6 is True, else "0.0.0.0". - """ - fallback_address = '::' if socket.has_ipv6 else '0.0.0.0' - if hasattr(socket, 'AI_PASSIVE'): - try: - addrinfos = socket.getaddrinfo(None, port, socket.AF_UNSPEC, - socket.SOCK_STREAM, socket.IPPROTO_TCP, - socket.AI_PASSIVE) - except socket.gaierror as e: - logger.warn('Failed to auto-detect wildcard address, assuming %s: %s', - fallback_address, str(e)) + """Implementation of TensorBoardServer using the Werkzeug dev server.""" + + # ThreadedWSGIServer handles this in werkzeug 0.12+ but we allow 0.11.x. + daemon_threads = True + + def __init__(self, wsgi_app, flags): + self._flags = flags + host = flags.host + port = flags.port + + self._auto_wildcard = flags.bind_all + if self._auto_wildcard: + # Serve on all interfaces, and attempt to serve both IPv4 and IPv6 + # traffic through one socket. + host = self._get_wildcard_address(port) + elif host is None: + host = "localhost" + + self._host = host + + self._fix_werkzeug_logging() + try: + super(WerkzeugServer, self).__init__(host, port, wsgi_app) + except socket.error as e: + if hasattr(errno, "EACCES") and e.errno == errno.EACCES: + raise TensorBoardServerException( + "TensorBoard must be run as superuser to bind to port %d" + % port + ) + elif hasattr(errno, "EADDRINUSE") and e.errno == errno.EADDRINUSE: + if port == 0: + raise TensorBoardServerException( + "TensorBoard unable to find any open port" + ) + else: + raise TensorBoardPortInUseError( + "TensorBoard could not bind to port %d, it was already in use" + % port + ) + elif ( + hasattr(errno, "EADDRNOTAVAIL") + and e.errno == errno.EADDRNOTAVAIL + ): + raise TensorBoardServerException( + "TensorBoard could not bind to unavailable address %s" + % host + ) + elif ( + hasattr(errno, "EAFNOSUPPORT") and e.errno == errno.EAFNOSUPPORT + ): + raise TensorBoardServerException( + "Tensorboard could not bind to unsupported address family %s" + % host + ) + # Raise the raw exception if it wasn't identifiable as a user error. + raise + + def _get_wildcard_address(self, port): + """Returns a wildcard address for the port in question. + + This will attempt to follow the best practice of calling + getaddrinfo() with a null host and AI_PASSIVE to request a + server-side socket wildcard address. If that succeeds, this + returns the first IPv6 address found, or if none, then returns + the first IPv4 address. If that fails, then this returns the + hardcoded address "::" if socket.has_ipv6 is True, else + "0.0.0.0". + """ + fallback_address = "::" if socket.has_ipv6 else "0.0.0.0" + if hasattr(socket, "AI_PASSIVE"): + try: + addrinfos = socket.getaddrinfo( + None, + port, + socket.AF_UNSPEC, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + socket.AI_PASSIVE, + ) + except socket.gaierror as e: + logger.warn( + "Failed to auto-detect wildcard address, assuming %s: %s", + fallback_address, + str(e), + ) + return fallback_address + addrs_by_family = defaultdict(list) + for family, _, _, _, sockaddr in addrinfos: + # Format of the "sockaddr" socket address varies by address family, + # but [0] is always the IP address portion. + addrs_by_family[family].append(sockaddr[0]) + if hasattr(socket, "AF_INET6") and addrs_by_family[socket.AF_INET6]: + return addrs_by_family[socket.AF_INET6][0] + if hasattr(socket, "AF_INET") and addrs_by_family[socket.AF_INET]: + return addrs_by_family[socket.AF_INET][0] + logger.warn( + "Failed to auto-detect wildcard address, assuming %s", + fallback_address, + ) return fallback_address - addrs_by_family = defaultdict(list) - for family, _, _, _, sockaddr in addrinfos: - # Format of the "sockaddr" socket address varies by address family, - # but [0] is always the IP address portion. - addrs_by_family[family].append(sockaddr[0]) - if hasattr(socket, 'AF_INET6') and addrs_by_family[socket.AF_INET6]: - return addrs_by_family[socket.AF_INET6][0] - if hasattr(socket, 'AF_INET') and addrs_by_family[socket.AF_INET]: - return addrs_by_family[socket.AF_INET][0] - logger.warn('Failed to auto-detect wildcard address, assuming %s', - fallback_address) - return fallback_address - - def server_bind(self): - """Override to enable IPV4 mapping for IPV6 sockets when desired. - - The main use case for this is so that when no host is specified, TensorBoard - can listen on all interfaces for both IPv4 and IPv6 connections, rather than - having to choose v4 or v6 and hope the browser didn't choose the other one. - """ - socket_is_v6 = ( - hasattr(socket, 'AF_INET6') and self.socket.family == socket.AF_INET6) - has_v6only_option = ( - hasattr(socket, 'IPPROTO_IPV6') and hasattr(socket, 'IPV6_V6ONLY')) - if self._auto_wildcard and socket_is_v6 and has_v6only_option: - try: - self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) - except socket.error as e: - # Log a warning on failure to dual-bind, except for EAFNOSUPPORT - # since that's expected if IPv4 isn't supported at all (IPv6-only). - if hasattr(errno, 'EAFNOSUPPORT') and e.errno != errno.EAFNOSUPPORT: - logger.warn('Failed to dual-bind to IPv4 wildcard: %s', str(e)) - super(WerkzeugServer, self).server_bind() - - def handle_error(self, request, client_address): - """Override to get rid of noisy EPIPE errors.""" - del request # unused - # Kludge to override a SocketServer.py method so we can get rid of noisy - # EPIPE errors. They're kind of a red herring as far as errors go. For - # example, `curl -N http://localhost:6006/ | head` will cause an EPIPE. - exc_info = sys.exc_info() - e = exc_info[1] - if isinstance(e, IOError) and e.errno == errno.EPIPE: - logger.warn('EPIPE caused by %s in HTTP serving' % str(client_address)) - else: - logger.error('HTTP serving error', exc_info=exc_info) - - def get_url(self): - if self._auto_wildcard: - display_host = socket.gethostname() - else: - host = self._host - display_host = ( - '[%s]' % host if ':' in host and not host.startswith('[') else host) - return 'http://%s:%d%s/' % (display_host, self.server_port, - self._flags.path_prefix.rstrip('/')) - - def print_serving_message(self): - if self._flags.host is None and not self._flags.bind_all: - sys.stderr.write( - 'Serving TensorBoard on localhost; to expose to the network, ' - 'use a proxy or pass --bind_all\n' - ) - sys.stderr.flush() - super(WerkzeugServer, self).print_serving_message() - - def _fix_werkzeug_logging(self): - """Fix werkzeug logging setup so it inherits TensorBoard's log level. - - This addresses a change in werkzeug 0.15.0+ [1] that causes it set its own - log level to INFO regardless of the root logger configuration. We instead - want werkzeug to inherit TensorBoard's root logger log level (set via absl - to WARNING by default). - - [1]: https://github.com/pallets/werkzeug/commit/4cf77d25858ff46ac7e9d64ade054bf05b41ce12 - """ - # Log once at DEBUG to force werkzeug to initialize its singleton logger, - # which sets the logger level to INFO it if is unset, and then access that - # object via logging.getLogger('werkzeug') to durably revert the level to - # unset (and thus make messages logged to it inherit the root logger level). - self.log('debug', 'Fixing werkzeug logger to inherit TensorBoard log level') - logging.getLogger('werkzeug').setLevel(logging.NOTSET) + + def server_bind(self): + """Override to enable IPV4 mapping for IPV6 sockets when desired. + + The main use case for this is so that when no host is specified, + TensorBoard can listen on all interfaces for both IPv4 and IPv6 + connections, rather than having to choose v4 or v6 and hope the + browser didn't choose the other one. + """ + socket_is_v6 = ( + hasattr(socket, "AF_INET6") + and self.socket.family == socket.AF_INET6 + ) + has_v6only_option = hasattr(socket, "IPPROTO_IPV6") and hasattr( + socket, "IPV6_V6ONLY" + ) + if self._auto_wildcard and socket_is_v6 and has_v6only_option: + try: + self.socket.setsockopt( + socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0 + ) + except socket.error as e: + # Log a warning on failure to dual-bind, except for EAFNOSUPPORT + # since that's expected if IPv4 isn't supported at all (IPv6-only). + if ( + hasattr(errno, "EAFNOSUPPORT") + and e.errno != errno.EAFNOSUPPORT + ): + logger.warn( + "Failed to dual-bind to IPv4 wildcard: %s", str(e) + ) + super(WerkzeugServer, self).server_bind() + + def handle_error(self, request, client_address): + """Override to get rid of noisy EPIPE errors.""" + del request # unused + # Kludge to override a SocketServer.py method so we can get rid of noisy + # EPIPE errors. They're kind of a red herring as far as errors go. For + # example, `curl -N http://localhost:6006/ | head` will cause an EPIPE. + exc_info = sys.exc_info() + e = exc_info[1] + if isinstance(e, IOError) and e.errno == errno.EPIPE: + logger.warn( + "EPIPE caused by %s in HTTP serving" % str(client_address) + ) + else: + logger.error("HTTP serving error", exc_info=exc_info) + + def get_url(self): + if self._auto_wildcard: + display_host = socket.gethostname() + else: + host = self._host + display_host = ( + "[%s]" % host + if ":" in host and not host.startswith("[") + else host + ) + return "http://%s:%d%s/" % ( + display_host, + self.server_port, + self._flags.path_prefix.rstrip("/"), + ) + + def print_serving_message(self): + if self._flags.host is None and not self._flags.bind_all: + sys.stderr.write( + "Serving TensorBoard on localhost; to expose to the network, " + "use a proxy or pass --bind_all\n" + ) + sys.stderr.flush() + super(WerkzeugServer, self).print_serving_message() + + def _fix_werkzeug_logging(self): + """Fix werkzeug logging setup so it inherits TensorBoard's log level. + + This addresses a change in werkzeug 0.15.0+ [1] that causes it set its own + log level to INFO regardless of the root logger configuration. We instead + want werkzeug to inherit TensorBoard's root logger log level (set via absl + to WARNING by default). + + [1]: https://github.com/pallets/werkzeug/commit/4cf77d25858ff46ac7e9d64ade054bf05b41ce12 + """ + # Log once at DEBUG to force werkzeug to initialize its singleton logger, + # which sets the logger level to INFO it if is unset, and then access that + # object via logging.getLogger('werkzeug') to durably revert the level to + # unset (and thus make messages logged to it inherit the root logger level). + self.log( + "debug", "Fixing werkzeug logger to inherit TensorBoard log level" + ) + logging.getLogger("werkzeug").setLevel(logging.NOTSET) create_port_scanning_werkzeug_server = with_port_scanning(WerkzeugServer) diff --git a/tensorboard/program_test.py b/tensorboard/program_test.py index 1ed1dcbdf2..8c0621cf78 100644 --- a/tensorboard/program_test.py +++ b/tensorboard/program_test.py @@ -25,10 +25,10 @@ import six try: - # python version >= 3.3 - from unittest import mock + # python version >= 3.3 + from unittest import mock except ImportError: - import mock # pylint: disable=unused-import + import mock # pylint: disable=unused-import from tensorboard import program from tensorboard import test as tb_test @@ -37,227 +37,238 @@ class TensorBoardTest(tb_test.TestCase): - """Tests the TensorBoard program.""" - - def testPlugins_pluginClass(self): - tb = program.TensorBoard(plugins=[core_plugin.CorePlugin]) - self.assertIsInstance(tb.plugin_loaders[0], base_plugin.BasicLoader) - self.assertIs(tb.plugin_loaders[0].plugin_class, core_plugin.CorePlugin) - - def testPlugins_pluginLoaderClass(self): - tb = program.TensorBoard(plugins=[core_plugin.CorePluginLoader]) - self.assertIsInstance(tb.plugin_loaders[0], core_plugin.CorePluginLoader) - - def testPlugins_pluginLoader(self): - loader = core_plugin.CorePluginLoader() - tb = program.TensorBoard(plugins=[loader]) - self.assertIs(tb.plugin_loaders[0], loader) - - def testPlugins_invalidType(self): - plugin_instance = core_plugin.CorePlugin(base_plugin.TBContext()) - with six.assertRaisesRegex(self, TypeError, 'CorePlugin'): - tb = program.TensorBoard(plugins=[plugin_instance]) - - def testConfigure(self): - tb = program.TensorBoard(plugins=[core_plugin.CorePluginLoader]) - tb.configure(logdir='foo') - self.assertEqual(tb.flags.logdir, 'foo') - - def testConfigure_unknownFlag(self): - tb = program.TensorBoard(plugins=[core_plugin.CorePlugin]) - with six.assertRaisesRegex(self, ValueError, 'Unknown TensorBoard flag'): - tb.configure(foo='bar') + """Tests the TensorBoard program.""" + + def testPlugins_pluginClass(self): + tb = program.TensorBoard(plugins=[core_plugin.CorePlugin]) + self.assertIsInstance(tb.plugin_loaders[0], base_plugin.BasicLoader) + self.assertIs(tb.plugin_loaders[0].plugin_class, core_plugin.CorePlugin) + + def testPlugins_pluginLoaderClass(self): + tb = program.TensorBoard(plugins=[core_plugin.CorePluginLoader]) + self.assertIsInstance( + tb.plugin_loaders[0], core_plugin.CorePluginLoader + ) + + def testPlugins_pluginLoader(self): + loader = core_plugin.CorePluginLoader() + tb = program.TensorBoard(plugins=[loader]) + self.assertIs(tb.plugin_loaders[0], loader) + + def testPlugins_invalidType(self): + plugin_instance = core_plugin.CorePlugin(base_plugin.TBContext()) + with six.assertRaisesRegex(self, TypeError, "CorePlugin"): + tb = program.TensorBoard(plugins=[plugin_instance]) + + def testConfigure(self): + tb = program.TensorBoard(plugins=[core_plugin.CorePluginLoader]) + tb.configure(logdir="foo") + self.assertEqual(tb.flags.logdir, "foo") + + def testConfigure_unknownFlag(self): + tb = program.TensorBoard(plugins=[core_plugin.CorePlugin]) + with six.assertRaisesRegex( + self, ValueError, "Unknown TensorBoard flag" + ): + tb.configure(foo="bar") class WerkzeugServerTest(tb_test.TestCase): - """Tests the default Werkzeug implementation of TensorBoardServer. - - Mostly useful for IPv4/IPv6 testing. This test should run with only IPv4, only - IPv6, and both IPv4 and IPv6 enabled. - """ - - class _StubApplication(object): - pass - - def make_flags(self, **kwargs): - flags = argparse.Namespace() - kwargs.setdefault('host', None) - kwargs.setdefault('bind_all', kwargs['host'] is None) - for k, v in six.iteritems(kwargs): - setattr(flags, k, v) - return flags - - def testMakeServerBlankHost(self): - # Test that we can bind to all interfaces without throwing an error - server = program.WerkzeugServer( - self._StubApplication(), - self.make_flags(port=0, path_prefix='')) - self.assertStartsWith(server.get_url(), 'http://') - - def testPathPrefixSlash(self): - #Test that checks the path prefix ends with one trailing slash - server = program.WerkzeugServer( - self._StubApplication(), - self.make_flags(port=0, path_prefix='/test')) - self.assertEndsWith(server.get_url(), '/test/') - - server = program.WerkzeugServer( - self._StubApplication(), - self.make_flags(port=0, path_prefix='/test/')) - self.assertEndsWith(server.get_url(), '/test/') - - def testSpecifiedHost(self): - one_passed = False - try: - server = program.WerkzeugServer( - self._StubApplication(), - self.make_flags(host='127.0.0.1', port=0, path_prefix='')) - self.assertStartsWith(server.get_url(), 'http://127.0.0.1:') - one_passed = True - except program.TensorBoardServerException: - # IPv4 is not supported - pass - try: - server = program.WerkzeugServer( - self._StubApplication(), - self.make_flags(host='::1', port=0, path_prefix='')) - self.assertStartsWith(server.get_url(), 'http://[::1]:') - one_passed = True - except program.TensorBoardServerException: - # IPv6 is not supported - pass - self.assertTrue(one_passed) # We expect either IPv4 or IPv6 to be supported + """Tests the default Werkzeug implementation of TensorBoardServer. + + Mostly useful for IPv4/IPv6 testing. This test should run with only + IPv4, only IPv6, and both IPv4 and IPv6 enabled. + """ + + class _StubApplication(object): + pass + + def make_flags(self, **kwargs): + flags = argparse.Namespace() + kwargs.setdefault("host", None) + kwargs.setdefault("bind_all", kwargs["host"] is None) + for k, v in six.iteritems(kwargs): + setattr(flags, k, v) + return flags + + def testMakeServerBlankHost(self): + # Test that we can bind to all interfaces without throwing an error + server = program.WerkzeugServer( + self._StubApplication(), self.make_flags(port=0, path_prefix="") + ) + self.assertStartsWith(server.get_url(), "http://") + + def testPathPrefixSlash(self): + # Test that checks the path prefix ends with one trailing slash + server = program.WerkzeugServer( + self._StubApplication(), + self.make_flags(port=0, path_prefix="/test"), + ) + self.assertEndsWith(server.get_url(), "/test/") + + server = program.WerkzeugServer( + self._StubApplication(), + self.make_flags(port=0, path_prefix="/test/"), + ) + self.assertEndsWith(server.get_url(), "/test/") + + def testSpecifiedHost(self): + one_passed = False + try: + server = program.WerkzeugServer( + self._StubApplication(), + self.make_flags(host="127.0.0.1", port=0, path_prefix=""), + ) + self.assertStartsWith(server.get_url(), "http://127.0.0.1:") + one_passed = True + except program.TensorBoardServerException: + # IPv4 is not supported + pass + try: + server = program.WerkzeugServer( + self._StubApplication(), + self.make_flags(host="::1", port=0, path_prefix=""), + ) + self.assertStartsWith(server.get_url(), "http://[::1]:") + one_passed = True + except program.TensorBoardServerException: + # IPv6 is not supported + pass + self.assertTrue( + one_passed + ) # We expect either IPv4 or IPv6 to be supported class SubcommandTest(tb_test.TestCase): - - def setUp(self): - super(SubcommandTest, self).setUp() - self.stderr = six.StringIO() - patchers = [ - mock.patch.object(program.TensorBoard, '_install_signal_handler'), - mock.patch.object(program.TensorBoard, '_run_serve_subcommand'), - mock.patch.object(_TestSubcommand, 'run'), - mock.patch.object(sys, 'stderr', self.stderr), - ] - for p in patchers: - p.start() - self.addCleanup(p.stop) - _TestSubcommand.run.return_value = None - - def tearDown(self): - stderr = self.stderr.getvalue() - if stderr: - # In case of failing tests, let there be debug info. - print('Stderr:\n%s' % stderr) - - def testImplicitServe(self): - tb = program.TensorBoard( - plugins=[core_plugin.CorePluginLoader], - subcommands=[_TestSubcommand(lambda parser: None)], - ) - tb.configure(('tb', '--logdir', 'logs', '--path_prefix', '/x///')) - tb.main() - program.TensorBoard._run_serve_subcommand.assert_called_once() - flags = program.TensorBoard._run_serve_subcommand.call_args[0][0] - self.assertEqual(flags.logdir, 'logs') - self.assertEqual(flags.path_prefix, '/x') # fixed by core_plugin - - def testExplicitServe(self): - tb = program.TensorBoard( - plugins=[core_plugin.CorePluginLoader], - subcommands=[_TestSubcommand()], - ) - tb.configure(('tb', 'serve', '--logdir', 'logs', '--path_prefix', '/x///')) - tb.main() - program.TensorBoard._run_serve_subcommand.assert_called_once() - flags = program.TensorBoard._run_serve_subcommand.call_args[0][0] - self.assertEqual(flags.logdir, 'logs') - self.assertEqual(flags.path_prefix, '/x') # fixed by core_plugin - - def testSubcommand(self): - def define_flags(parser): - parser.add_argument('--hello') - - tb = program.TensorBoard( - plugins=[core_plugin.CorePluginLoader], - subcommands=[_TestSubcommand(define_flags=define_flags)], - ) - tb.configure(('tb', 'test', '--hello', 'world')) - self.assertEqual(tb.main(), 0) - _TestSubcommand.run.assert_called_once() - flags = _TestSubcommand.run.call_args[0][0] - self.assertEqual(flags.hello, 'world') - - def testSubcommand_ExitCode(self): - tb = program.TensorBoard( - plugins=[core_plugin.CorePluginLoader], - subcommands=[_TestSubcommand()], - ) - _TestSubcommand.run.return_value = 77 - tb.configure(('tb', 'test')) - self.assertEqual(tb.main(), 77) - - def testSubcommand_DoesNotInheritBaseArgs(self): - tb = program.TensorBoard( - plugins=[core_plugin.CorePluginLoader], - subcommands=[_TestSubcommand()], - ) - with self.assertRaises(SystemExit): - tb.configure(('tb', 'test', '--logdir', 'logs')) - self.assertIn( - 'unrecognized arguments: --logdir logs', self.stderr.getvalue()) - self.stderr.truncate(0) - - def testSubcommand_MayRequirePositionals(self): - def define_flags(parser): - parser.add_argument('payload') - - tb = program.TensorBoard( - plugins=[core_plugin.CorePluginLoader], - subcommands=[_TestSubcommand(define_flags=define_flags)], - ) - with self.assertRaises(SystemExit): - tb.configure(('tb', 'test')) - self.assertIn('required', self.stderr.getvalue()) - self.assertIn('payload', self.stderr.getvalue()) - self.stderr.truncate(0) - - def testConflictingNames_AmongSubcommands(self): - with self.assertRaises(ValueError) as cm: - tb = program.TensorBoard( - plugins=[core_plugin.CorePluginLoader], - subcommands=[_TestSubcommand(), _TestSubcommand()], - ) - self.assertIn('Duplicate subcommand name:', str(cm.exception)) - self.assertIn('test', str(cm.exception)) - - def testConflictingNames_WithServe(self): - with self.assertRaises(ValueError) as cm: - tb = program.TensorBoard( - plugins=[core_plugin.CorePluginLoader], - subcommands=[_TestSubcommand(name='serve')], - ) - self.assertIn('Duplicate subcommand name:', str(cm.exception)) - self.assertIn('serve', str(cm.exception)) + def setUp(self): + super(SubcommandTest, self).setUp() + self.stderr = six.StringIO() + patchers = [ + mock.patch.object(program.TensorBoard, "_install_signal_handler"), + mock.patch.object(program.TensorBoard, "_run_serve_subcommand"), + mock.patch.object(_TestSubcommand, "run"), + mock.patch.object(sys, "stderr", self.stderr), + ] + for p in patchers: + p.start() + self.addCleanup(p.stop) + _TestSubcommand.run.return_value = None + + def tearDown(self): + stderr = self.stderr.getvalue() + if stderr: + # In case of failing tests, let there be debug info. + print("Stderr:\n%s" % stderr) + + def testImplicitServe(self): + tb = program.TensorBoard( + plugins=[core_plugin.CorePluginLoader], + subcommands=[_TestSubcommand(lambda parser: None)], + ) + tb.configure(("tb", "--logdir", "logs", "--path_prefix", "/x///")) + tb.main() + program.TensorBoard._run_serve_subcommand.assert_called_once() + flags = program.TensorBoard._run_serve_subcommand.call_args[0][0] + self.assertEqual(flags.logdir, "logs") + self.assertEqual(flags.path_prefix, "/x") # fixed by core_plugin + + def testExplicitServe(self): + tb = program.TensorBoard( + plugins=[core_plugin.CorePluginLoader], + subcommands=[_TestSubcommand()], + ) + tb.configure( + ("tb", "serve", "--logdir", "logs", "--path_prefix", "/x///") + ) + tb.main() + program.TensorBoard._run_serve_subcommand.assert_called_once() + flags = program.TensorBoard._run_serve_subcommand.call_args[0][0] + self.assertEqual(flags.logdir, "logs") + self.assertEqual(flags.path_prefix, "/x") # fixed by core_plugin + + def testSubcommand(self): + def define_flags(parser): + parser.add_argument("--hello") + + tb = program.TensorBoard( + plugins=[core_plugin.CorePluginLoader], + subcommands=[_TestSubcommand(define_flags=define_flags)], + ) + tb.configure(("tb", "test", "--hello", "world")) + self.assertEqual(tb.main(), 0) + _TestSubcommand.run.assert_called_once() + flags = _TestSubcommand.run.call_args[0][0] + self.assertEqual(flags.hello, "world") + + def testSubcommand_ExitCode(self): + tb = program.TensorBoard( + plugins=[core_plugin.CorePluginLoader], + subcommands=[_TestSubcommand()], + ) + _TestSubcommand.run.return_value = 77 + tb.configure(("tb", "test")) + self.assertEqual(tb.main(), 77) + + def testSubcommand_DoesNotInheritBaseArgs(self): + tb = program.TensorBoard( + plugins=[core_plugin.CorePluginLoader], + subcommands=[_TestSubcommand()], + ) + with self.assertRaises(SystemExit): + tb.configure(("tb", "test", "--logdir", "logs")) + self.assertIn( + "unrecognized arguments: --logdir logs", self.stderr.getvalue() + ) + self.stderr.truncate(0) + + def testSubcommand_MayRequirePositionals(self): + def define_flags(parser): + parser.add_argument("payload") + + tb = program.TensorBoard( + plugins=[core_plugin.CorePluginLoader], + subcommands=[_TestSubcommand(define_flags=define_flags)], + ) + with self.assertRaises(SystemExit): + tb.configure(("tb", "test")) + self.assertIn("required", self.stderr.getvalue()) + self.assertIn("payload", self.stderr.getvalue()) + self.stderr.truncate(0) + + def testConflictingNames_AmongSubcommands(self): + with self.assertRaises(ValueError) as cm: + tb = program.TensorBoard( + plugins=[core_plugin.CorePluginLoader], + subcommands=[_TestSubcommand(), _TestSubcommand()], + ) + self.assertIn("Duplicate subcommand name:", str(cm.exception)) + self.assertIn("test", str(cm.exception)) + + def testConflictingNames_WithServe(self): + with self.assertRaises(ValueError) as cm: + tb = program.TensorBoard( + plugins=[core_plugin.CorePluginLoader], + subcommands=[_TestSubcommand(name="serve")], + ) + self.assertIn("Duplicate subcommand name:", str(cm.exception)) + self.assertIn("serve", str(cm.exception)) class _TestSubcommand(program.TensorBoardSubcommand): + def __init__(self, name=None, define_flags=None): + self._name = name + self._define_flags = define_flags - def __init__(self, name=None, define_flags=None): - self._name = name - self._define_flags = define_flags - - def name(self): - return self._name or 'test' + def name(self): + return self._name or "test" - def define_flags(self, parser): - if self._define_flags: - self._define_flags(parser) + def define_flags(self, parser): + if self._define_flags: + self._define_flags(parser) - def run(self, flags): - pass + def run(self, flags): + pass -if __name__ == '__main__': - tb_test.main() +if __name__ == "__main__": + tb_test.main() diff --git a/tensorboard/scripts/execrooter.py b/tensorboard/scripts/execrooter.py index 76bc900b80..526e560156 100644 --- a/tensorboard/scripts/execrooter.py +++ b/tensorboard/scripts/execrooter.py @@ -26,69 +26,71 @@ def run(inputs, program, outputs): - """Creates temp symlink tree, runs program, and copies back outputs. + """Creates temp symlink tree, runs program, and copies back outputs. - Args: - inputs: List of fake paths to real paths, which are used for symlink tree. - program: List containing real path of program and its arguments. The - execroot directory will be appended as the last argument. - outputs: List of fake outputted paths to copy back to real paths. - Returns: - 0 if succeeded or nonzero if failed. - """ - root = tempfile.mkdtemp() - try: - cwd = os.getcwd() - for fake, real in inputs: - parent = os.path.join(root, os.path.dirname(fake)) - if not os.path.exists(parent): - os.makedirs(parent) - # Use symlink if possible and not on Windows, since on Windows 10 - # symlinks exist but they require administrator privileges to use. - if hasattr(os, 'symlink') and not os.name == 'nt': - os.symlink(os.path.join(cwd, real), os.path.join(root, fake)) - else: - shutil.copyfile(os.path.join(cwd, real), os.path.join(root, fake)) - if subprocess.call(program + [root]) != 0: - return 1 - for fake, real in outputs: - shutil.copyfile(os.path.join(root, fake), real) - return 0 - finally: + Args: + inputs: List of fake paths to real paths, which are used for symlink tree. + program: List containing real path of program and its arguments. The + execroot directory will be appended as the last argument. + outputs: List of fake outputted paths to copy back to real paths. + Returns: + 0 if succeeded or nonzero if failed. + """ + root = tempfile.mkdtemp() try: - shutil.rmtree(root) - except EnvironmentError: - # Ignore "file in use" errors on Windows; ok since it's just a tmpdir. - pass + cwd = os.getcwd() + for fake, real in inputs: + parent = os.path.join(root, os.path.dirname(fake)) + if not os.path.exists(parent): + os.makedirs(parent) + # Use symlink if possible and not on Windows, since on Windows 10 + # symlinks exist but they require administrator privileges to use. + if hasattr(os, "symlink") and not os.name == "nt": + os.symlink(os.path.join(cwd, real), os.path.join(root, fake)) + else: + shutil.copyfile( + os.path.join(cwd, real), os.path.join(root, fake) + ) + if subprocess.call(program + [root]) != 0: + return 1 + for fake, real in outputs: + shutil.copyfile(os.path.join(root, fake), real) + return 0 + finally: + try: + shutil.rmtree(root) + except EnvironmentError: + # Ignore "file in use" errors on Windows; ok since it's just a tmpdir. + pass def main(args): - """Invokes run function using a JSON file config. + """Invokes run function using a JSON file config. - Args: - args: CLI args, which can be a JSON file containing an object whose - attributes are the parameters to the run function. If multiple JSON - files are passed, their contents are concatenated. - Returns: - 0 if succeeded or nonzero if failed. - Raises: - Exception: If input data is missing. - """ - if not args: - raise Exception('Please specify at least one JSON config path') - inputs = [] - program = [] - outputs = [] - for arg in args: - with open(arg) as fd: - config = json.load(fd) - inputs.extend(config.get('inputs', [])) - program.extend(config.get('program', [])) - outputs.extend(config.get('outputs', [])) - if not program: - raise Exception('Please specify a program') - return run(inputs, program, outputs) + Args: + args: CLI args, which can be a JSON file containing an object whose + attributes are the parameters to the run function. If multiple JSON + files are passed, their contents are concatenated. + Returns: + 0 if succeeded or nonzero if failed. + Raises: + Exception: If input data is missing. + """ + if not args: + raise Exception("Please specify at least one JSON config path") + inputs = [] + program = [] + outputs = [] + for arg in args: + with open(arg) as fd: + config = json.load(fd) + inputs.extend(config.get("inputs", [])) + program.extend(config.get("program", [])) + outputs.extend(config.get("outputs", [])) + if not program: + raise Exception("Please specify a program") + return run(inputs, program, outputs) -if __name__ == '__main__': - sys.exit(main(sys.argv[1:])) +if __name__ == "__main__": + sys.exit(main(sys.argv[1:])) diff --git a/tensorboard/scripts/generate_testdata.py b/tensorboard/scripts/generate_testdata.py index 69d246a2ce..a10a4b6be8 100644 --- a/tensorboard/scripts/generate_testdata.py +++ b/tensorboard/scripts/generate_testdata.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Generate some standard test data for debugging TensorBoard. -""" +"""Generate some standard test data for debugging TensorBoard.""" from __future__ import absolute_import from __future__ import division @@ -33,11 +32,19 @@ import tensorflow as tf -flags.DEFINE_string("target", None, """The directory where serialized data -will be written""") +flags.DEFINE_string( + "target", + None, + """The directory where serialized data +will be written""", +) -flags.DEFINE_boolean("overwrite", False, """Whether to remove and overwrite -TARGET if it already exists.""") +flags.DEFINE_boolean( + "overwrite", + False, + """Whether to remove and overwrite +TARGET if it already exists.""", +) FLAGS = flags.FLAGS @@ -47,182 +54,194 @@ def _MakeHistogramBuckets(): - v = 1E-12 - buckets = [] - neg_buckets = [] - while v < 1E20: - buckets.append(v) - neg_buckets.append(-v) - v *= 1.1 - # Should include DBL_MAX, but won't bother for test data. - return neg_buckets[::-1] + [0] + buckets + v = 1e-12 + buckets = [] + neg_buckets = [] + while v < 1e20: + buckets.append(v) + neg_buckets.append(-v) + v *= 1.1 + # Should include DBL_MAX, but won't bother for test data. + return neg_buckets[::-1] + [0] + buckets def _MakeHistogram(values): - """Convert values into a histogram proto using logic from histogram.cc.""" - limits = _MakeHistogramBuckets() - counts = [0] * len(limits) - for v in values: - idx = bisect.bisect_left(limits, v) - counts[idx] += 1 - - limit_counts = [(limits[i], counts[i]) for i in xrange(len(limits)) - if counts[i]] - bucket_limit = [lc[0] for lc in limit_counts] - bucket = [lc[1] for lc in limit_counts] - sum_sq = sum(v * v for v in values) - return tf.compat.v1.HistogramProto( - min=min(values), - max=max(values), - num=len(values), - sum=sum(values), - sum_squares=sum_sq, - bucket_limit=bucket_limit, - bucket=bucket) + """Convert values into a histogram proto using logic from histogram.cc.""" + limits = _MakeHistogramBuckets() + counts = [0] * len(limits) + for v in values: + idx = bisect.bisect_left(limits, v) + counts[idx] += 1 + + limit_counts = [ + (limits[i], counts[i]) for i in xrange(len(limits)) if counts[i] + ] + bucket_limit = [lc[0] for lc in limit_counts] + bucket = [lc[1] for lc in limit_counts] + sum_sq = sum(v * v for v in values) + return tf.compat.v1.HistogramProto( + min=min(values), + max=max(values), + num=len(values), + sum=sum(values), + sum_squares=sum_sq, + bucket_limit=bucket_limit, + bucket=bucket, + ) def WriteScalarSeries(writer, tag, f, n=5): - """Write a series of scalar events to writer, using f to create values.""" - step = 0 - wall_time = _start_time - for i in xrange(n): - v = f(i) - value = tf.Summary.Value(tag=tag, simple_value=v) - summary = tf.Summary(value=[value]) - event = tf.Event(wall_time=wall_time, step=step, summary=summary) - writer.add_event(event) - step += 1 - wall_time += 10 + """Write a series of scalar events to writer, using f to create values.""" + step = 0 + wall_time = _start_time + for i in xrange(n): + v = f(i) + value = tf.Summary.Value(tag=tag, simple_value=v) + summary = tf.Summary(value=[value]) + event = tf.Event(wall_time=wall_time, step=step, summary=summary) + writer.add_event(event) + step += 1 + wall_time += 10 def WriteHistogramSeries(writer, tag, mu_sigma_tuples, n=20): - """Write a sequence of normally distributed histograms to writer.""" - step = 0 - wall_time = _start_time - for [mean, stddev] in mu_sigma_tuples: - data = [random.normalvariate(mean, stddev) for _ in xrange(n)] - histo = _MakeHistogram(data) - summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=histo)]) - event = tf.Event(wall_time=wall_time, step=step, summary=summary) - writer.add_event(event) - step += 10 - wall_time += 100 + """Write a sequence of normally distributed histograms to writer.""" + step = 0 + wall_time = _start_time + for [mean, stddev] in mu_sigma_tuples: + data = [random.normalvariate(mean, stddev) for _ in xrange(n)] + histo = _MakeHistogram(data) + summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=histo)]) + event = tf.Event(wall_time=wall_time, step=step, summary=summary) + writer.add_event(event) + step += 10 + wall_time += 100 def WriteImageSeries(writer, tag, n_images=1): - """Write a few dummy images to writer.""" - step = 0 - session = tf.compat.v1.Session() - p = tf.compat.v1.placeholder("uint8", (1, 4, 4, 3)) - s = tf.compat.v1.summary.image(tag, p) - for _ in xrange(n_images): - im = np.random.random_integers(0, 255, (1, 4, 4, 3)) - summ = session.run(s, feed_dict={p: im}) - writer.add_summary(summ, step) - step += 20 - session.close() + """Write a few dummy images to writer.""" + step = 0 + session = tf.compat.v1.Session() + p = tf.compat.v1.placeholder("uint8", (1, 4, 4, 3)) + s = tf.compat.v1.summary.image(tag, p) + for _ in xrange(n_images): + im = np.random.random_integers(0, 255, (1, 4, 4, 3)) + summ = session.run(s, feed_dict={p: im}) + writer.add_summary(summ, step) + step += 20 + session.close() def WriteAudioSeries(writer, tag, n_audio=1): - """Write a few dummy audio clips to writer.""" - step = 0 - session = tf.compat.v1.Session() - - min_frequency_hz = 440 - max_frequency_hz = 880 - sample_rate = 4000 - duration_frames = sample_rate // 2 # 0.5 seconds. - frequencies_per_run = 1 - num_channels = 2 - - p = tf.compat.v1.placeholder("float32", (frequencies_per_run, duration_frames, - num_channels)) - s = tf.compat.v1.summary.audio(tag, p, sample_rate) - - for _ in xrange(n_audio): - # Generate a different frequency for each channel to show stereo works. - frequencies = np.random.random_integers( - min_frequency_hz, - max_frequency_hz, - size=(frequencies_per_run, num_channels)) - tiled_frequencies = np.tile(frequencies, (1, duration_frames)) - tiled_increments = np.tile( - np.arange(0, duration_frames), - (num_channels, 1)).T.reshape(1, duration_frames * num_channels) - tones = np.sin(2.0 * np.pi * tiled_frequencies * tiled_increments / - sample_rate) - tones = tones.reshape(frequencies_per_run, duration_frames, num_channels) - - summ = session.run(s, feed_dict={p: tones}) - writer.add_summary(summ, step) - step += 20 - session.close() + """Write a few dummy audio clips to writer.""" + step = 0 + session = tf.compat.v1.Session() + + min_frequency_hz = 440 + max_frequency_hz = 880 + sample_rate = 4000 + duration_frames = sample_rate // 2 # 0.5 seconds. + frequencies_per_run = 1 + num_channels = 2 + + p = tf.compat.v1.placeholder( + "float32", (frequencies_per_run, duration_frames, num_channels) + ) + s = tf.compat.v1.summary.audio(tag, p, sample_rate) + + for _ in xrange(n_audio): + # Generate a different frequency for each channel to show stereo works. + frequencies = np.random.random_integers( + min_frequency_hz, + max_frequency_hz, + size=(frequencies_per_run, num_channels), + ) + tiled_frequencies = np.tile(frequencies, (1, duration_frames)) + tiled_increments = np.tile( + np.arange(0, duration_frames), (num_channels, 1) + ).T.reshape(1, duration_frames * num_channels) + tones = np.sin( + 2.0 * np.pi * tiled_frequencies * tiled_increments / sample_rate + ) + tones = tones.reshape( + frequencies_per_run, duration_frames, num_channels + ) + + summ = session.run(s, feed_dict={p: tones}) + writer.add_summary(summ, step) + step += 20 + session.close() def GenerateTestData(path): - """Generates the test data directory.""" - run1_path = os.path.join(path, "run1") - os.makedirs(run1_path) - writer1 = tf.summary.FileWriter(run1_path) - WriteScalarSeries(writer1, "foo/square", lambda x: x * x) - WriteScalarSeries(writer1, "bar/square", lambda x: x * x) - WriteScalarSeries(writer1, "foo/sin", math.sin) - WriteScalarSeries(writer1, "foo/cos", math.cos) - WriteHistogramSeries(writer1, "histo1", [[0, 1], [0.3, 1], [0.5, 1], [0.7, 1], - [1, 1]]) - WriteImageSeries(writer1, "im1") - WriteImageSeries(writer1, "im2") - WriteAudioSeries(writer1, "au1") - - run2_path = os.path.join(path, "run2") - os.makedirs(run2_path) - writer2 = tf.summary.FileWriter(run2_path) - WriteScalarSeries(writer2, "foo/square", lambda x: x * x * 2) - WriteScalarSeries(writer2, "bar/square", lambda x: x * x * 3) - WriteScalarSeries(writer2, "foo/cos", lambda x: math.cos(x) * 2) - WriteHistogramSeries(writer2, "histo1", [[0, 2], [0.3, 2], [0.5, 2], [0.7, 2], - [1, 2]]) - WriteHistogramSeries(writer2, "histo2", [[0, 1], [0.3, 1], [0.5, 1], [0.7, 1], - [1, 1]]) - WriteImageSeries(writer2, "im1") - WriteAudioSeries(writer2, "au2") - - graph_def = tf.compat.v1.GraphDef() - node1 = graph_def.node.add() - node1.name = "a" - node1.op = "matmul" - node2 = graph_def.node.add() - node2.name = "b" - node2.op = "matmul" - node2.input.extend(["a:0"]) - - writer1.add_graph(graph_def) - node3 = graph_def.node.add() - node3.name = "c" - node3.op = "matmul" - node3.input.extend(["a:0", "b:0"]) - writer2.add_graph(graph_def) - writer1.close() - writer2.close() + """Generates the test data directory.""" + run1_path = os.path.join(path, "run1") + os.makedirs(run1_path) + writer1 = tf.summary.FileWriter(run1_path) + WriteScalarSeries(writer1, "foo/square", lambda x: x * x) + WriteScalarSeries(writer1, "bar/square", lambda x: x * x) + WriteScalarSeries(writer1, "foo/sin", math.sin) + WriteScalarSeries(writer1, "foo/cos", math.cos) + WriteHistogramSeries( + writer1, "histo1", [[0, 1], [0.3, 1], [0.5, 1], [0.7, 1], [1, 1]] + ) + WriteImageSeries(writer1, "im1") + WriteImageSeries(writer1, "im2") + WriteAudioSeries(writer1, "au1") + + run2_path = os.path.join(path, "run2") + os.makedirs(run2_path) + writer2 = tf.summary.FileWriter(run2_path) + WriteScalarSeries(writer2, "foo/square", lambda x: x * x * 2) + WriteScalarSeries(writer2, "bar/square", lambda x: x * x * 3) + WriteScalarSeries(writer2, "foo/cos", lambda x: math.cos(x) * 2) + WriteHistogramSeries( + writer2, "histo1", [[0, 2], [0.3, 2], [0.5, 2], [0.7, 2], [1, 2]] + ) + WriteHistogramSeries( + writer2, "histo2", [[0, 1], [0.3, 1], [0.5, 1], [0.7, 1], [1, 1]] + ) + WriteImageSeries(writer2, "im1") + WriteAudioSeries(writer2, "au2") + + graph_def = tf.compat.v1.GraphDef() + node1 = graph_def.node.add() + node1.name = "a" + node1.op = "matmul" + node2 = graph_def.node.add() + node2.name = "b" + node2.op = "matmul" + node2.input.extend(["a:0"]) + + writer1.add_graph(graph_def) + node3 = graph_def.node.add() + node3.name = "c" + node3.op = "matmul" + node3.input.extend(["a:0", "b:0"]) + writer2.add_graph(graph_def) + writer1.close() + writer2.close() def main(unused_argv=None): - target = FLAGS.target - if not target: - print("The --target flag is required.") - return -1 - if os.path.exists(target): - if FLAGS.overwrite: - if os.path.isdir(target): - shutil.rmtree(target) - else: - os.remove(target) - else: - print("Refusing to overwrite target %s without --overwrite" % target) - return -2 - GenerateTestData(target) - return 0 + target = FLAGS.target + if not target: + print("The --target flag is required.") + return -1 + if os.path.exists(target): + if FLAGS.overwrite: + if os.path.isdir(target): + shutil.rmtree(target) + else: + os.remove(target) + else: + print( + "Refusing to overwrite target %s without --overwrite" % target + ) + return -2 + GenerateTestData(target) + return 0 if __name__ == "__main__": - app.run(main) + app.run(main) diff --git a/tensorboard/summary/__init__.py b/tensorboard/summary/__init__.py index e0d9777754..29f1f55fcd 100644 --- a/tensorboard/summary/__init__.py +++ b/tensorboard/summary/__init__.py @@ -22,16 +22,16 @@ # If the V1 summary API is accessible, load and re-export it here. try: - from tensorboard.summary import v1 + from tensorboard.summary import v1 except ImportError: - pass + pass # Load the V2 summary API if accessible. try: - from tensorboard.summary import v2 - from tensorboard.summary.v2 import * + from tensorboard.summary import v2 + from tensorboard.summary.v2 import * except ImportError: - pass + pass del absolute_import, division, print_function diff --git a/tensorboard/summary/_tf/summary/__init__.py b/tensorboard/summary/_tf/summary/__init__.py index 89045f24cf..3b813b57e8 100644 --- a/tensorboard/summary/_tf/summary/__init__.py +++ b/tensorboard/summary/_tf/summary/__init__.py @@ -23,8 +23,7 @@ # docstring below is what users will see as the tf.summary docstring and in the # generated API documentation, and this is just an implementation detail. -""" -Operations for writing summary data, for use in analysis and visualization. +"""Operations for writing summary data, for use in analysis and visualization. The `tf.summary` module provides APIs for writing summary data. This data can be visualized in TensorBoard, the visualization toolkit that comes with TensorFlow. @@ -77,7 +76,6 @@ def my_func(step): sess.run(step_update) sess.run(writer_flush) ``` - """ from __future__ import absolute_import @@ -87,88 +85,95 @@ def my_func(step): # Keep this import outside the function below for internal sync reasons. import tensorflow as tf + def reexport_tf_summary(): - """Re-export all symbols from the original tf.summary. - - This function finds the original tf.summary V2 API and re-exports all the - symbols from it within this module as well, so that when this module is - patched into the TF API namespace as the new tf.summary, the effect is an - overlay that just adds TensorBoard-provided symbols to the module. - - Finding the original tf.summary V2 API module reliably is a challenge, since - this code runs *during* the overall TF API import process and depending on - the order of imports (which is subject to change), different parts of the API - may or may not be defined at the point in time we attempt to access them. This - code also may be inserted into two places in the API (tf and tf.compat.v2) - and may be re-executed multiple times even for the same place in the API (due - to the TF module import system not populating sys.modules properly), so it - needs to be robust to many different scenarios. - - The one constraint we can count on is that everywhere this module is loaded - (via the component_api_helper mechanism in TF), it's going to be the 'summary' - submodule of a larger API package that already has a 'summary' attribute - that contains the TF-only summary API symbols we need to re-export. This - may either be the original TF-only summary module (the first time we load - this module) or a pre-existing copy of this module (if we're re-loading this - module again). We don't actually need to differentiate those two cases, - because it's okay if we re-import our own TensorBoard-provided symbols; they - will just be overwritten later on in this file. - - So given that guarantee, the approach we take is to first attempt to locate - a TF V2 API package that already has a 'summary' attribute (most likely this - is the parent package into which we're being imported, but not necessarily), - and then do the dynamic version of "from tf_api_package.summary import *". - - Lastly, this logic is encapsulated in a function to avoid symbol leakage. - """ - import sys - - # API packages to check for the original V2 summary API, in preference order - # to avoid going "under the hood" to the _api packages unless necessary. - packages = [ - 'tensorflow', - 'tensorflow.compat.v2', - 'tensorflow_core._api.v2', - 'tensorflow_core._api.v2.compat.v2', - 'tensorflow_core._api.v1.compat.v2', - # Old names for `tensorflow_core._api.*`. - 'tensorflow._api.v2', - 'tensorflow._api.v2.compat.v2', - 'tensorflow._api.v1.compat.v2', - ] - # If we aren't sure we're on V2, don't use tf.summary since it could be V1. - # Note there may be false positives since the __version__ attribute may not be - # defined at this point in the import process. - if not getattr(tf, '__version__', '').startswith('2.'): # noqa: F821 - packages.remove('tensorflow') - - def dynamic_wildcard_import(module): - """Implements the logic of "from module import *" for the given module.""" - symbols = getattr(module, '__all__', None) - if symbols is None: - symbols = [k for k in module.__dict__.keys() if not k.startswith('_')] - globals().update({symbol: getattr(module, symbol) for symbol in symbols}) - - notfound = object() # sentinel value - for package_name in packages: - package = sys.modules.get(package_name, notfound) - if package is notfound: - # Either it isn't in this installation at all (e.g. the _api.vX packages - # are only in API version X), it isn't imported yet, or it was imported - # but not inserted into sys.modules under its user-facing name (for the - # non-'_api' packages), at which point we continue down the list to look - # "under the hood" for it via its '_api' package name. - continue - module = getattr(package, 'summary', None) - if module is None: - # This happens if the package hasn't been fully imported yet. For example, - # the 'tensorflow' package won't yet have 'summary' attribute if we are - # loading this code via the 'tensorflow.compat...' path and 'compat' is - # imported before 'summary' in the 'tensorflow' __init__.py file. - continue - # Success, we hope. Import all the public symbols into this module. - dynamic_wildcard_import(module) - return + """Re-export all symbols from the original tf.summary. + + This function finds the original tf.summary V2 API and re-exports all the + symbols from it within this module as well, so that when this module is + patched into the TF API namespace as the new tf.summary, the effect is an + overlay that just adds TensorBoard-provided symbols to the module. + + Finding the original tf.summary V2 API module reliably is a challenge, since + this code runs *during* the overall TF API import process and depending on + the order of imports (which is subject to change), different parts of the API + may or may not be defined at the point in time we attempt to access them. This + code also may be inserted into two places in the API (tf and tf.compat.v2) + and may be re-executed multiple times even for the same place in the API (due + to the TF module import system not populating sys.modules properly), so it + needs to be robust to many different scenarios. + + The one constraint we can count on is that everywhere this module is loaded + (via the component_api_helper mechanism in TF), it's going to be the 'summary' + submodule of a larger API package that already has a 'summary' attribute + that contains the TF-only summary API symbols we need to re-export. This + may either be the original TF-only summary module (the first time we load + this module) or a pre-existing copy of this module (if we're re-loading this + module again). We don't actually need to differentiate those two cases, + because it's okay if we re-import our own TensorBoard-provided symbols; they + will just be overwritten later on in this file. + + So given that guarantee, the approach we take is to first attempt to locate + a TF V2 API package that already has a 'summary' attribute (most likely this + is the parent package into which we're being imported, but not necessarily), + and then do the dynamic version of "from tf_api_package.summary import *". + + Lastly, this logic is encapsulated in a function to avoid symbol leakage. + """ + import sys + + # API packages to check for the original V2 summary API, in preference order + # to avoid going "under the hood" to the _api packages unless necessary. + packages = [ + "tensorflow", + "tensorflow.compat.v2", + "tensorflow_core._api.v2", + "tensorflow_core._api.v2.compat.v2", + "tensorflow_core._api.v1.compat.v2", + # Old names for `tensorflow_core._api.*`. + "tensorflow._api.v2", + "tensorflow._api.v2.compat.v2", + "tensorflow._api.v1.compat.v2", + ] + # If we aren't sure we're on V2, don't use tf.summary since it could be V1. + # Note there may be false positives since the __version__ attribute may not be + # defined at this point in the import process. + if not getattr(tf, "__version__", "").startswith("2."): # noqa: F821 + packages.remove("tensorflow") + + def dynamic_wildcard_import(module): + """Implements the logic of "from module import *" for the given + module.""" + symbols = getattr(module, "__all__", None) + if symbols is None: + symbols = [ + k for k in module.__dict__.keys() if not k.startswith("_") + ] + globals().update( + {symbol: getattr(module, symbol) for symbol in symbols} + ) + + notfound = object() # sentinel value + for package_name in packages: + package = sys.modules.get(package_name, notfound) + if package is notfound: + # Either it isn't in this installation at all (e.g. the _api.vX packages + # are only in API version X), it isn't imported yet, or it was imported + # but not inserted into sys.modules under its user-facing name (for the + # non-'_api' packages), at which point we continue down the list to look + # "under the hood" for it via its '_api' package name. + continue + module = getattr(package, "summary", None) + if module is None: + # This happens if the package hasn't been fully imported yet. For example, + # the 'tensorflow' package won't yet have 'summary' attribute if we are + # loading this code via the 'tensorflow.compat...' path and 'compat' is + # imported before 'summary' in the 'tensorflow' __init__.py file. + continue + # Success, we hope. Import all the public symbols into this module. + dynamic_wildcard_import(module) + return + reexport_tf_summary() diff --git a/tensorboard/summary/summary_test.py b/tensorboard/summary/summary_test.py index 8a2a8aa300..62a1448017 100644 --- a/tensorboard/summary/summary_test.py +++ b/tensorboard/summary/summary_test.py @@ -33,91 +33,89 @@ class SummaryExportsBaseTest(object): - module = None - plugins = None - allowed = frozenset() - - def test_each_plugin_has_an_export(self): - for plugin in self.plugins: - self.assertIsInstance(getattr(self.module, plugin), collections.Callable) - - def test_plugins_export_pb_functions(self): - for plugin in self.plugins: - self.assertIsInstance( - getattr(self.module, '%s_pb' % plugin), collections.Callable) - - def test_all_exports_correspond_to_plugins(self): - exports = [name for name in dir(self.module) if not name.startswith('_')] - bad_exports = [ - name for name in exports - if name not in self.allowed and not any( - name == plugin or name.startswith('%s_' % plugin) - for plugin in self.plugins) - ] - if bad_exports: - self.fail( - 'The following exports do not correspond to known standard ' - 'plugins: %r. Please mark these as private by prepending an ' - 'underscore to their names, or, if they correspond to a new ' - 'plugin that you are certain should be part of the public API ' - 'forever, add that plugin to the STANDARD_PLUGINS set in this ' - 'module.' % bad_exports) + module = None + plugins = None + allowed = frozenset() + + def test_each_plugin_has_an_export(self): + for plugin in self.plugins: + self.assertIsInstance( + getattr(self.module, plugin), collections.Callable + ) + + def test_plugins_export_pb_functions(self): + for plugin in self.plugins: + self.assertIsInstance( + getattr(self.module, "%s_pb" % plugin), collections.Callable + ) + + def test_all_exports_correspond_to_plugins(self): + exports = [ + name for name in dir(self.module) if not name.startswith("_") + ] + bad_exports = [ + name + for name in exports + if name not in self.allowed + and not any( + name == plugin or name.startswith("%s_" % plugin) + for plugin in self.plugins + ) + ] + if bad_exports: + self.fail( + "The following exports do not correspond to known standard " + "plugins: %r. Please mark these as private by prepending an " + "underscore to their names, or, if they correspond to a new " + "plugin that you are certain should be part of the public API " + "forever, add that plugin to the STANDARD_PLUGINS set in this " + "module." % bad_exports + ) class SummaryExportsTest(SummaryExportsBaseTest, unittest.TestCase): - module = tb_summary - allowed = frozenset(('v1', 'v2')) - plugins = frozenset([ - 'audio', - 'histogram', - 'image', - 'scalar', - 'text', - ]) + module = tb_summary + allowed = frozenset(("v1", "v2")) + plugins = frozenset(["audio", "histogram", "image", "scalar", "text",]) - def test_plugins_export_pb_functions(self): - self.skipTest('V2 summary API _pb functions are not finalized yet') + def test_plugins_export_pb_functions(self): + self.skipTest("V2 summary API _pb functions are not finalized yet") class SummaryExportsV1Test(SummaryExportsBaseTest, unittest.TestCase): - module = tb_summary_v1 - plugins = frozenset([ - 'audio', - 'custom_scalar', - 'histogram', - 'image', - 'pr_curve', - 'scalar', - 'text', - ]) + module = tb_summary_v1 + plugins = frozenset( + [ + "audio", + "custom_scalar", + "histogram", + "image", + "pr_curve", + "scalar", + "text", + ] + ) class SummaryExportsV2Test(SummaryExportsBaseTest, unittest.TestCase): - module = tb_summary_v2 - plugins = frozenset([ - 'audio', - 'histogram', - 'image', - 'scalar', - 'text', - ]) + module = tb_summary_v2 + plugins = frozenset(["audio", "histogram", "image", "scalar", "text",]) - def test_plugins_export_pb_functions(self): - self.skipTest('V2 summary API _pb functions are not finalized yet') + def test_plugins_export_pb_functions(self): + self.skipTest("V2 summary API _pb functions are not finalized yet") class SummaryDepTest(unittest.TestCase): - - def test_no_tf_dep(self): - # Check as a precondition that TF wasn't already imported. - self.assertEqual('notfound', sys.modules.get('tensorflow', 'notfound')) - # Check that referencing summary API symbols still avoids a TF import - # as long as we don't actually invoke any API functions. - for module in (tb_summary, tb_summary_v1, tb_summary_v2): - print(dir(module)) - print(module.scalar) - self.assertEqual('notfound', sys.modules.get('tensorflow', 'notfound')) - - -if __name__ == '__main__': - unittest.main() + def test_no_tf_dep(self): + # Check as a precondition that TF wasn't already imported. + self.assertEqual("notfound", sys.modules.get("tensorflow", "notfound")) + # Check that referencing summary API symbols still avoids a TF import + # as long as we don't actually invoke any API functions. + for module in (tb_summary, tb_summary_v1, tb_summary_v2): + print(dir(module)) + print(module.scalar) + self.assertEqual("notfound", sys.modules.get("tensorflow", "notfound")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tensorboard/summary/tf_summary_test.py b/tensorboard/summary/tf_summary_test.py index ce4c3d23c0..7a91b5dd96 100644 --- a/tensorboard/summary/tf_summary_test.py +++ b/tensorboard/summary/tf_summary_test.py @@ -31,29 +31,30 @@ class TfSummaryExportTest(unittest.TestCase): - - def test_tf_summary_export(self): - # Ensure that TF wasn't already imported, since we want this test to cover - # the entire flow of "import tensorflow; use tf.summary" and if TF was in - # fact already imported that reduces the comprehensiveness of the test. - # This means this test has to be kept in its own file and that no other - # test methods in this file should import tensorflow. - self.assertEqual('notfound', sys.modules.get('tensorflow', 'notfound')) - import tensorflow as tf - if not tf.__version__.startswith('2.'): - if hasattr(tf, 'compat') and hasattr(tf.compat, 'v2'): - tf = tf.compat.v2 - else: - self.skipTest('TF v2 summary API not available') - # Check that tf.summary contains both TB-provided and TF-provided symbols. - expected_symbols = frozenset( - ['scalar', 'image', 'audio', 'histogram', 'text'] - + ['write', 'create_file_writer', 'SummaryWriter']) - self.assertLessEqual(expected_symbols, frozenset(dir(tf.summary))) - # Ensure we can dereference symbols as well. - print(tf.summary.scalar) - print(tf.summary.write) - - -if __name__ == '__main__': - unittest.main() + def test_tf_summary_export(self): + # Ensure that TF wasn't already imported, since we want this test to cover + # the entire flow of "import tensorflow; use tf.summary" and if TF was in + # fact already imported that reduces the comprehensiveness of the test. + # This means this test has to be kept in its own file and that no other + # test methods in this file should import tensorflow. + self.assertEqual("notfound", sys.modules.get("tensorflow", "notfound")) + import tensorflow as tf + + if not tf.__version__.startswith("2."): + if hasattr(tf, "compat") and hasattr(tf.compat, "v2"): + tf = tf.compat.v2 + else: + self.skipTest("TF v2 summary API not available") + # Check that tf.summary contains both TB-provided and TF-provided symbols. + expected_symbols = frozenset( + ["scalar", "image", "audio", "histogram", "text"] + + ["write", "create_file_writer", "SummaryWriter"] + ) + self.assertLessEqual(expected_symbols, frozenset(dir(tf.summary))) + # Ensure we can dereference symbols as well. + print(tf.summary.scalar) + print(tf.summary.write) + + +if __name__ == "__main__": + unittest.main() diff --git a/tensorboard/summary/writer/event_file_writer.py b/tensorboard/summary/writer/event_file_writer.py index adecd6823c..5fefcc5cab 100644 --- a/tensorboard/summary/writer/event_file_writer.py +++ b/tensorboard/summary/writer/event_file_writer.py @@ -31,16 +31,16 @@ class AtomicCounter(object): - def __init__(self, initial_value): - self._value = initial_value - self._lock = threading.Lock() + def __init__(self, initial_value): + self._value = initial_value + self._lock = threading.Lock() - def get(self): - with self._lock: - try: - return self._value - finally: - self._value += 1 + def get(self): + with self._lock: + try: + return self._value + finally: + self._value += 1 _global_uid = AtomicCounter(0) @@ -49,12 +49,15 @@ def get(self): class EventFileWriter(object): """Writes `Event` protocol buffers to an event file. - The `EventFileWriter` class creates an event file in the specified directory, - and asynchronously writes Event protocol buffers to the file. The Event file - is encoded using the tfrecord format, which is similar to RecordIO. + The `EventFileWriter` class creates an event file in the specified + directory, and asynchronously writes Event protocol buffers to the + file. The Event file is encoded using the tfrecord format, which is + similar to RecordIO. """ - def __init__(self, logdir, max_queue_size=10, flush_secs=120, filename_suffix=''): + def __init__( + self, logdir, max_queue_size=10, flush_secs=120, filename_suffix="" + ): """Creates a `EventFileWriter` and an event file to write to. On construction the summary writer creates a new event file in `logdir`. @@ -72,13 +75,28 @@ def __init__(self, logdir, max_queue_size=10, flush_secs=120, filename_suffix='' self._logdir = logdir if not tf.io.gfile.exists(logdir): tf.io.gfile.makedirs(logdir) - self._file_name = os.path.join(logdir, "events.out.tfevents.%010d.%s.%s.%s" % - (time.time(), socket.gethostname(), os.getpid(), _global_uid.get())) + filename_suffix # noqa E128 - self._general_file_writer = tf.io.gfile.GFile(self._file_name, 'wb') - self._async_writer = _AsyncWriter(RecordWriter(self._general_file_writer), max_queue_size, flush_secs) + self._file_name = ( + os.path.join( + logdir, + "events.out.tfevents.%010d.%s.%s.%s" + % ( + time.time(), + socket.gethostname(), + os.getpid(), + _global_uid.get(), + ), + ) + + filename_suffix + ) # noqa E128 + self._general_file_writer = tf.io.gfile.GFile(self._file_name, "wb") + self._async_writer = _AsyncWriter( + RecordWriter(self._general_file_writer), max_queue_size, flush_secs + ) # Initialize an event instance. - _event = event_pb2.Event(wall_time=time.time(), file_version='brain.Event:2') + _event = event_pb2.Event( + wall_time=time.time(), file_version="brain.Event:2" + ) self.add_event(_event) self.flush() @@ -93,37 +111,42 @@ def add_event(self, event): event: An `Event` protocol buffer. """ if not isinstance(event, event_pb2.Event): - raise TypeError("Expected an event_pb2.Event proto, " - " but got %s" % type(event)) + raise TypeError( + "Expected an event_pb2.Event proto, " + " but got %s" % type(event) + ) self._async_writer.write(event.SerializeToString()) def flush(self): """Flushes the event file to disk. - Call this method to make sure that all pending events have been written to - disk. + Call this method to make sure that all pending events have been + written to disk. """ self._async_writer.flush() def close(self): """Performs a final flush of the event file to disk, stops the - write/flush worker and closes the file. Call this method when you do not - need the summary writer anymore. + write/flush worker and closes the file. + + Call this method when you do not need the summary writer + anymore. """ self._async_writer.close() class _AsyncWriter(object): - '''Writes bytes to a file.''' + """Writes bytes to a file.""" def __init__(self, record_writer, max_queue_size=20, flush_secs=120): - """Writes bytes to a file asynchronously. - An instance of this class holds a queue to keep the incoming data temporarily. - Data passed to the `write` function will be put to the queue and the function - returns immediately. This class also maintains a thread to write data in the + """Writes bytes to a file asynchronously. An instance of this class + holds a queue to keep the incoming data temporarily. Data passed to the + `write` function will be put to the queue and the function returns + immediately. This class also maintains a thread to write data in the queue to disk. The first initialization parameter is an instance of - `tensorboard.summary.record_writer` which computes the CRC checksum and then write - the combined result to the disk. So we use an async approach to improve performance. + `tensorboard.summary.record_writer` which computes the CRC checksum and + then write the combined result to the disk. So we use an async approach + to improve performance. Args: record_writer: A RecordWriter instance @@ -134,29 +157,32 @@ def __init__(self, record_writer, max_queue_size=20, flush_secs=120): self._writer = record_writer self._closed = False self._byte_queue = six.moves.queue.Queue(max_queue_size) - self._worker = _AsyncWriterThread(self._byte_queue, self._writer, flush_secs) + self._worker = _AsyncWriterThread( + self._byte_queue, self._writer, flush_secs + ) self._lock = threading.Lock() self._worker.start() def write(self, bytestring): - '''Enqueue the given bytes to be written asychronously''' + """Enqueue the given bytes to be written asychronously.""" with self._lock: if self._closed: - raise IOError('Writer is closed') + raise IOError("Writer is closed") self._byte_queue.put(bytestring) def flush(self): - '''Write all the enqueued bytestring before this flush call to disk. + """Write all the enqueued bytestring before this flush call to disk. + Block until all the above bytestring are written. - ''' + """ with self._lock: if self._closed: - raise IOError('Writer is closed') + raise IOError("Writer is closed") self._byte_queue.join() self._writer.flush() def close(self): - '''Closes the underlying writer, flushing any pending writes first.''' + """Closes the underlying writer, flushing any pending writes first.""" if not self._closed: with self._lock: if not self._closed: @@ -171,6 +197,7 @@ class _AsyncWriterThread(threading.Thread): def __init__(self, queue, record_writer, flush_secs): """Creates an _AsyncWriterThread. + Args: queue: A Queue from which to dequeue data. record_writer: An instance of record_writer writer. diff --git a/tensorboard/summary/writer/event_file_writer_s3_test.py b/tensorboard/summary/writer/event_file_writer_s3_test.py index 920cefc1d7..e2d509a9f2 100644 --- a/tensorboard/summary/writer/event_file_writer_s3_test.py +++ b/tensorboard/summary/writer/event_file_writer_s3_test.py @@ -28,7 +28,9 @@ from tensorboard.compat.proto import event_pb2 from tensorboard.compat.proto.summary_pb2 import Summary from tensorboard.compat.tensorflow_stub.io import gfile -from tensorboard.compat.tensorflow_stub.pywrap_tensorflow import PyRecordReader_New +from tensorboard.compat.tensorflow_stub.pywrap_tensorflow import ( + PyRecordReader_New, +) from moto import mock_s3 from tensorboard import test as tb_test @@ -37,62 +39,64 @@ os.environ.setdefault("AWS_ACCESS_KEY_ID", "foobar_key") os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "foobar_secret") -USING_REAL_TF = tf.__version__ != 'stub' +USING_REAL_TF = tf.__version__ != "stub" -def s3_temp_dir(top_directory='top_dir', bucket_name='test', - region_name='us-east-1'): - """Creates a test S3 bucket and returns directory location. +def s3_temp_dir( + top_directory="top_dir", bucket_name="test", region_name="us-east-1" +): + """Creates a test S3 bucket and returns directory location. - Args: - top_directory: The path of the top level S3 directory in which - to create the directory structure. Defaults to 'top_dir'. - bucket_name: The S3 bucket name. Defaults to 'test'. - region_name: The S3 region name. Defaults to 'us-east-1'. + Args: + top_directory: The path of the top level S3 directory in which + to create the directory structure. Defaults to 'top_dir'. + bucket_name: The S3 bucket name. Defaults to 'test'. + region_name: The S3 region name. Defaults to 'us-east-1'. - Returns S3 URL of the top directory in the form 's3://bucket/path' - """ - s3_url = 's3://{}/{}'.format(bucket_name, top_directory) - client = boto3.client('s3', region_name=region_name) - client.create_bucket(Bucket=bucket_name) - return s3_url + Returns S3 URL of the top directory in the form 's3://bucket/path' + """ + s3_url = "s3://{}/{}".format(bucket_name, top_directory) + client = boto3.client("s3", region_name=region_name) + client.create_bucket(Bucket=bucket_name) + return s3_url def s3_join(*args): """Joins an S3 directory path as a replacement for os.path.join.""" - return '/'.join(args) + return "/".join(args) class EventFileWriterTest(tb_test.TestCase): - - @unittest.skipIf(USING_REAL_TF, 'Test only passes when using stub TF') - @mock_s3 - def test_event_file_writer_roundtrip(self): - _TAGNAME = 'dummy' - _DUMMY_VALUE = 42 - logdir = s3_temp_dir() - w = EventFileWriter(logdir) - summary = Summary(value=[Summary.Value(tag=_TAGNAME, simple_value=_DUMMY_VALUE)]) - fakeevent = event_pb2.Event(summary=summary) - w.add_event(fakeevent) - w.close() - event_files = sorted(gfile.glob(s3_join(logdir, '*'))) - self.assertEqual(len(event_files), 1) - r = PyRecordReader_New(event_files[0]) - r.GetNext() # meta data, so skip - r.GetNext() - self.assertEqual(fakeevent.SerializeToString(), r.record()) - - @unittest.skipIf(USING_REAL_TF, 'Test only passes when using stub TF') - @mock_s3 - def test_setting_filename_suffix_works(self): - logdir = s3_temp_dir() - - w = EventFileWriter(logdir, filename_suffix='.event_horizon') - w.close() - event_files = sorted(gfile.glob(s3_join(logdir, '*'))) - self.assertEqual(event_files[0].split('.')[-1], 'event_horizon') - - -if __name__ == '__main__': - tb_test.main() + @unittest.skipIf(USING_REAL_TF, "Test only passes when using stub TF") + @mock_s3 + def test_event_file_writer_roundtrip(self): + _TAGNAME = "dummy" + _DUMMY_VALUE = 42 + logdir = s3_temp_dir() + w = EventFileWriter(logdir) + summary = Summary( + value=[Summary.Value(tag=_TAGNAME, simple_value=_DUMMY_VALUE)] + ) + fakeevent = event_pb2.Event(summary=summary) + w.add_event(fakeevent) + w.close() + event_files = sorted(gfile.glob(s3_join(logdir, "*"))) + self.assertEqual(len(event_files), 1) + r = PyRecordReader_New(event_files[0]) + r.GetNext() # meta data, so skip + r.GetNext() + self.assertEqual(fakeevent.SerializeToString(), r.record()) + + @unittest.skipIf(USING_REAL_TF, "Test only passes when using stub TF") + @mock_s3 + def test_setting_filename_suffix_works(self): + logdir = s3_temp_dir() + + w = EventFileWriter(logdir, filename_suffix=".event_horizon") + w.close() + event_files = sorted(gfile.glob(s3_join(logdir, "*"))) + self.assertEqual(event_files[0].split(".")[-1], "event_horizon") + + +if __name__ == "__main__": + tb_test.main() diff --git a/tensorboard/summary/writer/event_file_writer_test.py b/tensorboard/summary/writer/event_file_writer_test.py index acbef412b8..dee0ba84ba 100644 --- a/tensorboard/summary/writer/event_file_writer_test.py +++ b/tensorboard/summary/writer/event_file_writer_test.py @@ -26,102 +26,112 @@ from tensorboard.summary.writer.event_file_writer import _AsyncWriter from tensorboard.compat.proto import event_pb2 from tensorboard.compat.proto.summary_pb2 import Summary -from tensorboard.compat.tensorflow_stub.pywrap_tensorflow import PyRecordReader_New +from tensorboard.compat.tensorflow_stub.pywrap_tensorflow import ( + PyRecordReader_New, +) from tensorboard import test as tb_test class EventFileWriterTest(tb_test.TestCase): - - def test_event_file_writer_roundtrip(self): - _TAGNAME = 'dummy' - _DUMMY_VALUE = 42 - logdir = self.get_temp_dir() - w = EventFileWriter(logdir) - summary = Summary(value=[Summary.Value(tag=_TAGNAME, simple_value=_DUMMY_VALUE)]) - fakeevent = event_pb2.Event(summary=summary) - w.add_event(fakeevent) - w.close() - event_files = sorted(glob.glob(os.path.join(logdir, '*'))) - self.assertEqual(len(event_files), 1) - r = PyRecordReader_New(event_files[0]) - r.GetNext() # meta data, so skip - r.GetNext() - self.assertEqual(fakeevent.SerializeToString(), r.record()) - - def test_setting_filename_suffix_works(self): - logdir = self.get_temp_dir() - - w = EventFileWriter(logdir, filename_suffix='.event_horizon') - w.close() - event_files = sorted(glob.glob(os.path.join(logdir, '*'))) - self.assertEqual(event_files[0].split('.')[-1], 'event_horizon') - - def test_async_writer_without_write(self): - logdir = self.get_temp_dir() - w = EventFileWriter(logdir) - w.close() - event_files = sorted(glob.glob(os.path.join(logdir, '*'))) - r = PyRecordReader_New(event_files[0]) - r.GetNext() - s = event_pb2.Event.FromString(r.record()) - self.assertEqual(s.file_version, "brain.Event:2") + def test_event_file_writer_roundtrip(self): + _TAGNAME = "dummy" + _DUMMY_VALUE = 42 + logdir = self.get_temp_dir() + w = EventFileWriter(logdir) + summary = Summary( + value=[Summary.Value(tag=_TAGNAME, simple_value=_DUMMY_VALUE)] + ) + fakeevent = event_pb2.Event(summary=summary) + w.add_event(fakeevent) + w.close() + event_files = sorted(glob.glob(os.path.join(logdir, "*"))) + self.assertEqual(len(event_files), 1) + r = PyRecordReader_New(event_files[0]) + r.GetNext() # meta data, so skip + r.GetNext() + self.assertEqual(fakeevent.SerializeToString(), r.record()) + + def test_setting_filename_suffix_works(self): + logdir = self.get_temp_dir() + + w = EventFileWriter(logdir, filename_suffix=".event_horizon") + w.close() + event_files = sorted(glob.glob(os.path.join(logdir, "*"))) + self.assertEqual(event_files[0].split(".")[-1], "event_horizon") + + def test_async_writer_without_write(self): + logdir = self.get_temp_dir() + w = EventFileWriter(logdir) + w.close() + event_files = sorted(glob.glob(os.path.join(logdir, "*"))) + r = PyRecordReader_New(event_files[0]) + r.GetNext() + s = event_pb2.Event.FromString(r.record()) + self.assertEqual(s.file_version, "brain.Event:2") class AsyncWriterTest(tb_test.TestCase): - - def test_async_writer_write_once(self): - filename = os.path.join(self.get_temp_dir(), "async_writer_write_once") - w = _AsyncWriter(open(filename, 'wb')) - bytes_to_write = b"hello world" - w.write(bytes_to_write) - w.close() - with open(filename, 'rb') as f: - self.assertEqual(f.read(), bytes_to_write) - - def test_async_writer_write_queue_full(self): - filename = os.path.join(self.get_temp_dir(), "async_writer_write_queue_full") - w = _AsyncWriter(open(filename, 'wb')) - bytes_to_write = b"hello world" - repeat = 100 - for i in range(repeat): - w.write(bytes_to_write) - w.close() - with open(filename, 'rb') as f: - self.assertEqual(f.read(), bytes_to_write * repeat) - - def test_async_writer_write_one_slot_queue(self): - filename = os.path.join(self.get_temp_dir(), "async_writer_write_one_slot_queue") - w = _AsyncWriter(open(filename, 'wb'), max_queue_size=1) - bytes_to_write = b"hello world" - repeat = 10 # faster - for i in range(repeat): - w.write(bytes_to_write) - w.close() - with open(filename, 'rb') as f: - self.assertEqual(f.read(), bytes_to_write * repeat) - - def test_async_writer_close_triggers_flush(self): - filename = os.path.join(self.get_temp_dir(), "async_writer_close_triggers_flush") - w = _AsyncWriter(open(filename, 'wb')) - bytes_to_write = b"x" * 64 - w.write(bytes_to_write) - w.close() - with open(filename, 'rb') as f: - self.assertEqual(f.read(), bytes_to_write) - - def test_write_after_async_writer_closed(self): - filename = os.path.join(self.get_temp_dir(), "write_after_async_writer_closed") - w = _AsyncWriter(open(filename, 'wb')) - bytes_to_write = b"x" * 64 - w.write(bytes_to_write) - w.close() - - with self.assertRaises(IOError): - w.write(bytes_to_write) - # nothing is written to the file after close - with open(filename, 'rb') as f: - self.assertEqual(f.read(), bytes_to_write) - - -if __name__ == '__main__': - tb_test.main() + def test_async_writer_write_once(self): + filename = os.path.join(self.get_temp_dir(), "async_writer_write_once") + w = _AsyncWriter(open(filename, "wb")) + bytes_to_write = b"hello world" + w.write(bytes_to_write) + w.close() + with open(filename, "rb") as f: + self.assertEqual(f.read(), bytes_to_write) + + def test_async_writer_write_queue_full(self): + filename = os.path.join( + self.get_temp_dir(), "async_writer_write_queue_full" + ) + w = _AsyncWriter(open(filename, "wb")) + bytes_to_write = b"hello world" + repeat = 100 + for i in range(repeat): + w.write(bytes_to_write) + w.close() + with open(filename, "rb") as f: + self.assertEqual(f.read(), bytes_to_write * repeat) + + def test_async_writer_write_one_slot_queue(self): + filename = os.path.join( + self.get_temp_dir(), "async_writer_write_one_slot_queue" + ) + w = _AsyncWriter(open(filename, "wb"), max_queue_size=1) + bytes_to_write = b"hello world" + repeat = 10 # faster + for i in range(repeat): + w.write(bytes_to_write) + w.close() + with open(filename, "rb") as f: + self.assertEqual(f.read(), bytes_to_write * repeat) + + def test_async_writer_close_triggers_flush(self): + filename = os.path.join( + self.get_temp_dir(), "async_writer_close_triggers_flush" + ) + w = _AsyncWriter(open(filename, "wb")) + bytes_to_write = b"x" * 64 + w.write(bytes_to_write) + w.close() + with open(filename, "rb") as f: + self.assertEqual(f.read(), bytes_to_write) + + def test_write_after_async_writer_closed(self): + filename = os.path.join( + self.get_temp_dir(), "write_after_async_writer_closed" + ) + w = _AsyncWriter(open(filename, "wb")) + bytes_to_write = b"x" * 64 + w.write(bytes_to_write) + w.close() + + with self.assertRaises(IOError): + w.write(bytes_to_write) + # nothing is written to the file after close + with open(filename, "rb") as f: + self.assertEqual(f.read(), bytes_to_write) + + +if __name__ == "__main__": + tb_test.main() diff --git a/tensorboard/summary/writer/record_writer.py b/tensorboard/summary/writer/record_writer.py index a49a6b14d1..52e14048b5 100644 --- a/tensorboard/summary/writer/record_writer.py +++ b/tensorboard/summary/writer/record_writer.py @@ -18,7 +18,8 @@ class RecordWriter(object): - """Write encoded protobuf to a file with packing defined in tensorflow""" + """Write encoded protobuf to a file with packing defined in tensorflow.""" + def __init__(self, writer): """Open a file to keep the tensorboard records. @@ -33,9 +34,9 @@ def __init__(self, writer): # byte data[length] # uint32 masked crc of data def write(self, data): - header = struct.pack(' 1: - logging.warning("conflicting installations: %s", sorted(actual)) - found_conflict = True - - if found_conflict: - preamble = reflow( - """ + freeze = pip(["freeze", "--all"]).decode("utf-8").splitlines() + packages = {line.split(u"==")[0]: line for line in freeze} + packages_set = frozenset(packages) + + # For each of the following families, expect exactly one package to be + # installed. + expect_unique = [ + frozenset([u"tensorboard", u"tb-nightly", u"tensorflow-tensorboard",]), + frozenset( + [ + u"tensorflow", + u"tensorflow-gpu", + u"tf-nightly", + u"tf-nightly-2.0-preview", + u"tf-nightly-gpu", + u"tf-nightly-gpu-2.0-preview", + ] + ), + frozenset( + [ + u"tensorflow-estimator", + u"tensorflow-estimator-2.0-preview", + u"tf-estimator-nightly", + ] + ), + ] + + found_conflict = False + for family in expect_unique: + actual = family & packages_set + for package in actual: + logging.info("installed: %s", packages[package]) + if len(actual) == 0: + logging.warning("no installation among: %s", sorted(family)) + elif len(actual) > 1: + logging.warning("conflicting installations: %s", sorted(actual)) + found_conflict = True + + if found_conflict: + preamble = reflow( + """ Conflicting package installations found. Depending on the order of installations and uninstallations, behavior may be undefined. Please uninstall ALL versions of TensorFlow and TensorBoard, @@ -223,268 +226,287 @@ def installed_packages(): you use TensorBoard without TensorFlow, just reinstall the appropriate version of TensorBoard directly.) """ - ) - packages_to_uninstall = sorted( - frozenset().union(*expect_unique) & packages_set - ) - commands = [ - "pip uninstall %s" % " ".join(packages_to_uninstall), - "pip install tensorflow # or `tensorflow-gpu`, or `tf-nightly`, ...", - ] - message = "%s\n\nNamely:\n\n%s" % ( - preamble, - "\n".join("\t%s" % c for c in commands), - ) - yield Suggestion("Fix conflicting installations", message) + ) + packages_to_uninstall = sorted( + frozenset().union(*expect_unique) & packages_set + ) + commands = [ + "pip uninstall %s" % " ".join(packages_to_uninstall), + "pip install tensorflow # or `tensorflow-gpu`, or `tf-nightly`, ...", + ] + message = "%s\n\nNamely:\n\n%s" % ( + preamble, + "\n".join("\t%s" % c for c in commands), + ) + yield Suggestion("Fix conflicting installations", message) @check def tensorboard_python_version(): - from tensorboard import version - logging.info("tensorboard.version.VERSION: %r", version.VERSION) + from tensorboard import version + + logging.info("tensorboard.version.VERSION: %r", version.VERSION) @check def tensorflow_python_version(): - import tensorflow as tf - logging.info("tensorflow.__version__: %r", tf.__version__) - logging.info("tensorflow.__git_version__: %r", tf.__git_version__) + import tensorflow as tf + + logging.info("tensorflow.__version__: %r", tf.__version__) + logging.info("tensorflow.__git_version__: %r", tf.__git_version__) @check def tensorboard_binary_path(): - logging.info("which tensorboard: %r", which("tensorboard")) + logging.info("which tensorboard: %r", which("tensorboard")) @check def addrinfos(): - sgetattr("has_ipv6", None) - family = sgetattr("AF_UNSPEC", 0) - socktype = sgetattr("SOCK_STREAM", 0) - protocol = 0 - flags_loopback = sgetattr("AI_ADDRCONFIG", 0) - flags_wildcard = sgetattr("AI_PASSIVE", 0) + sgetattr("has_ipv6", None) + family = sgetattr("AF_UNSPEC", 0) + socktype = sgetattr("SOCK_STREAM", 0) + protocol = 0 + flags_loopback = sgetattr("AI_ADDRCONFIG", 0) + flags_wildcard = sgetattr("AI_PASSIVE", 0) - hints_loopback = (family, socktype, protocol, flags_loopback) - infos_loopback = socket.getaddrinfo(None, 0, *hints_loopback) - print("Loopback flags: %r" % (flags_loopback,)) - print("Loopback infos: %r" % (infos_loopback,)) + hints_loopback = (family, socktype, protocol, flags_loopback) + infos_loopback = socket.getaddrinfo(None, 0, *hints_loopback) + print("Loopback flags: %r" % (flags_loopback,)) + print("Loopback infos: %r" % (infos_loopback,)) - hints_wildcard = (family, socktype, protocol, flags_wildcard) - infos_wildcard = socket.getaddrinfo(None, 0, *hints_wildcard) - print("Wildcard flags: %r" % (flags_wildcard,)) - print("Wildcard infos: %r" % (infos_wildcard,)) + hints_wildcard = (family, socktype, protocol, flags_wildcard) + infos_wildcard = socket.getaddrinfo(None, 0, *hints_wildcard) + print("Wildcard flags: %r" % (flags_wildcard,)) + print("Wildcard infos: %r" % (infos_wildcard,)) @check def readable_fqdn(): - # May raise `UnicodeDecodeError` for non-ASCII hostnames: - # https://github.com/tensorflow/tensorboard/issues/682 - try: - logging.info("socket.getfqdn(): %r", socket.getfqdn()) - except UnicodeDecodeError as e: + # May raise `UnicodeDecodeError` for non-ASCII hostnames: + # https://github.com/tensorflow/tensorboard/issues/682 try: - binary_hostname = subprocess.check_output(["hostname"]).strip() - except subprocess.CalledProcessError: - binary_hostname = b"" - is_non_ascii = not all( - 0x20 <= (ord(c) if not isinstance(c, int) else c) <= 0x7E # Python 2 - for c in binary_hostname - ) - if is_non_ascii: - message = reflow( - """ + logging.info("socket.getfqdn(): %r", socket.getfqdn()) + except UnicodeDecodeError as e: + try: + binary_hostname = subprocess.check_output(["hostname"]).strip() + except subprocess.CalledProcessError: + binary_hostname = b"" + is_non_ascii = not all( + 0x20 + <= (ord(c) if not isinstance(c, int) else c) + <= 0x7E # Python 2 + for c in binary_hostname + ) + if is_non_ascii: + message = reflow( + """ Your computer's hostname, %r, contains bytes outside of the printable ASCII range. Some versions of Python have trouble working with such names (https://bugs.python.org/issue26227). Consider changing to a hostname that only contains printable ASCII bytes. - """ % (binary_hostname,) - ) - yield Suggestion("Use an ASCII hostname", message) - else: - message = reflow( """ + % (binary_hostname,) + ) + yield Suggestion("Use an ASCII hostname", message) + else: + message = reflow( + """ Python can't read your computer's hostname, %r. This can occur if the hostname contains non-ASCII bytes (https://bugs.python.org/issue26227). Consider changing your hostname, rebooting your machine, and rerunning this diagnosis script to see if the problem is resolved. - """ % (binary_hostname,) - ) - yield Suggestion("Use a simpler hostname", message) - raise e + """ + % (binary_hostname,) + ) + yield Suggestion("Use a simpler hostname", message) + raise e @check def stat_tensorboardinfo(): - # We don't use `manager._get_info_dir`, because (a) that requires - # TensorBoard, and (b) that creates the directory if it doesn't exist. - path = os.path.join(tempfile.gettempdir(), ".tensorboard-info") - logging.info("directory: %s", path) - try: - stat_result = os.stat(path) - except OSError as e: - if e.errno == errno.ENOENT: - # No problem; this is just fine. - logging.info(".tensorboard-info directory does not exist") - return - else: - raise - logging.info("os.stat(...): %r", stat_result) - logging.info("mode: 0o%o", stat_result.st_mode) - if stat_result.st_mode & 0o777 != 0o777: - preamble = reflow( - """ + # We don't use `manager._get_info_dir`, because (a) that requires + # TensorBoard, and (b) that creates the directory if it doesn't exist. + path = os.path.join(tempfile.gettempdir(), ".tensorboard-info") + logging.info("directory: %s", path) + try: + stat_result = os.stat(path) + except OSError as e: + if e.errno == errno.ENOENT: + # No problem; this is just fine. + logging.info(".tensorboard-info directory does not exist") + return + else: + raise + logging.info("os.stat(...): %r", stat_result) + logging.info("mode: 0o%o", stat_result.st_mode) + if stat_result.st_mode & 0o777 != 0o777: + preamble = reflow( + """ The ".tensorboard-info" directory was created by an old version of TensorBoard, and its permissions are not set correctly; see issue #2010. Change that directory to be world-accessible (may require superuser privilege): """ - ) - # This error should only appear on Unices, so it's okay to use - # Unix-specific utilities and shell syntax. - quote = getattr(shlex, "quote", None) or pipes.quote # Python <3.3 - command = "chmod 777 %s" % quote(path) - message = "%s\n\n\t%s" % (preamble, command) - yield Suggestion("Fix permissions on \"%s\"" % path, message) + ) + # This error should only appear on Unices, so it's okay to use + # Unix-specific utilities and shell syntax. + quote = getattr(shlex, "quote", None) or pipes.quote # Python <3.3 + command = "chmod 777 %s" % quote(path) + message = "%s\n\n\t%s" % (preamble, command) + yield Suggestion('Fix permissions on "%s"' % path, message) @check def source_trees_without_genfiles(): - roots = list(sys.path) - if "" not in roots: - # Catch problems that would occur in a Python interactive shell - # (where `""` is prepended to `sys.path`) but not when - # `diagnose_tensorboard.py` is run as a standalone script. - roots.insert(0, "") - - def has_tensorboard(root): - return os.path.isfile(os.path.join(root, "tensorboard", "__init__.py")) - def has_genfiles(root): - sample_genfile = os.path.join("compat", "proto", "summary_pb2.py") - return os.path.isfile(os.path.join(root, "tensorboard", sample_genfile)) - def is_bad(root): - return has_tensorboard(root) and not has_genfiles(root) - - tensorboard_roots = [root for root in roots if has_tensorboard(root)] - bad_roots = [root for root in roots if is_bad(root)] - - logging.info( - "tensorboard_roots (%d): %r; bad_roots (%d): %r", - len(tensorboard_roots), - tensorboard_roots, - len(bad_roots), - bad_roots, - ) - - if bad_roots: - if bad_roots == [""]: - message = reflow( - """ + roots = list(sys.path) + if "" not in roots: + # Catch problems that would occur in a Python interactive shell + # (where `""` is prepended to `sys.path`) but not when + # `diagnose_tensorboard.py` is run as a standalone script. + roots.insert(0, "") + + def has_tensorboard(root): + return os.path.isfile(os.path.join(root, "tensorboard", "__init__.py")) + + def has_genfiles(root): + sample_genfile = os.path.join("compat", "proto", "summary_pb2.py") + return os.path.isfile(os.path.join(root, "tensorboard", sample_genfile)) + + def is_bad(root): + return has_tensorboard(root) and not has_genfiles(root) + + tensorboard_roots = [root for root in roots if has_tensorboard(root)] + bad_roots = [root for root in roots if is_bad(root)] + + logging.info( + "tensorboard_roots (%d): %r; bad_roots (%d): %r", + len(tensorboard_roots), + tensorboard_roots, + len(bad_roots), + bad_roots, + ) + + if bad_roots: + if bad_roots == [""]: + message = reflow( + """ Your current directory contains a `tensorboard` Python package that does not include generated files. This can happen if your current directory includes the TensorBoard source tree (e.g., you are in the TensorBoard Git repository). Consider changing to a different directory. """ - ) - else: - preamble = reflow( - """ + ) + else: + preamble = reflow( + """ Your Python path contains a `tensorboard` package that does not include generated files. This can happen if your current directory includes the TensorBoard source tree (e.g., you are in the TensorBoard Git repository). The following directories from your Python path may be problematic: """ - ) - roots = [] - realpaths_seen = set() - for root in bad_roots: - label = repr(root) if root else "current directory" - realpath = os.path.realpath(root) - if realpath in realpaths_seen: - # virtualenvs on Ubuntu install to both `lib` and `local/lib`; - # explicitly call out such duplicates to avoid confusion. - label += " (duplicate underlying directory)" - realpaths_seen.add(realpath) - roots.append(label) - message = "%s\n\n%s" % (preamble, "\n".join(" - %s" % s for s in roots)) - yield Suggestion("Avoid `tensorboard` packages without genfiles", message) + ) + roots = [] + realpaths_seen = set() + for root in bad_roots: + label = repr(root) if root else "current directory" + realpath = os.path.realpath(root) + if realpath in realpaths_seen: + # virtualenvs on Ubuntu install to both `lib` and `local/lib`; + # explicitly call out such duplicates to avoid confusion. + label += " (duplicate underlying directory)" + realpaths_seen.add(realpath) + roots.append(label) + message = "%s\n\n%s" % ( + preamble, + "\n".join(" - %s" % s for s in roots), + ) + yield Suggestion( + "Avoid `tensorboard` packages without genfiles", message + ) # Prefer to include this check last, as its output is long. @check def full_pip_freeze(): - logging.info("pip freeze --all:\n%s", pip(["freeze", "--all"]).decode("utf-8")) + logging.info( + "pip freeze --all:\n%s", pip(["freeze", "--all"]).decode("utf-8") + ) def set_up_logging(): - # Manually install handlers to prevent TensorFlow from stomping the - # default configuration if it's imported: - # https://github.com/tensorflow/tensorflow/issues/28147 - logger = logging.getLogger() - logger.setLevel(logging.INFO) - handler = logging.StreamHandler(sys.stdout) - handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) - logger.addHandler(handler) + # Manually install handlers to prevent TensorFlow from stomping the + # default configuration if it's imported: + # https://github.com/tensorflow/tensorflow/issues/28147 + logger = logging.getLogger() + logger.setLevel(logging.INFO) + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) + logger.addHandler(handler) def main(): - set_up_logging() - - print("### Diagnostics") - print() - - print("
") - print("Diagnostics output") - print() - - markdown_code_fence = "``````" # seems likely to be sufficient - print(markdown_code_fence) - suggestions = [] - for (i, check) in enumerate(CHECKS): - if i > 0: - print() - print("--- check: %s" % check.__name__) - try: - suggestions.extend(check()) - except Exception: - traceback.print_exc(file=sys.stdout) - pass - print(markdown_code_fence) - print() - print("
") - - for suggestion in suggestions: + set_up_logging() + + print("### Diagnostics") print() - print("### Suggestion: %s" % suggestion.headline) + + print("
") + print("Diagnostics output") print() - print(suggestion.description) - print() - print("### Next steps") - print() - if suggestions: - print(reflow( - """ + markdown_code_fence = "``````" # seems likely to be sufficient + print(markdown_code_fence) + suggestions = [] + for (i, check) in enumerate(CHECKS): + if i > 0: + print() + print("--- check: %s" % check.__name__) + try: + suggestions.extend(check()) + except Exception: + traceback.print_exc(file=sys.stdout) + pass + print(markdown_code_fence) + print() + print("
") + + for suggestion in suggestions: + print() + print("### Suggestion: %s" % suggestion.headline) + print() + print(suggestion.description) + + print() + print("### Next steps") + print() + if suggestions: + print( + reflow( + """ Please try each suggestion enumerated above to determine whether it solves your problem. If none of these suggestions works, please copy ALL of the above output, including the lines containing only backticks, into your GitHub issue or comment. Be sure to redact any sensitive information. """ - )) - else: - print(reflow( - """ + ) + ) + else: + print( + reflow( + """ No action items identified. Please copy ALL of the above output, including the lines containing only backticks, into your GitHub issue or comment. Be sure to redact any sensitive information. """ - )) + ) + ) if __name__ == "__main__": - main() + main() diff --git a/tensorboard/tools/import_google_fonts.py b/tensorboard/tools/import_google_fonts.py index c2cc739874..f6de26cb19 100644 --- a/tensorboard/tools/import_google_fonts.py +++ b/tensorboard/tools/import_google_fonts.py @@ -37,11 +37,11 @@ import tensorflow as tf ROBOTO_URLS = [ - 'https://fonts.googleapis.com/css?family=Roboto:400,300,300italic,400italic,500,500italic,700,700italic', - 'https://fonts.googleapis.com/css?family=Roboto+Mono:400,700', + "https://fonts.googleapis.com/css?family=Roboto:400,300,300italic,400italic,500,500italic,700,700italic", + "https://fonts.googleapis.com/css?family=Roboto+Mono:400,700", ] -GOOGLE_LICENSE_HTML = '''\ +GOOGLE_LICENSE_HTML = """\ -''' - -flags.DEFINE_string('urls', ';'.join(ROBOTO_URLS), - 'Google Fonts stylesheet URLs (semicolons delimited)') -flags.DEFINE_string('path', '/font-roboto/roboto.html', 'Path of HTML file') -flags.DEFINE_string('repo', 'com_google_fonts_roboto', 'Name of repository') -flags.DEFINE_string('license', 'notice', 'Bazel category of license') -flags.DEFINE_string('license_summary', 'Apache 2.0', 'License description') -flags.DEFINE_string('license_html', GOOGLE_LICENSE_HTML, - 'HTML @license comment') -flags.DEFINE_string('user_agent', - 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6) ' - 'AppleWebKit/537.36 (KHTML, like Gecko) ' - 'Chrome/62.0.3202.94 ' - 'Safari/537.36', - 'HTTP User-Agent header to send to Google Fonts') -flags.DEFINE_string('mirror', 'http://mirror.tensorflow.org/', - 'Mirror URL prefix') +""" + +flags.DEFINE_string( + "urls", + ";".join(ROBOTO_URLS), + "Google Fonts stylesheet URLs (semicolons delimited)", +) +flags.DEFINE_string("path", "/font-roboto/roboto.html", "Path of HTML file") +flags.DEFINE_string("repo", "com_google_fonts_roboto", "Name of repository") +flags.DEFINE_string("license", "notice", "Bazel category of license") +flags.DEFINE_string("license_summary", "Apache 2.0", "License description") +flags.DEFINE_string( + "license_html", GOOGLE_LICENSE_HTML, "HTML @license comment" +) +flags.DEFINE_string( + "user_agent", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/62.0.3202.94 " + "Safari/537.36", + "HTTP User-Agent header to send to Google Fonts", +) +flags.DEFINE_string( + "mirror", "http://mirror.tensorflow.org/", "Mirror URL prefix" +) FLAGS = flags.FLAGS -BAR = '/%s/' % ('*' * 78) -NON_REPO_PATTERN = re.compile(r'[^_a-z0-9]') -SCHEME_PATTERN = re.compile(r'https?://') -CSS_PATTERN = re.compile(r'(?:/\* (?P[-\w]+) \*/\s+)?' - r'(?P@font-face \{' - r'.*?family: [\'"]?(?P[^\'"]+)' - r'.*?src: local\([\'"]?(?P[^\'")]+)' - r'.*?url\([\'"]?(?P[^\'")]+)' - r'.*?\})', re.S) +BAR = "/%s/" % ("*" * 78) +NON_REPO_PATTERN = re.compile(r"[^_a-z0-9]") +SCHEME_PATTERN = re.compile(r"https?://") +CSS_PATTERN = re.compile( + r"(?:/\* (?P[-\w]+) \*/\s+)?" + r"(?P@font-face \{" + r'.*?family: [\'"]?(?P[^\'"]+)' + r'.*?src: local\([\'"]?(?P[^\'")]+)' + r'.*?url\([\'"]?(?P[^\'")]+)' + r".*?\})", + re.S, +) def open_url(url): - ru = urlparse.urlparse(url) - pu = urlparse.ParseResult('', '', ru.path, ru.params, ru.query, ru.fragment) - if ru.scheme == 'https': - c = httplib.HTTPSConnection(ru.netloc) - else: - c = httplib.HTTPConnection(ru.netloc) - c.putrequest('GET', pu.geturl()) - c.putheader('User-Agent', FLAGS.user_agent) - c.endheaders() - return c.getresponse() + ru = urlparse.urlparse(url) + pu = urlparse.ParseResult("", "", ru.path, ru.params, ru.query, ru.fragment) + if ru.scheme == "https": + c = httplib.HTTPSConnection(ru.netloc) + else: + c = httplib.HTTPConnection(ru.netloc) + c.putrequest("GET", pu.geturl()) + c.putheader("User-Agent", FLAGS.user_agent) + c.endheaders() + return c.getresponse() def get_sha256(fp): - hasher = hashlib.sha256() - for chunk in iter(lambda: fp.read(8 * 1024), ''): - hasher.update(chunk) - return hasher.hexdigest() + hasher = hashlib.sha256() + for chunk in iter(lambda: fp.read(8 * 1024), ""): + hasher.update(chunk) + return hasher.hexdigest() def get_mirror_url(original): - return SCHEME_PATTERN.sub(FLAGS.mirror, original) + return SCHEME_PATTERN.sub(FLAGS.mirror, original) def get_css(m): - url = m.group('url') - path = os.path.dirname(FLAGS.path) + url[url.rindex('/'):] - return m.group('css').replace(url, path) + url = m.group("url") + path = os.path.dirname(FLAGS.path) + url[url.rindex("/") :] + return m.group("css").replace(url, path) def underify(g): - return '_'.join(NON_REPO_PATTERN.sub('_', s.lower()) for s in g if s) + return "_".join(NON_REPO_PATTERN.sub("_", s.lower()) for s in g if s) def add_inline_file(lines, inner_lines): - for line in inner_lines: - lines.append(' %r,' % line) + for line in inner_lines: + lines.append(" %r," % line) def get_html_file(css): - result = [] - result.extend(FLAGS.license_html.split('\n')) - result.append('') - return result + result = [] + result.extend(FLAGS.license_html.split("\n")) + result.append("") + return result def get_extra_build_file_content(html): - result = [ - 'load("@io_bazel_rules_closure//closure:defs.bzl", "web_library")', - '', - 'web_library(', - ' name = "%s",' % FLAGS.repo, - ' path = "%s",' % os.path.dirname(FLAGS.path), - ' srcs = [', - ' "%s",' % os.path.basename(FLAGS.path), - ' ":files",', - ' ],', - ')', - '', - 'genrule(', - ' name = "html",', - ' outs = ["%s"],' % os.path.basename(FLAGS.path), - ' cmd = "\\n".join([', - ' "cat <<\'EOF\' >$@",', - ] - add_inline_file(result, html) - result.append(' "EOF",') - result.append(' ]),') - result.append(')') - return result + result = [ + 'load("@io_bazel_rules_closure//closure:defs.bzl", "web_library")', + "", + "web_library(", + ' name = "%s",' % FLAGS.repo, + ' path = "%s",' % os.path.dirname(FLAGS.path), + " srcs = [", + ' "%s",' % os.path.basename(FLAGS.path), + ' ":files",', + " ],", + ")", + "", + "genrule(", + ' name = "html",', + ' outs = ["%s"],' % os.path.basename(FLAGS.path), + ' cmd = "\\n".join([', + " \"cat <<'EOF' >$@\",", + ] + add_inline_file(result, html) + result.append(' "EOF",') + result.append(" ]),") + result.append(")") + return result def main(unused_argv=None): - assets = [] - for url in FLAGS.urls.split(';'): - for m in CSS_PATTERN.finditer(open_url(url).read()): - assets.append(m) - assets.sort(key=lambda m: (m.group('family'), - m.group('name'), - m.group('language'))) - - sys.stdout.write( - 'filegroup_external(\n' - ' name = "%s",\n' - ' licenses = ["%s"], # %s\n' - ' sha256_urls = {\n' % - (FLAGS.repo, FLAGS.license, FLAGS.license_summary)) - - css = [] - for m in assets: - css.append(get_css(m)) + assets = [] + for url in FLAGS.urls.split(";"): + for m in CSS_PATTERN.finditer(open_url(url).read()): + assets.append(m) + assets.sort( + key=lambda m: (m.group("family"), m.group("name"), m.group("language")) + ) + + sys.stdout.write( + "filegroup_external(\n" + ' name = "%s",\n' + ' licenses = ["%s"], # %s\n' + " sha256_urls = {\n" + % (FLAGS.repo, FLAGS.license, FLAGS.license_summary) + ) + + css = [] + for m in assets: + css.append(get_css(m)) + sys.stdout.write( + " # %s (%s)\n" + ' "%s": [\n' + ' "%s",\n' + ' "%s",\n' + " ],\n" + % ( + m.group("name"), + m.group("language") or "all", + get_sha256(open_url(m.group("url"))), + get_mirror_url(m.group("url")), + m.group("url"), + ) + ) + sys.stdout.write( - ' # %s (%s)\n' - ' "%s": [\n' - ' \"%s\",\n' - ' \"%s\",\n' - ' ],\n' % - (m.group('name'), - m.group('language') or 'all', - get_sha256(open_url(m.group('url'))), - get_mirror_url(m.group('url')), - m.group('url'))) - - sys.stdout.write( - ' },\n' - ' generated_rule_name = "files",\n' - ' extra_build_file_content = "\\n".join([\n') - result = [] - add_inline_file( - result, - get_extra_build_file_content( - get_html_file(css))) - for line in result: - sys.stdout.write(line + '\n') - sys.stdout.write( - ' ]),\n' - ')\n\n') - - -if __name__ == '__main__': - app.run(main) + " },\n" + ' generated_rule_name = "files",\n' + ' extra_build_file_content = "\\n".join([\n' + ) + result = [] + add_inline_file(result, get_extra_build_file_content(get_html_file(css))) + for line in result: + sys.stdout.write(line + "\n") + sys.stdout.write(" ]),\n" ")\n\n") + + +if __name__ == "__main__": + app.run(main) diff --git a/tensorboard/tools/whitespace_hygiene_test.py b/tensorboard/tools/whitespace_hygiene_test.py index 3e127811af..0ca9829e7b 100755 --- a/tensorboard/tools/whitespace_hygiene_test.py +++ b/tensorboard/tools/whitespace_hygiene_test.py @@ -28,75 +28,83 @@ import sys -exceptions = frozenset([ - # End-of-line whitespace is semantic in patch files when a line - # contains a single space. - "third_party/mock_call_assertions.patch", -]) +exceptions = frozenset( + [ + # End-of-line whitespace is semantic in patch files when a line + # contains a single space. + "third_party/mock_call_assertions.patch", + ] +) Match = collections.namedtuple("Match", ("filename", "line_number", "line")) def main(): - chdir_to_repo_root() - matches = git_grep(" *$") - errors = [m for m in matches if m.filename not in exceptions] - okay = True - - if errors: - print("Superfluous trailing whitespace:") - for error in errors: - print("%s:%d:%s$" % (error.filename, error.line_number, error.line)) - print() - okay = False - - stale_exceptions = exceptions - frozenset(m.filename for m in matches) - if stale_exceptions: - print("Stale exceptions (no whitespace problems; prune exceptions list):") - for filename in stale_exceptions: - print(filename) - print() - okay = False - - sys.exit(0 if okay else 1) + chdir_to_repo_root() + matches = git_grep(" *$") + errors = [m for m in matches if m.filename not in exceptions] + okay = True + + if errors: + print("Superfluous trailing whitespace:") + for error in errors: + print("%s:%d:%s$" % (error.filename, error.line_number, error.line)) + print() + okay = False + + stale_exceptions = exceptions - frozenset(m.filename for m in matches) + if stale_exceptions: + print( + "Stale exceptions (no whitespace problems; prune exceptions list):" + ) + for filename in stale_exceptions: + print(filename) + print() + okay = False + + sys.exit(0 if okay else 1) def git_grep(pattern): - """Run `git grep` and collect matches. - - This function exits the process if `git grep` writes any stderr: for - instance, if the provided pattern is an invalid regular expression. - - Args: - pattern: `str`; a pattern argument to `git grep`. - - Returns: - A list of `Match` values. - """ - cmd = ["git", "grep", "-Izn", "--", pattern] - p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - (stdout, stderr) = p.communicate() - if stderr: - getattr(sys.stderr, "buffer", sys.stderr).write(stderr) # Python 2 compat - sys.exit(1) - result = [] - for line in stdout.splitlines(): # assumes no newline characters in filenames - (filename_raw, line_number_raw, line_raw) = line.split(b"\0", 2) - match = Match( - filename=filename_raw.decode("utf-8", errors="replace"), - line_number=int(line_number_raw), - line=line_raw.decode("utf-8", errors="replace"), - ) - result.append(match) - return result + """Run `git grep` and collect matches. + + This function exits the process if `git grep` writes any stderr: for + instance, if the provided pattern is an invalid regular expression. + + Args: + pattern: `str`; a pattern argument to `git grep`. + + Returns: + A list of `Match` values. + """ + cmd = ["git", "grep", "-Izn", "--", pattern] + p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + (stdout, stderr) = p.communicate() + if stderr: + getattr(sys.stderr, "buffer", sys.stderr).write( + stderr + ) # Python 2 compat + sys.exit(1) + result = [] + for ( + line + ) in stdout.splitlines(): # assumes no newline characters in filenames + (filename_raw, line_number_raw, line_raw) = line.split(b"\0", 2) + match = Match( + filename=filename_raw.decode("utf-8", errors="replace"), + line_number=int(line_number_raw), + line=line_raw.decode("utf-8", errors="replace"), + ) + result.append(match) + return result def chdir_to_repo_root(): - toplevel = subprocess.check_output(["git", "rev-parse", "--show-toplevel"]) - toplevel = toplevel[:-1] # trim trailing LF - os.chdir(toplevel) + toplevel = subprocess.check_output(["git", "rev-parse", "--show-toplevel"]) + toplevel = toplevel[:-1] # trim trailing LF + os.chdir(toplevel) if __name__ == "__main__": - main() + main() diff --git a/tensorboard/uploader/auth.py b/tensorboard/uploader/auth.py index a435c6ff5a..1a1210c7a5 100644 --- a/tensorboard/uploader/auth.py +++ b/tensorboard/uploader/auth.py @@ -69,159 +69,179 @@ # Components of the relative path (within the user settings directory) at which # to store TensorBoard uploader credentials. TENSORBOARD_CREDENTIALS_FILEPATH_PARTS = [ - "tensorboard", "credentials", "uploader-creds.json"] + "tensorboard", + "credentials", + "uploader-creds.json", +] class CredentialsStore(object): - """Private file store for a `google.oauth2.credentials.Credentials`.""" - - _DEFAULT_CONFIG_DIRECTORY = object() # Sentinel value. - - def __init__(self, user_config_directory=_DEFAULT_CONFIG_DIRECTORY): - """Creates a CredentialsStore. - - Args: - user_config_directory: Optional absolute path to the root directory for - storing user configs, under which to store the credentials file. If not - set, defaults to a platform-specific location. If set to None, the - store is disabled (reads return None; write and clear are no-ops). - """ - if user_config_directory is CredentialsStore._DEFAULT_CONFIG_DIRECTORY: - user_config_directory = util.get_user_config_directory() - if user_config_directory is None: - logger.warning( - "Credentials caching disabled - no private config directory found") - if user_config_directory is None: - self._credentials_filepath = None - else: - self._credentials_filepath = os.path.join( - user_config_directory, *TENSORBOARD_CREDENTIALS_FILEPATH_PARTS) - - def read_credentials(self): - """Returns the current `google.oauth2.credentials.Credentials`, or None.""" - if self._credentials_filepath is None: - return None - if os.path.exists(self._credentials_filepath): - return google.oauth2.credentials.Credentials.from_authorized_user_file( - self._credentials_filepath) - return None - - def write_credentials(self, credentials): - """Writes a `google.oauth2.credentials.Credentials` to the store.""" - if not isinstance(credentials, google.oauth2.credentials.Credentials): - raise TypeError("Cannot write credentials of type %s" % type(credentials)) - if self._credentials_filepath is None: - return - # Make the credential file private if not on Windows; on Windows we rely on - # the default user config settings directory being private since we don't - # have a straightforward way to make an individual file private. - private = os.name != "nt" - util.make_file_with_directories(self._credentials_filepath, private=private) - data = { - "refresh_token": credentials.refresh_token, - "token_uri": credentials.token_uri, - "client_id": credentials.client_id, - "client_secret": credentials.client_secret, - "scopes": credentials.scopes, - "type": "authorized_user", - } - with open(self._credentials_filepath, "w") as f: - json.dump(data, f) - - def clear(self): - """Clears the store of any persisted credentials information.""" - if self._credentials_filepath is None: - return - try: - os.remove(self._credentials_filepath) - except OSError as e: - if e.errno != errno.ENOENT: - raise + """Private file store for a `google.oauth2.credentials.Credentials`.""" + + _DEFAULT_CONFIG_DIRECTORY = object() # Sentinel value. + + def __init__(self, user_config_directory=_DEFAULT_CONFIG_DIRECTORY): + """Creates a CredentialsStore. + + Args: + user_config_directory: Optional absolute path to the root directory for + storing user configs, under which to store the credentials file. If not + set, defaults to a platform-specific location. If set to None, the + store is disabled (reads return None; write and clear are no-ops). + """ + if user_config_directory is CredentialsStore._DEFAULT_CONFIG_DIRECTORY: + user_config_directory = util.get_user_config_directory() + if user_config_directory is None: + logger.warning( + "Credentials caching disabled - no private config directory found" + ) + if user_config_directory is None: + self._credentials_filepath = None + else: + self._credentials_filepath = os.path.join( + user_config_directory, *TENSORBOARD_CREDENTIALS_FILEPATH_PARTS + ) + + def read_credentials(self): + """Returns the current `google.oauth2.credentials.Credentials`, or + None.""" + if self._credentials_filepath is None: + return None + if os.path.exists(self._credentials_filepath): + return google.oauth2.credentials.Credentials.from_authorized_user_file( + self._credentials_filepath + ) + return None + + def write_credentials(self, credentials): + """Writes a `google.oauth2.credentials.Credentials` to the store.""" + if not isinstance(credentials, google.oauth2.credentials.Credentials): + raise TypeError( + "Cannot write credentials of type %s" % type(credentials) + ) + if self._credentials_filepath is None: + return + # Make the credential file private if not on Windows; on Windows we rely on + # the default user config settings directory being private since we don't + # have a straightforward way to make an individual file private. + private = os.name != "nt" + util.make_file_with_directories( + self._credentials_filepath, private=private + ) + data = { + "refresh_token": credentials.refresh_token, + "token_uri": credentials.token_uri, + "client_id": credentials.client_id, + "client_secret": credentials.client_secret, + "scopes": credentials.scopes, + "type": "authorized_user", + } + with open(self._credentials_filepath, "w") as f: + json.dump(data, f) + + def clear(self): + """Clears the store of any persisted credentials information.""" + if self._credentials_filepath is None: + return + try: + os.remove(self._credentials_filepath) + except OSError as e: + if e.errno != errno.ENOENT: + raise def build_installed_app_flow(client_config): - """Returns a `CustomInstalledAppFlow` for the given config. + """Returns a `CustomInstalledAppFlow` for the given config. - Args: - client_config (Mapping[str, Any]): The client configuration in the Google - client secrets format. + Args: + client_config (Mapping[str, Any]): The client configuration in the Google + client secrets format. - Returns: - CustomInstalledAppFlow: the constructed flow. - """ - return CustomInstalledAppFlow.from_client_config( - client_config, scopes=OPENID_CONNECT_SCOPES) + Returns: + CustomInstalledAppFlow: the constructed flow. + """ + return CustomInstalledAppFlow.from_client_config( + client_config, scopes=OPENID_CONNECT_SCOPES + ) class CustomInstalledAppFlow(google_auth_oauthlib.flow.InstalledAppFlow): - """Customized version of the Installed App OAuth2 flow.""" - - def run(self, force_console=False): - """Run the flow using a local server if possible, otherwise the console.""" - # TODO(b/141721828): make auto-detection smarter, especially for macOS. - if not force_console and os.getenv("DISPLAY"): - try: - return self.run_local_server(port=0) - except webbrowser.Error: - sys.stderr.write("Falling back to console authentication flow...\n") - return self.run_console() + """Customized version of the Installed App OAuth2 flow.""" + + def run(self, force_console=False): + """Run the flow using a local server if possible, otherwise the + console.""" + # TODO(b/141721828): make auto-detection smarter, especially for macOS. + if not force_console and os.getenv("DISPLAY"): + try: + return self.run_local_server(port=0) + except webbrowser.Error: + sys.stderr.write( + "Falling back to console authentication flow...\n" + ) + return self.run_console() class IdTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin): - """A `gRPC AuthMetadataPlugin` that uses ID tokens. + """A `gRPC AuthMetadataPlugin` that uses ID tokens. - This works like the existing `google.auth.transport.grpc.AuthMetadataPlugin` - except that instead of always using access tokens, it preferentially uses the - `Credentials.id_token` property if available (and logs an error otherwise). + This works like the existing `google.auth.transport.grpc.AuthMetadataPlugin` + except that instead of always using access tokens, it preferentially uses the + `Credentials.id_token` property if available (and logs an error otherwise). - See http://www.grpc.io/grpc/python/grpc.html#grpc.AuthMetadataPlugin - """ - - def __init__(self, credentials, request): - """Constructs an IdTokenAuthMetadataPlugin. - - Args: - credentials (google.auth.credentials.Credentials): The credentials to - add to requests. - request (google.auth.transport.Request): A HTTP transport request object - used to refresh credentials as needed. + See http://www.grpc.io/grpc/python/grpc.html#grpc.AuthMetadataPlugin """ - super(IdTokenAuthMetadataPlugin, self).__init__() - if not isinstance(credentials, google.oauth2.credentials.Credentials): - raise TypeError( - "Cannot get ID tokens from credentials type %s" % type(credentials)) - self._credentials = credentials - self._request = request - - def __call__(self, context, callback): - """Passes authorization metadata into the given callback. - Args: - context (grpc.AuthMetadataContext): The RPC context. - callback (grpc.AuthMetadataPluginCallback): The callback that will - be invoked to pass in the authorization metadata. - """ - headers = {} - self._credentials.before_request( - self._request, context.method_name, context.service_url, headers) - id_token = getattr(self._credentials, "id_token", None) - if id_token: - self._credentials.apply(headers, token=id_token) - else: - logger.error("Failed to find ID token credentials") - # Pass headers as key-value pairs to match CallCredentials metadata. - callback(list(headers.items()), None) + def __init__(self, credentials, request): + """Constructs an IdTokenAuthMetadataPlugin. + + Args: + credentials (google.auth.credentials.Credentials): The credentials to + add to requests. + request (google.auth.transport.Request): A HTTP transport request object + used to refresh credentials as needed. + """ + super(IdTokenAuthMetadataPlugin, self).__init__() + if not isinstance(credentials, google.oauth2.credentials.Credentials): + raise TypeError( + "Cannot get ID tokens from credentials type %s" + % type(credentials) + ) + self._credentials = credentials + self._request = request + + def __call__(self, context, callback): + """Passes authorization metadata into the given callback. + + Args: + context (grpc.AuthMetadataContext): The RPC context. + callback (grpc.AuthMetadataPluginCallback): The callback that will + be invoked to pass in the authorization metadata. + """ + headers = {} + self._credentials.before_request( + self._request, context.method_name, context.service_url, headers + ) + id_token = getattr(self._credentials, "id_token", None) + if id_token: + self._credentials.apply(headers, token=id_token) + else: + logger.error("Failed to find ID token credentials") + # Pass headers as key-value pairs to match CallCredentials metadata. + callback(list(headers.items()), None) def id_token_call_credentials(credentials): - """Constructs `grpc.CallCredentials` using `google.auth.Credentials.id_token`. + """Constructs `grpc.CallCredentials` using + `google.auth.Credentials.id_token`. - Args: - credentials (google.auth.credentials.Credentials): The credentials to use. + Args: + credentials (google.auth.credentials.Credentials): The credentials to use. - Returns: - grpc.CallCredentials: The call credentials. - """ - request = google.auth.transport.requests.Request() - return grpc.metadata_call_credentials( - IdTokenAuthMetadataPlugin(credentials, request)) + Returns: + grpc.CallCredentials: The call credentials. + """ + request = google.auth.transport.requests.Request() + return grpc.metadata_call_credentials( + IdTokenAuthMetadataPlugin(credentials, request) + ) diff --git a/tensorboard/uploader/auth_test.py b/tensorboard/uploader/auth_test.py index 587109f625..0e9178a5b9 100644 --- a/tensorboard/uploader/auth_test.py +++ b/tensorboard/uploader/auth_test.py @@ -30,111 +30,127 @@ class CredentialsStoreTest(tb_test.TestCase): - - def test_no_config_dir(self): - store = auth.CredentialsStore(user_config_directory=None) - self.assertIsNone(store.read_credentials()) - creds = google.oauth2.credentials.Credentials(token=None) - store.write_credentials(creds) - store.clear() - - def test_clear_existent_file(self): - root = self.get_temp_dir() - path = os.path.join( - root, "tensorboard", "credentials", "uploader-creds.json") - os.makedirs(os.path.dirname(path)) - open(path, mode="w").close() - self.assertTrue(os.path.exists(path)) - auth.CredentialsStore(user_config_directory=root).clear() - self.assertFalse(os.path.exists(path)) - - def test_clear_nonexistent_file(self): - root = self.get_temp_dir() - path = os.path.join( - root, "tensorboard", "credentials", "uploader-creds.json") - self.assertFalse(os.path.exists(path)) - auth.CredentialsStore(user_config_directory=root).clear() - self.assertFalse(os.path.exists(path)) - - def test_write_wrong_type(self): - creds = google.auth.credentials.AnonymousCredentials() - with self.assertRaisesRegex(TypeError, "google.auth.credentials"): - auth.CredentialsStore(user_config_directory=None).write_credentials(creds) - - def test_write_creates_private_file(self): - root = self.get_temp_dir() - auth.CredentialsStore(user_config_directory=root).write_credentials( - google.oauth2.credentials.Credentials( - token=None, refresh_token="12345")) - path = os.path.join( - root, "tensorboard", "credentials", "uploader-creds.json") - self.assertTrue(os.path.exists(path)) - # Skip permissions check on Windows. - if os.name != "nt": - self.assertEqual(0o600, os.stat(path).st_mode & 0o777) - with open(path) as f: - contents = json.load(f) - self.assertEqual("12345", contents["refresh_token"]) - - def test_write_overwrites_file(self): - root = self.get_temp_dir() - store = auth.CredentialsStore(user_config_directory=root) - # Write twice to ensure that we're overwriting correctly. - store.write_credentials(google.oauth2.credentials.Credentials( - token=None, refresh_token="12345")) - store.write_credentials(google.oauth2.credentials.Credentials( - token=None, refresh_token="67890")) - path = os.path.join( - root, "tensorboard", "credentials", "uploader-creds.json") - self.assertTrue(os.path.exists(path)) - with open(path) as f: - contents = json.load(f) - self.assertEqual("67890", contents["refresh_token"]) - - def test_write_and_read_roundtrip(self): - orig_creds = google.oauth2.credentials.Credentials( - token="12345", - refresh_token="67890", - token_uri="https://oauth2.googleapis.com/token", - client_id="my-client", - client_secret="123abc456xyz", - scopes=["userinfo", "email"]) - root = self.get_temp_dir() - store = auth.CredentialsStore(user_config_directory=root) - store.write_credentials(orig_creds) - creds = store.read_credentials() - self.assertEqual(orig_creds.refresh_token, creds.refresh_token) - self.assertEqual(orig_creds.token_uri, creds.token_uri) - self.assertEqual(orig_creds.client_id, creds.client_id) - self.assertEqual(orig_creds.client_secret, creds.client_secret) - - def test_read_nonexistent_file(self): - root = self.get_temp_dir() - store = auth.CredentialsStore(user_config_directory=root) - self.assertIsNone(store.read_credentials()) - - def test_read_non_json_file(self): - root = self.get_temp_dir() - store = auth.CredentialsStore(user_config_directory=root) - path = os.path.join( - root, "tensorboard", "credentials", "uploader-creds.json") - os.makedirs(os.path.dirname(path)) - with open(path, mode="w") as f: - f.write("foobar") - with self.assertRaises(ValueError): - store.read_credentials() - - def test_read_invalid_json_file(self): - root = self.get_temp_dir() - store = auth.CredentialsStore(user_config_directory=root) - path = os.path.join( - root, "tensorboard", "credentials", "uploader-creds.json") - os.makedirs(os.path.dirname(path)) - with open(path, mode="w") as f: - f.write("{}") - with self.assertRaises(ValueError): - store.read_credentials() + def test_no_config_dir(self): + store = auth.CredentialsStore(user_config_directory=None) + self.assertIsNone(store.read_credentials()) + creds = google.oauth2.credentials.Credentials(token=None) + store.write_credentials(creds) + store.clear() + + def test_clear_existent_file(self): + root = self.get_temp_dir() + path = os.path.join( + root, "tensorboard", "credentials", "uploader-creds.json" + ) + os.makedirs(os.path.dirname(path)) + open(path, mode="w").close() + self.assertTrue(os.path.exists(path)) + auth.CredentialsStore(user_config_directory=root).clear() + self.assertFalse(os.path.exists(path)) + + def test_clear_nonexistent_file(self): + root = self.get_temp_dir() + path = os.path.join( + root, "tensorboard", "credentials", "uploader-creds.json" + ) + self.assertFalse(os.path.exists(path)) + auth.CredentialsStore(user_config_directory=root).clear() + self.assertFalse(os.path.exists(path)) + + def test_write_wrong_type(self): + creds = google.auth.credentials.AnonymousCredentials() + with self.assertRaisesRegex(TypeError, "google.auth.credentials"): + auth.CredentialsStore(user_config_directory=None).write_credentials( + creds + ) + + def test_write_creates_private_file(self): + root = self.get_temp_dir() + auth.CredentialsStore(user_config_directory=root).write_credentials( + google.oauth2.credentials.Credentials( + token=None, refresh_token="12345" + ) + ) + path = os.path.join( + root, "tensorboard", "credentials", "uploader-creds.json" + ) + self.assertTrue(os.path.exists(path)) + # Skip permissions check on Windows. + if os.name != "nt": + self.assertEqual(0o600, os.stat(path).st_mode & 0o777) + with open(path) as f: + contents = json.load(f) + self.assertEqual("12345", contents["refresh_token"]) + + def test_write_overwrites_file(self): + root = self.get_temp_dir() + store = auth.CredentialsStore(user_config_directory=root) + # Write twice to ensure that we're overwriting correctly. + store.write_credentials( + google.oauth2.credentials.Credentials( + token=None, refresh_token="12345" + ) + ) + store.write_credentials( + google.oauth2.credentials.Credentials( + token=None, refresh_token="67890" + ) + ) + path = os.path.join( + root, "tensorboard", "credentials", "uploader-creds.json" + ) + self.assertTrue(os.path.exists(path)) + with open(path) as f: + contents = json.load(f) + self.assertEqual("67890", contents["refresh_token"]) + + def test_write_and_read_roundtrip(self): + orig_creds = google.oauth2.credentials.Credentials( + token="12345", + refresh_token="67890", + token_uri="https://oauth2.googleapis.com/token", + client_id="my-client", + client_secret="123abc456xyz", + scopes=["userinfo", "email"], + ) + root = self.get_temp_dir() + store = auth.CredentialsStore(user_config_directory=root) + store.write_credentials(orig_creds) + creds = store.read_credentials() + self.assertEqual(orig_creds.refresh_token, creds.refresh_token) + self.assertEqual(orig_creds.token_uri, creds.token_uri) + self.assertEqual(orig_creds.client_id, creds.client_id) + self.assertEqual(orig_creds.client_secret, creds.client_secret) + + def test_read_nonexistent_file(self): + root = self.get_temp_dir() + store = auth.CredentialsStore(user_config_directory=root) + self.assertIsNone(store.read_credentials()) + + def test_read_non_json_file(self): + root = self.get_temp_dir() + store = auth.CredentialsStore(user_config_directory=root) + path = os.path.join( + root, "tensorboard", "credentials", "uploader-creds.json" + ) + os.makedirs(os.path.dirname(path)) + with open(path, mode="w") as f: + f.write("foobar") + with self.assertRaises(ValueError): + store.read_credentials() + + def test_read_invalid_json_file(self): + root = self.get_temp_dir() + store = auth.CredentialsStore(user_config_directory=root) + path = os.path.join( + root, "tensorboard", "credentials", "uploader-creds.json" + ) + os.makedirs(os.path.dirname(path)) + with open(path, mode="w") as f: + f.write("{}") + with self.assertRaises(ValueError): + store.read_credentials() if __name__ == "__main__": - tb_test.main() + tb_test.main() diff --git a/tensorboard/uploader/exporter.py b/tensorboard/uploader/exporter.py index ef6d6d7205..99ece52060 100644 --- a/tensorboard/uploader/exporter.py +++ b/tensorboard/uploader/exporter.py @@ -41,205 +41,217 @@ _FILENAME_SAFE_CHARS = frozenset(string.ascii_letters + string.digits + "-_") # Maximum value of a signed 64-bit integer. -_MAX_INT64 = 2**63 - 1 +_MAX_INT64 = 2 ** 63 - 1 + class TensorBoardExporter(object): - """Exports all of the user's experiment data from TensorBoard.dev. - - Data is exported into a directory, with one file per experiment. Each - experiment file is a sequence of time series, represented as a stream - of JSON objects, one per line. Each JSON object includes a run name, - tag name, `tensorboard.compat.proto.summary_pb2.SummaryMetadata` proto - (base64-encoded, standard RFC 4648 alphabet), and set of points. - Points are stored in three equal-length lists of steps, wall times (as - seconds since epoch), and scalar values, for storage efficiency. - - Such streams of JSON objects may be conveniently processed with tools - like jq(1). - - For example one line of an experiment file might read (when - pretty-printed): - - { - "points": { - "steps": [0, 5], - "values": [4.8935227394104, 2.5438034534454346], - "wall_times": [1563406522.669238, 1563406523.0268838] - }, - "run": "lr_1E-04,conv=1,fc=2", - "summary_metadata": "CgkKB3NjYWxhcnMSC3hlbnQveGVudF8x", - "tag": "xent/xent_1" - } - - This is a time series with two points, both logged on 2019-07-17, one - about 0.36 seconds after the other. - """ - - def __init__(self, reader_service_client, output_directory): - """Constructs a TensorBoardExporter. + """Exports all of the user's experiment data from TensorBoard.dev. + + Data is exported into a directory, with one file per experiment. Each + experiment file is a sequence of time series, represented as a stream + of JSON objects, one per line. Each JSON object includes a run name, + tag name, `tensorboard.compat.proto.summary_pb2.SummaryMetadata` proto + (base64-encoded, standard RFC 4648 alphabet), and set of points. + Points are stored in three equal-length lists of steps, wall times (as + seconds since epoch), and scalar values, for storage efficiency. + + Such streams of JSON objects may be conveniently processed with tools + like jq(1). + + For example one line of an experiment file might read (when + pretty-printed): + + { + "points": { + "steps": [0, 5], + "values": [4.8935227394104, 2.5438034534454346], + "wall_times": [1563406522.669238, 1563406523.0268838] + }, + "run": "lr_1E-04,conv=1,fc=2", + "summary_metadata": "CgkKB3NjYWxhcnMSC3hlbnQveGVudF8x", + "tag": "xent/xent_1" + } - Args: - reader_service_client: A TensorBoardExporterService stub instance. - output_directory: Path to a directory into which to write data. The - directory must not exist, to avoid stomping existing or concurrent - output. Its ancestors will be created if needed. + This is a time series with two points, both logged on 2019-07-17, one + about 0.36 seconds after the other. """ - self._api = reader_service_client - self._outdir = output_directory - parent_dir = os.path.dirname(self._outdir) - if parent_dir: - _mkdir_p(parent_dir) - try: - os.mkdir(self._outdir) - except OSError as e: - if e.errno == errno.EEXIST: - # Bail to avoid stomping existing output. - raise OutputDirectoryExistsError() - def export(self, read_time=None): - """Executes the export flow. + def __init__(self, reader_service_client, output_directory): + """Constructs a TensorBoardExporter. + + Args: + reader_service_client: A TensorBoardExporterService stub instance. + output_directory: Path to a directory into which to write data. The + directory must not exist, to avoid stomping existing or concurrent + output. Its ancestors will be created if needed. + """ + self._api = reader_service_client + self._outdir = output_directory + parent_dir = os.path.dirname(self._outdir) + if parent_dir: + _mkdir_p(parent_dir) + try: + os.mkdir(self._outdir) + except OSError as e: + if e.errno == errno.EEXIST: + # Bail to avoid stomping existing output. + raise OutputDirectoryExistsError() + + def export(self, read_time=None): + """Executes the export flow. + + Args: + read_time: A fixed timestamp from which to export data, as float seconds + since epoch (like `time.time()`). Optional; defaults to the current + time. + + Yields: + After each experiment is successfully downloaded, the ID of that + experiment, as a string. + """ + if read_time is None: + read_time = time.time() + for experiment_id in self._request_experiment_ids(read_time): + filepath = _scalars_filepath(self._outdir, experiment_id) + try: + with _open_excl(filepath) as outfile: + data = self._request_scalar_data(experiment_id, read_time) + for block in data: + json.dump(block, outfile, sort_keys=True) + outfile.write("\n") + outfile.flush() + yield experiment_id + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.CANCELLED: + raise GrpcTimeoutException(experiment_id) + else: + raise + + def _request_experiment_ids(self, read_time): + """Yields all of the calling user's experiment IDs, as strings.""" + for experiment in list_experiments(self._api, read_time=read_time): + if isinstance(experiment, export_service_pb2.Experiment): + yield experiment.experiment_id + elif isinstance(experiment, six.string_types): + yield experiment + else: + raise AssertionError( + "Unexpected experiment type: %r" % (experiment,) + ) + + def _request_scalar_data(self, experiment_id, read_time): + """Yields JSON-serializable blocks of scalar data.""" + request = export_service_pb2.StreamExperimentDataRequest() + request.experiment_id = experiment_id + util.set_timestamp(request.read_timestamp, read_time) + # No special error handling as we don't expect any errors from these + # calls: all experiments should exist (read consistency timestamp) + # and be owned by the calling user (only queried for own experiment + # IDs). Any non-transient errors would be internal, and we have no + # way to efficiently resume from transient errors because the server + # does not support pagination. + stream = self._api.StreamExperimentData( + request, metadata=grpc_util.version_metadata() + ) + for response in stream: + metadata = base64.b64encode( + response.tag_metadata.SerializeToString() + ).decode("ascii") + wall_times = [ + t.ToNanoseconds() / 1e9 for t in response.points.wall_times + ] + yield { + u"run": response.run_name, + u"tag": response.tag_name, + u"summary_metadata": metadata, + u"points": { + u"steps": list(response.points.steps), + u"wall_times": wall_times, + u"values": list(response.points.values), + }, + } + + +def list_experiments(api_client, fieldmask=None, read_time=None): + """Yields all of the calling user's experiments. Args: + api_client: A TensorBoardExporterService stub instance. + fieldmask: An optional `export_service_pb2.ExperimentMask` value. read_time: A fixed timestamp from which to export data, as float seconds since epoch (like `time.time()`). Optional; defaults to the current time. Yields: - After each experiment is successfully downloaded, the ID of that - experiment, as a string. + For each experiment owned by the user, an `export_service_pb2.Experiment` + value, or a simple string experiment ID for older servers. """ if read_time is None: - read_time = time.time() - for experiment_id in self._request_experiment_ids(read_time): - filepath = _scalars_filepath(self._outdir, experiment_id) - try: - with _open_excl(filepath) as outfile: - data = self._request_scalar_data(experiment_id, read_time) - for block in data: - json.dump(block, outfile, sort_keys=True) - outfile.write("\n") - outfile.flush() - yield experiment_id - except grpc.RpcError as e: - if e.code() == grpc.StatusCode.CANCELLED: - raise GrpcTimeoutException(experiment_id) - else: - raise - - def _request_experiment_ids(self, read_time): - """Yields all of the calling user's experiment IDs, as strings.""" - for experiment in list_experiments(self._api, read_time=read_time): - if isinstance(experiment, export_service_pb2.Experiment): - yield experiment.experiment_id - elif isinstance(experiment, six.string_types): - yield experiment - else: - raise AssertionError("Unexpected experiment type: %r" % (experiment,)) - - def _request_scalar_data(self, experiment_id, read_time): - """Yields JSON-serializable blocks of scalar data.""" - request = export_service_pb2.StreamExperimentDataRequest() - request.experiment_id = experiment_id + read_time = time.time() + request = export_service_pb2.StreamExperimentsRequest(limit=_MAX_INT64) util.set_timestamp(request.read_timestamp, read_time) - # No special error handling as we don't expect any errors from these - # calls: all experiments should exist (read consistency timestamp) - # and be owned by the calling user (only queried for own experiment - # IDs). Any non-transient errors would be internal, and we have no - # way to efficiently resume from transient errors because the server - # does not support pagination. - stream = self._api.StreamExperimentData( - request, metadata=grpc_util.version_metadata()) + if fieldmask: + request.experiments_mask.CopyFrom(fieldmask) + stream = api_client.StreamExperiments( + request, metadata=grpc_util.version_metadata() + ) for response in stream: - metadata = base64.b64encode( - response.tag_metadata.SerializeToString()).decode("ascii") - wall_times = [t.ToNanoseconds() / 1e9 for t in response.points.wall_times] - yield { - u"run": response.run_name, - u"tag": response.tag_name, - u"summary_metadata": metadata, - u"points": { - u"steps": list(response.points.steps), - u"wall_times": wall_times, - u"values": list(response.points.values), - }, - } - - -def list_experiments(api_client, fieldmask=None, read_time=None): - """Yields all of the calling user's experiments. - - Args: - api_client: A TensorBoardExporterService stub instance. - fieldmask: An optional `export_service_pb2.ExperimentMask` value. - read_time: A fixed timestamp from which to export data, as float seconds - since epoch (like `time.time()`). Optional; defaults to the current - time. - - Yields: - For each experiment owned by the user, an `export_service_pb2.Experiment` - value, or a simple string experiment ID for older servers. - """ - if read_time is None: - read_time = time.time() - request = export_service_pb2.StreamExperimentsRequest(limit=_MAX_INT64) - util.set_timestamp(request.read_timestamp, read_time) - if fieldmask: - request.experiments_mask.CopyFrom(fieldmask) - stream = api_client.StreamExperiments( - request, metadata=grpc_util.version_metadata()) - for response in stream: - if response.experiments: - for experiment in response.experiments: - yield experiment - else: - # Old servers. - for experiment_id in response.experiment_ids: - yield experiment_id + if response.experiments: + for experiment in response.experiments: + yield experiment + else: + # Old servers. + for experiment_id in response.experiment_ids: + yield experiment_id class OutputDirectoryExistsError(ValueError): - pass + pass class OutputFileExistsError(ValueError): - # Like Python 3's `__builtins__.FileExistsError`. - pass + # Like Python 3's `__builtins__.FileExistsError`. + pass + class GrpcTimeoutException(Exception): - def __init__(self, experiment_id): - super(GrpcTimeoutException, self).__init__(experiment_id) - self.experiment_id = experiment_id + def __init__(self, experiment_id): + super(GrpcTimeoutException, self).__init__(experiment_id) + self.experiment_id = experiment_id + def _scalars_filepath(base_dir, experiment_id): - """Gets file path in which to store scalars for the given experiment.""" - # Experiment IDs from the server should be filename-safe; verify - # this before creating any files. - bad_chars = frozenset(experiment_id) - _FILENAME_SAFE_CHARS - if bad_chars: - raise RuntimeError( - "Unexpected characters ({bad_chars!r}) in experiment ID {eid!r}".format( - bad_chars=sorted(bad_chars), eid=experiment_id)) - return os.path.join(base_dir, "scalars_%s.json" % experiment_id) + """Gets file path in which to store scalars for the given experiment.""" + # Experiment IDs from the server should be filename-safe; verify + # this before creating any files. + bad_chars = frozenset(experiment_id) - _FILENAME_SAFE_CHARS + if bad_chars: + raise RuntimeError( + "Unexpected characters ({bad_chars!r}) in experiment ID {eid!r}".format( + bad_chars=sorted(bad_chars), eid=experiment_id + ) + ) + return os.path.join(base_dir, "scalars_%s.json" % experiment_id) def _mkdir_p(path): - """Like `os.makedirs(path, exist_ok=True)`, but Python 2-compatible.""" - try: - os.makedirs(path) - except OSError as e: - if e.errno != errno.EEXIST or not os.path.isdir(path): - raise + """Like `os.makedirs(path, exist_ok=True)`, but Python 2-compatible.""" + try: + os.makedirs(path) + except OSError as e: + if e.errno != errno.EEXIST or not os.path.isdir(path): + raise def _open_excl(path): - """Like `open(path, "x")`, but Python 2-compatible.""" - try: - # `os.O_EXCL` works on Windows as well as POSIX-compliant systems. - # See: - fd = os.open(path, os.O_WRONLY | os.O_CREAT | os.O_EXCL) - except OSError as e: - if e.errno == errno.EEXIST: - raise OutputFileExistsError(path) - else: - raise - return os.fdopen(fd, "w") + """Like `open(path, "x")`, but Python 2-compatible.""" + try: + # `os.O_EXCL` works on Windows as well as POSIX-compliant systems. + # See: + fd = os.open(path, os.O_WRONLY | os.O_CREAT | os.O_EXCL) + except OSError as e: + if e.errno == errno.EEXIST: + raise OutputFileExistsError(path) + else: + raise + return os.fdopen(fd, "w") diff --git a/tensorboard/uploader/exporter_test.py b/tensorboard/uploader/exporter_test.py index 3fe06f016b..e490a67b3e 100644 --- a/tensorboard/uploader/exporter_test.py +++ b/tensorboard/uploader/exporter_test.py @@ -27,10 +27,10 @@ import grpc_testing try: - # python version >= 3.3 - from unittest import mock + # python version >= 3.3 + from unittest import mock except ImportError: - import mock # pylint: disable=unused-import + import mock # pylint: disable=unused-import from tensorboard.uploader.proto import export_service_pb2 @@ -43,406 +43,435 @@ class TensorBoardExporterTest(tb_test.TestCase): - - def _create_mock_api_client(self): - return _create_mock_api_client() - - def _make_experiments_response(self, eids): - return export_service_pb2.StreamExperimentsResponse(experiment_ids=eids) - - def test_e2e_success_case(self): - mock_api_client = self._create_mock_api_client() - mock_api_client.StreamExperiments.return_value = iter([ - export_service_pb2.StreamExperimentsResponse(experiment_ids=["789"]), - ]) - - def stream_experiments(request, **kwargs): - del request # unused - self.assertEqual(kwargs["metadata"], grpc_util.version_metadata()) - yield export_service_pb2.StreamExperimentsResponse( - experiment_ids=["123", "456"]) - yield export_service_pb2.StreamExperimentsResponse(experiment_ids=["789"]) - - def stream_experiment_data(request, **kwargs): - self.assertEqual(kwargs["metadata"], grpc_util.version_metadata()) - for run in ("train", "test"): - for tag in ("accuracy", "loss"): - response = export_service_pb2.StreamExperimentDataResponse() - response.run_name = run - response.tag_name = tag - display_name = "%s:%s" % (request.experiment_id, tag) - response.tag_metadata.CopyFrom( - test_util.scalar_metadata(display_name)) - for step in range(10): - response.points.steps.append(step) - response.points.values.append(2.0 * step) - response.points.wall_times.add( - seconds=1571084520 + step, nanos=862939144) - yield response - - mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments) - mock_api_client.StreamExperimentData = mock.Mock( - wraps=stream_experiment_data) - - outdir = os.path.join(self.get_temp_dir(), "outdir") - exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir) - start_time = 1571084846.25 - start_time_pb = test_util.timestamp_pb(1571084846250000000) - - generator = exporter.export(read_time=start_time) - expected_files = [] - self.assertTrue(os.path.isdir(outdir)) - self.assertCountEqual(expected_files, os.listdir(outdir)) - mock_api_client.StreamExperiments.assert_not_called() - mock_api_client.StreamExperimentData.assert_not_called() - - # The first iteration should request the list of experiments and - # data for one of them. - self.assertEqual(next(generator), "123") - expected_files.append("scalars_123.json") - self.assertCountEqual(expected_files, os.listdir(outdir)) - - expected_eids_request = export_service_pb2.StreamExperimentsRequest() - expected_eids_request.read_timestamp.CopyFrom(start_time_pb) - expected_eids_request.limit = 2**63 - 1 - mock_api_client.StreamExperiments.assert_called_once_with( - expected_eids_request, metadata=grpc_util.version_metadata()) - - expected_data_request = export_service_pb2.StreamExperimentDataRequest() - expected_data_request.experiment_id = "123" - expected_data_request.read_timestamp.CopyFrom(start_time_pb) - mock_api_client.StreamExperimentData.assert_called_once_with( - expected_data_request, metadata=grpc_util.version_metadata()) - - # The next iteration should just request data for the next experiment. - mock_api_client.StreamExperiments.reset_mock() - mock_api_client.StreamExperimentData.reset_mock() - self.assertEqual(next(generator), "456") - - expected_files.append("scalars_456.json") - self.assertCountEqual(expected_files, os.listdir(outdir)) - mock_api_client.StreamExperiments.assert_not_called() - expected_data_request.experiment_id = "456" - mock_api_client.StreamExperimentData.assert_called_once_with( - expected_data_request, metadata=grpc_util.version_metadata()) - - # Again, request data for the next experiment; this experiment ID - # was in the second response batch in the list of IDs. - expected_files.append("scalars_789.json") - mock_api_client.StreamExperiments.reset_mock() - mock_api_client.StreamExperimentData.reset_mock() - self.assertEqual(next(generator), "789") - - self.assertCountEqual(expected_files, os.listdir(outdir)) - mock_api_client.StreamExperiments.assert_not_called() - expected_data_request.experiment_id = "789" - mock_api_client.StreamExperimentData.assert_called_once_with( - expected_data_request, metadata=grpc_util.version_metadata()) - - # The final continuation shouldn't need to send any RPCs. - mock_api_client.StreamExperiments.reset_mock() - mock_api_client.StreamExperimentData.reset_mock() - self.assertEqual(list(generator), []) - - self.assertCountEqual(expected_files, os.listdir(outdir)) - mock_api_client.StreamExperiments.assert_not_called() - mock_api_client.StreamExperimentData.assert_not_called() - - # Spot-check one of the files. - with open(os.path.join(outdir, "scalars_456.json")) as infile: - jsons = [json.loads(line) for line in infile] - self.assertLen(jsons, 4) - datum = jsons[2] - self.assertEqual(datum.pop("run"), "test") - self.assertEqual(datum.pop("tag"), "accuracy") - summary_metadata = summary_pb2.SummaryMetadata.FromString( - base64.b64decode(datum.pop("summary_metadata"))) - expected_summary_metadata = test_util.scalar_metadata("456:accuracy") - self.assertEqual(summary_metadata, expected_summary_metadata) - points = datum.pop("points") - expected_steps = [x for x in range(10)] - expected_values = [2.0 * x for x in range(10)] - expected_wall_times = [1571084520.862939144 + x for x in range(10)] - self.assertEqual(points.pop("steps"), expected_steps) - self.assertEqual(points.pop("values"), expected_values) - self.assertEqual(points.pop("wall_times"), expected_wall_times) - self.assertEqual(points, {}) - self.assertEqual(datum, {}) - - def test_rejects_dangerous_experiment_ids(self): - mock_api_client = self._create_mock_api_client() - - def stream_experiments(request, **kwargs): - del request # unused - yield export_service_pb2.StreamExperimentsResponse( - experiment_ids=["../authorized_keys"]) - - mock_api_client.StreamExperiments = stream_experiments - - outdir = os.path.join(self.get_temp_dir(), "outdir") - exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir) - generator = exporter.export() - - with self.assertRaises(RuntimeError) as cm: - next(generator) - - msg = str(cm.exception) - self.assertIn("Unexpected characters", msg) - self.assertIn(repr(sorted([u".", u"/"])), msg) - self.assertIn("../authorized_keys", msg) - mock_api_client.StreamExperimentData.assert_not_called() - - def test_fails_nicely_on_stream_experiment_data_timeout(self): - # Setup: Client where: - # 1. stream_experiments will say there is one experiment_id. - # 2. stream_experiment_data will raise a grpc CANCELLED, as per - # a timeout. - mock_api_client = self._create_mock_api_client() - experiment_id="123" - - def stream_experiments(request, **kwargs): - del request # unused - yield export_service_pb2.StreamExperimentsResponse( - experiment_ids=[experiment_id]) - - def stream_experiment_data(request, **kwargs): - raise test_util.grpc_error(grpc.StatusCode.CANCELLED, "details string") - - mock_api_client.StreamExperiments = stream_experiments - mock_api_client.StreamExperimentData = stream_experiment_data - - outdir = os.path.join(self.get_temp_dir(), "outdir") - # Execute: exporter.export() - exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir) - generator = exporter.export() - # Expect: A nice exception of the right type and carrying the right - # experiment_id. - with self.assertRaises(exporter_lib.GrpcTimeoutException) as cm: - next(generator) - self.assertEquals(cm.exception.experiment_id, experiment_id) - - def test_stream_experiment_data_passes_through_unexpected_exception(self): - # Setup: Client where: - # 1. stream_experiments will say there is one experiment_id. - # 2. stream_experiment_data will throw an internal error. - mock_api_client = self._create_mock_api_client() - experiment_id = "123" - - def stream_experiments(request, **kwargs): - del request # unused - yield export_service_pb2.StreamExperimentsResponse( - experiment_ids=[experiment_id]) - - def stream_experiment_data(request, **kwargs): - del request # unused - raise test_util.grpc_error(grpc.StatusCode.INTERNAL, "details string") - - mock_api_client.StreamExperiments = stream_experiments - mock_api_client.StreamExperimentData = stream_experiment_data - - outdir = os.path.join(self.get_temp_dir(), "outdir") - # Execute: exporter.export(). - exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir) - generator = exporter.export() - # Expect: The internal error is passed through. - with self.assertRaises(grpc.RpcError) as cm: - next(generator) - self.assertEquals(cm.exception.details(), "details string") - - def test_handles_outdir_with_no_slash(self): - oldcwd = os.getcwd() - try: - os.chdir(self.get_temp_dir()) - mock_api_client = self._create_mock_api_client() - mock_api_client.StreamExperiments.return_value = iter([ - export_service_pb2.StreamExperimentsResponse(experiment_ids=["123"]), - ]) - mock_api_client.StreamExperimentData.return_value = iter([ - export_service_pb2.StreamExperimentDataResponse() - ]) - - exporter = exporter_lib.TensorBoardExporter(mock_api_client, "outdir") - generator = exporter.export() - self.assertEqual(list(generator), ["123"]) - self.assertTrue(os.path.isdir("outdir")) - finally: - os.chdir(oldcwd) - - def test_rejects_existing_directory(self): - mock_api_client = self._create_mock_api_client() - outdir = os.path.join(self.get_temp_dir(), "outdir") - os.mkdir(outdir) - with open(os.path.join(outdir, "scalars_999.json"), "w"): - pass - - with self.assertRaises(exporter_lib.OutputDirectoryExistsError): - exporter_lib.TensorBoardExporter(mock_api_client, outdir) - - mock_api_client.StreamExperiments.assert_not_called() - mock_api_client.StreamExperimentData.assert_not_called() - - def test_rejects_existing_file(self): - mock_api_client = self._create_mock_api_client() - - def stream_experiments(request, **kwargs): - del request # unused - yield export_service_pb2.StreamExperimentsResponse(experiment_ids=["123"]) - - mock_api_client.StreamExperiments = stream_experiments - - outdir = os.path.join(self.get_temp_dir(), "outdir") - exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir) - generator = exporter.export() - - with open(os.path.join(outdir, "scalars_123.json"), "w"): - pass - - with self.assertRaises(exporter_lib.OutputFileExistsError): - next(generator) - - mock_api_client.StreamExperimentData.assert_not_called() - - def test_propagates_mkdir_errors(self): - mock_api_client = self._create_mock_api_client() - outdir = os.path.join(self.get_temp_dir(), "some_file", "outdir") - with open(os.path.join(self.get_temp_dir(), "some_file"), "w"): - pass - - with self.assertRaises(OSError): - exporter_lib.TensorBoardExporter(mock_api_client, outdir) - - mock_api_client.StreamExperiments.assert_not_called() - mock_api_client.StreamExperimentData.assert_not_called() + def _create_mock_api_client(self): + return _create_mock_api_client() + + def _make_experiments_response(self, eids): + return export_service_pb2.StreamExperimentsResponse(experiment_ids=eids) + + def test_e2e_success_case(self): + mock_api_client = self._create_mock_api_client() + mock_api_client.StreamExperiments.return_value = iter( + [ + export_service_pb2.StreamExperimentsResponse( + experiment_ids=["789"] + ), + ] + ) + + def stream_experiments(request, **kwargs): + del request # unused + self.assertEqual(kwargs["metadata"], grpc_util.version_metadata()) + yield export_service_pb2.StreamExperimentsResponse( + experiment_ids=["123", "456"] + ) + yield export_service_pb2.StreamExperimentsResponse( + experiment_ids=["789"] + ) + + def stream_experiment_data(request, **kwargs): + self.assertEqual(kwargs["metadata"], grpc_util.version_metadata()) + for run in ("train", "test"): + for tag in ("accuracy", "loss"): + response = export_service_pb2.StreamExperimentDataResponse() + response.run_name = run + response.tag_name = tag + display_name = "%s:%s" % (request.experiment_id, tag) + response.tag_metadata.CopyFrom( + test_util.scalar_metadata(display_name) + ) + for step in range(10): + response.points.steps.append(step) + response.points.values.append(2.0 * step) + response.points.wall_times.add( + seconds=1571084520 + step, nanos=862939144 + ) + yield response + + mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments) + mock_api_client.StreamExperimentData = mock.Mock( + wraps=stream_experiment_data + ) + + outdir = os.path.join(self.get_temp_dir(), "outdir") + exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir) + start_time = 1571084846.25 + start_time_pb = test_util.timestamp_pb(1571084846250000000) + + generator = exporter.export(read_time=start_time) + expected_files = [] + self.assertTrue(os.path.isdir(outdir)) + self.assertCountEqual(expected_files, os.listdir(outdir)) + mock_api_client.StreamExperiments.assert_not_called() + mock_api_client.StreamExperimentData.assert_not_called() + + # The first iteration should request the list of experiments and + # data for one of them. + self.assertEqual(next(generator), "123") + expected_files.append("scalars_123.json") + self.assertCountEqual(expected_files, os.listdir(outdir)) + + expected_eids_request = export_service_pb2.StreamExperimentsRequest() + expected_eids_request.read_timestamp.CopyFrom(start_time_pb) + expected_eids_request.limit = 2 ** 63 - 1 + mock_api_client.StreamExperiments.assert_called_once_with( + expected_eids_request, metadata=grpc_util.version_metadata() + ) + + expected_data_request = export_service_pb2.StreamExperimentDataRequest() + expected_data_request.experiment_id = "123" + expected_data_request.read_timestamp.CopyFrom(start_time_pb) + mock_api_client.StreamExperimentData.assert_called_once_with( + expected_data_request, metadata=grpc_util.version_metadata() + ) + + # The next iteration should just request data for the next experiment. + mock_api_client.StreamExperiments.reset_mock() + mock_api_client.StreamExperimentData.reset_mock() + self.assertEqual(next(generator), "456") + + expected_files.append("scalars_456.json") + self.assertCountEqual(expected_files, os.listdir(outdir)) + mock_api_client.StreamExperiments.assert_not_called() + expected_data_request.experiment_id = "456" + mock_api_client.StreamExperimentData.assert_called_once_with( + expected_data_request, metadata=grpc_util.version_metadata() + ) + + # Again, request data for the next experiment; this experiment ID + # was in the second response batch in the list of IDs. + expected_files.append("scalars_789.json") + mock_api_client.StreamExperiments.reset_mock() + mock_api_client.StreamExperimentData.reset_mock() + self.assertEqual(next(generator), "789") + + self.assertCountEqual(expected_files, os.listdir(outdir)) + mock_api_client.StreamExperiments.assert_not_called() + expected_data_request.experiment_id = "789" + mock_api_client.StreamExperimentData.assert_called_once_with( + expected_data_request, metadata=grpc_util.version_metadata() + ) + + # The final continuation shouldn't need to send any RPCs. + mock_api_client.StreamExperiments.reset_mock() + mock_api_client.StreamExperimentData.reset_mock() + self.assertEqual(list(generator), []) + + self.assertCountEqual(expected_files, os.listdir(outdir)) + mock_api_client.StreamExperiments.assert_not_called() + mock_api_client.StreamExperimentData.assert_not_called() + + # Spot-check one of the files. + with open(os.path.join(outdir, "scalars_456.json")) as infile: + jsons = [json.loads(line) for line in infile] + self.assertLen(jsons, 4) + datum = jsons[2] + self.assertEqual(datum.pop("run"), "test") + self.assertEqual(datum.pop("tag"), "accuracy") + summary_metadata = summary_pb2.SummaryMetadata.FromString( + base64.b64decode(datum.pop("summary_metadata")) + ) + expected_summary_metadata = test_util.scalar_metadata("456:accuracy") + self.assertEqual(summary_metadata, expected_summary_metadata) + points = datum.pop("points") + expected_steps = [x for x in range(10)] + expected_values = [2.0 * x for x in range(10)] + expected_wall_times = [1571084520.862939144 + x for x in range(10)] + self.assertEqual(points.pop("steps"), expected_steps) + self.assertEqual(points.pop("values"), expected_values) + self.assertEqual(points.pop("wall_times"), expected_wall_times) + self.assertEqual(points, {}) + self.assertEqual(datum, {}) + + def test_rejects_dangerous_experiment_ids(self): + mock_api_client = self._create_mock_api_client() + + def stream_experiments(request, **kwargs): + del request # unused + yield export_service_pb2.StreamExperimentsResponse( + experiment_ids=["../authorized_keys"] + ) + + mock_api_client.StreamExperiments = stream_experiments + + outdir = os.path.join(self.get_temp_dir(), "outdir") + exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir) + generator = exporter.export() + + with self.assertRaises(RuntimeError) as cm: + next(generator) + + msg = str(cm.exception) + self.assertIn("Unexpected characters", msg) + self.assertIn(repr(sorted([u".", u"/"])), msg) + self.assertIn("../authorized_keys", msg) + mock_api_client.StreamExperimentData.assert_not_called() + + def test_fails_nicely_on_stream_experiment_data_timeout(self): + # Setup: Client where: + # 1. stream_experiments will say there is one experiment_id. + # 2. stream_experiment_data will raise a grpc CANCELLED, as per + # a timeout. + mock_api_client = self._create_mock_api_client() + experiment_id = "123" + + def stream_experiments(request, **kwargs): + del request # unused + yield export_service_pb2.StreamExperimentsResponse( + experiment_ids=[experiment_id] + ) + + def stream_experiment_data(request, **kwargs): + raise test_util.grpc_error( + grpc.StatusCode.CANCELLED, "details string" + ) + + mock_api_client.StreamExperiments = stream_experiments + mock_api_client.StreamExperimentData = stream_experiment_data + + outdir = os.path.join(self.get_temp_dir(), "outdir") + # Execute: exporter.export() + exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir) + generator = exporter.export() + # Expect: A nice exception of the right type and carrying the right + # experiment_id. + with self.assertRaises(exporter_lib.GrpcTimeoutException) as cm: + next(generator) + self.assertEquals(cm.exception.experiment_id, experiment_id) + + def test_stream_experiment_data_passes_through_unexpected_exception(self): + # Setup: Client where: + # 1. stream_experiments will say there is one experiment_id. + # 2. stream_experiment_data will throw an internal error. + mock_api_client = self._create_mock_api_client() + experiment_id = "123" + + def stream_experiments(request, **kwargs): + del request # unused + yield export_service_pb2.StreamExperimentsResponse( + experiment_ids=[experiment_id] + ) + + def stream_experiment_data(request, **kwargs): + del request # unused + raise test_util.grpc_error( + grpc.StatusCode.INTERNAL, "details string" + ) + + mock_api_client.StreamExperiments = stream_experiments + mock_api_client.StreamExperimentData = stream_experiment_data + + outdir = os.path.join(self.get_temp_dir(), "outdir") + # Execute: exporter.export(). + exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir) + generator = exporter.export() + # Expect: The internal error is passed through. + with self.assertRaises(grpc.RpcError) as cm: + next(generator) + self.assertEquals(cm.exception.details(), "details string") + + def test_handles_outdir_with_no_slash(self): + oldcwd = os.getcwd() + try: + os.chdir(self.get_temp_dir()) + mock_api_client = self._create_mock_api_client() + mock_api_client.StreamExperiments.return_value = iter( + [ + export_service_pb2.StreamExperimentsResponse( + experiment_ids=["123"] + ), + ] + ) + mock_api_client.StreamExperimentData.return_value = iter( + [export_service_pb2.StreamExperimentDataResponse()] + ) + + exporter = exporter_lib.TensorBoardExporter( + mock_api_client, "outdir" + ) + generator = exporter.export() + self.assertEqual(list(generator), ["123"]) + self.assertTrue(os.path.isdir("outdir")) + finally: + os.chdir(oldcwd) + + def test_rejects_existing_directory(self): + mock_api_client = self._create_mock_api_client() + outdir = os.path.join(self.get_temp_dir(), "outdir") + os.mkdir(outdir) + with open(os.path.join(outdir, "scalars_999.json"), "w"): + pass + + with self.assertRaises(exporter_lib.OutputDirectoryExistsError): + exporter_lib.TensorBoardExporter(mock_api_client, outdir) + + mock_api_client.StreamExperiments.assert_not_called() + mock_api_client.StreamExperimentData.assert_not_called() + + def test_rejects_existing_file(self): + mock_api_client = self._create_mock_api_client() + + def stream_experiments(request, **kwargs): + del request # unused + yield export_service_pb2.StreamExperimentsResponse( + experiment_ids=["123"] + ) + + mock_api_client.StreamExperiments = stream_experiments + + outdir = os.path.join(self.get_temp_dir(), "outdir") + exporter = exporter_lib.TensorBoardExporter(mock_api_client, outdir) + generator = exporter.export() + + with open(os.path.join(outdir, "scalars_123.json"), "w"): + pass + + with self.assertRaises(exporter_lib.OutputFileExistsError): + next(generator) + + mock_api_client.StreamExperimentData.assert_not_called() + + def test_propagates_mkdir_errors(self): + mock_api_client = self._create_mock_api_client() + outdir = os.path.join(self.get_temp_dir(), "some_file", "outdir") + with open(os.path.join(self.get_temp_dir(), "some_file"), "w"): + pass + + with self.assertRaises(OSError): + exporter_lib.TensorBoardExporter(mock_api_client, outdir) + + mock_api_client.StreamExperiments.assert_not_called() + mock_api_client.StreamExperimentData.assert_not_called() class ListExperimentsTest(tb_test.TestCase): - - def test_experiment_ids_only(self): - mock_api_client = _create_mock_api_client() - - def stream_experiments(request, **kwargs): - del request # unused - yield export_service_pb2.StreamExperimentsResponse( - experiment_ids=["123", "456"]) - yield export_service_pb2.StreamExperimentsResponse( - experiment_ids=["789"]) - - mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments) - gen = exporter_lib.list_experiments(mock_api_client) - mock_api_client.StreamExperiments.assert_not_called() - self.assertEqual(list(gen), ["123", "456", "789"]) - - def test_mixed_experiments_and_ids(self): - mock_api_client = _create_mock_api_client() - - def stream_experiments(request, **kwargs): - del request # unused - - # Should include `experiment_ids` when no `experiments` given. - response = export_service_pb2.StreamExperimentsResponse() - response.experiment_ids.append("123") - response.experiment_ids.append("456") - yield response - - # Should ignore `experiment_ids` in the presence of `experiments`. - response = export_service_pb2.StreamExperimentsResponse() - response.experiment_ids.append("999") # will be omitted - response.experiments.add(experiment_id="789") - response.experiments.add(experiment_id="012") - yield response - - # Should include `experiments` even when no `experiment_ids` are given. - response = export_service_pb2.StreamExperimentsResponse() - response.experiments.add(experiment_id="345") - response.experiments.add(experiment_id="678") - yield response - - mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments) - gen = exporter_lib.list_experiments(mock_api_client) - mock_api_client.StreamExperiments.assert_not_called() - expected = [ - "123", - "456", - export_service_pb2.Experiment(experiment_id="789"), - export_service_pb2.Experiment(experiment_id="012"), - export_service_pb2.Experiment(experiment_id="345"), - export_service_pb2.Experiment(experiment_id="678"), - ] - self.assertEqual(list(gen), expected) + def test_experiment_ids_only(self): + mock_api_client = _create_mock_api_client() + + def stream_experiments(request, **kwargs): + del request # unused + yield export_service_pb2.StreamExperimentsResponse( + experiment_ids=["123", "456"] + ) + yield export_service_pb2.StreamExperimentsResponse( + experiment_ids=["789"] + ) + + mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments) + gen = exporter_lib.list_experiments(mock_api_client) + mock_api_client.StreamExperiments.assert_not_called() + self.assertEqual(list(gen), ["123", "456", "789"]) + + def test_mixed_experiments_and_ids(self): + mock_api_client = _create_mock_api_client() + + def stream_experiments(request, **kwargs): + del request # unused + + # Should include `experiment_ids` when no `experiments` given. + response = export_service_pb2.StreamExperimentsResponse() + response.experiment_ids.append("123") + response.experiment_ids.append("456") + yield response + + # Should ignore `experiment_ids` in the presence of `experiments`. + response = export_service_pb2.StreamExperimentsResponse() + response.experiment_ids.append("999") # will be omitted + response.experiments.add(experiment_id="789") + response.experiments.add(experiment_id="012") + yield response + + # Should include `experiments` even when no `experiment_ids` are given. + response = export_service_pb2.StreamExperimentsResponse() + response.experiments.add(experiment_id="345") + response.experiments.add(experiment_id="678") + yield response + + mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments) + gen = exporter_lib.list_experiments(mock_api_client) + mock_api_client.StreamExperiments.assert_not_called() + expected = [ + "123", + "456", + export_service_pb2.Experiment(experiment_id="789"), + export_service_pb2.Experiment(experiment_id="012"), + export_service_pb2.Experiment(experiment_id="345"), + export_service_pb2.Experiment(experiment_id="678"), + ] + self.assertEqual(list(gen), expected) class MkdirPTest(tb_test.TestCase): - - def test_makes_full_chain(self): - path = os.path.join(self.get_temp_dir(), "a", "b", "c") - exporter_lib._mkdir_p(path) - self.assertTrue(os.path.isdir(path)) - - def test_makes_leaf(self): - base = os.path.join(self.get_temp_dir(), "a", "b") - exporter_lib._mkdir_p(base) - leaf = os.path.join(self.get_temp_dir(), "a", "b", "c") - exporter_lib._mkdir_p(leaf) - self.assertTrue(os.path.isdir(leaf)) - - def test_fails_when_path_is_a_normal_file(self): - path = os.path.join(self.get_temp_dir(), "somefile") - with open(path, "w"): - pass - with self.assertRaises(OSError) as cm: - exporter_lib._mkdir_p(path) - self.assertEqual(cm.exception.errno, errno.EEXIST) - - def test_propagates_other_errors(self): - base = os.path.join(self.get_temp_dir(), "somefile") - with open(base, "w"): - pass - leaf = os.path.join(self.get_temp_dir(), "somefile", "somedir") - with self.assertRaises(OSError) as cm: - exporter_lib._mkdir_p(leaf) - self.assertNotEqual(cm.exception.errno, errno.EEXIST) - if os.name == "nt": - expected_errno = errno.ENOENT - else: - expected_errno = errno.ENOTDIR - self.assertEqual(cm.exception.errno, expected_errno) + def test_makes_full_chain(self): + path = os.path.join(self.get_temp_dir(), "a", "b", "c") + exporter_lib._mkdir_p(path) + self.assertTrue(os.path.isdir(path)) + + def test_makes_leaf(self): + base = os.path.join(self.get_temp_dir(), "a", "b") + exporter_lib._mkdir_p(base) + leaf = os.path.join(self.get_temp_dir(), "a", "b", "c") + exporter_lib._mkdir_p(leaf) + self.assertTrue(os.path.isdir(leaf)) + + def test_fails_when_path_is_a_normal_file(self): + path = os.path.join(self.get_temp_dir(), "somefile") + with open(path, "w"): + pass + with self.assertRaises(OSError) as cm: + exporter_lib._mkdir_p(path) + self.assertEqual(cm.exception.errno, errno.EEXIST) + + def test_propagates_other_errors(self): + base = os.path.join(self.get_temp_dir(), "somefile") + with open(base, "w"): + pass + leaf = os.path.join(self.get_temp_dir(), "somefile", "somedir") + with self.assertRaises(OSError) as cm: + exporter_lib._mkdir_p(leaf) + self.assertNotEqual(cm.exception.errno, errno.EEXIST) + if os.name == "nt": + expected_errno = errno.ENOENT + else: + expected_errno = errno.ENOTDIR + self.assertEqual(cm.exception.errno, expected_errno) class OpenExclTest(tb_test.TestCase): - - def test_success(self): - path = os.path.join(self.get_temp_dir(), "test.txt") - with exporter_lib._open_excl(path) as outfile: - outfile.write("hello\n") - with open(path) as infile: - self.assertEqual(infile.read(), "hello\n") - - def test_fails_when_file_exists(self): - path = os.path.join(self.get_temp_dir(), "test.txt") - with open(path, "w"): - pass - with self.assertRaises(exporter_lib.OutputFileExistsError) as cm: - exporter_lib._open_excl(path) - self.assertEqual(str(cm.exception), path) - - def test_propagates_other_errors(self): - path = os.path.join(self.get_temp_dir(), "enoent", "test.txt") - with self.assertRaises(OSError) as cm: - exporter_lib._open_excl(path) - self.assertEqual(cm.exception.errno, errno.ENOENT) + def test_success(self): + path = os.path.join(self.get_temp_dir(), "test.txt") + with exporter_lib._open_excl(path) as outfile: + outfile.write("hello\n") + with open(path) as infile: + self.assertEqual(infile.read(), "hello\n") + + def test_fails_when_file_exists(self): + path = os.path.join(self.get_temp_dir(), "test.txt") + with open(path, "w"): + pass + with self.assertRaises(exporter_lib.OutputFileExistsError) as cm: + exporter_lib._open_excl(path) + self.assertEqual(str(cm.exception), path) + + def test_propagates_other_errors(self): + path = os.path.join(self.get_temp_dir(), "enoent", "test.txt") + with self.assertRaises(OSError) as cm: + exporter_lib._open_excl(path) + self.assertEqual(cm.exception.errno, errno.ENOENT) def _create_mock_api_client(): - # Create a stub instance (using a test channel) in order to derive a mock - # from it with autospec enabled. Mocking TensorBoardExporterServiceStub - # itself doesn't work with autospec because grpc constructs stubs via - # metaclassing. - test_channel = grpc_testing.channel( - service_descriptors=[], time=grpc_testing.strict_real_time()) - stub = export_service_pb2_grpc.TensorBoardExporterServiceStub(test_channel) - mock_api_client = mock.create_autospec(stub) - return mock_api_client + # Create a stub instance (using a test channel) in order to derive a mock + # from it with autospec enabled. Mocking TensorBoardExporterServiceStub + # itself doesn't work with autospec because grpc constructs stubs via + # metaclassing. + test_channel = grpc_testing.channel( + service_descriptors=[], time=grpc_testing.strict_real_time() + ) + stub = export_service_pb2_grpc.TensorBoardExporterServiceStub(test_channel) + mock_api_client = mock.create_autospec(stub) + return mock_api_client if __name__ == "__main__": - tb_test.main() + tb_test.main() diff --git a/tensorboard/uploader/logdir_loader.py b/tensorboard/uploader/logdir_loader.py index fc83d2428a..f29094801a 100644 --- a/tensorboard/uploader/logdir_loader.py +++ b/tensorboard/uploader/logdir_loader.py @@ -30,78 +30,82 @@ class LogdirLoader(object): - """Loader for a root log directory, maintaining multiple DirectoryLoaders. + """Loader for a root log directory, maintaining multiple DirectoryLoaders. - This class takes a root log directory and a factory for DirectoryLoaders, and - maintains one DirectoryLoader per "logdir subdirectory" of the root logdir. + This class takes a root log directory and a factory for DirectoryLoaders, and + maintains one DirectoryLoader per "logdir subdirectory" of the root logdir. - Note that this class is not thread-safe. - """ - - def __init__(self, logdir, directory_loader_factory): - """Constructs a new LogdirLoader. - - Args: - logdir: The root log directory to load from. - directory_loader_factory: A factory for creating DirectoryLoaders. The - factory should take a path and return a DirectoryLoader. - - Raises: - ValueError: If logdir or directory_loader_factory are None. + Note that this class is not thread-safe. """ - if logdir is None: - raise ValueError('A logdir is required') - if directory_loader_factory is None: - raise ValueError('A directory loader factory is required') - self._logdir = logdir - self._directory_loader_factory = directory_loader_factory - # Maps run names to corresponding DirectoryLoader instances. - self._directory_loaders = {} - - def synchronize_runs(self): - """Finds new runs within `logdir` and makes `DirectoryLoaders` for them. - - In addition, any existing `DirectoryLoader` whose run directory no longer - exists will be deleted. - """ - logger.info('Starting logdir traversal of %s', self._logdir) - runs_seen = set() - for subdir in io_wrapper.GetLogdirSubdirectories(self._logdir): - run = os.path.relpath(subdir, self._logdir) - runs_seen.add(run) - if run not in self._directory_loaders: - logger.info('- Adding run for relative directory %s', run) - self._directory_loaders[run] = self._directory_loader_factory(subdir) - stale_runs = set(self._directory_loaders) - runs_seen - if stale_runs: - for run in stale_runs: - logger.info('- Removing run for relative directory %s', run) - del self._directory_loaders[run] - logger.info('Ending logdir traversal of %s', self._logdir) - - def get_run_events(self): - """Returns tf.Event generators for each run's `DirectoryLoader`. - - Warning: the generators are stateful and consuming them will affect the - results of any other existing generators for that run; calling code should - ensure it takes events from only a single generator per run at a time. - - Returns: - Dictionary containing an entry for each run, mapping the run name to a - generator yielding tf.Event protobuf objects loaded from that run. - """ - runs = list(self._directory_loaders) - logger.info('Creating event loading generators for %d runs', len(runs)) - run_to_loader = collections.OrderedDict() - for run_name in sorted(runs): - loader = self._directory_loaders[run_name] - run_to_loader[run_name] = self._wrap_loader_generator(loader.Load()) - return run_to_loader - - def _wrap_loader_generator(self, loader_generator): - """Wraps `DirectoryLoader` generator to swallow `DirectoryDeletedError`.""" - try: - for item in loader_generator: - yield item - except directory_watcher.DirectoryDeletedError: - return + + def __init__(self, logdir, directory_loader_factory): + """Constructs a new LogdirLoader. + + Args: + logdir: The root log directory to load from. + directory_loader_factory: A factory for creating DirectoryLoaders. The + factory should take a path and return a DirectoryLoader. + + Raises: + ValueError: If logdir or directory_loader_factory are None. + """ + if logdir is None: + raise ValueError("A logdir is required") + if directory_loader_factory is None: + raise ValueError("A directory loader factory is required") + self._logdir = logdir + self._directory_loader_factory = directory_loader_factory + # Maps run names to corresponding DirectoryLoader instances. + self._directory_loaders = {} + + def synchronize_runs(self): + """Finds new runs within `logdir` and makes `DirectoryLoaders` for + them. + + In addition, any existing `DirectoryLoader` whose run directory + no longer exists will be deleted. + """ + logger.info("Starting logdir traversal of %s", self._logdir) + runs_seen = set() + for subdir in io_wrapper.GetLogdirSubdirectories(self._logdir): + run = os.path.relpath(subdir, self._logdir) + runs_seen.add(run) + if run not in self._directory_loaders: + logger.info("- Adding run for relative directory %s", run) + self._directory_loaders[run] = self._directory_loader_factory( + subdir + ) + stale_runs = set(self._directory_loaders) - runs_seen + if stale_runs: + for run in stale_runs: + logger.info("- Removing run for relative directory %s", run) + del self._directory_loaders[run] + logger.info("Ending logdir traversal of %s", self._logdir) + + def get_run_events(self): + """Returns tf.Event generators for each run's `DirectoryLoader`. + + Warning: the generators are stateful and consuming them will affect the + results of any other existing generators for that run; calling code should + ensure it takes events from only a single generator per run at a time. + + Returns: + Dictionary containing an entry for each run, mapping the run name to a + generator yielding tf.Event protobuf objects loaded from that run. + """ + runs = list(self._directory_loaders) + logger.info("Creating event loading generators for %d runs", len(runs)) + run_to_loader = collections.OrderedDict() + for run_name in sorted(runs): + loader = self._directory_loaders[run_name] + run_to_loader[run_name] = self._wrap_loader_generator(loader.Load()) + return run_to_loader + + def _wrap_loader_generator(self, loader_generator): + """Wraps `DirectoryLoader` generator to swallow + `DirectoryDeletedError`.""" + try: + for item in loader_generator: + yield item + except directory_watcher.DirectoryDeletedError: + return diff --git a/tensorboard/uploader/logdir_loader_test.py b/tensorboard/uploader/logdir_loader_test.py index c3d8e09f3c..0625fbae82 100644 --- a/tensorboard/uploader/logdir_loader_test.py +++ b/tensorboard/uploader/logdir_loader_test.py @@ -31,124 +31,140 @@ class LogdirLoaderTest(tb_test.TestCase): + def _create_logdir_loader(self, logdir): + def directory_loader_factory(path): + return directory_loader.DirectoryLoader( + path, + event_file_loader.TimestampedEventFileLoader, + path_filter=io_wrapper.IsTensorFlowEventsFile, + ) - def _create_logdir_loader(self, logdir): - def directory_loader_factory(path): - return directory_loader.DirectoryLoader( - path, - event_file_loader.TimestampedEventFileLoader, - path_filter=io_wrapper.IsTensorFlowEventsFile) - return logdir_loader.LogdirLoader(logdir, directory_loader_factory) + return logdir_loader.LogdirLoader(logdir, directory_loader_factory) - def _extract_tags(self, event_generator): - """Converts a generator of tf.Events into a list of event tags.""" - return [ - event.summary.value[0].tag for event in event_generator - if not event.file_version - ] + def _extract_tags(self, event_generator): + """Converts a generator of tf.Events into a list of event tags.""" + return [ + event.summary.value[0].tag + for event in event_generator + if not event.file_version + ] - def _extract_run_to_tags(self, run_to_events): - """Returns run-to-tags dict from run-to-event-generator dict.""" - run_to_tags = {} - for run_name, event_generator in six.iteritems(run_to_events): - # There should be no duplicate runs. - self.assertNotIn(run_name, run_to_tags) - run_to_tags[run_name] = self._extract_tags(event_generator) - return run_to_tags + def _extract_run_to_tags(self, run_to_events): + """Returns run-to-tags dict from run-to-event-generator dict.""" + run_to_tags = {} + for run_name, event_generator in six.iteritems(run_to_events): + # There should be no duplicate runs. + self.assertNotIn(run_name, run_to_tags) + run_to_tags[run_name] = self._extract_tags(event_generator) + return run_to_tags - def test_empty_logdir(self): - logdir = self.get_temp_dir() - loader = self._create_logdir_loader(logdir) - # Default state is empty. - self.assertEmpty(list(loader.get_run_events())) - loader.synchronize_runs() - # Still empty, since there's no data. - self.assertEmpty(list(loader.get_run_events())) + def test_empty_logdir(self): + logdir = self.get_temp_dir() + loader = self._create_logdir_loader(logdir) + # Default state is empty. + self.assertEmpty(list(loader.get_run_events())) + loader.synchronize_runs() + # Still empty, since there's no data. + self.assertEmpty(list(loader.get_run_events())) - def test_single_event_logdir(self): - logdir = self.get_temp_dir() - with test_util.FileWriter(logdir) as writer: - writer.add_test_summary("foo") - loader = self._create_logdir_loader(logdir) - loader.synchronize_runs() - self.assertEqual( - self._extract_run_to_tags(loader.get_run_events()), {".": ["foo"]}) - # A second load should indicate no new data for the run. - self.assertEqual( - self._extract_run_to_tags(loader.get_run_events()), {".": []}) + def test_single_event_logdir(self): + logdir = self.get_temp_dir() + with test_util.FileWriter(logdir) as writer: + writer.add_test_summary("foo") + loader = self._create_logdir_loader(logdir) + loader.synchronize_runs() + self.assertEqual( + self._extract_run_to_tags(loader.get_run_events()), {".": ["foo"]} + ) + # A second load should indicate no new data for the run. + self.assertEqual( + self._extract_run_to_tags(loader.get_run_events()), {".": []} + ) - def test_multiple_writes_to_logdir(self): - logdir = self.get_temp_dir() - with test_util.FileWriter(os.path.join(logdir, "a")) as writer: - writer.add_test_summary("tag_a") - with test_util.FileWriter(os.path.join(logdir, "b")) as writer: - writer.add_test_summary("tag_b") - with test_util.FileWriter(os.path.join(logdir, "b", "x")) as writer: - writer.add_test_summary("tag_b_x") - writer_c = test_util.FileWriter(os.path.join(logdir, "c")) - writer_c.add_test_summary("tag_c") - writer_c.flush() - loader = self._create_logdir_loader(logdir) - loader.synchronize_runs() - self.assertEqual( - self._extract_run_to_tags(loader.get_run_events()), - {"a": ["tag_a"], "b": ["tag_b"], "b/x": ["tag_b_x"], "c": ["tag_c"]}) - # A second load should indicate no new data. - self.assertEqual( - self._extract_run_to_tags(loader.get_run_events()), - {"a": [], "b": [], "b/x": [], "c": []}) - # Write some new data to both new and pre-existing event files. - with test_util.FileWriter( - os.path.join(logdir, "a"), filename_suffix=".other") as writer: - writer.add_test_summary("tag_a_2") - writer.add_test_summary("tag_a_3") - writer.add_test_summary("tag_a_4") - with test_util.FileWriter( - os.path.join(logdir, "b", "x"), filename_suffix=".other") as writer: - writer.add_test_summary("tag_b_x_2") - with writer_c as writer: - writer.add_test_summary("tag_c_2") - # New data should appear on the next load. - self.assertEqual( - self._extract_run_to_tags(loader.get_run_events()), { - "a": ["tag_a_2", "tag_a_3", "tag_a_4"], - "b": [], - "b/x": ["tag_b_x_2"], - "c": ["tag_c_2"] - }) + def test_multiple_writes_to_logdir(self): + logdir = self.get_temp_dir() + with test_util.FileWriter(os.path.join(logdir, "a")) as writer: + writer.add_test_summary("tag_a") + with test_util.FileWriter(os.path.join(logdir, "b")) as writer: + writer.add_test_summary("tag_b") + with test_util.FileWriter(os.path.join(logdir, "b", "x")) as writer: + writer.add_test_summary("tag_b_x") + writer_c = test_util.FileWriter(os.path.join(logdir, "c")) + writer_c.add_test_summary("tag_c") + writer_c.flush() + loader = self._create_logdir_loader(logdir) + loader.synchronize_runs() + self.assertEqual( + self._extract_run_to_tags(loader.get_run_events()), + { + "a": ["tag_a"], + "b": ["tag_b"], + "b/x": ["tag_b_x"], + "c": ["tag_c"], + }, + ) + # A second load should indicate no new data. + self.assertEqual( + self._extract_run_to_tags(loader.get_run_events()), + {"a": [], "b": [], "b/x": [], "c": []}, + ) + # Write some new data to both new and pre-existing event files. + with test_util.FileWriter( + os.path.join(logdir, "a"), filename_suffix=".other" + ) as writer: + writer.add_test_summary("tag_a_2") + writer.add_test_summary("tag_a_3") + writer.add_test_summary("tag_a_4") + with test_util.FileWriter( + os.path.join(logdir, "b", "x"), filename_suffix=".other" + ) as writer: + writer.add_test_summary("tag_b_x_2") + with writer_c as writer: + writer.add_test_summary("tag_c_2") + # New data should appear on the next load. + self.assertEqual( + self._extract_run_to_tags(loader.get_run_events()), + { + "a": ["tag_a_2", "tag_a_3", "tag_a_4"], + "b": [], + "b/x": ["tag_b_x_2"], + "c": ["tag_c_2"], + }, + ) - def test_directory_deletion(self): - logdir = self.get_temp_dir() - with test_util.FileWriter(os.path.join(logdir, "a")) as writer: - writer.add_test_summary("tag_a") - with test_util.FileWriter(os.path.join(logdir, "b")) as writer: - writer.add_test_summary("tag_b") - with test_util.FileWriter(os.path.join(logdir, "c")) as writer: - writer.add_test_summary("tag_c") - loader = self._create_logdir_loader(logdir) - loader.synchronize_runs() - self.assertEqual(list(loader.get_run_events().keys()), ["a", "b", "c"]) - shutil.rmtree(os.path.join(logdir, "b")) - loader.synchronize_runs() - self.assertEqual(list(loader.get_run_events().keys()), ["a", "c"]) - shutil.rmtree(logdir) - loader.synchronize_runs() - self.assertEmpty(loader.get_run_events()) + def test_directory_deletion(self): + logdir = self.get_temp_dir() + with test_util.FileWriter(os.path.join(logdir, "a")) as writer: + writer.add_test_summary("tag_a") + with test_util.FileWriter(os.path.join(logdir, "b")) as writer: + writer.add_test_summary("tag_b") + with test_util.FileWriter(os.path.join(logdir, "c")) as writer: + writer.add_test_summary("tag_c") + loader = self._create_logdir_loader(logdir) + loader.synchronize_runs() + self.assertEqual(list(loader.get_run_events().keys()), ["a", "b", "c"]) + shutil.rmtree(os.path.join(logdir, "b")) + loader.synchronize_runs() + self.assertEqual(list(loader.get_run_events().keys()), ["a", "c"]) + shutil.rmtree(logdir) + loader.synchronize_runs() + self.assertEmpty(loader.get_run_events()) - def test_directory_deletion_during_event_loading(self): - logdir = self.get_temp_dir() - with test_util.FileWriter(logdir) as writer: - writer.add_test_summary("foo") - loader = self._create_logdir_loader(logdir) - loader.synchronize_runs() - self.assertEqual( - self._extract_run_to_tags(loader.get_run_events()), {".": ["foo"]}) - shutil.rmtree(logdir) - runs_to_events = loader.get_run_events() - self.assertEqual(list(runs_to_events.keys()), ["."]) - events = runs_to_events["."] - self.assertEqual(self._extract_tags(events), []) + def test_directory_deletion_during_event_loading(self): + logdir = self.get_temp_dir() + with test_util.FileWriter(logdir) as writer: + writer.add_test_summary("foo") + loader = self._create_logdir_loader(logdir) + loader.synchronize_runs() + self.assertEqual( + self._extract_run_to_tags(loader.get_run_events()), {".": ["foo"]} + ) + shutil.rmtree(logdir) + runs_to_events = loader.get_run_events() + self.assertEqual(list(runs_to_events.keys()), ["."]) + events = runs_to_events["."] + self.assertEqual(self._extract_tags(events), []) if __name__ == "__main__": - tb_test.main() + tb_test.main() diff --git a/tensorboard/uploader/peekable_iterator.py b/tensorboard/uploader/peekable_iterator.py index 59dea6a3ae..694bce5ef6 100644 --- a/tensorboard/uploader/peekable_iterator.py +++ b/tensorboard/uploader/peekable_iterator.py @@ -20,68 +20,68 @@ class PeekableIterator(object): - """Iterator adapter that supports peeking ahead. + """Iterator adapter that supports peeking ahead. - As with most Python iterators, this is also iterable; its `__iter__` - returns itself. + As with most Python iterators, this is also iterable; its `__iter__` + returns itself. - This class is not thread-safe. Use external synchronization if - iterating concurrently. - """ - - def __init__(self, iterable): - """Initializes a peeking iterator wrapping the provided iterable. - - Args: - iterable: An iterable to wrap. - """ - self._iterator = iter(iterable) - self._has_peeked = False - self._peeked_element = None - - def has_next(self): - """Checks whether there are any more items in this iterator. - - The next call to `next` or `peek` will raise `StopIteration` if and - only if this method returns `False`. - - Returns: - `True` if there are any more items in this iterator, else `False`. + This class is not thread-safe. Use external synchronization if + iterating concurrently. """ - try: - self.peek() - return True - except StopIteration: - return False - - def peek(self): - """Gets the next item in the iterator without consuming it. - - Multiple consecutive calls will return the same element. - Returns: - The value that would be returned by `next`. - - Raises: - StopIteration: If there are no more items in the iterator. - """ - if not self._has_peeked: - self._peeked_element = next(self._iterator) - self._has_peeked = True - return self._peeked_element - - def __iter__(self): - return self - - def __next__(self): - if self._has_peeked: - self._has_peeked = False - result = self._peeked_element - self._peeked_element = None # allow GC - return result - else: - return next(self._iterator) - - def next(self): - # (Like `__next__`, but Python 2.) - return self.__next__() + def __init__(self, iterable): + """Initializes a peeking iterator wrapping the provided iterable. + + Args: + iterable: An iterable to wrap. + """ + self._iterator = iter(iterable) + self._has_peeked = False + self._peeked_element = None + + def has_next(self): + """Checks whether there are any more items in this iterator. + + The next call to `next` or `peek` will raise `StopIteration` if and + only if this method returns `False`. + + Returns: + `True` if there are any more items in this iterator, else `False`. + """ + try: + self.peek() + return True + except StopIteration: + return False + + def peek(self): + """Gets the next item in the iterator without consuming it. + + Multiple consecutive calls will return the same element. + + Returns: + The value that would be returned by `next`. + + Raises: + StopIteration: If there are no more items in the iterator. + """ + if not self._has_peeked: + self._peeked_element = next(self._iterator) + self._has_peeked = True + return self._peeked_element + + def __iter__(self): + return self + + def __next__(self): + if self._has_peeked: + self._has_peeked = False + result = self._peeked_element + self._peeked_element = None # allow GC + return result + else: + return next(self._iterator) + + def next(self): + # (Like `__next__`, but Python 2.) + return self.__next__() diff --git a/tensorboard/uploader/peekable_iterator_test.py b/tensorboard/uploader/peekable_iterator_test.py index e2c52505d0..9cea3deb7f 100644 --- a/tensorboard/uploader/peekable_iterator_test.py +++ b/tensorboard/uploader/peekable_iterator_test.py @@ -23,46 +23,46 @@ class PeekableIteratorTest(tb_test.TestCase): - """Tests for `PeekableIterator`.""" + """Tests for `PeekableIterator`.""" - def test_empty_iteration(self): - it = peekable_iterator.PeekableIterator([]) - self.assertEqual(list(it), []) + def test_empty_iteration(self): + it = peekable_iterator.PeekableIterator([]) + self.assertEqual(list(it), []) - def test_normal_iteration(self): - it = peekable_iterator.PeekableIterator([1, 2, 3]) - self.assertEqual(list(it), [1, 2, 3]) + def test_normal_iteration(self): + it = peekable_iterator.PeekableIterator([1, 2, 3]) + self.assertEqual(list(it), [1, 2, 3]) - def test_simple_peek(self): - it = peekable_iterator.PeekableIterator([1, 2, 3]) - self.assertEqual(it.peek(), 1) - self.assertEqual(it.peek(), 1) - self.assertEqual(next(it), 1) - self.assertEqual(it.peek(), 2) - self.assertEqual(next(it), 2) - self.assertEqual(next(it), 3) - self.assertEqual(list(it), []) + def test_simple_peek(self): + it = peekable_iterator.PeekableIterator([1, 2, 3]) + self.assertEqual(it.peek(), 1) + self.assertEqual(it.peek(), 1) + self.assertEqual(next(it), 1) + self.assertEqual(it.peek(), 2) + self.assertEqual(next(it), 2) + self.assertEqual(next(it), 3) + self.assertEqual(list(it), []) - def test_simple_has_next(self): - it = peekable_iterator.PeekableIterator([1, 2]) - self.assertTrue(it.has_next()) - self.assertEqual(it.peek(), 1) - self.assertTrue(it.has_next()) - self.assertEqual(next(it), 1) - self.assertEqual(it.peek(), 2) - self.assertTrue(it.has_next()) - self.assertEqual(next(it), 2) - self.assertFalse(it.has_next()) - self.assertFalse(it.has_next()) + def test_simple_has_next(self): + it = peekable_iterator.PeekableIterator([1, 2]) + self.assertTrue(it.has_next()) + self.assertEqual(it.peek(), 1) + self.assertTrue(it.has_next()) + self.assertEqual(next(it), 1) + self.assertEqual(it.peek(), 2) + self.assertTrue(it.has_next()) + self.assertEqual(next(it), 2) + self.assertFalse(it.has_next()) + self.assertFalse(it.has_next()) - def test_peek_after_end(self): - it = peekable_iterator.PeekableIterator([1, 2, 3]) - self.assertEqual(list(it), [1, 2, 3]) - with self.assertRaises(StopIteration): - it.peek() - with self.assertRaises(StopIteration): - it.peek() + def test_peek_after_end(self): + it = peekable_iterator.PeekableIterator([1, 2, 3]) + self.assertEqual(list(it), [1, 2, 3]) + with self.assertRaises(StopIteration): + it.peek() + with self.assertRaises(StopIteration): + it.peek() if __name__ == "__main__": - tb_test.main() + tb_test.main() diff --git a/tensorboard/uploader/server_info.py b/tensorboard/uploader/server_info.py index 5906bb11b3..e381f38bb8 100644 --- a/tensorboard/uploader/server_info.py +++ b/tensorboard/uploader/server_info.py @@ -30,88 +30,88 @@ def _server_info_request(): - request = server_info_pb2.ServerInfoRequest() - request.version = version.VERSION - return request + request = server_info_pb2.ServerInfoRequest() + request.version = version.VERSION + return request def fetch_server_info(origin): - """Fetches server info from a remote server. - - Args: - origin: The server with which to communicate. Should be a string - like "https://tensorboard.dev", including protocol, host, and (if - needed) port. - - Returns: - A `server_info_pb2.ServerInfoResponse` message. - - Raises: - CommunicationError: Upon failure to connect to or successfully - communicate with the remote server. - """ - endpoint = "%s/api/uploader" % origin - post_body = _server_info_request().SerializeToString() - try: - response = requests.post( - endpoint, - data=post_body, - timeout=_REQUEST_TIMEOUT_SECONDS, - headers={"User-Agent": "tensorboard/%s" % version.VERSION}, - ) - except requests.RequestException as e: - raise CommunicationError("Failed to connect to backend: %s" % e) - if not response.ok: - raise CommunicationError( - "Non-OK status from backend (%d %s): %r" - % (response.status_code, response.reason, response.content) - ) - try: - return server_info_pb2.ServerInfoResponse.FromString(response.content) - except message.DecodeError as e: - raise CommunicationError( - "Corrupt response from backend (%s): %r" % (e, response.content) - ) + """Fetches server info from a remote server. + + Args: + origin: The server with which to communicate. Should be a string + like "https://tensorboard.dev", including protocol, host, and (if + needed) port. + + Returns: + A `server_info_pb2.ServerInfoResponse` message. + + Raises: + CommunicationError: Upon failure to connect to or successfully + communicate with the remote server. + """ + endpoint = "%s/api/uploader" % origin + post_body = _server_info_request().SerializeToString() + try: + response = requests.post( + endpoint, + data=post_body, + timeout=_REQUEST_TIMEOUT_SECONDS, + headers={"User-Agent": "tensorboard/%s" % version.VERSION}, + ) + except requests.RequestException as e: + raise CommunicationError("Failed to connect to backend: %s" % e) + if not response.ok: + raise CommunicationError( + "Non-OK status from backend (%d %s): %r" + % (response.status_code, response.reason, response.content) + ) + try: + return server_info_pb2.ServerInfoResponse.FromString(response.content) + except message.DecodeError as e: + raise CommunicationError( + "Corrupt response from backend (%s): %r" % (e, response.content) + ) def create_server_info(frontend_origin, api_endpoint): - """Manually creates server info given a frontend and backend. - - Args: - frontend_origin: The origin of the TensorBoard.dev frontend, like - "https://tensorboard.dev" or "http://localhost:8000". - api_endpoint: As to `server_info_pb2.ApiServer.endpoint`. - - Returns: - A `server_info_pb2.ServerInfoResponse` message. - """ - result = server_info_pb2.ServerInfoResponse() - result.compatibility.verdict = server_info_pb2.VERDICT_OK - result.api_server.endpoint = api_endpoint - url_format = result.url_format - placeholder = "{{EID}}" - while placeholder in frontend_origin: - placeholder = "{%s}" % placeholder - url_format.template = "%s/experiment/%s/" % (frontend_origin, placeholder) - url_format.id_placeholder = placeholder - return result + """Manually creates server info given a frontend and backend. + + Args: + frontend_origin: The origin of the TensorBoard.dev frontend, like + "https://tensorboard.dev" or "http://localhost:8000". + api_endpoint: As to `server_info_pb2.ApiServer.endpoint`. + + Returns: + A `server_info_pb2.ServerInfoResponse` message. + """ + result = server_info_pb2.ServerInfoResponse() + result.compatibility.verdict = server_info_pb2.VERDICT_OK + result.api_server.endpoint = api_endpoint + url_format = result.url_format + placeholder = "{{EID}}" + while placeholder in frontend_origin: + placeholder = "{%s}" % placeholder + url_format.template = "%s/experiment/%s/" % (frontend_origin, placeholder) + url_format.id_placeholder = placeholder + return result def experiment_url(server_info, experiment_id): - """Formats a URL that will resolve to the provided experiment. + """Formats a URL that will resolve to the provided experiment. - Args: - server_info: A `server_info_pb2.ServerInfoResponse` message. - experiment_id: A string; the ID of the experiment to link to. + Args: + server_info: A `server_info_pb2.ServerInfoResponse` message. + experiment_id: A string; the ID of the experiment to link to. - Returns: - A URL resolving to the given experiment, as a string. - """ - url_format = server_info.url_format - return url_format.template.replace(url_format.id_placeholder, experiment_id) + Returns: + A URL resolving to the given experiment, as a string. + """ + url_format = server_info.url_format + return url_format.template.replace(url_format.id_placeholder, experiment_id) class CommunicationError(RuntimeError): - """Raised upon failure to communicate with the server.""" + """Raised upon failure to communicate with the server.""" - pass + pass diff --git a/tensorboard/uploader/server_info_test.py b/tensorboard/uploader/server_info_test.py index 8d50fdf7ad..d891a150ff 100644 --- a/tensorboard/uploader/server_info_test.py +++ b/tensorboard/uploader/server_info_test.py @@ -33,150 +33,154 @@ class FetchServerInfoTest(tb_test.TestCase): - """Tests for `fetch_server_info`.""" - - def _start_server(self, app): - """Starts a server and returns its origin ("http://localhost:PORT").""" - (_, localhost) = _localhost() - server_class = _make_ipv6_compatible_wsgi_server() - server = simple_server.make_server(localhost, 0, app, server_class) - executor = futures.ThreadPoolExecutor() - future = executor.submit(server.serve_forever, poll_interval=0.01) - - def cleanup(): - server.shutdown() # stop handling requests - server.server_close() # release port - future.result(timeout=3) # wait for server termination - - self.addCleanup(cleanup) - if ":" in localhost and not localhost.startswith("["): - # IPv6 IP address, probably "::1". - localhost = "[%s]" % localhost - return "http://%s:%d" % (localhost, server.server_port) - - def test_fetches_response(self): - expected_result = server_info_pb2.ServerInfoResponse() - expected_result.compatibility.verdict = server_info_pb2.VERDICT_OK - expected_result.compatibility.details = "all clear" - expected_result.api_server.endpoint = "api.example.com:443" - expected_result.url_format.template = "http://localhost:8080/{{eid}}" - expected_result.url_format.id_placeholder = "{{eid}}" - - @wrappers.BaseRequest.application - def app(request): - self.assertEqual(request.method, "POST") - self.assertEqual(request.path, "/api/uploader") - body = request.get_data() - request_pb = server_info_pb2.ServerInfoRequest.FromString(body) - self.assertEqual(request_pb.version, version.VERSION) - return wrappers.BaseResponse(expected_result.SerializeToString()) - - origin = self._start_server(app) - result = server_info.fetch_server_info(origin) - self.assertEqual(result, expected_result) - - def test_econnrefused(self): - (family, localhost) = _localhost() - s = socket.socket(family) - s.bind((localhost, 0)) - self.addCleanup(s.close) - port = s.getsockname()[1] - with self.assertRaises(server_info.CommunicationError) as cm: - server_info.fetch_server_info("http://localhost:%d" % port) - msg = str(cm.exception) - self.assertIn("Failed to connect to backend", msg) - if os.name != "nt": - self.assertIn(os.strerror(errno.ECONNREFUSED), msg) - - def test_non_ok_response(self): - @wrappers.BaseRequest.application - def app(request): - del request # unused - return wrappers.BaseResponse(b"very sad", status="502 Bad Gateway") - - origin = self._start_server(app) - with self.assertRaises(server_info.CommunicationError) as cm: - server_info.fetch_server_info(origin) - msg = str(cm.exception) - self.assertIn("Non-OK status from backend (502 Bad Gateway)", msg) - self.assertIn("very sad", msg) - - def test_corrupt_response(self): - @wrappers.BaseRequest.application - def app(request): - del request # unused - return wrappers.BaseResponse(b"an unlikely proto") - - origin = self._start_server(app) - with self.assertRaises(server_info.CommunicationError) as cm: - server_info.fetch_server_info(origin) - msg = str(cm.exception) - self.assertIn("Corrupt response from backend", msg) - self.assertIn("an unlikely proto", msg) - - def test_user_agent(self): - @wrappers.BaseRequest.application - def app(request): - result = server_info_pb2.ServerInfoResponse() - result.compatibility.details = request.headers["User-Agent"] - return wrappers.BaseResponse(result.SerializeToString()) - - origin = self._start_server(app) - result = server_info.fetch_server_info(origin) - expected_user_agent = "tensorboard/%s" % version.VERSION - self.assertEqual(result.compatibility.details, expected_user_agent) + """Tests for `fetch_server_info`.""" + + def _start_server(self, app): + """Starts a server and returns its origin ("http://localhost:PORT").""" + (_, localhost) = _localhost() + server_class = _make_ipv6_compatible_wsgi_server() + server = simple_server.make_server(localhost, 0, app, server_class) + executor = futures.ThreadPoolExecutor() + future = executor.submit(server.serve_forever, poll_interval=0.01) + + def cleanup(): + server.shutdown() # stop handling requests + server.server_close() # release port + future.result(timeout=3) # wait for server termination + + self.addCleanup(cleanup) + if ":" in localhost and not localhost.startswith("["): + # IPv6 IP address, probably "::1". + localhost = "[%s]" % localhost + return "http://%s:%d" % (localhost, server.server_port) + + def test_fetches_response(self): + expected_result = server_info_pb2.ServerInfoResponse() + expected_result.compatibility.verdict = server_info_pb2.VERDICT_OK + expected_result.compatibility.details = "all clear" + expected_result.api_server.endpoint = "api.example.com:443" + expected_result.url_format.template = "http://localhost:8080/{{eid}}" + expected_result.url_format.id_placeholder = "{{eid}}" + + @wrappers.BaseRequest.application + def app(request): + self.assertEqual(request.method, "POST") + self.assertEqual(request.path, "/api/uploader") + body = request.get_data() + request_pb = server_info_pb2.ServerInfoRequest.FromString(body) + self.assertEqual(request_pb.version, version.VERSION) + return wrappers.BaseResponse(expected_result.SerializeToString()) + + origin = self._start_server(app) + result = server_info.fetch_server_info(origin) + self.assertEqual(result, expected_result) + + def test_econnrefused(self): + (family, localhost) = _localhost() + s = socket.socket(family) + s.bind((localhost, 0)) + self.addCleanup(s.close) + port = s.getsockname()[1] + with self.assertRaises(server_info.CommunicationError) as cm: + server_info.fetch_server_info("http://localhost:%d" % port) + msg = str(cm.exception) + self.assertIn("Failed to connect to backend", msg) + if os.name != "nt": + self.assertIn(os.strerror(errno.ECONNREFUSED), msg) + + def test_non_ok_response(self): + @wrappers.BaseRequest.application + def app(request): + del request # unused + return wrappers.BaseResponse(b"very sad", status="502 Bad Gateway") + + origin = self._start_server(app) + with self.assertRaises(server_info.CommunicationError) as cm: + server_info.fetch_server_info(origin) + msg = str(cm.exception) + self.assertIn("Non-OK status from backend (502 Bad Gateway)", msg) + self.assertIn("very sad", msg) + + def test_corrupt_response(self): + @wrappers.BaseRequest.application + def app(request): + del request # unused + return wrappers.BaseResponse(b"an unlikely proto") + + origin = self._start_server(app) + with self.assertRaises(server_info.CommunicationError) as cm: + server_info.fetch_server_info(origin) + msg = str(cm.exception) + self.assertIn("Corrupt response from backend", msg) + self.assertIn("an unlikely proto", msg) + + def test_user_agent(self): + @wrappers.BaseRequest.application + def app(request): + result = server_info_pb2.ServerInfoResponse() + result.compatibility.details = request.headers["User-Agent"] + return wrappers.BaseResponse(result.SerializeToString()) + + origin = self._start_server(app) + result = server_info.fetch_server_info(origin) + expected_user_agent = "tensorboard/%s" % version.VERSION + self.assertEqual(result.compatibility.details, expected_user_agent) class CreateServerInfoTest(tb_test.TestCase): - """Tests for `create_server_info`.""" + """Tests for `create_server_info`.""" - def test(self): - frontend = "http://localhost:8080" - backend = "localhost:10000" - result = server_info.create_server_info(frontend, backend) + def test(self): + frontend = "http://localhost:8080" + backend = "localhost:10000" + result = server_info.create_server_info(frontend, backend) - expected_compatibility = server_info_pb2.Compatibility() - expected_compatibility.verdict = server_info_pb2.VERDICT_OK - expected_compatibility.details = "" - self.assertEqual(result.compatibility, expected_compatibility) + expected_compatibility = server_info_pb2.Compatibility() + expected_compatibility.verdict = server_info_pb2.VERDICT_OK + expected_compatibility.details = "" + self.assertEqual(result.compatibility, expected_compatibility) - expected_api_server = server_info_pb2.ApiServer() - expected_api_server.endpoint = backend - self.assertEqual(result.api_server, expected_api_server) + expected_api_server = server_info_pb2.ApiServer() + expected_api_server.endpoint = backend + self.assertEqual(result.api_server, expected_api_server) - url_format = result.url_format - actual_url = url_format.template.replace(url_format.id_placeholder, "123") - expected_url = "http://localhost:8080/experiment/123/" - self.assertEqual(actual_url, expected_url) + url_format = result.url_format + actual_url = url_format.template.replace( + url_format.id_placeholder, "123" + ) + expected_url = "http://localhost:8080/experiment/123/" + self.assertEqual(actual_url, expected_url) class ExperimentUrlTest(tb_test.TestCase): - """Tests for `experiment_url`.""" + """Tests for `experiment_url`.""" - def test(self): - info = server_info_pb2.ServerInfoResponse() - info.url_format.template = "https://unittest.tensorboard.dev/x/???" - info.url_format.id_placeholder = "???" - actual = server_info.experiment_url(info, "123") - self.assertEqual(actual, "https://unittest.tensorboard.dev/x/123") + def test(self): + info = server_info_pb2.ServerInfoResponse() + info.url_format.template = "https://unittest.tensorboard.dev/x/???" + info.url_format.id_placeholder = "???" + actual = server_info.experiment_url(info, "123") + self.assertEqual(actual, "https://unittest.tensorboard.dev/x/123") def _localhost(): - """Gets family and nodename for a loopback address.""" - s = socket - infos = s.getaddrinfo(None, 0, s.AF_UNSPEC, s.SOCK_STREAM, 0, s.AI_ADDRCONFIG) - (family, _, _, _, address) = infos[0] - nodename = address[0] - return (family, nodename) + """Gets family and nodename for a loopback address.""" + s = socket + infos = s.getaddrinfo( + None, 0, s.AF_UNSPEC, s.SOCK_STREAM, 0, s.AI_ADDRCONFIG + ) + (family, _, _, _, address) = infos[0] + nodename = address[0] + return (family, nodename) def _make_ipv6_compatible_wsgi_server(): - """Creates a `WSGIServer` subclass that works on IPv6-only machines.""" - address_family = _localhost()[0] - attrs = {"address_family": address_family} - bases = (simple_server.WSGIServer, object) # `object` needed for py2 - return type("_Ipv6CompatibleWsgiServer", bases, attrs) + """Creates a `WSGIServer` subclass that works on IPv6-only machines.""" + address_family = _localhost()[0] + attrs = {"address_family": address_family} + bases = (simple_server.WSGIServer, object) # `object` needed for py2 + return type("_Ipv6CompatibleWsgiServer", bases, attrs) if __name__ == "__main__": - tb_test.main() + tb_test.main() diff --git a/tensorboard/uploader/test_util.py b/tensorboard/uploader/test_util.py index d6c5cca9d1..2506ab27df 100644 --- a/tensorboard/uploader/test_util.py +++ b/tensorboard/uploader/test_util.py @@ -27,37 +27,37 @@ class FakeTime(object): - """Thread-safe fake replacement for the `time` module.""" + """Thread-safe fake replacement for the `time` module.""" - def __init__(self, current=0.0): - self._time = float(current) - self._lock = threading.Lock() + def __init__(self, current=0.0): + self._time = float(current) + self._lock = threading.Lock() - def time(self): - with self._lock: - return self._time + def time(self): + with self._lock: + return self._time - def sleep(self, secs): - with self._lock: - self._time += secs + def sleep(self, secs): + with self._lock: + self._time += secs def scalar_metadata(display_name): - """Makes a scalar metadata proto, for constructing expected requests.""" - metadata = summary_pb2.SummaryMetadata(display_name=display_name) - metadata.plugin_data.plugin_name = "scalars" - return metadata + """Makes a scalar metadata proto, for constructing expected requests.""" + metadata = summary_pb2.SummaryMetadata(display_name=display_name) + metadata.plugin_data.plugin_name = "scalars" + return metadata def grpc_error(code, details): - # Monkey patch insertion for the methods a real grpc.RpcError would have. - error = grpc.RpcError("RPC error %r: %s" % (code, details)) - error.code = lambda: code - error.details = lambda: details - return error + # Monkey patch insertion for the methods a real grpc.RpcError would have. + error = grpc.RpcError("RPC error %r: %s" % (code, details)) + error.code = lambda: code + error.details = lambda: details + return error def timestamp_pb(nanos): - result = timestamp_pb2.Timestamp() - result.FromNanoseconds(nanos) - return result + result = timestamp_pb2.Timestamp() + result.FromNanoseconds(nanos) + return result diff --git a/tensorboard/uploader/uploader.py b/tensorboard/uploader/uploader.py index 5cb780b1f5..b9388a202b 100644 --- a/tensorboard/uploader/uploader.py +++ b/tensorboard/uploader/uploader.py @@ -62,352 +62,366 @@ class TensorBoardUploader(object): - """Uploads a TensorBoard logdir to TensorBoard.dev.""" + """Uploads a TensorBoard logdir to TensorBoard.dev.""" + + def __init__(self, writer_client, logdir, rate_limiter=None): + """Constructs a TensorBoardUploader. + + Args: + writer_client: a TensorBoardWriterService stub instance + logdir: path of the log directory to upload + rate_limiter: a `RateLimiter` to use to limit upload cycle frequency + """ + self._api = writer_client + self._logdir = logdir + self._request_builder = None + if rate_limiter is None: + self._rate_limiter = util.RateLimiter( + _MIN_UPLOAD_CYCLE_DURATION_SECS + ) + else: + self._rate_limiter = rate_limiter + active_filter = ( + lambda secs: secs + _EVENT_FILE_INACTIVE_SECS >= time.time() + ) + directory_loader_factory = functools.partial( + directory_loader.DirectoryLoader, + loader_factory=event_file_loader.TimestampedEventFileLoader, + path_filter=io_wrapper.IsTensorFlowEventsFile, + active_filter=active_filter, + ) + self._logdir_loader = logdir_loader.LogdirLoader( + self._logdir, directory_loader_factory + ) + + def create_experiment(self): + """Creates an Experiment for this upload session and returns the ID.""" + logger.info("Creating experiment") + request = write_service_pb2.CreateExperimentRequest() + response = grpc_util.call_with_retries( + self._api.CreateExperiment, request + ) + self._request_builder = _RequestBuilder(response.experiment_id) + return response.experiment_id + + def start_uploading(self): + """Blocks forever to continuously upload data from the logdir. + + Raises: + RuntimeError: If `create_experiment` has not yet been called. + ExperimentNotFoundError: If the experiment is deleted during the + course of the upload. + """ + if self._request_builder is None: + raise RuntimeError( + "Must call create_experiment() before start_uploading()" + ) + while True: + self._upload_once() + + def _upload_once(self): + """Runs one upload cycle, sending zero or more RPCs.""" + logger.info("Starting an upload cycle") + self._rate_limiter.tick() + + sync_start_time = time.time() + self._logdir_loader.synchronize_runs() + sync_duration_secs = time.time() - sync_start_time + logger.info("Logdir sync took %.3f seconds", sync_duration_secs) + + run_to_events = self._logdir_loader.get_run_events() + first_request = True + for request in self._request_builder.build_requests(run_to_events): + if not first_request: + self._rate_limiter.tick() + first_request = False + upload_start_time = time.time() + request_bytes = request.ByteSize() + logger.info("Trying request of %d bytes", request_bytes) + self._upload(request) + upload_duration_secs = time.time() - upload_start_time + logger.info( + "Upload for %d runs (%d bytes) took %.3f seconds", + len(request.runs), + request_bytes, + upload_duration_secs, + ) + + def _upload(self, request): + try: + # TODO(@nfelt): execute this RPC asynchronously. + grpc_util.call_with_retries(self._api.WriteScalar, request) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.NOT_FOUND: + raise ExperimentNotFoundError() + logger.error("Upload call failed with error %s", e) - def __init__(self, writer_client, logdir, rate_limiter=None): - """Constructs a TensorBoardUploader. + +def delete_experiment(writer_client, experiment_id): + """Permanently deletes an experiment and all of its contents. Args: writer_client: a TensorBoardWriterService stub instance - logdir: path of the log directory to upload - rate_limiter: a `RateLimiter` to use to limit upload cycle frequency - """ - self._api = writer_client - self._logdir = logdir - self._request_builder = None - if rate_limiter is None: - self._rate_limiter = util.RateLimiter(_MIN_UPLOAD_CYCLE_DURATION_SECS) - else: - self._rate_limiter = rate_limiter - active_filter = lambda secs: secs + _EVENT_FILE_INACTIVE_SECS >= time.time() - directory_loader_factory = functools.partial( - directory_loader.DirectoryLoader, - loader_factory=event_file_loader.TimestampedEventFileLoader, - path_filter=io_wrapper.IsTensorFlowEventsFile, - active_filter=active_filter) - self._logdir_loader = logdir_loader.LogdirLoader( - self._logdir, directory_loader_factory) - - def create_experiment(self): - """Creates an Experiment for this upload session and returns the ID.""" - logger.info("Creating experiment") - request = write_service_pb2.CreateExperimentRequest() - response = grpc_util.call_with_retries(self._api.CreateExperiment, request) - self._request_builder = _RequestBuilder(response.experiment_id) - return response.experiment_id - - def start_uploading(self): - """Blocks forever to continuously upload data from the logdir. + experiment_id: string ID of the experiment to delete Raises: - RuntimeError: If `create_experiment` has not yet been called. - ExperimentNotFoundError: If the experiment is deleted during the - course of the upload. + ExperimentNotFoundError: If no such experiment exists. + PermissionDeniedError: If the user is not authorized to delete this + experiment. + RuntimeError: On unexpected failure. """ - if self._request_builder is None: - raise RuntimeError( - "Must call create_experiment() before start_uploading()") - while True: - self._upload_once() - - def _upload_once(self): - """Runs one upload cycle, sending zero or more RPCs.""" - logger.info("Starting an upload cycle") - self._rate_limiter.tick() - - sync_start_time = time.time() - self._logdir_loader.synchronize_runs() - sync_duration_secs = time.time() - sync_start_time - logger.info("Logdir sync took %.3f seconds", sync_duration_secs) - - run_to_events = self._logdir_loader.get_run_events() - first_request = True - for request in self._request_builder.build_requests(run_to_events): - if not first_request: - self._rate_limiter.tick() - first_request = False - upload_start_time = time.time() - request_bytes = request.ByteSize() - logger.info("Trying request of %d bytes", request_bytes) - self._upload(request) - upload_duration_secs = time.time() - upload_start_time - logger.info( - "Upload for %d runs (%d bytes) took %.3f seconds", - len(request.runs), - request_bytes, - upload_duration_secs) - - def _upload(self, request): + logger.info("Deleting experiment %r", experiment_id) + request = write_service_pb2.DeleteExperimentRequest() + request.experiment_id = experiment_id try: - # TODO(@nfelt): execute this RPC asynchronously. - grpc_util.call_with_retries(self._api.WriteScalar, request) + grpc_util.call_with_retries(writer_client.DeleteExperiment, request) except grpc.RpcError as e: - if e.code() == grpc.StatusCode.NOT_FOUND: - raise ExperimentNotFoundError() - logger.error("Upload call failed with error %s", e) - - -def delete_experiment(writer_client, experiment_id): - """Permanently deletes an experiment and all of its contents. - - Args: - writer_client: a TensorBoardWriterService stub instance - experiment_id: string ID of the experiment to delete - - Raises: - ExperimentNotFoundError: If no such experiment exists. - PermissionDeniedError: If the user is not authorized to delete this - experiment. - RuntimeError: On unexpected failure. - """ - logger.info("Deleting experiment %r", experiment_id) - request = write_service_pb2.DeleteExperimentRequest() - request.experiment_id = experiment_id - try: - grpc_util.call_with_retries(writer_client.DeleteExperiment, request) - except grpc.RpcError as e: - if e.code() == grpc.StatusCode.NOT_FOUND: - raise ExperimentNotFoundError() - if e.code() == grpc.StatusCode.PERMISSION_DENIED: - raise PermissionDeniedError() - raise + if e.code() == grpc.StatusCode.NOT_FOUND: + raise ExperimentNotFoundError() + if e.code() == grpc.StatusCode.PERMISSION_DENIED: + raise PermissionDeniedError() + raise class ExperimentNotFoundError(RuntimeError): - pass + pass class PermissionDeniedError(RuntimeError): - pass + pass class _OutOfSpaceError(Exception): - """Action could not proceed without overflowing request budget. + """Action could not proceed without overflowing request budget. - This is a signaling exception (like `StopIteration`) used internally - by `_RequestBuilder`; it does not mean that anything has gone wrong. - """ - pass - - -class _RequestBuilder(object): - """Helper class for building requests that fit under a size limit. - - This class is not threadsafe. Use external synchronization if calling - its methods concurrently. - """ - - _NON_SCALAR_TIME_SERIES = object() # sentinel - - def __init__(self, experiment_id): - self._experiment_id = experiment_id - # The request currently being populated. - self._request = None # type: write_service_pb2.WriteScalarRequest - # A lower bound on the number of bytes that we may yet add to the - # request. - self._byte_budget = None # type: int - # Map from `(run_name, tag_name)` to `SummaryMetadata` if the time - # series is a scalar time series, else to `_NON_SCALAR_TIME_SERIES`. - self._tag_metadata = {} - - def _new_request(self): - """Allocates a new request and refreshes the budget.""" - self._request = write_service_pb2.WriteScalarRequest() - self._byte_budget = _MAX_REQUEST_LENGTH_BYTES - self._request.experiment_id = self._experiment_id - self._byte_budget -= self._request.ByteSize() - if self._byte_budget < 0: - raise RuntimeError("Byte budget too small for experiment ID") - - def build_requests(self, run_to_events): - """Converts a stream of TF events to a stream of outgoing requests. - - Each yielded request will be at most `_MAX_REQUEST_LENGTH_BYTES` - bytes long. - - Args: - run_to_events: Mapping from run name to generator of `tf.Event` - values, as returned by `LogdirLoader.get_run_events`. + This is a signaling exception (like `StopIteration`) used internally + by `_RequestBuilder`; it does not mean that anything has gone wrong. + """ - Yields: - A finite stream of `WriteScalarRequest` objects. + pass - Raises: - RuntimeError: If no progress can be made because even a single - point is too large (say, due to a gigabyte-long tag name). - """ - self._new_request() - runs = {} # cache: map from run name to `Run` proto in request - tags = {} # cache: map from `(run, tag)` to `Tag` proto in run in request - work_items = peekable_iterator.PeekableIterator( - self._run_values(run_to_events)) - - while work_items.has_next(): - (run_name, event, orig_value) = work_items.peek() - value = data_compat.migrate_value(orig_value) - time_series_key = (run_name, value.tag) - - metadata = self._tag_metadata.get(time_series_key) - if metadata is None: - plugin_name = value.metadata.plugin_data.plugin_name - if plugin_name == scalar_metadata.PLUGIN_NAME: - metadata = value.metadata - else: - metadata = _RequestBuilder._NON_SCALAR_TIME_SERIES - self._tag_metadata[time_series_key] = metadata - if metadata is _RequestBuilder._NON_SCALAR_TIME_SERIES: - next(work_items) - continue - try: - run_proto = runs.get(run_name) - if run_proto is None: - run_proto = self._create_run(run_name) - runs[run_name] = run_proto - tag_proto = tags.get((run_name, value.tag)) - if tag_proto is None: - tag_proto = self._create_tag(run_proto, value.tag, metadata) - tags[(run_name, value.tag)] = tag_proto - self._create_point(tag_proto, event, value) - next(work_items) - except _OutOfSpaceError: - # Flush request and start a new one. - request_to_emit = self._prune_request() - if request_to_emit is None: - raise RuntimeError("Could not make progress uploading data") - self._new_request() - runs.clear() - tags.clear() - yield request_to_emit - - final_request = self._prune_request() - if final_request is not None: - yield final_request - - def _run_values(self, run_to_events): - """Helper generator to create a single stream of work items.""" - # Note that each of these joins in principle has deletion anomalies: - # if the input stream contains runs with no events, or events with - # no values, we'll lose that information. This is not a problem: we - # would need to prune such data from the request anyway. - for (run_name, events) in six.iteritems(run_to_events): - for event in events: - for value in event.summary.value: - yield (run_name, event, value) - - def _prune_request(self): - """Removes empty runs and tags from the active request. - - This does not refund `self._byte_budget`; it is assumed that the - request will be emitted immediately, anyway. +class _RequestBuilder(object): + """Helper class for building requests that fit under a size limit. - Returns: - The active request, or `None` if after pruning the request - contains no data. + This class is not threadsafe. Use external synchronization if + calling its methods concurrently. """ - request = self._request - for (run_idx, run) in reversed(list(enumerate(request.runs))): - for (tag_idx, tag) in reversed(list(enumerate(run.tags))): - if not tag.points: - del run.tags[tag_idx] - if not run.tags: - del self._request.runs[run_idx] - if not request.runs: - request = None - return request - - def _create_run(self, run_name): - """Adds a run to the live request, if there's space. - - Args: - run_name: String name of the run to add. - Returns: - The `WriteScalarRequest.Run` that was added to `request.runs`. + _NON_SCALAR_TIME_SERIES = object() # sentinel + + def __init__(self, experiment_id): + self._experiment_id = experiment_id + # The request currently being populated. + self._request = None # type: write_service_pb2.WriteScalarRequest + # A lower bound on the number of bytes that we may yet add to the + # request. + self._byte_budget = None # type: int + # Map from `(run_name, tag_name)` to `SummaryMetadata` if the time + # series is a scalar time series, else to `_NON_SCALAR_TIME_SERIES`. + self._tag_metadata = {} + + def _new_request(self): + """Allocates a new request and refreshes the budget.""" + self._request = write_service_pb2.WriteScalarRequest() + self._byte_budget = _MAX_REQUEST_LENGTH_BYTES + self._request.experiment_id = self._experiment_id + self._byte_budget -= self._request.ByteSize() + if self._byte_budget < 0: + raise RuntimeError("Byte budget too small for experiment ID") + + def build_requests(self, run_to_events): + """Converts a stream of TF events to a stream of outgoing requests. + + Each yielded request will be at most `_MAX_REQUEST_LENGTH_BYTES` + bytes long. + + Args: + run_to_events: Mapping from run name to generator of `tf.Event` + values, as returned by `LogdirLoader.get_run_events`. + + Yields: + A finite stream of `WriteScalarRequest` objects. + + Raises: + RuntimeError: If no progress can be made because even a single + point is too large (say, due to a gigabyte-long tag name). + """ - Raises: - _OutOfSpaceError: If adding the run would exceed the remaining - request budget. - """ - run_proto = self._request.runs.add(name=run_name) - # We can't calculate the proto key cost exactly ahead of time, as - # it depends on the total size of all tags. Be conservative. - cost = run_proto.ByteSize() + _MAX_VARINT64_LENGTH_BYTES + 1 - if cost > self._byte_budget: - raise _OutOfSpaceError() - self._byte_budget -= cost - return run_proto - - def _create_tag(self, run_proto, tag_name, metadata): - """Adds a tag for the given value, if there's space. + self._new_request() + runs = {} # cache: map from run name to `Run` proto in request + tags = ( + {} + ) # cache: map from `(run, tag)` to `Tag` proto in run in request + work_items = peekable_iterator.PeekableIterator( + self._run_values(run_to_events) + ) + + while work_items.has_next(): + (run_name, event, orig_value) = work_items.peek() + value = data_compat.migrate_value(orig_value) + time_series_key = (run_name, value.tag) + + metadata = self._tag_metadata.get(time_series_key) + if metadata is None: + plugin_name = value.metadata.plugin_data.plugin_name + if plugin_name == scalar_metadata.PLUGIN_NAME: + metadata = value.metadata + else: + metadata = _RequestBuilder._NON_SCALAR_TIME_SERIES + self._tag_metadata[time_series_key] = metadata + if metadata is _RequestBuilder._NON_SCALAR_TIME_SERIES: + next(work_items) + continue + try: + run_proto = runs.get(run_name) + if run_proto is None: + run_proto = self._create_run(run_name) + runs[run_name] = run_proto + tag_proto = tags.get((run_name, value.tag)) + if tag_proto is None: + tag_proto = self._create_tag(run_proto, value.tag, metadata) + tags[(run_name, value.tag)] = tag_proto + self._create_point(tag_proto, event, value) + next(work_items) + except _OutOfSpaceError: + # Flush request and start a new one. + request_to_emit = self._prune_request() + if request_to_emit is None: + raise RuntimeError("Could not make progress uploading data") + self._new_request() + runs.clear() + tags.clear() + yield request_to_emit + + final_request = self._prune_request() + if final_request is not None: + yield final_request + + def _run_values(self, run_to_events): + """Helper generator to create a single stream of work items.""" + # Note that each of these joins in principle has deletion anomalies: + # if the input stream contains runs with no events, or events with + # no values, we'll lose that information. This is not a problem: we + # would need to prune such data from the request anyway. + for (run_name, events) in six.iteritems(run_to_events): + for event in events: + for value in event.summary.value: + yield (run_name, event, value) + + def _prune_request(self): + """Removes empty runs and tags from the active request. + + This does not refund `self._byte_budget`; it is assumed that the + request will be emitted immediately, anyway. + + Returns: + The active request, or `None` if after pruning the request + contains no data. + """ + request = self._request + for (run_idx, run) in reversed(list(enumerate(request.runs))): + for (tag_idx, tag) in reversed(list(enumerate(run.tags))): + if not tag.points: + del run.tags[tag_idx] + if not run.tags: + del self._request.runs[run_idx] + if not request.runs: + request = None + return request + + def _create_run(self, run_name): + """Adds a run to the live request, if there's space. + + Args: + run_name: String name of the run to add. + + Returns: + The `WriteScalarRequest.Run` that was added to `request.runs`. + + Raises: + _OutOfSpaceError: If adding the run would exceed the remaining + request budget. + """ + run_proto = self._request.runs.add(name=run_name) + # We can't calculate the proto key cost exactly ahead of time, as + # it depends on the total size of all tags. Be conservative. + cost = run_proto.ByteSize() + _MAX_VARINT64_LENGTH_BYTES + 1 + if cost > self._byte_budget: + raise _OutOfSpaceError() + self._byte_budget -= cost + return run_proto + + def _create_tag(self, run_proto, tag_name, metadata): + """Adds a tag for the given value, if there's space. + + Args: + run_proto: `WriteScalarRequest.Run` proto to which to add a tag. + tag_name: String name of the tag to add (as `value.tag`). + metadata: TensorBoard `SummaryMetadata` proto from the first + occurrence of this time series. + + Returns: + The `WriteScalarRequest.Tag` that was added to `run_proto.tags`. + + Raises: + _OutOfSpaceError: If adding the tag would exceed the remaining + request budget. + """ + tag_proto = run_proto.tags.add(name=tag_name) + tag_proto.metadata.CopyFrom(metadata) + submessage_cost = tag_proto.ByteSize() + # We can't calculate the proto key cost exactly ahead of time, as + # it depends on the number of points. Be conservative. + cost = submessage_cost + _MAX_VARINT64_LENGTH_BYTES + 1 + if cost > self._byte_budget: + raise _OutOfSpaceError() + self._byte_budget -= cost + return tag_proto + + def _create_point(self, tag_proto, event, value): + """Adds a scalar point to the given tag, if there's space. + + Args: + tag_proto: `WriteScalarRequest.Tag` proto to which to add a point. + event: Enclosing `Event` proto with the step and wall time data. + value: Scalar `Summary.Value` proto with the actual scalar data. + + Returns: + The `ScalarPoint` that was added to `tag_proto.points`. + + Raises: + _OutOfSpaceError: If adding the point would exceed the remaining + request budget. + """ + point = tag_proto.points.add() + point.step = event.step + # TODO(@nfelt): skip tensor roundtrip for Value with simple_value set + point.value = tensor_util.make_ndarray(value.tensor).item() + util.set_timestamp(point.wall_time, event.wall_time) + submessage_cost = point.ByteSize() + cost = submessage_cost + _varint_cost(submessage_cost) + 1 # proto key + if cost > self._byte_budget: + tag_proto.points.pop() + raise _OutOfSpaceError() + self._byte_budget -= cost + return point - Args: - run_proto: `WriteScalarRequest.Run` proto to which to add a tag. - tag_name: String name of the tag to add (as `value.tag`). - metadata: TensorBoard `SummaryMetadata` proto from the first - occurrence of this time series. - Returns: - The `WriteScalarRequest.Tag` that was added to `run_proto.tags`. +def _varint_cost(n): + """Computes the size of `n` encoded as an unsigned base-128 varint. - Raises: - _OutOfSpaceError: If adding the tag would exceed the remaining - request budget. - """ - tag_proto = run_proto.tags.add(name=tag_name) - tag_proto.metadata.CopyFrom(metadata) - submessage_cost = tag_proto.ByteSize() - # We can't calculate the proto key cost exactly ahead of time, as - # it depends on the number of points. Be conservative. - cost = submessage_cost + _MAX_VARINT64_LENGTH_BYTES + 1 - if cost > self._byte_budget: - raise _OutOfSpaceError() - self._byte_budget -= cost - return tag_proto - - def _create_point(self, tag_proto, event, value): - """Adds a scalar point to the given tag, if there's space. + This should be consistent with the proto wire format: + Args: - tag_proto: `WriteScalarRequest.Tag` proto to which to add a point. - event: Enclosing `Event` proto with the step and wall time data. - value: Scalar `Summary.Value` proto with the actual scalar data. + n: A non-negative integer. Returns: - The `ScalarPoint` that was added to `tag_proto.points`. - - Raises: - _OutOfSpaceError: If adding the point would exceed the remaining - request budget. + An integer number of bytes. """ - point = tag_proto.points.add() - point.step = event.step - # TODO(@nfelt): skip tensor roundtrip for Value with simple_value set - point.value = tensor_util.make_ndarray(value.tensor).item() - util.set_timestamp(point.wall_time, event.wall_time) - submessage_cost = point.ByteSize() - cost = submessage_cost + _varint_cost(submessage_cost) + 1 # proto key - if cost > self._byte_budget: - tag_proto.points.pop() - raise _OutOfSpaceError() - self._byte_budget -= cost - return point - - -def _varint_cost(n): - """Computes the size of `n` encoded as an unsigned base-128 varint. - - This should be consistent with the proto wire format: - - - Args: - n: A non-negative integer. - - Returns: - An integer number of bytes. - """ - result = 1 - while n >= 128: - result += 1 - n >>= 7 - return result + result = 1 + while n >= 128: + result += 1 + n >>= 7 + return result diff --git a/tensorboard/uploader/uploader_main.py b/tensorboard/uploader/uploader_main.py index 87ac1b6276..d9679f7cc9 100644 --- a/tensorboard/uploader/uploader_main.py +++ b/tensorboard/uploader/uploader_main.py @@ -60,253 +60,275 @@ """ -_SUBCOMMAND_FLAG = '_uploader__subcommand' -_SUBCOMMAND_KEY_UPLOAD = 'UPLOAD' -_SUBCOMMAND_KEY_DELETE = 'DELETE' -_SUBCOMMAND_KEY_LIST = 'LIST' -_SUBCOMMAND_KEY_EXPORT = 'EXPORT' -_SUBCOMMAND_KEY_AUTH = 'AUTH' -_AUTH_SUBCOMMAND_FLAG = '_uploader__subcommand_auth' -_AUTH_SUBCOMMAND_KEY_REVOKE = 'REVOKE' +_SUBCOMMAND_FLAG = "_uploader__subcommand" +_SUBCOMMAND_KEY_UPLOAD = "UPLOAD" +_SUBCOMMAND_KEY_DELETE = "DELETE" +_SUBCOMMAND_KEY_LIST = "LIST" +_SUBCOMMAND_KEY_EXPORT = "EXPORT" +_SUBCOMMAND_KEY_AUTH = "AUTH" +_AUTH_SUBCOMMAND_FLAG = "_uploader__subcommand_auth" +_AUTH_SUBCOMMAND_KEY_REVOKE = "REVOKE" _DEFAULT_ORIGIN = "https://tensorboard.dev" def _prompt_for_user_ack(intent): - """Prompts for user consent, exiting the program if they decline.""" - body = intent.get_ack_message_body() - header = '\n***** TensorBoard Uploader *****\n' - user_ack_message = '\n'.join((header, body, _MESSAGE_TOS)) - sys.stderr.write(user_ack_message) - sys.stderr.write('\n') - response = six.moves.input('Continue? (yes/NO) ') - if response.lower() not in ('y', 'yes'): - sys.exit(0) - sys.stderr.write('\n') + """Prompts for user consent, exiting the program if they decline.""" + body = intent.get_ack_message_body() + header = "\n***** TensorBoard Uploader *****\n" + user_ack_message = "\n".join((header, body, _MESSAGE_TOS)) + sys.stderr.write(user_ack_message) + sys.stderr.write("\n") + response = six.moves.input("Continue? (yes/NO) ") + if response.lower() not in ("y", "yes"): + sys.exit(0) + sys.stderr.write("\n") def _define_flags(parser): - """Configures flags on the provided argument parser. + """Configures flags on the provided argument parser. - Integration point for `tensorboard.program`'s subcommand system. + Integration point for `tensorboard.program`'s subcommand system. - Args: - parser: An `argparse.ArgumentParser` to be mutated. - """ + Args: + parser: An `argparse.ArgumentParser` to be mutated. + """ - subparsers = parser.add_subparsers() - - parser.add_argument( - '--origin', - type=str, - default='', - help='Experimental. Origin for TensorBoard.dev service to which ' - 'to connect. If not set, defaults to %r.' % _DEFAULT_ORIGIN) - - parser.add_argument( - '--api_endpoint', - type=str, - default='', - help='Experimental. Direct URL for the API server accepting ' - 'write requests. If set, will skip initial server handshake ' - 'unless `--origin` is also set.') - - parser.add_argument( - '--grpc_creds_type', - type=str, - default='ssl', - choices=('local', 'ssl', 'ssl_dev'), - help='The type of credentials to use for the gRPC client') - - parser.add_argument( - '--auth_force_console', - action='store_true', - help='Set to true to force authentication flow to use the ' - '--console rather than a browser redirect to localhost.') - - upload = subparsers.add_parser( - 'upload', help='upload an experiment to TensorBoard.dev') - upload.set_defaults(**{_SUBCOMMAND_FLAG: _SUBCOMMAND_KEY_UPLOAD}) - upload.add_argument( - '--logdir', - metavar='PATH', - type=str, - default=None, - help='Directory containing the logs to process') - - delete = subparsers.add_parser( - 'delete', - help='permanently delete an experiment', - inherited_absl_flags=None) - delete.set_defaults(**{_SUBCOMMAND_FLAG: _SUBCOMMAND_KEY_DELETE}) - # We would really like to call this next flag `--experiment` rather - # than `--experiment_id`, but this is broken inside Google due to a - # long-standing Python bug: - # (Some Google-internal dependencies define `--experimental_*` flags.) - # This isn't exactly a principled fix, but it gets the job done. - delete.add_argument( - '--experiment_id', - metavar='EXPERIMENT_ID', - type=str, - default=None, - help='ID of an experiment to delete permanently') - - list_parser = subparsers.add_parser( - 'list', help='list previously uploaded experiments') - list_parser.set_defaults(**{_SUBCOMMAND_FLAG: _SUBCOMMAND_KEY_LIST}) - - export = subparsers.add_parser( - 'export', help='download all your experiment data') - export.set_defaults(**{_SUBCOMMAND_FLAG: _SUBCOMMAND_KEY_EXPORT}) - export.add_argument( - '--outdir', - metavar='OUTPUT_PATH', - type=str, - default=None, - help='Directory into which to download all experiment data; ' - 'must not yet exist') - - auth_parser = subparsers.add_parser('auth', help='log in, log out') - auth_parser.set_defaults(**{_SUBCOMMAND_FLAG: _SUBCOMMAND_KEY_AUTH}) - auth_subparsers = auth_parser.add_subparsers() - - auth_revoke = auth_subparsers.add_parser( - 'revoke', help='revoke all existing credentials and log out') - auth_revoke.set_defaults( - **{_AUTH_SUBCOMMAND_FLAG: _AUTH_SUBCOMMAND_KEY_REVOKE}) - - -def _parse_flags(argv=('',)): - """Integration point for `absl.app`. - - Exits if flag values are invalid. - - Args: - argv: CLI arguments, as with `sys.argv`, where the first argument is taken - to be the name of the program being executed. - - Returns: - Either argv[:1] if argv was non-empty, or [''] otherwise, as a mechanism - for absl.app.run() compatibility. - """ - parser = argparse_flags.ArgumentParser( - prog='uploader', - description=('Upload your TensorBoard experiments to TensorBoard.dev')) - _define_flags(parser) - arg0 = argv[0] if argv else '' - global _FLAGS - _FLAGS = parser.parse_args(argv[1:]) - return [arg0] + subparsers = parser.add_subparsers() + parser.add_argument( + "--origin", + type=str, + default="", + help="Experimental. Origin for TensorBoard.dev service to which " + "to connect. If not set, defaults to %r." % _DEFAULT_ORIGIN, + ) -def _run(flags): - """Runs the main uploader program given parsed flags. + parser.add_argument( + "--api_endpoint", + type=str, + default="", + help="Experimental. Direct URL for the API server accepting " + "write requests. If set, will skip initial server handshake " + "unless `--origin` is also set.", + ) - Args: - flags: An `argparse.Namespace`. - """ + parser.add_argument( + "--grpc_creds_type", + type=str, + default="ssl", + choices=("local", "ssl", "ssl_dev"), + help="The type of credentials to use for the gRPC client", + ) - logging.set_stderrthreshold(logging.WARNING) - intent = _get_intent(flags) + parser.add_argument( + "--auth_force_console", + action="store_true", + help="Set to true to force authentication flow to use the " + "--console rather than a browser redirect to localhost.", + ) - store = auth.CredentialsStore() - if isinstance(intent, _AuthRevokeIntent): - store.clear() - sys.stderr.write('Logged out of uploader.\n') - sys.stderr.flush() - return - # TODO(b/141723268): maybe reconfirm Google Account prior to reuse. - credentials = store.read_credentials() - if not credentials: - _prompt_for_user_ack(intent) - client_config = json.loads(auth.OAUTH_CLIENT_CONFIG) - flow = auth.build_installed_app_flow(client_config) - credentials = flow.run(force_console=flags.auth_force_console) - sys.stderr.write('\n') # Extra newline after auth flow messages. - store.write_credentials(credentials) - - channel_options = None - if flags.grpc_creds_type == 'local': - channel_creds = grpc.local_channel_credentials() - elif flags.grpc_creds_type == 'ssl': - channel_creds = grpc.ssl_channel_credentials() - elif flags.grpc_creds_type == 'ssl_dev': - channel_creds = grpc.ssl_channel_credentials(dev_creds.DEV_SSL_CERT) - channel_options = [('grpc.ssl_target_name_override', 'localhost')] - else: - msg = 'Invalid --grpc_creds_type %s' % flags.grpc_creds_type - raise base_plugin.FlagsError(msg) - - try: - server_info = _get_server_info(flags) - except server_info_lib.CommunicationError as e: - _die(str(e)) - _handle_server_info(server_info) - - if not server_info.api_server.endpoint: - logging.error('Server info response: %s', server_info) - _die('Internal error: frontend did not specify an API server') - composite_channel_creds = grpc.composite_channel_credentials( - channel_creds, auth.id_token_call_credentials(credentials)) - - # TODO(@nfelt): In the `_UploadIntent` case, consider waiting until - # logdir exists to open channel. - channel = grpc.secure_channel( - server_info.api_server.endpoint, - composite_channel_creds, - options=channel_options) - with channel: - intent.execute(server_info, channel) + upload = subparsers.add_parser( + "upload", help="upload an experiment to TensorBoard.dev" + ) + upload.set_defaults(**{_SUBCOMMAND_FLAG: _SUBCOMMAND_KEY_UPLOAD}) + upload.add_argument( + "--logdir", + metavar="PATH", + type=str, + default=None, + help="Directory containing the logs to process", + ) + delete = subparsers.add_parser( + "delete", + help="permanently delete an experiment", + inherited_absl_flags=None, + ) + delete.set_defaults(**{_SUBCOMMAND_FLAG: _SUBCOMMAND_KEY_DELETE}) + # We would really like to call this next flag `--experiment` rather + # than `--experiment_id`, but this is broken inside Google due to a + # long-standing Python bug: + # (Some Google-internal dependencies define `--experimental_*` flags.) + # This isn't exactly a principled fix, but it gets the job done. + delete.add_argument( + "--experiment_id", + metavar="EXPERIMENT_ID", + type=str, + default=None, + help="ID of an experiment to delete permanently", + ) -@six.add_metaclass(abc.ABCMeta) -class _Intent(object): - """A description of the user's intent in invoking this program. + list_parser = subparsers.add_parser( + "list", help="list previously uploaded experiments" + ) + list_parser.set_defaults(**{_SUBCOMMAND_FLAG: _SUBCOMMAND_KEY_LIST}) - Each valid set of CLI flags corresponds to one intent: e.g., "upload - data from this logdir", or "delete the experiment with that ID". - """ + export = subparsers.add_parser( + "export", help="download all your experiment data" + ) + export.set_defaults(**{_SUBCOMMAND_FLAG: _SUBCOMMAND_KEY_EXPORT}) + export.add_argument( + "--outdir", + metavar="OUTPUT_PATH", + type=str, + default=None, + help="Directory into which to download all experiment data; " + "must not yet exist", + ) - @abc.abstractmethod - def get_ack_message_body(self): - """Gets the message to show when executing this intent at first login. + auth_parser = subparsers.add_parser("auth", help="log in, log out") + auth_parser.set_defaults(**{_SUBCOMMAND_FLAG: _SUBCOMMAND_KEY_AUTH}) + auth_subparsers = auth_parser.add_subparsers() + + auth_revoke = auth_subparsers.add_parser( + "revoke", help="revoke all existing credentials and log out" + ) + auth_revoke.set_defaults( + **{_AUTH_SUBCOMMAND_FLAG: _AUTH_SUBCOMMAND_KEY_REVOKE} + ) - This need not include the header (program name) or Terms of Service - notice. + +def _parse_flags(argv=("",)): + """Integration point for `absl.app`. + + Exits if flag values are invalid. + + Args: + argv: CLI arguments, as with `sys.argv`, where the first argument is taken + to be the name of the program being executed. Returns: - A Unicode string, potentially spanning multiple lines. + Either argv[:1] if argv was non-empty, or [''] otherwise, as a mechanism + for absl.app.run() compatibility. """ - pass + parser = argparse_flags.ArgumentParser( + prog="uploader", + description=("Upload your TensorBoard experiments to TensorBoard.dev"), + ) + _define_flags(parser) + arg0 = argv[0] if argv else "" + global _FLAGS + _FLAGS = parser.parse_args(argv[1:]) + return [arg0] - @abc.abstractmethod - def execute(self, server_info, channel): - """Carries out this intent with the specified gRPC channel. + +def _run(flags): + """Runs the main uploader program given parsed flags. Args: - server_info: A `server_info_pb2.ServerInfoResponse` value. - channel: A connected gRPC channel whose server provides the TensorBoard - reader and writer services. + flags: An `argparse.Namespace`. """ - pass + + logging.set_stderrthreshold(logging.WARNING) + intent = _get_intent(flags) + + store = auth.CredentialsStore() + if isinstance(intent, _AuthRevokeIntent): + store.clear() + sys.stderr.write("Logged out of uploader.\n") + sys.stderr.flush() + return + # TODO(b/141723268): maybe reconfirm Google Account prior to reuse. + credentials = store.read_credentials() + if not credentials: + _prompt_for_user_ack(intent) + client_config = json.loads(auth.OAUTH_CLIENT_CONFIG) + flow = auth.build_installed_app_flow(client_config) + credentials = flow.run(force_console=flags.auth_force_console) + sys.stderr.write("\n") # Extra newline after auth flow messages. + store.write_credentials(credentials) + + channel_options = None + if flags.grpc_creds_type == "local": + channel_creds = grpc.local_channel_credentials() + elif flags.grpc_creds_type == "ssl": + channel_creds = grpc.ssl_channel_credentials() + elif flags.grpc_creds_type == "ssl_dev": + channel_creds = grpc.ssl_channel_credentials(dev_creds.DEV_SSL_CERT) + channel_options = [("grpc.ssl_target_name_override", "localhost")] + else: + msg = "Invalid --grpc_creds_type %s" % flags.grpc_creds_type + raise base_plugin.FlagsError(msg) + + try: + server_info = _get_server_info(flags) + except server_info_lib.CommunicationError as e: + _die(str(e)) + _handle_server_info(server_info) + + if not server_info.api_server.endpoint: + logging.error("Server info response: %s", server_info) + _die("Internal error: frontend did not specify an API server") + composite_channel_creds = grpc.composite_channel_credentials( + channel_creds, auth.id_token_call_credentials(credentials) + ) + + # TODO(@nfelt): In the `_UploadIntent` case, consider waiting until + # logdir exists to open channel. + channel = grpc.secure_channel( + server_info.api_server.endpoint, + composite_channel_creds, + options=channel_options, + ) + with channel: + intent.execute(server_info, channel) + + +@six.add_metaclass(abc.ABCMeta) +class _Intent(object): + """A description of the user's intent in invoking this program. + + Each valid set of CLI flags corresponds to one intent: e.g., "upload + data from this logdir", or "delete the experiment with that ID". + """ + + @abc.abstractmethod + def get_ack_message_body(self): + """Gets the message to show when executing this intent at first login. + + This need not include the header (program name) or Terms of Service + notice. + + Returns: + A Unicode string, potentially spanning multiple lines. + """ + pass + + @abc.abstractmethod + def execute(self, server_info, channel): + """Carries out this intent with the specified gRPC channel. + + Args: + server_info: A `server_info_pb2.ServerInfoResponse` value. + channel: A connected gRPC channel whose server provides the TensorBoard + reader and writer services. + """ + pass class _AuthRevokeIntent(_Intent): - """The user intends to revoke credentials.""" + """The user intends to revoke credentials.""" + + def get_ack_message_body(self): + """Must not be called.""" + raise AssertionError("No user ack needed to revoke credentials") - def get_ack_message_body(self): - """Must not be called.""" - raise AssertionError('No user ack needed to revoke credentials') + def execute(self, server_info, channel): + """Execute handled specially by `main`. - def execute(self, server_info, channel): - """Execute handled specially by `main`. Must not be called.""" - raise AssertionError('_AuthRevokeIntent should not be directly executed') + Must not be called. + """ + raise AssertionError( + "_AuthRevokeIntent should not be directly executed" + ) class _DeleteExperimentIntent(_Intent): - """The user intends to delete an experiment.""" + """The user intends to delete an experiment.""" - _MESSAGE_TEMPLATE = textwrap.dedent(u"""\ + _MESSAGE_TEMPLATE = textwrap.dedent( + u"""\ This will delete the experiment on https://tensorboard.dev with the following experiment ID: @@ -315,90 +337,102 @@ class _DeleteExperimentIntent(_Intent): You have chosen to delete an experiment. All experiments uploaded to TensorBoard.dev are publicly visible. Do not upload sensitive data. - """) - - def __init__(self, experiment_id): - self.experiment_id = experiment_id - - def get_ack_message_body(self): - return self._MESSAGE_TEMPLATE.format(experiment_id=self.experiment_id) + """ + ) - def execute(self, server_info, channel): - api_client = write_service_pb2_grpc.TensorBoardWriterServiceStub(channel) - experiment_id = self.experiment_id - if not experiment_id: - raise base_plugin.FlagsError( - 'Must specify a non-empty experiment ID to delete.') - try: - uploader_lib.delete_experiment(api_client, experiment_id) - except uploader_lib.ExperimentNotFoundError: - _die( - 'No such experiment %s. Either it never existed or it has ' - 'already been deleted.' % experiment_id) - except uploader_lib.PermissionDeniedError: - _die( - 'Cannot delete experiment %s because it is owned by a ' - 'different user.' % experiment_id) - except grpc.RpcError as e: - _die('Internal error deleting experiment: %s' % e) - print('Deleted experiment %s.' % experiment_id) + def __init__(self, experiment_id): + self.experiment_id = experiment_id + + def get_ack_message_body(self): + return self._MESSAGE_TEMPLATE.format(experiment_id=self.experiment_id) + + def execute(self, server_info, channel): + api_client = write_service_pb2_grpc.TensorBoardWriterServiceStub( + channel + ) + experiment_id = self.experiment_id + if not experiment_id: + raise base_plugin.FlagsError( + "Must specify a non-empty experiment ID to delete." + ) + try: + uploader_lib.delete_experiment(api_client, experiment_id) + except uploader_lib.ExperimentNotFoundError: + _die( + "No such experiment %s. Either it never existed or it has " + "already been deleted." % experiment_id + ) + except uploader_lib.PermissionDeniedError: + _die( + "Cannot delete experiment %s because it is owned by a " + "different user." % experiment_id + ) + except grpc.RpcError as e: + _die("Internal error deleting experiment: %s" % e) + print("Deleted experiment %s." % experiment_id) class _ListIntent(_Intent): - """The user intends to list all their experiments.""" + """The user intends to list all their experiments.""" - _MESSAGE = textwrap.dedent(u"""\ + _MESSAGE = textwrap.dedent( + u"""\ This will list all experiments that you've uploaded to https://tensorboard.dev. TensorBoard.dev experiments are visible to everyone. Do not upload sensitive data. - """) - - def get_ack_message_body(self): - return self._MESSAGE - - def execute(self, server_info, channel): - api_client = export_service_pb2_grpc.TensorBoardExporterServiceStub(channel) - fieldmask = export_service_pb2.ExperimentMask( - create_time=True, - update_time=True, - num_scalars=True, - num_runs=True, - num_tags=True, + """ ) - gen = exporter_lib.list_experiments(api_client, fieldmask=fieldmask) - count = 0 - for experiment in gen: - count += 1 - if not isinstance(experiment, export_service_pb2.Experiment): - url = server_info_lib.experiment_url(server_info, experiment) - print(url) - continue - experiment_id = experiment.experiment_id - url = server_info_lib.experiment_url(server_info, experiment_id) - print(url) - data = [ - ('Id', experiment.experiment_id), - ('Created', util.format_time(experiment.create_time)), - ('Updated', util.format_time(experiment.update_time)), - ('Scalars', str(experiment.num_scalars)), - ('Runs', str(experiment.num_runs)), - ('Tags', str(experiment.num_tags)), - ] - for (name, value) in data: - print('\t%s %s' % (name.ljust(10), value)) - sys.stdout.flush() - if not count: - sys.stderr.write( - 'No experiments. Use `tensorboard dev upload` to get started.\n') - else: - sys.stderr.write('Total: %d experiment(s)\n' % count) - sys.stderr.flush() + + def get_ack_message_body(self): + return self._MESSAGE + + def execute(self, server_info, channel): + api_client = export_service_pb2_grpc.TensorBoardExporterServiceStub( + channel + ) + fieldmask = export_service_pb2.ExperimentMask( + create_time=True, + update_time=True, + num_scalars=True, + num_runs=True, + num_tags=True, + ) + gen = exporter_lib.list_experiments(api_client, fieldmask=fieldmask) + count = 0 + for experiment in gen: + count += 1 + if not isinstance(experiment, export_service_pb2.Experiment): + url = server_info_lib.experiment_url(server_info, experiment) + print(url) + continue + experiment_id = experiment.experiment_id + url = server_info_lib.experiment_url(server_info, experiment_id) + print(url) + data = [ + ("Id", experiment.experiment_id), + ("Created", util.format_time(experiment.create_time)), + ("Updated", util.format_time(experiment.update_time)), + ("Scalars", str(experiment.num_scalars)), + ("Runs", str(experiment.num_runs)), + ("Tags", str(experiment.num_tags)), + ] + for (name, value) in data: + print("\t%s %s" % (name.ljust(10), value)) + sys.stdout.flush() + if not count: + sys.stderr.write( + "No experiments. Use `tensorboard dev upload` to get started.\n" + ) + else: + sys.stderr.write("Total: %d experiment(s)\n" % count) + sys.stderr.flush() class _UploadIntent(_Intent): - """The user intends to upload an experiment from the given logdir.""" + """The user intends to upload an experiment from the given logdir.""" - _MESSAGE_TEMPLATE = textwrap.dedent(u"""\ + _MESSAGE_TEMPLATE = textwrap.dedent( + u"""\ This will upload your TensorBoard logs to https://tensorboard.dev/ from the following directory: @@ -406,40 +440,46 @@ class _UploadIntent(_Intent): This TensorBoard will be visible to everyone. Do not upload sensitive data. - """) - - def __init__(self, logdir): - self.logdir = logdir - - def get_ack_message_body(self): - return self._MESSAGE_TEMPLATE.format(logdir=self.logdir) - - def execute(self, server_info, channel): - api_client = write_service_pb2_grpc.TensorBoardWriterServiceStub(channel) - uploader = uploader_lib.TensorBoardUploader(api_client, self.logdir) - experiment_id = uploader.create_experiment() - url = server_info_lib.experiment_url(server_info, experiment_id) - print("Upload started and will continue reading any new data as it's added") - print("to the logdir. To stop uploading, press Ctrl-C.") - print("View your TensorBoard live at: %s" % url) - try: - uploader.start_uploading() - except uploader_lib.ExperimentNotFoundError: - print('Experiment was deleted; uploading has been cancelled') - return - except KeyboardInterrupt: - print() - print('Upload stopped. View your TensorBoard at %s' % url) - return - # TODO(@nfelt): make it possible for the upload cycle to end once we - # detect that no more runs are active, so this code can be reached. - print('Done! View your TensorBoard at %s' % url) + """ + ) + + def __init__(self, logdir): + self.logdir = logdir + + def get_ack_message_body(self): + return self._MESSAGE_TEMPLATE.format(logdir=self.logdir) + + def execute(self, server_info, channel): + api_client = write_service_pb2_grpc.TensorBoardWriterServiceStub( + channel + ) + uploader = uploader_lib.TensorBoardUploader(api_client, self.logdir) + experiment_id = uploader.create_experiment() + url = server_info_lib.experiment_url(server_info, experiment_id) + print( + "Upload started and will continue reading any new data as it's added" + ) + print("to the logdir. To stop uploading, press Ctrl-C.") + print("View your TensorBoard live at: %s" % url) + try: + uploader.start_uploading() + except uploader_lib.ExperimentNotFoundError: + print("Experiment was deleted; uploading has been cancelled") + return + except KeyboardInterrupt: + print() + print("Upload stopped. View your TensorBoard at %s" % url) + return + # TODO(@nfelt): make it possible for the upload cycle to end once we + # detect that no more runs are active, so this code can be reached. + print("Done! View your TensorBoard at %s" % url) class _ExportIntent(_Intent): - """The user intends to download all their experiment data.""" + """The user intends to download all their experiment data.""" - _MESSAGE_TEMPLATE = textwrap.dedent(u"""\ + _MESSAGE_TEMPLATE = textwrap.dedent( + u"""\ This will download all your experiment data from https://tensorboard.dev and save it to the following directory: @@ -448,139 +488,148 @@ class _ExportIntent(_Intent): Downloading your experiment data does not delete it from the service. All experiments uploaded to TensorBoard.dev are publicly visible. Do not upload sensitive data. - """) - - def __init__(self, output_dir): - self.output_dir = output_dir - - def get_ack_message_body(self): - return self._MESSAGE_TEMPLATE.format(output_dir=self.output_dir) + """ + ) - def execute(self, server_info, channel): - api_client = export_service_pb2_grpc.TensorBoardExporterServiceStub(channel) - outdir = self.output_dir - try: - exporter = exporter_lib.TensorBoardExporter(api_client, outdir) - except exporter_lib.OutputDirectoryExistsError: - msg = 'Output directory already exists: %r' % outdir - raise base_plugin.FlagsError(msg) - num_experiments = 0 - try: - for experiment_id in exporter.export(): - num_experiments += 1 - print('Downloaded experiment %s' % experiment_id) - except exporter_lib.GrpcTimeoutException as e: - print( - '\nUploader has failed because of a timeout error. Please reach ' - 'out via e-mail to tensorboard.dev-support@google.com to get help ' - 'completing your export of experiment %s.' % e.experiment_id) - print('Done. Downloaded %d experiments to: %s' % (num_experiments, outdir)) + def __init__(self, output_dir): + self.output_dir = output_dir + + def get_ack_message_body(self): + return self._MESSAGE_TEMPLATE.format(output_dir=self.output_dir) + + def execute(self, server_info, channel): + api_client = export_service_pb2_grpc.TensorBoardExporterServiceStub( + channel + ) + outdir = self.output_dir + try: + exporter = exporter_lib.TensorBoardExporter(api_client, outdir) + except exporter_lib.OutputDirectoryExistsError: + msg = "Output directory already exists: %r" % outdir + raise base_plugin.FlagsError(msg) + num_experiments = 0 + try: + for experiment_id in exporter.export(): + num_experiments += 1 + print("Downloaded experiment %s" % experiment_id) + except exporter_lib.GrpcTimeoutException as e: + print( + "\nUploader has failed because of a timeout error. Please reach " + "out via e-mail to tensorboard.dev-support@google.com to get help " + "completing your export of experiment %s." % e.experiment_id + ) + print( + "Done. Downloaded %d experiments to: %s" % (num_experiments, outdir) + ) def _get_intent(flags): - """Determines what the program should do (upload, delete, ...). + """Determines what the program should do (upload, delete, ...). - Args: - flags: An `argparse.Namespace` with the parsed flags. + Args: + flags: An `argparse.Namespace` with the parsed flags. - Returns: - An `_Intent` instance. + Returns: + An `_Intent` instance. - Raises: - base_plugin.FlagsError: If the command-line `flags` do not correctly - specify an intent. - """ - cmd = getattr(flags, _SUBCOMMAND_FLAG, None) - if cmd is None: - raise base_plugin.FlagsError('Must specify subcommand (try --help).') - if cmd == _SUBCOMMAND_KEY_UPLOAD: - if flags.logdir: - return _UploadIntent(os.path.expanduser(flags.logdir)) - else: - raise base_plugin.FlagsError( - 'Must specify directory to upload via `--logdir`.') - elif cmd == _SUBCOMMAND_KEY_DELETE: - if flags.experiment_id: - return _DeleteExperimentIntent(flags.experiment_id) - else: - raise base_plugin.FlagsError( - 'Must specify experiment to delete via `--experiment_id`.') - elif cmd == _SUBCOMMAND_KEY_LIST: - return _ListIntent() - elif cmd == _SUBCOMMAND_KEY_EXPORT: - if flags.outdir: - return _ExportIntent(flags.outdir) - else: - raise base_plugin.FlagsError( - 'Must specify output directory via `--outdir`.') - elif cmd == _SUBCOMMAND_KEY_AUTH: - auth_cmd = getattr(flags, _AUTH_SUBCOMMAND_FLAG, None) - if auth_cmd is None: - raise base_plugin.FlagsError('Must specify a subcommand to `auth`.') - if auth_cmd == _AUTH_SUBCOMMAND_KEY_REVOKE: - return _AuthRevokeIntent() + Raises: + base_plugin.FlagsError: If the command-line `flags` do not correctly + specify an intent. + """ + cmd = getattr(flags, _SUBCOMMAND_FLAG, None) + if cmd is None: + raise base_plugin.FlagsError("Must specify subcommand (try --help).") + if cmd == _SUBCOMMAND_KEY_UPLOAD: + if flags.logdir: + return _UploadIntent(os.path.expanduser(flags.logdir)) + else: + raise base_plugin.FlagsError( + "Must specify directory to upload via `--logdir`." + ) + elif cmd == _SUBCOMMAND_KEY_DELETE: + if flags.experiment_id: + return _DeleteExperimentIntent(flags.experiment_id) + else: + raise base_plugin.FlagsError( + "Must specify experiment to delete via `--experiment_id`." + ) + elif cmd == _SUBCOMMAND_KEY_LIST: + return _ListIntent() + elif cmd == _SUBCOMMAND_KEY_EXPORT: + if flags.outdir: + return _ExportIntent(flags.outdir) + else: + raise base_plugin.FlagsError( + "Must specify output directory via `--outdir`." + ) + elif cmd == _SUBCOMMAND_KEY_AUTH: + auth_cmd = getattr(flags, _AUTH_SUBCOMMAND_FLAG, None) + if auth_cmd is None: + raise base_plugin.FlagsError("Must specify a subcommand to `auth`.") + if auth_cmd == _AUTH_SUBCOMMAND_KEY_REVOKE: + return _AuthRevokeIntent() + else: + raise AssertionError("Unknown auth subcommand %r" % (auth_cmd,)) else: - raise AssertionError('Unknown auth subcommand %r' % (auth_cmd,)) - else: - raise AssertionError('Unknown subcommand %r' % (cmd,)) + raise AssertionError("Unknown subcommand %r" % (cmd,)) def _get_server_info(flags): - origin = flags.origin or _DEFAULT_ORIGIN - if flags.api_endpoint and not flags.origin: - return server_info_lib.create_server_info(origin, flags.api_endpoint) - server_info = server_info_lib.fetch_server_info(origin) - # Override with any API server explicitly specified on the command - # line, but only if the server accepted our initial handshake. - if flags.api_endpoint and server_info.api_server.endpoint: - server_info.api_server.endpoint = flags.api_endpoint - return server_info + origin = flags.origin or _DEFAULT_ORIGIN + if flags.api_endpoint and not flags.origin: + return server_info_lib.create_server_info(origin, flags.api_endpoint) + server_info = server_info_lib.fetch_server_info(origin) + # Override with any API server explicitly specified on the command + # line, but only if the server accepted our initial handshake. + if flags.api_endpoint and server_info.api_server.endpoint: + server_info.api_server.endpoint = flags.api_endpoint + return server_info def _handle_server_info(info): - compat = info.compatibility - if compat.verdict == server_info_pb2.VERDICT_WARN: - sys.stderr.write('Warning [from server]: %s\n' % compat.details) - sys.stderr.flush() - elif compat.verdict == server_info_pb2.VERDICT_ERROR: - _die('Error [from server]: %s' % compat.details) - else: - # OK or unknown; assume OK. - if compat.details: - sys.stderr.write('%s\n' % compat.details) - sys.stderr.flush() + compat = info.compatibility + if compat.verdict == server_info_pb2.VERDICT_WARN: + sys.stderr.write("Warning [from server]: %s\n" % compat.details) + sys.stderr.flush() + elif compat.verdict == server_info_pb2.VERDICT_ERROR: + _die("Error [from server]: %s" % compat.details) + else: + # OK or unknown; assume OK. + if compat.details: + sys.stderr.write("%s\n" % compat.details) + sys.stderr.flush() def _die(message): - sys.stderr.write('%s\n' % (message,)) - sys.stderr.flush() - sys.exit(1) + sys.stderr.write("%s\n" % (message,)) + sys.stderr.flush() + sys.exit(1) def main(unused_argv): - global _FLAGS - flags = _FLAGS - # Prevent accidental use of `_FLAGS` until migration to TensorBoard - # subcommand is complete, at which point `_FLAGS` goes away. - del _FLAGS - return _run(flags) + global _FLAGS + flags = _FLAGS + # Prevent accidental use of `_FLAGS` until migration to TensorBoard + # subcommand is complete, at which point `_FLAGS` goes away. + del _FLAGS + return _run(flags) class UploaderSubcommand(program.TensorBoardSubcommand): - """Integration point with `tensorboard` CLI.""" + """Integration point with `tensorboard` CLI.""" - def name(self): - return 'dev' + def name(self): + return "dev" - def define_flags(self, parser): - _define_flags(parser) + def define_flags(self, parser): + _define_flags(parser) - def run(self, flags): - return _run(flags) + def run(self, flags): + return _run(flags) - def help(self): - return 'upload data to TensorBoard.dev' + def help(self): + return "upload data to TensorBoard.dev" -if __name__ == '__main__': - app.run(main, flags_parser=_parse_flags) +if __name__ == "__main__": + app.run(main, flags_parser=_parse_flags) diff --git a/tensorboard/uploader/uploader_test.py b/tensorboard/uploader/uploader_test.py index db760b30c2..26a9fdd912 100644 --- a/tensorboard/uploader/uploader_test.py +++ b/tensorboard/uploader/uploader_test.py @@ -25,10 +25,10 @@ import grpc_testing try: - # python version >= 3.3 - from unittest import mock + # python version >= 3.3 + from unittest import mock except ImportError: - import mock # pylint: disable=unused-import + import mock # pylint: disable=unused-import import tensorflow as tf @@ -47,616 +47,685 @@ class AbortUploadError(Exception): - """Exception used in testing to abort the upload process.""" + """Exception used in testing to abort the upload process.""" class TensorboardUploaderTest(tf.test.TestCase): - - def _create_mock_client(self): - # Create a stub instance (using a test channel) in order to derive a mock - # from it with autospec enabled. Mocking TensorBoardWriterServiceStub itself - # doesn't work with autospec because grpc constructs stubs via metaclassing. - test_channel = grpc_testing.channel( - service_descriptors=[], time=grpc_testing.strict_real_time()) - stub = write_service_pb2_grpc.TensorBoardWriterServiceStub(test_channel) - mock_client = mock.create_autospec(stub) - fake_exp_response = write_service_pb2.CreateExperimentResponse( - experiment_id="123", url="should not be used!") - mock_client.CreateExperiment.return_value = fake_exp_response - return mock_client - - def test_create_experiment(self): - logdir = "/logs/foo" - mock_client = self._create_mock_client() - uploader = uploader_lib.TensorBoardUploader(mock_client, logdir) - eid = uploader.create_experiment() - self.assertEqual(eid, "123") - - def test_start_uploading_without_create_experiment_fails(self): - mock_client = self._create_mock_client() - uploader = uploader_lib.TensorBoardUploader(mock_client, "/logs/foo") - with self.assertRaisesRegex(RuntimeError, "call create_experiment()"): - uploader.start_uploading() - - def test_start_uploading(self): - mock_client = self._create_mock_client() - mock_rate_limiter = mock.create_autospec(util.RateLimiter) - uploader = uploader_lib.TensorBoardUploader( - mock_client, "/logs/foo", mock_rate_limiter) - uploader.create_experiment() - mock_builder = mock.create_autospec(uploader_lib._RequestBuilder) - request = write_service_pb2.WriteScalarRequest() - mock_builder.build_requests.side_effect = [ - iter([request, request]), - iter([request, request, request, request, request]), - AbortUploadError, - ] - # pylint: disable=g-backslash-continuation - with mock.patch.object(uploader, "_upload") as mock_upload, \ - mock.patch.object(uploader, "_request_builder", mock_builder), \ - self.assertRaises(AbortUploadError): - uploader.start_uploading() - # pylint: enable=g-backslash-continuation - self.assertEqual(7, mock_upload.call_count) - self.assertEqual(2 + 5 + 1, mock_rate_limiter.tick.call_count) - - def test_upload_empty_logdir(self): - logdir = self.get_temp_dir() - mock_client = self._create_mock_client() - mock_rate_limiter = mock.create_autospec(util.RateLimiter) - uploader = uploader_lib.TensorBoardUploader( - mock_client, logdir, mock_rate_limiter) - uploader.create_experiment() - uploader._upload_once() - mock_client.WriteScalar.assert_not_called() - - def test_upload_swallows_rpc_failure(self): - logdir = self.get_temp_dir() - with tb_test_util.FileWriter(logdir) as writer: - writer.add_test_summary("foo") - mock_client = self._create_mock_client() - mock_rate_limiter = mock.create_autospec(util.RateLimiter) - uploader = uploader_lib.TensorBoardUploader( - mock_client, logdir, mock_rate_limiter) - uploader.create_experiment() - error = test_util.grpc_error(grpc.StatusCode.INTERNAL, "Failure") - mock_client.WriteScalar.side_effect = error - uploader._upload_once() - mock_client.WriteScalar.assert_called_once() - - def test_upload_propagates_experiment_deletion(self): - logdir = self.get_temp_dir() - with tb_test_util.FileWriter(logdir) as writer: - writer.add_test_summary("foo") - mock_client = self._create_mock_client() - mock_rate_limiter = mock.create_autospec(util.RateLimiter) - uploader = uploader_lib.TensorBoardUploader( - mock_client, logdir, mock_rate_limiter) - uploader.create_experiment() - error = test_util.grpc_error(grpc.StatusCode.NOT_FOUND, "nope") - mock_client.WriteScalar.side_effect = error - with self.assertRaises(uploader_lib.ExperimentNotFoundError): - uploader._upload_once() - - def test_upload_preserves_wall_time(self): - logdir = self.get_temp_dir() - with tb_test_util.FileWriter(logdir) as writer: - # Add a raw event so we can specify the wall_time value deterministically. - writer.add_event( - event_pb2.Event( - step=1, - wall_time=123.123123123, - summary=scalar_v2.scalar_pb("foo", 5.0))) - mock_client = self._create_mock_client() - mock_rate_limiter = mock.create_autospec(util.RateLimiter) - uploader = uploader_lib.TensorBoardUploader( - mock_client, logdir, mock_rate_limiter) - uploader.create_experiment() - uploader._upload_once() - mock_client.WriteScalar.assert_called_once() - request = mock_client.WriteScalar.call_args[0][0] - # Just check the wall_time value; everything else is covered in the full - # logdir test below. - self.assertEqual( - 123123123123, - request.runs[0].tags[0].points[0].wall_time.ToNanoseconds()) - - def test_upload_full_logdir(self): - logdir = self.get_temp_dir() - mock_client = self._create_mock_client() - mock_rate_limiter = mock.create_autospec(util.RateLimiter) - uploader = uploader_lib.TensorBoardUploader( - mock_client, logdir, mock_rate_limiter) - uploader.create_experiment() - - # Convenience helpers for constructing expected requests. - run = write_service_pb2.WriteScalarRequest.Run - tag = write_service_pb2.WriteScalarRequest.Tag - point = scalar_pb2.ScalarPoint - - # First round - writer = tb_test_util.FileWriter(logdir) - writer.add_test_summary("foo", simple_value=5.0, step=1) - writer.add_test_summary("foo", simple_value=6.0, step=2) - writer.add_test_summary("foo", simple_value=7.0, step=3) - writer.add_test_summary("bar", simple_value=8.0, step=3) - writer.flush() - writer_a = tb_test_util.FileWriter(os.path.join(logdir, "a")) - writer_a.add_test_summary("qux", simple_value=9.0, step=2) - writer_a.flush() - uploader._upload_once() - self.assertEqual(1, mock_client.WriteScalar.call_count) - request1 = mock_client.WriteScalar.call_args[0][0] - _clear_wall_times(request1) - expected_request1 = write_service_pb2.WriteScalarRequest( - experiment_id="123", - runs=[ - run(name=".", - tags=[ - tag(name="foo", - metadata=test_util.scalar_metadata("foo"), - points=[ - point(step=1, value=5.0), - point(step=2, value=6.0), - point(step=3, value=7.0), - ]), - tag(name="bar", - metadata=test_util.scalar_metadata("bar"), - points=[ - point(step=3, value=8.0), - ]), - ]), - run(name="a", - tags=[ - tag(name="qux", - metadata=test_util.scalar_metadata("qux"), - points=[ - point(step=2, value=9.0), - ]), - ]), - ]) - self.assertProtoEquals(expected_request1, request1) - mock_client.WriteScalar.reset_mock() - - # Second round - writer.add_test_summary("foo", simple_value=10.0, step=5) - writer.add_test_summary("baz", simple_value=11.0, step=1) - writer.flush() - writer_b = tb_test_util.FileWriter(os.path.join(logdir, "b")) - writer_b.add_test_summary("xyz", simple_value=12.0, step=1) - writer_b.flush() - uploader._upload_once() - self.assertEqual(1, mock_client.WriteScalar.call_count) - request2 = mock_client.WriteScalar.call_args[0][0] - _clear_wall_times(request2) - expected_request2 = write_service_pb2.WriteScalarRequest( - experiment_id="123", - runs=[ - run(name=".", - tags=[ - tag(name="foo", - metadata=test_util.scalar_metadata("foo"), - points=[ - point(step=5, value=10.0), - ]), - tag(name="baz", - metadata=test_util.scalar_metadata("baz"), - points=[ - point(step=1, value=11.0), - ]), - ]), - run(name="b", - tags=[ - tag(name="xyz", - metadata=test_util.scalar_metadata("xyz"), - points=[ - point(step=1, value=12.0), - ]), - ]), - ]) - self.assertProtoEquals(expected_request2, request2) - mock_client.WriteScalar.reset_mock() - - # Empty third round - uploader._upload_once() - mock_client.WriteScalar.assert_not_called() + def _create_mock_client(self): + # Create a stub instance (using a test channel) in order to derive a mock + # from it with autospec enabled. Mocking TensorBoardWriterServiceStub itself + # doesn't work with autospec because grpc constructs stubs via metaclassing. + test_channel = grpc_testing.channel( + service_descriptors=[], time=grpc_testing.strict_real_time() + ) + stub = write_service_pb2_grpc.TensorBoardWriterServiceStub(test_channel) + mock_client = mock.create_autospec(stub) + fake_exp_response = write_service_pb2.CreateExperimentResponse( + experiment_id="123", url="should not be used!" + ) + mock_client.CreateExperiment.return_value = fake_exp_response + return mock_client + + def test_create_experiment(self): + logdir = "/logs/foo" + mock_client = self._create_mock_client() + uploader = uploader_lib.TensorBoardUploader(mock_client, logdir) + eid = uploader.create_experiment() + self.assertEqual(eid, "123") + + def test_start_uploading_without_create_experiment_fails(self): + mock_client = self._create_mock_client() + uploader = uploader_lib.TensorBoardUploader(mock_client, "/logs/foo") + with self.assertRaisesRegex(RuntimeError, "call create_experiment()"): + uploader.start_uploading() + + def test_start_uploading(self): + mock_client = self._create_mock_client() + mock_rate_limiter = mock.create_autospec(util.RateLimiter) + uploader = uploader_lib.TensorBoardUploader( + mock_client, "/logs/foo", mock_rate_limiter + ) + uploader.create_experiment() + mock_builder = mock.create_autospec(uploader_lib._RequestBuilder) + request = write_service_pb2.WriteScalarRequest() + mock_builder.build_requests.side_effect = [ + iter([request, request]), + iter([request, request, request, request, request]), + AbortUploadError, + ] + # pylint: disable=g-backslash-continuation + with mock.patch.object( + uploader, "_upload" + ) as mock_upload, mock.patch.object( + uploader, "_request_builder", mock_builder + ), self.assertRaises( + AbortUploadError + ): + uploader.start_uploading() + # pylint: enable=g-backslash-continuation + self.assertEqual(7, mock_upload.call_count) + self.assertEqual(2 + 5 + 1, mock_rate_limiter.tick.call_count) + + def test_upload_empty_logdir(self): + logdir = self.get_temp_dir() + mock_client = self._create_mock_client() + mock_rate_limiter = mock.create_autospec(util.RateLimiter) + uploader = uploader_lib.TensorBoardUploader( + mock_client, logdir, mock_rate_limiter + ) + uploader.create_experiment() + uploader._upload_once() + mock_client.WriteScalar.assert_not_called() + + def test_upload_swallows_rpc_failure(self): + logdir = self.get_temp_dir() + with tb_test_util.FileWriter(logdir) as writer: + writer.add_test_summary("foo") + mock_client = self._create_mock_client() + mock_rate_limiter = mock.create_autospec(util.RateLimiter) + uploader = uploader_lib.TensorBoardUploader( + mock_client, logdir, mock_rate_limiter + ) + uploader.create_experiment() + error = test_util.grpc_error(grpc.StatusCode.INTERNAL, "Failure") + mock_client.WriteScalar.side_effect = error + uploader._upload_once() + mock_client.WriteScalar.assert_called_once() + + def test_upload_propagates_experiment_deletion(self): + logdir = self.get_temp_dir() + with tb_test_util.FileWriter(logdir) as writer: + writer.add_test_summary("foo") + mock_client = self._create_mock_client() + mock_rate_limiter = mock.create_autospec(util.RateLimiter) + uploader = uploader_lib.TensorBoardUploader( + mock_client, logdir, mock_rate_limiter + ) + uploader.create_experiment() + error = test_util.grpc_error(grpc.StatusCode.NOT_FOUND, "nope") + mock_client.WriteScalar.side_effect = error + with self.assertRaises(uploader_lib.ExperimentNotFoundError): + uploader._upload_once() + + def test_upload_preserves_wall_time(self): + logdir = self.get_temp_dir() + with tb_test_util.FileWriter(logdir) as writer: + # Add a raw event so we can specify the wall_time value deterministically. + writer.add_event( + event_pb2.Event( + step=1, + wall_time=123.123123123, + summary=scalar_v2.scalar_pb("foo", 5.0), + ) + ) + mock_client = self._create_mock_client() + mock_rate_limiter = mock.create_autospec(util.RateLimiter) + uploader = uploader_lib.TensorBoardUploader( + mock_client, logdir, mock_rate_limiter + ) + uploader.create_experiment() + uploader._upload_once() + mock_client.WriteScalar.assert_called_once() + request = mock_client.WriteScalar.call_args[0][0] + # Just check the wall_time value; everything else is covered in the full + # logdir test below. + self.assertEqual( + 123123123123, + request.runs[0].tags[0].points[0].wall_time.ToNanoseconds(), + ) + + def test_upload_full_logdir(self): + logdir = self.get_temp_dir() + mock_client = self._create_mock_client() + mock_rate_limiter = mock.create_autospec(util.RateLimiter) + uploader = uploader_lib.TensorBoardUploader( + mock_client, logdir, mock_rate_limiter + ) + uploader.create_experiment() + + # Convenience helpers for constructing expected requests. + run = write_service_pb2.WriteScalarRequest.Run + tag = write_service_pb2.WriteScalarRequest.Tag + point = scalar_pb2.ScalarPoint + + # First round + writer = tb_test_util.FileWriter(logdir) + writer.add_test_summary("foo", simple_value=5.0, step=1) + writer.add_test_summary("foo", simple_value=6.0, step=2) + writer.add_test_summary("foo", simple_value=7.0, step=3) + writer.add_test_summary("bar", simple_value=8.0, step=3) + writer.flush() + writer_a = tb_test_util.FileWriter(os.path.join(logdir, "a")) + writer_a.add_test_summary("qux", simple_value=9.0, step=2) + writer_a.flush() + uploader._upload_once() + self.assertEqual(1, mock_client.WriteScalar.call_count) + request1 = mock_client.WriteScalar.call_args[0][0] + _clear_wall_times(request1) + expected_request1 = write_service_pb2.WriteScalarRequest( + experiment_id="123", + runs=[ + run( + name=".", + tags=[ + tag( + name="foo", + metadata=test_util.scalar_metadata("foo"), + points=[ + point(step=1, value=5.0), + point(step=2, value=6.0), + point(step=3, value=7.0), + ], + ), + tag( + name="bar", + metadata=test_util.scalar_metadata("bar"), + points=[point(step=3, value=8.0),], + ), + ], + ), + run( + name="a", + tags=[ + tag( + name="qux", + metadata=test_util.scalar_metadata("qux"), + points=[point(step=2, value=9.0),], + ), + ], + ), + ], + ) + self.assertProtoEquals(expected_request1, request1) + mock_client.WriteScalar.reset_mock() + + # Second round + writer.add_test_summary("foo", simple_value=10.0, step=5) + writer.add_test_summary("baz", simple_value=11.0, step=1) + writer.flush() + writer_b = tb_test_util.FileWriter(os.path.join(logdir, "b")) + writer_b.add_test_summary("xyz", simple_value=12.0, step=1) + writer_b.flush() + uploader._upload_once() + self.assertEqual(1, mock_client.WriteScalar.call_count) + request2 = mock_client.WriteScalar.call_args[0][0] + _clear_wall_times(request2) + expected_request2 = write_service_pb2.WriteScalarRequest( + experiment_id="123", + runs=[ + run( + name=".", + tags=[ + tag( + name="foo", + metadata=test_util.scalar_metadata("foo"), + points=[point(step=5, value=10.0),], + ), + tag( + name="baz", + metadata=test_util.scalar_metadata("baz"), + points=[point(step=1, value=11.0),], + ), + ], + ), + run( + name="b", + tags=[ + tag( + name="xyz", + metadata=test_util.scalar_metadata("xyz"), + points=[point(step=1, value=12.0),], + ), + ], + ), + ], + ) + self.assertProtoEquals(expected_request2, request2) + mock_client.WriteScalar.reset_mock() + + # Empty third round + uploader._upload_once() + mock_client.WriteScalar.assert_not_called() class RequestBuilderTest(tf.test.TestCase): - - def _populate_run_from_events(self, run_proto, events): - builder = uploader_lib._RequestBuilder(experiment_id="123") - requests = builder.build_requests({"": events}) - request = next(requests, None) - if request is not None: - self.assertLen(request.runs, 1) - run_proto.MergeFrom(request.runs[0]) - self.assertIsNone(next(requests, None)) - - def test_empty_events(self): - run_proto = write_service_pb2.WriteScalarRequest.Run() - self._populate_run_from_events(run_proto, []) - self.assertProtoEquals( - run_proto, write_service_pb2.WriteScalarRequest.Run()) - - def test_aggregation_by_tag(self): - def make_event(step, wall_time, tag, value): - return event_pb2.Event( - step=step, - wall_time=wall_time, - summary=scalar_v2.scalar_pb(tag, value)) - events = [ - make_event(1, 1.0, "one", 11.0), - make_event(1, 2.0, "two", 22.0), - make_event(2, 3.0, "one", 33.0), - make_event(2, 4.0, "two", 44.0), - make_event(1, 5.0, "one", 55.0), # Should preserve duplicate step=1. - make_event(1, 6.0, "three", 66.0), - ] - run_proto = write_service_pb2.WriteScalarRequest.Run() - self._populate_run_from_events(run_proto, events) - tag_data = { - tag.name: [ - (p.step, p.wall_time.ToSeconds(), p.value) for p in tag.points] - for tag in run_proto.tags} - self.assertEqual( - tag_data, { - "one": [(1, 1.0, 11.0), (2, 3.0, 33.0), (1, 5.0, 55.0)], - "two": [(1, 2.0, 22.0), (2, 4.0, 44.0)], - "three": [(1, 6.0, 66.0)], - }) - - def test_skips_non_scalar_events(self): - events = [ - event_pb2.Event(file_version="brain.Event:2"), - event_pb2.Event(summary=scalar_v2.scalar_pb("scalar1", 5.0)), - event_pb2.Event(summary=scalar_v2.scalar_pb("scalar2", 5.0)), - event_pb2.Event( - summary=histogram_v2.histogram_pb("histogram", [5.0])) - ] - run_proto = write_service_pb2.WriteScalarRequest.Run() - self._populate_run_from_events(run_proto, events) - tag_counts = {tag.name: len(tag.points) for tag in run_proto.tags} - self.assertEqual(tag_counts, {"scalar1": 1, "scalar2": 1}) - - def test_skips_scalar_events_in_non_scalar_time_series(self): - events = [ - event_pb2.Event(file_version="brain.Event:2"), - event_pb2.Event(summary=scalar_v2.scalar_pb("scalar1", 5.0)), - event_pb2.Event(summary=scalar_v2.scalar_pb("scalar2", 5.0)), - event_pb2.Event( - summary=histogram_v2.histogram_pb("histogram", [5.0])), - event_pb2.Event(summary=scalar_v2.scalar_pb("histogram", 5.0)), - ] - run_proto = write_service_pb2.WriteScalarRequest.Run() - self._populate_run_from_events(run_proto, events) - tag_counts = {tag.name: len(tag.points) for tag in run_proto.tags} - self.assertEqual(tag_counts, {"scalar1": 1, "scalar2": 1}) - - def test_remembers_first_metadata_in_scalar_time_series(self): - scalar_1 = event_pb2.Event(summary=scalar_v2.scalar_pb("loss", 4.0)) - scalar_2 = event_pb2.Event(summary=scalar_v2.scalar_pb("loss", 3.0)) - scalar_2.summary.value[0].ClearField("metadata") - events = [ - event_pb2.Event(file_version="brain.Event:2"), - scalar_1, - scalar_2, - ] - run_proto = write_service_pb2.WriteScalarRequest.Run() - self._populate_run_from_events(run_proto, events) - tag_counts = {tag.name: len(tag.points) for tag in run_proto.tags} - self.assertEqual(tag_counts, {"loss": 2}) - - def test_v1_summary_single_value(self): - event = event_pb2.Event(step=1, wall_time=123.456) - event.summary.value.add(tag="foo", simple_value=5.0) - run_proto = write_service_pb2.WriteScalarRequest.Run() - self._populate_run_from_events(run_proto, [event]) - expected_run_proto = write_service_pb2.WriteScalarRequest.Run() - foo_tag = expected_run_proto.tags.add() - foo_tag.name = "foo" - foo_tag.metadata.display_name = "foo" - foo_tag.metadata.plugin_data.plugin_name = "scalars" - foo_tag.points.add( - step=1, wall_time=test_util.timestamp_pb(123456000000), value=5.0) - self.assertProtoEquals(run_proto, expected_run_proto) - - def test_v1_summary_multiple_value(self): - event = event_pb2.Event(step=1, wall_time=123.456) - event.summary.value.add(tag="foo", simple_value=1.0) - event.summary.value.add(tag="foo", simple_value=2.0) - event.summary.value.add(tag="foo", simple_value=3.0) - run_proto = write_service_pb2.WriteScalarRequest.Run() - self._populate_run_from_events(run_proto, [event]) - expected_run_proto = write_service_pb2.WriteScalarRequest.Run() - foo_tag = expected_run_proto.tags.add() - foo_tag.name = "foo" - foo_tag.metadata.display_name = "foo" - foo_tag.metadata.plugin_data.plugin_name = "scalars" - foo_tag.points.add( - step=1, wall_time=test_util.timestamp_pb(123456000000), value=1.0) - foo_tag.points.add( - step=1, wall_time=test_util.timestamp_pb(123456000000), value=2.0) - foo_tag.points.add( - step=1, wall_time=test_util.timestamp_pb(123456000000), value=3.0) - self.assertProtoEquals(run_proto, expected_run_proto) - - def test_v1_summary_tb_summary(self): - tf_summary = summary_v1.scalar_pb("foo", 5.0) - tb_summary = summary_pb2.Summary.FromString(tf_summary.SerializeToString()) - event = event_pb2.Event(step=1, wall_time=123.456, summary=tb_summary) - run_proto = write_service_pb2.WriteScalarRequest.Run() - self._populate_run_from_events(run_proto, [event]) - expected_run_proto = write_service_pb2.WriteScalarRequest.Run() - foo_tag = expected_run_proto.tags.add() - foo_tag.name = "foo/scalar_summary" - foo_tag.metadata.display_name = "foo" - foo_tag.metadata.plugin_data.plugin_name = "scalars" - foo_tag.points.add( - step=1, wall_time=test_util.timestamp_pb(123456000000), value=5.0) - self.assertProtoEquals(run_proto, expected_run_proto) - - def test_v2_summary(self): - event = event_pb2.Event( - step=1, wall_time=123.456, summary=scalar_v2.scalar_pb("foo", 5.0)) - run_proto = write_service_pb2.WriteScalarRequest.Run() - self._populate_run_from_events(run_proto, [event]) - expected_run_proto = write_service_pb2.WriteScalarRequest.Run() - foo_tag = expected_run_proto.tags.add() - foo_tag.name = "foo" - foo_tag.metadata.plugin_data.plugin_name = "scalars" - foo_tag.points.add( - step=1, wall_time=test_util.timestamp_pb(123456000000), value=5.0) - self.assertProtoEquals(run_proto, expected_run_proto) - - def test_no_budget_for_experiment_id(self): - event = event_pb2.Event(step=1, wall_time=123.456) - event.summary.value.add(tag="foo", simple_value=1.0) - run_to_events = {"run_name": [event]} - long_experiment_id = "A" * uploader_lib._MAX_REQUEST_LENGTH_BYTES - with self.assertRaises(RuntimeError) as cm: - builder = uploader_lib._RequestBuilder(long_experiment_id) - list(builder.build_requests(run_to_events)) - self.assertEqual( - str(cm.exception), "Byte budget too small for experiment ID") - - def test_no_room_for_single_point(self): - event = event_pb2.Event(step=1, wall_time=123.456) - event.summary.value.add(tag="foo", simple_value=1.0) - long_run_name = "A" * uploader_lib._MAX_REQUEST_LENGTH_BYTES - run_to_events = {long_run_name: [event]} - with self.assertRaises(RuntimeError) as cm: - builder = uploader_lib._RequestBuilder("123") - list(builder.build_requests(run_to_events)) - self.assertEqual( - str(cm.exception), "Could not make progress uploading data") - - @mock.patch.object(uploader_lib, "_MAX_REQUEST_LENGTH_BYTES", 1024) - def test_break_at_run_boundary(self): - # Choose run name sizes such that one run fits, but not two. - long_run_1 = "A" * 768 - long_run_2 = "B" * 768 - event_1 = event_pb2.Event(step=1) - event_1.summary.value.add(tag="foo", simple_value=1.0) - event_2 = event_pb2.Event(step=2) - event_2.summary.value.add(tag="bar", simple_value=-2.0) - run_to_events = collections.OrderedDict([ - (long_run_1, [event_1]), - (long_run_2, [event_2]), - ]) - - builder = uploader_lib._RequestBuilder("123") - requests = list(builder.build_requests(run_to_events)) - for request in requests: - _clear_wall_times(request) - - expected = [ - write_service_pb2.WriteScalarRequest(experiment_id="123"), - write_service_pb2.WriteScalarRequest(experiment_id="123"), - ] - (expected[0].runs.add(name=long_run_1).tags.add( - name="foo", metadata=test_util.scalar_metadata("foo")).points.add( - step=1, value=1.0)) - (expected[1].runs.add(name=long_run_2).tags.add( - name="bar", metadata=test_util.scalar_metadata("bar")).points.add( - step=2, value=-2.0)) - self.assertEqual(requests, expected) - - @mock.patch.object(uploader_lib, "_MAX_REQUEST_LENGTH_BYTES", 1024) - def test_break_at_tag_boundary(self): - # Choose tag name sizes such that one tag fits, but not two. Note - # that tag names appear in both `Tag.name` and the summary metadata. - long_tag_1 = "a" * 384 - long_tag_2 = "b" * 384 - event = event_pb2.Event(step=1) - event.summary.value.add(tag=long_tag_1, simple_value=1.0) - event.summary.value.add(tag=long_tag_2, simple_value=2.0) - run_to_events = {"train": [event]} - - builder = uploader_lib._RequestBuilder("123") - requests = list(builder.build_requests(run_to_events)) - for request in requests: - _clear_wall_times(request) - - expected = [ - write_service_pb2.WriteScalarRequest(experiment_id="123"), - write_service_pb2.WriteScalarRequest(experiment_id="123"), - ] - (expected[0].runs.add(name="train").tags.add( - name=long_tag_1, - metadata=test_util.scalar_metadata(long_tag_1)).points.add( - step=1, value=1.0)) - (expected[1].runs.add(name="train").tags.add( - name=long_tag_2, - metadata=test_util.scalar_metadata(long_tag_2)).points.add( - step=1, value=2.0)) - self.assertEqual(requests, expected) - - @mock.patch.object(uploader_lib, "_MAX_REQUEST_LENGTH_BYTES", 1024) - def test_break_at_scalar_point_boundary(self): - point_count = 2000 # comfortably saturates a single 1024-byte request - events = [] - for step in range(point_count): - summary = scalar_v2.scalar_pb("loss", -2.0 * step) - if step > 0: - summary.value[0].ClearField("metadata") - events.append(event_pb2.Event(summary=summary, step=step)) - run_to_events = {"train": events} - - builder = uploader_lib._RequestBuilder("123") - requests = list(builder.build_requests(run_to_events)) - for request in requests: - _clear_wall_times(request) - - self.assertGreater(len(requests), 1) - self.assertLess(len(requests), point_count) - - total_points_in_result = 0 - for request in requests: - self.assertLen(request.runs, 1) - run = request.runs[0] - self.assertEqual(run.name, "train") - self.assertLen(run.tags, 1) - tag = run.tags[0] - self.assertEqual(tag.name, "loss") - for point in tag.points: - self.assertEqual(point.step, total_points_in_result) - self.assertEqual(point.value, -2.0 * point.step) - total_points_in_result += 1 - self.assertLessEqual( - request.ByteSize(), uploader_lib._MAX_REQUEST_LENGTH_BYTES) - self.assertEqual(total_points_in_result, point_count) - - def test_prunes_tags_and_runs(self): - event_1 = event_pb2.Event(step=1) - event_1.summary.value.add(tag="foo", simple_value=1.0) - event_2 = event_pb2.Event(step=2) - event_2.summary.value.add(tag="bar", simple_value=-2.0) - run_to_events = collections.OrderedDict([ - ("train", [event_1]), - ("test", [event_2]), - ]) - - real_create_point = uploader_lib._RequestBuilder._create_point - - create_point_call_count_box = [0] - - def mock_create_point(uploader_self, *args, **kwargs): - # Simulate out-of-space error the first time that we try to store - # the second point. - create_point_call_count_box[0] += 1 - if create_point_call_count_box[0] == 2: - raise uploader_lib._OutOfSpaceError() - return real_create_point(uploader_self, *args, **kwargs) - - with mock.patch.object( - uploader_lib._RequestBuilder, "_create_point", mock_create_point): - builder = uploader_lib._RequestBuilder("123") - requests = list(builder.build_requests(run_to_events)) - for request in requests: - _clear_wall_times(request) - - expected = [ - write_service_pb2.WriteScalarRequest(experiment_id="123"), - write_service_pb2.WriteScalarRequest(experiment_id="123"), - ] - (expected[0].runs.add(name="train").tags.add( - name="foo", metadata=test_util.scalar_metadata("foo")).points.add( - step=1, value=1.0)) - (expected[1].runs.add(name="test").tags.add( - name="bar", metadata=test_util.scalar_metadata("bar")).points.add( - step=2, value=-2.0)) - self.assertEqual(expected, requests) - - def test_wall_time_precision(self): - # Test a wall time that is exactly representable in float64 but has enough - # digits to incur error if converted to nanonseconds the naive way (* 1e9). - event1 = event_pb2.Event(step=1, wall_time=1567808404.765432119) - event1.summary.value.add(tag="foo", simple_value=1.0) - # Test a wall time where as a float64, the fractional part on its own will - # introduce error if truncated to 9 decimal places instead of rounded. - event2 = event_pb2.Event(step=2, wall_time=1.000000002) - event2.summary.value.add(tag="foo", simple_value=2.0) - run_proto = write_service_pb2.WriteScalarRequest.Run() - self._populate_run_from_events(run_proto, [event1, event2]) - self.assertEqual( - test_util.timestamp_pb(1567808404765432119), - run_proto.tags[0].points[0].wall_time) - self.assertEqual( - test_util.timestamp_pb(1000000002), - run_proto.tags[0].points[1].wall_time) + def _populate_run_from_events(self, run_proto, events): + builder = uploader_lib._RequestBuilder(experiment_id="123") + requests = builder.build_requests({"": events}) + request = next(requests, None) + if request is not None: + self.assertLen(request.runs, 1) + run_proto.MergeFrom(request.runs[0]) + self.assertIsNone(next(requests, None)) + + def test_empty_events(self): + run_proto = write_service_pb2.WriteScalarRequest.Run() + self._populate_run_from_events(run_proto, []) + self.assertProtoEquals( + run_proto, write_service_pb2.WriteScalarRequest.Run() + ) + + def test_aggregation_by_tag(self): + def make_event(step, wall_time, tag, value): + return event_pb2.Event( + step=step, + wall_time=wall_time, + summary=scalar_v2.scalar_pb(tag, value), + ) + + events = [ + make_event(1, 1.0, "one", 11.0), + make_event(1, 2.0, "two", 22.0), + make_event(2, 3.0, "one", 33.0), + make_event(2, 4.0, "two", 44.0), + make_event( + 1, 5.0, "one", 55.0 + ), # Should preserve duplicate step=1. + make_event(1, 6.0, "three", 66.0), + ] + run_proto = write_service_pb2.WriteScalarRequest.Run() + self._populate_run_from_events(run_proto, events) + tag_data = { + tag.name: [ + (p.step, p.wall_time.ToSeconds(), p.value) for p in tag.points + ] + for tag in run_proto.tags + } + self.assertEqual( + tag_data, + { + "one": [(1, 1.0, 11.0), (2, 3.0, 33.0), (1, 5.0, 55.0)], + "two": [(1, 2.0, 22.0), (2, 4.0, 44.0)], + "three": [(1, 6.0, 66.0)], + }, + ) + + def test_skips_non_scalar_events(self): + events = [ + event_pb2.Event(file_version="brain.Event:2"), + event_pb2.Event(summary=scalar_v2.scalar_pb("scalar1", 5.0)), + event_pb2.Event(summary=scalar_v2.scalar_pb("scalar2", 5.0)), + event_pb2.Event( + summary=histogram_v2.histogram_pb("histogram", [5.0]) + ), + ] + run_proto = write_service_pb2.WriteScalarRequest.Run() + self._populate_run_from_events(run_proto, events) + tag_counts = {tag.name: len(tag.points) for tag in run_proto.tags} + self.assertEqual(tag_counts, {"scalar1": 1, "scalar2": 1}) + + def test_skips_scalar_events_in_non_scalar_time_series(self): + events = [ + event_pb2.Event(file_version="brain.Event:2"), + event_pb2.Event(summary=scalar_v2.scalar_pb("scalar1", 5.0)), + event_pb2.Event(summary=scalar_v2.scalar_pb("scalar2", 5.0)), + event_pb2.Event( + summary=histogram_v2.histogram_pb("histogram", [5.0]) + ), + event_pb2.Event(summary=scalar_v2.scalar_pb("histogram", 5.0)), + ] + run_proto = write_service_pb2.WriteScalarRequest.Run() + self._populate_run_from_events(run_proto, events) + tag_counts = {tag.name: len(tag.points) for tag in run_proto.tags} + self.assertEqual(tag_counts, {"scalar1": 1, "scalar2": 1}) + + def test_remembers_first_metadata_in_scalar_time_series(self): + scalar_1 = event_pb2.Event(summary=scalar_v2.scalar_pb("loss", 4.0)) + scalar_2 = event_pb2.Event(summary=scalar_v2.scalar_pb("loss", 3.0)) + scalar_2.summary.value[0].ClearField("metadata") + events = [ + event_pb2.Event(file_version="brain.Event:2"), + scalar_1, + scalar_2, + ] + run_proto = write_service_pb2.WriteScalarRequest.Run() + self._populate_run_from_events(run_proto, events) + tag_counts = {tag.name: len(tag.points) for tag in run_proto.tags} + self.assertEqual(tag_counts, {"loss": 2}) + + def test_v1_summary_single_value(self): + event = event_pb2.Event(step=1, wall_time=123.456) + event.summary.value.add(tag="foo", simple_value=5.0) + run_proto = write_service_pb2.WriteScalarRequest.Run() + self._populate_run_from_events(run_proto, [event]) + expected_run_proto = write_service_pb2.WriteScalarRequest.Run() + foo_tag = expected_run_proto.tags.add() + foo_tag.name = "foo" + foo_tag.metadata.display_name = "foo" + foo_tag.metadata.plugin_data.plugin_name = "scalars" + foo_tag.points.add( + step=1, wall_time=test_util.timestamp_pb(123456000000), value=5.0 + ) + self.assertProtoEquals(run_proto, expected_run_proto) + + def test_v1_summary_multiple_value(self): + event = event_pb2.Event(step=1, wall_time=123.456) + event.summary.value.add(tag="foo", simple_value=1.0) + event.summary.value.add(tag="foo", simple_value=2.0) + event.summary.value.add(tag="foo", simple_value=3.0) + run_proto = write_service_pb2.WriteScalarRequest.Run() + self._populate_run_from_events(run_proto, [event]) + expected_run_proto = write_service_pb2.WriteScalarRequest.Run() + foo_tag = expected_run_proto.tags.add() + foo_tag.name = "foo" + foo_tag.metadata.display_name = "foo" + foo_tag.metadata.plugin_data.plugin_name = "scalars" + foo_tag.points.add( + step=1, wall_time=test_util.timestamp_pb(123456000000), value=1.0 + ) + foo_tag.points.add( + step=1, wall_time=test_util.timestamp_pb(123456000000), value=2.0 + ) + foo_tag.points.add( + step=1, wall_time=test_util.timestamp_pb(123456000000), value=3.0 + ) + self.assertProtoEquals(run_proto, expected_run_proto) + + def test_v1_summary_tb_summary(self): + tf_summary = summary_v1.scalar_pb("foo", 5.0) + tb_summary = summary_pb2.Summary.FromString( + tf_summary.SerializeToString() + ) + event = event_pb2.Event(step=1, wall_time=123.456, summary=tb_summary) + run_proto = write_service_pb2.WriteScalarRequest.Run() + self._populate_run_from_events(run_proto, [event]) + expected_run_proto = write_service_pb2.WriteScalarRequest.Run() + foo_tag = expected_run_proto.tags.add() + foo_tag.name = "foo/scalar_summary" + foo_tag.metadata.display_name = "foo" + foo_tag.metadata.plugin_data.plugin_name = "scalars" + foo_tag.points.add( + step=1, wall_time=test_util.timestamp_pb(123456000000), value=5.0 + ) + self.assertProtoEquals(run_proto, expected_run_proto) + + def test_v2_summary(self): + event = event_pb2.Event( + step=1, wall_time=123.456, summary=scalar_v2.scalar_pb("foo", 5.0) + ) + run_proto = write_service_pb2.WriteScalarRequest.Run() + self._populate_run_from_events(run_proto, [event]) + expected_run_proto = write_service_pb2.WriteScalarRequest.Run() + foo_tag = expected_run_proto.tags.add() + foo_tag.name = "foo" + foo_tag.metadata.plugin_data.plugin_name = "scalars" + foo_tag.points.add( + step=1, wall_time=test_util.timestamp_pb(123456000000), value=5.0 + ) + self.assertProtoEquals(run_proto, expected_run_proto) + + def test_no_budget_for_experiment_id(self): + event = event_pb2.Event(step=1, wall_time=123.456) + event.summary.value.add(tag="foo", simple_value=1.0) + run_to_events = {"run_name": [event]} + long_experiment_id = "A" * uploader_lib._MAX_REQUEST_LENGTH_BYTES + with self.assertRaises(RuntimeError) as cm: + builder = uploader_lib._RequestBuilder(long_experiment_id) + list(builder.build_requests(run_to_events)) + self.assertEqual( + str(cm.exception), "Byte budget too small for experiment ID" + ) + + def test_no_room_for_single_point(self): + event = event_pb2.Event(step=1, wall_time=123.456) + event.summary.value.add(tag="foo", simple_value=1.0) + long_run_name = "A" * uploader_lib._MAX_REQUEST_LENGTH_BYTES + run_to_events = {long_run_name: [event]} + with self.assertRaises(RuntimeError) as cm: + builder = uploader_lib._RequestBuilder("123") + list(builder.build_requests(run_to_events)) + self.assertEqual( + str(cm.exception), "Could not make progress uploading data" + ) + + @mock.patch.object(uploader_lib, "_MAX_REQUEST_LENGTH_BYTES", 1024) + def test_break_at_run_boundary(self): + # Choose run name sizes such that one run fits, but not two. + long_run_1 = "A" * 768 + long_run_2 = "B" * 768 + event_1 = event_pb2.Event(step=1) + event_1.summary.value.add(tag="foo", simple_value=1.0) + event_2 = event_pb2.Event(step=2) + event_2.summary.value.add(tag="bar", simple_value=-2.0) + run_to_events = collections.OrderedDict( + [(long_run_1, [event_1]), (long_run_2, [event_2]),] + ) + + builder = uploader_lib._RequestBuilder("123") + requests = list(builder.build_requests(run_to_events)) + for request in requests: + _clear_wall_times(request) + + expected = [ + write_service_pb2.WriteScalarRequest(experiment_id="123"), + write_service_pb2.WriteScalarRequest(experiment_id="123"), + ] + ( + expected[0] + .runs.add(name=long_run_1) + .tags.add(name="foo", metadata=test_util.scalar_metadata("foo")) + .points.add(step=1, value=1.0) + ) + ( + expected[1] + .runs.add(name=long_run_2) + .tags.add(name="bar", metadata=test_util.scalar_metadata("bar")) + .points.add(step=2, value=-2.0) + ) + self.assertEqual(requests, expected) + + @mock.patch.object(uploader_lib, "_MAX_REQUEST_LENGTH_BYTES", 1024) + def test_break_at_tag_boundary(self): + # Choose tag name sizes such that one tag fits, but not two. Note + # that tag names appear in both `Tag.name` and the summary metadata. + long_tag_1 = "a" * 384 + long_tag_2 = "b" * 384 + event = event_pb2.Event(step=1) + event.summary.value.add(tag=long_tag_1, simple_value=1.0) + event.summary.value.add(tag=long_tag_2, simple_value=2.0) + run_to_events = {"train": [event]} + + builder = uploader_lib._RequestBuilder("123") + requests = list(builder.build_requests(run_to_events)) + for request in requests: + _clear_wall_times(request) + + expected = [ + write_service_pb2.WriteScalarRequest(experiment_id="123"), + write_service_pb2.WriteScalarRequest(experiment_id="123"), + ] + ( + expected[0] + .runs.add(name="train") + .tags.add( + name=long_tag_1, metadata=test_util.scalar_metadata(long_tag_1) + ) + .points.add(step=1, value=1.0) + ) + ( + expected[1] + .runs.add(name="train") + .tags.add( + name=long_tag_2, metadata=test_util.scalar_metadata(long_tag_2) + ) + .points.add(step=1, value=2.0) + ) + self.assertEqual(requests, expected) + + @mock.patch.object(uploader_lib, "_MAX_REQUEST_LENGTH_BYTES", 1024) + def test_break_at_scalar_point_boundary(self): + point_count = 2000 # comfortably saturates a single 1024-byte request + events = [] + for step in range(point_count): + summary = scalar_v2.scalar_pb("loss", -2.0 * step) + if step > 0: + summary.value[0].ClearField("metadata") + events.append(event_pb2.Event(summary=summary, step=step)) + run_to_events = {"train": events} + + builder = uploader_lib._RequestBuilder("123") + requests = list(builder.build_requests(run_to_events)) + for request in requests: + _clear_wall_times(request) + + self.assertGreater(len(requests), 1) + self.assertLess(len(requests), point_count) + + total_points_in_result = 0 + for request in requests: + self.assertLen(request.runs, 1) + run = request.runs[0] + self.assertEqual(run.name, "train") + self.assertLen(run.tags, 1) + tag = run.tags[0] + self.assertEqual(tag.name, "loss") + for point in tag.points: + self.assertEqual(point.step, total_points_in_result) + self.assertEqual(point.value, -2.0 * point.step) + total_points_in_result += 1 + self.assertLessEqual( + request.ByteSize(), uploader_lib._MAX_REQUEST_LENGTH_BYTES + ) + self.assertEqual(total_points_in_result, point_count) + + def test_prunes_tags_and_runs(self): + event_1 = event_pb2.Event(step=1) + event_1.summary.value.add(tag="foo", simple_value=1.0) + event_2 = event_pb2.Event(step=2) + event_2.summary.value.add(tag="bar", simple_value=-2.0) + run_to_events = collections.OrderedDict( + [("train", [event_1]), ("test", [event_2]),] + ) + + real_create_point = uploader_lib._RequestBuilder._create_point + + create_point_call_count_box = [0] + + def mock_create_point(uploader_self, *args, **kwargs): + # Simulate out-of-space error the first time that we try to store + # the second point. + create_point_call_count_box[0] += 1 + if create_point_call_count_box[0] == 2: + raise uploader_lib._OutOfSpaceError() + return real_create_point(uploader_self, *args, **kwargs) + + with mock.patch.object( + uploader_lib._RequestBuilder, "_create_point", mock_create_point + ): + builder = uploader_lib._RequestBuilder("123") + requests = list(builder.build_requests(run_to_events)) + for request in requests: + _clear_wall_times(request) + + expected = [ + write_service_pb2.WriteScalarRequest(experiment_id="123"), + write_service_pb2.WriteScalarRequest(experiment_id="123"), + ] + ( + expected[0] + .runs.add(name="train") + .tags.add(name="foo", metadata=test_util.scalar_metadata("foo")) + .points.add(step=1, value=1.0) + ) + ( + expected[1] + .runs.add(name="test") + .tags.add(name="bar", metadata=test_util.scalar_metadata("bar")) + .points.add(step=2, value=-2.0) + ) + self.assertEqual(expected, requests) + + def test_wall_time_precision(self): + # Test a wall time that is exactly representable in float64 but has enough + # digits to incur error if converted to nanonseconds the naive way (* 1e9). + event1 = event_pb2.Event(step=1, wall_time=1567808404.765432119) + event1.summary.value.add(tag="foo", simple_value=1.0) + # Test a wall time where as a float64, the fractional part on its own will + # introduce error if truncated to 9 decimal places instead of rounded. + event2 = event_pb2.Event(step=2, wall_time=1.000000002) + event2.summary.value.add(tag="foo", simple_value=2.0) + run_proto = write_service_pb2.WriteScalarRequest.Run() + self._populate_run_from_events(run_proto, [event1, event2]) + self.assertEqual( + test_util.timestamp_pb(1567808404765432119), + run_proto.tags[0].points[0].wall_time, + ) + self.assertEqual( + test_util.timestamp_pb(1000000002), + run_proto.tags[0].points[1].wall_time, + ) class DeleteExperimentTest(tf.test.TestCase): - - def _create_mock_client(self): - # Create a stub instance (using a test channel) in order to derive a mock - # from it with autospec enabled. Mocking TensorBoardWriterServiceStub itself - # doesn't work with autospec because grpc constructs stubs via metaclassing. - test_channel = grpc_testing.channel( - service_descriptors=[], time=grpc_testing.strict_real_time()) - stub = write_service_pb2_grpc.TensorBoardWriterServiceStub(test_channel) - mock_client = mock.create_autospec(stub) - return mock_client - - def test_success(self): - mock_client = self._create_mock_client() - response = write_service_pb2.DeleteExperimentResponse() - mock_client.DeleteExperiment.return_value = response - - uploader_lib.delete_experiment(mock_client, "123") - - expected_request = write_service_pb2.DeleteExperimentRequest() - expected_request.experiment_id = "123" - mock_client.DeleteExperiment.assert_called_once() - (args, _) = mock_client.DeleteExperiment.call_args - self.assertEqual(args[0], expected_request) - - def test_not_found(self): - mock_client = self._create_mock_client() - error = test_util.grpc_error(grpc.StatusCode.NOT_FOUND, "nope") - mock_client.DeleteExperiment.side_effect = error - - with self.assertRaises(uploader_lib.ExperimentNotFoundError): - uploader_lib.delete_experiment(mock_client, "123") - - def test_unauthorized(self): - mock_client = self._create_mock_client() - error = test_util.grpc_error(grpc.StatusCode.PERMISSION_DENIED, "nope") - mock_client.DeleteExperiment.side_effect = error - - with self.assertRaises(uploader_lib.PermissionDeniedError): - uploader_lib.delete_experiment(mock_client, "123") - - def test_internal_error(self): - mock_client = self._create_mock_client() - error = test_util.grpc_error(grpc.StatusCode.INTERNAL, "travesty") - mock_client.DeleteExperiment.side_effect = error - - with self.assertRaises(grpc.RpcError) as cm: - uploader_lib.delete_experiment(mock_client, "123") - msg = str(cm.exception) - self.assertIn("travesty", msg) + def _create_mock_client(self): + # Create a stub instance (using a test channel) in order to derive a mock + # from it with autospec enabled. Mocking TensorBoardWriterServiceStub itself + # doesn't work with autospec because grpc constructs stubs via metaclassing. + test_channel = grpc_testing.channel( + service_descriptors=[], time=grpc_testing.strict_real_time() + ) + stub = write_service_pb2_grpc.TensorBoardWriterServiceStub(test_channel) + mock_client = mock.create_autospec(stub) + return mock_client + + def test_success(self): + mock_client = self._create_mock_client() + response = write_service_pb2.DeleteExperimentResponse() + mock_client.DeleteExperiment.return_value = response + + uploader_lib.delete_experiment(mock_client, "123") + + expected_request = write_service_pb2.DeleteExperimentRequest() + expected_request.experiment_id = "123" + mock_client.DeleteExperiment.assert_called_once() + (args, _) = mock_client.DeleteExperiment.call_args + self.assertEqual(args[0], expected_request) + + def test_not_found(self): + mock_client = self._create_mock_client() + error = test_util.grpc_error(grpc.StatusCode.NOT_FOUND, "nope") + mock_client.DeleteExperiment.side_effect = error + + with self.assertRaises(uploader_lib.ExperimentNotFoundError): + uploader_lib.delete_experiment(mock_client, "123") + + def test_unauthorized(self): + mock_client = self._create_mock_client() + error = test_util.grpc_error(grpc.StatusCode.PERMISSION_DENIED, "nope") + mock_client.DeleteExperiment.side_effect = error + + with self.assertRaises(uploader_lib.PermissionDeniedError): + uploader_lib.delete_experiment(mock_client, "123") + + def test_internal_error(self): + mock_client = self._create_mock_client() + error = test_util.grpc_error(grpc.StatusCode.INTERNAL, "travesty") + mock_client.DeleteExperiment.side_effect = error + + with self.assertRaises(grpc.RpcError) as cm: + uploader_lib.delete_experiment(mock_client, "123") + msg = str(cm.exception) + self.assertIn("travesty", msg) class VarintCostTest(tf.test.TestCase): - - def test_varint_cost(self): - self.assertEqual(uploader_lib._varint_cost(0), 1) - self.assertEqual(uploader_lib._varint_cost(7), 1) - self.assertEqual(uploader_lib._varint_cost(127), 1) - self.assertEqual(uploader_lib._varint_cost(128), 2) - self.assertEqual(uploader_lib._varint_cost(128 * 128 - 1), 2) - self.assertEqual(uploader_lib._varint_cost(128 * 128), 3) + def test_varint_cost(self): + self.assertEqual(uploader_lib._varint_cost(0), 1) + self.assertEqual(uploader_lib._varint_cost(7), 1) + self.assertEqual(uploader_lib._varint_cost(127), 1) + self.assertEqual(uploader_lib._varint_cost(128), 2) + self.assertEqual(uploader_lib._varint_cost(128 * 128 - 1), 2) + self.assertEqual(uploader_lib._varint_cost(128 * 128), 3) def _clear_wall_times(request): - """Clears the wall_time fields in a WriteScalarRequest to be deterministic.""" - for run in request.runs: - for tag in run.tags: - for point in tag.points: - point.ClearField("wall_time") + """Clears the wall_time fields in a WriteScalarRequest to be + deterministic.""" + for run in request.runs: + for tag in run.tags: + for point in tag.points: + point.ClearField("wall_time") if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/tensorboard/uploader/util.py b/tensorboard/uploader/util.py index 16ad493bc5..cc0534baab 100644 --- a/tensorboard/uploader/util.py +++ b/tensorboard/uploader/util.py @@ -26,138 +26,142 @@ class RateLimiter(object): - """Helper class for rate-limiting using a fixed minimum interval.""" - - def __init__(self, interval_secs): - """Constructs a RateLimiter that permits a tick() every `interval_secs`.""" - self._time = time # Use property for ease of testing. - self._interval_secs = interval_secs - self._last_called_secs = 0 - - def tick(self): - """Blocks until it has been at least `interval_secs` since last tick().""" - wait_secs = self._last_called_secs + self._interval_secs - self._time.time() - if wait_secs > 0: - self._time.sleep(wait_secs) - self._last_called_secs = self._time.time() + """Helper class for rate-limiting using a fixed minimum interval.""" + + def __init__(self, interval_secs): + """Constructs a RateLimiter that permits a tick() every + `interval_secs`.""" + self._time = time # Use property for ease of testing. + self._interval_secs = interval_secs + self._last_called_secs = 0 + + def tick(self): + """Blocks until it has been at least `interval_secs` since last + tick().""" + wait_secs = ( + self._last_called_secs + self._interval_secs - self._time.time() + ) + if wait_secs > 0: + self._time.sleep(wait_secs) + self._last_called_secs = self._time.time() def get_user_config_directory(): - """Returns a platform-specific root directory for user config settings.""" - # On Windows, prefer %LOCALAPPDATA%, then %APPDATA%, since we can expect the - # AppData directories to be ACLed to be visible only to the user and admin - # users (https://stackoverflow.com/a/7617601/1179226). If neither is set, - # return None instead of falling back to something that may be world-readable. - if os.name == "nt": - appdata = os.getenv("LOCALAPPDATA") - if appdata: - return appdata - appdata = os.getenv("APPDATA") - if appdata: - return appdata - return None - # On non-windows, use XDG_CONFIG_HOME if set, else default to ~/.config. - xdg_config_home = os.getenv("XDG_CONFIG_HOME") - if xdg_config_home: - return xdg_config_home - return os.path.join(os.path.expanduser("~"), ".config") + """Returns a platform-specific root directory for user config settings.""" + # On Windows, prefer %LOCALAPPDATA%, then %APPDATA%, since we can expect the + # AppData directories to be ACLed to be visible only to the user and admin + # users (https://stackoverflow.com/a/7617601/1179226). If neither is set, + # return None instead of falling back to something that may be world-readable. + if os.name == "nt": + appdata = os.getenv("LOCALAPPDATA") + if appdata: + return appdata + appdata = os.getenv("APPDATA") + if appdata: + return appdata + return None + # On non-windows, use XDG_CONFIG_HOME if set, else default to ~/.config. + xdg_config_home = os.getenv("XDG_CONFIG_HOME") + if xdg_config_home: + return xdg_config_home + return os.path.join(os.path.expanduser("~"), ".config") def make_file_with_directories(path, private=False): - """Creates a file and its containing directories, if they don't already exist. - - - If `private` is True, the file will be made private (readable only by the - current user) and so will the leaf directory. Pre-existing contents of the - file are not modified. - - Passing `private=True` is not supported on Windows because it doesn't support - the relevant parts of `os.chmod()`. - - Args: - path: str, The path of the file to create. - private: boolean, Whether to make the file and leaf directory readable only - by the current user. - - Raises: - RuntimeError: If called on Windows with `private` set to True. - """ - if private and os.name == "nt": - raise RuntimeError("Creating private file not supported on Windows") - try: - path = os.path.realpath(path) - leaf_dir = os.path.dirname(path) + """Creates a file and its containing directories, if they don't already + exist. + + If `private` is True, the file will be made private (readable only by the + current user) and so will the leaf directory. Pre-existing contents of the + file are not modified. + + Passing `private=True` is not supported on Windows because it doesn't support + the relevant parts of `os.chmod()`. + + Args: + path: str, The path of the file to create. + private: boolean, Whether to make the file and leaf directory readable only + by the current user. + + Raises: + RuntimeError: If called on Windows with `private` set to True. + """ + if private and os.name == "nt": + raise RuntimeError("Creating private file not supported on Windows") try: - os.makedirs(leaf_dir) - except OSError as e: - if e.errno != errno.EEXIST: - raise - if private: - os.chmod(leaf_dir, 0o700) - open(path, "a").close() - if private: - os.chmod(path, 0o600) - except EnvironmentError as e: - raise RuntimeError("Failed to create file %s: %s" % (path, e)) + path = os.path.realpath(path) + leaf_dir = os.path.dirname(path) + try: + os.makedirs(leaf_dir) + except OSError as e: + if e.errno != errno.EEXIST: + raise + if private: + os.chmod(leaf_dir, 0o700) + open(path, "a").close() + if private: + os.chmod(path, 0o600) + except EnvironmentError as e: + raise RuntimeError("Failed to create file %s: %s" % (path, e)) def set_timestamp(pb, seconds_since_epoch): - """Sets a `Timestamp` proto message to a floating point UNIX time. + """Sets a `Timestamp` proto message to a floating point UNIX time. - This is like `pb.FromNanoseconds(int(seconds_since_epoch * 1e9))` but - without introducing floating-point error. + This is like `pb.FromNanoseconds(int(seconds_since_epoch * 1e9))` but + without introducing floating-point error. - Args: - pb: A `google.protobuf.Timestamp` message to mutate. - seconds_since_epoch: A `float`, as returned by `time.time`. - """ - pb.seconds = int(seconds_since_epoch) - pb.nanos = int(round((seconds_since_epoch % 1) * 10**9)) + Args: + pb: A `google.protobuf.Timestamp` message to mutate. + seconds_since_epoch: A `float`, as returned by `time.time`. + """ + pb.seconds = int(seconds_since_epoch) + pb.nanos = int(round((seconds_since_epoch % 1) * 10 ** 9)) def format_time(timestamp_pb, now=None): - """Converts a `timestamp_pb2.Timestamp` to human-readable string. + """Converts a `timestamp_pb2.Timestamp` to human-readable string. - This always includes the absolute date and time, and for recent dates - may include a relative time like "(just now)" or "(2 hours ago)". + This always includes the absolute date and time, and for recent dates + may include a relative time like "(just now)" or "(2 hours ago)". - Args: - timestamp_pb: A `google.protobuf.timestamp_pb2.Timestamp` value to - convert to string. The input will not be modified. - now: A `datetime.datetime` object representing the current time, - used for determining relative times like "just now". Optional; - defaults to `datetime.datetime.now()`. + Args: + timestamp_pb: A `google.protobuf.timestamp_pb2.Timestamp` value to + convert to string. The input will not be modified. + now: A `datetime.datetime` object representing the current time, + used for determining relative times like "just now". Optional; + defaults to `datetime.datetime.now()`. - Returns: - A string suitable for human consumption. - """ + Returns: + A string suitable for human consumption. + """ - # Add and subtract a day for , - # which breaks early datetime conversions on Windows for small - # timestamps. - dt = datetime.datetime.fromtimestamp(timestamp_pb.seconds + 86400) - dt = dt - datetime.timedelta(seconds=86400) + # Add and subtract a day for , + # which breaks early datetime conversions on Windows for small + # timestamps. + dt = datetime.datetime.fromtimestamp(timestamp_pb.seconds + 86400) + dt = dt - datetime.timedelta(seconds=86400) - if now is None: - now = datetime.datetime.now() - ago = now.replace(microsecond=0) - dt + if now is None: + now = datetime.datetime.now() + ago = now.replace(microsecond=0) - dt - def ago_text(n, singular, plural): - return "%d %s ago" % (n, singular if n == 1 else plural) + def ago_text(n, singular, plural): + return "%d %s ago" % (n, singular if n == 1 else plural) - relative = None - if ago < datetime.timedelta(seconds=5): - relative = "just now" - elif ago < datetime.timedelta(minutes=1): - relative = ago_text(int(ago.total_seconds()), "second", "seconds") - elif ago < datetime.timedelta(hours=1): - relative = ago_text(int(ago.total_seconds()) // 60, "minute", "minutes") - elif ago < datetime.timedelta(days=1): - relative = ago_text(int(ago.total_seconds()) // 3600, "hour", "hours") + relative = None + if ago < datetime.timedelta(seconds=5): + relative = "just now" + elif ago < datetime.timedelta(minutes=1): + relative = ago_text(int(ago.total_seconds()), "second", "seconds") + elif ago < datetime.timedelta(hours=1): + relative = ago_text(int(ago.total_seconds()) // 60, "minute", "minutes") + elif ago < datetime.timedelta(days=1): + relative = ago_text(int(ago.total_seconds()) // 3600, "hour", "hours") - relative_part = " (%s)" % relative if relative is not None else "" - return str(dt) + relative_part + relative_part = " (%s)" % relative if relative is not None else "" + return str(dt) + relative_part def _ngettext(n, singular, plural): - return "%d %s ago" % (n, singular if n == 1 else plural) + return "%d %s ago" % (n, singular if n == 1 else plural) diff --git a/tensorboard/uploader/util_test.py b/tensorboard/uploader/util_test.py index 444797e7da..a1f38a357c 100644 --- a/tensorboard/uploader/util_test.py +++ b/tensorboard/uploader/util_test.py @@ -25,10 +25,10 @@ try: - # python version >= 3.3 - from unittest import mock + # python version >= 3.3 + from unittest import mock except ImportError: - import mock # pylint: disable=unused-import + import mock # pylint: disable=unused-import from google.protobuf import timestamp_pb2 @@ -38,209 +38,215 @@ class RateLimiterTest(tb_test.TestCase): - - def test_rate_limiting(self): - rate_limiter = util.RateLimiter(10) - fake_time = test_util.FakeTime(current=1000) - with mock.patch.object(rate_limiter, "_time", fake_time): - self.assertEqual(1000, fake_time.time()) - # No sleeping for initial tick. - rate_limiter.tick() - self.assertEqual(1000, fake_time.time()) - # Second tick requires a full sleep. - rate_limiter.tick() - self.assertEqual(1010, fake_time.time()) - # Third tick requires a sleep just to make up the remaining second. - fake_time.sleep(9) - self.assertEqual(1019, fake_time.time()) - rate_limiter.tick() - self.assertEqual(1020, fake_time.time()) - # Fourth tick requires no sleep since we have no remaining seconds. - fake_time.sleep(11) - self.assertEqual(1031, fake_time.time()) - rate_limiter.tick() - self.assertEqual(1031, fake_time.time()) + def test_rate_limiting(self): + rate_limiter = util.RateLimiter(10) + fake_time = test_util.FakeTime(current=1000) + with mock.patch.object(rate_limiter, "_time", fake_time): + self.assertEqual(1000, fake_time.time()) + # No sleeping for initial tick. + rate_limiter.tick() + self.assertEqual(1000, fake_time.time()) + # Second tick requires a full sleep. + rate_limiter.tick() + self.assertEqual(1010, fake_time.time()) + # Third tick requires a sleep just to make up the remaining second. + fake_time.sleep(9) + self.assertEqual(1019, fake_time.time()) + rate_limiter.tick() + self.assertEqual(1020, fake_time.time()) + # Fourth tick requires no sleep since we have no remaining seconds. + fake_time.sleep(11) + self.assertEqual(1031, fake_time.time()) + rate_limiter.tick() + self.assertEqual(1031, fake_time.time()) class GetUserConfigDirectoryTest(tb_test.TestCase): - - def test_windows(self): - with mock.patch.object(os, "name", "nt"): - with mock.patch.dict(os.environ, { - "LOCALAPPDATA": "C:\\Users\\Alice\\AppData\\Local", - "APPDATA": "C:\\Users\\Alice\\AppData\\Roaming", - }): - self.assertEqual( - "C:\\Users\\Alice\\AppData\\Local", - util.get_user_config_directory()) - with mock.patch.dict(os.environ, { - "LOCALAPPDATA": "", - "APPDATA": "C:\\Users\\Alice\\AppData\\Roaming", - }): - self.assertEqual( - "C:\\Users\\Alice\\AppData\\Roaming", - util.get_user_config_directory()) - with mock.patch.dict(os.environ, { - "LOCALAPPDATA": "", - "APPDATA": "", - }): - self.assertIsNone(util.get_user_config_directory()) - - def test_non_windows(self): - with mock.patch.dict(os.environ, {"HOME": "/home/alice"}): - self.assertEqual( - "/home/alice%s.config" % os.sep, util.get_user_config_directory()) - with mock.patch.dict( - os.environ, {"XDG_CONFIG_HOME": "/home/alice/configz"}): - self.assertEqual( - "/home/alice/configz", util.get_user_config_directory()) + def test_windows(self): + with mock.patch.object(os, "name", "nt"): + with mock.patch.dict( + os.environ, + { + "LOCALAPPDATA": "C:\\Users\\Alice\\AppData\\Local", + "APPDATA": "C:\\Users\\Alice\\AppData\\Roaming", + }, + ): + self.assertEqual( + "C:\\Users\\Alice\\AppData\\Local", + util.get_user_config_directory(), + ) + with mock.patch.dict( + os.environ, + { + "LOCALAPPDATA": "", + "APPDATA": "C:\\Users\\Alice\\AppData\\Roaming", + }, + ): + self.assertEqual( + "C:\\Users\\Alice\\AppData\\Roaming", + util.get_user_config_directory(), + ) + with mock.patch.dict( + os.environ, {"LOCALAPPDATA": "", "APPDATA": "",} + ): + self.assertIsNone(util.get_user_config_directory()) + + def test_non_windows(self): + with mock.patch.dict(os.environ, {"HOME": "/home/alice"}): + self.assertEqual( + "/home/alice%s.config" % os.sep, + util.get_user_config_directory(), + ) + with mock.patch.dict( + os.environ, {"XDG_CONFIG_HOME": "/home/alice/configz"} + ): + self.assertEqual( + "/home/alice/configz", util.get_user_config_directory() + ) skip_if_windows = unittest.skipIf(os.name == "nt", "Unsupported on Windows") class MakeFileWithDirectoriesTest(tb_test.TestCase): - - def test_windows_private(self): - with mock.patch.object(os, "name", "nt"): - with self.assertRaisesRegex(RuntimeError, "Windows"): - util.make_file_with_directories("/tmp/foo", private=True) - - def test_existing_file(self): - root = self.get_temp_dir() - path = os.path.join(root, "foo", "bar", "qux.txt") - os.makedirs(os.path.dirname(path)) - with open(path, mode="w") as f: - f.write("foobar") - util.make_file_with_directories(path) - with open(path, mode="r") as f: - self.assertEqual("foobar", f.read()) - - def test_existing_dir(self): - root = self.get_temp_dir() - path = os.path.join(root, "foo", "bar", "qux.txt") - os.makedirs(os.path.dirname(path)) - util.make_file_with_directories(path) - self.assertEqual(0, os.path.getsize(path)) - - def test_nonexistent_leaf_dir(self): - root = self.get_temp_dir() - path = os.path.join(root, "foo", "bar", "qux.txt") - os.makedirs(os.path.dirname(os.path.dirname(path))) - util.make_file_with_directories(path) - self.assertEqual(0, os.path.getsize(path)) - - def test_nonexistent_multiple_dirs(self): - root = self.get_temp_dir() - path = os.path.join(root, "foo", "bar", "qux.txt") - util.make_file_with_directories(path) - self.assertEqual(0, os.path.getsize(path)) - - def assertMode(self, mode, path): - self.assertEqual(mode, os.stat(path).st_mode & 0o777) - - @skip_if_windows - def test_private_existing_file(self): - root = self.get_temp_dir() - path = os.path.join(root, "foo", "bar", "qux.txt") - os.makedirs(os.path.dirname(path)) - with open(path, mode="w") as f: - f.write("foobar") - os.chmod(os.path.dirname(path), 0o777) - os.chmod(path, 0o666) - util.make_file_with_directories(path, private=True) - self.assertMode(0o700, os.path.dirname(path)) - self.assertMode(0o600, path) - with open(path, mode="r") as f: - self.assertEqual("foobar", f.read()) - - @skip_if_windows - def test_private_existing_dir(self): - root = self.get_temp_dir() - path = os.path.join(root, "foo", "bar", "qux.txt") - os.makedirs(os.path.dirname(path)) - os.chmod(os.path.dirname(path), 0o777) - util.make_file_with_directories(path, private=True) - self.assertMode(0o700, os.path.dirname(path)) - self.assertMode(0o600, path) - self.assertEqual(0, os.path.getsize(path)) - - @skip_if_windows - def test_private_nonexistent_leaf_dir(self): - root = self.get_temp_dir() - path = os.path.join(root, "foo", "bar", "qux.txt") - os.makedirs(os.path.dirname(os.path.dirname(path))) - util.make_file_with_directories(path, private=True) - self.assertMode(0o700, os.path.dirname(path)) - self.assertMode(0o600, path) - self.assertEqual(0, os.path.getsize(path)) - - @skip_if_windows - def test_private_nonexistent_multiple_dirs(self): - root = self.get_temp_dir() - path = os.path.join(root, "foo", "bar", "qux.txt") - util.make_file_with_directories(path, private=True) - self.assertMode(0o700, os.path.dirname(path)) - self.assertMode(0o600, path) - self.assertEqual(0, os.path.getsize(path)) + def test_windows_private(self): + with mock.patch.object(os, "name", "nt"): + with self.assertRaisesRegex(RuntimeError, "Windows"): + util.make_file_with_directories("/tmp/foo", private=True) + + def test_existing_file(self): + root = self.get_temp_dir() + path = os.path.join(root, "foo", "bar", "qux.txt") + os.makedirs(os.path.dirname(path)) + with open(path, mode="w") as f: + f.write("foobar") + util.make_file_with_directories(path) + with open(path, mode="r") as f: + self.assertEqual("foobar", f.read()) + + def test_existing_dir(self): + root = self.get_temp_dir() + path = os.path.join(root, "foo", "bar", "qux.txt") + os.makedirs(os.path.dirname(path)) + util.make_file_with_directories(path) + self.assertEqual(0, os.path.getsize(path)) + + def test_nonexistent_leaf_dir(self): + root = self.get_temp_dir() + path = os.path.join(root, "foo", "bar", "qux.txt") + os.makedirs(os.path.dirname(os.path.dirname(path))) + util.make_file_with_directories(path) + self.assertEqual(0, os.path.getsize(path)) + + def test_nonexistent_multiple_dirs(self): + root = self.get_temp_dir() + path = os.path.join(root, "foo", "bar", "qux.txt") + util.make_file_with_directories(path) + self.assertEqual(0, os.path.getsize(path)) + + def assertMode(self, mode, path): + self.assertEqual(mode, os.stat(path).st_mode & 0o777) + + @skip_if_windows + def test_private_existing_file(self): + root = self.get_temp_dir() + path = os.path.join(root, "foo", "bar", "qux.txt") + os.makedirs(os.path.dirname(path)) + with open(path, mode="w") as f: + f.write("foobar") + os.chmod(os.path.dirname(path), 0o777) + os.chmod(path, 0o666) + util.make_file_with_directories(path, private=True) + self.assertMode(0o700, os.path.dirname(path)) + self.assertMode(0o600, path) + with open(path, mode="r") as f: + self.assertEqual("foobar", f.read()) + + @skip_if_windows + def test_private_existing_dir(self): + root = self.get_temp_dir() + path = os.path.join(root, "foo", "bar", "qux.txt") + os.makedirs(os.path.dirname(path)) + os.chmod(os.path.dirname(path), 0o777) + util.make_file_with_directories(path, private=True) + self.assertMode(0o700, os.path.dirname(path)) + self.assertMode(0o600, path) + self.assertEqual(0, os.path.getsize(path)) + + @skip_if_windows + def test_private_nonexistent_leaf_dir(self): + root = self.get_temp_dir() + path = os.path.join(root, "foo", "bar", "qux.txt") + os.makedirs(os.path.dirname(os.path.dirname(path))) + util.make_file_with_directories(path, private=True) + self.assertMode(0o700, os.path.dirname(path)) + self.assertMode(0o600, path) + self.assertEqual(0, os.path.getsize(path)) + + @skip_if_windows + def test_private_nonexistent_multiple_dirs(self): + root = self.get_temp_dir() + path = os.path.join(root, "foo", "bar", "qux.txt") + util.make_file_with_directories(path, private=True) + self.assertMode(0o700, os.path.dirname(path)) + self.assertMode(0o600, path) + self.assertEqual(0, os.path.getsize(path)) class SetTimestampTest(tb_test.TestCase): - - def test_set_timestamp(self): - pb = timestamp_pb2.Timestamp() - t = 1234567890.007812500 - # Note that just multiplying by 1e9 would lose precision: - self.assertEqual(int(t * 1e9) % int(1e9), 7812608) - util.set_timestamp(pb, t) - self.assertEqual(pb.seconds, 1234567890) - self.assertEqual(pb.nanos, 7812500) + def test_set_timestamp(self): + pb = timestamp_pb2.Timestamp() + t = 1234567890.007812500 + # Note that just multiplying by 1e9 would lose precision: + self.assertEqual(int(t * 1e9) % int(1e9), 7812608) + util.set_timestamp(pb, t) + self.assertEqual(pb.seconds, 1234567890) + self.assertEqual(pb.nanos, 7812500) class FormatTimeTest(tb_test.TestCase): - - def _run(self, t=None, now=None): - timestamp_pb = timestamp_pb2.Timestamp() - util.set_timestamp(timestamp_pb, t) - with mock.patch.dict(os.environ, {"TZ": "UTC"}): - now = datetime.datetime.fromtimestamp(now) - return util.format_time(timestamp_pb, now=now) - - def test_just_now(self): - base = 1546398245 - actual = self._run(t=base, now=base + 1) - self.assertEqual(actual, "2019-01-02 03:04:05 (just now)") - - def test_seconds_ago(self): - base = 1546398245 - actual = self._run(t=base, now=base + 10) - self.assertEqual(actual, "2019-01-02 03:04:05 (10 seconds ago)") - - def test_minute_ago(self): - base = 1546398245 - actual = self._run(t=base, now=base + 66) - self.assertEqual(actual, "2019-01-02 03:04:05 (1 minute ago)") - - def test_minutes_ago(self): - base = 1546398245 - actual = self._run(t=base, now=base + 222) - self.assertEqual(actual, "2019-01-02 03:04:05 (3 minutes ago)") - - def test_hour_ago(self): - base = 1546398245 - actual = self._run(t=base, now=base + 3601) - self.assertEqual(actual, "2019-01-02 03:04:05 (1 hour ago)") - - def test_hours_ago(self): - base = 1546398245 - actual = self._run(t=base, now=base + 9999) - self.assertEqual(actual, "2019-01-02 03:04:05 (2 hours ago)") - - def test_long_ago(self): - base = 1546398245 - actual = self._run(t=base, now=base + 7 * 86400) - self.assertEqual(actual, "2019-01-02 03:04:05") + def _run(self, t=None, now=None): + timestamp_pb = timestamp_pb2.Timestamp() + util.set_timestamp(timestamp_pb, t) + with mock.patch.dict(os.environ, {"TZ": "UTC"}): + now = datetime.datetime.fromtimestamp(now) + return util.format_time(timestamp_pb, now=now) + + def test_just_now(self): + base = 1546398245 + actual = self._run(t=base, now=base + 1) + self.assertEqual(actual, "2019-01-02 03:04:05 (just now)") + + def test_seconds_ago(self): + base = 1546398245 + actual = self._run(t=base, now=base + 10) + self.assertEqual(actual, "2019-01-02 03:04:05 (10 seconds ago)") + + def test_minute_ago(self): + base = 1546398245 + actual = self._run(t=base, now=base + 66) + self.assertEqual(actual, "2019-01-02 03:04:05 (1 minute ago)") + + def test_minutes_ago(self): + base = 1546398245 + actual = self._run(t=base, now=base + 222) + self.assertEqual(actual, "2019-01-02 03:04:05 (3 minutes ago)") + + def test_hour_ago(self): + base = 1546398245 + actual = self._run(t=base, now=base + 3601) + self.assertEqual(actual, "2019-01-02 03:04:05 (1 hour ago)") + + def test_hours_ago(self): + base = 1546398245 + actual = self._run(t=base, now=base + 9999) + self.assertEqual(actual, "2019-01-02 03:04:05 (2 hours ago)") + + def test_long_ago(self): + base = 1546398245 + actual = self._run(t=base, now=base + 7 * 86400) + self.assertEqual(actual, "2019-01-02 03:04:05") if __name__ == "__main__": - tb_test.main() + tb_test.main() diff --git a/tensorboard/util/argparse_util.py b/tensorboard/util/argparse_util.py index 27d08dd1d5..5bb26e84c1 100644 --- a/tensorboard/util/argparse_util.py +++ b/tensorboard/util/argparse_util.py @@ -25,41 +25,41 @@ @contextlib.contextmanager def allow_missing_subcommand(): - """Make Python 2.7 behave like Python 3 w.r.t. default subcommands. + """Make Python 2.7 behave like Python 3 w.r.t. default subcommands. - The behavior of argparse was changed [1] [2] in Python 3.3. When a - parser defines subcommands, it used to be an error for the user to - invoke the binary without specifying a subcommand. As of Python 3.3, - this is permitted. This monkey patch backports the new behavior to - earlier versions of Python. + The behavior of argparse was changed [1] [2] in Python 3.3. When a + parser defines subcommands, it used to be an error for the user to + invoke the binary without specifying a subcommand. As of Python 3.3, + this is permitted. This monkey patch backports the new behavior to + earlier versions of Python. - This context manager need only be used around `parse_args`; parsers - may be constructed and configured outside of the context manager. + This context manager need only be used around `parse_args`; parsers + may be constructed and configured outside of the context manager. - [1]: https://github.com/python/cpython/commit/f97c59aaba2d93e48cbc6d25f7ff9f9c87f8d0b2 - [2]: https://bugs.python.org/issue16308 - """ + [1]: https://github.com/python/cpython/commit/f97c59aaba2d93e48cbc6d25f7ff9f9c87f8d0b2 + [2]: https://bugs.python.org/issue16308 + """ - real_error = argparse.ArgumentParser.error + real_error = argparse.ArgumentParser.error - # This must exactly match the error message raised by Python 2.7's - # `argparse` when no subparser is given. This is `argparse.py:1954` at - # Git tag `v2.7.16`. - ignored_message = gettext.gettext("too few arguments") + # This must exactly match the error message raised by Python 2.7's + # `argparse` when no subparser is given. This is `argparse.py:1954` at + # Git tag `v2.7.16`. + ignored_message = gettext.gettext("too few arguments") - def error(*args, **kwargs): - # Expected signature is `error(self, message)`, but we retain more - # flexibility to be forward-compatible with implementation changes. - if "message" not in kwargs and len(args) < 2: - return real_error(*args, **kwargs) - message = kwargs["message"] if "message" in kwargs else args[1] - if message == ignored_message: - return None - else: - return real_error(*args, **kwargs) + def error(*args, **kwargs): + # Expected signature is `error(self, message)`, but we retain more + # flexibility to be forward-compatible with implementation changes. + if "message" not in kwargs and len(args) < 2: + return real_error(*args, **kwargs) + message = kwargs["message"] if "message" in kwargs else args[1] + if message == ignored_message: + return None + else: + return real_error(*args, **kwargs) - argparse.ArgumentParser.error = error - try: - yield - finally: - argparse.ArgumentParser.error = real_error + argparse.ArgumentParser.error = error + try: + yield + finally: + argparse.ArgumentParser.error = real_error diff --git a/tensorboard/util/argparse_util_test.py b/tensorboard/util/argparse_util_test.py index 4e18d4a027..7279da1918 100644 --- a/tensorboard/util/argparse_util_test.py +++ b/tensorboard/util/argparse_util_test.py @@ -25,32 +25,31 @@ class AllowMissingSubcommandTest(tb_test.TestCase): - - def test_allows_missing_subcommands(self): - parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers() - subparser = subparsers.add_parser("magic") - subparser.set_defaults(chosen="magic") - with argparse_util.allow_missing_subcommand(): - args = parser.parse_args([]) - self.assertEqual(args, argparse.Namespace()) - - def test_allows_provided_subcommands(self): - parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers() - subparser = subparsers.add_parser("magic") - subparser.set_defaults(chosen="magic") - with argparse_util.allow_missing_subcommand(): - args = parser.parse_args(["magic"]) - self.assertEqual(args, argparse.Namespace(chosen="magic")) - - def test_still_complains_on_missing_arguments(self): - parser = argparse.ArgumentParser() - parser.add_argument("please_provide_me") - with argparse_util.allow_missing_subcommand(): - with self.assertRaises(SystemExit): - parser.parse_args([]) + def test_allows_missing_subcommands(self): + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers() + subparser = subparsers.add_parser("magic") + subparser.set_defaults(chosen="magic") + with argparse_util.allow_missing_subcommand(): + args = parser.parse_args([]) + self.assertEqual(args, argparse.Namespace()) + + def test_allows_provided_subcommands(self): + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers() + subparser = subparsers.add_parser("magic") + subparser.set_defaults(chosen="magic") + with argparse_util.allow_missing_subcommand(): + args = parser.parse_args(["magic"]) + self.assertEqual(args, argparse.Namespace(chosen="magic")) + + def test_still_complains_on_missing_arguments(self): + parser = argparse.ArgumentParser() + parser.add_argument("please_provide_me") + with argparse_util.allow_missing_subcommand(): + with self.assertRaises(SystemExit): + parser.parse_args([]) if __name__ == "__main__": - tb_test.main() + tb_test.main() diff --git a/tensorboard/util/encoder.py b/tensorboard/util/encoder.py index 130d1632d4..072331a6dc 100644 --- a/tensorboard/util/encoder.py +++ b/tensorboard/util/encoder.py @@ -27,83 +27,94 @@ class _TensorFlowPngEncoder(op_evaluator.PersistentOpEvaluator): - """Encode an image to PNG. + """Encode an image to PNG. - This function is thread-safe, and has high performance when run in - parallel. See `encode_png_benchmark.py` for details. + This function is thread-safe, and has high performance when run in + parallel. See `encode_png_benchmark.py` for details. - Arguments: - image: A numpy array of shape `[height, width, channels]`, where - `channels` is 1, 3, or 4, and of dtype uint8. + Arguments: + image: A numpy array of shape `[height, width, channels]`, where + `channels` is 1, 3, or 4, and of dtype uint8. - Returns: - A bytestring with PNG-encoded data. - """ + Returns: + A bytestring with PNG-encoded data. + """ - def __init__(self): - super(_TensorFlowPngEncoder, self).__init__() - self._image_placeholder = None - self._encode_op = None + def __init__(self): + super(_TensorFlowPngEncoder, self).__init__() + self._image_placeholder = None + self._encode_op = None - def initialize_graph(self): - # TODO(nickfelt): remove on-demand imports once dep situation is fixed. - import tensorflow.compat.v1 as tf - self._image_placeholder = tf.placeholder( - dtype=tf.uint8, name='image_to_encode') - self._encode_op = tf.image.encode_png(self._image_placeholder) + def initialize_graph(self): + # TODO(nickfelt): remove on-demand imports once dep situation is fixed. + import tensorflow.compat.v1 as tf - def run(self, image): # pylint: disable=arguments-differ - if not isinstance(image, np.ndarray): - raise ValueError("'image' must be a numpy array: %r" % image) - if image.dtype != np.uint8: - raise ValueError("'image' dtype must be uint8, but is %r" % image.dtype) - return self._encode_op.eval(feed_dict={self._image_placeholder: image}) + self._image_placeholder = tf.placeholder( + dtype=tf.uint8, name="image_to_encode" + ) + self._encode_op = tf.image.encode_png(self._image_placeholder) + + def run(self, image): # pylint: disable=arguments-differ + if not isinstance(image, np.ndarray): + raise ValueError("'image' must be a numpy array: %r" % image) + if image.dtype != np.uint8: + raise ValueError( + "'image' dtype must be uint8, but is %r" % image.dtype + ) + return self._encode_op.eval(feed_dict={self._image_placeholder: image}) encode_png = _TensorFlowPngEncoder() class _TensorFlowWavEncoder(op_evaluator.PersistentOpEvaluator): - """Encode an audio clip to WAV. - - This function is thread-safe and exhibits good parallel performance. - - Arguments: - audio: A numpy array of shape `[samples, channels]`. - samples_per_second: A positive `int`, in Hz. - - Returns: - A bytestring with WAV-encoded data. - """ - - def __init__(self): - super(_TensorFlowWavEncoder, self).__init__() - self._audio_placeholder = None - self._samples_per_second_placeholder = None - self._encode_op = None - - def initialize_graph(self): - # TODO(nickfelt): remove on-demand imports once dep situation is fixed. - import tensorflow.compat.v1 as tf - self._audio_placeholder = tf.placeholder( - dtype=tf.float32, name='image_to_encode') - self._samples_per_second_placeholder = tf.placeholder( - dtype=tf.int32, name='samples_per_second') - self._encode_op = tf.audio.encode_wav( - self._audio_placeholder, - sample_rate=self._samples_per_second_placeholder) - - def run(self, audio, samples_per_second): # pylint: disable=arguments-differ - if not isinstance(audio, np.ndarray): - raise ValueError("'audio' must be a numpy array: %r" % audio) - if not isinstance(samples_per_second, int): - raise ValueError("'samples_per_second' must be an int: %r" - % samples_per_second) - feed_dict = { - self._audio_placeholder: audio, - self._samples_per_second_placeholder: samples_per_second, - } - return self._encode_op.eval(feed_dict=feed_dict) + """Encode an audio clip to WAV. + + This function is thread-safe and exhibits good parallel performance. + + Arguments: + audio: A numpy array of shape `[samples, channels]`. + samples_per_second: A positive `int`, in Hz. + + Returns: + A bytestring with WAV-encoded data. + """ + + def __init__(self): + super(_TensorFlowWavEncoder, self).__init__() + self._audio_placeholder = None + self._samples_per_second_placeholder = None + self._encode_op = None + + def initialize_graph(self): + # TODO(nickfelt): remove on-demand imports once dep situation is fixed. + import tensorflow.compat.v1 as tf + + self._audio_placeholder = tf.placeholder( + dtype=tf.float32, name="image_to_encode" + ) + self._samples_per_second_placeholder = tf.placeholder( + dtype=tf.int32, name="samples_per_second" + ) + self._encode_op = tf.audio.encode_wav( + self._audio_placeholder, + sample_rate=self._samples_per_second_placeholder, + ) + + def run( + self, audio, samples_per_second + ): # pylint: disable=arguments-differ + if not isinstance(audio, np.ndarray): + raise ValueError("'audio' must be a numpy array: %r" % audio) + if not isinstance(samples_per_second, int): + raise ValueError( + "'samples_per_second' must be an int: %r" % samples_per_second + ) + feed_dict = { + self._audio_placeholder: audio, + self._samples_per_second_placeholder: samples_per_second, + } + return self._encode_op.eval(feed_dict=feed_dict) encode_wav = _TensorFlowWavEncoder() diff --git a/tensorboard/util/encoder_test.py b/tensorboard/util/encoder_test.py index c58916e196..ecaf44224e 100644 --- a/tensorboard/util/encoder_test.py +++ b/tensorboard/util/encoder_test.py @@ -24,59 +24,59 @@ class TensorFlowPngEncoderTest(tf.test.TestCase): - - def setUp(self): - super(TensorFlowPngEncoderTest, self).setUp() - self._encode = encoder._TensorFlowPngEncoder() - self._rgb = np.arange(12 * 34 * 3).reshape((12, 34, 3)).astype(np.uint8) - self._rgba = np.arange(21 * 43 * 4).reshape((21, 43, 4)).astype(np.uint8) - - def _check_png(self, data): - # If it has a valid PNG header and is of a reasonable size, we can - # assume it did the right thing. We trust the underlying - # `encode_png` op. - self.assertEqual(b'\x89PNG', data[:4]) - self.assertGreater(len(data), 128) - - def test_invalid_non_numpy(self): - with six.assertRaisesRegex(self, ValueError, "must be a numpy array"): - self._encode(self._rgb.tolist()) - - def test_invalid_non_uint8(self): - with six.assertRaisesRegex(self, ValueError, "dtype must be uint8"): - self._encode(self._rgb.astype(np.float32)) - - def test_encodes_png(self): - data = self._encode(self._rgb) - self._check_png(data) - - def test_encodes_png_with_alpha(self): - data = self._encode(self._rgba) - self._check_png(data) + def setUp(self): + super(TensorFlowPngEncoderTest, self).setUp() + self._encode = encoder._TensorFlowPngEncoder() + self._rgb = np.arange(12 * 34 * 3).reshape((12, 34, 3)).astype(np.uint8) + self._rgba = ( + np.arange(21 * 43 * 4).reshape((21, 43, 4)).astype(np.uint8) + ) + + def _check_png(self, data): + # If it has a valid PNG header and is of a reasonable size, we can + # assume it did the right thing. We trust the underlying + # `encode_png` op. + self.assertEqual(b"\x89PNG", data[:4]) + self.assertGreater(len(data), 128) + + def test_invalid_non_numpy(self): + with six.assertRaisesRegex(self, ValueError, "must be a numpy array"): + self._encode(self._rgb.tolist()) + + def test_invalid_non_uint8(self): + with six.assertRaisesRegex(self, ValueError, "dtype must be uint8"): + self._encode(self._rgb.astype(np.float32)) + + def test_encodes_png(self): + data = self._encode(self._rgb) + self._check_png(data) + + def test_encodes_png_with_alpha(self): + data = self._encode(self._rgba) + self._check_png(data) class TensorFlowWavEncoderTest(tf.test.TestCase): + def setUp(self): + super(TensorFlowWavEncoderTest, self).setUp() + self._encode = encoder._TensorFlowWavEncoder() + space = np.linspace(0.0, 100.0, 44100) + self._stereo = np.array([np.sin(space), np.cos(space)]).transpose() + self._mono = self._stereo.mean(axis=1, keepdims=True) - def setUp(self): - super(TensorFlowWavEncoderTest, self).setUp() - self._encode = encoder._TensorFlowWavEncoder() - space = np.linspace(0.0, 100.0, 44100) - self._stereo = np.array([np.sin(space), np.cos(space)]).transpose() - self._mono = self._stereo.mean(axis=1, keepdims=True) - - def _check_wav(self, data): - # If it has a valid WAV/RIFF header and is of a reasonable size, we - # can assume it did the right thing. We trust the underlying - # `encode_audio` op. - self.assertEqual(b'RIFF', data[:4]) - self.assertGreater(len(data), 128) + def _check_wav(self, data): + # If it has a valid WAV/RIFF header and is of a reasonable size, we + # can assume it did the right thing. We trust the underlying + # `encode_audio` op. + self.assertEqual(b"RIFF", data[:4]) + self.assertGreater(len(data), 128) - def test_encodes_mono_wav(self): - self._check_wav(self._encode(self._mono, samples_per_second=44100)) + def test_encodes_mono_wav(self): + self._check_wav(self._encode(self._mono, samples_per_second=44100)) - def test_encodes_stereo_wav(self): - self._check_wav(self._encode(self._stereo, samples_per_second=44100)) + def test_encodes_stereo_wav(self): + self._check_wav(self._encode(self._stereo, samples_per_second=44100)) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/util/grpc_util.py b/tensorboard/util/grpc_util.py index 43c551dfeb..7c22e4bfb0 100644 --- a/tensorboard/util/grpc_util.py +++ b/tensorboard/util/grpc_util.py @@ -40,87 +40,96 @@ _GRPC_RETRY_JITTER_FACTOR_MAX = 1.5 # Status codes from gRPC for which it's reasonable to retry the RPC. -_GRPC_RETRYABLE_STATUS_CODES = frozenset([ - grpc.StatusCode.ABORTED, - grpc.StatusCode.DEADLINE_EXCEEDED, - grpc.StatusCode.RESOURCE_EXHAUSTED, - grpc.StatusCode.UNAVAILABLE, -]) +_GRPC_RETRYABLE_STATUS_CODES = frozenset( + [ + grpc.StatusCode.ABORTED, + grpc.StatusCode.DEADLINE_EXCEEDED, + grpc.StatusCode.RESOURCE_EXHAUSTED, + grpc.StatusCode.UNAVAILABLE, + ] +) # gRPC metadata key whose value contains the client version. _VERSION_METADATA_KEY = "tensorboard-version" def call_with_retries(api_method, request, clock=None): - """Call a gRPC stub API method, with automatic retry logic. - - This only supports unary-unary RPCs: i.e., no streaming on either end. - Streamed RPCs will generally need application-level pagination support, - because after a gRPC error one must retry the entire request; there is no - "retry-resume" functionality. - - Args: - api_method: Callable for the API method to invoke. - request: Request protocol buffer to pass to the API method. - clock: an interface object supporting `time()` and `sleep()` methods - like the standard `time` module; if not passed, uses the normal module. - - Returns: - Response protocol buffer returned by the API method. - - Raises: - grpc.RpcError: if a non-retryable error is returned, or if all retry - attempts have been exhausted. - """ - if clock is None: - clock = time - # We can't actually use api_method.__name__ because it's not a real method, - # it's a special gRPC callable instance that doesn't expose the method name. - rpc_name = request.__class__.__name__.replace("Request", "") - logger.debug("RPC call %s with request: %r", rpc_name, request) - num_attempts = 0 - while True: - num_attempts += 1 - try: - return api_method( - request, - timeout=_GRPC_DEFAULT_TIMEOUT_SECS, - metadata=version_metadata()) - except grpc.RpcError as e: - logger.info("RPC call %s got error %s", rpc_name, e) - if e.code() not in _GRPC_RETRYABLE_STATUS_CODES: - raise - if num_attempts >= _GRPC_RETRY_MAX_ATTEMPTS: - raise - jitter_factor = random.uniform( - _GRPC_RETRY_JITTER_FACTOR_MIN, _GRPC_RETRY_JITTER_FACTOR_MAX) - backoff_secs = (_GRPC_RETRY_EXPONENTIAL_BASE**num_attempts) * jitter_factor - logger.info( - "RPC call %s attempted %d times, retrying in %.1f seconds", - rpc_name, num_attempts, backoff_secs) - clock.sleep(backoff_secs) + """Call a gRPC stub API method, with automatic retry logic. + + This only supports unary-unary RPCs: i.e., no streaming on either end. + Streamed RPCs will generally need application-level pagination support, + because after a gRPC error one must retry the entire request; there is no + "retry-resume" functionality. + + Args: + api_method: Callable for the API method to invoke. + request: Request protocol buffer to pass to the API method. + clock: an interface object supporting `time()` and `sleep()` methods + like the standard `time` module; if not passed, uses the normal module. + + Returns: + Response protocol buffer returned by the API method. + + Raises: + grpc.RpcError: if a non-retryable error is returned, or if all retry + attempts have been exhausted. + """ + if clock is None: + clock = time + # We can't actually use api_method.__name__ because it's not a real method, + # it's a special gRPC callable instance that doesn't expose the method name. + rpc_name = request.__class__.__name__.replace("Request", "") + logger.debug("RPC call %s with request: %r", rpc_name, request) + num_attempts = 0 + while True: + num_attempts += 1 + try: + return api_method( + request, + timeout=_GRPC_DEFAULT_TIMEOUT_SECS, + metadata=version_metadata(), + ) + except grpc.RpcError as e: + logger.info("RPC call %s got error %s", rpc_name, e) + if e.code() not in _GRPC_RETRYABLE_STATUS_CODES: + raise + if num_attempts >= _GRPC_RETRY_MAX_ATTEMPTS: + raise + jitter_factor = random.uniform( + _GRPC_RETRY_JITTER_FACTOR_MIN, _GRPC_RETRY_JITTER_FACTOR_MAX + ) + backoff_secs = ( + _GRPC_RETRY_EXPONENTIAL_BASE ** num_attempts + ) * jitter_factor + logger.info( + "RPC call %s attempted %d times, retrying in %.1f seconds", + rpc_name, + num_attempts, + backoff_secs, + ) + clock.sleep(backoff_secs) def version_metadata(): - """Creates gRPC invocation metadata encoding the TensorBoard version. + """Creates gRPC invocation metadata encoding the TensorBoard version. - Usage: `stub.MyRpc(request, metadata=version_metadata())`. + Usage: `stub.MyRpc(request, metadata=version_metadata())`. - Returns: - A tuple of key-value pairs (themselves 2-tuples) to be passed as the - `metadata` kwarg to gRPC stub API methods. - """ - return ((_VERSION_METADATA_KEY, version.VERSION),) + Returns: + A tuple of key-value pairs (themselves 2-tuples) to be passed as the + `metadata` kwarg to gRPC stub API methods. + """ + return ((_VERSION_METADATA_KEY, version.VERSION),) def extract_version(metadata): - """Extracts version from invocation metadata. + """Extracts version from invocation metadata. - The argument should be the result of a prior call to `metadata` or the - result of combining such a result with other metadata. + The argument should be the result of a prior call to `metadata` or the + result of combining such a result with other metadata. - Returns: - The TensorBoard version listed in this metadata, or `None` if none - is listed. - """ - return dict(metadata).get(_VERSION_METADATA_KEY) + Returns: + The TensorBoard version listed in this metadata, or `None` if none + is listed. + """ + return dict(metadata).get(_VERSION_METADATA_KEY) diff --git a/tensorboard/util/grpc_util_test.py b/tensorboard/util/grpc_util_test.py index f35ad1859f..28897967ec 100644 --- a/tensorboard/util/grpc_util_test.py +++ b/tensorboard/util/grpc_util_test.py @@ -35,130 +35,151 @@ def make_request(nonce): - return grpc_util_test_pb2.TestRpcRequest(nonce=nonce) + return grpc_util_test_pb2.TestRpcRequest(nonce=nonce) def make_response(nonce): - return grpc_util_test_pb2.TestRpcResponse(nonce=nonce) + return grpc_util_test_pb2.TestRpcResponse(nonce=nonce) class TestGrpcServer(grpc_util_test_pb2_grpc.TestServiceServicer): - """Helper for testing gRPC client logic with a dummy gRPC server.""" - - def __init__(self, handler): - super(TestGrpcServer, self).__init__() - self._handler = handler - - def TestRpc(self, request, context): - return self._handler(request, context) - - @contextlib.contextmanager - def run(self): - """Context manager to run the gRPC server and yield a client for it.""" - server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) - grpc_util_test_pb2_grpc.add_TestServiceServicer_to_server(self, server) - port = server.add_secure_port( - "localhost:0", grpc.local_server_credentials()) - def launch_server(): - server.start() - server.wait_for_termination() - thread = threading.Thread(target=launch_server, name="TestGrpcServer") - thread.daemon = True - thread.start() - with grpc.secure_channel( - "localhost:%d" % port, grpc.local_channel_credentials()) as channel: - yield grpc_util_test_pb2_grpc.TestServiceStub(channel) - server.stop(grace=None) - thread.join() + """Helper for testing gRPC client logic with a dummy gRPC server.""" + + def __init__(self, handler): + super(TestGrpcServer, self).__init__() + self._handler = handler + + def TestRpc(self, request, context): + return self._handler(request, context) + + @contextlib.contextmanager + def run(self): + """Context manager to run the gRPC server and yield a client for it.""" + server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) + grpc_util_test_pb2_grpc.add_TestServiceServicer_to_server(self, server) + port = server.add_secure_port( + "localhost:0", grpc.local_server_credentials() + ) + + def launch_server(): + server.start() + server.wait_for_termination() + + thread = threading.Thread(target=launch_server, name="TestGrpcServer") + thread.daemon = True + thread.start() + with grpc.secure_channel( + "localhost:%d" % port, grpc.local_channel_credentials() + ) as channel: + yield grpc_util_test_pb2_grpc.TestServiceStub(channel) + server.stop(grace=None) + thread.join() class CallWithRetriesTest(tb_test.TestCase): - - def test_call_with_retries_succeeds(self): - def handler(request, _): - return make_response(request.nonce) - server = TestGrpcServer(handler) - with server.run() as client: - response = grpc_util.call_with_retries(client.TestRpc, make_request(42)) - self.assertEqual(make_response(42), response) - - def test_call_with_retries_fails_immediately_on_permanent_error(self): - def handler(_, context): - context.abort(grpc.StatusCode.INTERNAL, "foo") - server = TestGrpcServer(handler) - with server.run() as client: - with self.assertRaises(grpc.RpcError) as raised: - grpc_util.call_with_retries(client.TestRpc, make_request(42)) - self.assertEqual(grpc.StatusCode.INTERNAL, raised.exception.code()) - self.assertEqual("foo", raised.exception.details()) - - def test_call_with_retries_fails_after_backoff_on_nonpermanent_error(self): - attempt_times = [] - fake_time = test_util.FakeTime() - def handler(_, context): - attempt_times.append(fake_time.time()) - context.abort(grpc.StatusCode.UNAVAILABLE, "foo") - server = TestGrpcServer(handler) - with server.run() as client: - with self.assertRaises(grpc.RpcError) as raised: - grpc_util.call_with_retries(client.TestRpc, make_request(42), fake_time) - self.assertEqual(grpc.StatusCode.UNAVAILABLE, raised.exception.code()) - self.assertEqual("foo", raised.exception.details()) - self.assertLen(attempt_times, 5) - self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4) - self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8) - self.assertBetween(attempt_times[3] - attempt_times[2], 8, 16) - self.assertBetween(attempt_times[4] - attempt_times[3], 16, 32) - - def test_call_with_retries_succeeds_after_backoff_on_transient_error(self): - attempt_times = [] - fake_time = test_util.FakeTime() - def handler(request, context): - attempt_times.append(fake_time.time()) - if len(attempt_times) < 3: - context.abort(grpc.StatusCode.UNAVAILABLE, "foo") - return make_response(request.nonce) - server = TestGrpcServer(handler) - with server.run() as client: - response = grpc_util.call_with_retries( - client.TestRpc, make_request(42), fake_time) - self.assertEqual(make_response(42), response) - self.assertLen(attempt_times, 3) - self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4) - self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8) - - def test_call_with_retries_includes_version_metadata(self): - def digest(s): - """Hashes a string into a positive 32-bit signed integer.""" - return int(hashlib.sha256(s.encode("utf-8")).hexdigest(), 16) & 0x7fffffff - def handler(request, context): - metadata = context.invocation_metadata() - client_version = grpc_util.extract_version(metadata) - return make_response(digest(client_version)) - server = TestGrpcServer(handler) - with server.run() as client: - response = grpc_util.call_with_retries(client.TestRpc, make_request(0)) - expected_nonce = digest( - grpc_util.extract_version(grpc_util.version_metadata())) - self.assertEqual(make_response(expected_nonce), response) + def test_call_with_retries_succeeds(self): + def handler(request, _): + return make_response(request.nonce) + + server = TestGrpcServer(handler) + with server.run() as client: + response = grpc_util.call_with_retries( + client.TestRpc, make_request(42) + ) + self.assertEqual(make_response(42), response) + + def test_call_with_retries_fails_immediately_on_permanent_error(self): + def handler(_, context): + context.abort(grpc.StatusCode.INTERNAL, "foo") + + server = TestGrpcServer(handler) + with server.run() as client: + with self.assertRaises(grpc.RpcError) as raised: + grpc_util.call_with_retries(client.TestRpc, make_request(42)) + self.assertEqual(grpc.StatusCode.INTERNAL, raised.exception.code()) + self.assertEqual("foo", raised.exception.details()) + + def test_call_with_retries_fails_after_backoff_on_nonpermanent_error(self): + attempt_times = [] + fake_time = test_util.FakeTime() + + def handler(_, context): + attempt_times.append(fake_time.time()) + context.abort(grpc.StatusCode.UNAVAILABLE, "foo") + + server = TestGrpcServer(handler) + with server.run() as client: + with self.assertRaises(grpc.RpcError) as raised: + grpc_util.call_with_retries( + client.TestRpc, make_request(42), fake_time + ) + self.assertEqual(grpc.StatusCode.UNAVAILABLE, raised.exception.code()) + self.assertEqual("foo", raised.exception.details()) + self.assertLen(attempt_times, 5) + self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4) + self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8) + self.assertBetween(attempt_times[3] - attempt_times[2], 8, 16) + self.assertBetween(attempt_times[4] - attempt_times[3], 16, 32) + + def test_call_with_retries_succeeds_after_backoff_on_transient_error(self): + attempt_times = [] + fake_time = test_util.FakeTime() + + def handler(request, context): + attempt_times.append(fake_time.time()) + if len(attempt_times) < 3: + context.abort(grpc.StatusCode.UNAVAILABLE, "foo") + return make_response(request.nonce) + + server = TestGrpcServer(handler) + with server.run() as client: + response = grpc_util.call_with_retries( + client.TestRpc, make_request(42), fake_time + ) + self.assertEqual(make_response(42), response) + self.assertLen(attempt_times, 3) + self.assertBetween(attempt_times[1] - attempt_times[0], 2, 4) + self.assertBetween(attempt_times[2] - attempt_times[1], 4, 8) + + def test_call_with_retries_includes_version_metadata(self): + def digest(s): + """Hashes a string into a positive 32-bit signed integer.""" + return ( + int(hashlib.sha256(s.encode("utf-8")).hexdigest(), 16) + & 0x7FFFFFFF + ) + + def handler(request, context): + metadata = context.invocation_metadata() + client_version = grpc_util.extract_version(metadata) + return make_response(digest(client_version)) + + server = TestGrpcServer(handler) + with server.run() as client: + response = grpc_util.call_with_retries( + client.TestRpc, make_request(0) + ) + expected_nonce = digest( + grpc_util.extract_version(grpc_util.version_metadata()) + ) + self.assertEqual(make_response(expected_nonce), response) class VersionMetadataTest(tb_test.TestCase): + def test_structure(self): + result = grpc_util.version_metadata() + self.assertIsInstance(result, tuple) + for kv in result: + self.assertIsInstance(kv, tuple) + self.assertLen(kv, 2) + (k, v) = kv + self.assertIsInstance(k, str) + self.assertIsInstance(v, six.string_types) - def test_structure(self): - result = grpc_util.version_metadata() - self.assertIsInstance(result, tuple) - for kv in result: - self.assertIsInstance(kv, tuple) - self.assertLen(kv, 2) - (k, v) = kv - self.assertIsInstance(k, str) - self.assertIsInstance(v, six.string_types) - - def test_roundtrip(self): - result = grpc_util.extract_version(grpc_util.version_metadata()) - self.assertEqual(result, version.VERSION) + def test_roundtrip(self): + result = grpc_util.extract_version(grpc_util.version_metadata()) + self.assertEqual(result, version.VERSION) if __name__ == "__main__": - tb_test.main() + tb_test.main() diff --git a/tensorboard/util/lazy_tensor_creator.py b/tensorboard/util/lazy_tensor_creator.py index b70b736e93..85d230adb7 100644 --- a/tensorboard/util/lazy_tensor_creator.py +++ b/tensorboard/util/lazy_tensor_creator.py @@ -25,58 +25,61 @@ class LazyTensorCreator(object): - """Lazy auto-converting wrapper for a callable that returns a `tf.Tensor`. + """Lazy auto-converting wrapper for a callable that returns a `tf.Tensor`. - This class wraps an arbitrary callable that returns a `Tensor` so that it - will be automatically converted to a `Tensor` by any logic that calls - `tf.convert_to_tensor()`. This also memoizes the callable so that it is - called at most once. + This class wraps an arbitrary callable that returns a `Tensor` so that it + will be automatically converted to a `Tensor` by any logic that calls + `tf.convert_to_tensor()`. This also memoizes the callable so that it is + called at most once. - The intended use of this class is to defer the construction of a `Tensor` - (e.g. to avoid unnecessary wasted computation, or ensure any new ops are - created in a context only available later on in execution), while remaining - compatible with APIs that expect to be given an already materialized value - that can be converted to a `Tensor`. + The intended use of this class is to defer the construction of a `Tensor` + (e.g. to avoid unnecessary wasted computation, or ensure any new ops are + created in a context only available later on in execution), while remaining + compatible with APIs that expect to be given an already materialized value + that can be converted to a `Tensor`. - This class is thread-safe. - """ - - def __init__(self, tensor_callable): - """Initializes a LazyTensorCreator object. - - Args: - tensor_callable: A callable that returns a `tf.Tensor`. + This class is thread-safe. """ - if not callable(tensor_callable): - raise ValueError("Not a callable: %r" % tensor_callable) - self._tensor_callable = tensor_callable - self._tensor = None - self._tensor_lock = threading.RLock() - _register_conversion_function_once() - - def __call__(self): - if self._tensor is None or self._tensor is _CALL_IN_PROGRESS_SENTINEL: - with self._tensor_lock: - if self._tensor is _CALL_IN_PROGRESS_SENTINEL: - raise RuntimeError("Cannot use LazyTensorCreator with reentrant callable") - elif self._tensor is None: - self._tensor = _CALL_IN_PROGRESS_SENTINEL - self._tensor = self._tensor_callable() - return self._tensor + + def __init__(self, tensor_callable): + """Initializes a LazyTensorCreator object. + + Args: + tensor_callable: A callable that returns a `tf.Tensor`. + """ + if not callable(tensor_callable): + raise ValueError("Not a callable: %r" % tensor_callable) + self._tensor_callable = tensor_callable + self._tensor = None + self._tensor_lock = threading.RLock() + _register_conversion_function_once() + + def __call__(self): + if self._tensor is None or self._tensor is _CALL_IN_PROGRESS_SENTINEL: + with self._tensor_lock: + if self._tensor is _CALL_IN_PROGRESS_SENTINEL: + raise RuntimeError( + "Cannot use LazyTensorCreator with reentrant callable" + ) + elif self._tensor is None: + self._tensor = _CALL_IN_PROGRESS_SENTINEL + self._tensor = self._tensor_callable() + return self._tensor def _lazy_tensor_creator_converter(value, dtype=None, name=None, as_ref=False): - del name # ignored - if not isinstance(value, LazyTensorCreator): - raise RuntimeError("Expected LazyTensorCreator, got %r" % value) - if as_ref: - raise RuntimeError("Cannot use LazyTensorCreator to create ref tensor") - tensor = value() - if dtype not in (None, tensor.dtype): - raise RuntimeError( - "Cannot convert LazyTensorCreator returning dtype %s to dtype %s" % ( - tensor.dtype, dtype)) - return tensor + del name # ignored + if not isinstance(value, LazyTensorCreator): + raise RuntimeError("Expected LazyTensorCreator, got %r" % value) + if as_ref: + raise RuntimeError("Cannot use LazyTensorCreator to create ref tensor") + tensor = value() + if dtype not in (None, tensor.dtype): + raise RuntimeError( + "Cannot convert LazyTensorCreator returning dtype %s to dtype %s" + % (tensor.dtype, dtype) + ) + return tensor # Use module-level bit and lock to ensure that registration of the @@ -86,22 +89,23 @@ def _lazy_tensor_creator_converter(value, dtype=None, name=None, as_ref=False): def _register_conversion_function_once(): - """Performs one-time registration of `_lazy_tensor_creator_converter`. - - This helper can be invoked multiple times but only registers the conversion - function on the first invocation, making it suitable for calling when - constructing a LazyTensorCreator. - - Deferring the registration is necessary because doing it at at module import - time would trigger the lazy TensorFlow import to resolve, and that in turn - would break the delicate `tf.summary` import cycle avoidance scheme. - """ - global _conversion_registered - if not _conversion_registered: - with _conversion_registered_lock: - if not _conversion_registered: - _conversion_registered = True - tf.register_tensor_conversion_function( - base_type=LazyTensorCreator, - conversion_func=_lazy_tensor_creator_converter, - priority=0) + """Performs one-time registration of `_lazy_tensor_creator_converter`. + + This helper can be invoked multiple times but only registers the conversion + function on the first invocation, making it suitable for calling when + constructing a LazyTensorCreator. + + Deferring the registration is necessary because doing it at at module import + time would trigger the lazy TensorFlow import to resolve, and that in turn + would break the delicate `tf.summary` import cycle avoidance scheme. + """ + global _conversion_registered + if not _conversion_registered: + with _conversion_registered_lock: + if not _conversion_registered: + _conversion_registered = True + tf.register_tensor_conversion_function( + base_type=LazyTensorCreator, + conversion_func=_lazy_tensor_creator_converter, + priority=0, + ) diff --git a/tensorboard/util/lazy_tensor_creator_test.py b/tensorboard/util/lazy_tensor_creator_test.py index 2156280950..6053d40cda 100644 --- a/tensorboard/util/lazy_tensor_creator_test.py +++ b/tensorboard/util/lazy_tensor_creator_test.py @@ -25,76 +25,85 @@ class LazyTensorCreatorTest(tf.test.TestCase): - - def assertEqualAsNumpy(self, a, b): - # TODO(#2507): Remove after we no longer test against TF 1.x. - self.assertEqual(a.numpy(), b.numpy()) - - def test_lazy_creation_with_memoization(self): - boxed_count = [0] - @lazy_tensor_creator.LazyTensorCreator - def lazy_tensor(): - boxed_count[0] = boxed_count[0] + 1 - return tf.constant(1) - self.assertEqual(0, boxed_count[0]) - real_tensor = lazy_tensor() - self.assertEqual(1, boxed_count[0]) - lazy_tensor() - self.assertEqual(1, boxed_count[0]) - - def test_conversion_explicit(self): - @lazy_tensor_creator.LazyTensorCreator - def lazy_tensor(): - return tf.constant(1) - real_tensor = tf.convert_to_tensor(lazy_tensor) - self.assertEqualAsNumpy(tf.constant(1), real_tensor) - - def test_conversion_identity(self): - @lazy_tensor_creator.LazyTensorCreator - def lazy_tensor(): - return tf.constant(1) - real_tensor = tf.identity(lazy_tensor) - self.assertEqualAsNumpy(tf.constant(1), real_tensor) - - def test_conversion_implicit(self): - @lazy_tensor_creator.LazyTensorCreator - def lazy_tensor(): - return tf.constant(1) - real_tensor = lazy_tensor + tf.constant(1) - self.assertEqualAsNumpy(tf.constant(2), real_tensor) - - def test_explicit_dtype_okay_if_matches(self): - @lazy_tensor_creator.LazyTensorCreator - def lazy_tensor(): - return tf.constant(1, dtype=tf.int32) - real_tensor = tf.convert_to_tensor(lazy_tensor, dtype=tf.int32) - self.assertEqual(tf.int32, real_tensor.dtype) - self.assertEqualAsNumpy(tf.constant(1, dtype=tf.int32), real_tensor) - - def test_explicit_dtype_rejected_if_different(self): - @lazy_tensor_creator.LazyTensorCreator - def lazy_tensor(): - return tf.constant(1, dtype=tf.int32) - with self.assertRaisesRegex(RuntimeError, "dtype"): - tf.convert_to_tensor(lazy_tensor, dtype=tf.int64) - - def test_as_ref_rejected(self): - @lazy_tensor_creator.LazyTensorCreator - def lazy_tensor(): - return tf.constant(1, dtype=tf.int32) - with self.assertRaisesRegex(RuntimeError, "ref tensor"): - # Call conversion routine manually since this isn't actually - # exposed as an argument to tf.convert_to_tensor. - lazy_tensor_creator._lazy_tensor_creator_converter( - lazy_tensor, as_ref=True) - - def test_reentrant_callable_does_not_deadlock(self): - @lazy_tensor_creator.LazyTensorCreator - def lazy_tensor(): - return lazy_tensor() - with self.assertRaisesRegex(RuntimeError, "reentrant callable"): - lazy_tensor() - - -if __name__ == '__main__': - tf.test.main() + def assertEqualAsNumpy(self, a, b): + # TODO(#2507): Remove after we no longer test against TF 1.x. + self.assertEqual(a.numpy(), b.numpy()) + + def test_lazy_creation_with_memoization(self): + boxed_count = [0] + + @lazy_tensor_creator.LazyTensorCreator + def lazy_tensor(): + boxed_count[0] = boxed_count[0] + 1 + return tf.constant(1) + + self.assertEqual(0, boxed_count[0]) + real_tensor = lazy_tensor() + self.assertEqual(1, boxed_count[0]) + lazy_tensor() + self.assertEqual(1, boxed_count[0]) + + def test_conversion_explicit(self): + @lazy_tensor_creator.LazyTensorCreator + def lazy_tensor(): + return tf.constant(1) + + real_tensor = tf.convert_to_tensor(lazy_tensor) + self.assertEqualAsNumpy(tf.constant(1), real_tensor) + + def test_conversion_identity(self): + @lazy_tensor_creator.LazyTensorCreator + def lazy_tensor(): + return tf.constant(1) + + real_tensor = tf.identity(lazy_tensor) + self.assertEqualAsNumpy(tf.constant(1), real_tensor) + + def test_conversion_implicit(self): + @lazy_tensor_creator.LazyTensorCreator + def lazy_tensor(): + return tf.constant(1) + + real_tensor = lazy_tensor + tf.constant(1) + self.assertEqualAsNumpy(tf.constant(2), real_tensor) + + def test_explicit_dtype_okay_if_matches(self): + @lazy_tensor_creator.LazyTensorCreator + def lazy_tensor(): + return tf.constant(1, dtype=tf.int32) + + real_tensor = tf.convert_to_tensor(lazy_tensor, dtype=tf.int32) + self.assertEqual(tf.int32, real_tensor.dtype) + self.assertEqualAsNumpy(tf.constant(1, dtype=tf.int32), real_tensor) + + def test_explicit_dtype_rejected_if_different(self): + @lazy_tensor_creator.LazyTensorCreator + def lazy_tensor(): + return tf.constant(1, dtype=tf.int32) + + with self.assertRaisesRegex(RuntimeError, "dtype"): + tf.convert_to_tensor(lazy_tensor, dtype=tf.int64) + + def test_as_ref_rejected(self): + @lazy_tensor_creator.LazyTensorCreator + def lazy_tensor(): + return tf.constant(1, dtype=tf.int32) + + with self.assertRaisesRegex(RuntimeError, "ref tensor"): + # Call conversion routine manually since this isn't actually + # exposed as an argument to tf.convert_to_tensor. + lazy_tensor_creator._lazy_tensor_creator_converter( + lazy_tensor, as_ref=True + ) + + def test_reentrant_callable_does_not_deadlock(self): + @lazy_tensor_creator.LazyTensorCreator + def lazy_tensor(): + return lazy_tensor() + + with self.assertRaisesRegex(RuntimeError, "reentrant callable"): + lazy_tensor() + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/util/op_evaluator.py b/tensorboard/util/op_evaluator.py index f82f497484..b3bd1eecdb 100644 --- a/tensorboard/util/op_evaluator.py +++ b/tensorboard/util/op_evaluator.py @@ -25,81 +25,84 @@ class PersistentOpEvaluator(object): - """Evaluate a fixed TensorFlow graph repeatedly, safely, efficiently. - - Extend this class to create a particular kind of op evaluator, like an - image encoder. In `initialize_graph`, create an appropriate TensorFlow - graph with placeholder inputs. In `run`, evaluate this graph and - return its result. This class will manage a singleton graph and - session to preserve memory usage, and will ensure that this graph and - session do not interfere with other concurrent sessions. - - A subclass of this class offers a threadsafe, highly parallel Python - entry point for evaluating a particular TensorFlow graph. - - Example usage: - - class FluxCapacitanceEvaluator(PersistentOpEvaluator): - \"\"\"Compute the flux capacitance required for a system. - - Arguments: - x: Available power input, as a `float`, in jigawatts. - - Returns: - A `float`, in nanofarads. - \"\"\" - - def initialize_graph(self): - self._placeholder = tf.placeholder(some_dtype) - self._op = some_op(self._placeholder) - - def run(self, x): - return self._op.eval(feed_dict: {self._placeholder: x}) - - evaluate_flux_capacitance = FluxCapacitanceEvaluator() - - for x in xs: - evaluate_flux_capacitance(x) - """ - - def __init__(self): - super(PersistentOpEvaluator, self).__init__() - self._session = None - self._initialization_lock = threading.Lock() - - def _lazily_initialize(self): - """Initialize the graph and session, if this has not yet been done.""" - # TODO(nickfelt): remove on-demand imports once dep situation is fixed. - import tensorflow.compat.v1 as tf - with self._initialization_lock: - if self._session: - return - graph = tf.Graph() - with graph.as_default(): - self.initialize_graph() - # Don't reserve GPU because libpng can't run on GPU. - config = tf.ConfigProto(device_count={'GPU': 0}) - self._session = tf.Session(graph=graph, config=config) - - def initialize_graph(self): - """Create the TensorFlow graph needed to compute this operation. - - This should write ops to the default graph and return `None`. - """ - raise NotImplementedError('Subclasses must implement "initialize_graph".') + """Evaluate a fixed TensorFlow graph repeatedly, safely, efficiently. + + Extend this class to create a particular kind of op evaluator, like an + image encoder. In `initialize_graph`, create an appropriate TensorFlow + graph with placeholder inputs. In `run`, evaluate this graph and + return its result. This class will manage a singleton graph and + session to preserve memory usage, and will ensure that this graph and + session do not interfere with other concurrent sessions. + + A subclass of this class offers a threadsafe, highly parallel Python + entry point for evaluating a particular TensorFlow graph. + + Example usage: + + class FluxCapacitanceEvaluator(PersistentOpEvaluator): + \"\"\"Compute the flux capacitance required for a system. + + Arguments: + x: Available power input, as a `float`, in jigawatts. + + Returns: + A `float`, in nanofarads. + \"\"\" + + def initialize_graph(self): + self._placeholder = tf.placeholder(some_dtype) + self._op = some_op(self._placeholder) + + def run(self, x): + return self._op.eval(feed_dict: {self._placeholder: x}) - def run(self, *args, **kwargs): - """Evaluate the ops with the given input. + evaluate_flux_capacitance = FluxCapacitanceEvaluator() - When this function is called, the default session will have the - graph defined by a previous call to `initialize_graph`. This - function should evaluate any ops necessary to compute the result of - the query for the given *args and **kwargs, likely returning the - result of a call to `some_op.eval(...)`. + for x in xs: + evaluate_flux_capacitance(x) """ - raise NotImplementedError('Subclasses must implement "run".') - def __call__(self, *args, **kwargs): - self._lazily_initialize() - with self._session.as_default(): - return self.run(*args, **kwargs) + def __init__(self): + super(PersistentOpEvaluator, self).__init__() + self._session = None + self._initialization_lock = threading.Lock() + + def _lazily_initialize(self): + """Initialize the graph and session, if this has not yet been done.""" + # TODO(nickfelt): remove on-demand imports once dep situation is fixed. + import tensorflow.compat.v1 as tf + + with self._initialization_lock: + if self._session: + return + graph = tf.Graph() + with graph.as_default(): + self.initialize_graph() + # Don't reserve GPU because libpng can't run on GPU. + config = tf.ConfigProto(device_count={"GPU": 0}) + self._session = tf.Session(graph=graph, config=config) + + def initialize_graph(self): + """Create the TensorFlow graph needed to compute this operation. + + This should write ops to the default graph and return `None`. + """ + raise NotImplementedError( + 'Subclasses must implement "initialize_graph".' + ) + + def run(self, *args, **kwargs): + """Evaluate the ops with the given input. + + When this function is called, the default session will have the + graph defined by a previous call to `initialize_graph`. This + function should evaluate any ops necessary to compute the result + of the query for the given *args and **kwargs, likely returning + the result of a call to `some_op.eval(...)`. + """ + raise NotImplementedError('Subclasses must implement "run".') + + def __call__(self, *args, **kwargs): + self._lazily_initialize() + with self._session.as_default(): + return self.run(*args, **kwargs) diff --git a/tensorboard/util/op_evaluator_test.py b/tensorboard/util/op_evaluator_test.py index ca30c868c8..d6ec40793a 100644 --- a/tensorboard/util/op_evaluator_test.py +++ b/tensorboard/util/op_evaluator_test.py @@ -20,52 +20,53 @@ from tensorboard.util import op_evaluator -class PersistentOpEvaluatorTest(tf.test.TestCase): - - def setUp(self): - super(PersistentOpEvaluatorTest, self).setUp() - patch = tf.test.mock.patch('tensorflow.compat.v1.Session', wraps=tf.Session) - patch.start() - self.addCleanup(patch.stop) +class PersistentOpEvaluatorTest(tf.test.TestCase): + def setUp(self): + super(PersistentOpEvaluatorTest, self).setUp() - class Squarer(op_evaluator.PersistentOpEvaluator): + patch = tf.test.mock.patch( + "tensorflow.compat.v1.Session", wraps=tf.Session + ) + patch.start() + self.addCleanup(patch.stop) - def __init__(self): - super(Squarer, self).__init__() - self._input = None - self._squarer = None + class Squarer(op_evaluator.PersistentOpEvaluator): + def __init__(self): + super(Squarer, self).__init__() + self._input = None + self._squarer = None - def initialize_graph(self): - self._input = tf.placeholder(tf.int32) - self._squarer = tf.square(self._input) + def initialize_graph(self): + self._input = tf.placeholder(tf.int32) + self._squarer = tf.square(self._input) - def run(self, xs): # pylint: disable=arguments-differ - return self._squarer.eval(feed_dict={self._input: xs}) + def run(self, xs): # pylint: disable=arguments-differ + return self._squarer.eval(feed_dict={self._input: xs}) - self._square = Squarer() + self._square = Squarer() - def test_preserves_existing_session(self): - with tf.Session() as sess: - op = tf.reduce_sum(input_tensor=[2, 2]) - self.assertIs(sess, tf.get_default_session()) + def test_preserves_existing_session(self): + with tf.Session() as sess: + op = tf.reduce_sum(input_tensor=[2, 2]) + self.assertIs(sess, tf.get_default_session()) - result = self._square(123) - self.assertEqual(123 * 123, result) + result = self._square(123) + self.assertEqual(123 * 123, result) - self.assertIs(sess, tf.get_default_session()) - number_of_lights = sess.run(op) - self.assertEqual(number_of_lights, 4) + self.assertIs(sess, tf.get_default_session()) + number_of_lights = sess.run(op) + self.assertEqual(number_of_lights, 4) - def test_lazily_initializes_sessions(self): - self.assertEqual(tf.Session.call_count, 0) + def test_lazily_initializes_sessions(self): + self.assertEqual(tf.Session.call_count, 0) - def test_reuses_sessions(self): - self._square(123) - self.assertEqual(tf.Session.call_count, 1) - self._square(234) - self.assertEqual(tf.Session.call_count, 1) + def test_reuses_sessions(self): + self._square(123) + self.assertEqual(tf.Session.call_count, 1) + self._square(234) + self.assertEqual(tf.Session.call_count, 1) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorboard/util/platform_util.py b/tensorboard/util/platform_util.py index 89cd74c943..8b8ba5c310 100644 --- a/tensorboard/util/platform_util.py +++ b/tensorboard/util/platform_util.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""TensorBoard helper routine for platform. -""" +"""TensorBoard helper routine for platform.""" from __future__ import absolute_import from __future__ import division @@ -21,5 +20,5 @@ def readahead_file_path(path, unused_readahead=None): - """Readahead files not implemented; simply returns given path.""" - return path + """Readahead files not implemented; simply returns given path.""" + return path diff --git a/tensorboard/util/platform_util_test.py b/tensorboard/util/platform_util_test.py index eed21c8fae..c0c4ec2af5 100644 --- a/tensorboard/util/platform_util_test.py +++ b/tensorboard/util/platform_util_test.py @@ -21,10 +21,11 @@ class PlatformUtilTest(tb_test.TestCase): + def test_readahead_file_path(self): + self.assertEqual( + "foo/bar", platform_util.readahead_file_path("foo/bar") + ) - def test_readahead_file_path(self): - self.assertEqual('foo/bar', platform_util.readahead_file_path('foo/bar')) - -if __name__ == '__main__': - tb_test.main() +if __name__ == "__main__": + tb_test.main() diff --git a/tensorboard/util/tb_logging.py b/tensorboard/util/tb_logging.py index 375d9fd391..f49a5b2a20 100644 --- a/tensorboard/util/tb_logging.py +++ b/tensorboard/util/tb_logging.py @@ -17,8 +17,9 @@ import logging -_logger = logging.getLogger('tensorboard') +_logger = logging.getLogger("tensorboard") + def get_logger(): - """Returns TensorBoard logger""" - return _logger + """Returns TensorBoard logger.""" + return _logger diff --git a/tensorboard/util/tensor_util.py b/tensorboard/util/tensor_util.py index 3a374a44cc..5b7c5ade77 100644 --- a/tensorboard/util/tensor_util.py +++ b/tensorboard/util/tensor_util.py @@ -27,8 +27,12 @@ def ExtractBitsFromFloat16(x): return np.asscalar(np.asarray(x, dtype=np.float16).view(np.uint16)) + def SlowAppendFloat16ArrayToTensorProto(tensor_proto, proto_values): - tensor_proto.half_val.extend([ExtractBitsFromFloat16(x) for x in proto_values]) + tensor_proto.half_val.extend( + [ExtractBitsFromFloat16(x) for x in proto_values] + ) + def ExtractBitsFromBFloat16(x): return np.asscalar( @@ -37,7 +41,10 @@ def ExtractBitsFromBFloat16(x): def SlowAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values): - tensor_proto.half_val.extend([ExtractBitsFromBFloat16(x) for x in proto_values]) + tensor_proto.half_val.extend( + [ExtractBitsFromBFloat16(x) for x in proto_values] + ) + def SlowAppendFloat32ArrayToTensorProto(tensor_proto, proto_values): tensor_proto.float_val.extend([np.asscalar(x) for x in proto_values]) @@ -72,17 +79,21 @@ def SlowAppendComplex64ArrayToTensorProto(tensor_proto, proto_values): [np.asscalar(v) for x in proto_values for v in [x.real, x.imag]] ) + def SlowAppendComplex128ArrayToTensorProto(tensor_proto, proto_values): tensor_proto.dcomplex_val.extend( [np.asscalar(v) for x in proto_values for v in [x.real, x.imag]] ) + def SlowAppendObjectArrayToTensorProto(tensor_proto, proto_values): tensor_proto.string_val.extend([compat.as_bytes(x) for x in proto_values]) + def SlowAppendBoolArrayToTensorProto(tensor_proto, proto_values): tensor_proto.bool_val.extend([np.asscalar(x) for x in proto_values]) + _NP_TO_APPEND_FN = { np.float16: SlowAppendFloat16ArrayToTensorProto, np.float32: SlowAppendFloat32ArrayToTensorProto, @@ -107,7 +118,9 @@ def SlowAppendBoolArrayToTensorProto(tensor_proto, proto_values): # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16. } -BACKUP_DICT = {dtypes.bfloat16.as_numpy_dtype: SlowAppendBFloat16ArrayToTensorProto} +BACKUP_DICT = { + dtypes.bfloat16.as_numpy_dtype: SlowAppendBFloat16ArrayToTensorProto +} def GetFromNumpyDTypeDict(dtype_dict, dtype): @@ -140,6 +153,7 @@ def _GetDenseDimensions(list_of_lists): else: return [len(list_of_lists)] + _GetDenseDimensions(list_of_lists[0]) + def _FlattenToStrings(nested_strings): if isinstance(nested_strings, (list, tuple)): for inner in nested_strings: @@ -280,46 +294,46 @@ def _Assertconvertible(values, dtype): def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False): """Create a TensorProto. - Args: - values: Values to put in the TensorProto. - dtype: Optional tensor_pb2 DataType value. - shape: List of integers representing the dimensions of tensor. - verify_shape: Boolean that enables verification of a shape of values. + Args: + values: Values to put in the TensorProto. + dtype: Optional tensor_pb2 DataType value. + shape: List of integers representing the dimensions of tensor. + verify_shape: Boolean that enables verification of a shape of values. - Returns: - A `TensorProto`. Depending on the type, it may contain data in the - "tensor_content" attribute, which is not directly useful to Python programs. - To access the values you should convert the proto back to a numpy ndarray - with `tensor_util.MakeNdarray(proto)`. + Returns: + A `TensorProto`. Depending on the type, it may contain data in the + "tensor_content" attribute, which is not directly useful to Python programs. + To access the values you should convert the proto back to a numpy ndarray + with `tensor_util.MakeNdarray(proto)`. - If `values` is a `TensorProto`, it is immediately returned; `dtype` and - `shape` are ignored. + If `values` is a `TensorProto`, it is immediately returned; `dtype` and + `shape` are ignored. - Raises: - TypeError: if unsupported types are provided. - ValueError: if arguments have inappropriate values or if verify_shape is - True and shape of values is not equals to a shape from the argument. + Raises: + TypeError: if unsupported types are provided. + ValueError: if arguments have inappropriate values or if verify_shape is + True and shape of values is not equals to a shape from the argument. - make_tensor_proto accepts "values" of a python scalar, a python list, a - numpy ndarray, or a numpy scalar. + make_tensor_proto accepts "values" of a python scalar, a python list, a + numpy ndarray, or a numpy scalar. - If "values" is a python scalar or a python list, make_tensor_proto - first convert it to numpy ndarray. If dtype is None, the - conversion tries its best to infer the right numpy data - type. Otherwise, the resulting numpy array has a convertible data - type with the given dtype. + If "values" is a python scalar or a python list, make_tensor_proto + first convert it to numpy ndarray. If dtype is None, the + conversion tries its best to infer the right numpy data + type. Otherwise, the resulting numpy array has a convertible data + type with the given dtype. - In either case above, the numpy ndarray (either the caller provided - or the auto converted) must have the convertible type with dtype. + In either case above, the numpy ndarray (either the caller provided + or the auto converted) must have the convertible type with dtype. - make_tensor_proto then converts the numpy array to a tensor proto. + make_tensor_proto then converts the numpy array to a tensor proto. - If "shape" is None, the resulting tensor proto represents the numpy - array precisely. + If "shape" is None, the resulting tensor proto represents the numpy + array precisely. - Otherwise, "shape" specifies the tensor's shape and the numpy array - can not have more elements than what "shape" specifies. - """ + Otherwise, "shape" specifies the tensor's shape and the numpy array + can not have more elements than what "shape" specifies. + """ if isinstance(values, tensor_pb2.TensorProto): return values @@ -368,7 +382,10 @@ def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False): nparray = np.array(values, dtype=np_dt) # check to them. # We need to pass in quantized values as tuples, so don't apply the shape - if list(nparray.shape) != _GetDenseDimensions(values) and not is_quantized: + if ( + list(nparray.shape) != _GetDenseDimensions(values) + and not is_quantized + ): raise ValueError( """Argument must be a dense tensor: %s""" """ - got shape %s, but wanted %s.""" @@ -397,7 +414,8 @@ def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False): numpy_dtype = dtype if dtype is not None and ( - not hasattr(dtype, "base_dtype") or dtype.base_dtype != numpy_dtype.base_dtype + not hasattr(dtype, "base_dtype") + or dtype.base_dtype != numpy_dtype.base_dtype ): raise TypeError( "Inconvertible types: %s vs. %s. Value is %s" @@ -483,25 +501,28 @@ def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False): def make_ndarray(tensor): """Create a numpy ndarray from a tensor. - Create a numpy ndarray with the same shape and data as the tensor. - - Args: - tensor: A TensorProto. + Create a numpy ndarray with the same shape and data as the tensor. - Returns: - A numpy array with the tensor contents. + Args: + tensor: A TensorProto. - Raises: - TypeError: if tensor has unsupported type. + Returns: + A numpy array with the tensor contents. - """ + Raises: + TypeError: if tensor has unsupported type. + """ shape = [d.size for d in tensor.tensor_shape.dim] num_elements = np.prod(shape, dtype=np.int64) tensor_dtype = dtypes.as_dtype(tensor.dtype) dtype = tensor_dtype.as_numpy_dtype if tensor.tensor_content: - return np.frombuffer(tensor.tensor_content, dtype=dtype).copy().reshape(shape) + return ( + np.frombuffer(tensor.tensor_content, dtype=dtype) + .copy() + .reshape(shape) + ) elif tensor_dtype == dtypes.float16 or tensor_dtype == dtypes.bfloat16: # the half_val field of the TensorProto stores the binary representation # of the fp16: we need to reinterpret this as a proper float16 @@ -558,13 +579,16 @@ def make_ndarray(tensor): np.array(tensor.string_val[0], dtype=dtype), num_elements ).reshape(shape) else: - return np.array([x for x in tensor.string_val], dtype=dtype).reshape(shape) + return np.array( + [x for x in tensor.string_val], dtype=dtype + ).reshape(shape) elif tensor_dtype == dtypes.complex64: it = iter(tensor.scomplex_val) if len(tensor.scomplex_val) == 2: return np.repeat( np.array( - complex(tensor.scomplex_val[0], tensor.scomplex_val[1]), dtype=dtype + complex(tensor.scomplex_val[0], tensor.scomplex_val[1]), + dtype=dtype, ), num_elements, ).reshape(shape) @@ -577,7 +601,8 @@ def make_ndarray(tensor): if len(tensor.dcomplex_val) == 2: return np.repeat( np.array( - complex(tensor.dcomplex_val[0], tensor.dcomplex_val[1]), dtype=dtype + complex(tensor.dcomplex_val[0], tensor.dcomplex_val[1]), + dtype=dtype, ), num_elements, ).reshape(shape) diff --git a/tensorboard/util/test_util.py b/tensorboard/util/test_util.py index 44d56fa164..f758b9f002 100644 --- a/tensorboard/util/test_util.py +++ b/tensorboard/util/test_util.py @@ -26,6 +26,7 @@ import unittest import tensorflow as tf + # See discussion on issue #1996 for private module import justification. from tensorflow.python import tf2 as tensorflow_python_tf2 @@ -39,126 +40,146 @@ class FileWriter(tf.compat.v1.summary.FileWriter): - """FileWriter for test. - - TensorFlow FileWriter uses TensorFlow's Protobuf Python binding which is - largely discouraged in TensorBoard. We do not want a TB.Writer but require one - for testing in integrational style (writing out event files and use the real - event readers). - """ - def __init__(self, *args, **kwargs): - # Briefly enter graph mode context so this testing FileWriter can be - # created from an eager mode context without triggering a usage error. - with tf.compat.v1.Graph().as_default(): - super(FileWriter, self).__init__(*args, **kwargs) - - def add_test_summary(self, tag, simple_value=1.0, step=None): - """Convenience for writing a simple summary for a given tag.""" - value = summary_pb2.Summary.Value(tag=tag, simple_value=simple_value) - summary = summary_pb2.Summary(value=[value]) - self.add_summary(summary, global_step=step) - - def add_event(self, event): - if isinstance(event, event_pb2.Event): - tf_event = tf.compat.v1.Event.FromString(event.SerializeToString()) - else: - logger.warn('Added TensorFlow event proto. ' - 'Please prefer TensorBoard copy of the proto') - tf_event = event - super(FileWriter, self).add_event(tf_event) - - def add_summary(self, summary, global_step=None): - if isinstance(summary, summary_pb2.Summary): - tf_summary = tf.compat.v1.Summary.FromString(summary.SerializeToString()) - else: - logger.warn('Added TensorFlow summary proto. ' - 'Please prefer TensorBoard copy of the proto') - tf_summary = summary - super(FileWriter, self).add_summary(tf_summary, global_step) - - def add_session_log(self, session_log, global_step=None): - if isinstance(session_log, event_pb2.SessionLog): - tf_session_log = tf.compat.v1.SessionLog.FromString(session_log.SerializeToString()) - else: - logger.warn('Added TensorFlow session_log proto. ' - 'Please prefer TensorBoard copy of the proto') - tf_session_log = session_log - super(FileWriter, self).add_session_log(tf_session_log, global_step) - - def add_graph(self, graph, global_step=None, graph_def=None): - if isinstance(graph_def, graph_pb2.GraphDef): - tf_graph_def = tf.compat.v1.GraphDef.FromString(graph_def.SerializeToString()) - else: - tf_graph_def = graph_def - - super(FileWriter, self).add_graph(graph, global_step=global_step, graph_def=tf_graph_def) - - def add_meta_graph(self, meta_graph_def, global_step=None): - if isinstance(meta_graph_def, meta_graph_pb2.MetaGraphDef): - tf_meta_graph_def = tf.compat.v1.MetaGraphDef.FromString(meta_graph_def.SerializeToString()) - else: - tf_meta_graph_def = meta_graph_def - - super(FileWriter, self).add_meta_graph(meta_graph_def=tf_meta_graph_def, global_step=global_step) + """FileWriter for test. + + TensorFlow FileWriter uses TensorFlow's Protobuf Python binding + which is largely discouraged in TensorBoard. We do not want a + TB.Writer but require one for testing in integrational style + (writing out event files and use the real event readers). + """ + + def __init__(self, *args, **kwargs): + # Briefly enter graph mode context so this testing FileWriter can be + # created from an eager mode context without triggering a usage error. + with tf.compat.v1.Graph().as_default(): + super(FileWriter, self).__init__(*args, **kwargs) + + def add_test_summary(self, tag, simple_value=1.0, step=None): + """Convenience for writing a simple summary for a given tag.""" + value = summary_pb2.Summary.Value(tag=tag, simple_value=simple_value) + summary = summary_pb2.Summary(value=[value]) + self.add_summary(summary, global_step=step) + + def add_event(self, event): + if isinstance(event, event_pb2.Event): + tf_event = tf.compat.v1.Event.FromString(event.SerializeToString()) + else: + logger.warn( + "Added TensorFlow event proto. " + "Please prefer TensorBoard copy of the proto" + ) + tf_event = event + super(FileWriter, self).add_event(tf_event) + + def add_summary(self, summary, global_step=None): + if isinstance(summary, summary_pb2.Summary): + tf_summary = tf.compat.v1.Summary.FromString( + summary.SerializeToString() + ) + else: + logger.warn( + "Added TensorFlow summary proto. " + "Please prefer TensorBoard copy of the proto" + ) + tf_summary = summary + super(FileWriter, self).add_summary(tf_summary, global_step) + + def add_session_log(self, session_log, global_step=None): + if isinstance(session_log, event_pb2.SessionLog): + tf_session_log = tf.compat.v1.SessionLog.FromString( + session_log.SerializeToString() + ) + else: + logger.warn( + "Added TensorFlow session_log proto. " + "Please prefer TensorBoard copy of the proto" + ) + tf_session_log = session_log + super(FileWriter, self).add_session_log(tf_session_log, global_step) + + def add_graph(self, graph, global_step=None, graph_def=None): + if isinstance(graph_def, graph_pb2.GraphDef): + tf_graph_def = tf.compat.v1.GraphDef.FromString( + graph_def.SerializeToString() + ) + else: + tf_graph_def = graph_def + + super(FileWriter, self).add_graph( + graph, global_step=global_step, graph_def=tf_graph_def + ) + + def add_meta_graph(self, meta_graph_def, global_step=None): + if isinstance(meta_graph_def, meta_graph_pb2.MetaGraphDef): + tf_meta_graph_def = tf.compat.v1.MetaGraphDef.FromString( + meta_graph_def.SerializeToString() + ) + else: + tf_meta_graph_def = meta_graph_def + + super(FileWriter, self).add_meta_graph( + meta_graph_def=tf_meta_graph_def, global_step=global_step + ) class FileWriterCache(object): - """Cache for TensorBoard test file writers. - """ - # Cache, keyed by directory. - _cache = {} + """Cache for TensorBoard test file writers.""" - # Lock protecting _FILE_WRITERS. - _lock = threading.RLock() + # Cache, keyed by directory. + _cache = {} - @staticmethod - def get(logdir): - """Returns the FileWriter for the specified directory. + # Lock protecting _FILE_WRITERS. + _lock = threading.RLock() - Args: - logdir: str, name of the directory. + @staticmethod + def get(logdir): + """Returns the FileWriter for the specified directory. - Returns: - A `FileWriter`. - """ - with FileWriterCache._lock: - if logdir not in FileWriterCache._cache: - FileWriterCache._cache[logdir] = FileWriter( - logdir, graph=tf.compat.v1.get_default_graph()) - return FileWriterCache._cache[logdir] + Args: + logdir: str, name of the directory. + + Returns: + A `FileWriter`. + """ + with FileWriterCache._lock: + if logdir not in FileWriterCache._cache: + FileWriterCache._cache[logdir] = FileWriter( + logdir, graph=tf.compat.v1.get_default_graph() + ) + return FileWriterCache._cache[logdir] class FakeTime(object): - """Thread-safe fake replacement for the `time` module.""" + """Thread-safe fake replacement for the `time` module.""" - def __init__(self, current=0.0): - self._time = float(current) - self._lock = threading.Lock() + def __init__(self, current=0.0): + self._time = float(current) + self._lock = threading.Lock() - def time(self): - with self._lock: - return self._time + def time(self): + with self._lock: + return self._time - def sleep(self, secs): - with self._lock: - self._time += secs + def sleep(self, secs): + with self._lock: + self._time += secs def ensure_tb_summary_proto(summary): - """Ensures summary is TensorBoard Summary proto. + """Ensures summary is TensorBoard Summary proto. - TB v1 summary API returns TF Summary proto. To make test for v1 and v2 API - congruent, one can use this API to convert result of v1 API to TB Summary - proto. - """ - if isinstance(summary, summary_pb2.Summary): - return summary + TB v1 summary API returns TF Summary proto. To make test for v1 and + v2 API congruent, one can use this API to convert result of v1 API + to TB Summary proto. + """ + if isinstance(summary, summary_pb2.Summary): + return summary - return summary_pb2.Summary.FromString(summary.SerializeToString()) + return summary_pb2.Summary.FromString(summary.SerializeToString()) def _run_conditionally(guard, name, default_reason=None): - """Create a decorator factory that skips a test when guard returns False. + """Create a decorator factory that skips a test when guard returns False. The factory raises ValueError when default_reason is None and reason is not passed to the factory. @@ -177,19 +198,21 @@ def _run_conditionally(guard, name, default_reason=None): A function that returns a decorator. """ - def _impl(reason=None): - if reason is None: - if default_reason is None: - raise ValueError('%s requires a reason for skipping.' % name) - reason = default_reason - return unittest.skipUnless(guard(), reason) + def _impl(reason=None): + if reason is None: + if default_reason is None: + raise ValueError("%s requires a reason for skipping." % name) + reason = default_reason + return unittest.skipUnless(guard(), reason) + + return _impl - return _impl run_v1_only = _run_conditionally( - lambda: not tensorflow_python_tf2.enabled(), - name='run_v1_only') + lambda: not tensorflow_python_tf2.enabled(), name="run_v1_only" +) run_v2_only = _run_conditionally( lambda: tensorflow_python_tf2.enabled(), - name='run_v2_only', - default_reason='Test only appropriate for TensorFlow v2') + name="run_v2_only", + default_reason="Test only appropriate for TensorFlow v2", +) diff --git a/tensorboard/version.py b/tensorboard/version.py index 1b3a476f1a..fa05473e7d 100644 --- a/tensorboard/version.py +++ b/tensorboard/version.py @@ -15,4 +15,4 @@ """Contains the version string.""" -VERSION = '2.2.0a0' +VERSION = "2.2.0a0" diff --git a/tensorboard/version_test.py b/tensorboard/version_test.py index d517aeffcc..2b30428cf1 100644 --- a/tensorboard/version_test.py +++ b/tensorboard/version_test.py @@ -23,18 +23,17 @@ class VersionTest(tb_test.TestCase): + def test_valid_pep440_version(self): + """Ensure that our version is PEP 440-compliant.""" + # `Version` and `LegacyVersion` are vendored and not publicly + # exported; get handles to them. + compliant_version = pkg_resources.parse_version("1.0.0") + legacy_version = pkg_resources.parse_version("some arbitrary string") + self.assertNotEqual(type(compliant_version), type(legacy_version)) - def test_valid_pep440_version(self): - """Ensure that our version is PEP 440-compliant.""" - # `Version` and `LegacyVersion` are vendored and not publicly - # exported; get handles to them. - compliant_version = pkg_resources.parse_version("1.0.0") - legacy_version = pkg_resources.parse_version("some arbitrary string") - self.assertNotEqual(type(compliant_version), type(legacy_version)) + tensorboard_version = pkg_resources.parse_version(version.VERSION) + self.assertIsInstance(tensorboard_version, type(compliant_version)) - tensorboard_version = pkg_resources.parse_version(version.VERSION) - self.assertIsInstance(tensorboard_version, type(compliant_version)) - -if __name__ == '__main__': - tb_test.main() +if __name__ == "__main__": + tb_test.main()