diff --git a/.travis.yml b/.travis.yml
index 2f8406347f..af1297cf03 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,8 +1,8 @@
dist: trusty
language: python
python:
- - "2.7"
- "3.6"
+ - "2.7"
branches:
only:
@@ -46,12 +46,14 @@ before_install:
install:
- elapsed "install"
+ - "PY3=\"$(python -c 'if __import__(\"sys\").version_info[0] > 2: print(1)')\""
# Older versions of Pip sometimes resolve specifiers like `tf-nightly`
# to versions other than the most recent(!).
- pip install -U pip
# Lint check deps.
- pip install flake8==3.7.8
- pip install yamllint==1.17.0
+ - if [ -n "${PY3}" ]; then pip install black==19.10b0; fi
# TensorBoard deps.
- pip install futures==3.1.1
- pip install grpcio==1.24.3
@@ -87,6 +89,8 @@ before_script:
- flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# Lint frontend code
- yarn lint
+ # Lint backend code
+ - if [ -n "${PY3}" ]; then black --check .; fi
# Lint .yaml docs files. Use '# yamllint disable-line rule:foo' to suppress.
- yamllint -c docs/.yamllint docs docs/.yamllint
# Make sure that IPython notebooks have valid Markdown.
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000..c73ad6175e
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,3 @@
+[tool.black]
+line-length = 80
+target-version = ["py27", "py36", "py37", "py38"]
diff --git a/tensorboard/__init__.py b/tensorboard/__init__.py
index 6d172de7f1..510de45b0b 100644
--- a/tensorboard/__init__.py
+++ b/tensorboard/__init__.py
@@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""TensorBoard is a webapp for understanding TensorFlow runs and graphs.
-"""
+"""TensorBoard is a webapp for understanding TensorFlow runs and graphs."""
from __future__ import absolute_import
from __future__ import division
@@ -67,33 +66,36 @@
# additional discussion.
-@lazy.lazy_load('tensorboard.notebook')
+@lazy.lazy_load("tensorboard.notebook")
def notebook():
- import importlib
- return importlib.import_module('tensorboard.notebook')
+ import importlib
+ return importlib.import_module("tensorboard.notebook")
-@lazy.lazy_load('tensorboard.program')
+
+@lazy.lazy_load("tensorboard.program")
def program():
- import importlib
- return importlib.import_module('tensorboard.program')
+ import importlib
+
+ return importlib.import_module("tensorboard.program")
-@lazy.lazy_load('tensorboard.summary')
+@lazy.lazy_load("tensorboard.summary")
def summary():
- import importlib
- return importlib.import_module('tensorboard.summary')
+ import importlib
+
+ return importlib.import_module("tensorboard.summary")
def load_ipython_extension(ipython):
- """IPython API entry point.
+ """IPython API entry point.
- Only intended to be called by the IPython runtime.
+ Only intended to be called by the IPython runtime.
- See:
- https://ipython.readthedocs.io/en/stable/config/extensions/index.html
- """
- notebook._load_ipython_extension(ipython)
+ See:
+ https://ipython.readthedocs.io/en/stable/config/extensions/index.html
+ """
+ notebook._load_ipython_extension(ipython)
__version__ = version.VERSION
diff --git a/tensorboard/__main__.py b/tensorboard/__main__.py
index 5ae362013a..51022d7afa 100644
--- a/tensorboard/__main__.py
+++ b/tensorboard/__main__.py
@@ -24,5 +24,5 @@
del _main
-if __name__ == '__main__':
- run_main()
+if __name__ == "__main__":
+ run_main()
diff --git a/tensorboard/backend/application.py b/tensorboard/backend/application.py
index 5dd0f132dc..6ccde8d3f8 100644
--- a/tensorboard/backend/application.py
+++ b/tensorboard/backend/application.py
@@ -37,7 +37,9 @@
import time
import six
-from six.moves.urllib import parse as urlparse # pylint: disable=wrong-import-order
+from six.moves.urllib import (
+ parse as urlparse,
+) # pylint: disable=wrong-import-order
from werkzeug import wrappers
@@ -48,9 +50,15 @@
from tensorboard.backend import path_prefix
from tensorboard.backend import security_validator
from tensorboard.backend.event_processing import db_import_multiplexer
-from tensorboard.backend.event_processing import data_provider as event_data_provider
-from tensorboard.backend.event_processing import plugin_event_accumulator as event_accumulator
-from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer
+from tensorboard.backend.event_processing import (
+ data_provider as event_data_provider,
+)
+from tensorboard.backend.event_processing import (
+ plugin_event_accumulator as event_accumulator,
+)
+from tensorboard.backend.event_processing import (
+ plugin_event_multiplexer as event_multiplexer,
+)
from tensorboard.plugins import base_plugin
from tensorboard.plugins.audio import metadata as audio_metadata
from tensorboard.plugins.core import core_plugin
@@ -75,110 +83,115 @@
pr_curve_metadata.PLUGIN_NAME: 100,
}
-DATA_PREFIX = '/data'
-PLUGIN_PREFIX = '/plugin'
-PLUGINS_LISTING_ROUTE = '/plugins_listing'
-PLUGIN_ENTRY_ROUTE = '/plugin_entry.html'
+DATA_PREFIX = "/data"
+PLUGIN_PREFIX = "/plugin"
+PLUGINS_LISTING_ROUTE = "/plugins_listing"
+PLUGIN_ENTRY_ROUTE = "/plugin_entry.html"
# Slashes in a plugin name could throw the router for a loop. An empty
# name would be confusing, too. To be safe, let's restrict the valid
# names as follows.
-_VALID_PLUGIN_RE = re.compile(r'^[A-Za-z0-9_.-]+$')
+_VALID_PLUGIN_RE = re.compile(r"^[A-Za-z0-9_.-]+$")
logger = tb_logging.get_logger()
def tensor_size_guidance_from_flags(flags):
- """Apply user per-summary size guidance overrides."""
+ """Apply user per-summary size guidance overrides."""
- tensor_size_guidance = dict(DEFAULT_TENSOR_SIZE_GUIDANCE)
- if not flags or not flags.samples_per_plugin:
- return tensor_size_guidance
+ tensor_size_guidance = dict(DEFAULT_TENSOR_SIZE_GUIDANCE)
+ if not flags or not flags.samples_per_plugin:
+ return tensor_size_guidance
- for token in flags.samples_per_plugin.split(','):
- k, v = token.strip().split('=')
- tensor_size_guidance[k] = int(v)
+ for token in flags.samples_per_plugin.split(","):
+ k, v = token.strip().split("=")
+ tensor_size_guidance[k] = int(v)
- return tensor_size_guidance
+ return tensor_size_guidance
def standard_tensorboard_wsgi(flags, plugin_loaders, assets_zip_provider):
- """Construct a TensorBoardWSGIApp with standard plugins and multiplexer.
-
- Args:
- flags: An argparse.Namespace containing TensorBoard CLI flags.
- plugin_loaders: A list of TBLoader instances.
- assets_zip_provider: See TBContext documentation for more information.
-
- Returns:
- The new TensorBoard WSGI application.
-
- :type plugin_loaders: list[base_plugin.TBLoader]
- :rtype: TensorBoardWSGI
- """
- data_provider = None
- multiplexer = None
- reload_interval = flags.reload_interval
- if flags.db_import:
- # DB import mode.
- db_uri = flags.db
- # Create a temporary DB file if we weren't given one.
- if not db_uri:
- tmpdir = tempfile.mkdtemp(prefix='tbimport')
- atexit.register(shutil.rmtree, tmpdir)
- db_uri = 'sqlite:%s/tmp.sqlite' % tmpdir
- db_connection_provider = create_sqlite_connection_provider(db_uri)
- logger.info('Importing logdir into DB at %s', db_uri)
- multiplexer = db_import_multiplexer.DbImportMultiplexer(
- db_uri=db_uri,
- db_connection_provider=db_connection_provider,
- purge_orphaned_data=flags.purge_orphaned_data,
- max_reload_threads=flags.max_reload_threads)
- elif flags.db:
- # DB read-only mode, never load event logs.
- reload_interval = -1
- db_connection_provider = create_sqlite_connection_provider(flags.db)
- multiplexer = _DbModeMultiplexer(flags.db, db_connection_provider)
- else:
- # Regular logdir loading mode.
- multiplexer = event_multiplexer.EventMultiplexer(
- size_guidance=DEFAULT_SIZE_GUIDANCE,
- tensor_size_guidance=tensor_size_guidance_from_flags(flags),
- purge_orphaned_data=flags.purge_orphaned_data,
- max_reload_threads=flags.max_reload_threads,
- event_file_active_filter=_get_event_file_active_filter(flags))
- if flags.generic_data != 'false':
- data_provider = event_data_provider.MultiplexerDataProvider(
- multiplexer, flags.logdir or flags.logdir_spec
- )
-
- if reload_interval >= 0:
- # We either reload the multiplexer once when TensorBoard starts up, or we
- # continuously reload the multiplexer.
- if flags.logdir:
- path_to_run = {os.path.expanduser(flags.logdir): None}
+ """Construct a TensorBoardWSGIApp with standard plugins and multiplexer.
+
+ Args:
+ flags: An argparse.Namespace containing TensorBoard CLI flags.
+ plugin_loaders: A list of TBLoader instances.
+ assets_zip_provider: See TBContext documentation for more information.
+
+ Returns:
+ The new TensorBoard WSGI application.
+
+ :type plugin_loaders: list[base_plugin.TBLoader]
+ :rtype: TensorBoardWSGI
+ """
+ data_provider = None
+ multiplexer = None
+ reload_interval = flags.reload_interval
+ if flags.db_import:
+ # DB import mode.
+ db_uri = flags.db
+ # Create a temporary DB file if we weren't given one.
+ if not db_uri:
+ tmpdir = tempfile.mkdtemp(prefix="tbimport")
+ atexit.register(shutil.rmtree, tmpdir)
+ db_uri = "sqlite:%s/tmp.sqlite" % tmpdir
+ db_connection_provider = create_sqlite_connection_provider(db_uri)
+ logger.info("Importing logdir into DB at %s", db_uri)
+ multiplexer = db_import_multiplexer.DbImportMultiplexer(
+ db_uri=db_uri,
+ db_connection_provider=db_connection_provider,
+ purge_orphaned_data=flags.purge_orphaned_data,
+ max_reload_threads=flags.max_reload_threads,
+ )
+ elif flags.db:
+ # DB read-only mode, never load event logs.
+ reload_interval = -1
+ db_connection_provider = create_sqlite_connection_provider(flags.db)
+ multiplexer = _DbModeMultiplexer(flags.db, db_connection_provider)
else:
- path_to_run = parse_event_files_spec(flags.logdir_spec)
- start_reloading_multiplexer(
- multiplexer, path_to_run, reload_interval, flags.reload_task)
- return TensorBoardWSGIApp(
- flags, plugin_loaders, data_provider, assets_zip_provider, multiplexer)
+ # Regular logdir loading mode.
+ multiplexer = event_multiplexer.EventMultiplexer(
+ size_guidance=DEFAULT_SIZE_GUIDANCE,
+ tensor_size_guidance=tensor_size_guidance_from_flags(flags),
+ purge_orphaned_data=flags.purge_orphaned_data,
+ max_reload_threads=flags.max_reload_threads,
+ event_file_active_filter=_get_event_file_active_filter(flags),
+ )
+ if flags.generic_data != "false":
+ data_provider = event_data_provider.MultiplexerDataProvider(
+ multiplexer, flags.logdir or flags.logdir_spec
+ )
+
+ if reload_interval >= 0:
+ # We either reload the multiplexer once when TensorBoard starts up, or we
+ # continuously reload the multiplexer.
+ if flags.logdir:
+ path_to_run = {os.path.expanduser(flags.logdir): None}
+ else:
+ path_to_run = parse_event_files_spec(flags.logdir_spec)
+ start_reloading_multiplexer(
+ multiplexer, path_to_run, reload_interval, flags.reload_task
+ )
+ return TensorBoardWSGIApp(
+ flags, plugin_loaders, data_provider, assets_zip_provider, multiplexer
+ )
def _handling_errors(wsgi_app):
- def wrapper(*args):
- (environ, start_response) = (args[-2], args[-1])
- try:
- return wsgi_app(*args)
- except errors.PublicError as e:
- request = wrappers.Request(environ)
- error_app = http_util.Respond(
- request, str(e), "text/plain", code=e.http_code
- )
- return error_app(environ, start_response)
- # Let other exceptions be handled by the server, as an opaque
- # internal server error.
- return wrapper
+ def wrapper(*args):
+ (environ, start_response) = (args[-2], args[-1])
+ try:
+ return wsgi_app(*args)
+ except errors.PublicError as e:
+ request = wrappers.Request(environ)
+ error_app = http_util.Respond(
+ request, str(e), "text/plain", code=e.http_code
+ )
+ return error_app(environ, start_response)
+ # Let other exceptions be handled by the server, as an opaque
+ # internal server error.
+
+ return wrapper
def TensorBoardWSGIApp(
@@ -186,568 +199,624 @@ def TensorBoardWSGIApp(
plugins,
data_provider=None,
assets_zip_provider=None,
- deprecated_multiplexer=None):
- """Constructs a TensorBoard WSGI app from plugins and data providers.
-
- Args:
- flags: An argparse.Namespace containing TensorBoard CLI flags.
- plugins: A list of plugin loader instances.
- assets_zip_provider: See TBContext documentation for more information.
- data_provider: Instance of `tensorboard.data.provider.DataProvider`. May
- be `None` if `flags.generic_data` is set to `"false"` in which case
- `deprecated_multiplexer` must be passed instead.
- deprecated_multiplexer: Optional `plugin_event_multiplexer.EventMultiplexer`
- to use for any plugins not yet enabled for the DataProvider API.
- Required if the data_provider argument is not passed.
-
- Returns:
- A WSGI application that implements the TensorBoard backend.
-
- :type plugins: list[base_plugin.TBLoader]
- """
- db_uri = None
- db_connection_provider = None
- if isinstance(
- deprecated_multiplexer,
- (db_import_multiplexer.DbImportMultiplexer, _DbModeMultiplexer)):
- db_uri = deprecated_multiplexer.db_uri
- db_connection_provider = deprecated_multiplexer.db_connection_provider
- plugin_name_to_instance = {}
- context = base_plugin.TBContext(
- data_provider=data_provider,
- db_connection_provider=db_connection_provider,
- db_uri=db_uri,
- flags=flags,
- logdir=flags.logdir,
- multiplexer=deprecated_multiplexer,
- assets_zip_provider=assets_zip_provider,
- plugin_name_to_instance=plugin_name_to_instance,
- window_title=flags.window_title)
- tbplugins = []
- for plugin_spec in plugins:
- loader = make_plugin_loader(plugin_spec)
- plugin = loader.load(context)
- if plugin is None:
- continue
- tbplugins.append(plugin)
- plugin_name_to_instance[plugin.plugin_name] = plugin
- return TensorBoardWSGI(tbplugins, flags.path_prefix)
-
-
-class TensorBoardWSGI(object):
- """The TensorBoard WSGI app that delegates to a set of TBPlugin."""
-
- def __init__(self, plugins, path_prefix=''):
- """Constructs TensorBoardWSGI instance.
+ deprecated_multiplexer=None,
+):
+ """Constructs a TensorBoard WSGI app from plugins and data providers.
Args:
- plugins: A list of base_plugin.TBPlugin subclass instances.
flags: An argparse.Namespace containing TensorBoard CLI flags.
+ plugins: A list of plugin loader instances.
+ assets_zip_provider: See TBContext documentation for more information.
+ data_provider: Instance of `tensorboard.data.provider.DataProvider`. May
+ be `None` if `flags.generic_data` is set to `"false"` in which case
+ `deprecated_multiplexer` must be passed instead.
+ deprecated_multiplexer: Optional `plugin_event_multiplexer.EventMultiplexer`
+ to use for any plugins not yet enabled for the DataProvider API.
+ Required if the data_provider argument is not passed.
Returns:
- A WSGI application for the set of all TBPlugin instances.
+ A WSGI application that implements the TensorBoard backend.
- Raises:
- ValueError: If some plugin has no plugin_name
- ValueError: If some plugin has an invalid plugin_name (plugin
- names must only contain [A-Za-z0-9_.-])
- ValueError: If two plugins have the same plugin_name
- ValueError: If some plugin handles a route that does not start
- with a slash
-
- :type plugins: list[base_plugin.TBPlugin]
+ :type plugins: list[base_plugin.TBLoader]
"""
- self._plugins = plugins
- self._path_prefix = path_prefix
- if self._path_prefix.endswith('/'):
- # Should have been fixed by `fix_flags`.
- raise ValueError('Trailing slash in path prefix: %r' % self._path_prefix)
-
- self.exact_routes = {
- # TODO(@chihuahua): Delete this RPC once we have skylark rules that
- # obviate the need for the frontend to determine which plugins are
- # active.
- DATA_PREFIX + PLUGINS_LISTING_ROUTE: self._serve_plugins_listing,
- DATA_PREFIX + PLUGIN_ENTRY_ROUTE: self._serve_plugin_entry,
- }
- unordered_prefix_routes = {}
-
- # Serve the routes from the registered plugins using their name as the route
- # prefix. For example if plugin z has two routes /a and /b, they will be
- # served as /data/plugin/z/a and /data/plugin/z/b.
- plugin_names_encountered = set()
- for plugin in self._plugins:
- if plugin.plugin_name is None:
- raise ValueError('Plugin %s has no plugin_name' % plugin)
- if not _VALID_PLUGIN_RE.match(plugin.plugin_name):
- raise ValueError('Plugin %s has invalid name %r' % (plugin,
- plugin.plugin_name))
- if plugin.plugin_name in plugin_names_encountered:
- raise ValueError('Duplicate plugins for name %s' % plugin.plugin_name)
- plugin_names_encountered.add(plugin.plugin_name)
-
- try:
- plugin_apps = plugin.get_plugin_apps()
- except Exception as e: # pylint: disable=broad-except
- if type(plugin) is core_plugin.CorePlugin: # pylint: disable=unidiomatic-typecheck
- raise
- logger.warn('Plugin %s failed. Exception: %s',
- plugin.plugin_name, str(e))
- continue
- for route, app in plugin_apps.items():
- if not route.startswith('/'):
- raise ValueError('Plugin named %r handles invalid route %r: '
- 'route does not start with a slash' %
- (plugin.plugin_name, route))
- if type(plugin) is core_plugin.CorePlugin: # pylint: disable=unidiomatic-typecheck
- path = route
- else:
- path = (
- DATA_PREFIX + PLUGIN_PREFIX + '/' + plugin.plugin_name + route
- )
-
- if path.endswith('/*'):
- # Note we remove the '*' but leave the slash in place.
- path = path[:-1]
- if '*' in path:
- # note we re-add the removed * in the format string
- raise ValueError('Plugin %r handles invalid route \'%s*\': Only '
- 'trailing wildcards are supported '
- '(i.e., `/.../*`)' %
- (plugin.plugin_name, path))
- unordered_prefix_routes[path] = app
- else:
- if '*' in path:
- raise ValueError('Plugin %r handles invalid route %r: Only '
- 'trailing wildcards are supported '
- '(i.e., `/.../*`)' %
- (plugin.plugin_name, path))
- self.exact_routes[path] = app
-
- # Wildcard routes will be checked in the given order, so we sort them
- # longest to shortest so that a more specific route will take precedence
- # over a more general one (e.g., a catchall route `/*` should come last).
- self.prefix_routes = collections.OrderedDict(
- sorted(
- six.iteritems(unordered_prefix_routes),
- key=lambda x: len(x[0]),
- reverse=True))
-
- self._app = self._create_wsgi_app()
-
- def _create_wsgi_app(self):
- """Apply middleware to create the final WSGI app."""
- app = self._route_request
- app = empty_path_redirect.EmptyPathRedirectMiddleware(app)
- app = experiment_id.ExperimentIdMiddleware(app)
- app = path_prefix.PathPrefixMiddleware(app, self._path_prefix)
- app = security_validator.SecurityValidatorMiddleware(app)
- app = _handling_errors(app)
- return app
-
- @wrappers.Request.application
- def _serve_plugin_entry(self, request):
- """Serves a HTML for iframed plugin entry point.
+ db_uri = None
+ db_connection_provider = None
+ if isinstance(
+ deprecated_multiplexer,
+ (db_import_multiplexer.DbImportMultiplexer, _DbModeMultiplexer),
+ ):
+ db_uri = deprecated_multiplexer.db_uri
+ db_connection_provider = deprecated_multiplexer.db_connection_provider
+ plugin_name_to_instance = {}
+ context = base_plugin.TBContext(
+ data_provider=data_provider,
+ db_connection_provider=db_connection_provider,
+ db_uri=db_uri,
+ flags=flags,
+ logdir=flags.logdir,
+ multiplexer=deprecated_multiplexer,
+ assets_zip_provider=assets_zip_provider,
+ plugin_name_to_instance=plugin_name_to_instance,
+ window_title=flags.window_title,
+ )
+ tbplugins = []
+ for plugin_spec in plugins:
+ loader = make_plugin_loader(plugin_spec)
+ plugin = loader.load(context)
+ if plugin is None:
+ continue
+ tbplugins.append(plugin)
+ plugin_name_to_instance[plugin.plugin_name] = plugin
+ return TensorBoardWSGI(tbplugins, flags.path_prefix)
- Args:
- request: The werkzeug.Request object.
- Returns:
- A werkzeug.Response object.
- """
- name = request.args.get('name')
- plugins = [
- plugin for plugin in self._plugins if plugin.plugin_name == name]
-
- if not plugins:
- raise errors.NotFoundError(name)
-
- if len(plugins) > 1:
- # Technically is not possible as plugin names are unique and is checked
- # by the check on __init__.
- reason = (
- 'Plugin invariant error: multiple plugins with name '
- '{name} found: {list}'
- ).format(name=name, list=plugins)
- raise AssertionError(reason)
-
- plugin = plugins[0]
- module_path = plugin.frontend_metadata().es_module_path
- if not module_path:
- return http_util.Respond(
- request, 'Plugin is not module loadable', 'text/plain', code=400)
-
- # non-self origin is blocked by CSP but this is a good invariant checking.
- if urlparse.urlparse(module_path).netloc:
- raise ValueError('Expected es_module_path to be non-absolute path')
-
- module_json = json.dumps('.' + module_path)
- script_content = 'import({}).then((m) => void m.render());'.format(
- module_json)
- digest = hashlib.sha256(script_content.encode('utf-8')).digest()
- script_sha = base64.b64encode(digest).decode('ascii')
-
- html = textwrap.dedent("""
+class TensorBoardWSGI(object):
+ """The TensorBoard WSGI app that delegates to a set of TBPlugin."""
+
+ def __init__(self, plugins, path_prefix=""):
+ """Constructs TensorBoardWSGI instance.
+
+ Args:
+ plugins: A list of base_plugin.TBPlugin subclass instances.
+ flags: An argparse.Namespace containing TensorBoard CLI flags.
+
+ Returns:
+ A WSGI application for the set of all TBPlugin instances.
+
+ Raises:
+ ValueError: If some plugin has no plugin_name
+ ValueError: If some plugin has an invalid plugin_name (plugin
+ names must only contain [A-Za-z0-9_.-])
+ ValueError: If two plugins have the same plugin_name
+ ValueError: If some plugin handles a route that does not start
+ with a slash
+
+ :type plugins: list[base_plugin.TBPlugin]
+ """
+ self._plugins = plugins
+ self._path_prefix = path_prefix
+ if self._path_prefix.endswith("/"):
+ # Should have been fixed by `fix_flags`.
+ raise ValueError(
+ "Trailing slash in path prefix: %r" % self._path_prefix
+ )
+
+ self.exact_routes = {
+ # TODO(@chihuahua): Delete this RPC once we have skylark rules that
+ # obviate the need for the frontend to determine which plugins are
+ # active.
+ DATA_PREFIX + PLUGINS_LISTING_ROUTE: self._serve_plugins_listing,
+ DATA_PREFIX + PLUGIN_ENTRY_ROUTE: self._serve_plugin_entry,
+ }
+ unordered_prefix_routes = {}
+
+ # Serve the routes from the registered plugins using their name as the route
+ # prefix. For example if plugin z has two routes /a and /b, they will be
+ # served as /data/plugin/z/a and /data/plugin/z/b.
+ plugin_names_encountered = set()
+ for plugin in self._plugins:
+ if plugin.plugin_name is None:
+ raise ValueError("Plugin %s has no plugin_name" % plugin)
+ if not _VALID_PLUGIN_RE.match(plugin.plugin_name):
+ raise ValueError(
+ "Plugin %s has invalid name %r"
+ % (plugin, plugin.plugin_name)
+ )
+ if plugin.plugin_name in plugin_names_encountered:
+ raise ValueError(
+ "Duplicate plugins for name %s" % plugin.plugin_name
+ )
+ plugin_names_encountered.add(plugin.plugin_name)
+
+ try:
+ plugin_apps = plugin.get_plugin_apps()
+ except Exception as e: # pylint: disable=broad-except
+ if (
+ type(plugin) is core_plugin.CorePlugin
+ ): # pylint: disable=unidiomatic-typecheck
+ raise
+ logger.warn(
+ "Plugin %s failed. Exception: %s",
+ plugin.plugin_name,
+ str(e),
+ )
+ continue
+ for route, app in plugin_apps.items():
+ if not route.startswith("/"):
+ raise ValueError(
+ "Plugin named %r handles invalid route %r: "
+ "route does not start with a slash"
+ % (plugin.plugin_name, route)
+ )
+ if (
+ type(plugin) is core_plugin.CorePlugin
+ ): # pylint: disable=unidiomatic-typecheck
+ path = route
+ else:
+ path = (
+ DATA_PREFIX
+ + PLUGIN_PREFIX
+ + "/"
+ + plugin.plugin_name
+ + route
+ )
+
+ if path.endswith("/*"):
+ # Note we remove the '*' but leave the slash in place.
+ path = path[:-1]
+ if "*" in path:
+ # note we re-add the removed * in the format string
+ raise ValueError(
+ "Plugin %r handles invalid route '%s*': Only "
+ "trailing wildcards are supported "
+ "(i.e., `/.../*`)" % (plugin.plugin_name, path)
+ )
+ unordered_prefix_routes[path] = app
+ else:
+ if "*" in path:
+ raise ValueError(
+ "Plugin %r handles invalid route %r: Only "
+ "trailing wildcards are supported "
+ "(i.e., `/.../*`)" % (plugin.plugin_name, path)
+ )
+ self.exact_routes[path] = app
+
+ # Wildcard routes will be checked in the given order, so we sort them
+ # longest to shortest so that a more specific route will take precedence
+ # over a more general one (e.g., a catchall route `/*` should come last).
+ self.prefix_routes = collections.OrderedDict(
+ sorted(
+ six.iteritems(unordered_prefix_routes),
+ key=lambda x: len(x[0]),
+ reverse=True,
+ )
+ )
+
+ self._app = self._create_wsgi_app()
+
+ def _create_wsgi_app(self):
+ """Apply middleware to create the final WSGI app."""
+ app = self._route_request
+ app = empty_path_redirect.EmptyPathRedirectMiddleware(app)
+ app = experiment_id.ExperimentIdMiddleware(app)
+ app = path_prefix.PathPrefixMiddleware(app, self._path_prefix)
+ app = security_validator.SecurityValidatorMiddleware(app)
+ app = _handling_errors(app)
+ return app
+
+ @wrappers.Request.application
+ def _serve_plugin_entry(self, request):
+ """Serves a HTML for iframed plugin entry point.
+
+ Args:
+ request: The werkzeug.Request object.
+
+ Returns:
+ A werkzeug.Response object.
+ """
+ name = request.args.get("name")
+ plugins = [
+ plugin for plugin in self._plugins if plugin.plugin_name == name
+ ]
+
+ if not plugins:
+ raise errors.NotFoundError(name)
+
+ if len(plugins) > 1:
+ # Technically is not possible as plugin names are unique and is checked
+ # by the check on __init__.
+ reason = (
+ "Plugin invariant error: multiple plugins with name "
+ "{name} found: {list}"
+ ).format(name=name, list=plugins)
+ raise AssertionError(reason)
+
+ plugin = plugins[0]
+ module_path = plugin.frontend_metadata().es_module_path
+ if not module_path:
+ return http_util.Respond(
+ request, "Plugin is not module loadable", "text/plain", code=400
+ )
+
+ # non-self origin is blocked by CSP but this is a good invariant checking.
+ if urlparse.urlparse(module_path).netloc:
+ raise ValueError("Expected es_module_path to be non-absolute path")
+
+ module_json = json.dumps("." + module_path)
+ script_content = "import({}).then((m) => void m.render());".format(
+ module_json
+ )
+ digest = hashlib.sha256(script_content.encode("utf-8")).digest()
+ script_sha = base64.b64encode(digest).decode("ascii")
+
+ html = textwrap.dedent(
+ """
- """).format(name=name, script_content=script_content)
- return http_util.Respond(
- request,
- html,
- 'text/html',
- csp_scripts_sha256s=[script_sha],
- )
+ """
+ ).format(name=name, script_content=script_content)
+ return http_util.Respond(
+ request, html, "text/html", csp_scripts_sha256s=[script_sha],
+ )
- @wrappers.Request.application
- def _serve_plugins_listing(self, request):
- """Serves an object mapping plugin name to whether it is enabled.
+ @wrappers.Request.application
+ def _serve_plugins_listing(self, request):
+ """Serves an object mapping plugin name to whether it is enabled.
+
+ Args:
+ request: The werkzeug.Request object.
+
+ Returns:
+ A werkzeug.Response object.
+ """
+ response = collections.OrderedDict()
+ for plugin in self._plugins:
+ if (
+ type(plugin) is core_plugin.CorePlugin
+ ): # pylint: disable=unidiomatic-typecheck
+ # This plugin's existence is a backend implementation detail.
+ continue
+ start = time.time()
+ is_active = plugin.is_active()
+ elapsed = time.time() - start
+ logger.info(
+ "Plugin listing: is_active() for %s took %0.3f seconds",
+ plugin.plugin_name,
+ elapsed,
+ )
+
+ plugin_metadata = plugin.frontend_metadata()
+ output_metadata = {
+ "disable_reload": plugin_metadata.disable_reload,
+ "enabled": is_active,
+ # loading_mechanism set below
+ "remove_dom": plugin_metadata.remove_dom,
+ # tab_name set below
+ }
+
+ if plugin_metadata.tab_name is not None:
+ output_metadata["tab_name"] = plugin_metadata.tab_name
+ else:
+ output_metadata["tab_name"] = plugin.plugin_name
+
+ es_module_handler = plugin_metadata.es_module_path
+ element_name = plugin_metadata.element_name
+ is_ng_component = plugin_metadata.is_ng_component
+ if is_ng_component:
+ if element_name is not None:
+ raise ValueError(
+ "Plugin %r declared as both Angular built-in and legacy"
+ % plugin.plugin_name
+ )
+ if es_module_handler is not None:
+ raise ValueError(
+ "Plugin %r declared as both Angular built-in and iframed"
+ % plugin.plugin_name
+ )
+ loading_mechanism = {
+ "type": "NG_COMPONENT",
+ }
+ elif element_name is not None and es_module_handler is not None:
+ logger.error(
+ "Plugin %r declared as both legacy and iframed; skipping",
+ plugin.plugin_name,
+ )
+ continue
+ elif element_name is not None and es_module_handler is None:
+ loading_mechanism = {
+ "type": "CUSTOM_ELEMENT",
+ "element_name": element_name,
+ }
+ elif element_name is None and es_module_handler is not None:
+ loading_mechanism = {
+ "type": "IFRAME",
+ "module_path": "".join(
+ [
+ request.script_root,
+ DATA_PREFIX,
+ PLUGIN_PREFIX,
+ "/",
+ plugin.plugin_name,
+ es_module_handler,
+ ]
+ ),
+ }
+ else:
+ # As a compatibility measure (for plugins that we don't
+ # control), we'll pull it from the frontend registry for now.
+ loading_mechanism = {
+ "type": "NONE",
+ }
+ output_metadata["loading_mechanism"] = loading_mechanism
+
+ response[plugin.plugin_name] = output_metadata
+ return http_util.Respond(request, response, "application/json")
+
+ def __call__(self, environ, start_response):
+ """Central entry point for the TensorBoard application.
+
+ This __call__ method conforms to the WSGI spec, so that instances of this
+ class are WSGI applications.
+
+ Args:
+ environ: See WSGI spec (PEP 3333).
+ start_response: See WSGI spec (PEP 3333).
+ """
+ return self._app(environ, start_response)
+
+ def _route_request(self, environ, start_response):
+ """Delegate an incoming request to sub-applications.
+
+ This method supports strict string matching and wildcard routes of a
+ single path component, such as `/foo/*`. Other routing patterns,
+ like regular expressions, are not supported.
+
+ This is the main TensorBoard entry point before middleware is
+ applied. (See `_create_wsgi_app`.)
+
+ Args:
+ environ: See WSGI spec (PEP 3333).
+ start_response: See WSGI spec (PEP 3333).
+ """
+
+ request = wrappers.Request(environ)
+ parsed_url = urlparse.urlparse(request.path)
+ clean_path = _clean_path(parsed_url.path)
+
+ # pylint: disable=too-many-function-args
+ if clean_path in self.exact_routes:
+ return self.exact_routes[clean_path](environ, start_response)
+ else:
+ for path_prefix in self.prefix_routes:
+ if clean_path.startswith(path_prefix):
+ return self.prefix_routes[path_prefix](
+ environ, start_response
+ )
- Args:
- request: The werkzeug.Request object.
+ logger.warn("path %s not found, sending 404", clean_path)
+ return http_util.Respond(
+ request, "Not found", "text/plain", code=404
+ )(environ, start_response)
+ # pylint: enable=too-many-function-args
- Returns:
- A werkzeug.Response object.
- """
- response = collections.OrderedDict()
- for plugin in self._plugins:
- if type(plugin) is core_plugin.CorePlugin: # pylint: disable=unidiomatic-typecheck
- # This plugin's existence is a backend implementation detail.
- continue
- start = time.time()
- is_active = plugin.is_active()
- elapsed = time.time() - start
- logger.info(
- 'Plugin listing: is_active() for %s took %0.3f seconds',
- plugin.plugin_name, elapsed)
-
- plugin_metadata = plugin.frontend_metadata()
- output_metadata = {
- 'disable_reload': plugin_metadata.disable_reload,
- 'enabled': is_active,
- # loading_mechanism set below
- 'remove_dom': plugin_metadata.remove_dom,
- # tab_name set below
- }
-
- if plugin_metadata.tab_name is not None:
- output_metadata['tab_name'] = plugin_metadata.tab_name
- else:
- output_metadata['tab_name'] = plugin.plugin_name
-
- es_module_handler = plugin_metadata.es_module_path
- element_name = plugin_metadata.element_name
- is_ng_component = plugin_metadata.is_ng_component
- if is_ng_component:
- if element_name is not None:
- raise ValueError(
- 'Plugin %r declared as both Angular built-in and legacy' %
- plugin.plugin_name)
- if es_module_handler is not None:
- raise ValueError(
- 'Plugin %r declared as both Angular built-in and iframed' %
- plugin.plugin_name)
- loading_mechanism = {
- 'type': 'NG_COMPONENT',
- }
- elif element_name is not None and es_module_handler is not None:
- logger.error(
- 'Plugin %r declared as both legacy and iframed; skipping',
- plugin.plugin_name,
- )
- continue
- elif element_name is not None and es_module_handler is None:
- loading_mechanism = {
- 'type': 'CUSTOM_ELEMENT',
- 'element_name': element_name,
- }
- elif element_name is None and es_module_handler is not None:
- loading_mechanism = {
- 'type': 'IFRAME',
- 'module_path': ''.join([
- request.script_root, DATA_PREFIX, PLUGIN_PREFIX, '/',
- plugin.plugin_name, es_module_handler,
- ]),
- }
- else:
- # As a compatibility measure (for plugins that we don't
- # control), we'll pull it from the frontend registry for now.
- loading_mechanism = {
- 'type': 'NONE',
- }
- output_metadata['loading_mechanism'] = loading_mechanism
- response[plugin.plugin_name] = output_metadata
- return http_util.Respond(request, response, 'application/json')
+def parse_event_files_spec(logdir_spec):
+ """Parses `logdir_spec` into a map from paths to run group names.
- def __call__(self, environ, start_response):
- """Central entry point for the TensorBoard application.
+ The `--logdir_spec` flag format is a comma-separated list of path
+ specifications. A path spec looks like 'group_name:/path/to/directory' or
+ '/path/to/directory'; in the latter case, the group is unnamed. Group names
+ cannot start with a forward slash: /foo:bar/baz will be interpreted as a spec
+ with no name and path '/foo:bar/baz'.
- This __call__ method conforms to the WSGI spec, so that instances of this
- class are WSGI applications.
+ Globs are not supported.
Args:
- environ: See WSGI spec (PEP 3333).
- start_response: See WSGI spec (PEP 3333).
+ logdir: A comma-separated list of run specifications.
+ Returns:
+ A dict mapping directory paths to names like {'/path/to/directory': 'name'}.
+ Groups without an explicit name are named after their path. If logdir is
+ None, returns an empty dict, which is helpful for testing things that don't
+ require any valid runs.
"""
- return self._app(environ, start_response)
+ files = {}
+ if logdir_spec is None:
+ return files
+ # Make sure keeping consistent with ParseURI in core/lib/io/path.cc
+ uri_pattern = re.compile("[a-zA-Z][0-9a-zA-Z.]*://.*")
+ for specification in logdir_spec.split(","):
+ # Check if the spec contains group. A spec start with xyz:// is regarded as
+ # URI path spec instead of group spec. If the spec looks like /foo:bar/baz,
+ # then we assume it's a path with a colon. If the spec looks like
+ # [a-zA-z]:\foo then we assume its a Windows path and not a single letter
+ # group
+ if (
+ uri_pattern.match(specification) is None
+ and ":" in specification
+ and specification[0] != "/"
+ and not os.path.splitdrive(specification)[0]
+ ):
+ # We split at most once so run_name:/path:with/a/colon will work.
+ run_name, _, path = specification.partition(":")
+ else:
+ run_name = None
+ path = specification
+ if uri_pattern.match(path) is None:
+ path = os.path.realpath(os.path.expanduser(path))
+ files[path] = run_name
+ return files
- def _route_request(self, environ, start_response):
- """Delegate an incoming request to sub-applications.
- This method supports strict string matching and wildcard routes of a
- single path component, such as `/foo/*`. Other routing patterns,
- like regular expressions, are not supported.
+def start_reloading_multiplexer(
+ multiplexer, path_to_run, load_interval, reload_task
+):
+ """Starts automatically reloading the given multiplexer.
- This is the main TensorBoard entry point before middleware is
- applied. (See `_create_wsgi_app`.)
+ If `load_interval` is positive, the thread will reload the multiplexer
+ by calling `ReloadMultiplexer` every `load_interval` seconds, starting
+ immediately. Otherwise, reloads the multiplexer once and never again.
Args:
- environ: See WSGI spec (PEP 3333).
- start_response: See WSGI spec (PEP 3333).
- """
-
- request = wrappers.Request(environ)
- parsed_url = urlparse.urlparse(request.path)
- clean_path = _clean_path(parsed_url.path)
+ multiplexer: The `EventMultiplexer` to add runs to and reload.
+ path_to_run: A dict mapping from paths to run names, where `None` as the run
+ name is interpreted as a run name equal to the path.
+ load_interval: An integer greater than or equal to 0. If positive, how many
+ seconds to wait after one load before starting the next load. Otherwise,
+ reloads the multiplexer once and never again (no continuous reloading).
+ reload_task: Indicates the type of background task to reload with.
- # pylint: disable=too-many-function-args
- if clean_path in self.exact_routes:
- return self.exact_routes[clean_path](environ, start_response)
+ Raises:
+ ValueError: If `load_interval` is negative.
+ """
+ if load_interval < 0:
+ raise ValueError("load_interval is negative: %d" % load_interval)
+
+ def _reload():
+ while True:
+ start = time.time()
+ logger.info("TensorBoard reload process beginning")
+ for path, name in six.iteritems(path_to_run):
+ multiplexer.AddRunsFromDirectory(path, name)
+ logger.info(
+ "TensorBoard reload process: Reload the whole Multiplexer"
+ )
+ multiplexer.Reload()
+ duration = time.time() - start
+ logger.info(
+ "TensorBoard done reloading. Load took %0.3f secs", duration
+ )
+ if load_interval == 0:
+ # Only load the multiplexer once. Do not continuously reload.
+ break
+ time.sleep(load_interval)
+
+ if reload_task == "process":
+ logger.info("Launching reload in a child process")
+ import multiprocessing
+
+ process = multiprocessing.Process(target=_reload, name="Reloader")
+ # Best-effort cleanup; on exit, the main TB parent process will attempt to
+ # kill all its daemonic children.
+ process.daemon = True
+ process.start()
+ elif reload_task in ("thread", "auto"):
+ logger.info("Launching reload in a daemon thread")
+ thread = threading.Thread(target=_reload, name="Reloader")
+ # Make this a daemon thread, which won't block TB from exiting.
+ thread.daemon = True
+ thread.start()
+ elif reload_task == "blocking":
+ if load_interval != 0:
+ raise ValueError(
+ "blocking reload only allowed with load_interval=0"
+ )
+ _reload()
else:
- for path_prefix in self.prefix_routes:
- if clean_path.startswith(path_prefix):
- return self.prefix_routes[path_prefix](environ, start_response)
+ raise ValueError("unrecognized reload_task: %s" % reload_task)
- logger.warn('path %s not found, sending 404', clean_path)
- return http_util.Respond(request, 'Not found', 'text/plain', code=404)(
- environ, start_response)
- # pylint: enable=too-many-function-args
+def create_sqlite_connection_provider(db_uri):
+ """Returns function that returns SQLite Connection objects.
-def parse_event_files_spec(logdir_spec):
- """Parses `logdir_spec` into a map from paths to run group names.
-
- The `--logdir_spec` flag format is a comma-separated list of path
- specifications. A path spec looks like 'group_name:/path/to/directory' or
- '/path/to/directory'; in the latter case, the group is unnamed. Group names
- cannot start with a forward slash: /foo:bar/baz will be interpreted as a spec
- with no name and path '/foo:bar/baz'.
-
- Globs are not supported.
-
- Args:
- logdir: A comma-separated list of run specifications.
- Returns:
- A dict mapping directory paths to names like {'/path/to/directory': 'name'}.
- Groups without an explicit name are named after their path. If logdir is
- None, returns an empty dict, which is helpful for testing things that don't
- require any valid runs.
- """
- files = {}
- if logdir_spec is None:
- return files
- # Make sure keeping consistent with ParseURI in core/lib/io/path.cc
- uri_pattern = re.compile('[a-zA-Z][0-9a-zA-Z.]*://.*')
- for specification in logdir_spec.split(','):
- # Check if the spec contains group. A spec start with xyz:// is regarded as
- # URI path spec instead of group spec. If the spec looks like /foo:bar/baz,
- # then we assume it's a path with a colon. If the spec looks like
- # [a-zA-z]:\foo then we assume its a Windows path and not a single letter
- # group
- if (uri_pattern.match(specification) is None and ':' in specification and
- specification[0] != '/' and not os.path.splitdrive(specification)[0]):
- # We split at most once so run_name:/path:with/a/colon will work.
- run_name, _, path = specification.partition(':')
- else:
- run_name = None
- path = specification
- if uri_pattern.match(path) is None:
- path = os.path.realpath(os.path.expanduser(path))
- files[path] = run_name
- return files
-
-
-def start_reloading_multiplexer(multiplexer, path_to_run, load_interval,
- reload_task):
- """Starts automatically reloading the given multiplexer.
-
- If `load_interval` is positive, the thread will reload the multiplexer
- by calling `ReloadMultiplexer` every `load_interval` seconds, starting
- immediately. Otherwise, reloads the multiplexer once and never again.
-
- Args:
- multiplexer: The `EventMultiplexer` to add runs to and reload.
- path_to_run: A dict mapping from paths to run names, where `None` as the run
- name is interpreted as a run name equal to the path.
- load_interval: An integer greater than or equal to 0. If positive, how many
- seconds to wait after one load before starting the next load. Otherwise,
- reloads the multiplexer once and never again (no continuous reloading).
- reload_task: Indicates the type of background task to reload with.
-
- Raises:
- ValueError: If `load_interval` is negative.
- """
- if load_interval < 0:
- raise ValueError('load_interval is negative: %d' % load_interval)
-
- def _reload():
- while True:
- start = time.time()
- logger.info('TensorBoard reload process beginning')
- for path, name in six.iteritems(path_to_run):
- multiplexer.AddRunsFromDirectory(path, name)
- logger.info('TensorBoard reload process: Reload the whole Multiplexer')
- multiplexer.Reload()
- duration = time.time() - start
- logger.info('TensorBoard done reloading. Load took %0.3f secs', duration)
- if load_interval == 0:
- # Only load the multiplexer once. Do not continuously reload.
- break
- time.sleep(load_interval)
-
- if reload_task == 'process':
- logger.info('Launching reload in a child process')
- import multiprocessing
- process = multiprocessing.Process(target=_reload, name='Reloader')
- # Best-effort cleanup; on exit, the main TB parent process will attempt to
- # kill all its daemonic children.
- process.daemon = True
- process.start()
- elif reload_task in ('thread', 'auto'):
- logger.info('Launching reload in a daemon thread')
- thread = threading.Thread(target=_reload, name='Reloader')
- # Make this a daemon thread, which won't block TB from exiting.
- thread.daemon = True
- thread.start()
- elif reload_task == 'blocking':
- if load_interval != 0:
- raise ValueError('blocking reload only allowed with load_interval=0')
- _reload()
- else:
- raise ValueError('unrecognized reload_task: %s' % reload_task)
+ Args:
+ db_uri: A string URI expressing the DB file, e.g. "sqlite:~/tb.db".
+ Returns:
+ A function that returns a new PEP-249 DB Connection, which must be closed,
+ each time it is called.
-def create_sqlite_connection_provider(db_uri):
- """Returns function that returns SQLite Connection objects.
-
- Args:
- db_uri: A string URI expressing the DB file, e.g. "sqlite:~/tb.db".
-
- Returns:
- A function that returns a new PEP-249 DB Connection, which must be closed,
- each time it is called.
-
- Raises:
- ValueError: If db_uri is not a valid sqlite file URI.
- """
- uri = urlparse.urlparse(db_uri)
- if uri.scheme != 'sqlite':
- raise ValueError('Only sqlite DB URIs are supported: ' + db_uri)
- if uri.netloc:
- raise ValueError('Can not connect to SQLite over network: ' + db_uri)
- if uri.path == ':memory:':
- raise ValueError('Memory mode SQLite not supported: ' + db_uri)
- path = os.path.expanduser(uri.path)
- params = _get_connect_params(uri.query)
- # TODO(@jart): Add thread-local pooling.
- return lambda: sqlite3.connect(path, **params)
+ Raises:
+ ValueError: If db_uri is not a valid sqlite file URI.
+ """
+ uri = urlparse.urlparse(db_uri)
+ if uri.scheme != "sqlite":
+ raise ValueError("Only sqlite DB URIs are supported: " + db_uri)
+ if uri.netloc:
+ raise ValueError("Can not connect to SQLite over network: " + db_uri)
+ if uri.path == ":memory:":
+ raise ValueError("Memory mode SQLite not supported: " + db_uri)
+ path = os.path.expanduser(uri.path)
+ params = _get_connect_params(uri.query)
+ # TODO(@jart): Add thread-local pooling.
+ return lambda: sqlite3.connect(path, **params)
def _get_connect_params(query):
- params = urlparse.parse_qs(query)
- if any(len(v) > 2 for v in params.values()):
- raise ValueError('DB URI params list has duplicate keys: ' + query)
- return {k: json.loads(v[0]) for k, v in params.items()}
+ params = urlparse.parse_qs(query)
+ if any(len(v) > 2 for v in params.values()):
+ raise ValueError("DB URI params list has duplicate keys: " + query)
+ return {k: json.loads(v[0]) for k, v in params.items()}
def _clean_path(path):
- """Removes a trailing slash from a non-root path.
+ """Removes a trailing slash from a non-root path.
- Arguments:
- path: The path of a request.
+ Arguments:
+ path: The path of a request.
- Returns:
- The route to use to serve the request.
- """
- if path != '/' and path.endswith('/'):
- return path[:-1]
- return path
+ Returns:
+ The route to use to serve the request.
+ """
+ if path != "/" and path.endswith("/"):
+ return path[:-1]
+ return path
def _get_event_file_active_filter(flags):
- """Returns a predicate for whether an event file load timestamp is active.
-
- Returns:
- A predicate function accepting a single UNIX timestamp float argument, or
- None if multi-file loading is not enabled.
- """
- if not flags.reload_multifile:
- return None
- inactive_secs = flags.reload_multifile_inactive_secs
- if inactive_secs == 0:
- return None
- if inactive_secs < 0:
- return lambda timestamp: True
- return lambda timestamp: timestamp + inactive_secs >= time.time()
+ """Returns a predicate for whether an event file load timestamp is active.
+ Returns:
+ A predicate function accepting a single UNIX timestamp float argument, or
+ None if multi-file loading is not enabled.
+ """
+ if not flags.reload_multifile:
+ return None
+ inactive_secs = flags.reload_multifile_inactive_secs
+ if inactive_secs == 0:
+ return None
+ if inactive_secs < 0:
+ return lambda timestamp: True
+ return lambda timestamp: timestamp + inactive_secs >= time.time()
-class _DbModeMultiplexer(event_multiplexer.EventMultiplexer):
- """Shim EventMultiplexer to use when in read-only DB mode.
- In read-only DB mode, the EventMultiplexer is nonfunctional - there is no
- logdir to reload, and the data is all exposed via SQL. This class represents
- the do-nothing EventMultiplexer for that purpose, which serves only as a
- conduit for DB-related parameters.
+class _DbModeMultiplexer(event_multiplexer.EventMultiplexer):
+ """Shim EventMultiplexer to use when in read-only DB mode.
- The load APIs raise exceptions if called, and the read APIs always
- return empty results.
- """
- def __init__(self, db_uri, db_connection_provider):
- """Constructor for `_DbModeMultiplexer`.
+ In read-only DB mode, the EventMultiplexer is nonfunctional - there is no
+ logdir to reload, and the data is all exposed via SQL. This class represents
+ the do-nothing EventMultiplexer for that purpose, which serves only as a
+ conduit for DB-related parameters.
- Args:
- db_uri: A URI to the database file in use.
- db_connection_provider: Provider function for creating a DB connection.
+ The load APIs raise exceptions if called, and the read APIs always
+ return empty results.
"""
- logger.info('_DbModeMultiplexer initializing for %s', db_uri)
- super(_DbModeMultiplexer, self).__init__()
- self.db_uri = db_uri
- self.db_connection_provider = db_connection_provider
- logger.info('_DbModeMultiplexer done initializing')
- def AddRun(self, path, name=None):
- """Unsupported."""
- raise NotImplementedError()
+ def __init__(self, db_uri, db_connection_provider):
+ """Constructor for `_DbModeMultiplexer`.
+
+ Args:
+ db_uri: A URI to the database file in use.
+ db_connection_provider: Provider function for creating a DB connection.
+ """
+ logger.info("_DbModeMultiplexer initializing for %s", db_uri)
+ super(_DbModeMultiplexer, self).__init__()
+ self.db_uri = db_uri
+ self.db_connection_provider = db_connection_provider
+ logger.info("_DbModeMultiplexer done initializing")
+
+ def AddRun(self, path, name=None):
+ """Unsupported."""
+ raise NotImplementedError()
- def AddRunsFromDirectory(self, path, name=None):
- """Unsupported."""
- raise NotImplementedError()
+ def AddRunsFromDirectory(self, path, name=None):
+ """Unsupported."""
+ raise NotImplementedError()
- def Reload(self):
- """Unsupported."""
- raise NotImplementedError()
+ def Reload(self):
+ """Unsupported."""
+ raise NotImplementedError()
def make_plugin_loader(plugin_spec):
- """Returns a plugin loader for the given plugin.
-
- Args:
- plugin_spec: A TBPlugin subclass, or a TBLoader instance or subclass.
-
- Returns:
- A TBLoader for the given plugin.
-
- :type plugin_spec:
- Type[base_plugin.TBPlugin] | Type[base_plugin.TBLoader] |
- base_plugin.TBLoader
- :rtype: base_plugin.TBLoader
- """
- if isinstance(plugin_spec, base_plugin.TBLoader):
- return plugin_spec
- if isinstance(plugin_spec, type):
- if issubclass(plugin_spec, base_plugin.TBLoader):
- return plugin_spec()
- if issubclass(plugin_spec, base_plugin.TBPlugin):
- return base_plugin.BasicLoader(plugin_spec)
- raise TypeError("Not a TBLoader or TBPlugin subclass: %r" % (plugin_spec,))
+ """Returns a plugin loader for the given plugin.
+
+ Args:
+ plugin_spec: A TBPlugin subclass, or a TBLoader instance or subclass.
+
+ Returns:
+ A TBLoader for the given plugin.
+
+ :type plugin_spec:
+ Type[base_plugin.TBPlugin] | Type[base_plugin.TBLoader] |
+ base_plugin.TBLoader
+ :rtype: base_plugin.TBLoader
+ """
+ if isinstance(plugin_spec, base_plugin.TBLoader):
+ return plugin_spec
+ if isinstance(plugin_spec, type):
+ if issubclass(plugin_spec, base_plugin.TBLoader):
+ return plugin_spec()
+ if issubclass(plugin_spec, base_plugin.TBPlugin):
+ return base_plugin.BasicLoader(plugin_spec)
+ raise TypeError("Not a TBLoader or TBPlugin subclass: %r" % (plugin_spec,))
diff --git a/tensorboard/backend/application_test.py b/tensorboard/backend/application_test.py
index 111d135a54..36cf656495 100644
--- a/tensorboard/backend/application_test.py
+++ b/tensorboard/backend/application_test.py
@@ -32,10 +32,10 @@
import six
try:
- # python version >= 3.3
- from unittest import mock
+ # python version >= 3.3
+ from unittest import mock
except ImportError:
- import mock # pylint: disable=unused-import
+ import mock # pylint: disable=unused-import
from werkzeug import test as werkzeug_test
from werkzeug import wrappers
@@ -44,865 +44,920 @@
from tensorboard import plugin_util
from tensorboard import test as tb_test
from tensorboard.backend import application
-from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer
+from tensorboard.backend.event_processing import (
+ plugin_event_multiplexer as event_multiplexer,
+)
from tensorboard.plugins import base_plugin
class FakeFlags(object):
- def __init__(
- self,
- logdir,
- logdir_spec='',
- purge_orphaned_data=True,
- reload_interval=60,
- samples_per_plugin='',
- max_reload_threads=1,
- reload_task='auto',
- db='',
- db_import=False,
- window_title='',
- path_prefix='',
- reload_multifile=False,
- reload_multifile_inactive_secs=4000,
- generic_data='auto'):
- self.logdir = logdir
- self.logdir_spec = logdir_spec
- self.purge_orphaned_data = purge_orphaned_data
- self.reload_interval = reload_interval
- self.samples_per_plugin = samples_per_plugin
- self.max_reload_threads = max_reload_threads
- self.reload_task = reload_task
- self.db = db
- self.db_import = db_import
- self.window_title = window_title
- self.path_prefix = path_prefix
- self.reload_multifile = reload_multifile
- self.reload_multifile_inactive_secs = reload_multifile_inactive_secs
- self.generic_data = generic_data
+ def __init__(
+ self,
+ logdir,
+ logdir_spec="",
+ purge_orphaned_data=True,
+ reload_interval=60,
+ samples_per_plugin="",
+ max_reload_threads=1,
+ reload_task="auto",
+ db="",
+ db_import=False,
+ window_title="",
+ path_prefix="",
+ reload_multifile=False,
+ reload_multifile_inactive_secs=4000,
+ generic_data="auto",
+ ):
+ self.logdir = logdir
+ self.logdir_spec = logdir_spec
+ self.purge_orphaned_data = purge_orphaned_data
+ self.reload_interval = reload_interval
+ self.samples_per_plugin = samples_per_plugin
+ self.max_reload_threads = max_reload_threads
+ self.reload_task = reload_task
+ self.db = db
+ self.db_import = db_import
+ self.window_title = window_title
+ self.path_prefix = path_prefix
+ self.reload_multifile = reload_multifile
+ self.reload_multifile_inactive_secs = reload_multifile_inactive_secs
+ self.generic_data = generic_data
class FakePlugin(base_plugin.TBPlugin):
- """A plugin with no functionality."""
-
- def __init__(self,
- context=None,
- plugin_name='foo',
- is_active_value=True,
- routes_mapping={},
- element_name_value=None,
- es_module_path_value=None,
- is_ng_component=False,
- construction_callback=None):
- """Constructs a fake plugin.
-
- Args:
- context: The TBContext magic container. Contains properties that are
- potentially useful to this plugin.
- plugin_name: The name of this plugin.
- is_active_value: Whether the plugin is active.
- routes_mapping: A dictionary mapping from route (string URL path) to the
- method called when a user issues a request to that route.
- es_module_path_value: An optional string value that indicates a frontend
- module entry to the plugin. Must be one of the keys of routes_mapping.
- is_ng_component: Whether this plugin is of the built-in Angular-based
- type.
- construction_callback: An optional callback called when the plugin is
- constructed. The callback is passed the TBContext.
- """
- self.plugin_name = plugin_name
- self._is_active_value = is_active_value
- self._routes_mapping = routes_mapping
- self._element_name_value = element_name_value
- self._es_module_path_value = es_module_path_value
- self._is_ng_component = is_ng_component
-
- if construction_callback:
- construction_callback(context)
-
- def get_plugin_apps(self):
- """Returns a mapping from routes to handlers offered by this plugin.
-
- Returns:
- A dictionary mapping from routes to handlers offered by this plugin.
- """
- return self._routes_mapping
-
- def is_active(self):
- """Returns whether this plugin is active.
-
- Returns:
- A boolean. Whether this plugin is active.
- """
- return self._is_active_value
-
- def frontend_metadata(self):
- return base_plugin.FrontendMetadata(
- element_name=self._element_name_value,
- es_module_path=self._es_module_path_value,
- is_ng_component=self._is_ng_component,
- )
+ """A plugin with no functionality."""
+
+ def __init__(
+ self,
+ context=None,
+ plugin_name="foo",
+ is_active_value=True,
+ routes_mapping={},
+ element_name_value=None,
+ es_module_path_value=None,
+ is_ng_component=False,
+ construction_callback=None,
+ ):
+ """Constructs a fake plugin.
+
+ Args:
+ context: The TBContext magic container. Contains properties that are
+ potentially useful to this plugin.
+ plugin_name: The name of this plugin.
+ is_active_value: Whether the plugin is active.
+ routes_mapping: A dictionary mapping from route (string URL path) to the
+ method called when a user issues a request to that route.
+ es_module_path_value: An optional string value that indicates a frontend
+ module entry to the plugin. Must be one of the keys of routes_mapping.
+ is_ng_component: Whether this plugin is of the built-in Angular-based
+ type.
+ construction_callback: An optional callback called when the plugin is
+ constructed. The callback is passed the TBContext.
+ """
+ self.plugin_name = plugin_name
+ self._is_active_value = is_active_value
+ self._routes_mapping = routes_mapping
+ self._element_name_value = element_name_value
+ self._es_module_path_value = es_module_path_value
+ self._is_ng_component = is_ng_component
+
+ if construction_callback:
+ construction_callback(context)
+
+ def get_plugin_apps(self):
+ """Returns a mapping from routes to handlers offered by this plugin.
+
+ Returns:
+ A dictionary mapping from routes to handlers offered by this plugin.
+ """
+ return self._routes_mapping
+
+ def is_active(self):
+ """Returns whether this plugin is active.
+
+ Returns:
+ A boolean. Whether this plugin is active.
+ """
+ return self._is_active_value
+
+ def frontend_metadata(self):
+ return base_plugin.FrontendMetadata(
+ element_name=self._element_name_value,
+ es_module_path=self._es_module_path_value,
+ is_ng_component=self._is_ng_component,
+ )
class FakePluginLoader(base_plugin.TBLoader):
- """Pass-through loader for FakePlugin with arbitrary arguments."""
+ """Pass-through loader for FakePlugin with arbitrary arguments."""
- def __init__(self, **kwargs):
- self._kwargs = kwargs
+ def __init__(self, **kwargs):
+ self._kwargs = kwargs
- def load(self, context):
- return FakePlugin(context, **self._kwargs)
+ def load(self, context):
+ return FakePlugin(context, **self._kwargs)
class HandlingErrorsTest(tb_test.TestCase):
-
- def test_successful_response_passes_through(self):
- @application._handling_errors
- @wrappers.Request.application
- def app(request):
- return wrappers.Response('All is well', 200, content_type='text/html')
-
- server = werkzeug_test.Client(app, wrappers.BaseResponse)
- response = server.get('/')
- self.assertEqual(response.get_data(), b'All is well')
- self.assertEqual(response.status_code, 200)
- self.assertEqual(response.headers.get('Content-Type'), 'text/html')
-
- def test_public_errors_serve_response(self):
- @application._handling_errors
- @wrappers.Request.application
- def app(request):
- raise errors.NotFoundError('no scalar data for run=foo, tag=bar')
-
- server = werkzeug_test.Client(app, wrappers.BaseResponse)
- response = server.get('/')
- self.assertEqual(
- response.get_data(),
- b'Not found: no scalar data for run=foo, tag=bar',
- )
- self.assertEqual(response.status_code, 404)
- self.assertStartsWith(response.headers.get('Content-Type'), 'text/plain')
-
- def test_internal_errors_propagate(self):
- @application._handling_errors
- @wrappers.Request.application
- def app(request):
- raise ValueError('something borked internally')
-
- server = werkzeug_test.Client(app, wrappers.BaseResponse)
- with self.assertRaises(ValueError) as cm:
- response = server.get('/')
- self.assertEqual(str(cm.exception), 'something borked internally')
-
- def test_passes_through_non_wsgi_args(self):
- class C(object):
- @application._handling_errors
- def __call__(self, environ, start_response):
- start_response('200 OK', [('Content-Type', 'text/html')])
- yield b'All is well'
-
- app = C()
- server = werkzeug_test.Client(app, wrappers.BaseResponse)
- response = server.get('/')
- self.assertEqual(response.get_data(), b'All is well')
- self.assertEqual(response.status_code, 200)
- self.assertEqual(response.headers.get('Content-Type'), 'text/html')
+ def test_successful_response_passes_through(self):
+ @application._handling_errors
+ @wrappers.Request.application
+ def app(request):
+ return wrappers.Response(
+ "All is well", 200, content_type="text/html"
+ )
+
+ server = werkzeug_test.Client(app, wrappers.BaseResponse)
+ response = server.get("/")
+ self.assertEqual(response.get_data(), b"All is well")
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.headers.get("Content-Type"), "text/html")
+
+ def test_public_errors_serve_response(self):
+ @application._handling_errors
+ @wrappers.Request.application
+ def app(request):
+ raise errors.NotFoundError("no scalar data for run=foo, tag=bar")
+
+ server = werkzeug_test.Client(app, wrappers.BaseResponse)
+ response = server.get("/")
+ self.assertEqual(
+ response.get_data(),
+ b"Not found: no scalar data for run=foo, tag=bar",
+ )
+ self.assertEqual(response.status_code, 404)
+ self.assertStartsWith(
+ response.headers.get("Content-Type"), "text/plain"
+ )
+
+ def test_internal_errors_propagate(self):
+ @application._handling_errors
+ @wrappers.Request.application
+ def app(request):
+ raise ValueError("something borked internally")
+
+ server = werkzeug_test.Client(app, wrappers.BaseResponse)
+ with self.assertRaises(ValueError) as cm:
+ response = server.get("/")
+ self.assertEqual(str(cm.exception), "something borked internally")
+
+ def test_passes_through_non_wsgi_args(self):
+ class C(object):
+ @application._handling_errors
+ def __call__(self, environ, start_response):
+ start_response("200 OK", [("Content-Type", "text/html")])
+ yield b"All is well"
+
+ app = C()
+ server = werkzeug_test.Client(app, wrappers.BaseResponse)
+ response = server.get("/")
+ self.assertEqual(response.get_data(), b"All is well")
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.headers.get("Content-Type"), "text/html")
class ApplicationTest(tb_test.TestCase):
- def setUp(self):
- plugins = [
- FakePlugin(plugin_name='foo'),
- FakePlugin(
- plugin_name='bar',
- is_active_value=False,
- element_name_value='tf-bar-dashboard',
- ),
- FakePlugin(
- plugin_name='baz',
- routes_mapping={
- '/esmodule': lambda req: None,
- },
- es_module_path_value='/esmodule'
- ),
- FakePlugin(
- plugin_name='qux',
- is_active_value=False,
- is_ng_component=True,
- ),
- ]
- app = application.TensorBoardWSGI(plugins)
- self.server = werkzeug_test.Client(app, wrappers.BaseResponse)
-
- def _get_json(self, path):
- response = self.server.get(path)
- self.assertEqual(200, response.status_code)
- self.assertEqual('application/json', response.headers.get('Content-Type'))
- return json.loads(response.get_data().decode('utf-8'))
-
- def testBasicStartup(self):
- """Start the server up and then shut it down immediately."""
- pass
-
- def testRequestNonexistentPage(self):
- """Request a page that doesn't exist; it should 404."""
- response = self.server.get('/asdf')
- self.assertEqual(404, response.status_code)
-
- def testPluginsListing(self):
- """Test the format of the data/plugins_listing endpoint."""
- parsed_object = self._get_json('/data/plugins_listing')
- self.assertEqual(
- parsed_object,
- {
- 'foo': {
- 'enabled': True,
- 'loading_mechanism': {'type': 'NONE'},
- 'remove_dom': False,
- 'tab_name': 'foo',
- 'disable_reload': False,
- },
- 'bar': {
- 'enabled': False,
- 'loading_mechanism': {
- 'type': 'CUSTOM_ELEMENT',
- 'element_name': 'tf-bar-dashboard',
+ def setUp(self):
+ plugins = [
+ FakePlugin(plugin_name="foo"),
+ FakePlugin(
+ plugin_name="bar",
+ is_active_value=False,
+ element_name_value="tf-bar-dashboard",
+ ),
+ FakePlugin(
+ plugin_name="baz",
+ routes_mapping={"/esmodule": lambda req: None,},
+ es_module_path_value="/esmodule",
+ ),
+ FakePlugin(
+ plugin_name="qux", is_active_value=False, is_ng_component=True,
+ ),
+ ]
+ app = application.TensorBoardWSGI(plugins)
+ self.server = werkzeug_test.Client(app, wrappers.BaseResponse)
+
+ def _get_json(self, path):
+ response = self.server.get(path)
+ self.assertEqual(200, response.status_code)
+ self.assertEqual(
+ "application/json", response.headers.get("Content-Type")
+ )
+ return json.loads(response.get_data().decode("utf-8"))
+
+ def testBasicStartup(self):
+ """Start the server up and then shut it down immediately."""
+ pass
+
+ def testRequestNonexistentPage(self):
+ """Request a page that doesn't exist; it should 404."""
+ response = self.server.get("/asdf")
+ self.assertEqual(404, response.status_code)
+
+ def testPluginsListing(self):
+ """Test the format of the data/plugins_listing endpoint."""
+ parsed_object = self._get_json("/data/plugins_listing")
+ self.assertEqual(
+ parsed_object,
+ {
+ "foo": {
+ "enabled": True,
+ "loading_mechanism": {"type": "NONE"},
+ "remove_dom": False,
+ "tab_name": "foo",
+ "disable_reload": False,
},
- 'tab_name': 'bar',
- 'remove_dom': False,
- 'disable_reload': False,
- },
- 'baz': {
- 'enabled': True,
- 'loading_mechanism': {
- 'type': 'IFRAME',
- 'module_path': '/data/plugin/baz/esmodule',
+ "bar": {
+ "enabled": False,
+ "loading_mechanism": {
+ "type": "CUSTOM_ELEMENT",
+ "element_name": "tf-bar-dashboard",
+ },
+ "tab_name": "bar",
+ "remove_dom": False,
+ "disable_reload": False,
},
- 'tab_name': 'baz',
- 'remove_dom': False,
- 'disable_reload': False,
- },
- 'qux': {
- 'enabled': False,
- 'loading_mechanism': {
- 'type': 'NG_COMPONENT',
+ "baz": {
+ "enabled": True,
+ "loading_mechanism": {
+ "type": "IFRAME",
+ "module_path": "/data/plugin/baz/esmodule",
+ },
+ "tab_name": "baz",
+ "remove_dom": False,
+ "disable_reload": False,
+ },
+ "qux": {
+ "enabled": False,
+ "loading_mechanism": {"type": "NG_COMPONENT",},
+ "tab_name": "qux",
+ "remove_dom": False,
+ "disable_reload": False,
},
- 'tab_name': 'qux',
- 'remove_dom': False,
- 'disable_reload': False,
},
-
- }
- )
-
- def testPluginEntry(self):
- """Test the data/plugin_entry.html endpoint."""
- response = self.server.get('/data/plugin_entry.html?name=baz')
- self.assertEqual(200, response.status_code)
- self.assertEqual(
- 'text/html; charset=utf-8', response.headers.get('Content-Type'))
-
- document = response.get_data().decode('utf-8')
- self.assertIn('', document)
- self.assertIn(
- 'import("./esmodule").then((m) => void m.render());', document)
- # base64 sha256 of above script
- self.assertIn(
- "'sha256-3KGOnqHhLsX2RmjH/K2DurN9N2qtApZk5zHdSPg4LcA='",
- response.headers.get('Content-Security-Policy'),
- )
-
- for name in ['bazz', 'baz ']:
- response = self.server.get('/data/plugin_entry.html?name=%s' % name)
- self.assertEqual(404, response.status_code)
-
- for name in ['foo', 'bar']:
- response = self.server.get('/data/plugin_entry.html?name=%s' % name)
- self.assertEqual(400, response.status_code)
- self.assertEqual(
- response.get_data().decode('utf-8'),
- 'Plugin is not module loadable',
- )
-
- def testPluginEntryBadModulePath(self):
- plugins = [
- FakePlugin(
- plugin_name='mallory',
- es_module_path_value='//pwn.tb/somepath'
- ),
- ]
- app = application.TensorBoardWSGI(plugins)
- server = werkzeug_test.Client(app, wrappers.BaseResponse)
- with six.assertRaisesRegex(
- self, ValueError, 'Expected es_module_path to be non-absolute path'):
- server.get('/data/plugin_entry.html?name=mallory')
-
- def testNgComponentPluginWithIncompatibleSetElementName(self):
- plugins = [
- FakePlugin(
- plugin_name='quux',
- is_ng_component=True,
- element_name_value='incompatible',
- ),
- ]
- app = application.TensorBoardWSGI(plugins)
- server = werkzeug_test.Client(app, wrappers.BaseResponse)
- with six.assertRaisesRegex(
- self, ValueError, 'quux.*declared.*both Angular.*legacy'):
- server.get('/data/plugins_listing')
-
- def testNgComponentPluginWithIncompatiblEsModulePath(self):
- plugins = [
- FakePlugin(
- plugin_name='quux',
- is_ng_component=True,
- es_module_path_value='//incompatible',
- ),
- ]
- app = application.TensorBoardWSGI(plugins)
- server = werkzeug_test.Client(app, wrappers.BaseResponse)
- with six.assertRaisesRegex(
- self, ValueError, 'quux.*declared.*both Angular.*iframed'):
- server.get('/data/plugins_listing')
+ )
+
+ def testPluginEntry(self):
+ """Test the data/plugin_entry.html endpoint."""
+ response = self.server.get("/data/plugin_entry.html?name=baz")
+ self.assertEqual(200, response.status_code)
+ self.assertEqual(
+ "text/html; charset=utf-8", response.headers.get("Content-Type")
+ )
+
+ document = response.get_data().decode("utf-8")
+ self.assertIn('', document)
+ self.assertIn(
+ 'import("./esmodule").then((m) => void m.render());', document
+ )
+ # base64 sha256 of above script
+ self.assertIn(
+ "'sha256-3KGOnqHhLsX2RmjH/K2DurN9N2qtApZk5zHdSPg4LcA='",
+ response.headers.get("Content-Security-Policy"),
+ )
+
+ for name in ["bazz", "baz "]:
+ response = self.server.get("/data/plugin_entry.html?name=%s" % name)
+ self.assertEqual(404, response.status_code)
+
+ for name in ["foo", "bar"]:
+ response = self.server.get("/data/plugin_entry.html?name=%s" % name)
+ self.assertEqual(400, response.status_code)
+ self.assertEqual(
+ response.get_data().decode("utf-8"),
+ "Plugin is not module loadable",
+ )
+
+ def testPluginEntryBadModulePath(self):
+ plugins = [
+ FakePlugin(
+ plugin_name="mallory", es_module_path_value="//pwn.tb/somepath"
+ ),
+ ]
+ app = application.TensorBoardWSGI(plugins)
+ server = werkzeug_test.Client(app, wrappers.BaseResponse)
+ with six.assertRaisesRegex(
+ self, ValueError, "Expected es_module_path to be non-absolute path"
+ ):
+ server.get("/data/plugin_entry.html?name=mallory")
+
+ def testNgComponentPluginWithIncompatibleSetElementName(self):
+ plugins = [
+ FakePlugin(
+ plugin_name="quux",
+ is_ng_component=True,
+ element_name_value="incompatible",
+ ),
+ ]
+ app = application.TensorBoardWSGI(plugins)
+ server = werkzeug_test.Client(app, wrappers.BaseResponse)
+ with six.assertRaisesRegex(
+ self, ValueError, "quux.*declared.*both Angular.*legacy"
+ ):
+ server.get("/data/plugins_listing")
+
+ def testNgComponentPluginWithIncompatiblEsModulePath(self):
+ plugins = [
+ FakePlugin(
+ plugin_name="quux",
+ is_ng_component=True,
+ es_module_path_value="//incompatible",
+ ),
+ ]
+ app = application.TensorBoardWSGI(plugins)
+ server = werkzeug_test.Client(app, wrappers.BaseResponse)
+ with six.assertRaisesRegex(
+ self, ValueError, "quux.*declared.*both Angular.*iframed"
+ ):
+ server.get("/data/plugins_listing")
class ApplicationBaseUrlTest(tb_test.TestCase):
- path_prefix = '/test'
- def setUp(self):
- plugins = [
- FakePlugin(plugin_name='foo'),
- FakePlugin(
- plugin_name='bar',
- is_active_value=False,
- element_name_value='tf-bar-dashboard',
- ),
- FakePlugin(
- plugin_name='baz',
- routes_mapping={
- '/esmodule': lambda req: None,
- },
- es_module_path_value='/esmodule'
- ),
- ]
- app = application.TensorBoardWSGI(plugins, path_prefix=self.path_prefix)
- self.server = werkzeug_test.Client(app, wrappers.BaseResponse)
-
- def _get_json(self, path):
- response = self.server.get(path)
- self.assertEqual(200, response.status_code)
- self.assertEqual('application/json', response.headers.get('Content-Type'))
- return json.loads(response.get_data().decode('utf-8'))
-
- def testBaseUrlRequest(self):
- """Base URL should redirect to "/" for proper relative URLs."""
- response = self.server.get(self.path_prefix)
- self.assertEqual(301, response.status_code)
-
- def testBaseUrlRequestNonexistentPage(self):
- """Request a page that doesn't exist; it should 404."""
- response = self.server.get(self.path_prefix + '/asdf')
- self.assertEqual(404, response.status_code)
-
- def testBaseUrlNonexistentPluginsListing(self):
- """Test the format of the data/plugins_listing endpoint."""
- response = self.server.get('/non_existent_prefix/data/plugins_listing')
- self.assertEqual(404, response.status_code)
-
- def testPluginsListing(self):
- """Test the format of the data/plugins_listing endpoint."""
- parsed_object = self._get_json(self.path_prefix + '/data/plugins_listing')
- self.assertEqual(
- parsed_object,
- {
- 'foo': {
- 'enabled': True,
- 'loading_mechanism': {'type': 'NONE'},
- 'remove_dom': False,
- 'tab_name': 'foo',
- 'disable_reload': False,
- },
- 'bar': {
- 'enabled': False,
- 'loading_mechanism': {
- 'type': 'CUSTOM_ELEMENT',
- 'element_name': 'tf-bar-dashboard',
+ path_prefix = "/test"
+
+ def setUp(self):
+ plugins = [
+ FakePlugin(plugin_name="foo"),
+ FakePlugin(
+ plugin_name="bar",
+ is_active_value=False,
+ element_name_value="tf-bar-dashboard",
+ ),
+ FakePlugin(
+ plugin_name="baz",
+ routes_mapping={"/esmodule": lambda req: None,},
+ es_module_path_value="/esmodule",
+ ),
+ ]
+ app = application.TensorBoardWSGI(plugins, path_prefix=self.path_prefix)
+ self.server = werkzeug_test.Client(app, wrappers.BaseResponse)
+
+ def _get_json(self, path):
+ response = self.server.get(path)
+ self.assertEqual(200, response.status_code)
+ self.assertEqual(
+ "application/json", response.headers.get("Content-Type")
+ )
+ return json.loads(response.get_data().decode("utf-8"))
+
+ def testBaseUrlRequest(self):
+ """Base URL should redirect to "/" for proper relative URLs."""
+ response = self.server.get(self.path_prefix)
+ self.assertEqual(301, response.status_code)
+
+ def testBaseUrlRequestNonexistentPage(self):
+ """Request a page that doesn't exist; it should 404."""
+ response = self.server.get(self.path_prefix + "/asdf")
+ self.assertEqual(404, response.status_code)
+
+ def testBaseUrlNonexistentPluginsListing(self):
+ """Test the format of the data/plugins_listing endpoint."""
+ response = self.server.get("/non_existent_prefix/data/plugins_listing")
+ self.assertEqual(404, response.status_code)
+
+ def testPluginsListing(self):
+ """Test the format of the data/plugins_listing endpoint."""
+ parsed_object = self._get_json(
+ self.path_prefix + "/data/plugins_listing"
+ )
+ self.assertEqual(
+ parsed_object,
+ {
+ "foo": {
+ "enabled": True,
+ "loading_mechanism": {"type": "NONE"},
+ "remove_dom": False,
+ "tab_name": "foo",
+ "disable_reload": False,
},
- 'tab_name': 'bar',
- 'remove_dom': False,
- 'disable_reload': False,
- },
- 'baz': {
- 'enabled': True,
- 'loading_mechanism': {
- 'type': 'IFRAME',
- 'module_path': '/test/data/plugin/baz/esmodule',
+ "bar": {
+ "enabled": False,
+ "loading_mechanism": {
+ "type": "CUSTOM_ELEMENT",
+ "element_name": "tf-bar-dashboard",
+ },
+ "tab_name": "bar",
+ "remove_dom": False,
+ "disable_reload": False,
+ },
+ "baz": {
+ "enabled": True,
+ "loading_mechanism": {
+ "type": "IFRAME",
+ "module_path": "/test/data/plugin/baz/esmodule",
+ },
+ "tab_name": "baz",
+ "remove_dom": False,
+ "disable_reload": False,
},
- 'tab_name': 'baz',
- 'remove_dom': False,
- 'disable_reload': False,
},
- }
- )
+ )
class ApplicationPluginNameTest(tb_test.TestCase):
-
- def testSimpleName(self):
- application.TensorBoardWSGI(
- plugins=[FakePlugin(plugin_name='scalars')])
-
- def testComprehensiveName(self):
- application.TensorBoardWSGI(
- plugins=[FakePlugin(plugin_name='Scalar-Dashboard_3000.1')])
-
- def testNameIsNone(self):
- with six.assertRaisesRegex(self, ValueError, r'no plugin_name'):
- application.TensorBoardWSGI(
- plugins=[FakePlugin(plugin_name=None)])
-
- def testEmptyName(self):
- with six.assertRaisesRegex(self, ValueError, r'invalid name'):
- application.TensorBoardWSGI(
- plugins=[FakePlugin(plugin_name='')])
-
- def testNameWithSlashes(self):
- with six.assertRaisesRegex(self, ValueError, r'invalid name'):
- application.TensorBoardWSGI(
- plugins=[FakePlugin(plugin_name='scalars/data')])
-
- def testNameWithSpaces(self):
- with six.assertRaisesRegex(self, ValueError, r'invalid name'):
- application.TensorBoardWSGI(
- plugins=[FakePlugin(plugin_name='my favorite plugin')])
-
- def testDuplicateName(self):
- with six.assertRaisesRegex(self, ValueError, r'Duplicate'):
- application.TensorBoardWSGI(
- plugins=[FakePlugin(plugin_name='scalars'),
- FakePlugin(plugin_name='scalars')])
+ def testSimpleName(self):
+ application.TensorBoardWSGI(plugins=[FakePlugin(plugin_name="scalars")])
+
+ def testComprehensiveName(self):
+ application.TensorBoardWSGI(
+ plugins=[FakePlugin(plugin_name="Scalar-Dashboard_3000.1")]
+ )
+
+ def testNameIsNone(self):
+ with six.assertRaisesRegex(self, ValueError, r"no plugin_name"):
+ application.TensorBoardWSGI(plugins=[FakePlugin(plugin_name=None)])
+
+ def testEmptyName(self):
+ with six.assertRaisesRegex(self, ValueError, r"invalid name"):
+ application.TensorBoardWSGI(plugins=[FakePlugin(plugin_name="")])
+
+ def testNameWithSlashes(self):
+ with six.assertRaisesRegex(self, ValueError, r"invalid name"):
+ application.TensorBoardWSGI(
+ plugins=[FakePlugin(plugin_name="scalars/data")]
+ )
+
+ def testNameWithSpaces(self):
+ with six.assertRaisesRegex(self, ValueError, r"invalid name"):
+ application.TensorBoardWSGI(
+ plugins=[FakePlugin(plugin_name="my favorite plugin")]
+ )
+
+ def testDuplicateName(self):
+ with six.assertRaisesRegex(self, ValueError, r"Duplicate"):
+ application.TensorBoardWSGI(
+ plugins=[
+ FakePlugin(plugin_name="scalars"),
+ FakePlugin(plugin_name="scalars"),
+ ]
+ )
class ApplicationPluginRouteTest(tb_test.TestCase):
+ def _make_plugin(self, route):
+ return FakePlugin(
+ plugin_name="foo",
+ routes_mapping={route: lambda environ, start_response: None},
+ )
- def _make_plugin(self, route):
- return FakePlugin(
- plugin_name='foo',
- routes_mapping={route: lambda environ, start_response: None})
-
- def testNormalRoute(self):
- application.TensorBoardWSGI([self._make_plugin('/runs')])
+ def testNormalRoute(self):
+ application.TensorBoardWSGI([self._make_plugin("/runs")])
- def testWildcardRoute(self):
- application.TensorBoardWSGI([self._make_plugin('/foo/*')])
+ def testWildcardRoute(self):
+ application.TensorBoardWSGI([self._make_plugin("/foo/*")])
- def testNonPathComponentWildcardRoute(self):
- with six.assertRaisesRegex(self, ValueError, r'invalid route'):
- application.TensorBoardWSGI([self._make_plugin('/foo*')])
+ def testNonPathComponentWildcardRoute(self):
+ with six.assertRaisesRegex(self, ValueError, r"invalid route"):
+ application.TensorBoardWSGI([self._make_plugin("/foo*")])
- def testMultiWildcardRoute(self):
- with six.assertRaisesRegex(self, ValueError, r'invalid route'):
- application.TensorBoardWSGI([self._make_plugin('/foo/*/bar/*')])
+ def testMultiWildcardRoute(self):
+ with six.assertRaisesRegex(self, ValueError, r"invalid route"):
+ application.TensorBoardWSGI([self._make_plugin("/foo/*/bar/*")])
- def testInternalWildcardRoute(self):
- with six.assertRaisesRegex(self, ValueError, r'invalid route'):
- application.TensorBoardWSGI([self._make_plugin('/foo/*/bar')])
+ def testInternalWildcardRoute(self):
+ with six.assertRaisesRegex(self, ValueError, r"invalid route"):
+ application.TensorBoardWSGI([self._make_plugin("/foo/*/bar")])
- def testEmptyRoute(self):
- with six.assertRaisesRegex(self, ValueError, r'invalid route'):
- application.TensorBoardWSGI([self._make_plugin('')])
+ def testEmptyRoute(self):
+ with six.assertRaisesRegex(self, ValueError, r"invalid route"):
+ application.TensorBoardWSGI([self._make_plugin("")])
- def testSlashlessRoute(self):
- with six.assertRaisesRegex(self, ValueError, r'invalid route'):
- application.TensorBoardWSGI([self._make_plugin('runaway')])
+ def testSlashlessRoute(self):
+ with six.assertRaisesRegex(self, ValueError, r"invalid route"):
+ application.TensorBoardWSGI([self._make_plugin("runaway")])
class MakePluginLoaderTest(tb_test.TestCase):
+ def testMakePluginLoader_pluginClass(self):
+ loader = application.make_plugin_loader(FakePlugin)
+ self.assertIsInstance(loader, base_plugin.BasicLoader)
+ self.assertIs(loader.plugin_class, FakePlugin)
- def testMakePluginLoader_pluginClass(self):
- loader = application.make_plugin_loader(FakePlugin)
- self.assertIsInstance(loader, base_plugin.BasicLoader)
- self.assertIs(loader.plugin_class, FakePlugin)
-
- def testMakePluginLoader_pluginLoaderClass(self):
- loader = application.make_plugin_loader(FakePluginLoader)
- self.assertIsInstance(loader, FakePluginLoader)
+ def testMakePluginLoader_pluginLoaderClass(self):
+ loader = application.make_plugin_loader(FakePluginLoader)
+ self.assertIsInstance(loader, FakePluginLoader)
- def testMakePluginLoader_pluginLoader(self):
- loader = FakePluginLoader()
- self.assertIs(loader, application.make_plugin_loader(loader))
+ def testMakePluginLoader_pluginLoader(self):
+ loader = FakePluginLoader()
+ self.assertIs(loader, application.make_plugin_loader(loader))
- def testMakePluginLoader_invalidType(self):
- with six.assertRaisesRegex(self, TypeError, 'FakePlugin'):
- application.make_plugin_loader(FakePlugin())
+ def testMakePluginLoader_invalidType(self):
+ with six.assertRaisesRegex(self, TypeError, "FakePlugin"):
+ application.make_plugin_loader(FakePlugin())
class GetEventFileActiveFilterTest(tb_test.TestCase):
-
- def testDisabled(self):
- flags = FakeFlags('logdir', reload_multifile=False)
- self.assertIsNone(application._get_event_file_active_filter(flags))
-
- def testInactiveSecsZero(self):
- flags = FakeFlags('logdir', reload_multifile=True,
- reload_multifile_inactive_secs=0)
- self.assertIsNone(application._get_event_file_active_filter(flags))
-
- def testInactiveSecsNegative(self):
- flags = FakeFlags('logdir', reload_multifile=True,
- reload_multifile_inactive_secs=-1)
- filter_fn = application._get_event_file_active_filter(flags)
- self.assertTrue(filter_fn(0))
- self.assertTrue(filter_fn(time.time()))
- self.assertTrue(filter_fn(float("inf")))
-
- def testInactiveSecs(self):
- flags = FakeFlags('logdir', reload_multifile=True,
- reload_multifile_inactive_secs=10)
- filter_fn = application._get_event_file_active_filter(flags)
- with mock.patch.object(time, 'time') as mock_time:
- mock_time.return_value = 100
- self.assertFalse(filter_fn(0))
- self.assertFalse(filter_fn(time.time() - 11))
- self.assertTrue(filter_fn(time.time() - 10))
- self.assertTrue(filter_fn(time.time()))
- self.assertTrue(filter_fn(float("inf")))
+ def testDisabled(self):
+ flags = FakeFlags("logdir", reload_multifile=False)
+ self.assertIsNone(application._get_event_file_active_filter(flags))
+
+ def testInactiveSecsZero(self):
+ flags = FakeFlags(
+ "logdir", reload_multifile=True, reload_multifile_inactive_secs=0
+ )
+ self.assertIsNone(application._get_event_file_active_filter(flags))
+
+ def testInactiveSecsNegative(self):
+ flags = FakeFlags(
+ "logdir", reload_multifile=True, reload_multifile_inactive_secs=-1
+ )
+ filter_fn = application._get_event_file_active_filter(flags)
+ self.assertTrue(filter_fn(0))
+ self.assertTrue(filter_fn(time.time()))
+ self.assertTrue(filter_fn(float("inf")))
+
+ def testInactiveSecs(self):
+ flags = FakeFlags(
+ "logdir", reload_multifile=True, reload_multifile_inactive_secs=10
+ )
+ filter_fn = application._get_event_file_active_filter(flags)
+ with mock.patch.object(time, "time") as mock_time:
+ mock_time.return_value = 100
+ self.assertFalse(filter_fn(0))
+ self.assertFalse(filter_fn(time.time() - 11))
+ self.assertTrue(filter_fn(time.time() - 10))
+ self.assertTrue(filter_fn(time.time()))
+ self.assertTrue(filter_fn(float("inf")))
class ParseEventFilesSpecTest(tb_test.TestCase):
-
- def assertPlatformSpecificLogdirParsing(self, pathObj, logdir, expected):
- """
- A custom assertion to test :func:`parse_event_files_spec` under various
- systems.
-
- Args:
- pathObj: a custom replacement object for `os.path`, typically
- `posixpath` or `ntpath`
- logdir: the string to be parsed by
- :func:`~application.parse_event_files_spec`
- expected: the expected dictionary as returned by
- :func:`~application.parse_event_files_spec`
-
- """
-
- with mock.patch('os.path', pathObj):
- self.assertEqual(application.parse_event_files_spec(logdir), expected)
-
- def testBasic(self):
- self.assertPlatformSpecificLogdirParsing(
- posixpath, '/lol/cat', {'/lol/cat': None})
- self.assertPlatformSpecificLogdirParsing(
- ntpath, r'C:\lol\cat', {r'C:\lol\cat': None})
-
- def testRunName(self):
- self.assertPlatformSpecificLogdirParsing(
- posixpath, 'lol:/cat', {'/cat': 'lol'})
- self.assertPlatformSpecificLogdirParsing(
- ntpath, 'lol:C:\\cat', {'C:\\cat': 'lol'})
-
- def testPathWithColonThatComesAfterASlash_isNotConsideredARunName(self):
- self.assertPlatformSpecificLogdirParsing(
- posixpath, '/lol:/cat', {'/lol:/cat': None})
-
- def testExpandsUser(self):
- oldhome = os.environ.get('HOME', None)
- try:
- os.environ['HOME'] = '/usr/eliza'
- self.assertPlatformSpecificLogdirParsing(
- posixpath, '~/lol/cat~dog', {'/usr/eliza/lol/cat~dog': None})
- os.environ['HOME'] = r'C:\Users\eliza'
- self.assertPlatformSpecificLogdirParsing(
- ntpath, r'~\lol\cat~dog', {r'C:\Users\eliza\lol\cat~dog': None})
- finally:
- if oldhome is not None:
- os.environ['HOME'] = oldhome
-
- def testExpandsUserForMultipleDirectories(self):
- oldhome = os.environ.get('HOME', None)
- try:
- os.environ['HOME'] = '/usr/eliza'
- self.assertPlatformSpecificLogdirParsing(
- posixpath, 'a:~/lol,b:~/cat',
- {'/usr/eliza/lol': 'a', '/usr/eliza/cat': 'b'})
- os.environ['HOME'] = r'C:\Users\eliza'
- self.assertPlatformSpecificLogdirParsing(
- ntpath, r'aa:~\lol,bb:~\cat',
- {r'C:\Users\eliza\lol': 'aa', r'C:\Users\eliza\cat': 'bb'})
- finally:
- if oldhome is not None:
- os.environ['HOME'] = oldhome
-
- def testMultipleDirectories(self):
- self.assertPlatformSpecificLogdirParsing(
- posixpath, '/a,/b', {'/a': None, '/b': None})
- self.assertPlatformSpecificLogdirParsing(
- ntpath, 'C:\\a,C:\\b', {'C:\\a': None, 'C:\\b': None})
-
- def testNormalizesPaths(self):
- self.assertPlatformSpecificLogdirParsing(
- posixpath, '/lol/.//cat/../cat', {'/lol/cat': None})
- self.assertPlatformSpecificLogdirParsing(
- ntpath, 'C:\\lol\\.\\\\cat\\..\\cat', {'C:\\lol\\cat': None})
-
- def testAbsolutifies(self):
- self.assertPlatformSpecificLogdirParsing(
- posixpath, 'lol/cat', {posixpath.realpath('lol/cat'): None})
- self.assertPlatformSpecificLogdirParsing(
- ntpath, 'lol\\cat', {ntpath.realpath('lol\\cat'): None})
-
- def testRespectsGCSPath(self):
- self.assertPlatformSpecificLogdirParsing(
- posixpath, 'gs://foo/path', {'gs://foo/path': None})
- self.assertPlatformSpecificLogdirParsing(
- ntpath, 'gs://foo/path', {'gs://foo/path': None})
-
- def testRespectsHDFSPath(self):
- self.assertPlatformSpecificLogdirParsing(
- posixpath, 'hdfs://foo/path', {'hdfs://foo/path': None})
- self.assertPlatformSpecificLogdirParsing(
- ntpath, 'hdfs://foo/path', {'hdfs://foo/path': None})
-
- def testDoesNotExpandUserInGCSPath(self):
- self.assertPlatformSpecificLogdirParsing(
- posixpath, 'gs://~/foo/path', {'gs://~/foo/path': None})
- self.assertPlatformSpecificLogdirParsing(
- ntpath, 'gs://~/foo/path', {'gs://~/foo/path': None})
-
- def testDoesNotNormalizeGCSPath(self):
- self.assertPlatformSpecificLogdirParsing(
- posixpath, 'gs://foo/./path//..', {'gs://foo/./path//..': None})
- self.assertPlatformSpecificLogdirParsing(
- ntpath, 'gs://foo/./path//..', {'gs://foo/./path//..': None})
-
- def testRunNameWithGCSPath(self):
- self.assertPlatformSpecificLogdirParsing(
- posixpath, 'lol:gs://foo/path', {'gs://foo/path': 'lol'})
- self.assertPlatformSpecificLogdirParsing(
- ntpath, 'lol:gs://foo/path', {'gs://foo/path': 'lol'})
-
- def testSingleLetterGroup(self):
- self.assertPlatformSpecificLogdirParsing(
- posixpath, 'A:/foo/path', {'/foo/path': 'A'})
- # single letter groups are not supported on Windows
- with self.assertRaises(AssertionError):
- self.assertPlatformSpecificLogdirParsing(
- ntpath, 'A:C:\\foo\\path', {'C:\\foo\\path': 'A'})
+ def assertPlatformSpecificLogdirParsing(self, pathObj, logdir, expected):
+ """A custom assertion to test :func:`parse_event_files_spec` under
+ various systems.
+
+ Args:
+ pathObj: a custom replacement object for `os.path`, typically
+ `posixpath` or `ntpath`
+ logdir: the string to be parsed by
+ :func:`~application.parse_event_files_spec`
+ expected: the expected dictionary as returned by
+ :func:`~application.parse_event_files_spec`
+ """
+
+ with mock.patch("os.path", pathObj):
+ self.assertEqual(
+ application.parse_event_files_spec(logdir), expected
+ )
+
+ def testBasic(self):
+ self.assertPlatformSpecificLogdirParsing(
+ posixpath, "/lol/cat", {"/lol/cat": None}
+ )
+ self.assertPlatformSpecificLogdirParsing(
+ ntpath, r"C:\lol\cat", {r"C:\lol\cat": None}
+ )
+
+ def testRunName(self):
+ self.assertPlatformSpecificLogdirParsing(
+ posixpath, "lol:/cat", {"/cat": "lol"}
+ )
+ self.assertPlatformSpecificLogdirParsing(
+ ntpath, "lol:C:\\cat", {"C:\\cat": "lol"}
+ )
+
+ def testPathWithColonThatComesAfterASlash_isNotConsideredARunName(self):
+ self.assertPlatformSpecificLogdirParsing(
+ posixpath, "/lol:/cat", {"/lol:/cat": None}
+ )
+
+ def testExpandsUser(self):
+ oldhome = os.environ.get("HOME", None)
+ try:
+ os.environ["HOME"] = "/usr/eliza"
+ self.assertPlatformSpecificLogdirParsing(
+ posixpath, "~/lol/cat~dog", {"/usr/eliza/lol/cat~dog": None}
+ )
+ os.environ["HOME"] = r"C:\Users\eliza"
+ self.assertPlatformSpecificLogdirParsing(
+ ntpath, r"~\lol\cat~dog", {r"C:\Users\eliza\lol\cat~dog": None}
+ )
+ finally:
+ if oldhome is not None:
+ os.environ["HOME"] = oldhome
+
+ def testExpandsUserForMultipleDirectories(self):
+ oldhome = os.environ.get("HOME", None)
+ try:
+ os.environ["HOME"] = "/usr/eliza"
+ self.assertPlatformSpecificLogdirParsing(
+ posixpath,
+ "a:~/lol,b:~/cat",
+ {"/usr/eliza/lol": "a", "/usr/eliza/cat": "b"},
+ )
+ os.environ["HOME"] = r"C:\Users\eliza"
+ self.assertPlatformSpecificLogdirParsing(
+ ntpath,
+ r"aa:~\lol,bb:~\cat",
+ {r"C:\Users\eliza\lol": "aa", r"C:\Users\eliza\cat": "bb"},
+ )
+ finally:
+ if oldhome is not None:
+ os.environ["HOME"] = oldhome
+
+ def testMultipleDirectories(self):
+ self.assertPlatformSpecificLogdirParsing(
+ posixpath, "/a,/b", {"/a": None, "/b": None}
+ )
+ self.assertPlatformSpecificLogdirParsing(
+ ntpath, "C:\\a,C:\\b", {"C:\\a": None, "C:\\b": None}
+ )
+
+ def testNormalizesPaths(self):
+ self.assertPlatformSpecificLogdirParsing(
+ posixpath, "/lol/.//cat/../cat", {"/lol/cat": None}
+ )
+ self.assertPlatformSpecificLogdirParsing(
+ ntpath, "C:\\lol\\.\\\\cat\\..\\cat", {"C:\\lol\\cat": None}
+ )
+
+ def testAbsolutifies(self):
+ self.assertPlatformSpecificLogdirParsing(
+ posixpath, "lol/cat", {posixpath.realpath("lol/cat"): None}
+ )
+ self.assertPlatformSpecificLogdirParsing(
+ ntpath, "lol\\cat", {ntpath.realpath("lol\\cat"): None}
+ )
+
+ def testRespectsGCSPath(self):
+ self.assertPlatformSpecificLogdirParsing(
+ posixpath, "gs://foo/path", {"gs://foo/path": None}
+ )
+ self.assertPlatformSpecificLogdirParsing(
+ ntpath, "gs://foo/path", {"gs://foo/path": None}
+ )
+
+ def testRespectsHDFSPath(self):
+ self.assertPlatformSpecificLogdirParsing(
+ posixpath, "hdfs://foo/path", {"hdfs://foo/path": None}
+ )
+ self.assertPlatformSpecificLogdirParsing(
+ ntpath, "hdfs://foo/path", {"hdfs://foo/path": None}
+ )
+
+ def testDoesNotExpandUserInGCSPath(self):
+ self.assertPlatformSpecificLogdirParsing(
+ posixpath, "gs://~/foo/path", {"gs://~/foo/path": None}
+ )
+ self.assertPlatformSpecificLogdirParsing(
+ ntpath, "gs://~/foo/path", {"gs://~/foo/path": None}
+ )
+
+ def testDoesNotNormalizeGCSPath(self):
+ self.assertPlatformSpecificLogdirParsing(
+ posixpath, "gs://foo/./path//..", {"gs://foo/./path//..": None}
+ )
+ self.assertPlatformSpecificLogdirParsing(
+ ntpath, "gs://foo/./path//..", {"gs://foo/./path//..": None}
+ )
+
+ def testRunNameWithGCSPath(self):
+ self.assertPlatformSpecificLogdirParsing(
+ posixpath, "lol:gs://foo/path", {"gs://foo/path": "lol"}
+ )
+ self.assertPlatformSpecificLogdirParsing(
+ ntpath, "lol:gs://foo/path", {"gs://foo/path": "lol"}
+ )
+
+ def testSingleLetterGroup(self):
+ self.assertPlatformSpecificLogdirParsing(
+ posixpath, "A:/foo/path", {"/foo/path": "A"}
+ )
+ # single letter groups are not supported on Windows
+ with self.assertRaises(AssertionError):
+ self.assertPlatformSpecificLogdirParsing(
+ ntpath, "A:C:\\foo\\path", {"C:\\foo\\path": "A"}
+ )
class TensorBoardPluginsTest(tb_test.TestCase):
+ def setUp(self):
+ self.context = None
+ dummy_assets_zip_provider = lambda: None
+ # The application should have added routes for both plugins.
+ self.app = application.standard_tensorboard_wsgi(
+ FakeFlags(logdir=self.get_temp_dir()),
+ [
+ FakePluginLoader(
+ plugin_name="foo",
+ is_active_value=True,
+ routes_mapping={"/foo_route": self._foo_handler},
+ construction_callback=self._construction_callback,
+ ),
+ FakePluginLoader(
+ plugin_name="bar",
+ is_active_value=True,
+ routes_mapping={
+ "/bar_route": self._bar_handler,
+ "/wildcard/*": self._wildcard_handler,
+ "/wildcard/special/*": self._wildcard_special_handler,
+ "/wildcard/special/exact": self._foo_handler,
+ },
+ construction_callback=self._construction_callback,
+ ),
+ FakePluginLoader(
+ plugin_name="whoami",
+ routes_mapping={"/eid": self._eid_handler,},
+ ),
+ ],
+ dummy_assets_zip_provider,
+ )
+
+ self.server = werkzeug_test.Client(self.app, wrappers.BaseResponse)
+
+ def _construction_callback(self, context):
+ """Called when a plugin is constructed."""
+ self.context = context
+
+ def _test_route(self, route, expected_status_code):
+ response = self.server.get(route)
+ self.assertEqual(response.status_code, expected_status_code)
- def setUp(self):
- self.context = None
- dummy_assets_zip_provider = lambda: None
- # The application should have added routes for both plugins.
- self.app = application.standard_tensorboard_wsgi(
- FakeFlags(logdir=self.get_temp_dir()),
- [
- FakePluginLoader(
- plugin_name='foo',
- is_active_value=True,
- routes_mapping={'/foo_route': self._foo_handler},
- construction_callback=self._construction_callback),
- FakePluginLoader(
- plugin_name='bar',
- is_active_value=True,
- routes_mapping={
- '/bar_route': self._bar_handler,
- '/wildcard/*': self._wildcard_handler,
- '/wildcard/special/*': self._wildcard_special_handler,
- '/wildcard/special/exact': self._foo_handler,
- },
- construction_callback=self._construction_callback),
- FakePluginLoader(
- plugin_name='whoami',
- routes_mapping={
- '/eid': self._eid_handler,
- }),
- ],
- dummy_assets_zip_provider)
-
- self.server = werkzeug_test.Client(self.app, wrappers.BaseResponse)
-
- def _construction_callback(self, context):
- """Called when a plugin is constructed."""
- self.context = context
-
- def _test_route(self, route, expected_status_code):
- response = self.server.get(route)
- self.assertEqual(response.status_code, expected_status_code)
-
- @wrappers.Request.application
- def _foo_handler(self, request):
- return wrappers.Response(response='hello world', status=200)
-
- def _bar_handler(self):
- pass
-
- @wrappers.Request.application
- def _eid_handler(self, request):
- eid = plugin_util.experiment_id(request.environ)
- body = json.dumps({'experiment_id': eid})
- return wrappers.Response(body, 200, content_type='application/json')
-
- @wrappers.Request.application
- def _wildcard_handler(self, request):
- if request.path == '/data/plugin/bar/wildcard/ok':
- return wrappers.Response(response='hello world', status=200)
- elif request.path == '/data/plugin/bar/wildcard/':
- # this route cannot actually be hit; see testEmptyWildcardRouteWithSlash.
- return wrappers.Response(response='hello world', status=200)
- else:
- return wrappers.Response(status=401)
-
- @wrappers.Request.application
- def _wildcard_special_handler(self, request):
- return wrappers.Response(status=300)
-
- def testPluginsAdded(self):
- # The routes are prefixed with /data/plugin/[plugin name].
- expected_routes = frozenset((
- '/data/plugin/foo/foo_route',
- '/data/plugin/bar/bar_route',
- ))
- self.assertLessEqual(expected_routes, frozenset(self.app.exact_routes))
-
- def testNameToPluginMapping(self):
- # The mapping from plugin name to instance should include all plugins.
- mapping = self.context.plugin_name_to_instance
- self.assertItemsEqual(['foo', 'bar', 'whoami'], list(mapping.keys()))
- self.assertEqual('foo', mapping['foo'].plugin_name)
- self.assertEqual('bar', mapping['bar'].plugin_name)
- self.assertEqual('whoami', mapping['whoami'].plugin_name)
-
- def testNormalRoute(self):
- self._test_route('/data/plugin/foo/foo_route', 200)
-
- def testNormalRouteIsNotWildcard(self):
- self._test_route('/data/plugin/foo/foo_route/bogus', 404)
-
- def testMissingRoute(self):
- self._test_route('/data/plugin/foo/bogus', 404)
-
- def testExperimentIdIntegration_withNoExperimentId(self):
- response = self.server.get('/data/plugin/whoami/eid')
- self.assertEqual(response.status_code, 200)
- data = json.loads(response.get_data().decode('utf-8'))
- self.assertEqual(data, {'experiment_id': ''})
-
- def testExperimentIdIntegration_withExperimentId(self):
- response = self.server.get('/experiment/123/data/plugin/whoami/eid')
- self.assertEqual(response.status_code, 200)
- data = json.loads(response.get_data().decode('utf-8'))
- self.assertEqual(data, {'experiment_id': '123'})
-
- def testEmptyRoute(self):
- self._test_route('', 301)
-
- def testSlashlessRoute(self):
- self._test_route('runaway', 404)
-
- def testWildcardAcceptedRoute(self):
- self._test_route('/data/plugin/bar/wildcard/ok', 200)
-
- def testLongerWildcardRouteTakesPrecedence(self):
- # This tests that the longer 'special' wildcard takes precedence over
- # the shorter one.
- self._test_route('/data/plugin/bar/wildcard/special/blah', 300)
-
- def testExactRouteTakesPrecedence(self):
- # This tests that an exact match takes precedence over a wildcard.
- self._test_route('/data/plugin/bar/wildcard/special/exact', 200)
-
- def testWildcardRejectedRoute(self):
- # A plugin may reject a request passed to it via a wildcard route.
- # Note our test plugin returns 401 in this case, to distinguish this
- # response from a 404 passed if the route were not found.
- self._test_route('/data/plugin/bar/wildcard/bogus', 401)
-
- def testWildcardRouteWithoutSlash(self):
- # A wildcard route requires a slash before the '*'.
- # Lacking one, no route is matched.
- self._test_route('/data/plugin/bar/wildcard', 404)
-
- def testEmptyWildcardRouteWithSlash(self):
- # A wildcard route requires a slash before the '*'. Here we include the
- # slash, so we might expect the route to match.
- #
- # However: Trailing slashes are automatically removed from incoming requests
- # in _clean_path(). Consequently, this request does not match the wildcard
- # route after all.
- #
- # Note the test plugin specifically accepts this route (returning 200), so
- # the fact that 404 is returned demonstrates that the plugin was not
- # consulted.
- self._test_route('/data/plugin/bar/wildcard/', 404)
+ @wrappers.Request.application
+ def _foo_handler(self, request):
+ return wrappers.Response(response="hello world", status=200)
+ def _bar_handler(self):
+ pass
-class DbTest(tb_test.TestCase):
+ @wrappers.Request.application
+ def _eid_handler(self, request):
+ eid = plugin_util.experiment_id(request.environ)
+ body = json.dumps({"experiment_id": eid})
+ return wrappers.Response(body, 200, content_type="application/json")
- def testSqliteDb(self):
- db_uri = 'sqlite:' + os.path.join(self.get_temp_dir(), 'db')
- db_connection_provider = application.create_sqlite_connection_provider(
- db_uri)
- with contextlib.closing(db_connection_provider()) as conn:
- with conn:
- with contextlib.closing(conn.cursor()) as c:
- c.execute('create table peeps (name text)')
- c.execute('insert into peeps (name) values (?)', ('justine',))
- db_connection_provider = application.create_sqlite_connection_provider(
- db_uri)
- with contextlib.closing(db_connection_provider()) as conn:
- with contextlib.closing(conn.cursor()) as c:
- c.execute('select name from peeps')
- self.assertEqual(('justine',), c.fetchone())
-
- def testTransactionRollback(self):
- db_uri = 'sqlite:' + os.path.join(self.get_temp_dir(), 'db')
- db_connection_provider = application.create_sqlite_connection_provider(
- db_uri)
- with contextlib.closing(db_connection_provider()) as conn:
- with conn:
- with contextlib.closing(conn.cursor()) as c:
- c.execute('create table peeps (name text)')
- try:
- with conn:
- with contextlib.closing(conn.cursor()) as c:
- c.execute('insert into peeps (name) values (?)', ('justine',))
- raise IOError('hi')
- except IOError:
- pass
- with contextlib.closing(conn.cursor()) as c:
- c.execute('select name from peeps')
- self.assertIsNone(c.fetchone())
-
- def testTransactionRollback_doesntDoAnythingIfIsolationLevelIsNone(self):
- # NOTE: This is a terrible idea. Don't do this.
- db_uri = ('sqlite:' + os.path.join(self.get_temp_dir(), 'db') +
- '?isolation_level=null')
- db_connection_provider = application.create_sqlite_connection_provider(
- db_uri)
- with contextlib.closing(db_connection_provider()) as conn:
- with conn:
- with contextlib.closing(conn.cursor()) as c:
- c.execute('create table peeps (name text)')
- try:
- with conn:
- with contextlib.closing(conn.cursor()) as c:
- c.execute('insert into peeps (name) values (?)', ('justine',))
- raise IOError('hi')
- except IOError:
- pass
- with contextlib.closing(conn.cursor()) as c:
- c.execute('select name from peeps')
- self.assertEqual(('justine',), c.fetchone())
+ @wrappers.Request.application
+ def _wildcard_handler(self, request):
+ if request.path == "/data/plugin/bar/wildcard/ok":
+ return wrappers.Response(response="hello world", status=200)
+ elif request.path == "/data/plugin/bar/wildcard/":
+ # this route cannot actually be hit; see testEmptyWildcardRouteWithSlash.
+ return wrappers.Response(response="hello world", status=200)
+ else:
+ return wrappers.Response(status=401)
- def testSqliteUriErrors(self):
- with self.assertRaises(ValueError):
- application.create_sqlite_connection_provider("lol:cat")
- with self.assertRaises(ValueError):
- application.create_sqlite_connection_provider("sqlite::memory:")
- with self.assertRaises(ValueError):
- application.create_sqlite_connection_provider("sqlite://foo.example/bar")
+ @wrappers.Request.application
+ def _wildcard_special_handler(self, request):
+ return wrappers.Response(status=300)
+
+ def testPluginsAdded(self):
+ # The routes are prefixed with /data/plugin/[plugin name].
+ expected_routes = frozenset(
+ ("/data/plugin/foo/foo_route", "/data/plugin/bar/bar_route",)
+ )
+ self.assertLessEqual(expected_routes, frozenset(self.app.exact_routes))
+
+ def testNameToPluginMapping(self):
+ # The mapping from plugin name to instance should include all plugins.
+ mapping = self.context.plugin_name_to_instance
+ self.assertItemsEqual(["foo", "bar", "whoami"], list(mapping.keys()))
+ self.assertEqual("foo", mapping["foo"].plugin_name)
+ self.assertEqual("bar", mapping["bar"].plugin_name)
+ self.assertEqual("whoami", mapping["whoami"].plugin_name)
+
+ def testNormalRoute(self):
+ self._test_route("/data/plugin/foo/foo_route", 200)
+
+ def testNormalRouteIsNotWildcard(self):
+ self._test_route("/data/plugin/foo/foo_route/bogus", 404)
+
+ def testMissingRoute(self):
+ self._test_route("/data/plugin/foo/bogus", 404)
+
+ def testExperimentIdIntegration_withNoExperimentId(self):
+ response = self.server.get("/data/plugin/whoami/eid")
+ self.assertEqual(response.status_code, 200)
+ data = json.loads(response.get_data().decode("utf-8"))
+ self.assertEqual(data, {"experiment_id": ""})
+
+ def testExperimentIdIntegration_withExperimentId(self):
+ response = self.server.get("/experiment/123/data/plugin/whoami/eid")
+ self.assertEqual(response.status_code, 200)
+ data = json.loads(response.get_data().decode("utf-8"))
+ self.assertEqual(data, {"experiment_id": "123"})
+
+ def testEmptyRoute(self):
+ self._test_route("", 301)
+
+ def testSlashlessRoute(self):
+ self._test_route("runaway", 404)
+
+ def testWildcardAcceptedRoute(self):
+ self._test_route("/data/plugin/bar/wildcard/ok", 200)
+
+ def testLongerWildcardRouteTakesPrecedence(self):
+ # This tests that the longer 'special' wildcard takes precedence over
+ # the shorter one.
+ self._test_route("/data/plugin/bar/wildcard/special/blah", 300)
+
+ def testExactRouteTakesPrecedence(self):
+ # This tests that an exact match takes precedence over a wildcard.
+ self._test_route("/data/plugin/bar/wildcard/special/exact", 200)
+
+ def testWildcardRejectedRoute(self):
+ # A plugin may reject a request passed to it via a wildcard route.
+ # Note our test plugin returns 401 in this case, to distinguish this
+ # response from a 404 passed if the route were not found.
+ self._test_route("/data/plugin/bar/wildcard/bogus", 401)
+
+ def testWildcardRouteWithoutSlash(self):
+ # A wildcard route requires a slash before the '*'.
+ # Lacking one, no route is matched.
+ self._test_route("/data/plugin/bar/wildcard", 404)
+
+ def testEmptyWildcardRouteWithSlash(self):
+ # A wildcard route requires a slash before the '*'. Here we include the
+ # slash, so we might expect the route to match.
+ #
+ # However: Trailing slashes are automatically removed from incoming requests
+ # in _clean_path(). Consequently, this request does not match the wildcard
+ # route after all.
+ #
+ # Note the test plugin specifically accepts this route (returning 200), so
+ # the fact that 404 is returned demonstrates that the plugin was not
+ # consulted.
+ self._test_route("/data/plugin/bar/wildcard/", 404)
-if __name__ == '__main__':
- tb_test.main()
+class DbTest(tb_test.TestCase):
+ def testSqliteDb(self):
+ db_uri = "sqlite:" + os.path.join(self.get_temp_dir(), "db")
+ db_connection_provider = application.create_sqlite_connection_provider(
+ db_uri
+ )
+ with contextlib.closing(db_connection_provider()) as conn:
+ with conn:
+ with contextlib.closing(conn.cursor()) as c:
+ c.execute("create table peeps (name text)")
+ c.execute(
+ "insert into peeps (name) values (?)", ("justine",)
+ )
+ db_connection_provider = application.create_sqlite_connection_provider(
+ db_uri
+ )
+ with contextlib.closing(db_connection_provider()) as conn:
+ with contextlib.closing(conn.cursor()) as c:
+ c.execute("select name from peeps")
+ self.assertEqual(("justine",), c.fetchone())
+
+ def testTransactionRollback(self):
+ db_uri = "sqlite:" + os.path.join(self.get_temp_dir(), "db")
+ db_connection_provider = application.create_sqlite_connection_provider(
+ db_uri
+ )
+ with contextlib.closing(db_connection_provider()) as conn:
+ with conn:
+ with contextlib.closing(conn.cursor()) as c:
+ c.execute("create table peeps (name text)")
+ try:
+ with conn:
+ with contextlib.closing(conn.cursor()) as c:
+ c.execute(
+ "insert into peeps (name) values (?)", ("justine",)
+ )
+ raise IOError("hi")
+ except IOError:
+ pass
+ with contextlib.closing(conn.cursor()) as c:
+ c.execute("select name from peeps")
+ self.assertIsNone(c.fetchone())
+
+ def testTransactionRollback_doesntDoAnythingIfIsolationLevelIsNone(self):
+ # NOTE: This is a terrible idea. Don't do this.
+ db_uri = (
+ "sqlite:"
+ + os.path.join(self.get_temp_dir(), "db")
+ + "?isolation_level=null"
+ )
+ db_connection_provider = application.create_sqlite_connection_provider(
+ db_uri
+ )
+ with contextlib.closing(db_connection_provider()) as conn:
+ with conn:
+ with contextlib.closing(conn.cursor()) as c:
+ c.execute("create table peeps (name text)")
+ try:
+ with conn:
+ with contextlib.closing(conn.cursor()) as c:
+ c.execute(
+ "insert into peeps (name) values (?)", ("justine",)
+ )
+ raise IOError("hi")
+ except IOError:
+ pass
+ with contextlib.closing(conn.cursor()) as c:
+ c.execute("select name from peeps")
+ self.assertEqual(("justine",), c.fetchone())
+
+ def testSqliteUriErrors(self):
+ with self.assertRaises(ValueError):
+ application.create_sqlite_connection_provider("lol:cat")
+ with self.assertRaises(ValueError):
+ application.create_sqlite_connection_provider("sqlite::memory:")
+ with self.assertRaises(ValueError):
+ application.create_sqlite_connection_provider(
+ "sqlite://foo.example/bar"
+ )
+
+
+if __name__ == "__main__":
+ tb_test.main()
diff --git a/tensorboard/backend/empty_path_redirect.py b/tensorboard/backend/empty_path_redirect.py
index d2d5be4fe0..69b7cae5fc 100644
--- a/tensorboard/backend/empty_path_redirect.py
+++ b/tensorboard/backend/empty_path_redirect.py
@@ -31,20 +31,20 @@
class EmptyPathRedirectMiddleware(object):
- """WSGI middleware to redirect from "" to "/"."""
-
- def __init__(self, application):
- """Initializes this middleware.
-
- Args:
- application: The WSGI application to wrap (see PEP 3333).
- """
- self._application = application
-
- def __call__(self, environ, start_response):
- path = environ.get("PATH_INFO", "")
- if path:
- return self._application(environ, start_response)
- location = environ.get("SCRIPT_NAME", "") + "/"
- start_response("301 Moved Permanently", [("Location", location)])
- return []
+ """WSGI middleware to redirect from "" to "/"."""
+
+ def __init__(self, application):
+ """Initializes this middleware.
+
+ Args:
+ application: The WSGI application to wrap (see PEP 3333).
+ """
+ self._application = application
+
+ def __call__(self, environ, start_response):
+ path = environ.get("PATH_INFO", "")
+ if path:
+ return self._application(environ, start_response)
+ location = environ.get("SCRIPT_NAME", "") + "/"
+ start_response("301 Moved Permanently", [("Location", location)])
+ return []
diff --git a/tensorboard/backend/empty_path_redirect_test.py b/tensorboard/backend/empty_path_redirect_test.py
index 4aac642fda..7a643e14c3 100644
--- a/tensorboard/backend/empty_path_redirect_test.py
+++ b/tensorboard/backend/empty_path_redirect_test.py
@@ -28,46 +28,48 @@
class EmptyPathRedirectMiddlewareTest(tb_test.TestCase):
- """Tests for `EmptyPathRedirectMiddleware`."""
+ """Tests for `EmptyPathRedirectMiddleware`."""
- def setUp(self):
- super(EmptyPathRedirectMiddlewareTest, self).setUp()
- app = werkzeug.Request.application(lambda req: werkzeug.Response(req.path))
- app = empty_path_redirect.EmptyPathRedirectMiddleware(app)
- app = self._lax_strip_foo_middleware(app)
- self.app = app
- self.server = werkzeug_test.Client(self.app, werkzeug.BaseResponse)
+ def setUp(self):
+ super(EmptyPathRedirectMiddlewareTest, self).setUp()
+ app = werkzeug.Request.application(
+ lambda req: werkzeug.Response(req.path)
+ )
+ app = empty_path_redirect.EmptyPathRedirectMiddleware(app)
+ app = self._lax_strip_foo_middleware(app)
+ self.app = app
+ self.server = werkzeug_test.Client(self.app, werkzeug.BaseResponse)
- def _lax_strip_foo_middleware(self, app):
- """Strips a `/foo` prefix if it exists; no-op otherwise."""
+ def _lax_strip_foo_middleware(self, app):
+ """Strips a `/foo` prefix if it exists; no-op otherwise."""
- def wrapper(environ, start_response):
- path = environ.get("PATH_INFO", "")
- if path.startswith("/foo"):
- environ["PATH_INFO"] = path[len("/foo") :]
- environ["SCRIPT_NAME"] = "/foo"
- return app(environ, start_response)
+ def wrapper(environ, start_response):
+ path = environ.get("PATH_INFO", "")
+ if path.startswith("/foo"):
+ environ["PATH_INFO"] = path[len("/foo") :]
+ environ["SCRIPT_NAME"] = "/foo"
+ return app(environ, start_response)
- return wrapper
+ return wrapper
- def test_normal_route_not_redirected(self):
- response = self.server.get("/foo/bar")
- self.assertEqual(response.status_code, 200)
+ def test_normal_route_not_redirected(self):
+ response = self.server.get("/foo/bar")
+ self.assertEqual(response.status_code, 200)
- def test_slash_not_redirected(self):
- response = self.server.get("/foo/")
- self.assertEqual(response.status_code, 200)
+ def test_slash_not_redirected(self):
+ response = self.server.get("/foo/")
+ self.assertEqual(response.status_code, 200)
- def test_empty_redirected_with_script_name(self):
- response = self.server.get("/foo")
- self.assertEqual(response.status_code, 301)
- self.assertEqual(response.headers["Location"], "/foo/")
+ def test_empty_redirected_with_script_name(self):
+ response = self.server.get("/foo")
+ self.assertEqual(response.status_code, 301)
+ self.assertEqual(response.headers["Location"], "/foo/")
- def test_empty_redirected_with_blank_script_name(self):
- response = self.server.get("")
- self.assertEqual(response.status_code, 301)
- self.assertEqual(response.headers["Location"], "/")
+ def test_empty_redirected_with_blank_script_name(self):
+ response = self.server.get("")
+ self.assertEqual(response.status_code, 301)
+ self.assertEqual(response.headers["Location"], "/")
if __name__ == "__main__":
- tb_test.main()
+ tb_test.main()
diff --git a/tensorboard/backend/event_processing/data_provider.py b/tensorboard/backend/event_processing/data_provider.py
index 859ed7be73..aca177ef51 100644
--- a/tensorboard/backend/event_processing/data_provider.py
+++ b/tensorboard/backend/event_processing/data_provider.py
@@ -35,285 +35,292 @@
class MultiplexerDataProvider(provider.DataProvider):
- def __init__(self, multiplexer, logdir):
- """Trivial initializer.
-
- Args:
- multiplexer: A `plugin_event_multiplexer.EventMultiplexer` (note:
- not a boring old `event_multiplexer.EventMultiplexer`).
- logdir: The log directory from which data is being read. Only used
- cosmetically. Should be a `str`.
- """
- self._multiplexer = multiplexer
- self._logdir = logdir
-
- def _validate_experiment_id(self, experiment_id):
- # This data provider doesn't consume the experiment ID at all, but
- # as a courtesy to callers we require that it be a valid string, to
- # help catch usage errors.
- if not isinstance(experiment_id, str):
- raise TypeError(
- "experiment_id must be %r, but got %r: %r"
- % (str, type(experiment_id), experiment_id)
- )
-
- def _test_run_tag(self, run_tag_filter, run, tag):
- runs = run_tag_filter.runs
- if runs is not None and run not in runs:
- return False
- tags = run_tag_filter.tags
- if tags is not None and tag not in tags:
- return False
- return True
-
- def _get_first_event_timestamp(self, run_name):
- try:
- return self._multiplexer.FirstEventTimestamp(run_name)
- except ValueError as e:
- return None
-
- def data_location(self, experiment_id):
- self._validate_experiment_id(experiment_id)
- return str(self._logdir)
-
- def list_runs(self, experiment_id):
- self._validate_experiment_id(experiment_id)
- return [
- provider.Run(
- run_id=run, # use names as IDs
- run_name=run,
- start_time=self._get_first_event_timestamp(run),
+ def __init__(self, multiplexer, logdir):
+ """Trivial initializer.
+
+ Args:
+ multiplexer: A `plugin_event_multiplexer.EventMultiplexer` (note:
+ not a boring old `event_multiplexer.EventMultiplexer`).
+ logdir: The log directory from which data is being read. Only used
+ cosmetically. Should be a `str`.
+ """
+ self._multiplexer = multiplexer
+ self._logdir = logdir
+
+ def _validate_experiment_id(self, experiment_id):
+ # This data provider doesn't consume the experiment ID at all, but
+ # as a courtesy to callers we require that it be a valid string, to
+ # help catch usage errors.
+ if not isinstance(experiment_id, str):
+ raise TypeError(
+ "experiment_id must be %r, but got %r: %r"
+ % (str, type(experiment_id), experiment_id)
+ )
+
+ def _test_run_tag(self, run_tag_filter, run, tag):
+ runs = run_tag_filter.runs
+ if runs is not None and run not in runs:
+ return False
+ tags = run_tag_filter.tags
+ if tags is not None and tag not in tags:
+ return False
+ return True
+
+ def _get_first_event_timestamp(self, run_name):
+ try:
+ return self._multiplexer.FirstEventTimestamp(run_name)
+ except ValueError as e:
+ return None
+
+ def data_location(self, experiment_id):
+ self._validate_experiment_id(experiment_id)
+ return str(self._logdir)
+
+ def list_runs(self, experiment_id):
+ self._validate_experiment_id(experiment_id)
+ return [
+ provider.Run(
+ run_id=run, # use names as IDs
+ run_name=run,
+ start_time=self._get_first_event_timestamp(run),
+ )
+ for run in self._multiplexer.Runs()
+ ]
+
+ def list_scalars(self, experiment_id, plugin_name, run_tag_filter=None):
+ self._validate_experiment_id(experiment_id)
+ run_tag_content = self._multiplexer.PluginRunToTagToContent(plugin_name)
+ return self._list(
+ provider.ScalarTimeSeries, run_tag_content, run_tag_filter
)
- for run in self._multiplexer.Runs()
- ]
-
- def list_scalars(self, experiment_id, plugin_name, run_tag_filter=None):
- self._validate_experiment_id(experiment_id)
- run_tag_content = self._multiplexer.PluginRunToTagToContent(plugin_name)
- return self._list(
- provider.ScalarTimeSeries, run_tag_content, run_tag_filter
- )
- def read_scalars(
- self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
- ):
- # TODO(@wchargin): Downsampling not implemented, as the multiplexer
- # is already downsampled. We could downsample on top of the existing
- # sampling, which would be nice for testing.
- del downsample # ignored for now
- index = self.list_scalars(
- experiment_id, plugin_name, run_tag_filter=run_tag_filter
- )
- return self._read(_convert_scalar_event, index)
+ def read_scalars(
+ self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
+ ):
+ # TODO(@wchargin): Downsampling not implemented, as the multiplexer
+ # is already downsampled. We could downsample on top of the existing
+ # sampling, which would be nice for testing.
+ del downsample # ignored for now
+ index = self.list_scalars(
+ experiment_id, plugin_name, run_tag_filter=run_tag_filter
+ )
+ return self._read(_convert_scalar_event, index)
- def list_tensors(self, experiment_id, plugin_name, run_tag_filter=None):
- self._validate_experiment_id(experiment_id)
- run_tag_content = self._multiplexer.PluginRunToTagToContent(plugin_name)
- return self._list(
- provider.TensorTimeSeries, run_tag_content, run_tag_filter
- )
+ def list_tensors(self, experiment_id, plugin_name, run_tag_filter=None):
+ self._validate_experiment_id(experiment_id)
+ run_tag_content = self._multiplexer.PluginRunToTagToContent(plugin_name)
+ return self._list(
+ provider.TensorTimeSeries, run_tag_content, run_tag_filter
+ )
- def read_tensors(
- self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
- ):
- # TODO(@wchargin): Downsampling not implemented, as the multiplexer
- # is already downsampled. We could downsample on top of the existing
- # sampling, which would be nice for testing.
- del downsample # ignored for now
- index = self.list_tensors(
- experiment_id, plugin_name, run_tag_filter=run_tag_filter
- )
- return self._read(_convert_tensor_event, index)
+ def read_tensors(
+ self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
+ ):
+ # TODO(@wchargin): Downsampling not implemented, as the multiplexer
+ # is already downsampled. We could downsample on top of the existing
+ # sampling, which would be nice for testing.
+ del downsample # ignored for now
+ index = self.list_tensors(
+ experiment_id, plugin_name, run_tag_filter=run_tag_filter
+ )
+ return self._read(_convert_tensor_event, index)
+
+ def _list(self, construct_time_series, run_tag_content, run_tag_filter):
+ """Helper to list scalar or tensor time series.
+
+ Args:
+ construct_time_series: `ScalarTimeSeries` or `TensorTimeSeries`.
+ run_tag_content: Result of `_multiplexer.PluginRunToTagToContent(...)`.
+ run_tag_filter: As given by the client; may be `None`.
+
+ Returns:
+ A list of objects of type given by `construct_time_series`,
+ suitable to be returned from `list_scalars` or `list_tensors`.
+ """
+ result = {}
+ if run_tag_filter is None:
+ run_tag_filter = provider.RunTagFilter(runs=None, tags=None)
+ for (run, tag_to_content) in six.iteritems(run_tag_content):
+ result_for_run = {}
+ for tag in tag_to_content:
+ if not self._test_run_tag(run_tag_filter, run, tag):
+ continue
+ result[run] = result_for_run
+ max_step = None
+ max_wall_time = None
+ for event in self._multiplexer.Tensors(run, tag):
+ if max_step is None or max_step < event.step:
+ max_step = event.step
+ if max_wall_time is None or max_wall_time < event.wall_time:
+ max_wall_time = event.wall_time
+ summary_metadata = self._multiplexer.SummaryMetadata(run, tag)
+ result_for_run[tag] = construct_time_series(
+ max_step=max_step,
+ max_wall_time=max_wall_time,
+ plugin_content=summary_metadata.plugin_data.content,
+ description=summary_metadata.summary_description,
+ display_name=summary_metadata.display_name,
+ )
+ return result
+
+ def _read(self, convert_event, index):
+ """Helper to read scalar or tensor data from the multiplexer.
+
+ Args:
+ convert_event: Takes `plugin_event_accumulator.TensorEvent` to
+ either `provider.ScalarDatum` or `provider.TensorDatum`.
+ index: The result of `list_scalars` or `list_tensors`.
+
+ Returns:
+ A dict of dicts of values returned by `convert_event` calls,
+ suitable to be returned from `read_scalars` or `read_tensors`.
+ """
+ result = {}
+ for (run, tags_for_run) in six.iteritems(index):
+ result_for_run = {}
+ result[run] = result_for_run
+ for (tag, metadata) in six.iteritems(tags_for_run):
+ events = self._multiplexer.Tensors(run, tag)
+ result_for_run[tag] = [convert_event(e) for e in events]
+ return result
+
+ def list_blob_sequences(
+ self, experiment_id, plugin_name, run_tag_filter=None
+ ):
+ self._validate_experiment_id(experiment_id)
+ if run_tag_filter is None:
+ run_tag_filter = provider.RunTagFilter(runs=None, tags=None)
+
+ # TODO(davidsoergel, wchargin): consider images, etc.
+ # Note this plugin_name can really just be 'graphs' for now; the
+ # v2 cases are not handled yet.
+ if plugin_name != graphs_metadata.PLUGIN_NAME:
+ logger.warn("Directory has no blob data for plugin %r", plugin_name)
+ return {}
+
+ result = collections.defaultdict(lambda: {})
+ for (run, run_info) in six.iteritems(self._multiplexer.Runs()):
+ tag = None
+ if not self._test_run_tag(run_tag_filter, run, tag):
+ continue
+ if not run_info[plugin_event_accumulator.GRAPH]:
+ continue
+ result[run][tag] = provider.BlobSequenceTimeSeries(
+ max_step=0,
+ max_wall_time=0,
+ latest_max_index=0, # Graphs are always one blob at a time
+ plugin_content=None,
+ description=None,
+ display_name=None,
+ )
+ return result
+
+ def read_blob_sequences(
+ self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
+ ):
+ self._validate_experiment_id(experiment_id)
+ # TODO(davidsoergel, wchargin): consider images, etc.
+ # Note this plugin_name can really just be 'graphs' for now; the
+ # v2 cases are not handled yet.
+ if plugin_name != graphs_metadata.PLUGIN_NAME:
+ logger.warn("Directory has no blob data for plugin %r", plugin_name)
+ return {}
+
+ result = collections.defaultdict(
+ lambda: collections.defaultdict(lambda: [])
+ )
+ for (run, run_info) in six.iteritems(self._multiplexer.Runs()):
+ tag = None
+ if not self._test_run_tag(run_tag_filter, run, tag):
+ continue
+ if not run_info[plugin_event_accumulator.GRAPH]:
+ continue
+
+ time_series = result[run][tag]
+
+ wall_time = 0.0 # dummy value for graph
+ step = 0 # dummy value for graph
+ index = 0 # dummy value for graph
+
+ # In some situations these blobs may have directly accessible URLs.
+ # But, for now, we assume they don't.
+ graph_url = None
+ graph_blob_key = _encode_blob_key(
+ experiment_id, plugin_name, run, tag, step, index
+ )
+ blob_ref = provider.BlobReference(graph_blob_key, graph_url)
+
+ datum = provider.BlobSequenceDatum(
+ wall_time=wall_time, step=step, values=(blob_ref,),
+ )
+ time_series.append(datum)
+ return result
+
+ def read_blob(self, blob_key):
+ # note: ignoring nearly all key elements: there is only one graph per run.
+ (
+ unused_experiment_id,
+ plugin_name,
+ run,
+ unused_tag,
+ unused_step,
+ unused_index,
+ ) = _decode_blob_key(blob_key)
+
+ # TODO(davidsoergel, wchargin): consider images, etc.
+ if plugin_name != graphs_metadata.PLUGIN_NAME:
+ logger.warn("Directory has no blob data for plugin %r", plugin_name)
+ raise errors.NotFoundError()
+
+ serialized_graph = self._multiplexer.SerializedGraph(run)
+
+ # TODO(davidsoergel): graph_defs have no step attribute so we don't filter
+ # on it. Other blob types might, though.
+
+ if serialized_graph is None:
+ logger.warn("No blob found for key %r", blob_key)
+ raise errors.NotFoundError()
+
+ # TODO(davidsoergel): consider internal structure of non-graphdef blobs.
+ # In particular, note we ignore the requested index, since it's always 0.
+ return serialized_graph
- def _list(self, construct_time_series, run_tag_content, run_tag_filter):
- """Helper to list scalar or tensor time series.
- Args:
- construct_time_series: `ScalarTimeSeries` or `TensorTimeSeries`.
- run_tag_content: Result of `_multiplexer.PluginRunToTagToContent(...)`.
- run_tag_filter: As given by the client; may be `None`.
+# TODO(davidsoergel): deduplicate with other implementations
+def _encode_blob_key(experiment_id, plugin_name, run, tag, step, index):
+ """Generate a blob key: a short, URL-safe string identifying a blob.
- Returns:
- A list of objects of type given by `construct_time_series`,
- suitable to be returned from `list_scalars` or `list_tensors`.
- """
- result = {}
- if run_tag_filter is None:
- run_tag_filter = provider.RunTagFilter(runs=None, tags=None)
- for (run, tag_to_content) in six.iteritems(run_tag_content):
- result_for_run = {}
- for tag in tag_to_content:
- if not self._test_run_tag(run_tag_filter, run, tag):
- continue
- result[run] = result_for_run
- max_step = None
- max_wall_time = None
- for event in self._multiplexer.Tensors(run, tag):
- if max_step is None or max_step < event.step:
- max_step = event.step
- if max_wall_time is None or max_wall_time < event.wall_time:
- max_wall_time = event.wall_time
- summary_metadata = self._multiplexer.SummaryMetadata(run, tag)
- result_for_run[tag] = construct_time_series(
- max_step=max_step,
- max_wall_time=max_wall_time,
- plugin_content=summary_metadata.plugin_data.content,
- description=summary_metadata.summary_description,
- display_name=summary_metadata.display_name,
- )
- return result
+ A blob can be located using a set of integer and string fields; here we
+ serialize these to allow passing the data through a URL. Specifically, we
+ 1) construct a tuple of the arguments in order; 2) represent that as an
+ ascii-encoded JSON string (without whitespace); and 3) take the URL-safe
+ base64 encoding of that, with no padding. For example:
- def _read(self, convert_event, index):
- """Helper to read scalar or tensor data from the multiplexer.
+ 1) Tuple: ("some_id", "graphs", "train", "graph_def", 2, 0)
+ 2) JSON: ["some_id","graphs","train","graph_def",2,0]
+ 3) base64: WyJzb21lX2lkIiwiZ3JhcGhzIiwidHJhaW4iLCJncmFwaF9kZWYiLDIsMF0K
Args:
- convert_event: Takes `plugin_event_accumulator.TensorEvent` to
- either `provider.ScalarDatum` or `provider.TensorDatum`.
- index: The result of `list_scalars` or `list_tensors`.
+ experiment_id: a string ID identifying an experiment.
+ plugin_name: string
+ run: string
+ tag: string
+ step: int
+ index: int
Returns:
- A dict of dicts of values returned by `convert_event` calls,
- suitable to be returned from `read_scalars` or `read_tensors`.
+ A URL-safe base64-encoded string representing the provided arguments.
"""
- result = {}
- for (run, tags_for_run) in six.iteritems(index):
- result_for_run = {}
- result[run] = result_for_run
- for (tag, metadata) in six.iteritems(tags_for_run):
- events = self._multiplexer.Tensors(run, tag)
- result_for_run[tag] = [convert_event(e) for e in events]
- return result
-
- def list_blob_sequences(
- self, experiment_id, plugin_name, run_tag_filter=None
- ):
- self._validate_experiment_id(experiment_id)
- if run_tag_filter is None:
- run_tag_filter = provider.RunTagFilter(runs=None, tags=None)
-
- # TODO(davidsoergel, wchargin): consider images, etc.
- # Note this plugin_name can really just be 'graphs' for now; the
- # v2 cases are not handled yet.
- if plugin_name != graphs_metadata.PLUGIN_NAME:
- logger.warn("Directory has no blob data for plugin %r", plugin_name)
- return {}
-
- result = collections.defaultdict(lambda: {})
- for (run, run_info) in six.iteritems(self._multiplexer.Runs()):
- tag = None
- if not self._test_run_tag(run_tag_filter, run, tag):
- continue
- if not run_info[plugin_event_accumulator.GRAPH]:
- continue
- result[run][tag] = provider.BlobSequenceTimeSeries(
- max_step=0,
- max_wall_time=0,
- latest_max_index=0, # Graphs are always one blob at a time
- plugin_content=None,
- description=None,
- display_name=None,
- )
- return result
-
- def read_blob_sequences(
- self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
- ):
- self._validate_experiment_id(experiment_id)
- # TODO(davidsoergel, wchargin): consider images, etc.
- # Note this plugin_name can really just be 'graphs' for now; the
- # v2 cases are not handled yet.
- if plugin_name != graphs_metadata.PLUGIN_NAME:
- logger.warn("Directory has no blob data for plugin %r", plugin_name)
- return {}
-
- result = collections.defaultdict(
- lambda: collections.defaultdict(lambda: []))
- for (run, run_info) in six.iteritems(self._multiplexer.Runs()):
- tag = None
- if not self._test_run_tag(run_tag_filter, run, tag):
- continue
- if not run_info[plugin_event_accumulator.GRAPH]:
- continue
-
- time_series = result[run][tag]
-
- wall_time = 0. # dummy value for graph
- step = 0 # dummy value for graph
- index = 0 # dummy value for graph
-
- # In some situations these blobs may have directly accessible URLs.
- # But, for now, we assume they don't.
- graph_url = None
- graph_blob_key = _encode_blob_key(
- experiment_id, plugin_name, run, tag, step, index)
- blob_ref = provider.BlobReference(graph_blob_key, graph_url)
-
- datum = provider.BlobSequenceDatum(
- wall_time=wall_time,
- step=step,
- values=(blob_ref,),
- )
- time_series.append(datum)
- return result
-
- def read_blob(self, blob_key):
- # note: ignoring nearly all key elements: there is only one graph per run.
- (unused_experiment_id, plugin_name, run, unused_tag, unused_step,
- unused_index) = _decode_blob_key(blob_key)
-
- # TODO(davidsoergel, wchargin): consider images, etc.
- if plugin_name != graphs_metadata.PLUGIN_NAME:
- logger.warn("Directory has no blob data for plugin %r", plugin_name)
- raise errors.NotFoundError()
-
- serialized_graph = self._multiplexer.SerializedGraph(run)
-
- # TODO(davidsoergel): graph_defs have no step attribute so we don't filter
- # on it. Other blob types might, though.
-
- if serialized_graph is None:
- logger.warn("No blob found for key %r", blob_key)
- raise errors.NotFoundError()
-
- # TODO(davidsoergel): consider internal structure of non-graphdef blobs.
- # In particular, note we ignore the requested index, since it's always 0.
- return serialized_graph
-
-
-# TODO(davidsoergel): deduplicate with other implementations
-def _encode_blob_key(experiment_id, plugin_name, run, tag, step, index):
- """Generate a blob key: a short, URL-safe string identifying a blob.
-
- A blob can be located using a set of integer and string fields; here we
- serialize these to allow passing the data through a URL. Specifically, we
- 1) construct a tuple of the arguments in order; 2) represent that as an
- ascii-encoded JSON string (without whitespace); and 3) take the URL-safe
- base64 encoding of that, with no padding. For example:
-
- 1) Tuple: ("some_id", "graphs", "train", "graph_def", 2, 0)
- 2) JSON: ["some_id","graphs","train","graph_def",2,0]
- 3) base64: WyJzb21lX2lkIiwiZ3JhcGhzIiwidHJhaW4iLCJncmFwaF9kZWYiLDIsMF0K
-
- Args:
- experiment_id: a string ID identifying an experiment.
- plugin_name: string
- run: string
- tag: string
- step: int
- index: int
-
- Returns:
- A URL-safe base64-encoded string representing the provided arguments.
- """
- # Encodes the blob key as a URL-safe string, as required by the
- # `BlobReference` API in `tensorboard/data/provider.py`, because these keys
- # may be used to construct URLs for retrieving blobs.
- stringified = json.dumps(
- (experiment_id, plugin_name, run, tag, step, index),
- separators=(",", ":"))
- bytesified = stringified.encode("ascii")
- encoded = base64.urlsafe_b64encode(bytesified)
- return six.ensure_str(encoded).rstrip("=")
+ # Encodes the blob key as a URL-safe string, as required by the
+ # `BlobReference` API in `tensorboard/data/provider.py`, because these keys
+ # may be used to construct URLs for retrieving blobs.
+ stringified = json.dumps(
+ (experiment_id, plugin_name, run, tag, step, index),
+ separators=(",", ":"),
+ )
+ bytesified = stringified.encode("ascii")
+ encoded = base64.urlsafe_b64encode(bytesified)
+ return six.ensure_str(encoded).rstrip("=")
# Any changes to this function need not be backward-compatible, even though
@@ -322,34 +329,36 @@ def _encode_blob_key(experiment_id, plugin_name, run, tag, step, index):
# within the context of the session that created them (via the matching
# `_encode_blob_key` function above).
def _decode_blob_key(key):
- """Decode a blob key produced by `_encode_blob_key` into component fields.
+ """Decode a blob key produced by `_encode_blob_key` into component fields.
- Args:
- key: a blob key, as generated by `_encode_blob_key`.
+ Args:
+ key: a blob key, as generated by `_encode_blob_key`.
- Returns:
- A tuple of `(experiment_id, plugin_name, run, tag, step, index)`, with types
- matching the arguments of `_encode_blob_key`.
- """
- decoded = base64.urlsafe_b64decode(key + "==") # pad past a multiple of 4.
- stringified = decoded.decode("ascii")
- (experiment_id, plugin_name, run, tag, step, index) = json.loads(stringified)
- return (experiment_id, plugin_name, run, tag, step, index)
+ Returns:
+ A tuple of `(experiment_id, plugin_name, run, tag, step, index)`, with types
+ matching the arguments of `_encode_blob_key`.
+ """
+ decoded = base64.urlsafe_b64decode(key + "==") # pad past a multiple of 4.
+ stringified = decoded.decode("ascii")
+ (experiment_id, plugin_name, run, tag, step, index) = json.loads(
+ stringified
+ )
+ return (experiment_id, plugin_name, run, tag, step, index)
def _convert_scalar_event(event):
- """Helper for `read_scalars`."""
- return provider.ScalarDatum(
- step=event.step,
- wall_time=event.wall_time,
- value=tensor_util.make_ndarray(event.tensor_proto).item(),
- )
+ """Helper for `read_scalars`."""
+ return provider.ScalarDatum(
+ step=event.step,
+ wall_time=event.wall_time,
+ value=tensor_util.make_ndarray(event.tensor_proto).item(),
+ )
def _convert_tensor_event(event):
- """Helper for `read_tensors`."""
- return provider.TensorDatum(
- step=event.step,
- wall_time=event.wall_time,
- numpy=tensor_util.make_ndarray(event.tensor_proto),
- )
+ """Helper for `read_tensors`."""
+ return provider.TensorDatum(
+ step=event.step,
+ wall_time=event.wall_time,
+ numpy=tensor_util.make_ndarray(event.tensor_proto),
+ )
diff --git a/tensorboard/backend/event_processing/data_provider_test.py b/tensorboard/backend/event_processing/data_provider_test.py
index 53256c303f..e13ede243d 100644
--- a/tensorboard/backend/event_processing/data_provider_test.py
+++ b/tensorboard/backend/event_processing/data_provider_test.py
@@ -43,239 +43,267 @@
class MultiplexerDataProviderTest(tf.test.TestCase):
- def setUp(self):
- super(MultiplexerDataProviderTest, self).setUp()
- self.logdir = self.get_temp_dir()
-
- logdir = os.path.join(self.logdir, "polynomials")
- with tf.summary.create_file_writer(logdir).as_default():
- for i in xrange(10):
- scalar_summary.scalar("square", i ** 2, step=2 * i, description="boxen")
- scalar_summary.scalar("cube", i ** 3, step=3 * i)
-
- logdir = os.path.join(self.logdir, "waves")
- with tf.summary.create_file_writer(logdir).as_default():
- for i in xrange(10):
- scalar_summary.scalar("sine", tf.sin(float(i)), step=i)
- scalar_summary.scalar("square", tf.sign(tf.sin(float(i))), step=i)
- # Summary with rank-0 data but not owned by the scalars plugin.
- metadata = summary_pb2.SummaryMetadata()
- metadata.plugin_data.plugin_name = "marigraphs"
- tf.summary.write("high_tide", tensor=i, step=i, metadata=metadata)
-
- logdir = os.path.join(self.logdir, "pictures")
- with tf.summary.create_file_writer(logdir).as_default():
- colors = [
- ("`#F0F`", (255, 0, 255), "purple"),
- ("`#0F0`", (255, 0, 255), "green"),
- ]
- for (description, rgb, name) in colors:
- pixel = tf.constant([[list(rgb)]], dtype=tf.uint8)
- for i in xrange(1, 11):
- pixels = [tf.tile(pixel, [i, i, 1])]
- image_summary.image(name, pixels, step=i, description=description)
-
- def create_multiplexer(self):
- multiplexer = event_multiplexer.EventMultiplexer()
- multiplexer.AddRunsFromDirectory(self.logdir)
- multiplexer.Reload()
- return multiplexer
-
- def create_provider(self):
- multiplexer = self.create_multiplexer()
- return data_provider.MultiplexerDataProvider(multiplexer, self.logdir)
-
- def test_data_location(self):
- provider = self.create_provider()
- result = provider.data_location(experiment_id="unused")
- self.assertEqual(result, self.logdir)
-
- def test_list_runs(self):
- # We can't control the timestamps of events written to disk (without
- # manually reading the tfrecords, modifying the data, and writing
- # them back out), so we provide a fake multiplexer instead.
- start_times = {
- "second_2": 2.0,
- "first": 1.5,
- "no_time": None,
- "second_1": 2.0,
- }
- class FakeMultiplexer(object):
- def Runs(multiplexer):
- result = ["second_2", "first", "no_time", "second_1"]
- self.assertItemsEqual(result, start_times)
- return result
-
- def FirstEventTimestamp(multiplexer, run):
- self.assertIn(run, start_times)
- result = start_times[run]
- if result is None:
- raise ValueError("No event timestep could be found")
- else:
- return result
-
- multiplexer = FakeMultiplexer()
- provider = data_provider.MultiplexerDataProvider(multiplexer, "fake_logdir")
- result = provider.list_runs(experiment_id="unused")
- self.assertItemsEqual(result, [
- base_provider.Run(run_id=run, run_name=run, start_time=start_time)
- for (run, start_time) in six.iteritems(start_times)
- ])
-
- def test_list_scalars_all(self):
- provider = self.create_provider()
- result = provider.list_scalars(
- experiment_id="unused",
- plugin_name=scalar_metadata.PLUGIN_NAME,
- run_tag_filter=None,
- )
- self.assertItemsEqual(result.keys(), ["polynomials", "waves"])
- self.assertItemsEqual(result["polynomials"].keys(), ["square", "cube"])
- self.assertItemsEqual(result["waves"].keys(), ["square", "sine"])
- sample = result["polynomials"]["square"]
- self.assertIsInstance(sample, base_provider.ScalarTimeSeries)
- self.assertEqual(sample.max_step, 18)
- # nothing to test for wall time, as it can't be mocked out
- self.assertEqual(sample.plugin_content, b"")
- self.assertEqual(sample.display_name, "") # not written by V2 summary ops
- self.assertEqual(sample.description, "boxen")
-
- def test_list_scalars_filters(self):
- provider = self.create_provider()
-
- result = provider.list_scalars(
- experiment_id="unused",
- plugin_name=scalar_metadata.PLUGIN_NAME,
- run_tag_filter=base_provider.RunTagFilter(["waves"], ["square"]),
- )
- self.assertItemsEqual(result.keys(), ["waves"])
- self.assertItemsEqual(result["waves"].keys(), ["square"])
-
- result = provider.list_scalars(
- experiment_id="unused",
- plugin_name=scalar_metadata.PLUGIN_NAME,
- run_tag_filter=base_provider.RunTagFilter(tags=["square", "quartic"]),
- )
- self.assertItemsEqual(result.keys(), ["polynomials", "waves"])
- self.assertItemsEqual(result["polynomials"].keys(), ["square"])
- self.assertItemsEqual(result["waves"].keys(), ["square"])
-
- result = provider.list_scalars(
- experiment_id="unused",
- plugin_name=scalar_metadata.PLUGIN_NAME,
- run_tag_filter=base_provider.RunTagFilter(runs=["waves", "hugs"]),
- )
- self.assertItemsEqual(result.keys(), ["waves"])
- self.assertItemsEqual(result["waves"].keys(), ["sine", "square"])
-
- result = provider.list_scalars(
- experiment_id="unused",
- plugin_name=scalar_metadata.PLUGIN_NAME,
- run_tag_filter=base_provider.RunTagFilter(["un"], ["likely"]),
- )
- self.assertEqual(result, {})
-
- def test_read_scalars(self):
- multiplexer = self.create_multiplexer()
- provider = data_provider.MultiplexerDataProvider(multiplexer, self.logdir)
-
- run_tag_filter = base_provider.RunTagFilter(
- runs=["waves", "polynomials", "unicorns"],
- tags=["sine", "square", "cube", "iridescence"],
- )
- result = provider.read_scalars(
- experiment_id="unused",
- plugin_name=scalar_metadata.PLUGIN_NAME,
- run_tag_filter=run_tag_filter,
- downsample=None, # not yet implemented
- )
-
- self.assertItemsEqual(result.keys(), ["polynomials", "waves"])
- self.assertItemsEqual(result["polynomials"].keys(), ["square", "cube"])
- self.assertItemsEqual(result["waves"].keys(), ["square", "sine"])
- for run in result:
- for tag in result[run]:
- tensor_events = multiplexer.Tensors(run, tag)
- self.assertLen(result[run][tag], len(tensor_events))
- for (datum, event) in zip(result[run][tag], tensor_events):
- self.assertEqual(datum.step, event.step)
- self.assertEqual(datum.wall_time, event.wall_time)
- self.assertEqual(
- datum.value, tensor_util.make_ndarray(event.tensor_proto).item()
- )
-
- def test_read_scalars_but_not_rank_0(self):
- provider = self.create_provider()
- run_tag_filter = base_provider.RunTagFilter(["pictures"], ["purple"])
- # No explicit checks yet.
- with six.assertRaisesRegex(
- self,
- ValueError,
- "can only convert an array of size 1 to a Python scalar"):
- provider.read_scalars(
- experiment_id="unused",
- plugin_name=image_metadata.PLUGIN_NAME,
- run_tag_filter=run_tag_filter,
- )
-
- def test_list_tensors_all(self):
- provider = self.create_provider()
- result = provider.list_tensors(
- experiment_id="unused",
- plugin_name=image_metadata.PLUGIN_NAME,
- run_tag_filter=None,
- )
- self.assertItemsEqual(result.keys(), ["pictures"])
- self.assertItemsEqual(result["pictures"].keys(), ["purple", "green"])
- sample = result["pictures"]["purple"]
- self.assertIsInstance(sample, base_provider.TensorTimeSeries)
- self.assertEqual(sample.max_step, 10)
- # nothing to test for wall time, as it can't be mocked out
- self.assertEqual(sample.plugin_content, b"")
- self.assertEqual(sample.display_name, "") # not written by V2 summary ops
- self.assertEqual(sample.description, "`#F0F`")
-
- def test_list_tensors_filters(self):
- provider = self.create_provider()
-
- # Quick check only, as scalars and tensors use the same underlying
- # filtering implementation.
- result = provider.list_tensors(
- experiment_id="unused",
- plugin_name=image_metadata.PLUGIN_NAME,
- run_tag_filter=base_provider.RunTagFilter(["pictures"], ["green"]),
- )
- self.assertItemsEqual(result.keys(), ["pictures"])
- self.assertItemsEqual(result["pictures"].keys(), ["green"])
-
- def test_read_tensors(self):
- multiplexer = self.create_multiplexer()
- provider = data_provider.MultiplexerDataProvider(multiplexer, self.logdir)
-
- run_tag_filter = base_provider.RunTagFilter(
- runs=["pictures"],
- tags=["purple", "green"],
- )
- result = provider.read_tensors(
- experiment_id="unused",
- plugin_name=image_metadata.PLUGIN_NAME,
- run_tag_filter=run_tag_filter,
- downsample=None, # not yet implemented
- )
-
- self.assertItemsEqual(result.keys(), ["pictures"])
- self.assertItemsEqual(result["pictures"].keys(), ["purple", "green"])
- for run in result:
- for tag in result[run]:
- tensor_events = multiplexer.Tensors(run, tag)
- self.assertLen(result[run][tag], len(tensor_events))
- for (datum, event) in zip(result[run][tag], tensor_events):
- self.assertEqual(datum.step, event.step)
- self.assertEqual(datum.wall_time, event.wall_time)
- np.testing.assert_equal(
- datum.numpy, tensor_util.make_ndarray(event.tensor_proto)
- )
+ def setUp(self):
+ super(MultiplexerDataProviderTest, self).setUp()
+ self.logdir = self.get_temp_dir()
+
+ logdir = os.path.join(self.logdir, "polynomials")
+ with tf.summary.create_file_writer(logdir).as_default():
+ for i in xrange(10):
+ scalar_summary.scalar(
+ "square", i ** 2, step=2 * i, description="boxen"
+ )
+ scalar_summary.scalar("cube", i ** 3, step=3 * i)
+
+ logdir = os.path.join(self.logdir, "waves")
+ with tf.summary.create_file_writer(logdir).as_default():
+ for i in xrange(10):
+ scalar_summary.scalar("sine", tf.sin(float(i)), step=i)
+ scalar_summary.scalar(
+ "square", tf.sign(tf.sin(float(i))), step=i
+ )
+ # Summary with rank-0 data but not owned by the scalars plugin.
+ metadata = summary_pb2.SummaryMetadata()
+ metadata.plugin_data.plugin_name = "marigraphs"
+ tf.summary.write(
+ "high_tide", tensor=i, step=i, metadata=metadata
+ )
+
+ logdir = os.path.join(self.logdir, "pictures")
+ with tf.summary.create_file_writer(logdir).as_default():
+ colors = [
+ ("`#F0F`", (255, 0, 255), "purple"),
+ ("`#0F0`", (255, 0, 255), "green"),
+ ]
+ for (description, rgb, name) in colors:
+ pixel = tf.constant([[list(rgb)]], dtype=tf.uint8)
+ for i in xrange(1, 11):
+ pixels = [tf.tile(pixel, [i, i, 1])]
+ image_summary.image(
+ name, pixels, step=i, description=description
+ )
+
+ def create_multiplexer(self):
+ multiplexer = event_multiplexer.EventMultiplexer()
+ multiplexer.AddRunsFromDirectory(self.logdir)
+ multiplexer.Reload()
+ return multiplexer
+
+ def create_provider(self):
+ multiplexer = self.create_multiplexer()
+ return data_provider.MultiplexerDataProvider(multiplexer, self.logdir)
+
+ def test_data_location(self):
+ provider = self.create_provider()
+ result = provider.data_location(experiment_id="unused")
+ self.assertEqual(result, self.logdir)
+
+ def test_list_runs(self):
+ # We can't control the timestamps of events written to disk (without
+ # manually reading the tfrecords, modifying the data, and writing
+ # them back out), so we provide a fake multiplexer instead.
+ start_times = {
+ "second_2": 2.0,
+ "first": 1.5,
+ "no_time": None,
+ "second_1": 2.0,
+ }
+
+ class FakeMultiplexer(object):
+ def Runs(multiplexer):
+ result = ["second_2", "first", "no_time", "second_1"]
+ self.assertItemsEqual(result, start_times)
+ return result
+
+ def FirstEventTimestamp(multiplexer, run):
+ self.assertIn(run, start_times)
+ result = start_times[run]
+ if result is None:
+ raise ValueError("No event timestep could be found")
+ else:
+ return result
+
+ multiplexer = FakeMultiplexer()
+ provider = data_provider.MultiplexerDataProvider(
+ multiplexer, "fake_logdir"
+ )
+ result = provider.list_runs(experiment_id="unused")
+ self.assertItemsEqual(
+ result,
+ [
+ base_provider.Run(
+ run_id=run, run_name=run, start_time=start_time
+ )
+ for (run, start_time) in six.iteritems(start_times)
+ ],
+ )
+
+ def test_list_scalars_all(self):
+ provider = self.create_provider()
+ result = provider.list_scalars(
+ experiment_id="unused",
+ plugin_name=scalar_metadata.PLUGIN_NAME,
+ run_tag_filter=None,
+ )
+ self.assertItemsEqual(result.keys(), ["polynomials", "waves"])
+ self.assertItemsEqual(result["polynomials"].keys(), ["square", "cube"])
+ self.assertItemsEqual(result["waves"].keys(), ["square", "sine"])
+ sample = result["polynomials"]["square"]
+ self.assertIsInstance(sample, base_provider.ScalarTimeSeries)
+ self.assertEqual(sample.max_step, 18)
+ # nothing to test for wall time, as it can't be mocked out
+ self.assertEqual(sample.plugin_content, b"")
+ self.assertEqual(
+ sample.display_name, ""
+ ) # not written by V2 summary ops
+ self.assertEqual(sample.description, "boxen")
+
+ def test_list_scalars_filters(self):
+ provider = self.create_provider()
+
+ result = provider.list_scalars(
+ experiment_id="unused",
+ plugin_name=scalar_metadata.PLUGIN_NAME,
+ run_tag_filter=base_provider.RunTagFilter(["waves"], ["square"]),
+ )
+ self.assertItemsEqual(result.keys(), ["waves"])
+ self.assertItemsEqual(result["waves"].keys(), ["square"])
+
+ result = provider.list_scalars(
+ experiment_id="unused",
+ plugin_name=scalar_metadata.PLUGIN_NAME,
+ run_tag_filter=base_provider.RunTagFilter(
+ tags=["square", "quartic"]
+ ),
+ )
+ self.assertItemsEqual(result.keys(), ["polynomials", "waves"])
+ self.assertItemsEqual(result["polynomials"].keys(), ["square"])
+ self.assertItemsEqual(result["waves"].keys(), ["square"])
+
+ result = provider.list_scalars(
+ experiment_id="unused",
+ plugin_name=scalar_metadata.PLUGIN_NAME,
+ run_tag_filter=base_provider.RunTagFilter(runs=["waves", "hugs"]),
+ )
+ self.assertItemsEqual(result.keys(), ["waves"])
+ self.assertItemsEqual(result["waves"].keys(), ["sine", "square"])
+
+ result = provider.list_scalars(
+ experiment_id="unused",
+ plugin_name=scalar_metadata.PLUGIN_NAME,
+ run_tag_filter=base_provider.RunTagFilter(["un"], ["likely"]),
+ )
+ self.assertEqual(result, {})
+
+ def test_read_scalars(self):
+ multiplexer = self.create_multiplexer()
+ provider = data_provider.MultiplexerDataProvider(
+ multiplexer, self.logdir
+ )
+
+ run_tag_filter = base_provider.RunTagFilter(
+ runs=["waves", "polynomials", "unicorns"],
+ tags=["sine", "square", "cube", "iridescence"],
+ )
+ result = provider.read_scalars(
+ experiment_id="unused",
+ plugin_name=scalar_metadata.PLUGIN_NAME,
+ run_tag_filter=run_tag_filter,
+ downsample=None, # not yet implemented
+ )
+
+ self.assertItemsEqual(result.keys(), ["polynomials", "waves"])
+ self.assertItemsEqual(result["polynomials"].keys(), ["square", "cube"])
+ self.assertItemsEqual(result["waves"].keys(), ["square", "sine"])
+ for run in result:
+ for tag in result[run]:
+ tensor_events = multiplexer.Tensors(run, tag)
+ self.assertLen(result[run][tag], len(tensor_events))
+ for (datum, event) in zip(result[run][tag], tensor_events):
+ self.assertEqual(datum.step, event.step)
+ self.assertEqual(datum.wall_time, event.wall_time)
+ self.assertEqual(
+ datum.value,
+ tensor_util.make_ndarray(event.tensor_proto).item(),
+ )
+
+ def test_read_scalars_but_not_rank_0(self):
+ provider = self.create_provider()
+ run_tag_filter = base_provider.RunTagFilter(["pictures"], ["purple"])
+ # No explicit checks yet.
+ with six.assertRaisesRegex(
+ self,
+ ValueError,
+ "can only convert an array of size 1 to a Python scalar",
+ ):
+ provider.read_scalars(
+ experiment_id="unused",
+ plugin_name=image_metadata.PLUGIN_NAME,
+ run_tag_filter=run_tag_filter,
+ )
+
+ def test_list_tensors_all(self):
+ provider = self.create_provider()
+ result = provider.list_tensors(
+ experiment_id="unused",
+ plugin_name=image_metadata.PLUGIN_NAME,
+ run_tag_filter=None,
+ )
+ self.assertItemsEqual(result.keys(), ["pictures"])
+ self.assertItemsEqual(result["pictures"].keys(), ["purple", "green"])
+ sample = result["pictures"]["purple"]
+ self.assertIsInstance(sample, base_provider.TensorTimeSeries)
+ self.assertEqual(sample.max_step, 10)
+ # nothing to test for wall time, as it can't be mocked out
+ self.assertEqual(sample.plugin_content, b"")
+ self.assertEqual(
+ sample.display_name, ""
+ ) # not written by V2 summary ops
+ self.assertEqual(sample.description, "`#F0F`")
+
+ def test_list_tensors_filters(self):
+ provider = self.create_provider()
+
+ # Quick check only, as scalars and tensors use the same underlying
+ # filtering implementation.
+ result = provider.list_tensors(
+ experiment_id="unused",
+ plugin_name=image_metadata.PLUGIN_NAME,
+ run_tag_filter=base_provider.RunTagFilter(["pictures"], ["green"]),
+ )
+ self.assertItemsEqual(result.keys(), ["pictures"])
+ self.assertItemsEqual(result["pictures"].keys(), ["green"])
+
+ def test_read_tensors(self):
+ multiplexer = self.create_multiplexer()
+ provider = data_provider.MultiplexerDataProvider(
+ multiplexer, self.logdir
+ )
+
+ run_tag_filter = base_provider.RunTagFilter(
+ runs=["pictures"], tags=["purple", "green"],
+ )
+ result = provider.read_tensors(
+ experiment_id="unused",
+ plugin_name=image_metadata.PLUGIN_NAME,
+ run_tag_filter=run_tag_filter,
+ downsample=None, # not yet implemented
+ )
+
+ self.assertItemsEqual(result.keys(), ["pictures"])
+ self.assertItemsEqual(result["pictures"].keys(), ["purple", "green"])
+ for run in result:
+ for tag in result[run]:
+ tensor_events = multiplexer.Tensors(run, tag)
+ self.assertLen(result[run][tag], len(tensor_events))
+ for (datum, event) in zip(result[run][tag], tensor_events):
+ self.assertEqual(datum.step, event.step)
+ self.assertEqual(datum.wall_time, event.wall_time)
+ np.testing.assert_equal(
+ datum.numpy,
+ tensor_util.make_ndarray(event.tensor_proto),
+ )
if __name__ == "__main__":
- tf.test.main()
+ tf.test.main()
diff --git a/tensorboard/backend/event_processing/db_import_multiplexer.py b/tensorboard/backend/event_processing/db_import_multiplexer.py
index 968a06ea0f..dd4d39700b 100644
--- a/tensorboard/backend/event_processing/db_import_multiplexer.py
+++ b/tensorboard/backend/event_processing/db_import_multiplexer.py
@@ -42,262 +42,289 @@
class DbImportMultiplexer(plugin_event_multiplexer.EventMultiplexer):
- """A loading-only `EventMultiplexer` that populates a SQLite DB.
-
- This EventMultiplexer only loads data; the read APIs always return empty
- results, since all data is accessed instead via SQL against the
- db_connection_provider wrapped by this multiplexer.
- """
-
- def __init__(self,
- db_uri,
- db_connection_provider,
- purge_orphaned_data,
- max_reload_threads):
- """Constructor for `DbImportMultiplexer`.
-
- Args:
- db_uri: A URI to the database file in use.
- db_connection_provider: Provider function for creating a DB connection.
- purge_orphaned_data: Whether to discard any events that were "orphaned" by
- a TensorFlow restart.
- max_reload_threads: The max number of threads that TensorBoard can use
- to reload runs. Each thread reloads one run at a time. If not provided,
- reloads runs serially (one after another).
- """
- logger.info('DbImportMultiplexer initializing for %s', db_uri)
- super(DbImportMultiplexer, self).__init__()
- self.db_uri = db_uri
- self.db_connection_provider = db_connection_provider
- self._purge_orphaned_data = purge_orphaned_data
- self._max_reload_threads = max_reload_threads
- self._event_sink = None
- self._run_loaders = {}
-
- if self._purge_orphaned_data:
- logger.warn(
- '--db_import does not yet support purging orphaned data')
-
- conn = self.db_connection_provider()
- # Set the DB in WAL mode so reads don't block writes.
- conn.execute('PRAGMA journal_mode=wal')
- conn.execute('PRAGMA synchronous=normal') # Recommended for WAL mode
- sqlite_writer.initialize_schema(conn)
- logger.info('DbImportMultiplexer done initializing')
-
- def AddRun(self, path, name=None):
- """Unsupported; instead use AddRunsFromDirectory."""
- raise NotImplementedError("Unsupported; use AddRunsFromDirectory()")
-
- def AddRunsFromDirectory(self, path, name=None):
- """Load runs from a directory; recursively walks subdirectories.
-
- If path doesn't exist, no-op. This ensures that it is safe to call
- `AddRunsFromDirectory` multiple times, even before the directory is made.
-
- Args:
- path: A string path to a directory to load runs from.
- name: Optional, specifies a name for the experiment under which the
- runs from this directory hierarchy will be imported. If omitted, the
- path will be used as the name.
-
- Raises:
- ValueError: If the path exists and isn't a directory.
+ """A loading-only `EventMultiplexer` that populates a SQLite DB.
+
+ This EventMultiplexer only loads data; the read APIs always return
+ empty results, since all data is accessed instead via SQL against
+ the db_connection_provider wrapped by this multiplexer.
"""
- logger.info('Starting AddRunsFromDirectory: %s (as %s)', path, name)
- for subdir in io_wrapper.GetLogdirSubdirectories(path):
- logger.info('Processing directory %s', subdir)
- if subdir not in self._run_loaders:
- logger.info('Creating DB loader for directory %s', subdir)
- names = self._get_exp_and_run_names(path, subdir, name)
- experiment_name, run_name = names
- self._run_loaders[subdir] = _RunLoader(
- subdir=subdir,
- experiment_name=experiment_name,
- run_name=run_name)
- logger.info('Done with AddRunsFromDirectory: %s', path)
-
- def Reload(self):
- """Load events from every detected run."""
- logger.info('Beginning DbImportMultiplexer.Reload()')
- # Defer event sink creation until needed; this ensures it will only exist in
- # the thread that calls Reload(), since DB connections must be thread-local.
- if not self._event_sink:
- self._event_sink = _SqliteWriterEventSink(self.db_connection_provider)
- # Use collections.deque() for speed when we don't need blocking since it
- # also has thread-safe appends/pops.
- loader_queue = collections.deque(six.itervalues(self._run_loaders))
- loader_delete_queue = collections.deque()
-
- def batch_generator():
- while True:
- try:
- loader = loader_queue.popleft()
- except IndexError:
- return
- try:
- for batch in loader.load_batches():
- yield batch
- except directory_watcher.DirectoryDeletedError:
- loader_delete_queue.append(loader)
- except (OSError, IOError) as e:
- logger.error('Unable to load run %r: %s', loader.subdir, e)
-
- num_threads = min(self._max_reload_threads, len(self._run_loaders))
- if num_threads <= 1:
- logger.info('Importing runs serially on a single thread')
- for batch in batch_generator():
- self._event_sink.write_batch(batch)
- else:
- output_queue = queue.Queue()
- sentinel = object()
- def producer():
- try:
- for batch in batch_generator():
- output_queue.put(batch)
- finally:
- output_queue.put(sentinel)
- logger.info('Starting %d threads to import runs', num_threads)
- for i in xrange(num_threads):
- thread = threading.Thread(target=producer, name='Loader %d' % i)
- thread.daemon = True
- thread.start()
- num_live_threads = num_threads
- while num_live_threads > 0:
- output = output_queue.get()
- if output == sentinel:
- num_live_threads -= 1
- continue
- self._event_sink.write_batch(output)
- for loader in loader_delete_queue:
- logger.warn('Deleting loader %r', loader.subdir)
- del self._run_loaders[loader.subdir]
- logger.info('Finished with DbImportMultiplexer.Reload()')
-
- def _get_exp_and_run_names(self, path, subdir, experiment_name_override=None):
- if experiment_name_override is not None:
- return (experiment_name_override, os.path.relpath(subdir, path))
- sep = io_wrapper.PathSeparator(path)
- path_parts = os.path.relpath(subdir, path).split(sep, 1)
- experiment_name = path_parts[0]
- run_name = path_parts[1] if len(path_parts) == 2 else '.'
- return (experiment_name, run_name)
+
+ def __init__(
+ self,
+ db_uri,
+ db_connection_provider,
+ purge_orphaned_data,
+ max_reload_threads,
+ ):
+ """Constructor for `DbImportMultiplexer`.
+
+ Args:
+ db_uri: A URI to the database file in use.
+ db_connection_provider: Provider function for creating a DB connection.
+ purge_orphaned_data: Whether to discard any events that were "orphaned" by
+ a TensorFlow restart.
+ max_reload_threads: The max number of threads that TensorBoard can use
+ to reload runs. Each thread reloads one run at a time. If not provided,
+ reloads runs serially (one after another).
+ """
+ logger.info("DbImportMultiplexer initializing for %s", db_uri)
+ super(DbImportMultiplexer, self).__init__()
+ self.db_uri = db_uri
+ self.db_connection_provider = db_connection_provider
+ self._purge_orphaned_data = purge_orphaned_data
+ self._max_reload_threads = max_reload_threads
+ self._event_sink = None
+ self._run_loaders = {}
+
+ if self._purge_orphaned_data:
+ logger.warn(
+ "--db_import does not yet support purging orphaned data"
+ )
+
+ conn = self.db_connection_provider()
+ # Set the DB in WAL mode so reads don't block writes.
+ conn.execute("PRAGMA journal_mode=wal")
+ conn.execute("PRAGMA synchronous=normal") # Recommended for WAL mode
+ sqlite_writer.initialize_schema(conn)
+ logger.info("DbImportMultiplexer done initializing")
+
+ def AddRun(self, path, name=None):
+ """Unsupported; instead use AddRunsFromDirectory."""
+ raise NotImplementedError("Unsupported; use AddRunsFromDirectory()")
+
+ def AddRunsFromDirectory(self, path, name=None):
+ """Load runs from a directory; recursively walks subdirectories.
+
+ If path doesn't exist, no-op. This ensures that it is safe to call
+ `AddRunsFromDirectory` multiple times, even before the directory is made.
+
+ Args:
+ path: A string path to a directory to load runs from.
+ name: Optional, specifies a name for the experiment under which the
+ runs from this directory hierarchy will be imported. If omitted, the
+ path will be used as the name.
+
+ Raises:
+ ValueError: If the path exists and isn't a directory.
+ """
+ logger.info("Starting AddRunsFromDirectory: %s (as %s)", path, name)
+ for subdir in io_wrapper.GetLogdirSubdirectories(path):
+ logger.info("Processing directory %s", subdir)
+ if subdir not in self._run_loaders:
+ logger.info("Creating DB loader for directory %s", subdir)
+ names = self._get_exp_and_run_names(path, subdir, name)
+ experiment_name, run_name = names
+ self._run_loaders[subdir] = _RunLoader(
+ subdir=subdir,
+ experiment_name=experiment_name,
+ run_name=run_name,
+ )
+ logger.info("Done with AddRunsFromDirectory: %s", path)
+
+ def Reload(self):
+ """Load events from every detected run."""
+ logger.info("Beginning DbImportMultiplexer.Reload()")
+ # Defer event sink creation until needed; this ensures it will only exist in
+ # the thread that calls Reload(), since DB connections must be thread-local.
+ if not self._event_sink:
+ self._event_sink = _SqliteWriterEventSink(
+ self.db_connection_provider
+ )
+ # Use collections.deque() for speed when we don't need blocking since it
+ # also has thread-safe appends/pops.
+ loader_queue = collections.deque(six.itervalues(self._run_loaders))
+ loader_delete_queue = collections.deque()
+
+ def batch_generator():
+ while True:
+ try:
+ loader = loader_queue.popleft()
+ except IndexError:
+ return
+ try:
+ for batch in loader.load_batches():
+ yield batch
+ except directory_watcher.DirectoryDeletedError:
+ loader_delete_queue.append(loader)
+ except (OSError, IOError) as e:
+ logger.error("Unable to load run %r: %s", loader.subdir, e)
+
+ num_threads = min(self._max_reload_threads, len(self._run_loaders))
+ if num_threads <= 1:
+ logger.info("Importing runs serially on a single thread")
+ for batch in batch_generator():
+ self._event_sink.write_batch(batch)
+ else:
+ output_queue = queue.Queue()
+ sentinel = object()
+
+ def producer():
+ try:
+ for batch in batch_generator():
+ output_queue.put(batch)
+ finally:
+ output_queue.put(sentinel)
+
+ logger.info("Starting %d threads to import runs", num_threads)
+ for i in xrange(num_threads):
+ thread = threading.Thread(target=producer, name="Loader %d" % i)
+ thread.daemon = True
+ thread.start()
+ num_live_threads = num_threads
+ while num_live_threads > 0:
+ output = output_queue.get()
+ if output == sentinel:
+ num_live_threads -= 1
+ continue
+ self._event_sink.write_batch(output)
+ for loader in loader_delete_queue:
+ logger.warn("Deleting loader %r", loader.subdir)
+ del self._run_loaders[loader.subdir]
+ logger.info("Finished with DbImportMultiplexer.Reload()")
+
+ def _get_exp_and_run_names(
+ self, path, subdir, experiment_name_override=None
+ ):
+ if experiment_name_override is not None:
+ return (experiment_name_override, os.path.relpath(subdir, path))
+ sep = io_wrapper.PathSeparator(path)
+ path_parts = os.path.relpath(subdir, path).split(sep, 1)
+ experiment_name = path_parts[0]
+ run_name = path_parts[1] if len(path_parts) == 2 else "."
+ return (experiment_name, run_name)
+
# Struct holding a list of tf.Event serialized protos along with metadata about
# the associated experiment and run.
-_EventBatch = collections.namedtuple('EventBatch',
- ['events', 'experiment_name', 'run_name'])
+_EventBatch = collections.namedtuple(
+ "EventBatch", ["events", "experiment_name", "run_name"]
+)
class _RunLoader(object):
- """Loads a single run directory in batches."""
-
- _BATCH_COUNT = 5000
- _BATCH_BYTES = 2**20 # 1 MiB
-
- def __init__(self, subdir, experiment_name, run_name):
- """Constructs a `_RunLoader`.
-
- Args:
- subdir: string, filesystem path of the run directory
- experiment_name: string, name of the run's experiment
- run_name: string, name of the run
- """
- self._subdir = subdir
- self._experiment_name = experiment_name
- self._run_name = run_name
- self._directory_watcher = directory_watcher.DirectoryWatcher(
- subdir,
- event_file_loader.RawEventFileLoader,
- io_wrapper.IsTensorFlowEventsFile)
-
- @property
- def subdir(self):
- return self._subdir
-
- def load_batches(self):
- """Returns a batched event iterator over the run directory event files."""
- event_iterator = self._directory_watcher.Load()
- while True:
- events = []
- event_bytes = 0
- start = time.time()
- for event_proto in event_iterator:
- events.append(event_proto)
- event_bytes += len(event_proto)
- if len(events) >= self._BATCH_COUNT or event_bytes >= self._BATCH_BYTES:
- break
- elapsed = time.time() - start
- logger.debug('RunLoader.load_batch() yielded in %0.3f sec for %s',
- elapsed, self._subdir)
- if not events:
- return
- yield _EventBatch(
- events=events,
- experiment_name=self._experiment_name,
- run_name=self._run_name)
+ """Loads a single run directory in batches."""
+
+ _BATCH_COUNT = 5000
+ _BATCH_BYTES = 2 ** 20 # 1 MiB
+
+ def __init__(self, subdir, experiment_name, run_name):
+ """Constructs a `_RunLoader`.
+
+ Args:
+ subdir: string, filesystem path of the run directory
+ experiment_name: string, name of the run's experiment
+ run_name: string, name of the run
+ """
+ self._subdir = subdir
+ self._experiment_name = experiment_name
+ self._run_name = run_name
+ self._directory_watcher = directory_watcher.DirectoryWatcher(
+ subdir,
+ event_file_loader.RawEventFileLoader,
+ io_wrapper.IsTensorFlowEventsFile,
+ )
+
+ @property
+ def subdir(self):
+ return self._subdir
+
+ def load_batches(self):
+ """Returns a batched event iterator over the run directory event
+ files."""
+ event_iterator = self._directory_watcher.Load()
+ while True:
+ events = []
+ event_bytes = 0
+ start = time.time()
+ for event_proto in event_iterator:
+ events.append(event_proto)
+ event_bytes += len(event_proto)
+ if (
+ len(events) >= self._BATCH_COUNT
+ or event_bytes >= self._BATCH_BYTES
+ ):
+ break
+ elapsed = time.time() - start
+ logger.debug(
+ "RunLoader.load_batch() yielded in %0.3f sec for %s",
+ elapsed,
+ self._subdir,
+ )
+ if not events:
+ return
+ yield _EventBatch(
+ events=events,
+ experiment_name=self._experiment_name,
+ run_name=self._run_name,
+ )
@six.add_metaclass(abc.ABCMeta)
class _EventSink(object):
- """Abstract sink for batches of serialized tf.Event data."""
+ """Abstract sink for batches of serialized tf.Event data."""
- @abc.abstractmethod
- def write_batch(self, event_batch):
- """Writes the given event batch to the sink.
+ @abc.abstractmethod
+ def write_batch(self, event_batch):
+ """Writes the given event batch to the sink.
- Args:
- event_batch: an _EventBatch of event data.
- """
- raise NotImplementedError()
+ Args:
+ event_batch: an _EventBatch of event data.
+ """
+ raise NotImplementedError()
class _SqliteWriterEventSink(_EventSink):
- """Implementation of EventSink using SqliteWriter."""
-
- def __init__(self, db_connection_provider):
- """Constructs a SqliteWriterEventSink.
-
- Args:
- db_connection_provider: Provider function for creating a DB connection.
- """
- self._writer = sqlite_writer.SqliteWriter(db_connection_provider)
-
- def write_batch(self, event_batch):
- start = time.time()
- tagged_data = {}
- for event_proto in event_batch.events:
- event = event_pb2.Event.FromString(event_proto)
- self._process_event(event, tagged_data)
- if tagged_data:
- self._writer.write_summaries(
- tagged_data,
- experiment_name=event_batch.experiment_name,
- run_name=event_batch.run_name)
- elapsed = time.time() - start
- logger.debug(
- 'SqliteWriterEventSink.WriteBatch() took %0.3f sec for %s events',
- elapsed, len(event_batch.events))
-
- def _process_event(self, event, tagged_data):
- """Processes a single tf.Event and records it in tagged_data."""
- event_type = event.WhichOneof('what')
- # Handle the most common case first.
- if event_type == 'summary':
- for value in event.summary.value:
- value = data_compat.migrate_value(value)
- tag, metadata, values = tagged_data.get(value.tag, (None, None, []))
- values.append((event.step, event.wall_time, value.tensor))
- if tag is None:
- # Store metadata only from the first event.
- tagged_data[value.tag] = sqlite_writer.TagData(
- value.tag, value.metadata, values)
- elif event_type == 'file_version':
- pass # TODO: reject file version < 2 (at loader level)
- elif event_type == 'session_log':
- if event.session_log.status == event_pb2.SessionLog.START:
- pass # TODO: implement purging via sqlite writer truncation method
- elif event_type in ('graph_def', 'meta_graph_def'):
- pass # TODO: support graphs
- elif event_type == 'tagged_run_metadata':
- pass # TODO: support run metadata
+ """Implementation of EventSink using SqliteWriter."""
+
+ def __init__(self, db_connection_provider):
+ """Constructs a SqliteWriterEventSink.
+
+ Args:
+ db_connection_provider: Provider function for creating a DB connection.
+ """
+ self._writer = sqlite_writer.SqliteWriter(db_connection_provider)
+
+ def write_batch(self, event_batch):
+ start = time.time()
+ tagged_data = {}
+ for event_proto in event_batch.events:
+ event = event_pb2.Event.FromString(event_proto)
+ self._process_event(event, tagged_data)
+ if tagged_data:
+ self._writer.write_summaries(
+ tagged_data,
+ experiment_name=event_batch.experiment_name,
+ run_name=event_batch.run_name,
+ )
+ elapsed = time.time() - start
+ logger.debug(
+ "SqliteWriterEventSink.WriteBatch() took %0.3f sec for %s events",
+ elapsed,
+ len(event_batch.events),
+ )
+
+ def _process_event(self, event, tagged_data):
+ """Processes a single tf.Event and records it in tagged_data."""
+ event_type = event.WhichOneof("what")
+ # Handle the most common case first.
+ if event_type == "summary":
+ for value in event.summary.value:
+ value = data_compat.migrate_value(value)
+ tag, metadata, values = tagged_data.get(
+ value.tag, (None, None, [])
+ )
+ values.append((event.step, event.wall_time, value.tensor))
+ if tag is None:
+ # Store metadata only from the first event.
+ tagged_data[value.tag] = sqlite_writer.TagData(
+ value.tag, value.metadata, values
+ )
+ elif event_type == "file_version":
+ pass # TODO: reject file version < 2 (at loader level)
+ elif event_type == "session_log":
+ if event.session_log.status == event_pb2.SessionLog.START:
+ pass # TODO: implement purging via sqlite writer truncation method
+ elif event_type in ("graph_def", "meta_graph_def"):
+ pass # TODO: support graphs
+ elif event_type == "tagged_run_metadata":
+ pass # TODO: support run metadata
diff --git a/tensorboard/backend/event_processing/db_import_multiplexer_test.py b/tensorboard/backend/event_processing/db_import_multiplexer_test.py
index 4a43dd3fb8..89ee99f2d3 100644
--- a/tensorboard/backend/event_processing/db_import_multiplexer_test.py
+++ b/tensorboard/backend/event_processing/db_import_multiplexer_test.py
@@ -31,134 +31,151 @@
def add_event(path):
- with test_util.FileWriterCache.get(path) as writer:
- event = event_pb2.Event()
- event.summary.value.add(tag='tag', tensor=tensor_util.make_tensor_proto(1))
- writer.add_event(event)
+ with test_util.FileWriterCache.get(path) as writer:
+ event = event_pb2.Event()
+ event.summary.value.add(
+ tag="tag", tensor=tensor_util.make_tensor_proto(1)
+ )
+ writer.add_event(event)
class DbImportMultiplexerTest(tf.test.TestCase):
-
- def setUp(self):
- super(DbImportMultiplexerTest, self).setUp()
-
- db_file_name = os.path.join(self.get_temp_dir(), 'db')
- self.db_connection_provider = lambda: sqlite3.connect(db_file_name)
- self.multiplexer = db_import_multiplexer.DbImportMultiplexer(
- db_uri='sqlite:' + db_file_name,
- db_connection_provider=self.db_connection_provider,
- purge_orphaned_data=False,
- max_reload_threads=1)
-
- def _get_runs(self):
- db = self.db_connection_provider()
- cursor = db.execute('''
+ def setUp(self):
+ super(DbImportMultiplexerTest, self).setUp()
+
+ db_file_name = os.path.join(self.get_temp_dir(), "db")
+ self.db_connection_provider = lambda: sqlite3.connect(db_file_name)
+ self.multiplexer = db_import_multiplexer.DbImportMultiplexer(
+ db_uri="sqlite:" + db_file_name,
+ db_connection_provider=self.db_connection_provider,
+ purge_orphaned_data=False,
+ max_reload_threads=1,
+ )
+
+ def _get_runs(self):
+ db = self.db_connection_provider()
+ cursor = db.execute(
+ """
SELECT
Runs.run_name
FROM Runs
ORDER BY Runs.run_name
- ''')
- return [row[0] for row in cursor]
-
- def _get_experiments(self):
- db = self.db_connection_provider()
- cursor = db.execute('''
+ """
+ )
+ return [row[0] for row in cursor]
+
+ def _get_experiments(self):
+ db = self.db_connection_provider()
+ cursor = db.execute(
+ """
SELECT
Experiments.experiment_name
FROM Experiments
ORDER BY Experiments.experiment_name
- ''')
- return [row[0] for row in cursor]
-
- def test_init(self):
- """Tests that DB schema is created when creating DbImportMultiplexer."""
- # Reading DB before schema initialization raises.
- self.assertEqual(self._get_experiments(), [])
- self.assertEqual(self._get_runs(), [])
-
- def test_empty_folder(self):
- fake_dir = os.path.join(self.get_temp_dir(), 'fake_dir')
- self.multiplexer.AddRunsFromDirectory(fake_dir)
- self.assertEqual(self._get_experiments(), [])
- self.assertEqual(self._get_runs(), [])
-
- def test_flat(self):
- path = self.get_temp_dir()
- add_event(path)
- self.multiplexer.AddRunsFromDirectory(path)
- self.multiplexer.Reload()
- # Because we added runs from `path`, there is no folder to infer experiment
- # and run names from.
- self.assertEqual(self._get_experiments(), [u'.'])
- self.assertEqual(self._get_runs(), [u'.'])
-
- def test_single_level(self):
- path = self.get_temp_dir()
- add_event(os.path.join(path, 'exp1'))
- add_event(os.path.join(path, 'exp2'))
- self.multiplexer.AddRunsFromDirectory(path)
- self.multiplexer.Reload()
- self.assertEqual(self._get_experiments(), [u'exp1', u'exp2'])
- # Run names are '.'. because we already used the directory name for
- # inferring experiment name. There are two items with the same name but
- # with different ids.
- self.assertEqual(self._get_runs(), [u'.', u'.'])
-
- def test_double_level(self):
- path = self.get_temp_dir()
- add_event(os.path.join(path, 'exp1', 'test'))
- add_event(os.path.join(path, 'exp1', 'train'))
- add_event(os.path.join(path, 'exp2', 'test'))
- self.multiplexer.AddRunsFromDirectory(path)
- self.multiplexer.Reload()
- self.assertEqual(self._get_experiments(), [u'exp1', u'exp2'])
- # There are two items with the same name but with different ids.
- self.assertEqual(self._get_runs(), [u'test', u'test', u'train'])
-
- def test_mixed_levels(self):
- # Mixture of root and single levels.
- path = self.get_temp_dir()
- # Train is in the root directory.
- add_event(os.path.join(path))
- add_event(os.path.join(path, 'eval'))
- self.multiplexer.AddRunsFromDirectory(path)
- self.multiplexer.Reload()
- self.assertEqual(self._get_experiments(), [u'.', u'eval'])
- self.assertEqual(self._get_runs(), [u'.', u'.'])
-
- def test_deep(self):
- path = self.get_temp_dir()
- add_event(os.path.join(path, 'exp1', 'run1', 'bar', 'train'))
- add_event(os.path.join(path, 'exp2', 'run1', 'baz', 'train'))
- self.multiplexer.AddRunsFromDirectory(path)
- self.multiplexer.Reload()
- self.assertEqual(self._get_experiments(), [u'exp1', u'exp2'])
- self.assertEqual(self._get_runs(), [os.path.join('run1', 'bar', 'train'),
- os.path.join('run1', 'baz', 'train')])
-
- def test_manual_name(self):
- path1 = os.path.join(self.get_temp_dir(), 'foo')
- path2 = os.path.join(self.get_temp_dir(), 'bar')
- add_event(os.path.join(path1, 'some', 'nested', 'name'))
- add_event(os.path.join(path2, 'some', 'nested', 'name'))
- self.multiplexer.AddRunsFromDirectory(path1, 'name1')
- self.multiplexer.AddRunsFromDirectory(path2, 'name2')
- self.multiplexer.Reload()
- self.assertEqual(self._get_experiments(), [u'name1', u'name2'])
- # Run name ignored 'foo' and 'bar' on 'foo/some/nested/name' and
- # 'bar/some/nested/name', respectively.
- # There are two items with the same name but with different ids.
- self.assertEqual(self._get_runs(), [os.path.join('some', 'nested', 'name'),
- os.path.join('some', 'nested', 'name')])
-
- def test_empty_read_apis(self):
- path = self.get_temp_dir()
- add_event(path)
- self.assertEmpty(self.multiplexer.Runs())
- self.multiplexer.AddRunsFromDirectory(path)
- self.multiplexer.Reload()
- self.assertEmpty(self.multiplexer.Runs())
-
-
-if __name__ == '__main__':
- tf.test.main()
+ """
+ )
+ return [row[0] for row in cursor]
+
+ def test_init(self):
+ """Tests that DB schema is created when creating
+ DbImportMultiplexer."""
+ # Reading DB before schema initialization raises.
+ self.assertEqual(self._get_experiments(), [])
+ self.assertEqual(self._get_runs(), [])
+
+ def test_empty_folder(self):
+ fake_dir = os.path.join(self.get_temp_dir(), "fake_dir")
+ self.multiplexer.AddRunsFromDirectory(fake_dir)
+ self.assertEqual(self._get_experiments(), [])
+ self.assertEqual(self._get_runs(), [])
+
+ def test_flat(self):
+ path = self.get_temp_dir()
+ add_event(path)
+ self.multiplexer.AddRunsFromDirectory(path)
+ self.multiplexer.Reload()
+ # Because we added runs from `path`, there is no folder to infer experiment
+ # and run names from.
+ self.assertEqual(self._get_experiments(), [u"."])
+ self.assertEqual(self._get_runs(), [u"."])
+
+ def test_single_level(self):
+ path = self.get_temp_dir()
+ add_event(os.path.join(path, "exp1"))
+ add_event(os.path.join(path, "exp2"))
+ self.multiplexer.AddRunsFromDirectory(path)
+ self.multiplexer.Reload()
+ self.assertEqual(self._get_experiments(), [u"exp1", u"exp2"])
+ # Run names are '.'. because we already used the directory name for
+ # inferring experiment name. There are two items with the same name but
+ # with different ids.
+ self.assertEqual(self._get_runs(), [u".", u"."])
+
+ def test_double_level(self):
+ path = self.get_temp_dir()
+ add_event(os.path.join(path, "exp1", "test"))
+ add_event(os.path.join(path, "exp1", "train"))
+ add_event(os.path.join(path, "exp2", "test"))
+ self.multiplexer.AddRunsFromDirectory(path)
+ self.multiplexer.Reload()
+ self.assertEqual(self._get_experiments(), [u"exp1", u"exp2"])
+ # There are two items with the same name but with different ids.
+ self.assertEqual(self._get_runs(), [u"test", u"test", u"train"])
+
+ def test_mixed_levels(self):
+ # Mixture of root and single levels.
+ path = self.get_temp_dir()
+ # Train is in the root directory.
+ add_event(os.path.join(path))
+ add_event(os.path.join(path, "eval"))
+ self.multiplexer.AddRunsFromDirectory(path)
+ self.multiplexer.Reload()
+ self.assertEqual(self._get_experiments(), [u".", u"eval"])
+ self.assertEqual(self._get_runs(), [u".", u"."])
+
+ def test_deep(self):
+ path = self.get_temp_dir()
+ add_event(os.path.join(path, "exp1", "run1", "bar", "train"))
+ add_event(os.path.join(path, "exp2", "run1", "baz", "train"))
+ self.multiplexer.AddRunsFromDirectory(path)
+ self.multiplexer.Reload()
+ self.assertEqual(self._get_experiments(), [u"exp1", u"exp2"])
+ self.assertEqual(
+ self._get_runs(),
+ [
+ os.path.join("run1", "bar", "train"),
+ os.path.join("run1", "baz", "train"),
+ ],
+ )
+
+ def test_manual_name(self):
+ path1 = os.path.join(self.get_temp_dir(), "foo")
+ path2 = os.path.join(self.get_temp_dir(), "bar")
+ add_event(os.path.join(path1, "some", "nested", "name"))
+ add_event(os.path.join(path2, "some", "nested", "name"))
+ self.multiplexer.AddRunsFromDirectory(path1, "name1")
+ self.multiplexer.AddRunsFromDirectory(path2, "name2")
+ self.multiplexer.Reload()
+ self.assertEqual(self._get_experiments(), [u"name1", u"name2"])
+ # Run name ignored 'foo' and 'bar' on 'foo/some/nested/name' and
+ # 'bar/some/nested/name', respectively.
+ # There are two items with the same name but with different ids.
+ self.assertEqual(
+ self._get_runs(),
+ [
+ os.path.join("some", "nested", "name"),
+ os.path.join("some", "nested", "name"),
+ ],
+ )
+
+ def test_empty_read_apis(self):
+ path = self.get_temp_dir()
+ add_event(path)
+ self.assertEmpty(self.multiplexer.Runs())
+ self.multiplexer.AddRunsFromDirectory(path)
+ self.multiplexer.Reload()
+ self.assertEmpty(self.multiplexer.Runs())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorboard/backend/event_processing/directory_loader.py b/tensorboard/backend/event_processing/directory_loader.py
index 4182c3f38d..27979842e6 100644
--- a/tensorboard/backend/event_processing/directory_loader.py
+++ b/tensorboard/backend/event_processing/directory_loader.py
@@ -33,104 +33,115 @@
class DirectoryLoader(object):
- """Loader for an entire directory, maintaining multiple active file loaders.
-
- This class takes a directory, a factory for loaders, and optionally a
- path filter and watches all the paths inside that directory for new data.
- Each file loader created by the factory must read a path and produce an
- iterator of (timestamp, value) pairs.
-
- Unlike DirectoryWatcher, this class does not assume that only one file
- receives new data at a time; there can be arbitrarily many active files.
- However, any file whose maximum load timestamp fails an "active" predicate
- will be marked as inactive and no longer checked for new data.
- """
-
- def __init__(self, directory, loader_factory, path_filter=lambda x: True,
- active_filter=lambda timestamp: True):
- """Constructs a new MultiFileDirectoryLoader.
-
- Args:
- directory: The directory to load files from.
- loader_factory: A factory for creating loaders. The factory should take a
- path and return an object that has a Load method returning an iterator
- yielding (unix timestamp as float, value) pairs for any new data
- path_filter: If specified, only paths matching this filter are loaded.
- active_filter: If specified, any loader whose maximum load timestamp does
- not pass this filter will be marked as inactive and no longer read.
-
- Raises:
- ValueError: If directory or loader_factory are None.
+ """Loader for an entire directory, maintaining multiple active file
+ loaders.
+
+ This class takes a directory, a factory for loaders, and optionally a
+ path filter and watches all the paths inside that directory for new data.
+ Each file loader created by the factory must read a path and produce an
+ iterator of (timestamp, value) pairs.
+
+ Unlike DirectoryWatcher, this class does not assume that only one file
+ receives new data at a time; there can be arbitrarily many active files.
+ However, any file whose maximum load timestamp fails an "active" predicate
+ will be marked as inactive and no longer checked for new data.
"""
- if directory is None:
- raise ValueError('A directory is required')
- if loader_factory is None:
- raise ValueError('A loader factory is required')
- self._directory = directory
- self._loader_factory = loader_factory
- self._path_filter = path_filter
- self._active_filter = active_filter
- self._loaders = {}
- self._max_timestamps = {}
-
- def Load(self):
- """Loads new values from all active files.
-
- Yields:
- All values that have not been yielded yet.
-
- Raises:
- DirectoryDeletedError: If the directory has been permanently deleted
- (as opposed to being temporarily unavailable).
- """
- try:
- all_paths = io_wrapper.ListDirectoryAbsolute(self._directory)
- paths = sorted(p for p in all_paths if self._path_filter(p))
- for path in paths:
- for value in self._LoadPath(path):
- yield value
- except tf.errors.OpError as e:
- if not tf.io.gfile.exists(self._directory):
- raise directory_watcher.DirectoryDeletedError(
- 'Directory %s has been permanently deleted' % self._directory)
- else:
- logger.info('Ignoring error during file loading: %s' % e)
-
- def _LoadPath(self, path):
- """Generator for values from a single path's loader.
-
- Args:
- path: the path to load from
-
- Yields:
- All values from this path's loader that have not been yielded yet.
- """
- max_timestamp = self._max_timestamps.get(path, None)
- if max_timestamp is _INACTIVE or self._MarkIfInactive(path, max_timestamp):
- logger.debug('Skipping inactive path %s', path)
- return
- loader = self._loaders.get(path, None)
- if loader is None:
- try:
- loader = self._loader_factory(path)
- except tf.errors.NotFoundError:
- # Happens if a file was removed after we listed the directory.
- logger.debug('Skipping nonexistent path %s', path)
- return
- self._loaders[path] = loader
- logger.info('Loading data from path %s', path)
- for timestamp, value in loader.Load():
- if max_timestamp is None or timestamp > max_timestamp:
- max_timestamp = timestamp
- yield value
- if not self._MarkIfInactive(path, max_timestamp):
- self._max_timestamps[path] = max_timestamp
-
- def _MarkIfInactive(self, path, max_timestamp):
- """If max_timestamp is inactive, returns True and marks the path as such."""
- logger.debug('Checking active status of %s at %s', path, max_timestamp)
- if max_timestamp is not None and not self._active_filter(max_timestamp):
- self._max_timestamps[path] = _INACTIVE
- del self._loaders[path]
- return True
- return False
+
+ def __init__(
+ self,
+ directory,
+ loader_factory,
+ path_filter=lambda x: True,
+ active_filter=lambda timestamp: True,
+ ):
+ """Constructs a new MultiFileDirectoryLoader.
+
+ Args:
+ directory: The directory to load files from.
+ loader_factory: A factory for creating loaders. The factory should take a
+ path and return an object that has a Load method returning an iterator
+ yielding (unix timestamp as float, value) pairs for any new data
+ path_filter: If specified, only paths matching this filter are loaded.
+ active_filter: If specified, any loader whose maximum load timestamp does
+ not pass this filter will be marked as inactive and no longer read.
+
+ Raises:
+ ValueError: If directory or loader_factory are None.
+ """
+ if directory is None:
+ raise ValueError("A directory is required")
+ if loader_factory is None:
+ raise ValueError("A loader factory is required")
+ self._directory = directory
+ self._loader_factory = loader_factory
+ self._path_filter = path_filter
+ self._active_filter = active_filter
+ self._loaders = {}
+ self._max_timestamps = {}
+
+ def Load(self):
+ """Loads new values from all active files.
+
+ Yields:
+ All values that have not been yielded yet.
+
+ Raises:
+ DirectoryDeletedError: If the directory has been permanently deleted
+ (as opposed to being temporarily unavailable).
+ """
+ try:
+ all_paths = io_wrapper.ListDirectoryAbsolute(self._directory)
+ paths = sorted(p for p in all_paths if self._path_filter(p))
+ for path in paths:
+ for value in self._LoadPath(path):
+ yield value
+ except tf.errors.OpError as e:
+ if not tf.io.gfile.exists(self._directory):
+ raise directory_watcher.DirectoryDeletedError(
+ "Directory %s has been permanently deleted"
+ % self._directory
+ )
+ else:
+ logger.info("Ignoring error during file loading: %s" % e)
+
+ def _LoadPath(self, path):
+ """Generator for values from a single path's loader.
+
+ Args:
+ path: the path to load from
+
+ Yields:
+ All values from this path's loader that have not been yielded yet.
+ """
+ max_timestamp = self._max_timestamps.get(path, None)
+ if max_timestamp is _INACTIVE or self._MarkIfInactive(
+ path, max_timestamp
+ ):
+ logger.debug("Skipping inactive path %s", path)
+ return
+ loader = self._loaders.get(path, None)
+ if loader is None:
+ try:
+ loader = self._loader_factory(path)
+ except tf.errors.NotFoundError:
+ # Happens if a file was removed after we listed the directory.
+ logger.debug("Skipping nonexistent path %s", path)
+ return
+ self._loaders[path] = loader
+ logger.info("Loading data from path %s", path)
+ for timestamp, value in loader.Load():
+ if max_timestamp is None or timestamp > max_timestamp:
+ max_timestamp = timestamp
+ yield value
+ if not self._MarkIfInactive(path, max_timestamp):
+ self._max_timestamps[path] = max_timestamp
+
+ def _MarkIfInactive(self, path, max_timestamp):
+ """If max_timestamp is inactive, returns True and marks the path as
+ such."""
+ logger.debug("Checking active status of %s at %s", path, max_timestamp)
+ if max_timestamp is not None and not self._active_filter(max_timestamp):
+ self._max_timestamps[path] = _INACTIVE
+ del self._loaders[path]
+ return True
+ return False
diff --git a/tensorboard/backend/event_processing/directory_loader_test.py b/tensorboard/backend/event_processing/directory_loader_test.py
index deaa1030fa..efaab24227 100644
--- a/tensorboard/backend/event_processing/directory_loader_test.py
+++ b/tensorboard/backend/event_processing/directory_loader_test.py
@@ -25,10 +25,10 @@
import shutil
try:
- # python version >= 3.3
- from unittest import mock
+ # python version >= 3.3
+ from unittest import mock
except ImportError:
- import mock # pylint: disable=unused-import
+ import mock # pylint: disable=unused-import
import tensorflow as tf
@@ -40,214 +40,239 @@
class _TimestampedByteLoader(object):
- """A loader that loads timestamped bytes from a file."""
+ """A loader that loads timestamped bytes from a file."""
- def __init__(self, path, registry=None):
- self._path = path
- self._registry = registry if registry is not None else []
- self._registry.append(path)
- self._f = open(path)
+ def __init__(self, path, registry=None):
+ self._path = path
+ self._registry = registry if registry is not None else []
+ self._registry.append(path)
+ self._f = open(path)
- def __del__(self):
- self._registry.remove(self._path)
+ def __del__(self):
+ self._registry.remove(self._path)
- def Load(self):
- while True:
- line = self._f.readline()
- if not line:
- return
- ts, value = line.rstrip('\n').split(':')
- yield float(ts), value
+ def Load(self):
+ while True:
+ line = self._f.readline()
+ if not line:
+ return
+ ts, value = line.rstrip("\n").split(":")
+ yield float(ts), value
class DirectoryLoaderTest(tf.test.TestCase):
+ def setUp(self):
+ # Put everything in a directory so it's easier to delete w/in tests.
+ self._directory = os.path.join(self.get_temp_dir(), "testdir")
+ os.mkdir(self._directory)
+ self._loader = directory_loader.DirectoryLoader(
+ self._directory, _TimestampedByteLoader
+ )
- def setUp(self):
- # Put everything in a directory so it's easier to delete w/in tests.
- self._directory = os.path.join(self.get_temp_dir(), 'testdir')
- os.mkdir(self._directory)
- self._loader = directory_loader.DirectoryLoader(
- self._directory, _TimestampedByteLoader)
-
- def _WriteToFile(self, filename, data, timestamps=None):
- if timestamps is None:
- timestamps = range(len(data))
- self.assertEqual(len(data), len(timestamps))
- path = os.path.join(self._directory, filename)
- with open(path, 'a') as f:
- for byte, timestamp in zip(data, timestamps):
- f.write('%f:%s\n' % (timestamp, byte))
-
- def assertLoaderYields(self, values):
- self.assertEqual(list(self._loader.Load()), values)
-
- def testRaisesWithBadArguments(self):
- with self.assertRaises(ValueError):
- directory_loader.DirectoryLoader(None, lambda x: None)
- with self.assertRaises(ValueError):
- directory_loader.DirectoryLoader('dir', None)
-
- def testEmptyDirectory(self):
- self.assertLoaderYields([])
-
- def testSingleFileLoading(self):
- self._WriteToFile('a', 'abc')
- self.assertLoaderYields(['a', 'b', 'c'])
- self.assertLoaderYields([])
- self._WriteToFile('a', 'xyz')
- self.assertLoaderYields(['x', 'y', 'z'])
- self.assertLoaderYields([])
-
- def testMultipleFileLoading(self):
- self._WriteToFile('a', 'a')
- self._WriteToFile('b', 'b')
- self.assertLoaderYields(['a', 'b'])
- self.assertLoaderYields([])
- self._WriteToFile('a', 'A')
- self._WriteToFile('b', 'B')
- self._WriteToFile('c', 'c')
- # The loader should read new data from all the files.
- self.assertLoaderYields(['A', 'B', 'c'])
- self.assertLoaderYields([])
-
- def testMultipleFileLoading_intermediateEmptyFiles(self):
- self._WriteToFile('a', 'a')
- self._WriteToFile('b', '')
- self._WriteToFile('c', 'c')
- self.assertLoaderYields(['a', 'c'])
-
- def testPathFilter(self):
- self._loader = directory_loader.DirectoryLoader(
- self._directory, _TimestampedByteLoader,
- lambda path: 'tfevents' in path)
- self._WriteToFile('skipped', 'a')
- self._WriteToFile('event.out.tfevents.foo.bar', 'b')
- self._WriteToFile('tf.event', 'c')
- self.assertLoaderYields(['b'])
-
- def testActiveFilter_staticFilterBehavior(self):
- """Tests behavior of a static active_filter."""
- loader_registry = []
- loader_factory = functools.partial(
- _TimestampedByteLoader, registry=loader_registry)
- active_filter = lambda timestamp: timestamp >= 2
- self._loader = directory_loader.DirectoryLoader(
- self._directory, loader_factory, active_filter=active_filter)
- def assertLoadersForPaths(paths):
- paths = [os.path.join(self._directory, path) for path in paths]
- self.assertEqual(loader_registry, paths)
- # a: normal-looking file.
- # b: file without sufficiently active data (should be marked inactive).
- # c: file with timestamps in reverse order (max computed correctly).
- # d: empty file (should be considered active in absence of timestamps).
- self._WriteToFile('a', ['A1', 'A2'], [1, 2])
- self._WriteToFile('b', ['B1'], [1])
- self._WriteToFile('c', ['C2', 'C1', 'C0'], [2, 1, 0])
- self._WriteToFile('d', [], [])
- self.assertLoaderYields(['A1', 'A2', 'B1', 'C2', 'C1', 'C0'])
- assertLoadersForPaths(['a', 'c', 'd'])
- self._WriteToFile('a', ['A3'], [3])
- self._WriteToFile('b', ['B3'], [3])
- self._WriteToFile('c', ['C0'], [0])
- self._WriteToFile('d', ['D3'], [3])
- self.assertLoaderYields(['A3', 'C0', 'D3'])
- assertLoadersForPaths(['a', 'c', 'd'])
- # Check that a 0 timestamp in file C on the most recent load doesn't
- # override the max timestamp of 2 seen in the earlier load.
- self._WriteToFile('c', ['C4'], [4])
- self.assertLoaderYields(['C4'])
- assertLoadersForPaths(['a', 'c', 'd'])
-
- def testActiveFilter_dynamicFilterBehavior(self):
- """Tests behavior of a dynamic active_filter."""
- loader_registry = []
- loader_factory = functools.partial(
- _TimestampedByteLoader, registry=loader_registry)
- threshold = 0
- active_filter = lambda timestamp: timestamp >= threshold
- self._loader = directory_loader.DirectoryLoader(
- self._directory, loader_factory, active_filter=active_filter)
- def assertLoadersForPaths(paths):
- paths = [os.path.join(self._directory, path) for path in paths]
- self.assertEqual(loader_registry, paths)
- self._WriteToFile('a', ['A1', 'A2'], [1, 2])
- self._WriteToFile('b', ['B1', 'B2', 'B3'], [1, 2, 3])
- self._WriteToFile('c', ['C1'], [1])
- threshold = 2
- # First load pass should leave file C marked inactive.
- self.assertLoaderYields(['A1', 'A2', 'B1', 'B2', 'B3', 'C1'])
- assertLoadersForPaths(['a', 'b'])
- self._WriteToFile('a', ['A4'], [4])
- self._WriteToFile('b', ['B4'], [4])
- self._WriteToFile('c', ['C4'], [4])
- threshold = 3
- # Second load pass should mark file A as inactive (due to newly
- # increased threshold) and thus skip reading data from it.
- self.assertLoaderYields(['B4'])
- assertLoadersForPaths(['b'])
- self._WriteToFile('b', ['B5', 'B6'], [5, 6])
- # Simulate a third pass in which the threshold increases while
- # we're processing a file, so it's still active at the start of the
- # load but should be marked inactive at the end.
- load_generator = self._loader.Load()
- self.assertEqual('B5', next(load_generator))
- threshold = 7
- self.assertEqual(['B6'], list(load_generator))
- assertLoadersForPaths([])
- # Confirm that all loaders are now inactive.
- self._WriteToFile('b', ['B7'], [7])
- self.assertLoaderYields([])
-
- def testDoesntCrashWhenCurrentFileIsDeleted(self):
- # Use actual file loader so it emits the real error.
- self._loader = directory_loader.DirectoryLoader(
- self._directory, event_file_loader.TimestampedEventFileLoader)
- with test_util.FileWriter(self._directory, filename_suffix='.a') as writer_a:
- writer_a.add_test_summary('a')
- events = list(self._loader.Load())
- events.pop(0) # Ignore the file_version event.
- self.assertEqual(1, len(events))
- self.assertEqual('a', events[0].summary.value[0].tag)
- os.remove(glob.glob(os.path.join(self._directory, '*.a'))[0])
- with test_util.FileWriter(self._directory, filename_suffix='.b') as writer_b:
- writer_b.add_test_summary('b')
- events = list(self._loader.Load())
- events.pop(0) # Ignore the file_version event.
- self.assertEqual(1, len(events))
- self.assertEqual('b', events[0].summary.value[0].tag)
-
- def testDoesntCrashWhenUpcomingFileIsDeleted(self):
- # Use actual file loader so it emits the real error.
- self._loader = directory_loader.DirectoryLoader(
- self._directory, event_file_loader.TimestampedEventFileLoader)
- with test_util.FileWriter(self._directory, filename_suffix='.a') as writer_a:
- writer_a.add_test_summary('a')
- with test_util.FileWriter(self._directory, filename_suffix='.b') as writer_b:
- writer_b.add_test_summary('b')
- generator = self._loader.Load()
- next(generator) # Ignore the file_version event.
- event = next(generator)
- self.assertEqual('a', event.summary.value[0].tag)
- os.remove(glob.glob(os.path.join(self._directory, '*.b'))[0])
- self.assertEmpty(list(generator))
-
- def testRaisesDirectoryDeletedError_whenDirectoryIsDeleted(self):
- self._WriteToFile('a', 'a')
- self.assertLoaderYields(['a'])
- shutil.rmtree(self._directory)
- with self.assertRaises(directory_watcher.DirectoryDeletedError):
- next(self._loader.Load())
-
- def testDoesntRaiseDirectoryDeletedError_forUnrecognizedException(self):
- self._WriteToFile('a', 'a')
- self.assertLoaderYields(['a'])
- class MyException(Exception):
- pass
- with mock.patch.object(io_wrapper, 'ListDirectoryAbsolute') as mock_listdir:
- mock_listdir.side_effect = MyException
- with self.assertRaises(MyException):
- next(self._loader.Load())
- self.assertLoaderYields([])
-
-if __name__ == '__main__':
- tf.test.main()
+ def _WriteToFile(self, filename, data, timestamps=None):
+ if timestamps is None:
+ timestamps = range(len(data))
+ self.assertEqual(len(data), len(timestamps))
+ path = os.path.join(self._directory, filename)
+ with open(path, "a") as f:
+ for byte, timestamp in zip(data, timestamps):
+ f.write("%f:%s\n" % (timestamp, byte))
+
+ def assertLoaderYields(self, values):
+ self.assertEqual(list(self._loader.Load()), values)
+
+ def testRaisesWithBadArguments(self):
+ with self.assertRaises(ValueError):
+ directory_loader.DirectoryLoader(None, lambda x: None)
+ with self.assertRaises(ValueError):
+ directory_loader.DirectoryLoader("dir", None)
+
+ def testEmptyDirectory(self):
+ self.assertLoaderYields([])
+
+ def testSingleFileLoading(self):
+ self._WriteToFile("a", "abc")
+ self.assertLoaderYields(["a", "b", "c"])
+ self.assertLoaderYields([])
+ self._WriteToFile("a", "xyz")
+ self.assertLoaderYields(["x", "y", "z"])
+ self.assertLoaderYields([])
+
+ def testMultipleFileLoading(self):
+ self._WriteToFile("a", "a")
+ self._WriteToFile("b", "b")
+ self.assertLoaderYields(["a", "b"])
+ self.assertLoaderYields([])
+ self._WriteToFile("a", "A")
+ self._WriteToFile("b", "B")
+ self._WriteToFile("c", "c")
+ # The loader should read new data from all the files.
+ self.assertLoaderYields(["A", "B", "c"])
+ self.assertLoaderYields([])
+
+ def testMultipleFileLoading_intermediateEmptyFiles(self):
+ self._WriteToFile("a", "a")
+ self._WriteToFile("b", "")
+ self._WriteToFile("c", "c")
+ self.assertLoaderYields(["a", "c"])
+
+ def testPathFilter(self):
+ self._loader = directory_loader.DirectoryLoader(
+ self._directory,
+ _TimestampedByteLoader,
+ lambda path: "tfevents" in path,
+ )
+ self._WriteToFile("skipped", "a")
+ self._WriteToFile("event.out.tfevents.foo.bar", "b")
+ self._WriteToFile("tf.event", "c")
+ self.assertLoaderYields(["b"])
+
+ def testActiveFilter_staticFilterBehavior(self):
+ """Tests behavior of a static active_filter."""
+ loader_registry = []
+ loader_factory = functools.partial(
+ _TimestampedByteLoader, registry=loader_registry
+ )
+ active_filter = lambda timestamp: timestamp >= 2
+ self._loader = directory_loader.DirectoryLoader(
+ self._directory, loader_factory, active_filter=active_filter
+ )
+
+ def assertLoadersForPaths(paths):
+ paths = [os.path.join(self._directory, path) for path in paths]
+ self.assertEqual(loader_registry, paths)
+
+ # a: normal-looking file.
+ # b: file without sufficiently active data (should be marked inactive).
+ # c: file with timestamps in reverse order (max computed correctly).
+ # d: empty file (should be considered active in absence of timestamps).
+ self._WriteToFile("a", ["A1", "A2"], [1, 2])
+ self._WriteToFile("b", ["B1"], [1])
+ self._WriteToFile("c", ["C2", "C1", "C0"], [2, 1, 0])
+ self._WriteToFile("d", [], [])
+ self.assertLoaderYields(["A1", "A2", "B1", "C2", "C1", "C0"])
+ assertLoadersForPaths(["a", "c", "d"])
+ self._WriteToFile("a", ["A3"], [3])
+ self._WriteToFile("b", ["B3"], [3])
+ self._WriteToFile("c", ["C0"], [0])
+ self._WriteToFile("d", ["D3"], [3])
+ self.assertLoaderYields(["A3", "C0", "D3"])
+ assertLoadersForPaths(["a", "c", "d"])
+ # Check that a 0 timestamp in file C on the most recent load doesn't
+ # override the max timestamp of 2 seen in the earlier load.
+ self._WriteToFile("c", ["C4"], [4])
+ self.assertLoaderYields(["C4"])
+ assertLoadersForPaths(["a", "c", "d"])
+
+ def testActiveFilter_dynamicFilterBehavior(self):
+ """Tests behavior of a dynamic active_filter."""
+ loader_registry = []
+ loader_factory = functools.partial(
+ _TimestampedByteLoader, registry=loader_registry
+ )
+ threshold = 0
+ active_filter = lambda timestamp: timestamp >= threshold
+ self._loader = directory_loader.DirectoryLoader(
+ self._directory, loader_factory, active_filter=active_filter
+ )
+
+ def assertLoadersForPaths(paths):
+ paths = [os.path.join(self._directory, path) for path in paths]
+ self.assertEqual(loader_registry, paths)
+
+ self._WriteToFile("a", ["A1", "A2"], [1, 2])
+ self._WriteToFile("b", ["B1", "B2", "B3"], [1, 2, 3])
+ self._WriteToFile("c", ["C1"], [1])
+ threshold = 2
+ # First load pass should leave file C marked inactive.
+ self.assertLoaderYields(["A1", "A2", "B1", "B2", "B3", "C1"])
+ assertLoadersForPaths(["a", "b"])
+ self._WriteToFile("a", ["A4"], [4])
+ self._WriteToFile("b", ["B4"], [4])
+ self._WriteToFile("c", ["C4"], [4])
+ threshold = 3
+ # Second load pass should mark file A as inactive (due to newly
+ # increased threshold) and thus skip reading data from it.
+ self.assertLoaderYields(["B4"])
+ assertLoadersForPaths(["b"])
+ self._WriteToFile("b", ["B5", "B6"], [5, 6])
+ # Simulate a third pass in which the threshold increases while
+ # we're processing a file, so it's still active at the start of the
+ # load but should be marked inactive at the end.
+ load_generator = self._loader.Load()
+ self.assertEqual("B5", next(load_generator))
+ threshold = 7
+ self.assertEqual(["B6"], list(load_generator))
+ assertLoadersForPaths([])
+ # Confirm that all loaders are now inactive.
+ self._WriteToFile("b", ["B7"], [7])
+ self.assertLoaderYields([])
+
+ def testDoesntCrashWhenCurrentFileIsDeleted(self):
+ # Use actual file loader so it emits the real error.
+ self._loader = directory_loader.DirectoryLoader(
+ self._directory, event_file_loader.TimestampedEventFileLoader
+ )
+ with test_util.FileWriter(
+ self._directory, filename_suffix=".a"
+ ) as writer_a:
+ writer_a.add_test_summary("a")
+ events = list(self._loader.Load())
+ events.pop(0) # Ignore the file_version event.
+ self.assertEqual(1, len(events))
+ self.assertEqual("a", events[0].summary.value[0].tag)
+ os.remove(glob.glob(os.path.join(self._directory, "*.a"))[0])
+ with test_util.FileWriter(
+ self._directory, filename_suffix=".b"
+ ) as writer_b:
+ writer_b.add_test_summary("b")
+ events = list(self._loader.Load())
+ events.pop(0) # Ignore the file_version event.
+ self.assertEqual(1, len(events))
+ self.assertEqual("b", events[0].summary.value[0].tag)
+
+ def testDoesntCrashWhenUpcomingFileIsDeleted(self):
+ # Use actual file loader so it emits the real error.
+ self._loader = directory_loader.DirectoryLoader(
+ self._directory, event_file_loader.TimestampedEventFileLoader
+ )
+ with test_util.FileWriter(
+ self._directory, filename_suffix=".a"
+ ) as writer_a:
+ writer_a.add_test_summary("a")
+ with test_util.FileWriter(
+ self._directory, filename_suffix=".b"
+ ) as writer_b:
+ writer_b.add_test_summary("b")
+ generator = self._loader.Load()
+ next(generator) # Ignore the file_version event.
+ event = next(generator)
+ self.assertEqual("a", event.summary.value[0].tag)
+ os.remove(glob.glob(os.path.join(self._directory, "*.b"))[0])
+ self.assertEmpty(list(generator))
+
+ def testRaisesDirectoryDeletedError_whenDirectoryIsDeleted(self):
+ self._WriteToFile("a", "a")
+ self.assertLoaderYields(["a"])
+ shutil.rmtree(self._directory)
+ with self.assertRaises(directory_watcher.DirectoryDeletedError):
+ next(self._loader.Load())
+
+ def testDoesntRaiseDirectoryDeletedError_forUnrecognizedException(self):
+ self._WriteToFile("a", "a")
+ self.assertLoaderYields(["a"])
+
+ class MyException(Exception):
+ pass
+
+ with mock.patch.object(
+ io_wrapper, "ListDirectoryAbsolute"
+ ) as mock_listdir:
+ mock_listdir.side_effect = MyException
+ with self.assertRaises(MyException):
+ next(self._loader.Load())
+ self.assertLoaderYields([])
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorboard/backend/event_processing/directory_watcher.py b/tensorboard/backend/event_processing/directory_watcher.py
index 6ca7df19bd..05dca4603b 100644
--- a/tensorboard/backend/event_processing/directory_watcher.py
+++ b/tensorboard/backend/event_processing/directory_watcher.py
@@ -27,231 +27,252 @@
logger = tb_logging.get_logger()
-class DirectoryWatcher(object):
- """A DirectoryWatcher wraps a loader to load from a sequence of paths.
-
- A loader reads a path and produces some kind of values as an iterator. A
- DirectoryWatcher takes a directory, a factory for loaders, and optionally a
- path filter and watches all the paths inside that directory.
-
- This class is only valid under the assumption that only one path will be
- written to by the data source at a time and that once the source stops writing
- to a path, it will start writing to a new path that's lexicographically
- greater and never come back. It uses some heuristics to check whether this is
- true based on tracking changes to the files' sizes, but the check can have
- false negatives. However, it should have no false positives.
- """
-
- def __init__(self, directory, loader_factory, path_filter=lambda x: True):
- """Constructs a new DirectoryWatcher.
-
- Args:
- directory: The directory to load files from.
- loader_factory: A factory for creating loaders. The factory should take a
- path and return an object that has a Load method returning an
- iterator that will yield all events that have not been yielded yet.
- path_filter: If specified, only paths matching this filter are loaded.
-
- Raises:
- ValueError: If path_provider or loader_factory are None.
- """
- if directory is None:
- raise ValueError('A directory is required')
- if loader_factory is None:
- raise ValueError('A loader factory is required')
- self._directory = directory
- self._path = None
- self._loader_factory = loader_factory
- self._loader = None
- self._path_filter = path_filter
- self._ooo_writes_detected = False
- # The file size for each file at the time it was finalized.
- self._finalized_sizes = {}
-
- def Load(self):
- """Loads new values.
-
- The watcher will load from one path at a time; as soon as that path stops
- yielding events, it will move on to the next path. We assume that old paths
- are never modified after a newer path has been written. As a result, Load()
- can be called multiple times in a row without losing events that have not
- been yielded yet. In other words, we guarantee that every event will be
- yielded exactly once.
-
- Yields:
- All values that have not been yielded yet.
-
- Raises:
- DirectoryDeletedError: If the directory has been permanently deleted
- (as opposed to being temporarily unavailable).
- """
- try:
- for event in self._LoadInternal():
- yield event
- except tf.errors.OpError:
- if not tf.io.gfile.exists(self._directory):
- raise DirectoryDeletedError(
- 'Directory %s has been permanently deleted' % self._directory)
-
- def _LoadInternal(self):
- """Internal implementation of Load().
-
- The only difference between this and Load() is that the latter will throw
- DirectoryDeletedError on I/O errors if it thinks that the directory has been
- permanently deleted.
-
- Yields:
- All values that have not been yielded yet.
- """
-
- # If the loader exists, check it for a value.
- if not self._loader:
- self._InitializeLoader()
-
- # If it still doesn't exist, there is no data
- if not self._loader:
- return
-
- while True:
- # Yield all the new events in the path we're currently loading from.
- for event in self._loader.Load():
- yield event
-
- next_path = self._GetNextPath()
- if not next_path:
- logger.info('No path found after %s', self._path)
- # Current path is empty and there are no new paths, so we're done.
- return
-
- # There's a new path, so check to make sure there weren't any events
- # written between when we finished reading the current path and when we
- # checked for the new one. The sequence of events might look something
- # like this:
- #
- # 1. Event #1 written to path #1.
- # 2. We check for events and yield event #1 from path #1
- # 3. We check for events and see that there are no more events in path #1.
- # 4. Event #2 is written to path #1.
- # 5. Event #3 is written to path #2.
- # 6. We check for a new path and see that path #2 exists.
- #
- # Without this loop, we would miss event #2. We're also guaranteed by the
- # loader contract that no more events will be written to path #1 after
- # events start being written to path #2, so we don't have to worry about
- # that.
- for event in self._loader.Load():
- yield event
-
- logger.info('Directory watcher advancing from %s to %s', self._path,
- next_path)
-
- # Advance to the next path and start over.
- self._SetPath(next_path)
-
- # The number of paths before the current one to check for out of order writes.
- _OOO_WRITE_CHECK_COUNT = 20
-
- def OutOfOrderWritesDetected(self):
- """Returns whether any out-of-order writes have been detected.
-
- Out-of-order writes are only checked as part of the Load() iterator. Once an
- out-of-order write is detected, this function will always return true.
-
- Note that out-of-order write detection is not performed on GCS paths, so
- this function will always return false.
-
- Returns:
- Whether any out-of-order write has ever been detected by this watcher.
+class DirectoryWatcher(object):
+ """A DirectoryWatcher wraps a loader to load from a sequence of paths.
+
+ A loader reads a path and produces some kind of values as an iterator. A
+ DirectoryWatcher takes a directory, a factory for loaders, and optionally a
+ path filter and watches all the paths inside that directory.
+
+ This class is only valid under the assumption that only one path will be
+ written to by the data source at a time and that once the source stops writing
+ to a path, it will start writing to a new path that's lexicographically
+ greater and never come back. It uses some heuristics to check whether this is
+ true based on tracking changes to the files' sizes, but the check can have
+ false negatives. However, it should have no false positives.
"""
- return self._ooo_writes_detected
- def _InitializeLoader(self):
- path = self._GetNextPath()
- if path:
- self._SetPath(path)
+ def __init__(self, directory, loader_factory, path_filter=lambda x: True):
+ """Constructs a new DirectoryWatcher.
+
+ Args:
+ directory: The directory to load files from.
+ loader_factory: A factory for creating loaders. The factory should take a
+ path and return an object that has a Load method returning an
+ iterator that will yield all events that have not been yielded yet.
+ path_filter: If specified, only paths matching this filter are loaded.
+
+ Raises:
+ ValueError: If path_provider or loader_factory are None.
+ """
+ if directory is None:
+ raise ValueError("A directory is required")
+ if loader_factory is None:
+ raise ValueError("A loader factory is required")
+ self._directory = directory
+ self._path = None
+ self._loader_factory = loader_factory
+ self._loader = None
+ self._path_filter = path_filter
+ self._ooo_writes_detected = False
+ # The file size for each file at the time it was finalized.
+ self._finalized_sizes = {}
+
+ def Load(self):
+ """Loads new values.
+
+ The watcher will load from one path at a time; as soon as that path stops
+ yielding events, it will move on to the next path. We assume that old paths
+ are never modified after a newer path has been written. As a result, Load()
+ can be called multiple times in a row without losing events that have not
+ been yielded yet. In other words, we guarantee that every event will be
+ yielded exactly once.
+
+ Yields:
+ All values that have not been yielded yet.
+
+ Raises:
+ DirectoryDeletedError: If the directory has been permanently deleted
+ (as opposed to being temporarily unavailable).
+ """
+ try:
+ for event in self._LoadInternal():
+ yield event
+ except tf.errors.OpError:
+ if not tf.io.gfile.exists(self._directory):
+ raise DirectoryDeletedError(
+ "Directory %s has been permanently deleted"
+ % self._directory
+ )
+
+ def _LoadInternal(self):
+ """Internal implementation of Load().
+
+ The only difference between this and Load() is that the latter will throw
+ DirectoryDeletedError on I/O errors if it thinks that the directory has been
+ permanently deleted.
+
+ Yields:
+ All values that have not been yielded yet.
+ """
+
+ # If the loader exists, check it for a value.
+ if not self._loader:
+ self._InitializeLoader()
+
+ # If it still doesn't exist, there is no data
+ if not self._loader:
+ return
+
+ while True:
+ # Yield all the new events in the path we're currently loading from.
+ for event in self._loader.Load():
+ yield event
+
+ next_path = self._GetNextPath()
+ if not next_path:
+ logger.info("No path found after %s", self._path)
+ # Current path is empty and there are no new paths, so we're done.
+ return
+
+ # There's a new path, so check to make sure there weren't any events
+ # written between when we finished reading the current path and when we
+ # checked for the new one. The sequence of events might look something
+ # like this:
+ #
+ # 1. Event #1 written to path #1.
+ # 2. We check for events and yield event #1 from path #1
+ # 3. We check for events and see that there are no more events in path #1.
+ # 4. Event #2 is written to path #1.
+ # 5. Event #3 is written to path #2.
+ # 6. We check for a new path and see that path #2 exists.
+ #
+ # Without this loop, we would miss event #2. We're also guaranteed by the
+ # loader contract that no more events will be written to path #1 after
+ # events start being written to path #2, so we don't have to worry about
+ # that.
+ for event in self._loader.Load():
+ yield event
+
+ logger.info(
+ "Directory watcher advancing from %s to %s",
+ self._path,
+ next_path,
+ )
+
+ # Advance to the next path and start over.
+ self._SetPath(next_path)
+
+ # The number of paths before the current one to check for out of order writes.
+ _OOO_WRITE_CHECK_COUNT = 20
+
+ def OutOfOrderWritesDetected(self):
+ """Returns whether any out-of-order writes have been detected.
+
+ Out-of-order writes are only checked as part of the Load() iterator. Once an
+ out-of-order write is detected, this function will always return true.
+
+ Note that out-of-order write detection is not performed on GCS paths, so
+ this function will always return false.
+
+ Returns:
+ Whether any out-of-order write has ever been detected by this watcher.
+ """
+ return self._ooo_writes_detected
+
+ def _InitializeLoader(self):
+ path = self._GetNextPath()
+ if path:
+ self._SetPath(path)
+
+ def _SetPath(self, path):
+ """Sets the current path to watch for new events.
+
+ This also records the size of the old path, if any. If the size can't be
+ found, an error is logged.
+
+ Args:
+ path: The full path of the file to watch.
+ """
+ old_path = self._path
+ if old_path and not io_wrapper.IsCloudPath(old_path):
+ try:
+ # We're done with the path, so store its size.
+ size = tf.io.gfile.stat(old_path).length
+ logger.debug("Setting latest size of %s to %d", old_path, size)
+ self._finalized_sizes[old_path] = size
+ except tf.errors.OpError as e:
+ logger.error("Unable to get size of %s: %s", old_path, e)
+
+ self._path = path
+ self._loader = self._loader_factory(path)
+
+ def _GetNextPath(self):
+ """Gets the next path to load from.
+
+ This function also does the checking for out-of-order writes as it iterates
+ through the paths.
+
+ Returns:
+ The next path to load events from, or None if there are no more paths.
+ """
+ paths = sorted(
+ path
+ for path in io_wrapper.ListDirectoryAbsolute(self._directory)
+ if self._path_filter(path)
+ )
+ if not paths:
+ return None
+
+ if self._path is None:
+ return paths[0]
+
+ # Don't bother checking if the paths are GCS (which we can't check) or if
+ # we've already detected an OOO write.
+ if (
+ not io_wrapper.IsCloudPath(paths[0])
+ and not self._ooo_writes_detected
+ ):
+ # Check the previous _OOO_WRITE_CHECK_COUNT paths for out of order writes.
+ current_path_index = bisect.bisect_left(paths, self._path)
+ ooo_check_start = max(
+ 0, current_path_index - self._OOO_WRITE_CHECK_COUNT
+ )
+ for path in paths[ooo_check_start:current_path_index]:
+ if self._HasOOOWrite(path):
+ self._ooo_writes_detected = True
+ break
+
+ next_paths = list(
+ path for path in paths if self._path is None or path > self._path
+ )
+ if next_paths:
+ return min(next_paths)
+ else:
+ return None
+
+ def _HasOOOWrite(self, path):
+ """Returns whether the path has had an out-of-order write."""
+ # Check the sizes of each path before the current one.
+ size = tf.io.gfile.stat(path).length
+ old_size = self._finalized_sizes.get(path, None)
+ if size != old_size:
+ if old_size is None:
+ logger.error(
+ "File %s created after file %s even though it's "
+ "lexicographically earlier",
+ path,
+ self._path,
+ )
+ else:
+ logger.error(
+ "File %s updated even though the current file is %s",
+ path,
+ self._path,
+ )
+ return True
+ else:
+ return False
- def _SetPath(self, path):
- """Sets the current path to watch for new events.
- This also records the size of the old path, if any. If the size can't be
- found, an error is logged.
+class DirectoryDeletedError(Exception):
+ """Thrown by Load() when the directory is *permanently* gone.
- Args:
- path: The full path of the file to watch.
- """
- old_path = self._path
- if old_path and not io_wrapper.IsCloudPath(old_path):
- try:
- # We're done with the path, so store its size.
- size = tf.io.gfile.stat(old_path).length
- logger.debug('Setting latest size of %s to %d', old_path, size)
- self._finalized_sizes[old_path] = size
- except tf.errors.OpError as e:
- logger.error('Unable to get size of %s: %s', old_path, e)
-
- self._path = path
- self._loader = self._loader_factory(path)
-
- def _GetNextPath(self):
- """Gets the next path to load from.
-
- This function also does the checking for out-of-order writes as it iterates
- through the paths.
-
- Returns:
- The next path to load events from, or None if there are no more paths.
+ We distinguish this from temporary errors so that other code can
+ decide to drop all of our data only when a directory has been
+ intentionally deleted, as opposed to due to transient filesystem
+ errors.
"""
- paths = sorted(path
- for path in io_wrapper.ListDirectoryAbsolute(self._directory)
- if self._path_filter(path))
- if not paths:
- return None
-
- if self._path is None:
- return paths[0]
-
- # Don't bother checking if the paths are GCS (which we can't check) or if
- # we've already detected an OOO write.
- if not io_wrapper.IsCloudPath(paths[0]) and not self._ooo_writes_detected:
- # Check the previous _OOO_WRITE_CHECK_COUNT paths for out of order writes.
- current_path_index = bisect.bisect_left(paths, self._path)
- ooo_check_start = max(0, current_path_index - self._OOO_WRITE_CHECK_COUNT)
- for path in paths[ooo_check_start:current_path_index]:
- if self._HasOOOWrite(path):
- self._ooo_writes_detected = True
- break
-
- next_paths = list(path
- for path in paths
- if self._path is None or path > self._path)
- if next_paths:
- return min(next_paths)
- else:
- return None
-
- def _HasOOOWrite(self, path):
- """Returns whether the path has had an out-of-order write."""
- # Check the sizes of each path before the current one.
- size = tf.io.gfile.stat(path).length
- old_size = self._finalized_sizes.get(path, None)
- if size != old_size:
- if old_size is None:
- logger.error('File %s created after file %s even though it\'s '
- 'lexicographically earlier', path, self._path)
- else:
- logger.error('File %s updated even though the current file is %s',
- path, self._path)
- return True
- else:
- return False
-
-
-class DirectoryDeletedError(Exception):
- """Thrown by Load() when the directory is *permanently* gone.
- We distinguish this from temporary errors so that other code can decide to
- drop all of our data only when a directory has been intentionally deleted,
- as opposed to due to transient filesystem errors.
- """
- pass
+ pass
diff --git a/tensorboard/backend/event_processing/directory_watcher_test.py b/tensorboard/backend/event_processing/directory_watcher_test.py
index b9a00007d4..619edb8fbb 100644
--- a/tensorboard/backend/event_processing/directory_watcher_test.py
+++ b/tensorboard/backend/event_processing/directory_watcher_test.py
@@ -29,185 +29,192 @@
class _ByteLoader(object):
- """A loader that loads individual bytes from a file."""
+ """A loader that loads individual bytes from a file."""
- def __init__(self, path):
- self._f = open(path)
- self.bytes_read = 0
+ def __init__(self, path):
+ self._f = open(path)
+ self.bytes_read = 0
- def Load(self):
- while True:
- self._f.seek(self.bytes_read)
- byte = self._f.read(1)
- if byte:
- self.bytes_read += 1
- yield byte
- else:
- return
+ def Load(self):
+ while True:
+ self._f.seek(self.bytes_read)
+ byte = self._f.read(1)
+ if byte:
+ self.bytes_read += 1
+ yield byte
+ else:
+ return
class DirectoryWatcherTest(tf.test.TestCase):
-
- def setUp(self):
- # Put everything in a directory so it's easier to delete.
- self._directory = os.path.join(self.get_temp_dir(), 'monitor_dir')
- os.mkdir(self._directory)
- self._watcher = directory_watcher.DirectoryWatcher(self._directory,
- _ByteLoader)
- self.stubs = tf.compat.v1.test.StubOutForTesting()
-
- def tearDown(self):
- self.stubs.CleanUp()
- try:
- shutil.rmtree(self._directory)
- except OSError:
- # Some tests delete the directory.
- pass
-
- def _WriteToFile(self, filename, data):
- path = os.path.join(self._directory, filename)
- with open(path, 'a') as f:
- f.write(data)
-
- def _LoadAllEvents(self):
- """Loads all events in the watcher."""
- for _ in self._watcher.Load():
- pass
-
- def assertWatcherYields(self, values):
- self.assertEqual(list(self._watcher.Load()), values)
-
- def testRaisesWithBadArguments(self):
- with self.assertRaises(ValueError):
- directory_watcher.DirectoryWatcher(None, lambda x: None)
- with self.assertRaises(ValueError):
- directory_watcher.DirectoryWatcher('dir', None)
-
- def testEmptyDirectory(self):
- self.assertWatcherYields([])
-
- def testSingleWrite(self):
- self._WriteToFile('a', 'abc')
- self.assertWatcherYields(['a', 'b', 'c'])
- self.assertFalse(self._watcher.OutOfOrderWritesDetected())
-
- def testMultipleWrites(self):
- self._WriteToFile('a', 'abc')
- self.assertWatcherYields(['a', 'b', 'c'])
- self._WriteToFile('a', 'xyz')
- self.assertWatcherYields(['x', 'y', 'z'])
- self.assertFalse(self._watcher.OutOfOrderWritesDetected())
-
- def testMultipleLoads(self):
- self._WriteToFile('a', 'a')
- self._watcher.Load()
- self._watcher.Load()
- self.assertWatcherYields(['a'])
- self.assertFalse(self._watcher.OutOfOrderWritesDetected())
-
- def testMultipleFilesAtOnce(self):
- self._WriteToFile('b', 'b')
- self._WriteToFile('a', 'a')
- self.assertWatcherYields(['a', 'b'])
- self.assertFalse(self._watcher.OutOfOrderWritesDetected())
-
- def testFinishesLoadingFileWhenSwitchingToNewFile(self):
- self._WriteToFile('a', 'a')
- # Empty the iterator.
- self.assertEqual(['a'], list(self._watcher.Load()))
- self._WriteToFile('a', 'b')
- self._WriteToFile('b', 'c')
- # The watcher should finish its current file before starting a new one.
- self.assertWatcherYields(['b', 'c'])
- self.assertFalse(self._watcher.OutOfOrderWritesDetected())
-
- def testIntermediateEmptyFiles(self):
- self._WriteToFile('a', 'a')
- self._WriteToFile('b', '')
- self._WriteToFile('c', 'c')
- self.assertWatcherYields(['a', 'c'])
- self.assertFalse(self._watcher.OutOfOrderWritesDetected())
-
- def testPathFilter(self):
- self._watcher = directory_watcher.DirectoryWatcher(
- self._directory, _ByteLoader,
- lambda path: 'do_not_watch_me' not in path)
-
- self._WriteToFile('a', 'a')
- self._WriteToFile('do_not_watch_me', 'b')
- self._WriteToFile('c', 'c')
- self.assertWatcherYields(['a', 'c'])
- self.assertFalse(self._watcher.OutOfOrderWritesDetected())
-
- def testDetectsNewOldFiles(self):
- self._WriteToFile('b', 'a')
- self._LoadAllEvents()
- self._WriteToFile('a', 'a')
- self._LoadAllEvents()
- self.assertTrue(self._watcher.OutOfOrderWritesDetected())
-
- def testIgnoresNewerFiles(self):
- self._WriteToFile('a', 'a')
- self._LoadAllEvents()
- self._WriteToFile('q', 'a')
- self._LoadAllEvents()
- self.assertFalse(self._watcher.OutOfOrderWritesDetected())
-
- def testDetectsChangingOldFiles(self):
- self._WriteToFile('a', 'a')
- self._WriteToFile('b', 'a')
- self._LoadAllEvents()
- self._WriteToFile('a', 'c')
- self._LoadAllEvents()
- self.assertTrue(self._watcher.OutOfOrderWritesDetected())
-
- def testDoesntCrashWhenFileIsDeleted(self):
- self._WriteToFile('a', 'a')
- self._LoadAllEvents()
- os.remove(os.path.join(self._directory, 'a'))
- self._WriteToFile('b', 'b')
- self.assertWatcherYields(['b'])
-
- def testRaisesRightErrorWhenDirectoryIsDeleted(self):
- self._WriteToFile('a', 'a')
- self._LoadAllEvents()
- shutil.rmtree(self._directory)
- with self.assertRaises(directory_watcher.DirectoryDeletedError):
- self._LoadAllEvents()
-
- def testDoesntRaiseDirectoryDeletedErrorIfOutageIsTransient(self):
- self._WriteToFile('a', 'a')
- self._LoadAllEvents()
- shutil.rmtree(self._directory)
-
- # Fake a single transient I/O error.
- def FakeFactory(original):
-
- def Fake(*args, **kwargs):
- if FakeFactory.has_been_called:
- original(*args, **kwargs)
- else:
- raise OSError('lp0 temporarily on fire')
-
- return Fake
-
- FakeFactory.has_been_called = False
-
- stub_names = [
- 'ListDirectoryAbsolute',
- 'ListRecursivelyViaGlobbing',
- 'ListRecursivelyViaWalking',
- ]
- for stub_name in stub_names:
- self.stubs.Set(io_wrapper, stub_name,
- FakeFactory(getattr(io_wrapper, stub_name)))
- for stub_name in ['exists', 'stat']:
- self.stubs.Set(tf.io.gfile, stub_name,
- FakeFactory(getattr(tf.io.gfile, stub_name)))
-
- with self.assertRaises((IOError, OSError)):
- self._LoadAllEvents()
-
-
-if __name__ == '__main__':
- tf.test.main()
+ def setUp(self):
+ # Put everything in a directory so it's easier to delete.
+ self._directory = os.path.join(self.get_temp_dir(), "monitor_dir")
+ os.mkdir(self._directory)
+ self._watcher = directory_watcher.DirectoryWatcher(
+ self._directory, _ByteLoader
+ )
+ self.stubs = tf.compat.v1.test.StubOutForTesting()
+
+ def tearDown(self):
+ self.stubs.CleanUp()
+ try:
+ shutil.rmtree(self._directory)
+ except OSError:
+ # Some tests delete the directory.
+ pass
+
+ def _WriteToFile(self, filename, data):
+ path = os.path.join(self._directory, filename)
+ with open(path, "a") as f:
+ f.write(data)
+
+ def _LoadAllEvents(self):
+ """Loads all events in the watcher."""
+ for _ in self._watcher.Load():
+ pass
+
+ def assertWatcherYields(self, values):
+ self.assertEqual(list(self._watcher.Load()), values)
+
+ def testRaisesWithBadArguments(self):
+ with self.assertRaises(ValueError):
+ directory_watcher.DirectoryWatcher(None, lambda x: None)
+ with self.assertRaises(ValueError):
+ directory_watcher.DirectoryWatcher("dir", None)
+
+ def testEmptyDirectory(self):
+ self.assertWatcherYields([])
+
+ def testSingleWrite(self):
+ self._WriteToFile("a", "abc")
+ self.assertWatcherYields(["a", "b", "c"])
+ self.assertFalse(self._watcher.OutOfOrderWritesDetected())
+
+ def testMultipleWrites(self):
+ self._WriteToFile("a", "abc")
+ self.assertWatcherYields(["a", "b", "c"])
+ self._WriteToFile("a", "xyz")
+ self.assertWatcherYields(["x", "y", "z"])
+ self.assertFalse(self._watcher.OutOfOrderWritesDetected())
+
+ def testMultipleLoads(self):
+ self._WriteToFile("a", "a")
+ self._watcher.Load()
+ self._watcher.Load()
+ self.assertWatcherYields(["a"])
+ self.assertFalse(self._watcher.OutOfOrderWritesDetected())
+
+ def testMultipleFilesAtOnce(self):
+ self._WriteToFile("b", "b")
+ self._WriteToFile("a", "a")
+ self.assertWatcherYields(["a", "b"])
+ self.assertFalse(self._watcher.OutOfOrderWritesDetected())
+
+ def testFinishesLoadingFileWhenSwitchingToNewFile(self):
+ self._WriteToFile("a", "a")
+ # Empty the iterator.
+ self.assertEqual(["a"], list(self._watcher.Load()))
+ self._WriteToFile("a", "b")
+ self._WriteToFile("b", "c")
+ # The watcher should finish its current file before starting a new one.
+ self.assertWatcherYields(["b", "c"])
+ self.assertFalse(self._watcher.OutOfOrderWritesDetected())
+
+ def testIntermediateEmptyFiles(self):
+ self._WriteToFile("a", "a")
+ self._WriteToFile("b", "")
+ self._WriteToFile("c", "c")
+ self.assertWatcherYields(["a", "c"])
+ self.assertFalse(self._watcher.OutOfOrderWritesDetected())
+
+ def testPathFilter(self):
+ self._watcher = directory_watcher.DirectoryWatcher(
+ self._directory,
+ _ByteLoader,
+ lambda path: "do_not_watch_me" not in path,
+ )
+
+ self._WriteToFile("a", "a")
+ self._WriteToFile("do_not_watch_me", "b")
+ self._WriteToFile("c", "c")
+ self.assertWatcherYields(["a", "c"])
+ self.assertFalse(self._watcher.OutOfOrderWritesDetected())
+
+ def testDetectsNewOldFiles(self):
+ self._WriteToFile("b", "a")
+ self._LoadAllEvents()
+ self._WriteToFile("a", "a")
+ self._LoadAllEvents()
+ self.assertTrue(self._watcher.OutOfOrderWritesDetected())
+
+ def testIgnoresNewerFiles(self):
+ self._WriteToFile("a", "a")
+ self._LoadAllEvents()
+ self._WriteToFile("q", "a")
+ self._LoadAllEvents()
+ self.assertFalse(self._watcher.OutOfOrderWritesDetected())
+
+ def testDetectsChangingOldFiles(self):
+ self._WriteToFile("a", "a")
+ self._WriteToFile("b", "a")
+ self._LoadAllEvents()
+ self._WriteToFile("a", "c")
+ self._LoadAllEvents()
+ self.assertTrue(self._watcher.OutOfOrderWritesDetected())
+
+ def testDoesntCrashWhenFileIsDeleted(self):
+ self._WriteToFile("a", "a")
+ self._LoadAllEvents()
+ os.remove(os.path.join(self._directory, "a"))
+ self._WriteToFile("b", "b")
+ self.assertWatcherYields(["b"])
+
+ def testRaisesRightErrorWhenDirectoryIsDeleted(self):
+ self._WriteToFile("a", "a")
+ self._LoadAllEvents()
+ shutil.rmtree(self._directory)
+ with self.assertRaises(directory_watcher.DirectoryDeletedError):
+ self._LoadAllEvents()
+
+ def testDoesntRaiseDirectoryDeletedErrorIfOutageIsTransient(self):
+ self._WriteToFile("a", "a")
+ self._LoadAllEvents()
+ shutil.rmtree(self._directory)
+
+ # Fake a single transient I/O error.
+ def FakeFactory(original):
+ def Fake(*args, **kwargs):
+ if FakeFactory.has_been_called:
+ original(*args, **kwargs)
+ else:
+ raise OSError("lp0 temporarily on fire")
+
+ return Fake
+
+ FakeFactory.has_been_called = False
+
+ stub_names = [
+ "ListDirectoryAbsolute",
+ "ListRecursivelyViaGlobbing",
+ "ListRecursivelyViaWalking",
+ ]
+ for stub_name in stub_names:
+ self.stubs.Set(
+ io_wrapper,
+ stub_name,
+ FakeFactory(getattr(io_wrapper, stub_name)),
+ )
+ for stub_name in ["exists", "stat"]:
+ self.stubs.Set(
+ tf.io.gfile,
+ stub_name,
+ FakeFactory(getattr(tf.io.gfile, stub_name)),
+ )
+
+ with self.assertRaises((IOError, OSError)):
+ self._LoadAllEvents()
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorboard/backend/event_processing/event_accumulator.py b/tensorboard/backend/event_processing/event_accumulator.py
index 83c42685b6..d616cf3006 100644
--- a/tensorboard/backend/event_processing/event_accumulator.py
+++ b/tensorboard/backend/event_processing/event_accumulator.py
@@ -37,49 +37,61 @@
logger = tb_logging.get_logger()
namedtuple = collections.namedtuple
-ScalarEvent = namedtuple('ScalarEvent', ['wall_time', 'step', 'value'])
-
-CompressedHistogramEvent = namedtuple('CompressedHistogramEvent',
- ['wall_time', 'step',
- 'compressed_histogram_values'])
-
-HistogramEvent = namedtuple('HistogramEvent',
- ['wall_time', 'step', 'histogram_value'])
-
-HistogramValue = namedtuple('HistogramValue', ['min', 'max', 'num', 'sum',
- 'sum_squares', 'bucket_limit',
- 'bucket'])
-
-ImageEvent = namedtuple('ImageEvent', ['wall_time', 'step',
- 'encoded_image_string', 'width',
- 'height'])
-
-AudioEvent = namedtuple('AudioEvent', ['wall_time', 'step',
- 'encoded_audio_string', 'content_type',
- 'sample_rate', 'length_frames'])
-
-TensorEvent = namedtuple('TensorEvent', ['wall_time', 'step', 'tensor_proto'])
+ScalarEvent = namedtuple("ScalarEvent", ["wall_time", "step", "value"])
+
+CompressedHistogramEvent = namedtuple(
+ "CompressedHistogramEvent",
+ ["wall_time", "step", "compressed_histogram_values"],
+)
+
+HistogramEvent = namedtuple(
+ "HistogramEvent", ["wall_time", "step", "histogram_value"]
+)
+
+HistogramValue = namedtuple(
+ "HistogramValue",
+ ["min", "max", "num", "sum", "sum_squares", "bucket_limit", "bucket"],
+)
+
+ImageEvent = namedtuple(
+ "ImageEvent",
+ ["wall_time", "step", "encoded_image_string", "width", "height"],
+)
+
+AudioEvent = namedtuple(
+ "AudioEvent",
+ [
+ "wall_time",
+ "step",
+ "encoded_audio_string",
+ "content_type",
+ "sample_rate",
+ "length_frames",
+ ],
+)
+
+TensorEvent = namedtuple("TensorEvent", ["wall_time", "step", "tensor_proto"])
## Different types of summary events handled by the event_accumulator
SUMMARY_TYPES = {
- 'simple_value': '_ProcessScalar',
- 'histo': '_ProcessHistogram',
- 'image': '_ProcessImage',
- 'audio': '_ProcessAudio',
- 'tensor': '_ProcessTensor',
+ "simple_value": "_ProcessScalar",
+ "histo": "_ProcessHistogram",
+ "image": "_ProcessImage",
+ "audio": "_ProcessAudio",
+ "tensor": "_ProcessTensor",
}
## The tagTypes below are just arbitrary strings chosen to pass the type
## information of the tag from the backend to the frontend
-COMPRESSED_HISTOGRAMS = 'distributions'
-HISTOGRAMS = 'histograms'
-IMAGES = 'images'
-AUDIO = 'audio'
-SCALARS = 'scalars'
-TENSORS = 'tensors'
-GRAPH = 'graph'
-META_GRAPH = 'meta_graph'
-RUN_METADATA = 'run_metadata'
+COMPRESSED_HISTOGRAMS = "distributions"
+HISTOGRAMS = "histograms"
+IMAGES = "images"
+AUDIO = "audio"
+SCALARS = "scalars"
+TENSORS = "tensors"
+GRAPH = "graph"
+META_GRAPH = "meta_graph"
+RUN_METADATA = "run_metadata"
## Normal CDF for std_devs: (-Inf, -1.5, -1, -0.5, 0, 0.5, 1, 1.5, Inf)
## naturally gives bands around median of width 1 std dev, 2 std dev, 3 std dev,
@@ -106,662 +118,737 @@
class EventAccumulator(object):
- """An `EventAccumulator` takes an event generator, and accumulates the values.
-
- The `EventAccumulator` is intended to provide a convenient Python interface
- for loading Event data written during a TensorFlow run. TensorFlow writes out
- `Event` protobuf objects, which have a timestamp and step number, and often
- contain a `Summary`. Summaries can have different kinds of data like an image,
- a scalar value, or a histogram. The Summaries also have a tag, which we use to
- organize logically related data. The `EventAccumulator` supports retrieving
- the `Event` and `Summary` data by its tag.
-
- Calling `Tags()` gets a map from `tagType` (e.g. `'images'`,
- `'compressedHistograms'`, `'scalars'`, etc) to the associated tags for those
- data types. Then, various functional endpoints (eg
- `Accumulator.Scalars(tag)`) allow for the retrieval of all data
- associated with that tag.
-
- The `Reload()` method synchronously loads all of the data written so far.
-
- Histograms, audio, and images are very large, so storing all of them is not
- recommended.
-
- Fields:
- audios: A reservoir.Reservoir of audio summaries.
- compressed_histograms: A reservoir.Reservoir of compressed
- histogram summaries.
- histograms: A reservoir.Reservoir of histogram summaries.
- images: A reservoir.Reservoir of image summaries.
- most_recent_step: Step of last Event proto added. This should only
- be accessed from the thread that calls Reload. This is -1 if
- nothing has been loaded yet.
- most_recent_wall_time: Timestamp of last Event proto added. This is
- a float containing seconds from the UNIX epoch, or -1 if
- nothing has been loaded yet. This should only be accessed from
- the thread that calls Reload.
- path: A file path to a directory containing tf events files, or a single
- tf events file. The accumulator will load events from this path.
- scalars: A reservoir.Reservoir of scalar summaries.
- tensors: A reservoir.Reservoir of tensor summaries.
-
- @@Tensors
- """
-
- def __init__(self,
- path,
- size_guidance=None,
- compression_bps=NORMAL_HISTOGRAM_BPS,
- purge_orphaned_data=True):
- """Construct the `EventAccumulator`.
-
- Args:
+ """An `EventAccumulator` takes an event generator, and accumulates the
+ values.
+
+ The `EventAccumulator` is intended to provide a convenient Python interface
+ for loading Event data written during a TensorFlow run. TensorFlow writes out
+ `Event` protobuf objects, which have a timestamp and step number, and often
+ contain a `Summary`. Summaries can have different kinds of data like an image,
+ a scalar value, or a histogram. The Summaries also have a tag, which we use to
+ organize logically related data. The `EventAccumulator` supports retrieving
+ the `Event` and `Summary` data by its tag.
+
+ Calling `Tags()` gets a map from `tagType` (e.g. `'images'`,
+ `'compressedHistograms'`, `'scalars'`, etc) to the associated tags for those
+ data types. Then, various functional endpoints (eg
+ `Accumulator.Scalars(tag)`) allow for the retrieval of all data
+ associated with that tag.
+
+ The `Reload()` method synchronously loads all of the data written so far.
+
+ Histograms, audio, and images are very large, so storing all of them is not
+ recommended.
+
+ Fields:
+ audios: A reservoir.Reservoir of audio summaries.
+ compressed_histograms: A reservoir.Reservoir of compressed
+ histogram summaries.
+ histograms: A reservoir.Reservoir of histogram summaries.
+ images: A reservoir.Reservoir of image summaries.
+ most_recent_step: Step of last Event proto added. This should only
+ be accessed from the thread that calls Reload. This is -1 if
+ nothing has been loaded yet.
+ most_recent_wall_time: Timestamp of last Event proto added. This is
+ a float containing seconds from the UNIX epoch, or -1 if
+ nothing has been loaded yet. This should only be accessed from
+ the thread that calls Reload.
path: A file path to a directory containing tf events files, or a single
- tf events file. The accumulator will load events from this path.
- size_guidance: Information on how much data the EventAccumulator should
- store in memory. The DEFAULT_SIZE_GUIDANCE tries not to store too much
- so as to avoid OOMing the client. The size_guidance should be a map
- from a `tagType` string to an integer representing the number of
- items to keep per tag for items of that `tagType`. If the size is 0,
- all events are stored.
- compression_bps: Information on how the `EventAccumulator` should compress
- histogram data for the `CompressedHistograms` tag (for details see
- `ProcessCompressedHistogram`).
- purge_orphaned_data: Whether to discard any events that were "orphaned" by
- a TensorFlow restart.
- """
- size_guidance = size_guidance or DEFAULT_SIZE_GUIDANCE
- sizes = {}
- for key in DEFAULT_SIZE_GUIDANCE:
- if key in size_guidance:
- sizes[key] = size_guidance[key]
- else:
- sizes[key] = DEFAULT_SIZE_GUIDANCE[key]
-
- self._first_event_timestamp = None
- self.scalars = reservoir.Reservoir(size=sizes[SCALARS])
-
- self._graph = None
- self._graph_from_metagraph = False
- self._meta_graph = None
- self._tagged_metadata = {}
- self.summary_metadata = {}
- self.histograms = reservoir.Reservoir(size=sizes[HISTOGRAMS])
- self.compressed_histograms = reservoir.Reservoir(
- size=sizes[COMPRESSED_HISTOGRAMS], always_keep_last=False)
- self.images = reservoir.Reservoir(size=sizes[IMAGES])
- self.audios = reservoir.Reservoir(size=sizes[AUDIO])
- self.tensors = reservoir.Reservoir(size=sizes[TENSORS])
-
- # Keep a mapping from plugin name to a dict mapping from tag to plugin data
- # content obtained from the SummaryMetadata (metadata field of Value) for
- # that plugin (This is not the entire SummaryMetadata proto - only the
- # content for that plugin). The SummaryWriter only keeps the content on the
- # first event encountered per tag, so we must store that first instance of
- # content for each tag.
- self._plugin_to_tag_to_content = collections.defaultdict(dict)
-
- self._generator_mutex = threading.Lock()
- self.path = path
- self._generator = _GeneratorFromPath(path)
-
- self._compression_bps = compression_bps
- self.purge_orphaned_data = purge_orphaned_data
-
- self.most_recent_step = -1
- self.most_recent_wall_time = -1
- self.file_version = None
-
- # The attributes that get built up by the accumulator
- self.accumulated_attrs = ('scalars', 'histograms',
- 'compressed_histograms', 'images', 'audios')
- self._tensor_summaries = {}
-
- def Reload(self):
- """Loads all events added since the last call to `Reload`.
-
- If `Reload` was never called, loads all events in the file.
+ tf events file. The accumulator will load events from this path.
+ scalars: A reservoir.Reservoir of scalar summaries.
+ tensors: A reservoir.Reservoir of tensor summaries.
- Returns:
- The `EventAccumulator`.
- """
- with self._generator_mutex:
- for event in self._generator.Load():
- self._ProcessEvent(event)
- return self
-
- def PluginAssets(self, plugin_name):
- """Return a list of all plugin assets for the given plugin.
-
- Args:
- plugin_name: The string name of a plugin to retrieve assets for.
-
- Returns:
- A list of string plugin asset names, or empty list if none are available.
- If the plugin was not registered, an empty list is returned.
+ @@Tensors
"""
- return plugin_asset_util.ListAssets(self.path, plugin_name)
-
- def RetrievePluginAsset(self, plugin_name, asset_name):
- """Return the contents of a given plugin asset.
-
- Args:
- plugin_name: The string name of a plugin.
- asset_name: The string name of an asset.
-
- Returns:
- The string contents of the plugin asset.
-
- Raises:
- KeyError: If the asset is not available.
- """
- return plugin_asset_util.RetrieveAsset(self.path, plugin_name, asset_name)
-
- def FirstEventTimestamp(self):
- """Returns the timestamp in seconds of the first event.
-
- If the first event has been loaded (either by this method or by `Reload`,
- this returns immediately. Otherwise, it will load in the first event. Note
- that this means that calling `Reload` will cause this to block until
- `Reload` has finished.
-
- Returns:
- The timestamp in seconds of the first event that was loaded.
-
- Raises:
- ValueError: If no events have been loaded and there were no events found
- on disk.
- """
- if self._first_event_timestamp is not None:
- return self._first_event_timestamp
- with self._generator_mutex:
- try:
- event = next(self._generator.Load())
- self._ProcessEvent(event)
- return self._first_event_timestamp
-
- except StopIteration:
- raise ValueError('No event timestamp could be found')
-
- def PluginTagToContent(self, plugin_name):
- """Returns a dict mapping tags to content specific to that plugin.
-
- Args:
- plugin_name: The name of the plugin for which to fetch plugin-specific
- content.
- Raises:
- KeyError: if the plugin name is not found.
-
- Returns:
- A dict mapping tag names to bytestrings of plugin-specific content-- by
- convention, in the form of binary serialized protos.
- """
- if plugin_name not in self._plugin_to_tag_to_content:
- raise KeyError('Plugin %r could not be found.' % plugin_name)
- return self._plugin_to_tag_to_content[plugin_name]
-
- def SummaryMetadata(self, tag):
- """Given a summary tag name, return the associated metadata object.
-
- Args:
- tag: The name of a tag, as a string.
-
- Raises:
- KeyError: If the tag is not found.
-
- Returns:
- A `SummaryMetadata` protobuf.
- """
- return self.summary_metadata[tag]
-
- def _ProcessEvent(self, event):
- """Called whenever an event is loaded."""
- if self._first_event_timestamp is None:
- self._first_event_timestamp = event.wall_time
-
- if event.HasField('file_version'):
- new_file_version = _ParseFileVersion(event.file_version)
- if self.file_version and self.file_version != new_file_version:
- ## This should not happen.
- logger.warn(('Found new file_version for event.proto. This will '
- 'affect purging logic for TensorFlow restarts. '
- 'Old: {0} New: {1}').format(self.file_version,
- new_file_version))
- self.file_version = new_file_version
-
- self._MaybePurgeOrphanedData(event)
-
- ## Process the event.
- # GraphDef and MetaGraphDef are handled in a special way:
- # If no graph_def Event is available, but a meta_graph_def is, and it
- # contains a graph_def, then use the meta_graph_def.graph_def as our graph.
- # If a graph_def Event is available, always prefer it to the graph_def
- # inside the meta_graph_def.
- if event.HasField('graph_def'):
- if self._graph is not None:
- logger.warn(
- ('Found more than one graph event per run, or there was '
- 'a metagraph containing a graph_def, as well as one or '
- 'more graph events. Overwriting the graph with the '
- 'newest event.'))
- self._graph = event.graph_def
- self._graph_from_metagraph = False
- elif event.HasField('meta_graph_def'):
- if self._meta_graph is not None:
- logger.warn(('Found more than one metagraph event per run. '
- 'Overwriting the metagraph with the newest event.'))
- self._meta_graph = event.meta_graph_def
- if self._graph is None or self._graph_from_metagraph:
- # We may have a graph_def in the metagraph. If so, and no
- # graph_def is directly available, use this one instead.
+ def __init__(
+ self,
+ path,
+ size_guidance=None,
+ compression_bps=NORMAL_HISTOGRAM_BPS,
+ purge_orphaned_data=True,
+ ):
+ """Construct the `EventAccumulator`.
+
+ Args:
+ path: A file path to a directory containing tf events files, or a single
+ tf events file. The accumulator will load events from this path.
+ size_guidance: Information on how much data the EventAccumulator should
+ store in memory. The DEFAULT_SIZE_GUIDANCE tries not to store too much
+ so as to avoid OOMing the client. The size_guidance should be a map
+ from a `tagType` string to an integer representing the number of
+ items to keep per tag for items of that `tagType`. If the size is 0,
+ all events are stored.
+ compression_bps: Information on how the `EventAccumulator` should compress
+ histogram data for the `CompressedHistograms` tag (for details see
+ `ProcessCompressedHistogram`).
+ purge_orphaned_data: Whether to discard any events that were "orphaned" by
+ a TensorFlow restart.
+ """
+ size_guidance = size_guidance or DEFAULT_SIZE_GUIDANCE
+ sizes = {}
+ for key in DEFAULT_SIZE_GUIDANCE:
+ if key in size_guidance:
+ sizes[key] = size_guidance[key]
+ else:
+ sizes[key] = DEFAULT_SIZE_GUIDANCE[key]
+
+ self._first_event_timestamp = None
+ self.scalars = reservoir.Reservoir(size=sizes[SCALARS])
+
+ self._graph = None
+ self._graph_from_metagraph = False
+ self._meta_graph = None
+ self._tagged_metadata = {}
+ self.summary_metadata = {}
+ self.histograms = reservoir.Reservoir(size=sizes[HISTOGRAMS])
+ self.compressed_histograms = reservoir.Reservoir(
+ size=sizes[COMPRESSED_HISTOGRAMS], always_keep_last=False
+ )
+ self.images = reservoir.Reservoir(size=sizes[IMAGES])
+ self.audios = reservoir.Reservoir(size=sizes[AUDIO])
+ self.tensors = reservoir.Reservoir(size=sizes[TENSORS])
+
+ # Keep a mapping from plugin name to a dict mapping from tag to plugin data
+ # content obtained from the SummaryMetadata (metadata field of Value) for
+ # that plugin (This is not the entire SummaryMetadata proto - only the
+ # content for that plugin). The SummaryWriter only keeps the content on the
+ # first event encountered per tag, so we must store that first instance of
+ # content for each tag.
+ self._plugin_to_tag_to_content = collections.defaultdict(dict)
+
+ self._generator_mutex = threading.Lock()
+ self.path = path
+ self._generator = _GeneratorFromPath(path)
+
+ self._compression_bps = compression_bps
+ self.purge_orphaned_data = purge_orphaned_data
+
+ self.most_recent_step = -1
+ self.most_recent_wall_time = -1
+ self.file_version = None
+
+ # The attributes that get built up by the accumulator
+ self.accumulated_attrs = (
+ "scalars",
+ "histograms",
+ "compressed_histograms",
+ "images",
+ "audios",
+ )
+ self._tensor_summaries = {}
+
+ def Reload(self):
+ """Loads all events added since the last call to `Reload`.
+
+ If `Reload` was never called, loads all events in the file.
+
+ Returns:
+ The `EventAccumulator`.
+ """
+ with self._generator_mutex:
+ for event in self._generator.Load():
+ self._ProcessEvent(event)
+ return self
+
+ def PluginAssets(self, plugin_name):
+ """Return a list of all plugin assets for the given plugin.
+
+ Args:
+ plugin_name: The string name of a plugin to retrieve assets for.
+
+ Returns:
+ A list of string plugin asset names, or empty list if none are available.
+ If the plugin was not registered, an empty list is returned.
+ """
+ return plugin_asset_util.ListAssets(self.path, plugin_name)
+
+ def RetrievePluginAsset(self, plugin_name, asset_name):
+ """Return the contents of a given plugin asset.
+
+ Args:
+ plugin_name: The string name of a plugin.
+ asset_name: The string name of an asset.
+
+ Returns:
+ The string contents of the plugin asset.
+
+ Raises:
+ KeyError: If the asset is not available.
+ """
+ return plugin_asset_util.RetrieveAsset(
+ self.path, plugin_name, asset_name
+ )
+
+ def FirstEventTimestamp(self):
+ """Returns the timestamp in seconds of the first event.
+
+ If the first event has been loaded (either by this method or by `Reload`,
+ this returns immediately. Otherwise, it will load in the first event. Note
+ that this means that calling `Reload` will cause this to block until
+ `Reload` has finished.
+
+ Returns:
+ The timestamp in seconds of the first event that was loaded.
+
+ Raises:
+ ValueError: If no events have been loaded and there were no events found
+ on disk.
+ """
+ if self._first_event_timestamp is not None:
+ return self._first_event_timestamp
+ with self._generator_mutex:
+ try:
+ event = next(self._generator.Load())
+ self._ProcessEvent(event)
+ return self._first_event_timestamp
+
+ except StopIteration:
+ raise ValueError("No event timestamp could be found")
+
+ def PluginTagToContent(self, plugin_name):
+ """Returns a dict mapping tags to content specific to that plugin.
+
+ Args:
+ plugin_name: The name of the plugin for which to fetch plugin-specific
+ content.
+
+ Raises:
+ KeyError: if the plugin name is not found.
+
+ Returns:
+ A dict mapping tag names to bytestrings of plugin-specific content-- by
+ convention, in the form of binary serialized protos.
+ """
+ if plugin_name not in self._plugin_to_tag_to_content:
+ raise KeyError("Plugin %r could not be found." % plugin_name)
+ return self._plugin_to_tag_to_content[plugin_name]
+
+ def SummaryMetadata(self, tag):
+ """Given a summary tag name, return the associated metadata object.
+
+ Args:
+ tag: The name of a tag, as a string.
+
+ Raises:
+ KeyError: If the tag is not found.
+
+ Returns:
+ A `SummaryMetadata` protobuf.
+ """
+ return self.summary_metadata[tag]
+
+ def _ProcessEvent(self, event):
+ """Called whenever an event is loaded."""
+ if self._first_event_timestamp is None:
+ self._first_event_timestamp = event.wall_time
+
+ if event.HasField("file_version"):
+ new_file_version = _ParseFileVersion(event.file_version)
+ if self.file_version and self.file_version != new_file_version:
+ ## This should not happen.
+ logger.warn(
+ (
+ "Found new file_version for event.proto. This will "
+ "affect purging logic for TensorFlow restarts. "
+ "Old: {0} New: {1}"
+ ).format(self.file_version, new_file_version)
+ )
+ self.file_version = new_file_version
+
+ self._MaybePurgeOrphanedData(event)
+
+ ## Process the event.
+ # GraphDef and MetaGraphDef are handled in a special way:
+ # If no graph_def Event is available, but a meta_graph_def is, and it
+ # contains a graph_def, then use the meta_graph_def.graph_def as our graph.
+ # If a graph_def Event is available, always prefer it to the graph_def
+ # inside the meta_graph_def.
+ if event.HasField("graph_def"):
+ if self._graph is not None:
+ logger.warn(
+ (
+ "Found more than one graph event per run, or there was "
+ "a metagraph containing a graph_def, as well as one or "
+ "more graph events. Overwriting the graph with the "
+ "newest event."
+ )
+ )
+ self._graph = event.graph_def
+ self._graph_from_metagraph = False
+ elif event.HasField("meta_graph_def"):
+ if self._meta_graph is not None:
+ logger.warn(
+ (
+ "Found more than one metagraph event per run. "
+ "Overwriting the metagraph with the newest event."
+ )
+ )
+ self._meta_graph = event.meta_graph_def
+ if self._graph is None or self._graph_from_metagraph:
+ # We may have a graph_def in the metagraph. If so, and no
+ # graph_def is directly available, use this one instead.
+ meta_graph = meta_graph_pb2.MetaGraphDef()
+ meta_graph.ParseFromString(self._meta_graph)
+ if meta_graph.graph_def:
+ if self._graph is not None:
+ logger.warn(
+ (
+ "Found multiple metagraphs containing graph_defs,"
+ "but did not find any graph events. Overwriting the "
+ "graph with the newest metagraph version."
+ )
+ )
+ self._graph_from_metagraph = True
+ self._graph = meta_graph.graph_def.SerializeToString()
+ elif event.HasField("tagged_run_metadata"):
+ tag = event.tagged_run_metadata.tag
+ if tag in self._tagged_metadata:
+ logger.warn(
+ 'Found more than one "run metadata" event with tag '
+ + tag
+ + ". Overwriting it with the newest event."
+ )
+ self._tagged_metadata[tag] = event.tagged_run_metadata.run_metadata
+ elif event.HasField("summary"):
+ for value in event.summary.value:
+ if value.HasField("metadata"):
+ tag = value.tag
+ # We only store the first instance of the metadata. This check
+ # is important: the `FileWriter` does strip metadata from all
+ # values except the first one per each tag, but a new
+ # `FileWriter` is created every time a training job stops and
+ # restarts. Hence, we must also ignore non-initial metadata in
+ # this logic.
+ if tag not in self.summary_metadata:
+ self.summary_metadata[tag] = value.metadata
+ plugin_data = value.metadata.plugin_data
+ if plugin_data.plugin_name:
+ self._plugin_to_tag_to_content[
+ plugin_data.plugin_name
+ ][tag] = plugin_data.content
+ else:
+ logger.warn(
+ (
+ "This summary with tag %r is oddly not associated with a "
+ "plugin."
+ ),
+ tag,
+ )
+
+ for summary_type, summary_func in SUMMARY_TYPES.items():
+ if value.HasField(summary_type):
+ datum = getattr(value, summary_type)
+ tag = value.tag
+ if summary_type == "tensor" and not tag:
+ # This tensor summary was created using the old method that used
+ # plugin assets. We must still continue to support it.
+ tag = value.node_name
+ getattr(self, summary_func)(
+ tag, event.wall_time, event.step, datum
+ )
+
+ def Tags(self):
+ """Return all tags found in the value stream.
+
+ Returns:
+ A `{tagType: ['list', 'of', 'tags']}` dictionary.
+ """
+ return {
+ IMAGES: self.images.Keys(),
+ AUDIO: self.audios.Keys(),
+ HISTOGRAMS: self.histograms.Keys(),
+ SCALARS: self.scalars.Keys(),
+ COMPRESSED_HISTOGRAMS: self.compressed_histograms.Keys(),
+ TENSORS: self.tensors.Keys(),
+ # Use a heuristic: if the metagraph is available, but
+ # graph is not, then we assume the metagraph contains the graph.
+ GRAPH: self._graph is not None,
+ META_GRAPH: self._meta_graph is not None,
+ RUN_METADATA: list(self._tagged_metadata.keys()),
+ }
+
+ def Scalars(self, tag):
+ """Given a summary tag, return all associated `ScalarEvent`s.
+
+ Args:
+ tag: A string tag associated with the events.
+
+ Raises:
+ KeyError: If the tag is not found.
+
+ Returns:
+ An array of `ScalarEvent`s.
+ """
+ return self.scalars.Items(tag)
+
+ def Graph(self):
+ """Return the graph definition, if there is one.
+
+ If the graph is stored directly, return that. If no graph is stored
+ directly but a metagraph is stored containing a graph, return that.
+
+ Raises:
+ ValueError: If there is no graph for this run.
+
+ Returns:
+ The `graph_def` proto.
+ """
+ graph = graph_pb2.GraphDef()
+ if self._graph is not None:
+ graph.ParseFromString(self._graph)
+ return graph
+ raise ValueError("There is no graph in this EventAccumulator")
+
+ def MetaGraph(self):
+ """Return the metagraph definition, if there is one.
+
+ Raises:
+ ValueError: If there is no metagraph for this run.
+
+ Returns:
+ The `meta_graph_def` proto.
+ """
+ if self._meta_graph is None:
+ raise ValueError("There is no metagraph in this EventAccumulator")
meta_graph = meta_graph_pb2.MetaGraphDef()
meta_graph.ParseFromString(self._meta_graph)
- if meta_graph.graph_def:
- if self._graph is not None:
- logger.warn(
- ('Found multiple metagraphs containing graph_defs,'
- 'but did not find any graph events. Overwriting the '
- 'graph with the newest metagraph version.'))
- self._graph_from_metagraph = True
- self._graph = meta_graph.graph_def.SerializeToString()
- elif event.HasField('tagged_run_metadata'):
- tag = event.tagged_run_metadata.tag
- if tag in self._tagged_metadata:
- logger.warn('Found more than one "run metadata" event with tag ' +
- tag + '. Overwriting it with the newest event.')
- self._tagged_metadata[tag] = event.tagged_run_metadata.run_metadata
- elif event.HasField('summary'):
- for value in event.summary.value:
- if value.HasField('metadata'):
- tag = value.tag
- # We only store the first instance of the metadata. This check
- # is important: the `FileWriter` does strip metadata from all
- # values except the first one per each tag, but a new
- # `FileWriter` is created every time a training job stops and
- # restarts. Hence, we must also ignore non-initial metadata in
- # this logic.
- if tag not in self.summary_metadata:
- self.summary_metadata[tag] = value.metadata
- plugin_data = value.metadata.plugin_data
- if plugin_data.plugin_name:
- self._plugin_to_tag_to_content[plugin_data.plugin_name][tag] = (
- plugin_data.content)
- else:
- logger.warn(
- ('This summary with tag %r is oddly not associated with a '
- 'plugin.'), tag)
-
- for summary_type, summary_func in SUMMARY_TYPES.items():
- if value.HasField(summary_type):
- datum = getattr(value, summary_type)
- tag = value.tag
- if summary_type == 'tensor' and not tag:
- # This tensor summary was created using the old method that used
- # plugin assets. We must still continue to support it.
- tag = value.node_name
- getattr(self, summary_func)(tag, event.wall_time, event.step, datum)
-
+ return meta_graph
- def Tags(self):
- """Return all tags found in the value stream.
+ def RunMetadata(self, tag):
+ """Given a tag, return the associated session.run() metadata.
- Returns:
- A `{tagType: ['list', 'of', 'tags']}` dictionary.
- """
- return {
- IMAGES: self.images.Keys(),
- AUDIO: self.audios.Keys(),
- HISTOGRAMS: self.histograms.Keys(),
- SCALARS: self.scalars.Keys(),
- COMPRESSED_HISTOGRAMS: self.compressed_histograms.Keys(),
- TENSORS: self.tensors.Keys(),
- # Use a heuristic: if the metagraph is available, but
- # graph is not, then we assume the metagraph contains the graph.
- GRAPH: self._graph is not None,
- META_GRAPH: self._meta_graph is not None,
- RUN_METADATA: list(self._tagged_metadata.keys())
- }
-
- def Scalars(self, tag):
- """Given a summary tag, return all associated `ScalarEvent`s.
-
- Args:
- tag: A string tag associated with the events.
-
- Raises:
- KeyError: If the tag is not found.
-
- Returns:
- An array of `ScalarEvent`s.
- """
- return self.scalars.Items(tag)
+ Args:
+ tag: A string tag associated with the event.
- def Graph(self):
- """Return the graph definition, if there is one.
+ Raises:
+ ValueError: If the tag is not found.
- If the graph is stored directly, return that. If no graph is stored
- directly but a metagraph is stored containing a graph, return that.
+ Returns:
+ The metadata in form of `RunMetadata` proto.
+ """
+ if tag not in self._tagged_metadata:
+ raise ValueError("There is no run metadata with this tag name")
- Raises:
- ValueError: If there is no graph for this run.
+ run_metadata = config_pb2.RunMetadata()
+ run_metadata.ParseFromString(self._tagged_metadata[tag])
+ return run_metadata
- Returns:
- The `graph_def` proto.
- """
- graph = graph_pb2.GraphDef()
- if self._graph is not None:
- graph.ParseFromString(self._graph)
- return graph
- raise ValueError('There is no graph in this EventAccumulator')
+ def Histograms(self, tag):
+ """Given a summary tag, return all associated histograms.
- def MetaGraph(self):
- """Return the metagraph definition, if there is one.
+ Args:
+ tag: A string tag associated with the events.
- Raises:
- ValueError: If there is no metagraph for this run.
-
- Returns:
- The `meta_graph_def` proto.
- """
- if self._meta_graph is None:
- raise ValueError('There is no metagraph in this EventAccumulator')
- meta_graph = meta_graph_pb2.MetaGraphDef()
- meta_graph.ParseFromString(self._meta_graph)
- return meta_graph
+ Raises:
+ KeyError: If the tag is not found.
- def RunMetadata(self, tag):
- """Given a tag, return the associated session.run() metadata.
+ Returns:
+ An array of `HistogramEvent`s.
+ """
+ return self.histograms.Items(tag)
- Args:
- tag: A string tag associated with the event.
+ def CompressedHistograms(self, tag):
+ """Given a summary tag, return all associated compressed histograms.
- Raises:
- ValueError: If the tag is not found.
+ Args:
+ tag: A string tag associated with the events.
+
+ Raises:
+ KeyError: If the tag is not found.
+
+ Returns:
+ An array of `CompressedHistogramEvent`s.
+ """
+ return self.compressed_histograms.Items(tag)
+
+ def Images(self, tag):
+ """Given a summary tag, return all associated images.
+
+ Args:
+ tag: A string tag associated with the events.
+
+ Raises:
+ KeyError: If the tag is not found.
+
+ Returns:
+ An array of `ImageEvent`s.
+ """
+ return self.images.Items(tag)
+
+ def Audio(self, tag):
+ """Given a summary tag, return all associated audio.
+
+ Args:
+ tag: A string tag associated with the events.
+
+ Raises:
+ KeyError: If the tag is not found.
+
+ Returns:
+ An array of `AudioEvent`s.
+ """
+ return self.audios.Items(tag)
+
+ def Tensors(self, tag):
+ """Given a summary tag, return all associated tensors.
+
+ Args:
+ tag: A string tag associated with the events.
+
+ Raises:
+ KeyError: If the tag is not found.
+
+ Returns:
+ An array of `TensorEvent`s.
+ """
+ return self.tensors.Items(tag)
+
+ def _MaybePurgeOrphanedData(self, event):
+ """Maybe purge orphaned data due to a TensorFlow crash.
+
+ When TensorFlow crashes at step T+O and restarts at step T, any events
+ written after step T are now "orphaned" and will be at best misleading if
+ they are included in TensorBoard.
+
+ This logic attempts to determine if there is orphaned data, and purge it
+ if it is found.
+
+ Args:
+ event: The event to use as a reference, to determine if a purge is needed.
+ """
+ if not self.purge_orphaned_data:
+ return
+ ## Check if the event happened after a crash, and purge expired tags.
+ if self.file_version and self.file_version >= 2:
+ ## If the file_version is recent enough, use the SessionLog enum
+ ## to check for restarts.
+ self._CheckForRestartAndMaybePurge(event)
+ else:
+ ## If there is no file version, default to old logic of checking for
+ ## out of order steps.
+ self._CheckForOutOfOrderStepAndMaybePurge(event)
+
+ def _CheckForRestartAndMaybePurge(self, event):
+ """Check and discard expired events using SessionLog.START.
+
+ Check for a SessionLog.START event and purge all previously seen events
+ with larger steps, because they are out of date. Because of supervisor
+ threading, it is possible that this logic will cause the first few event
+ messages to be discarded since supervisor threading does not guarantee
+ that the START message is deterministically written first.
+
+ This method is preferred over _CheckForOutOfOrderStepAndMaybePurge which
+ can inadvertently discard events due to supervisor threading.
+
+ Args:
+ event: The event to use as reference. If the event is a START event, all
+ previously seen events with a greater event.step will be purged.
+ """
+ if (
+ event.HasField("session_log")
+ and event.session_log.status == event_pb2.SessionLog.START
+ ):
+ self._Purge(event, by_tags=False)
+
+ def _CheckForOutOfOrderStepAndMaybePurge(self, event):
+ """Check for out-of-order event.step and discard expired events for
+ tags.
+
+ Check if the event is out of order relative to the global most recent step.
+ If it is, purge outdated summaries for tags that the event contains.
+
+ Args:
+ event: The event to use as reference. If the event is out-of-order, all
+ events with the same tags, but with a greater event.step will be purged.
+ """
+ if event.step < self.most_recent_step and event.HasField("summary"):
+ self._Purge(event, by_tags=True)
+ else:
+ self.most_recent_step = event.step
+ self.most_recent_wall_time = event.wall_time
+
+ def _ConvertHistogramProtoToTuple(self, histo):
+ return HistogramValue(
+ min=histo.min,
+ max=histo.max,
+ num=histo.num,
+ sum=histo.sum,
+ sum_squares=histo.sum_squares,
+ bucket_limit=list(histo.bucket_limit),
+ bucket=list(histo.bucket),
+ )
+
+ def _ProcessHistogram(self, tag, wall_time, step, histo):
+ """Processes a proto histogram by adding it to accumulated state."""
+ histo = self._ConvertHistogramProtoToTuple(histo)
+ histo_ev = HistogramEvent(wall_time, step, histo)
+ self.histograms.AddItem(tag, histo_ev)
+ self.compressed_histograms.AddItem(
+ tag, histo_ev, self._CompressHistogram
+ )
+
+ def _CompressHistogram(self, histo_ev):
+ """Callback for _ProcessHistogram."""
+ return CompressedHistogramEvent(
+ histo_ev.wall_time,
+ histo_ev.step,
+ compressor.compress_histogram_proto(
+ histo_ev.histogram_value, self._compression_bps
+ ),
+ )
+
+ def _ProcessImage(self, tag, wall_time, step, image):
+ """Processes an image by adding it to accumulated state."""
+ event = ImageEvent(
+ wall_time=wall_time,
+ step=step,
+ encoded_image_string=image.encoded_image_string,
+ width=image.width,
+ height=image.height,
+ )
+ self.images.AddItem(tag, event)
+
+ def _ProcessAudio(self, tag, wall_time, step, audio):
+ """Processes a audio by adding it to accumulated state."""
+ event = AudioEvent(
+ wall_time=wall_time,
+ step=step,
+ encoded_audio_string=audio.encoded_audio_string,
+ content_type=audio.content_type,
+ sample_rate=audio.sample_rate,
+ length_frames=audio.length_frames,
+ )
+ self.audios.AddItem(tag, event)
+
+ def _ProcessScalar(self, tag, wall_time, step, scalar):
+ """Processes a simple value by adding it to accumulated state."""
+ sv = ScalarEvent(wall_time=wall_time, step=step, value=scalar)
+ self.scalars.AddItem(tag, sv)
+
+ def _ProcessTensor(self, tag, wall_time, step, tensor):
+ tv = TensorEvent(wall_time=wall_time, step=step, tensor_proto=tensor)
+ self.tensors.AddItem(tag, tv)
+
+ def _Purge(self, event, by_tags):
+ """Purge all events that have occurred after the given event.step.
+
+ If by_tags is True, purge all events that occurred after the given
+ event.step, but only for the tags that the event has. Non-sequential
+ event.steps suggest that a TensorFlow restart occurred, and we discard
+ the out-of-order events to display a consistent view in TensorBoard.
+
+ Discarding by tags is the safer method, when we are unsure whether a restart
+ has occurred, given that threading in supervisor can cause events of
+ different tags to arrive with unsynchronized step values.
+
+ If by_tags is False, then purge all events with event.step greater than the
+ given event.step. This can be used when we are certain that a TensorFlow
+ restart has occurred and these events can be discarded.
+
+ Args:
+ event: The event to use as reference for the purge. All events with
+ the same tags, but with a greater event.step will be purged.
+ by_tags: Bool to dictate whether to discard all out-of-order events or
+ only those that are associated with the given reference event.
+ """
+ ## Keep data in reservoirs that has a step less than event.step
+ _NotExpired = lambda x: x.step < event.step
+
+ if by_tags:
+
+ def _ExpiredPerTag(value):
+ return [
+ getattr(self, x).FilterItems(_NotExpired, value.tag)
+ for x in self.accumulated_attrs
+ ]
+
+ expired_per_tags = [
+ _ExpiredPerTag(value) for value in event.summary.value
+ ]
+ expired_per_type = [sum(x) for x in zip(*expired_per_tags)]
+ else:
+ expired_per_type = [
+ getattr(self, x).FilterItems(_NotExpired)
+ for x in self.accumulated_attrs
+ ]
+
+ if sum(expired_per_type) > 0:
+ purge_msg = _GetPurgeMessage(
+ self.most_recent_step,
+ self.most_recent_wall_time,
+ event.step,
+ event.wall_time,
+ *expired_per_type
+ )
+ logger.warn(purge_msg)
+
+
+def _GetPurgeMessage(
+ most_recent_step,
+ most_recent_wall_time,
+ event_step,
+ event_wall_time,
+ num_expired_scalars,
+ num_expired_histos,
+ num_expired_comp_histos,
+ num_expired_images,
+ num_expired_audio,
+):
+ """Return the string message associated with TensorBoard purges."""
+ return (
+ "Detected out of order event.step likely caused by "
+ "a TensorFlow restart. Purging expired events from Tensorboard"
+ " display between the previous step: {} (timestamp: {}) and "
+ "current step: {} (timestamp: {}). Removing {} scalars, {} "
+ "histograms, {} compressed histograms, {} images, "
+ "and {} audio."
+ ).format(
+ most_recent_step,
+ most_recent_wall_time,
+ event_step,
+ event_wall_time,
+ num_expired_scalars,
+ num_expired_histos,
+ num_expired_comp_histos,
+ num_expired_images,
+ num_expired_audio,
+ )
- Returns:
- The metadata in form of `RunMetadata` proto.
- """
- if tag not in self._tagged_metadata:
- raise ValueError('There is no run metadata with this tag name')
- run_metadata = config_pb2.RunMetadata()
- run_metadata.ParseFromString(self._tagged_metadata[tag])
- return run_metadata
-
- def Histograms(self, tag):
- """Given a summary tag, return all associated histograms.
-
- Args:
- tag: A string tag associated with the events.
-
- Raises:
- KeyError: If the tag is not found.
-
- Returns:
- An array of `HistogramEvent`s.
- """
- return self.histograms.Items(tag)
-
- def CompressedHistograms(self, tag):
- """Given a summary tag, return all associated compressed histograms.
-
- Args:
- tag: A string tag associated with the events.
-
- Raises:
- KeyError: If the tag is not found.
-
- Returns:
- An array of `CompressedHistogramEvent`s.
- """
- return self.compressed_histograms.Items(tag)
-
- def Images(self, tag):
- """Given a summary tag, return all associated images.
-
- Args:
- tag: A string tag associated with the events.
-
- Raises:
- KeyError: If the tag is not found.
-
- Returns:
- An array of `ImageEvent`s.
- """
- return self.images.Items(tag)
-
- def Audio(self, tag):
- """Given a summary tag, return all associated audio.
-
- Args:
- tag: A string tag associated with the events.
-
- Raises:
- KeyError: If the tag is not found.
-
- Returns:
- An array of `AudioEvent`s.
- """
- return self.audios.Items(tag)
-
- def Tensors(self, tag):
- """Given a summary tag, return all associated tensors.
-
- Args:
- tag: A string tag associated with the events.
-
- Raises:
- KeyError: If the tag is not found.
-
- Returns:
- An array of `TensorEvent`s.
- """
- return self.tensors.Items(tag)
-
- def _MaybePurgeOrphanedData(self, event):
- """Maybe purge orphaned data due to a TensorFlow crash.
-
- When TensorFlow crashes at step T+O and restarts at step T, any events
- written after step T are now "orphaned" and will be at best misleading if
- they are included in TensorBoard.
-
- This logic attempts to determine if there is orphaned data, and purge it
- if it is found.
-
- Args:
- event: The event to use as a reference, to determine if a purge is needed.
- """
- if not self.purge_orphaned_data:
- return
- ## Check if the event happened after a crash, and purge expired tags.
- if self.file_version and self.file_version >= 2:
- ## If the file_version is recent enough, use the SessionLog enum
- ## to check for restarts.
- self._CheckForRestartAndMaybePurge(event)
+def _GeneratorFromPath(path):
+ """Create an event generator for file or directory at given path string."""
+ if not path:
+ raise ValueError("path must be a valid string")
+ if io_wrapper.IsTensorFlowEventsFile(path):
+ return event_file_loader.EventFileLoader(path)
else:
- ## If there is no file version, default to old logic of checking for
- ## out of order steps.
- self._CheckForOutOfOrderStepAndMaybePurge(event)
-
- def _CheckForRestartAndMaybePurge(self, event):
- """Check and discard expired events using SessionLog.START.
-
- Check for a SessionLog.START event and purge all previously seen events
- with larger steps, because they are out of date. Because of supervisor
- threading, it is possible that this logic will cause the first few event
- messages to be discarded since supervisor threading does not guarantee
- that the START message is deterministically written first.
-
- This method is preferred over _CheckForOutOfOrderStepAndMaybePurge which
- can inadvertently discard events due to supervisor threading.
-
- Args:
- event: The event to use as reference. If the event is a START event, all
- previously seen events with a greater event.step will be purged.
- """
- if event.HasField(
- 'session_log') and event.session_log.status == event_pb2.SessionLog.START:
- self._Purge(event, by_tags=False)
+ return directory_watcher.DirectoryWatcher(
+ path,
+ event_file_loader.EventFileLoader,
+ io_wrapper.IsTensorFlowEventsFile,
+ )
- def _CheckForOutOfOrderStepAndMaybePurge(self, event):
- """Check for out-of-order event.step and discard expired events for tags.
- Check if the event is out of order relative to the global most recent step.
- If it is, purge outdated summaries for tags that the event contains.
+def _ParseFileVersion(file_version):
+ """Convert the string file_version in event.proto into a float.
Args:
- event: The event to use as reference. If the event is out-of-order, all
- events with the same tags, but with a greater event.step will be purged.
- """
- if event.step < self.most_recent_step and event.HasField('summary'):
- self._Purge(event, by_tags=True)
- else:
- self.most_recent_step = event.step
- self.most_recent_wall_time = event.wall_time
-
- def _ConvertHistogramProtoToTuple(self, histo):
- return HistogramValue(min=histo.min,
- max=histo.max,
- num=histo.num,
- sum=histo.sum,
- sum_squares=histo.sum_squares,
- bucket_limit=list(histo.bucket_limit),
- bucket=list(histo.bucket))
-
- def _ProcessHistogram(self, tag, wall_time, step, histo):
- """Processes a proto histogram by adding it to accumulated state."""
- histo = self._ConvertHistogramProtoToTuple(histo)
- histo_ev = HistogramEvent(wall_time, step, histo)
- self.histograms.AddItem(tag, histo_ev)
- self.compressed_histograms.AddItem(tag, histo_ev, self._CompressHistogram)
-
- def _CompressHistogram(self, histo_ev):
- """Callback for _ProcessHistogram."""
- return CompressedHistogramEvent(
- histo_ev.wall_time,
- histo_ev.step,
- compressor.compress_histogram_proto(
- histo_ev.histogram_value, self._compression_bps))
-
- def _ProcessImage(self, tag, wall_time, step, image):
- """Processes an image by adding it to accumulated state."""
- event = ImageEvent(wall_time=wall_time,
- step=step,
- encoded_image_string=image.encoded_image_string,
- width=image.width,
- height=image.height)
- self.images.AddItem(tag, event)
-
- def _ProcessAudio(self, tag, wall_time, step, audio):
- """Processes a audio by adding it to accumulated state."""
- event = AudioEvent(wall_time=wall_time,
- step=step,
- encoded_audio_string=audio.encoded_audio_string,
- content_type=audio.content_type,
- sample_rate=audio.sample_rate,
- length_frames=audio.length_frames)
- self.audios.AddItem(tag, event)
-
- def _ProcessScalar(self, tag, wall_time, step, scalar):
- """Processes a simple value by adding it to accumulated state."""
- sv = ScalarEvent(wall_time=wall_time, step=step, value=scalar)
- self.scalars.AddItem(tag, sv)
-
- def _ProcessTensor(self, tag, wall_time, step, tensor):
- tv = TensorEvent(wall_time=wall_time, step=step, tensor_proto=tensor)
- self.tensors.AddItem(tag, tv)
-
- def _Purge(self, event, by_tags):
- """Purge all events that have occurred after the given event.step.
-
- If by_tags is True, purge all events that occurred after the given
- event.step, but only for the tags that the event has. Non-sequential
- event.steps suggest that a TensorFlow restart occurred, and we discard
- the out-of-order events to display a consistent view in TensorBoard.
-
- Discarding by tags is the safer method, when we are unsure whether a restart
- has occurred, given that threading in supervisor can cause events of
- different tags to arrive with unsynchronized step values.
-
- If by_tags is False, then purge all events with event.step greater than the
- given event.step. This can be used when we are certain that a TensorFlow
- restart has occurred and these events can be discarded.
+ file_version: String file_version from event.proto
- Args:
- event: The event to use as reference for the purge. All events with
- the same tags, but with a greater event.step will be purged.
- by_tags: Bool to dictate whether to discard all out-of-order events or
- only those that are associated with the given reference event.
+ Returns:
+ Version number as a float.
"""
- ## Keep data in reservoirs that has a step less than event.step
- _NotExpired = lambda x: x.step < event.step
-
- if by_tags:
- def _ExpiredPerTag(value):
- return [getattr(self, x).FilterItems(_NotExpired, value.tag)
- for x in self.accumulated_attrs]
-
- expired_per_tags = [_ExpiredPerTag(value)
- for value in event.summary.value]
- expired_per_type = [sum(x) for x in zip(*expired_per_tags)]
- else:
- expired_per_type = [getattr(self, x).FilterItems(_NotExpired)
- for x in self.accumulated_attrs]
-
- if sum(expired_per_type) > 0:
- purge_msg = _GetPurgeMessage(self.most_recent_step,
- self.most_recent_wall_time, event.step,
- event.wall_time, *expired_per_type)
- logger.warn(purge_msg)
-
-
-def _GetPurgeMessage(most_recent_step, most_recent_wall_time, event_step,
- event_wall_time, num_expired_scalars, num_expired_histos,
- num_expired_comp_histos, num_expired_images,
- num_expired_audio):
- """Return the string message associated with TensorBoard purges."""
- return ('Detected out of order event.step likely caused by '
- 'a TensorFlow restart. Purging expired events from Tensorboard'
- ' display between the previous step: {} (timestamp: {}) and '
- 'current step: {} (timestamp: {}). Removing {} scalars, {} '
- 'histograms, {} compressed histograms, {} images, '
- 'and {} audio.').format(most_recent_step, most_recent_wall_time,
- event_step, event_wall_time,
- num_expired_scalars, num_expired_histos,
- num_expired_comp_histos, num_expired_images,
- num_expired_audio)
-
-
-def _GeneratorFromPath(path):
- """Create an event generator for file or directory at given path string."""
- if not path:
- raise ValueError('path must be a valid string')
- if io_wrapper.IsTensorFlowEventsFile(path):
- return event_file_loader.EventFileLoader(path)
- else:
- return directory_watcher.DirectoryWatcher(
- path,
- event_file_loader.EventFileLoader,
- io_wrapper.IsTensorFlowEventsFile)
-
-
-def _ParseFileVersion(file_version):
- """Convert the string file_version in event.proto into a float.
-
- Args:
- file_version: String file_version from event.proto
-
- Returns:
- Version number as a float.
- """
- tokens = file_version.split('brain.Event:')
- try:
- return float(tokens[-1])
- except ValueError:
- ## This should never happen according to the definition of file_version
- ## specified in event.proto.
- logger.warn(
- ('Invalid event.proto file_version. Defaulting to use of '
- 'out-of-order event.step logic for purging expired events.'))
- return -1
+ tokens = file_version.split("brain.Event:")
+ try:
+ return float(tokens[-1])
+ except ValueError:
+ ## This should never happen according to the definition of file_version
+ ## specified in event.proto.
+ logger.warn(
+ (
+ "Invalid event.proto file_version. Defaulting to use of "
+ "out-of-order event.step logic for purging expired events."
+ )
+ )
+ return -1
diff --git a/tensorboard/backend/event_processing/event_accumulator_test.py b/tensorboard/backend/event_processing/event_accumulator_test.py
index 7a857e78e8..89c55cf076 100644
--- a/tensorboard/backend/event_processing/event_accumulator_test.py
+++ b/tensorboard/backend/event_processing/event_accumulator_test.py
@@ -40,918 +40,1020 @@
class _EventGenerator(object):
- """Class that can add_events and then yield them back.
-
- Satisfies the EventGenerator API required for the EventAccumulator.
- Satisfies the EventWriter API required to create a tf.summary.FileWriter.
-
- Has additional convenience methods for adding test events.
- """
-
- def __init__(self, testcase, zero_out_timestamps=False):
- self._testcase = testcase
- self.items = []
- self.zero_out_timestamps = zero_out_timestamps
-
- def Load(self):
- while self.items:
- yield self.items.pop(0)
-
- def AddScalar(self, tag, wall_time=0, step=0, value=0):
- event = event_pb2.Event(
- wall_time=wall_time,
- step=step,
- summary=summary_pb2.Summary(
- value=[summary_pb2.Summary.Value(tag=tag, simple_value=value)]))
- self.AddEvent(event)
-
- def AddHistogram(self,
- tag,
- wall_time=0,
- step=0,
- hmin=1,
- hmax=2,
- hnum=3,
- hsum=4,
- hsum_squares=5,
- hbucket_limit=None,
- hbucket=None):
- histo = summary_pb2.HistogramProto(
- min=hmin,
- max=hmax,
- num=hnum,
- sum=hsum,
- sum_squares=hsum_squares,
- bucket_limit=hbucket_limit,
- bucket=hbucket)
- event = event_pb2.Event(
- wall_time=wall_time,
- step=step,
- summary=summary_pb2.Summary(
- value=[summary_pb2.Summary.Value(tag=tag, histo=histo)]))
- self.AddEvent(event)
-
- def AddImage(self,
- tag,
- wall_time=0,
- step=0,
- encoded_image_string=b'imgstr',
- width=150,
- height=100):
- image = summary_pb2.Summary.Image(
- encoded_image_string=encoded_image_string, width=width, height=height)
- event = event_pb2.Event(
- wall_time=wall_time,
- step=step,
- summary=summary_pb2.Summary(
- value=[summary_pb2.Summary.Value(tag=tag, image=image)]))
- self.AddEvent(event)
-
- def AddAudio(self,
- tag,
- wall_time=0,
- step=0,
- encoded_audio_string=b'sndstr',
- content_type='audio/wav',
- sample_rate=44100,
- length_frames=22050):
- audio = summary_pb2.Summary.Audio(
- encoded_audio_string=encoded_audio_string,
- content_type=content_type,
- sample_rate=sample_rate,
- length_frames=length_frames)
- event = event_pb2.Event(
- wall_time=wall_time,
- step=step,
- summary=summary_pb2.Summary(
- value=[summary_pb2.Summary.Value(tag=tag, audio=audio)]))
- self.AddEvent(event)
-
- def AddEvent(self, event):
- if self.zero_out_timestamps:
- event.wall_time = 0
- self.items.append(event)
-
- def add_event(self, event): # pylint: disable=invalid-name
- """Match the EventWriter API."""
- self.AddEvent(event)
-
- def get_logdir(self): # pylint: disable=invalid-name
- """Return a temp directory for asset writing."""
- return self._testcase.get_temp_dir()
-
- def close(self):
- """Closes the event writer"""
- # noop
+ """Class that can add_events and then yield them back.
+ Satisfies the EventGenerator API required for the EventAccumulator.
+ Satisfies the EventWriter API required to create a tf.summary.FileWriter.
-class EventAccumulatorTest(tf.test.TestCase):
-
- def assertTagsEqual(self, actual, expected):
- """Utility method for checking the return value of the Tags() call.
-
- It fills out the `expected` arg with the default (empty) values for every
- tag type, so that the author needs only specify the non-empty values they
- are interested in testing.
-
- Args:
- actual: The actual Accumulator tags response.
- expected: The expected tags response (empty fields may be omitted)
+ Has additional convenience methods for adding test events.
"""
- empty_tags = {
- ea.IMAGES: [],
- ea.AUDIO: [],
- ea.SCALARS: [],
- ea.HISTOGRAMS: [],
- ea.COMPRESSED_HISTOGRAMS: [],
- ea.GRAPH: False,
- ea.META_GRAPH: False,
- ea.RUN_METADATA: [],
- ea.TENSORS: [],
- }
-
- # Verifies that there are no unexpected keys in the actual response.
- # If this line fails, likely you added a new tag type, and need to update
- # the empty_tags dictionary above.
- self.assertItemsEqual(actual.keys(), empty_tags.keys())
-
- for key in actual:
- expected_value = expected.get(key, empty_tags[key])
- if isinstance(expected_value, list):
- self.assertItemsEqual(actual[key], expected_value)
- else:
- self.assertEqual(actual[key], expected_value)
-
-
-class MockingEventAccumulatorTest(EventAccumulatorTest):
-
- def setUp(self):
- super(MockingEventAccumulatorTest, self).setUp()
- self.stubs = tf.compat.v1.test.StubOutForTesting()
- self._real_constructor = ea.EventAccumulator
- self._real_generator = ea._GeneratorFromPath
-
- def _FakeAccumulatorConstructor(generator, *args, **kwargs):
- ea._GeneratorFromPath = lambda x: generator
- return self._real_constructor(generator, *args, **kwargs)
-
- ea.EventAccumulator = _FakeAccumulatorConstructor
-
- def tearDown(self):
- self.stubs.CleanUp()
- ea.EventAccumulator = self._real_constructor
- ea._GeneratorFromPath = self._real_generator
-
- def testEmptyAccumulator(self):
- gen = _EventGenerator(self)
- x = ea.EventAccumulator(gen)
- x.Reload()
- self.assertTagsEqual(x.Tags(), {})
-
- def testTags(self):
- """Tags should be found in EventAccumulator after adding some events."""
- gen = _EventGenerator(self)
- gen.AddScalar('s1')
- gen.AddScalar('s2')
- gen.AddHistogram('hst1')
- gen.AddHistogram('hst2')
- gen.AddImage('im1')
- gen.AddImage('im2')
- gen.AddAudio('snd1')
- gen.AddAudio('snd2')
- acc = ea.EventAccumulator(gen)
- acc.Reload()
- self.assertTagsEqual(acc.Tags(), {
- ea.IMAGES: ['im1', 'im2'],
- ea.AUDIO: ['snd1', 'snd2'],
- ea.SCALARS: ['s1', 's2'],
- ea.HISTOGRAMS: ['hst1', 'hst2'],
- ea.COMPRESSED_HISTOGRAMS: ['hst1', 'hst2'],
- })
-
- def testReload(self):
- """EventAccumulator contains suitable tags after calling Reload."""
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
- acc.Reload()
- self.assertTagsEqual(acc.Tags(), {})
- gen.AddScalar('s1')
- gen.AddScalar('s2')
- gen.AddHistogram('hst1')
- gen.AddHistogram('hst2')
- gen.AddImage('im1')
- gen.AddImage('im2')
- gen.AddAudio('snd1')
- gen.AddAudio('snd2')
- acc.Reload()
- self.assertTagsEqual(acc.Tags(), {
- ea.IMAGES: ['im1', 'im2'],
- ea.AUDIO: ['snd1', 'snd2'],
- ea.SCALARS: ['s1', 's2'],
- ea.HISTOGRAMS: ['hst1', 'hst2'],
- ea.COMPRESSED_HISTOGRAMS: ['hst1', 'hst2'],
- })
-
- def testScalars(self):
- """Tests whether EventAccumulator contains scalars after adding them."""
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
- s1 = ea.ScalarEvent(wall_time=1, step=10, value=32)
- s2 = ea.ScalarEvent(wall_time=2, step=12, value=64)
- gen.AddScalar('s1', wall_time=1, step=10, value=32)
- gen.AddScalar('s2', wall_time=2, step=12, value=64)
- acc.Reload()
- self.assertEqual(acc.Scalars('s1'), [s1])
- self.assertEqual(acc.Scalars('s2'), [s2])
-
- def testHistograms(self):
- """Tests whether histograms are inserted into EventAccumulator."""
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
-
- val1 = ea.HistogramValue(
- min=1,
- max=2,
- num=3,
- sum=4,
- sum_squares=5,
- bucket_limit=[1, 2, 3],
- bucket=[0, 3, 0])
- val2 = ea.HistogramValue(
- min=-2,
- max=3,
- num=4,
- sum=5,
- sum_squares=6,
- bucket_limit=[2, 3, 4],
- bucket=[1, 3, 0])
-
- hst1 = ea.HistogramEvent(wall_time=1, step=10, histogram_value=val1)
- hst2 = ea.HistogramEvent(wall_time=2, step=12, histogram_value=val2)
- gen.AddHistogram(
- 'hst1',
- wall_time=1,
- step=10,
+ def __init__(self, testcase, zero_out_timestamps=False):
+ self._testcase = testcase
+ self.items = []
+ self.zero_out_timestamps = zero_out_timestamps
+
+ def Load(self):
+ while self.items:
+ yield self.items.pop(0)
+
+ def AddScalar(self, tag, wall_time=0, step=0, value=0):
+ event = event_pb2.Event(
+ wall_time=wall_time,
+ step=step,
+ summary=summary_pb2.Summary(
+ value=[summary_pb2.Summary.Value(tag=tag, simple_value=value)]
+ ),
+ )
+ self.AddEvent(event)
+
+ def AddHistogram(
+ self,
+ tag,
+ wall_time=0,
+ step=0,
hmin=1,
hmax=2,
hnum=3,
hsum=4,
hsum_squares=5,
- hbucket_limit=[1, 2, 3],
- hbucket=[0, 3, 0])
- gen.AddHistogram(
- 'hst2',
- wall_time=2,
- step=12,
- hmin=-2,
- hmax=3,
- hnum=4,
- hsum=5,
- hsum_squares=6,
- hbucket_limit=[2, 3, 4],
- hbucket=[1, 3, 0])
- acc.Reload()
- self.assertEqual(acc.Histograms('hst1'), [hst1])
- self.assertEqual(acc.Histograms('hst2'), [hst2])
-
- def testCompressedHistograms(self):
- """Tests compressed histograms inserted into EventAccumulator."""
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen, compression_bps=(0, 2500, 5000, 7500, 10000))
-
- gen.AddHistogram(
- 'hst1',
- wall_time=1,
- step=10,
- hmin=1,
- hmax=2,
- hnum=3,
- hsum=4,
- hsum_squares=5,
- hbucket_limit=[1, 2, 3],
- hbucket=[0, 3, 0])
- gen.AddHistogram(
- 'hst2',
- wall_time=2,
- step=12,
- hmin=-2,
- hmax=3,
- hnum=4,
- hsum=5,
- hsum_squares=6,
- hbucket_limit=[2, 3, 4],
- hbucket=[1, 3, 0])
- acc.Reload()
-
- # Create the expected values after compressing hst1
- expected_vals1 = [
- compressor.CompressedHistogramValue(bp, val)
- for bp, val in [(0, 1.0), (2500, 1.25), (5000, 1.5), (7500, 1.75
- ), (10000, 2.0)]
- ]
- expected_cmphst1 = ea.CompressedHistogramEvent(
- wall_time=1, step=10, compressed_histogram_values=expected_vals1)
- self.assertEqual(acc.CompressedHistograms('hst1'), [expected_cmphst1])
-
- # Create the expected values after compressing hst2
- expected_vals2 = [
- compressor.CompressedHistogramValue(bp, val)
- for bp, val in [(0, -2),
- (2500, 2),
- (5000, 2 + 1 / 3),
- (7500, 2 + 2 / 3),
- (10000, 3)]
- ]
- expected_cmphst2 = ea.CompressedHistogramEvent(
- wall_time=2, step=12, compressed_histogram_values=expected_vals2)
- self.assertEqual(acc.CompressedHistograms('hst2'), [expected_cmphst2])
-
- def testImages(self):
- """Tests 2 images inserted/accessed in EventAccumulator."""
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
- im1 = ea.ImageEvent(
- wall_time=1,
- step=10,
- encoded_image_string=b'big',
- width=400,
- height=300)
- im2 = ea.ImageEvent(
- wall_time=2,
- step=12,
- encoded_image_string=b'small',
- width=40,
- height=30)
- gen.AddImage(
- 'im1',
- wall_time=1,
- step=10,
- encoded_image_string=b'big',
- width=400,
- height=300)
- gen.AddImage(
- 'im2',
- wall_time=2,
- step=12,
- encoded_image_string=b'small',
- width=40,
- height=30)
- acc.Reload()
- self.assertEqual(acc.Images('im1'), [im1])
- self.assertEqual(acc.Images('im2'), [im2])
-
- def testAudio(self):
- """Tests 2 audio events inserted/accessed in EventAccumulator."""
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
- snd1 = ea.AudioEvent(
- wall_time=1,
- step=10,
- encoded_audio_string=b'big',
- content_type='audio/wav',
- sample_rate=44100,
- length_frames=441000)
- snd2 = ea.AudioEvent(
- wall_time=2,
- step=12,
- encoded_audio_string=b'small',
- content_type='audio/wav',
+ hbucket_limit=None,
+ hbucket=None,
+ ):
+ histo = summary_pb2.HistogramProto(
+ min=hmin,
+ max=hmax,
+ num=hnum,
+ sum=hsum,
+ sum_squares=hsum_squares,
+ bucket_limit=hbucket_limit,
+ bucket=hbucket,
+ )
+ event = event_pb2.Event(
+ wall_time=wall_time,
+ step=step,
+ summary=summary_pb2.Summary(
+ value=[summary_pb2.Summary.Value(tag=tag, histo=histo)]
+ ),
+ )
+ self.AddEvent(event)
+
+ def AddImage(
+ self,
+ tag,
+ wall_time=0,
+ step=0,
+ encoded_image_string=b"imgstr",
+ width=150,
+ height=100,
+ ):
+ image = summary_pb2.Summary.Image(
+ encoded_image_string=encoded_image_string,
+ width=width,
+ height=height,
+ )
+ event = event_pb2.Event(
+ wall_time=wall_time,
+ step=step,
+ summary=summary_pb2.Summary(
+ value=[summary_pb2.Summary.Value(tag=tag, image=image)]
+ ),
+ )
+ self.AddEvent(event)
+
+ def AddAudio(
+ self,
+ tag,
+ wall_time=0,
+ step=0,
+ encoded_audio_string=b"sndstr",
+ content_type="audio/wav",
sample_rate=44100,
- length_frames=44100)
- gen.AddAudio(
- 'snd1',
- wall_time=1,
- step=10,
- encoded_audio_string=b'big',
- content_type='audio/wav',
- sample_rate=44100,
- length_frames=441000)
- gen.AddAudio(
- 'snd2',
- wall_time=2,
- step=12,
- encoded_audio_string=b'small',
- content_type='audio/wav',
- sample_rate=44100,
- length_frames=44100)
- acc.Reload()
- self.assertEqual(acc.Audio('snd1'), [snd1])
- self.assertEqual(acc.Audio('snd2'), [snd2])
-
- def testKeyError(self):
- """KeyError should be raised when accessing non-existing keys."""
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
- acc.Reload()
- with self.assertRaises(KeyError):
- acc.Scalars('s1')
- with self.assertRaises(KeyError):
- acc.Scalars('hst1')
- with self.assertRaises(KeyError):
- acc.Scalars('im1')
- with self.assertRaises(KeyError):
- acc.Histograms('s1')
- with self.assertRaises(KeyError):
- acc.Histograms('im1')
- with self.assertRaises(KeyError):
- acc.Images('s1')
- with self.assertRaises(KeyError):
- acc.Images('hst1')
- with self.assertRaises(KeyError):
- acc.Audio('s1')
- with self.assertRaises(KeyError):
- acc.Audio('hst1')
-
- def testNonValueEvents(self):
- """Non-value events in the generator don't cause early exits."""
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
- gen.AddScalar('s1', wall_time=1, step=10, value=20)
- gen.AddEvent(
- event_pb2.Event(wall_time=2, step=20, file_version='nots2'))
- gen.AddScalar('s3', wall_time=3, step=100, value=1)
- gen.AddHistogram('hst1')
- gen.AddImage('im1')
- gen.AddAudio('snd1')
-
- acc.Reload()
- self.assertTagsEqual(acc.Tags(), {
- ea.IMAGES: ['im1'],
- ea.AUDIO: ['snd1'],
- ea.SCALARS: ['s1', 's3'],
- ea.HISTOGRAMS: ['hst1'],
- ea.COMPRESSED_HISTOGRAMS: ['hst1'],
- })
-
- def testExpiredDataDiscardedAfterRestartForFileVersionLessThan2(self):
- """Tests that events are discarded after a restart is detected.
-
- If a step value is observed to be lower than what was previously seen,
- this should force a discard of all previous items with the same tag
- that are outdated.
-
- Only file versions < 2 use this out-of-order discard logic. Later versions
- discard events based on the step value of SessionLog.START.
- """
- warnings = []
- self.stubs.Set(logger, 'warn', warnings.append)
-
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
-
- gen.AddEvent(
- event_pb2.Event(wall_time=0, step=0, file_version='brain.Event:1'))
- gen.AddScalar('s1', wall_time=1, step=100, value=20)
- gen.AddScalar('s1', wall_time=1, step=200, value=20)
- gen.AddScalar('s1', wall_time=1, step=300, value=20)
- acc.Reload()
- ## Check that number of items are what they should be
- self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200, 300])
-
- gen.AddScalar('s1', wall_time=1, step=101, value=20)
- gen.AddScalar('s1', wall_time=1, step=201, value=20)
- gen.AddScalar('s1', wall_time=1, step=301, value=20)
- acc.Reload()
- ## Check that we have discarded 200 and 300 from s1
- self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 101, 201, 301])
-
- def testOrphanedDataNotDiscardedIfFlagUnset(self):
- """Tests that events are not discarded if purge_orphaned_data is false.
- """
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen, purge_orphaned_data=False)
-
- gen.AddEvent(
- event_pb2.Event(wall_time=0, step=0, file_version='brain.Event:1'))
- gen.AddScalar('s1', wall_time=1, step=100, value=20)
- gen.AddScalar('s1', wall_time=1, step=200, value=20)
- gen.AddScalar('s1', wall_time=1, step=300, value=20)
- acc.Reload()
- ## Check that number of items are what they should be
- self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200, 300])
-
- gen.AddScalar('s1', wall_time=1, step=101, value=20)
- gen.AddScalar('s1', wall_time=1, step=201, value=20)
- gen.AddScalar('s1', wall_time=1, step=301, value=20)
- acc.Reload()
- ## Check that we have NOT discarded 200 and 300 from s1
- self.assertEqual([x.step for x in acc.Scalars('s1')],
- [100, 200, 300, 101, 201, 301])
-
- def testEventsDiscardedPerTagAfterRestartForFileVersionLessThan2(self):
- """Tests that event discards after restart, only affect the misordered tag.
-
- If a step value is observed to be lower than what was previously seen,
- this should force a discard of all previous items that are outdated, but
- only for the out of order tag. Other tags should remain unaffected.
-
- Only file versions < 2 use this out-of-order discard logic. Later versions
- discard events based on the step value of SessionLog.START.
- """
- warnings = []
- self.stubs.Set(logger, 'warn', warnings.append)
-
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
-
- gen.AddEvent(
- event_pb2.Event(wall_time=0, step=0, file_version='brain.Event:1'))
- gen.AddScalar('s1', wall_time=1, step=100, value=20)
- gen.AddScalar('s1', wall_time=1, step=200, value=20)
- gen.AddScalar('s1', wall_time=1, step=300, value=20)
- gen.AddScalar('s1', wall_time=1, step=101, value=20)
- gen.AddScalar('s1', wall_time=1, step=201, value=20)
- gen.AddScalar('s1', wall_time=1, step=301, value=20)
-
- gen.AddScalar('s2', wall_time=1, step=101, value=20)
- gen.AddScalar('s2', wall_time=1, step=201, value=20)
- gen.AddScalar('s2', wall_time=1, step=301, value=20)
-
- acc.Reload()
- ## Check that we have discarded 200 and 300
- self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 101, 201, 301])
-
- ## Check that s1 discards do not affect s2
- ## i.e. check that only events from the out of order tag are discarded
- self.assertEqual([x.step for x in acc.Scalars('s2')], [101, 201, 301])
-
- def testOnlySummaryEventsTriggerDiscards(self):
- """Test that file version event does not trigger data purge."""
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
- gen.AddScalar('s1', wall_time=1, step=100, value=20)
- ev1 = event_pb2.Event(wall_time=2, step=0, file_version='brain.Event:1')
- graph_bytes = tf.compat.v1.GraphDef().SerializeToString()
- ev2 = event_pb2.Event(wall_time=3, step=0, graph_def=graph_bytes)
- gen.AddEvent(ev1)
- gen.AddEvent(ev2)
- acc.Reload()
- self.assertEqual([x.step for x in acc.Scalars('s1')], [100])
-
- def testSessionLogStartMessageDiscardsExpiredEvents(self):
- """Test that SessionLog.START message discards expired events.
-
- This discard logic is preferred over the out-of-order step discard logic,
- but this logic can only be used for event protos which have the SessionLog
- enum, which was introduced to event.proto for file_version >= brain.Event:2.
- """
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
- gen.AddEvent(
- event_pb2.Event(wall_time=0, step=1, file_version='brain.Event:2'))
-
- gen.AddScalar('s1', wall_time=1, step=100, value=20)
- gen.AddScalar('s1', wall_time=1, step=200, value=20)
- gen.AddScalar('s1', wall_time=1, step=300, value=20)
- gen.AddScalar('s1', wall_time=1, step=400, value=20)
-
- gen.AddScalar('s2', wall_time=1, step=202, value=20)
- gen.AddScalar('s2', wall_time=1, step=203, value=20)
-
- slog = event_pb2.SessionLog(status=event_pb2.SessionLog.START)
- gen.AddEvent(
- event_pb2.Event(wall_time=2, step=201, session_log=slog))
- acc.Reload()
- self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200])
- self.assertEqual([x.step for x in acc.Scalars('s2')], [])
-
- def testFirstEventTimestamp(self):
- """Test that FirstEventTimestamp() returns wall_time of the first event."""
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
- gen.AddEvent(
- event_pb2.Event(wall_time=10, step=20, file_version='brain.Event:2'))
- gen.AddScalar('s1', wall_time=30, step=40, value=20)
- self.assertEqual(acc.FirstEventTimestamp(), 10)
-
- def testReloadPopulatesFirstEventTimestamp(self):
- """Test that Reload() means FirstEventTimestamp() won't load events."""
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
- gen.AddEvent(
- event_pb2.Event(wall_time=1, step=2, file_version='brain.Event:2'))
-
- acc.Reload()
-
- def _Die(*args, **kwargs): # pylint: disable=unused-argument
- raise RuntimeError('Load() should not be called')
-
- self.stubs.Set(gen, 'Load', _Die)
- self.assertEqual(acc.FirstEventTimestamp(), 1)
-
- def testFirstEventTimestampLoadsEvent(self):
- """Test that FirstEventTimestamp() doesn't discard the loaded event."""
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
- gen.AddEvent(
- event_pb2.Event(wall_time=1, step=2, file_version='brain.Event:2'))
-
- self.assertEqual(acc.FirstEventTimestamp(), 1)
- acc.Reload()
- self.assertEqual(acc.file_version, 2.0)
-
- def testTFSummaryScalar(self):
- """Verify processing of tf.summary.scalar."""
- event_sink = _EventGenerator(self, zero_out_timestamps=True)
- with test_util.FileWriterCache.get(self.get_temp_dir()) as writer:
- writer.event_writer = event_sink
- with self.test_session() as sess:
- ipt = tf.compat.v1.placeholder(tf.float32)
- tf.compat.v1.summary.scalar('scalar1', ipt)
- tf.compat.v1.summary.scalar('scalar2', ipt * ipt)
- merged = tf.compat.v1.summary.merge_all()
- writer.add_graph(sess.graph)
- for i in xrange(10):
- summ = sess.run(merged, feed_dict={ipt: i})
- writer.add_summary(summ, global_step=i)
-
- accumulator = ea.EventAccumulator(event_sink)
- accumulator.Reload()
-
- seq1 = [ea.ScalarEvent(wall_time=0, step=i, value=i) for i in xrange(10)]
- seq2 = [
- ea.ScalarEvent(
- wall_time=0, step=i, value=i * i) for i in xrange(10)
- ]
-
- self.assertTagsEqual(accumulator.Tags(), {
- ea.SCALARS: ['scalar1', 'scalar2'],
- ea.GRAPH: True,
- ea.META_GRAPH: False,
- })
-
- self.assertEqual(accumulator.Scalars('scalar1'), seq1)
- self.assertEqual(accumulator.Scalars('scalar2'), seq2)
- first_value = accumulator.Scalars('scalar1')[0].value
- self.assertTrue(isinstance(first_value, float))
-
- def testTFSummaryImage(self):
- """Verify processing of tf.summary.image."""
- event_sink = _EventGenerator(self, zero_out_timestamps=True)
- with test_util.FileWriterCache.get(self.get_temp_dir()) as writer:
- writer.event_writer = event_sink
- with self.test_session() as sess:
- ipt = tf.ones([10, 4, 4, 3], tf.uint8)
- # This is an interesting example, because the old tf.image_summary op
- # would throw an error here, because it would be tag reuse.
- # Using the tf node name instead allows argument re-use to the image
- # summary.
- with tf.name_scope('1'):
- tf.compat.v1.summary.image('images', ipt, max_outputs=1)
- with tf.name_scope('2'):
- tf.compat.v1.summary.image('images', ipt, max_outputs=2)
- with tf.name_scope('3'):
- tf.compat.v1.summary.image('images', ipt, max_outputs=3)
- merged = tf.compat.v1.summary.merge_all()
- writer.add_graph(sess.graph)
- for i in xrange(10):
- summ = sess.run(merged)
- writer.add_summary(summ, global_step=i)
-
- accumulator = ea.EventAccumulator(event_sink)
- accumulator.Reload()
-
- tags = [
- u'1/images/image', u'2/images/image/0', u'2/images/image/1',
- u'3/images/image/0', u'3/images/image/1', u'3/images/image/2'
- ]
-
- self.assertTagsEqual(accumulator.Tags(), {
- ea.IMAGES: tags,
- ea.GRAPH: True,
- ea.META_GRAPH: False,
- })
-
- def testTFSummaryTensor(self):
- """Verify processing of tf.summary.tensor."""
- event_sink = _EventGenerator(self, zero_out_timestamps=True)
- with test_util.FileWriterCache.get(self.get_temp_dir()) as writer:
- writer.event_writer = event_sink
- with self.test_session() as sess:
- tf.compat.v1.summary.tensor_summary('scalar', tf.constant(1.0))
- tf.compat.v1.summary.tensor_summary('vector', tf.constant([1.0, 2.0, 3.0]))
- tf.compat.v1.summary.tensor_summary('string', tf.constant(six.b('foobar')))
- merged = tf.compat.v1.summary.merge_all()
- summ = sess.run(merged)
- writer.add_summary(summ, 0)
-
- accumulator = ea.EventAccumulator(event_sink)
- accumulator.Reload()
-
- self.assertTagsEqual(accumulator.Tags(), {
- ea.TENSORS: ['scalar', 'vector', 'string'],
- })
-
- scalar_proto = accumulator.Tensors('scalar')[0].tensor_proto
- scalar = tf.compat.v1.make_ndarray(scalar_proto)
- vector_proto = accumulator.Tensors('vector')[0].tensor_proto
- vector = tf.compat.v1.make_ndarray(vector_proto)
- string_proto = accumulator.Tensors('string')[0].tensor_proto
- string = tf.compat.v1.make_ndarray(string_proto)
-
- self.assertTrue(np.array_equal(scalar, 1.0))
- self.assertTrue(np.array_equal(vector, [1.0, 2.0, 3.0]))
- self.assertTrue(np.array_equal(string, six.b('foobar')))
+ length_frames=22050,
+ ):
+ audio = summary_pb2.Summary.Audio(
+ encoded_audio_string=encoded_audio_string,
+ content_type=content_type,
+ sample_rate=sample_rate,
+ length_frames=length_frames,
+ )
+ event = event_pb2.Event(
+ wall_time=wall_time,
+ step=step,
+ summary=summary_pb2.Summary(
+ value=[summary_pb2.Summary.Value(tag=tag, audio=audio)]
+ ),
+ )
+ self.AddEvent(event)
+
+ def AddEvent(self, event):
+ if self.zero_out_timestamps:
+ event.wall_time = 0
+ self.items.append(event)
+
+ def add_event(self, event): # pylint: disable=invalid-name
+ """Match the EventWriter API."""
+ self.AddEvent(event)
+
+ def get_logdir(self): # pylint: disable=invalid-name
+ """Return a temp directory for asset writing."""
+ return self._testcase.get_temp_dir()
+
+ def close(self):
+ """Closes the event writer."""
+ # noop
-class RealisticEventAccumulatorTest(EventAccumulatorTest):
+class EventAccumulatorTest(tf.test.TestCase):
+ def assertTagsEqual(self, actual, expected):
+ """Utility method for checking the return value of the Tags() call.
+
+ It fills out the `expected` arg with the default (empty) values for every
+ tag type, so that the author needs only specify the non-empty values they
+ are interested in testing.
+
+ Args:
+ actual: The actual Accumulator tags response.
+ expected: The expected tags response (empty fields may be omitted)
+ """
+
+ empty_tags = {
+ ea.IMAGES: [],
+ ea.AUDIO: [],
+ ea.SCALARS: [],
+ ea.HISTOGRAMS: [],
+ ea.COMPRESSED_HISTOGRAMS: [],
+ ea.GRAPH: False,
+ ea.META_GRAPH: False,
+ ea.RUN_METADATA: [],
+ ea.TENSORS: [],
+ }
+
+ # Verifies that there are no unexpected keys in the actual response.
+ # If this line fails, likely you added a new tag type, and need to update
+ # the empty_tags dictionary above.
+ self.assertItemsEqual(actual.keys(), empty_tags.keys())
+
+ for key in actual:
+ expected_value = expected.get(key, empty_tags[key])
+ if isinstance(expected_value, list):
+ self.assertItemsEqual(actual[key], expected_value)
+ else:
+ self.assertEqual(actual[key], expected_value)
+
+
+class MockingEventAccumulatorTest(EventAccumulatorTest):
+ def setUp(self):
+ super(MockingEventAccumulatorTest, self).setUp()
+ self.stubs = tf.compat.v1.test.StubOutForTesting()
+ self._real_constructor = ea.EventAccumulator
+ self._real_generator = ea._GeneratorFromPath
+
+ def _FakeAccumulatorConstructor(generator, *args, **kwargs):
+ ea._GeneratorFromPath = lambda x: generator
+ return self._real_constructor(generator, *args, **kwargs)
+
+ ea.EventAccumulator = _FakeAccumulatorConstructor
+
+ def tearDown(self):
+ self.stubs.CleanUp()
+ ea.EventAccumulator = self._real_constructor
+ ea._GeneratorFromPath = self._real_generator
+
+ def testEmptyAccumulator(self):
+ gen = _EventGenerator(self)
+ x = ea.EventAccumulator(gen)
+ x.Reload()
+ self.assertTagsEqual(x.Tags(), {})
+
+ def testTags(self):
+ """Tags should be found in EventAccumulator after adding some
+ events."""
+ gen = _EventGenerator(self)
+ gen.AddScalar("s1")
+ gen.AddScalar("s2")
+ gen.AddHistogram("hst1")
+ gen.AddHistogram("hst2")
+ gen.AddImage("im1")
+ gen.AddImage("im2")
+ gen.AddAudio("snd1")
+ gen.AddAudio("snd2")
+ acc = ea.EventAccumulator(gen)
+ acc.Reload()
+ self.assertTagsEqual(
+ acc.Tags(),
+ {
+ ea.IMAGES: ["im1", "im2"],
+ ea.AUDIO: ["snd1", "snd2"],
+ ea.SCALARS: ["s1", "s2"],
+ ea.HISTOGRAMS: ["hst1", "hst2"],
+ ea.COMPRESSED_HISTOGRAMS: ["hst1", "hst2"],
+ },
+ )
+
+ def testReload(self):
+ """EventAccumulator contains suitable tags after calling Reload."""
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+ acc.Reload()
+ self.assertTagsEqual(acc.Tags(), {})
+ gen.AddScalar("s1")
+ gen.AddScalar("s2")
+ gen.AddHistogram("hst1")
+ gen.AddHistogram("hst2")
+ gen.AddImage("im1")
+ gen.AddImage("im2")
+ gen.AddAudio("snd1")
+ gen.AddAudio("snd2")
+ acc.Reload()
+ self.assertTagsEqual(
+ acc.Tags(),
+ {
+ ea.IMAGES: ["im1", "im2"],
+ ea.AUDIO: ["snd1", "snd2"],
+ ea.SCALARS: ["s1", "s2"],
+ ea.HISTOGRAMS: ["hst1", "hst2"],
+ ea.COMPRESSED_HISTOGRAMS: ["hst1", "hst2"],
+ },
+ )
+
+ def testScalars(self):
+ """Tests whether EventAccumulator contains scalars after adding
+ them."""
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+ s1 = ea.ScalarEvent(wall_time=1, step=10, value=32)
+ s2 = ea.ScalarEvent(wall_time=2, step=12, value=64)
+ gen.AddScalar("s1", wall_time=1, step=10, value=32)
+ gen.AddScalar("s2", wall_time=2, step=12, value=64)
+ acc.Reload()
+ self.assertEqual(acc.Scalars("s1"), [s1])
+ self.assertEqual(acc.Scalars("s2"), [s2])
+
+ def testHistograms(self):
+ """Tests whether histograms are inserted into EventAccumulator."""
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+
+ val1 = ea.HistogramValue(
+ min=1,
+ max=2,
+ num=3,
+ sum=4,
+ sum_squares=5,
+ bucket_limit=[1, 2, 3],
+ bucket=[0, 3, 0],
+ )
+ val2 = ea.HistogramValue(
+ min=-2,
+ max=3,
+ num=4,
+ sum=5,
+ sum_squares=6,
+ bucket_limit=[2, 3, 4],
+ bucket=[1, 3, 0],
+ )
+
+ hst1 = ea.HistogramEvent(wall_time=1, step=10, histogram_value=val1)
+ hst2 = ea.HistogramEvent(wall_time=2, step=12, histogram_value=val2)
+ gen.AddHistogram(
+ "hst1",
+ wall_time=1,
+ step=10,
+ hmin=1,
+ hmax=2,
+ hnum=3,
+ hsum=4,
+ hsum_squares=5,
+ hbucket_limit=[1, 2, 3],
+ hbucket=[0, 3, 0],
+ )
+ gen.AddHistogram(
+ "hst2",
+ wall_time=2,
+ step=12,
+ hmin=-2,
+ hmax=3,
+ hnum=4,
+ hsum=5,
+ hsum_squares=6,
+ hbucket_limit=[2, 3, 4],
+ hbucket=[1, 3, 0],
+ )
+ acc.Reload()
+ self.assertEqual(acc.Histograms("hst1"), [hst1])
+ self.assertEqual(acc.Histograms("hst2"), [hst2])
+
+ def testCompressedHistograms(self):
+ """Tests compressed histograms inserted into EventAccumulator."""
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(
+ gen, compression_bps=(0, 2500, 5000, 7500, 10000)
+ )
+
+ gen.AddHistogram(
+ "hst1",
+ wall_time=1,
+ step=10,
+ hmin=1,
+ hmax=2,
+ hnum=3,
+ hsum=4,
+ hsum_squares=5,
+ hbucket_limit=[1, 2, 3],
+ hbucket=[0, 3, 0],
+ )
+ gen.AddHistogram(
+ "hst2",
+ wall_time=2,
+ step=12,
+ hmin=-2,
+ hmax=3,
+ hnum=4,
+ hsum=5,
+ hsum_squares=6,
+ hbucket_limit=[2, 3, 4],
+ hbucket=[1, 3, 0],
+ )
+ acc.Reload()
+
+ # Create the expected values after compressing hst1
+ expected_vals1 = [
+ compressor.CompressedHistogramValue(bp, val)
+ for bp, val in [
+ (0, 1.0),
+ (2500, 1.25),
+ (5000, 1.5),
+ (7500, 1.75),
+ (10000, 2.0),
+ ]
+ ]
+ expected_cmphst1 = ea.CompressedHistogramEvent(
+ wall_time=1, step=10, compressed_histogram_values=expected_vals1
+ )
+ self.assertEqual(acc.CompressedHistograms("hst1"), [expected_cmphst1])
+
+ # Create the expected values after compressing hst2
+ expected_vals2 = [
+ compressor.CompressedHistogramValue(bp, val)
+ for bp, val in [
+ (0, -2),
+ (2500, 2),
+ (5000, 2 + 1 / 3),
+ (7500, 2 + 2 / 3),
+ (10000, 3),
+ ]
+ ]
+ expected_cmphst2 = ea.CompressedHistogramEvent(
+ wall_time=2, step=12, compressed_histogram_values=expected_vals2
+ )
+ self.assertEqual(acc.CompressedHistograms("hst2"), [expected_cmphst2])
+
+ def testImages(self):
+ """Tests 2 images inserted/accessed in EventAccumulator."""
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+ im1 = ea.ImageEvent(
+ wall_time=1,
+ step=10,
+ encoded_image_string=b"big",
+ width=400,
+ height=300,
+ )
+ im2 = ea.ImageEvent(
+ wall_time=2,
+ step=12,
+ encoded_image_string=b"small",
+ width=40,
+ height=30,
+ )
+ gen.AddImage(
+ "im1",
+ wall_time=1,
+ step=10,
+ encoded_image_string=b"big",
+ width=400,
+ height=300,
+ )
+ gen.AddImage(
+ "im2",
+ wall_time=2,
+ step=12,
+ encoded_image_string=b"small",
+ width=40,
+ height=30,
+ )
+ acc.Reload()
+ self.assertEqual(acc.Images("im1"), [im1])
+ self.assertEqual(acc.Images("im2"), [im2])
+
+ def testAudio(self):
+ """Tests 2 audio events inserted/accessed in EventAccumulator."""
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+ snd1 = ea.AudioEvent(
+ wall_time=1,
+ step=10,
+ encoded_audio_string=b"big",
+ content_type="audio/wav",
+ sample_rate=44100,
+ length_frames=441000,
+ )
+ snd2 = ea.AudioEvent(
+ wall_time=2,
+ step=12,
+ encoded_audio_string=b"small",
+ content_type="audio/wav",
+ sample_rate=44100,
+ length_frames=44100,
+ )
+ gen.AddAudio(
+ "snd1",
+ wall_time=1,
+ step=10,
+ encoded_audio_string=b"big",
+ content_type="audio/wav",
+ sample_rate=44100,
+ length_frames=441000,
+ )
+ gen.AddAudio(
+ "snd2",
+ wall_time=2,
+ step=12,
+ encoded_audio_string=b"small",
+ content_type="audio/wav",
+ sample_rate=44100,
+ length_frames=44100,
+ )
+ acc.Reload()
+ self.assertEqual(acc.Audio("snd1"), [snd1])
+ self.assertEqual(acc.Audio("snd2"), [snd2])
+
+ def testKeyError(self):
+ """KeyError should be raised when accessing non-existing keys."""
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+ acc.Reload()
+ with self.assertRaises(KeyError):
+ acc.Scalars("s1")
+ with self.assertRaises(KeyError):
+ acc.Scalars("hst1")
+ with self.assertRaises(KeyError):
+ acc.Scalars("im1")
+ with self.assertRaises(KeyError):
+ acc.Histograms("s1")
+ with self.assertRaises(KeyError):
+ acc.Histograms("im1")
+ with self.assertRaises(KeyError):
+ acc.Images("s1")
+ with self.assertRaises(KeyError):
+ acc.Images("hst1")
+ with self.assertRaises(KeyError):
+ acc.Audio("s1")
+ with self.assertRaises(KeyError):
+ acc.Audio("hst1")
+
+ def testNonValueEvents(self):
+ """Non-value events in the generator don't cause early exits."""
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+ gen.AddScalar("s1", wall_time=1, step=10, value=20)
+ gen.AddEvent(
+ event_pb2.Event(wall_time=2, step=20, file_version="nots2")
+ )
+ gen.AddScalar("s3", wall_time=3, step=100, value=1)
+ gen.AddHistogram("hst1")
+ gen.AddImage("im1")
+ gen.AddAudio("snd1")
+
+ acc.Reload()
+ self.assertTagsEqual(
+ acc.Tags(),
+ {
+ ea.IMAGES: ["im1"],
+ ea.AUDIO: ["snd1"],
+ ea.SCALARS: ["s1", "s3"],
+ ea.HISTOGRAMS: ["hst1"],
+ ea.COMPRESSED_HISTOGRAMS: ["hst1"],
+ },
+ )
+
+ def testExpiredDataDiscardedAfterRestartForFileVersionLessThan2(self):
+ """Tests that events are discarded after a restart is detected.
+
+ If a step value is observed to be lower than what was previously seen,
+ this should force a discard of all previous items with the same tag
+ that are outdated.
+
+ Only file versions < 2 use this out-of-order discard logic. Later versions
+ discard events based on the step value of SessionLog.START.
+ """
+ warnings = []
+ self.stubs.Set(logger, "warn", warnings.append)
+
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+
+ gen.AddEvent(
+ event_pb2.Event(wall_time=0, step=0, file_version="brain.Event:1")
+ )
+ gen.AddScalar("s1", wall_time=1, step=100, value=20)
+ gen.AddScalar("s1", wall_time=1, step=200, value=20)
+ gen.AddScalar("s1", wall_time=1, step=300, value=20)
+ acc.Reload()
+ ## Check that number of items are what they should be
+ self.assertEqual([x.step for x in acc.Scalars("s1")], [100, 200, 300])
+
+ gen.AddScalar("s1", wall_time=1, step=101, value=20)
+ gen.AddScalar("s1", wall_time=1, step=201, value=20)
+ gen.AddScalar("s1", wall_time=1, step=301, value=20)
+ acc.Reload()
+ ## Check that we have discarded 200 and 300 from s1
+ self.assertEqual(
+ [x.step for x in acc.Scalars("s1")], [100, 101, 201, 301]
+ )
+
+ def testOrphanedDataNotDiscardedIfFlagUnset(self):
+ """Tests that events are not discarded if purge_orphaned_data is
+ false."""
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen, purge_orphaned_data=False)
+
+ gen.AddEvent(
+ event_pb2.Event(wall_time=0, step=0, file_version="brain.Event:1")
+ )
+ gen.AddScalar("s1", wall_time=1, step=100, value=20)
+ gen.AddScalar("s1", wall_time=1, step=200, value=20)
+ gen.AddScalar("s1", wall_time=1, step=300, value=20)
+ acc.Reload()
+ ## Check that number of items are what they should be
+ self.assertEqual([x.step for x in acc.Scalars("s1")], [100, 200, 300])
+
+ gen.AddScalar("s1", wall_time=1, step=101, value=20)
+ gen.AddScalar("s1", wall_time=1, step=201, value=20)
+ gen.AddScalar("s1", wall_time=1, step=301, value=20)
+ acc.Reload()
+ ## Check that we have NOT discarded 200 and 300 from s1
+ self.assertEqual(
+ [x.step for x in acc.Scalars("s1")], [100, 200, 300, 101, 201, 301]
+ )
+
+ def testEventsDiscardedPerTagAfterRestartForFileVersionLessThan2(self):
+ """Tests that event discards after restart, only affect the misordered
+ tag.
+
+ If a step value is observed to be lower than what was previously seen,
+ this should force a discard of all previous items that are outdated, but
+ only for the out of order tag. Other tags should remain unaffected.
+
+ Only file versions < 2 use this out-of-order discard logic. Later versions
+ discard events based on the step value of SessionLog.START.
+ """
+ warnings = []
+ self.stubs.Set(logger, "warn", warnings.append)
+
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+
+ gen.AddEvent(
+ event_pb2.Event(wall_time=0, step=0, file_version="brain.Event:1")
+ )
+ gen.AddScalar("s1", wall_time=1, step=100, value=20)
+ gen.AddScalar("s1", wall_time=1, step=200, value=20)
+ gen.AddScalar("s1", wall_time=1, step=300, value=20)
+ gen.AddScalar("s1", wall_time=1, step=101, value=20)
+ gen.AddScalar("s1", wall_time=1, step=201, value=20)
+ gen.AddScalar("s1", wall_time=1, step=301, value=20)
+
+ gen.AddScalar("s2", wall_time=1, step=101, value=20)
+ gen.AddScalar("s2", wall_time=1, step=201, value=20)
+ gen.AddScalar("s2", wall_time=1, step=301, value=20)
+
+ acc.Reload()
+ ## Check that we have discarded 200 and 300
+ self.assertEqual(
+ [x.step for x in acc.Scalars("s1")], [100, 101, 201, 301]
+ )
+
+ ## Check that s1 discards do not affect s2
+ ## i.e. check that only events from the out of order tag are discarded
+ self.assertEqual([x.step for x in acc.Scalars("s2")], [101, 201, 301])
+
+ def testOnlySummaryEventsTriggerDiscards(self):
+ """Test that file version event does not trigger data purge."""
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+ gen.AddScalar("s1", wall_time=1, step=100, value=20)
+ ev1 = event_pb2.Event(wall_time=2, step=0, file_version="brain.Event:1")
+ graph_bytes = tf.compat.v1.GraphDef().SerializeToString()
+ ev2 = event_pb2.Event(wall_time=3, step=0, graph_def=graph_bytes)
+ gen.AddEvent(ev1)
+ gen.AddEvent(ev2)
+ acc.Reload()
+ self.assertEqual([x.step for x in acc.Scalars("s1")], [100])
+
+ def testSessionLogStartMessageDiscardsExpiredEvents(self):
+ """Test that SessionLog.START message discards expired events.
+
+ This discard logic is preferred over the out-of-order step
+ discard logic, but this logic can only be used for event protos
+ which have the SessionLog enum, which was introduced to
+ event.proto for file_version >= brain.Event:2.
+ """
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+ gen.AddEvent(
+ event_pb2.Event(wall_time=0, step=1, file_version="brain.Event:2")
+ )
+
+ gen.AddScalar("s1", wall_time=1, step=100, value=20)
+ gen.AddScalar("s1", wall_time=1, step=200, value=20)
+ gen.AddScalar("s1", wall_time=1, step=300, value=20)
+ gen.AddScalar("s1", wall_time=1, step=400, value=20)
+
+ gen.AddScalar("s2", wall_time=1, step=202, value=20)
+ gen.AddScalar("s2", wall_time=1, step=203, value=20)
+
+ slog = event_pb2.SessionLog(status=event_pb2.SessionLog.START)
+ gen.AddEvent(event_pb2.Event(wall_time=2, step=201, session_log=slog))
+ acc.Reload()
+ self.assertEqual([x.step for x in acc.Scalars("s1")], [100, 200])
+ self.assertEqual([x.step for x in acc.Scalars("s2")], [])
+
+ def testFirstEventTimestamp(self):
+ """Test that FirstEventTimestamp() returns wall_time of the first
+ event."""
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+ gen.AddEvent(
+ event_pb2.Event(wall_time=10, step=20, file_version="brain.Event:2")
+ )
+ gen.AddScalar("s1", wall_time=30, step=40, value=20)
+ self.assertEqual(acc.FirstEventTimestamp(), 10)
+
+ def testReloadPopulatesFirstEventTimestamp(self):
+ """Test that Reload() means FirstEventTimestamp() won't load events."""
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+ gen.AddEvent(
+ event_pb2.Event(wall_time=1, step=2, file_version="brain.Event:2")
+ )
+
+ acc.Reload()
+
+ def _Die(*args, **kwargs): # pylint: disable=unused-argument
+ raise RuntimeError("Load() should not be called")
+
+ self.stubs.Set(gen, "Load", _Die)
+ self.assertEqual(acc.FirstEventTimestamp(), 1)
+
+ def testFirstEventTimestampLoadsEvent(self):
+ """Test that FirstEventTimestamp() doesn't discard the loaded event."""
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+ gen.AddEvent(
+ event_pb2.Event(wall_time=1, step=2, file_version="brain.Event:2")
+ )
+
+ self.assertEqual(acc.FirstEventTimestamp(), 1)
+ acc.Reload()
+ self.assertEqual(acc.file_version, 2.0)
+
+ def testTFSummaryScalar(self):
+ """Verify processing of tf.summary.scalar."""
+ event_sink = _EventGenerator(self, zero_out_timestamps=True)
+ with test_util.FileWriterCache.get(self.get_temp_dir()) as writer:
+ writer.event_writer = event_sink
+ with self.test_session() as sess:
+ ipt = tf.compat.v1.placeholder(tf.float32)
+ tf.compat.v1.summary.scalar("scalar1", ipt)
+ tf.compat.v1.summary.scalar("scalar2", ipt * ipt)
+ merged = tf.compat.v1.summary.merge_all()
+ writer.add_graph(sess.graph)
+ for i in xrange(10):
+ summ = sess.run(merged, feed_dict={ipt: i})
+ writer.add_summary(summ, global_step=i)
+
+ accumulator = ea.EventAccumulator(event_sink)
+ accumulator.Reload()
+
+ seq1 = [
+ ea.ScalarEvent(wall_time=0, step=i, value=i) for i in xrange(10)
+ ]
+ seq2 = [
+ ea.ScalarEvent(wall_time=0, step=i, value=i * i) for i in xrange(10)
+ ]
+
+ self.assertTagsEqual(
+ accumulator.Tags(),
+ {
+ ea.SCALARS: ["scalar1", "scalar2"],
+ ea.GRAPH: True,
+ ea.META_GRAPH: False,
+ },
+ )
+
+ self.assertEqual(accumulator.Scalars("scalar1"), seq1)
+ self.assertEqual(accumulator.Scalars("scalar2"), seq2)
+ first_value = accumulator.Scalars("scalar1")[0].value
+ self.assertTrue(isinstance(first_value, float))
+
+ def testTFSummaryImage(self):
+ """Verify processing of tf.summary.image."""
+ event_sink = _EventGenerator(self, zero_out_timestamps=True)
+ with test_util.FileWriterCache.get(self.get_temp_dir()) as writer:
+ writer.event_writer = event_sink
+ with self.test_session() as sess:
+ ipt = tf.ones([10, 4, 4, 3], tf.uint8)
+ # This is an interesting example, because the old tf.image_summary op
+ # would throw an error here, because it would be tag reuse.
+ # Using the tf node name instead allows argument re-use to the image
+ # summary.
+ with tf.name_scope("1"):
+ tf.compat.v1.summary.image("images", ipt, max_outputs=1)
+ with tf.name_scope("2"):
+ tf.compat.v1.summary.image("images", ipt, max_outputs=2)
+ with tf.name_scope("3"):
+ tf.compat.v1.summary.image("images", ipt, max_outputs=3)
+ merged = tf.compat.v1.summary.merge_all()
+ writer.add_graph(sess.graph)
+ for i in xrange(10):
+ summ = sess.run(merged)
+ writer.add_summary(summ, global_step=i)
+
+ accumulator = ea.EventAccumulator(event_sink)
+ accumulator.Reload()
+
+ tags = [
+ u"1/images/image",
+ u"2/images/image/0",
+ u"2/images/image/1",
+ u"3/images/image/0",
+ u"3/images/image/1",
+ u"3/images/image/2",
+ ]
+
+ self.assertTagsEqual(
+ accumulator.Tags(),
+ {ea.IMAGES: tags, ea.GRAPH: True, ea.META_GRAPH: False,},
+ )
+
+ def testTFSummaryTensor(self):
+ """Verify processing of tf.summary.tensor."""
+ event_sink = _EventGenerator(self, zero_out_timestamps=True)
+ with test_util.FileWriterCache.get(self.get_temp_dir()) as writer:
+ writer.event_writer = event_sink
+ with self.test_session() as sess:
+ tf.compat.v1.summary.tensor_summary("scalar", tf.constant(1.0))
+ tf.compat.v1.summary.tensor_summary(
+ "vector", tf.constant([1.0, 2.0, 3.0])
+ )
+ tf.compat.v1.summary.tensor_summary(
+ "string", tf.constant(six.b("foobar"))
+ )
+ merged = tf.compat.v1.summary.merge_all()
+ summ = sess.run(merged)
+ writer.add_summary(summ, 0)
+
+ accumulator = ea.EventAccumulator(event_sink)
+ accumulator.Reload()
+
+ self.assertTagsEqual(
+ accumulator.Tags(), {ea.TENSORS: ["scalar", "vector", "string"],}
+ )
+
+ scalar_proto = accumulator.Tensors("scalar")[0].tensor_proto
+ scalar = tf.compat.v1.make_ndarray(scalar_proto)
+ vector_proto = accumulator.Tensors("vector")[0].tensor_proto
+ vector = tf.compat.v1.make_ndarray(vector_proto)
+ string_proto = accumulator.Tensors("string")[0].tensor_proto
+ string = tf.compat.v1.make_ndarray(string_proto)
+
+ self.assertTrue(np.array_equal(scalar, 1.0))
+ self.assertTrue(np.array_equal(vector, [1.0, 2.0, 3.0]))
+ self.assertTrue(np.array_equal(string, six.b("foobar")))
- def testScalarsRealistically(self):
- """Test accumulator by writing values and then reading them."""
-
- def FakeScalarSummary(tag, value):
- value = summary_pb2.Summary.Value(tag=tag, simple_value=value)
- summary = summary_pb2.Summary(value=[value])
- return summary
-
- directory = os.path.join(self.get_temp_dir(), 'values_dir')
- if tf.io.gfile.isdir(directory):
- tf.io.gfile.rmtree(directory)
- tf.io.gfile.mkdir(directory)
-
- writer = test_util.FileWriter(directory, max_queue=100)
-
- with tf.Graph().as_default() as graph:
- _ = tf.constant([2.0, 1.0])
- # Add a graph to the summary writer.
- writer.add_graph(graph)
- meta_graph_def = tf.compat.v1.train.export_meta_graph(graph_def=graph.as_graph_def(
- add_shapes=True))
- writer.add_meta_graph(meta_graph_def)
-
- run_metadata = config_pb2.RunMetadata()
- device_stats = run_metadata.step_stats.dev_stats.add()
- device_stats.device = 'test device'
- writer.add_run_metadata(run_metadata, 'test run')
-
- # Write a bunch of events using the writer.
- for i in xrange(30):
- summ_id = FakeScalarSummary('id', i)
- summ_sq = FakeScalarSummary('sq', i * i)
- writer.add_summary(summ_id, i * 5)
- writer.add_summary(summ_sq, i * 5)
- writer.flush()
-
- # Verify that we can load those events properly
- acc = ea.EventAccumulator(directory)
- acc.Reload()
- self.assertTagsEqual(acc.Tags(), {
- ea.SCALARS: ['id', 'sq'],
- ea.GRAPH: True,
- ea.META_GRAPH: True,
- ea.RUN_METADATA: ['test run'],
- })
- id_events = acc.Scalars('id')
- sq_events = acc.Scalars('sq')
- self.assertEqual(30, len(id_events))
- self.assertEqual(30, len(sq_events))
- for i in xrange(30):
- self.assertEqual(i * 5, id_events[i].step)
- self.assertEqual(i * 5, sq_events[i].step)
- self.assertEqual(i, id_events[i].value)
- self.assertEqual(i * i, sq_events[i].value)
-
- # Write a few more events to test incremental reloading
- for i in xrange(30, 40):
- summ_id = FakeScalarSummary('id', i)
- summ_sq = FakeScalarSummary('sq', i * i)
- writer.add_summary(summ_id, i * 5)
- writer.add_summary(summ_sq, i * 5)
- writer.flush()
-
- # Verify we can now see all of the data
- acc.Reload()
- id_events = acc.Scalars('id')
- sq_events = acc.Scalars('sq')
- self.assertEqual(40, len(id_events))
- self.assertEqual(40, len(sq_events))
- for i in xrange(40):
- self.assertEqual(i * 5, id_events[i].step)
- self.assertEqual(i * 5, sq_events[i].step)
- self.assertEqual(i, id_events[i].value)
- self.assertEqual(i * i, sq_events[i].value)
-
- expected_graph_def = graph_pb2.GraphDef.FromString(
- graph.as_graph_def(add_shapes=True).SerializeToString())
- self.assertProtoEquals(expected_graph_def, acc.Graph())
-
- expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString(
- meta_graph_def.SerializeToString())
- self.assertProtoEquals(expected_meta_graph, acc.MetaGraph())
-
- def testGraphFromMetaGraphBecomesAvailable(self):
- """Test accumulator by writing values and then reading them."""
-
- directory = os.path.join(self.get_temp_dir(), 'metagraph_test_values_dir')
- if tf.io.gfile.isdir(directory):
- tf.io.gfile.rmtree(directory)
- tf.io.gfile.mkdir(directory)
-
- writer = test_util.FileWriter(directory, max_queue=100)
-
- with tf.Graph().as_default() as graph:
- _ = tf.constant([2.0, 1.0])
- # Add a graph to the summary writer.
- meta_graph_def = tf.compat.v1.train.export_meta_graph(graph_def=graph.as_graph_def(
- add_shapes=True))
- writer.add_meta_graph(meta_graph_def)
-
- writer.flush()
-
- # Verify that we can load those events properly
- acc = ea.EventAccumulator(directory)
- acc.Reload()
- self.assertTagsEqual(acc.Tags(), {
- ea.GRAPH: True,
- ea.META_GRAPH: True,
- })
-
- expected_graph_def = graph_pb2.GraphDef.FromString(
- graph.as_graph_def(add_shapes=True).SerializeToString())
- self.assertProtoEquals(expected_graph_def, acc.Graph())
-
- expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString(
- meta_graph_def.SerializeToString())
- self.assertProtoEquals(expected_meta_graph, acc.MetaGraph())
-
- def _writeMetadata(self, logdir, summary_metadata, nonce=''):
- """Write to disk a summary with the given metadata.
-
- Arguments:
- logdir: a string
- summary_metadata: a `SummaryMetadata` protobuf object
- nonce: optional; will be added to the end of the event file name
- to guarantee that multiple calls to this function do not stomp the
- same file
- """
- summary = summary_pb2.Summary()
- summary.value.add(
- tensor=tensor_util.make_tensor_proto(['po', 'ta', 'to'], dtype=tf.string),
- tag='you_are_it',
- metadata=summary_metadata)
- writer = test_util.FileWriter(logdir, filename_suffix=nonce)
- writer.add_summary(summary.SerializeToString())
- writer.close()
-
- def testSummaryMetadata(self):
- logdir = self.get_temp_dir()
- summary_metadata = summary_pb2.SummaryMetadata(
- display_name='current tagee', summary_description='no')
- summary_metadata.plugin_data.plugin_name = 'outlet'
- self._writeMetadata(logdir, summary_metadata)
- acc = ea.EventAccumulator(logdir)
- acc.Reload()
- self.assertProtoEquals(summary_metadata,
- acc.SummaryMetadata('you_are_it'))
-
- def testSummaryMetadata_FirstMetadataWins(self):
- logdir = self.get_temp_dir()
- summary_metadata_1 = summary_pb2.SummaryMetadata(
- display_name='current tagee',
- summary_description='no',
- plugin_data=summary_pb2.SummaryMetadata.PluginData(
- plugin_name='outlet', content=b'120v'))
- self._writeMetadata(logdir, summary_metadata_1, nonce='1')
- acc = ea.EventAccumulator(logdir)
- acc.Reload()
- summary_metadata_2 = summary_pb2.SummaryMetadata(
- display_name='tagee of the future',
- summary_description='definitely not',
- plugin_data=summary_pb2.SummaryMetadata.PluginData(
- plugin_name='plug', content=b'110v'))
- self._writeMetadata(logdir, summary_metadata_2, nonce='2')
- acc.Reload()
-
- self.assertProtoEquals(summary_metadata_1,
- acc.SummaryMetadata('you_are_it'))
-
- def testPluginTagToContent_PluginsCannotJumpOnTheBandwagon(self):
- # If there are multiple `SummaryMetadata` for a given tag, and the
- # set of plugins in the `plugin_data` of second is different from
- # that of the first, then the second set should be ignored.
- logdir = self.get_temp_dir()
- summary_metadata_1 = summary_pb2.SummaryMetadata(
- display_name='current tagee',
- summary_description='no',
- plugin_data=summary_pb2.SummaryMetadata.PluginData(
- plugin_name='outlet', content=b'120v'))
- self._writeMetadata(logdir, summary_metadata_1, nonce='1')
- acc = ea.EventAccumulator(logdir)
- acc.Reload()
- summary_metadata_2 = summary_pb2.SummaryMetadata(
- display_name='tagee of the future',
- summary_description='definitely not',
- plugin_data=summary_pb2.SummaryMetadata.PluginData(
- plugin_name='plug', content=b'110v'))
- self._writeMetadata(logdir, summary_metadata_2, nonce='2')
- acc.Reload()
-
- self.assertEqual(acc.PluginTagToContent('outlet'),
- {'you_are_it': b'120v'})
- with six.assertRaisesRegex(self, KeyError, 'plug'):
- acc.PluginTagToContent('plug')
-
-if __name__ == '__main__':
- tf.test.main()
+class RealisticEventAccumulatorTest(EventAccumulatorTest):
+ def testScalarsRealistically(self):
+ """Test accumulator by writing values and then reading them."""
+
+ def FakeScalarSummary(tag, value):
+ value = summary_pb2.Summary.Value(tag=tag, simple_value=value)
+ summary = summary_pb2.Summary(value=[value])
+ return summary
+
+ directory = os.path.join(self.get_temp_dir(), "values_dir")
+ if tf.io.gfile.isdir(directory):
+ tf.io.gfile.rmtree(directory)
+ tf.io.gfile.mkdir(directory)
+
+ writer = test_util.FileWriter(directory, max_queue=100)
+
+ with tf.Graph().as_default() as graph:
+ _ = tf.constant([2.0, 1.0])
+ # Add a graph to the summary writer.
+ writer.add_graph(graph)
+ meta_graph_def = tf.compat.v1.train.export_meta_graph(
+ graph_def=graph.as_graph_def(add_shapes=True)
+ )
+ writer.add_meta_graph(meta_graph_def)
+
+ run_metadata = config_pb2.RunMetadata()
+ device_stats = run_metadata.step_stats.dev_stats.add()
+ device_stats.device = "test device"
+ writer.add_run_metadata(run_metadata, "test run")
+
+ # Write a bunch of events using the writer.
+ for i in xrange(30):
+ summ_id = FakeScalarSummary("id", i)
+ summ_sq = FakeScalarSummary("sq", i * i)
+ writer.add_summary(summ_id, i * 5)
+ writer.add_summary(summ_sq, i * 5)
+ writer.flush()
+
+ # Verify that we can load those events properly
+ acc = ea.EventAccumulator(directory)
+ acc.Reload()
+ self.assertTagsEqual(
+ acc.Tags(),
+ {
+ ea.SCALARS: ["id", "sq"],
+ ea.GRAPH: True,
+ ea.META_GRAPH: True,
+ ea.RUN_METADATA: ["test run"],
+ },
+ )
+ id_events = acc.Scalars("id")
+ sq_events = acc.Scalars("sq")
+ self.assertEqual(30, len(id_events))
+ self.assertEqual(30, len(sq_events))
+ for i in xrange(30):
+ self.assertEqual(i * 5, id_events[i].step)
+ self.assertEqual(i * 5, sq_events[i].step)
+ self.assertEqual(i, id_events[i].value)
+ self.assertEqual(i * i, sq_events[i].value)
+
+ # Write a few more events to test incremental reloading
+ for i in xrange(30, 40):
+ summ_id = FakeScalarSummary("id", i)
+ summ_sq = FakeScalarSummary("sq", i * i)
+ writer.add_summary(summ_id, i * 5)
+ writer.add_summary(summ_sq, i * 5)
+ writer.flush()
+
+ # Verify we can now see all of the data
+ acc.Reload()
+ id_events = acc.Scalars("id")
+ sq_events = acc.Scalars("sq")
+ self.assertEqual(40, len(id_events))
+ self.assertEqual(40, len(sq_events))
+ for i in xrange(40):
+ self.assertEqual(i * 5, id_events[i].step)
+ self.assertEqual(i * 5, sq_events[i].step)
+ self.assertEqual(i, id_events[i].value)
+ self.assertEqual(i * i, sq_events[i].value)
+
+ expected_graph_def = graph_pb2.GraphDef.FromString(
+ graph.as_graph_def(add_shapes=True).SerializeToString()
+ )
+ self.assertProtoEquals(expected_graph_def, acc.Graph())
+
+ expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString(
+ meta_graph_def.SerializeToString()
+ )
+ self.assertProtoEquals(expected_meta_graph, acc.MetaGraph())
+
+ def testGraphFromMetaGraphBecomesAvailable(self):
+ """Test accumulator by writing values and then reading them."""
+
+ directory = os.path.join(
+ self.get_temp_dir(), "metagraph_test_values_dir"
+ )
+ if tf.io.gfile.isdir(directory):
+ tf.io.gfile.rmtree(directory)
+ tf.io.gfile.mkdir(directory)
+
+ writer = test_util.FileWriter(directory, max_queue=100)
+
+ with tf.Graph().as_default() as graph:
+ _ = tf.constant([2.0, 1.0])
+ # Add a graph to the summary writer.
+ meta_graph_def = tf.compat.v1.train.export_meta_graph(
+ graph_def=graph.as_graph_def(add_shapes=True)
+ )
+ writer.add_meta_graph(meta_graph_def)
+
+ writer.flush()
+
+ # Verify that we can load those events properly
+ acc = ea.EventAccumulator(directory)
+ acc.Reload()
+ self.assertTagsEqual(acc.Tags(), {ea.GRAPH: True, ea.META_GRAPH: True,})
+
+ expected_graph_def = graph_pb2.GraphDef.FromString(
+ graph.as_graph_def(add_shapes=True).SerializeToString()
+ )
+ self.assertProtoEquals(expected_graph_def, acc.Graph())
+
+ expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString(
+ meta_graph_def.SerializeToString()
+ )
+ self.assertProtoEquals(expected_meta_graph, acc.MetaGraph())
+
+ def _writeMetadata(self, logdir, summary_metadata, nonce=""):
+ """Write to disk a summary with the given metadata.
+
+ Arguments:
+ logdir: a string
+ summary_metadata: a `SummaryMetadata` protobuf object
+ nonce: optional; will be added to the end of the event file name
+ to guarantee that multiple calls to this function do not stomp the
+ same file
+ """
+
+ summary = summary_pb2.Summary()
+ summary.value.add(
+ tensor=tensor_util.make_tensor_proto(
+ ["po", "ta", "to"], dtype=tf.string
+ ),
+ tag="you_are_it",
+ metadata=summary_metadata,
+ )
+ writer = test_util.FileWriter(logdir, filename_suffix=nonce)
+ writer.add_summary(summary.SerializeToString())
+ writer.close()
+
+ def testSummaryMetadata(self):
+ logdir = self.get_temp_dir()
+ summary_metadata = summary_pb2.SummaryMetadata(
+ display_name="current tagee", summary_description="no"
+ )
+ summary_metadata.plugin_data.plugin_name = "outlet"
+ self._writeMetadata(logdir, summary_metadata)
+ acc = ea.EventAccumulator(logdir)
+ acc.Reload()
+ self.assertProtoEquals(
+ summary_metadata, acc.SummaryMetadata("you_are_it")
+ )
+
+ def testSummaryMetadata_FirstMetadataWins(self):
+ logdir = self.get_temp_dir()
+ summary_metadata_1 = summary_pb2.SummaryMetadata(
+ display_name="current tagee",
+ summary_description="no",
+ plugin_data=summary_pb2.SummaryMetadata.PluginData(
+ plugin_name="outlet", content=b"120v"
+ ),
+ )
+ self._writeMetadata(logdir, summary_metadata_1, nonce="1")
+ acc = ea.EventAccumulator(logdir)
+ acc.Reload()
+ summary_metadata_2 = summary_pb2.SummaryMetadata(
+ display_name="tagee of the future",
+ summary_description="definitely not",
+ plugin_data=summary_pb2.SummaryMetadata.PluginData(
+ plugin_name="plug", content=b"110v"
+ ),
+ )
+ self._writeMetadata(logdir, summary_metadata_2, nonce="2")
+ acc.Reload()
+
+ self.assertProtoEquals(
+ summary_metadata_1, acc.SummaryMetadata("you_are_it")
+ )
+
+ def testPluginTagToContent_PluginsCannotJumpOnTheBandwagon(self):
+ # If there are multiple `SummaryMetadata` for a given tag, and the
+ # set of plugins in the `plugin_data` of second is different from
+ # that of the first, then the second set should be ignored.
+ logdir = self.get_temp_dir()
+ summary_metadata_1 = summary_pb2.SummaryMetadata(
+ display_name="current tagee",
+ summary_description="no",
+ plugin_data=summary_pb2.SummaryMetadata.PluginData(
+ plugin_name="outlet", content=b"120v"
+ ),
+ )
+ self._writeMetadata(logdir, summary_metadata_1, nonce="1")
+ acc = ea.EventAccumulator(logdir)
+ acc.Reload()
+ summary_metadata_2 = summary_pb2.SummaryMetadata(
+ display_name="tagee of the future",
+ summary_description="definitely not",
+ plugin_data=summary_pb2.SummaryMetadata.PluginData(
+ plugin_name="plug", content=b"110v"
+ ),
+ )
+ self._writeMetadata(logdir, summary_metadata_2, nonce="2")
+ acc.Reload()
+
+ self.assertEqual(
+ acc.PluginTagToContent("outlet"), {"you_are_it": b"120v"}
+ )
+ with six.assertRaisesRegex(self, KeyError, "plug"):
+ acc.PluginTagToContent("plug")
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorboard/backend/event_processing/event_file_inspector.py b/tensorboard/backend/event_processing/event_file_inspector.py
index 5504d671a5..284970944e 100644
--- a/tensorboard/backend/event_processing/event_file_inspector.py
+++ b/tensorboard/backend/event_processing/event_file_inspector.py
@@ -125,31 +125,34 @@
# Map of field names within summary.proto to the user-facing names that this
# script outputs.
-SUMMARY_TYPE_TO_FIELD = {'simple_value': 'scalars',
- 'histo': 'histograms',
- 'image': 'images',
- 'audio': 'audio'}
+SUMMARY_TYPE_TO_FIELD = {
+ "simple_value": "scalars",
+ "histo": "histograms",
+ "image": "images",
+ "audio": "audio",
+}
for summary_type in event_accumulator.SUMMARY_TYPES:
- if summary_type not in SUMMARY_TYPE_TO_FIELD:
- SUMMARY_TYPE_TO_FIELD[summary_type] = summary_type
+ if summary_type not in SUMMARY_TYPE_TO_FIELD:
+ SUMMARY_TYPE_TO_FIELD[summary_type] = summary_type
# Types of summaries that we may want to query for by tag.
TAG_FIELDS = list(SUMMARY_TYPE_TO_FIELD.values())
# Summaries that we want to see every instance of.
-LONG_FIELDS = ['sessionlog:start', 'sessionlog:stop']
+LONG_FIELDS = ["sessionlog:start", "sessionlog:stop"]
# Summaries that we only want an abridged digest of, since they would
# take too much screen real estate otherwise.
-SHORT_FIELDS = ['graph', 'sessionlog:checkpoint'] + TAG_FIELDS
+SHORT_FIELDS = ["graph", "sessionlog:checkpoint"] + TAG_FIELDS
# All summary types that we can inspect.
TRACKED_FIELDS = SHORT_FIELDS + LONG_FIELDS
# An `Observation` contains the data within each Event file that the inspector
# cares about. The inspector accumulates Observations as it processes events.
-Observation = collections.namedtuple('Observation', ['step', 'wall_time',
- 'tag'])
+Observation = collections.namedtuple(
+ "Observation", ["step", "wall_time", "tag"]
+)
# An InspectionUnit is created for each organizational structure in the event
# files visible in the final terminal output. For instance, one InspectionUnit
@@ -159,259 +162,286 @@
# The InspectionUnit contains the `name` of the organizational unit that will be
# printed to console, a `generator` that yields `Event` protos, and a mapping
# from string fields to `Observations` that the inspector creates.
-InspectionUnit = collections.namedtuple('InspectionUnit', ['name', 'generator',
- 'field_to_obs'])
-
-PRINT_SEPARATOR = '=' * 70 + '\n'
-
-
-def get_field_to_observations_map(generator, query_for_tag=''):
- """Return a field to `Observations` dict for the event generator.
-
- Args:
- generator: A generator over event protos.
- query_for_tag: A string that if specified, only create observations for
- events with this tag name.
-
- Returns:
- A dict mapping keys in `TRACKED_FIELDS` to an `Observation` list.
- """
-
- def increment(stat, event, tag=''):
- assert stat in TRACKED_FIELDS
- field_to_obs[stat].append(Observation(step=event.step,
- wall_time=event.wall_time,
- tag=tag)._asdict())
-
- field_to_obs = dict([(t, []) for t in TRACKED_FIELDS])
-
- for event in generator:
- ## Process the event
- if event.HasField('graph_def') and (not query_for_tag):
- increment('graph', event)
- if event.HasField('session_log') and (not query_for_tag):
- status = event.session_log.status
- if status == event_pb2.SessionLog.START:
- increment('sessionlog:start', event)
- elif status == event_pb2.SessionLog.STOP:
- increment('sessionlog:stop', event)
- elif status == event_pb2.SessionLog.CHECKPOINT:
- increment('sessionlog:checkpoint', event)
- elif event.HasField('summary'):
- for value in event.summary.value:
- if query_for_tag and value.tag != query_for_tag:
- continue
-
- for proto_name, display_name in SUMMARY_TYPE_TO_FIELD.items():
- if value.HasField(proto_name):
- increment(display_name, event, value.tag)
- return field_to_obs
+InspectionUnit = collections.namedtuple(
+ "InspectionUnit", ["name", "generator", "field_to_obs"]
+)
+
+PRINT_SEPARATOR = "=" * 70 + "\n"
+
+
+def get_field_to_observations_map(generator, query_for_tag=""):
+ """Return a field to `Observations` dict for the event generator.
+
+ Args:
+ generator: A generator over event protos.
+ query_for_tag: A string that if specified, only create observations for
+ events with this tag name.
+
+ Returns:
+ A dict mapping keys in `TRACKED_FIELDS` to an `Observation` list.
+ """
+
+ def increment(stat, event, tag=""):
+ assert stat in TRACKED_FIELDS
+ field_to_obs[stat].append(
+ Observation(
+ step=event.step, wall_time=event.wall_time, tag=tag
+ )._asdict()
+ )
+
+ field_to_obs = dict([(t, []) for t in TRACKED_FIELDS])
+
+ for event in generator:
+ ## Process the event
+ if event.HasField("graph_def") and (not query_for_tag):
+ increment("graph", event)
+ if event.HasField("session_log") and (not query_for_tag):
+ status = event.session_log.status
+ if status == event_pb2.SessionLog.START:
+ increment("sessionlog:start", event)
+ elif status == event_pb2.SessionLog.STOP:
+ increment("sessionlog:stop", event)
+ elif status == event_pb2.SessionLog.CHECKPOINT:
+ increment("sessionlog:checkpoint", event)
+ elif event.HasField("summary"):
+ for value in event.summary.value:
+ if query_for_tag and value.tag != query_for_tag:
+ continue
+
+ for proto_name, display_name in SUMMARY_TYPE_TO_FIELD.items():
+ if value.HasField(proto_name):
+ increment(display_name, event, value.tag)
+ return field_to_obs
def get_unique_tags(field_to_obs):
- """Returns a dictionary of tags that a user could query over.
+ """Returns a dictionary of tags that a user could query over.
- Args:
- field_to_obs: Dict that maps string field to `Observation` list.
+ Args:
+ field_to_obs: Dict that maps string field to `Observation` list.
- Returns:
- A dict that maps keys in `TAG_FIELDS` to a list of string tags present in
- the event files. If the dict does not have any observations of the type,
- maps to an empty list so that we can render this to console.
- """
- return {field: sorted(set([x.get('tag', '') for x in observations]))
- for field, observations in field_to_obs.items()
- if field in TAG_FIELDS}
+ Returns:
+ A dict that maps keys in `TAG_FIELDS` to a list of string tags present in
+ the event files. If the dict does not have any observations of the type,
+ maps to an empty list so that we can render this to console.
+ """
+ return {
+ field: sorted(set([x.get("tag", "") for x in observations]))
+ for field, observations in field_to_obs.items()
+ if field in TAG_FIELDS
+ }
def print_dict(d, show_missing=True):
- """Prints a shallow dict to console.
-
- Args:
- d: Dict to print.
- show_missing: Whether to show keys with empty values.
- """
- for k, v in sorted(d.items()):
- if (not v) and show_missing:
- # No instances of the key, so print missing symbol.
- print('{} -'.format(k))
- elif isinstance(v, list):
- # Value is a list, so print each item of the list.
- print(k)
- for item in v:
- print(' {}'.format(item))
- elif isinstance(v, dict):
- # Value is a dict, so print each (key, value) pair of the dict.
- print(k)
- for kk, vv in sorted(v.items()):
- print(' {:<20} {}'.format(kk, vv))
+ """Prints a shallow dict to console.
+
+ Args:
+ d: Dict to print.
+ show_missing: Whether to show keys with empty values.
+ """
+ for k, v in sorted(d.items()):
+ if (not v) and show_missing:
+ # No instances of the key, so print missing symbol.
+ print("{} -".format(k))
+ elif isinstance(v, list):
+ # Value is a list, so print each item of the list.
+ print(k)
+ for item in v:
+ print(" {}".format(item))
+ elif isinstance(v, dict):
+ # Value is a dict, so print each (key, value) pair of the dict.
+ print(k)
+ for kk, vv in sorted(v.items()):
+ print(" {:<20} {}".format(kk, vv))
def get_dict_to_print(field_to_obs):
- """Transform the field-to-obs mapping into a printable dictionary.
+ """Transform the field-to-obs mapping into a printable dictionary.
- Args:
- field_to_obs: Dict that maps string field to `Observation` list.
+ Args:
+ field_to_obs: Dict that maps string field to `Observation` list.
- Returns:
- A dict with the keys and values to print to console.
- """
+ Returns:
+ A dict with the keys and values to print to console.
+ """
- def compressed_steps(steps):
- return {'num_steps': len(set(steps)),
- 'min_step': min(steps),
- 'max_step': max(steps),
- 'last_step': steps[-1],
- 'first_step': steps[0],
- 'outoforder_steps': get_out_of_order(steps)}
+ def compressed_steps(steps):
+ return {
+ "num_steps": len(set(steps)),
+ "min_step": min(steps),
+ "max_step": max(steps),
+ "last_step": steps[-1],
+ "first_step": steps[0],
+ "outoforder_steps": get_out_of_order(steps),
+ }
- def full_steps(steps):
- return {'steps': steps, 'outoforder_steps': get_out_of_order(steps)}
+ def full_steps(steps):
+ return {"steps": steps, "outoforder_steps": get_out_of_order(steps)}
- output = {}
- for field, observations in field_to_obs.items():
- if not observations:
- output[field] = None
- continue
+ output = {}
+ for field, observations in field_to_obs.items():
+ if not observations:
+ output[field] = None
+ continue
- steps = [x['step'] for x in observations]
- if field in SHORT_FIELDS:
- output[field] = compressed_steps(steps)
- if field in LONG_FIELDS:
- output[field] = full_steps(steps)
+ steps = [x["step"] for x in observations]
+ if field in SHORT_FIELDS:
+ output[field] = compressed_steps(steps)
+ if field in LONG_FIELDS:
+ output[field] = full_steps(steps)
- return output
+ return output
def get_out_of_order(list_of_numbers):
- """Returns elements that break the monotonically non-decreasing trend.
-
- This is used to find instances of global step values that are "out-of-order",
- which may trigger TensorBoard event discarding logic.
-
- Args:
- list_of_numbers: A list of numbers.
-
- Returns:
- A list of tuples in which each tuple are two elements are adjacent, but the
- second element is lower than the first.
- """
- # TODO: Consider changing this to only check for out-of-order
- # steps within a particular tag.
- result = []
- # pylint: disable=consider-using-enumerate
- for i in range(len(list_of_numbers)):
- if i == 0:
- continue
- if list_of_numbers[i] < list_of_numbers[i - 1]:
- result.append((list_of_numbers[i - 1], list_of_numbers[i]))
- return result
+ """Returns elements that break the monotonically non-decreasing trend.
+
+ This is used to find instances of global step values that are "out-of-order",
+ which may trigger TensorBoard event discarding logic.
+
+ Args:
+ list_of_numbers: A list of numbers.
+
+ Returns:
+ A list of tuples in which each tuple are two elements are adjacent, but the
+ second element is lower than the first.
+ """
+ # TODO: Consider changing this to only check for out-of-order
+ # steps within a particular tag.
+ result = []
+ # pylint: disable=consider-using-enumerate
+ for i in range(len(list_of_numbers)):
+ if i == 0:
+ continue
+ if list_of_numbers[i] < list_of_numbers[i - 1]:
+ result.append((list_of_numbers[i - 1], list_of_numbers[i]))
+ return result
def generators_from_logdir(logdir):
- """Returns a list of event generators for subdirectories with event files.
+ """Returns a list of event generators for subdirectories with event files.
- The number of generators returned should equal the number of directories
- within logdir that contain event files. If only logdir contains event files,
- returns a list of length one.
+ The number of generators returned should equal the number of directories
+ within logdir that contain event files. If only logdir contains event files,
+ returns a list of length one.
- Args:
- logdir: A log directory that contains event files.
+ Args:
+ logdir: A log directory that contains event files.
- Returns:
- List of event generators for each subdirectory with event files.
- """
- subdirs = io_wrapper.GetLogdirSubdirectories(logdir)
- generators = [
- itertools.chain(*[
- generator_from_event_file(os.path.join(subdir, f))
- for f in tf.io.gfile.listdir(subdir)
- if io_wrapper.IsTensorFlowEventsFile(os.path.join(subdir, f))
- ]) for subdir in subdirs
- ]
- return generators
+ Returns:
+ List of event generators for each subdirectory with event files.
+ """
+ subdirs = io_wrapper.GetLogdirSubdirectories(logdir)
+ generators = [
+ itertools.chain(
+ *[
+ generator_from_event_file(os.path.join(subdir, f))
+ for f in tf.io.gfile.listdir(subdir)
+ if io_wrapper.IsTensorFlowEventsFile(os.path.join(subdir, f))
+ ]
+ )
+ for subdir in subdirs
+ ]
+ return generators
def generator_from_event_file(event_file):
- """Returns a generator that yields events from an event file."""
- return event_file_loader.EventFileLoader(event_file).Load()
-
-
-def get_inspection_units(logdir='', event_file='', tag=''):
- """Returns a list of InspectionUnit objects given either logdir or event_file.
-
- If logdir is given, the number of InspectionUnits should equal the
- number of directories or subdirectories that contain event files.
-
- If event_file is given, the number of InspectionUnits should be 1.
-
- Args:
- logdir: A log directory that contains event files.
- event_file: Or, a particular event file path.
- tag: An optional tag name to query for.
-
- Returns:
- A list of InspectionUnit objects.
- """
- if logdir:
- subdirs = io_wrapper.GetLogdirSubdirectories(logdir)
- inspection_units = []
- for subdir in subdirs:
- generator = itertools.chain(*[
- generator_from_event_file(os.path.join(subdir, f))
- for f in tf.io.gfile.listdir(subdir)
- if io_wrapper.IsTensorFlowEventsFile(os.path.join(subdir, f))
- ])
- inspection_units.append(InspectionUnit(
- name=subdir,
- generator=generator,
- field_to_obs=get_field_to_observations_map(generator, tag)))
- if inspection_units:
- print('Found event files in:\n{}\n'.format('\n'.join(
- [u.name for u in inspection_units])))
- elif io_wrapper.IsTensorFlowEventsFile(logdir):
- print(
- 'It seems that {} may be an event file instead of a logdir. If this '
- 'is the case, use --event_file instead of --logdir to pass '
- 'it in.'.format(logdir))
- else:
- print('No event files found within logdir {}'.format(logdir))
- return inspection_units
- elif event_file:
- generator = generator_from_event_file(event_file)
- return [InspectionUnit(
- name=event_file,
- generator=generator,
- field_to_obs=get_field_to_observations_map(generator, tag))]
- return []
-
-
-def inspect(logdir='', event_file='', tag=''):
- """Main function for inspector that prints out a digest of event files.
-
- Args:
- logdir: A log directory that contains event files.
- event_file: Or, a particular event file path.
- tag: An optional tag name to query for.
-
- Raises:
- ValueError: If neither logdir and event_file are given, or both are given.
- """
- print(PRINT_SEPARATOR +
- 'Processing event files... (this can take a few minutes)\n' +
- PRINT_SEPARATOR)
- inspection_units = get_inspection_units(logdir, event_file, tag)
-
- for unit in inspection_units:
- if tag:
- print('Event statistics for tag {} in {}:'.format(tag, unit.name))
- else:
- # If the user is not inspecting a particular tag, also print the list of
- # all available tags that they can query.
- print('These tags are in {}:'.format(unit.name))
- print_dict(get_unique_tags(unit.field_to_obs))
- print(PRINT_SEPARATOR)
- print('Event statistics for {}:'.format(unit.name))
-
- print_dict(get_dict_to_print(unit.field_to_obs), show_missing=(not tag))
- print(PRINT_SEPARATOR)
+ """Returns a generator that yields events from an event file."""
+ return event_file_loader.EventFileLoader(event_file).Load()
+
+
+def get_inspection_units(logdir="", event_file="", tag=""):
+ """Returns a list of InspectionUnit objects given either logdir or
+ event_file.
+
+ If logdir is given, the number of InspectionUnits should equal the
+ number of directories or subdirectories that contain event files.
+
+ If event_file is given, the number of InspectionUnits should be 1.
+
+ Args:
+ logdir: A log directory that contains event files.
+ event_file: Or, a particular event file path.
+ tag: An optional tag name to query for.
+
+ Returns:
+ A list of InspectionUnit objects.
+ """
+ if logdir:
+ subdirs = io_wrapper.GetLogdirSubdirectories(logdir)
+ inspection_units = []
+ for subdir in subdirs:
+ generator = itertools.chain(
+ *[
+ generator_from_event_file(os.path.join(subdir, f))
+ for f in tf.io.gfile.listdir(subdir)
+ if io_wrapper.IsTensorFlowEventsFile(
+ os.path.join(subdir, f)
+ )
+ ]
+ )
+ inspection_units.append(
+ InspectionUnit(
+ name=subdir,
+ generator=generator,
+ field_to_obs=get_field_to_observations_map(generator, tag),
+ )
+ )
+ if inspection_units:
+ print(
+ "Found event files in:\n{}\n".format(
+ "\n".join([u.name for u in inspection_units])
+ )
+ )
+ elif io_wrapper.IsTensorFlowEventsFile(logdir):
+ print(
+ "It seems that {} may be an event file instead of a logdir. If this "
+ "is the case, use --event_file instead of --logdir to pass "
+ "it in.".format(logdir)
+ )
+ else:
+ print("No event files found within logdir {}".format(logdir))
+ return inspection_units
+ elif event_file:
+ generator = generator_from_event_file(event_file)
+ return [
+ InspectionUnit(
+ name=event_file,
+ generator=generator,
+ field_to_obs=get_field_to_observations_map(generator, tag),
+ )
+ ]
+ return []
+
+
+def inspect(logdir="", event_file="", tag=""):
+ """Main function for inspector that prints out a digest of event files.
+
+ Args:
+ logdir: A log directory that contains event files.
+ event_file: Or, a particular event file path.
+ tag: An optional tag name to query for.
+
+ Raises:
+ ValueError: If neither logdir and event_file are given, or both are given.
+ """
+ print(
+ PRINT_SEPARATOR
+ + "Processing event files... (this can take a few minutes)\n"
+ + PRINT_SEPARATOR
+ )
+ inspection_units = get_inspection_units(logdir, event_file, tag)
+
+ for unit in inspection_units:
+ if tag:
+ print("Event statistics for tag {} in {}:".format(tag, unit.name))
+ else:
+ # If the user is not inspecting a particular tag, also print the list of
+ # all available tags that they can query.
+ print("These tags are in {}:".format(unit.name))
+ print_dict(get_unique_tags(unit.field_to_obs))
+ print(PRINT_SEPARATOR)
+ print("Event statistics for {}:".format(unit.name))
+
+ print_dict(get_dict_to_print(unit.field_to_obs), show_missing=(not tag))
+ print(PRINT_SEPARATOR)
diff --git a/tensorboard/backend/event_processing/event_file_inspector_test.py b/tensorboard/backend/event_processing/event_file_inspector_test.py
index 2a12a8fe9b..7648d3e3f4 100644
--- a/tensorboard/backend/event_processing/event_file_inspector_test.py
+++ b/tensorboard/backend/event_processing/event_file_inspector_test.py
@@ -31,171 +31,192 @@
class EventFileInspectorTest(tf.test.TestCase):
-
- def setUp(self):
- self.logdir = os.path.join(self.get_temp_dir(), 'tfevents')
- self._MakeDirectoryIfNotExists(self.logdir)
-
- def tearDown(self):
- shutil.rmtree(self.logdir)
-
- def _MakeDirectoryIfNotExists(self, path):
- if not os.path.exists(path):
- os.mkdir(path)
-
- def _WriteScalarSummaries(self, data, subdirs=('',)):
- # Writes data to a tempfile in subdirs, and returns generator for the data.
- # If subdirs is given, writes data identically to all subdirectories.
- for subdir_ in subdirs:
- subdir = os.path.join(self.logdir, subdir_)
- self._MakeDirectoryIfNotExists(subdir)
-
- with test_util.FileWriterCache.get(subdir) as sw:
- for datum in data:
- summary = summary_pb2.Summary()
- if 'simple_value' in datum:
- summary.value.add(tag=datum['tag'],
- simple_value=datum['simple_value'])
- sw.add_summary(summary, global_step=datum['step'])
- elif 'histo' in datum:
- summary.value.add(tag=datum['tag'],
- histo=summary_pb2.HistogramProto())
- sw.add_summary(summary, global_step=datum['step'])
- elif 'session_log' in datum:
- sw.add_session_log(datum['session_log'], global_step=datum['step'])
-
- def testEmptyLogdir(self):
- # Nothing was written to logdir
- units = efi.get_inspection_units(self.logdir)
- self.assertEqual([], units)
-
- def testGetAvailableTags(self):
- data = [{'tag': 'c', 'histo': 2, 'step': 10},
- {'tag': 'c', 'histo': 2, 'step': 11},
- {'tag': 'c', 'histo': 2, 'step': 9},
- {'tag': 'b', 'simple_value': 2, 'step': 20},
- {'tag': 'b', 'simple_value': 2, 'step': 15},
- {'tag': 'a', 'simple_value': 2, 'step': 3}]
- self._WriteScalarSummaries(data)
- units = efi.get_inspection_units(self.logdir)
- tags = efi.get_unique_tags(units[0].field_to_obs)
- self.assertEqual(['a', 'b'], tags['scalars'])
- self.assertEqual(['c'], tags['histograms'])
-
- def testInspectAll(self):
- data = [{'tag': 'c', 'histo': 2, 'step': 10},
- {'tag': 'c', 'histo': 2, 'step': 11},
- {'tag': 'c', 'histo': 2, 'step': 9},
- {'tag': 'b', 'simple_value': 2, 'step': 20},
- {'tag': 'b', 'simple_value': 2, 'step': 15},
- {'tag': 'a', 'simple_value': 2, 'step': 3}]
- self._WriteScalarSummaries(data)
- units = efi.get_inspection_units(self.logdir)
- printable = efi.get_dict_to_print(units[0].field_to_obs)
- self.assertEqual(printable['histograms']['max_step'], 11)
- self.assertEqual(printable['histograms']['min_step'], 9)
- self.assertEqual(printable['histograms']['num_steps'], 3)
- self.assertEqual(printable['histograms']['last_step'], 9)
- self.assertEqual(printable['histograms']['first_step'], 10)
- self.assertEqual(printable['histograms']['outoforder_steps'], [(11, 9)])
-
- self.assertEqual(printable['scalars']['max_step'], 20)
- self.assertEqual(printable['scalars']['min_step'], 3)
- self.assertEqual(printable['scalars']['num_steps'], 3)
- self.assertEqual(printable['scalars']['last_step'], 3)
- self.assertEqual(printable['scalars']['first_step'], 20)
- self.assertEqual(printable['scalars']['outoforder_steps'], [(20, 15),
- (15, 3)])
-
- def testInspectTag(self):
- data = [{'tag': 'c', 'histo': 2, 'step': 10},
- {'tag': 'c', 'histo': 2, 'step': 11},
- {'tag': 'c', 'histo': 2, 'step': 9},
- {'tag': 'b', 'histo': 2, 'step': 20},
- {'tag': 'b', 'simple_value': 2, 'step': 15},
- {'tag': 'a', 'simple_value': 2, 'step': 3}]
- self._WriteScalarSummaries(data)
- units = efi.get_inspection_units(self.logdir, tag='c')
- printable = efi.get_dict_to_print(units[0].field_to_obs)
- self.assertEqual(printable['histograms']['max_step'], 11)
- self.assertEqual(printable['histograms']['min_step'], 9)
- self.assertEqual(printable['histograms']['num_steps'], 3)
- self.assertEqual(printable['histograms']['last_step'], 9)
- self.assertEqual(printable['histograms']['first_step'], 10)
- self.assertEqual(printable['histograms']['outoforder_steps'], [(11, 9)])
- self.assertEqual(printable['scalars'], None)
-
- def testSessionLogSummaries(self):
- data = [
- {
- 'session_log': event_pb2.SessionLog(
- status=event_pb2.SessionLog.START),
- 'step': 0
- },
- {
- 'session_log': event_pb2.SessionLog(
- status=event_pb2.SessionLog.CHECKPOINT),
- 'step': 1
- },
- {
- 'session_log': event_pb2.SessionLog(
- status=event_pb2.SessionLog.CHECKPOINT),
- 'step': 2
- },
- {
- 'session_log': event_pb2.SessionLog(
- status=event_pb2.SessionLog.CHECKPOINT),
- 'step': 3
- },
- {
- 'session_log': event_pb2.SessionLog(
- status=event_pb2.SessionLog.STOP),
- 'step': 4
- },
- {
- 'session_log': event_pb2.SessionLog(
- status=event_pb2.SessionLog.START),
- 'step': 5
- },
- {
- 'session_log': event_pb2.SessionLog(
- status=event_pb2.SessionLog.STOP),
- 'step': 6
- },
- ]
-
- self._WriteScalarSummaries(data)
- units = efi.get_inspection_units(self.logdir)
- self.assertEqual(1, len(units))
- printable = efi.get_dict_to_print(units[0].field_to_obs)
- self.assertEqual(printable['sessionlog:start']['steps'], [0, 5])
- self.assertEqual(printable['sessionlog:stop']['steps'], [4, 6])
- self.assertEqual(printable['sessionlog:checkpoint']['num_steps'], 3)
-
- def testInspectAllWithNestedLogdirs(self):
- data = [{'tag': 'c', 'simple_value': 2, 'step': 10},
- {'tag': 'c', 'simple_value': 2, 'step': 11},
- {'tag': 'c', 'simple_value': 2, 'step': 9},
- {'tag': 'b', 'simple_value': 2, 'step': 20},
- {'tag': 'b', 'simple_value': 2, 'step': 15},
- {'tag': 'a', 'simple_value': 2, 'step': 3}]
-
- subdirs = ['eval', 'train']
- self._WriteScalarSummaries(data, subdirs=subdirs)
- units = efi.get_inspection_units(self.logdir)
- self.assertEqual(2, len(units))
- directory_names = [os.path.join(self.logdir, name) for name in subdirs]
- self.assertEqual(directory_names, sorted([unit.name for unit in units]))
-
- for unit in units:
- printable = efi.get_dict_to_print(unit.field_to_obs)['scalars']
- self.assertEqual(printable['max_step'], 20)
- self.assertEqual(printable['min_step'], 3)
- self.assertEqual(printable['num_steps'], 6)
- self.assertEqual(printable['last_step'], 3)
- self.assertEqual(printable['first_step'], 10)
- self.assertEqual(printable['outoforder_steps'], [(11, 9), (20, 15),
- (15, 3)])
-
-if __name__ == '__main__':
- tf.test.main()
+ def setUp(self):
+ self.logdir = os.path.join(self.get_temp_dir(), "tfevents")
+ self._MakeDirectoryIfNotExists(self.logdir)
+
+ def tearDown(self):
+ shutil.rmtree(self.logdir)
+
+ def _MakeDirectoryIfNotExists(self, path):
+ if not os.path.exists(path):
+ os.mkdir(path)
+
+ def _WriteScalarSummaries(self, data, subdirs=("",)):
+ # Writes data to a tempfile in subdirs, and returns generator for the data.
+ # If subdirs is given, writes data identically to all subdirectories.
+ for subdir_ in subdirs:
+ subdir = os.path.join(self.logdir, subdir_)
+ self._MakeDirectoryIfNotExists(subdir)
+
+ with test_util.FileWriterCache.get(subdir) as sw:
+ for datum in data:
+ summary = summary_pb2.Summary()
+ if "simple_value" in datum:
+ summary.value.add(
+ tag=datum["tag"], simple_value=datum["simple_value"]
+ )
+ sw.add_summary(summary, global_step=datum["step"])
+ elif "histo" in datum:
+ summary.value.add(
+ tag=datum["tag"], histo=summary_pb2.HistogramProto()
+ )
+ sw.add_summary(summary, global_step=datum["step"])
+ elif "session_log" in datum:
+ sw.add_session_log(
+ datum["session_log"], global_step=datum["step"]
+ )
+
+ def testEmptyLogdir(self):
+ # Nothing was written to logdir
+ units = efi.get_inspection_units(self.logdir)
+ self.assertEqual([], units)
+
+ def testGetAvailableTags(self):
+ data = [
+ {"tag": "c", "histo": 2, "step": 10},
+ {"tag": "c", "histo": 2, "step": 11},
+ {"tag": "c", "histo": 2, "step": 9},
+ {"tag": "b", "simple_value": 2, "step": 20},
+ {"tag": "b", "simple_value": 2, "step": 15},
+ {"tag": "a", "simple_value": 2, "step": 3},
+ ]
+ self._WriteScalarSummaries(data)
+ units = efi.get_inspection_units(self.logdir)
+ tags = efi.get_unique_tags(units[0].field_to_obs)
+ self.assertEqual(["a", "b"], tags["scalars"])
+ self.assertEqual(["c"], tags["histograms"])
+
+ def testInspectAll(self):
+ data = [
+ {"tag": "c", "histo": 2, "step": 10},
+ {"tag": "c", "histo": 2, "step": 11},
+ {"tag": "c", "histo": 2, "step": 9},
+ {"tag": "b", "simple_value": 2, "step": 20},
+ {"tag": "b", "simple_value": 2, "step": 15},
+ {"tag": "a", "simple_value": 2, "step": 3},
+ ]
+ self._WriteScalarSummaries(data)
+ units = efi.get_inspection_units(self.logdir)
+ printable = efi.get_dict_to_print(units[0].field_to_obs)
+ self.assertEqual(printable["histograms"]["max_step"], 11)
+ self.assertEqual(printable["histograms"]["min_step"], 9)
+ self.assertEqual(printable["histograms"]["num_steps"], 3)
+ self.assertEqual(printable["histograms"]["last_step"], 9)
+ self.assertEqual(printable["histograms"]["first_step"], 10)
+ self.assertEqual(printable["histograms"]["outoforder_steps"], [(11, 9)])
+
+ self.assertEqual(printable["scalars"]["max_step"], 20)
+ self.assertEqual(printable["scalars"]["min_step"], 3)
+ self.assertEqual(printable["scalars"]["num_steps"], 3)
+ self.assertEqual(printable["scalars"]["last_step"], 3)
+ self.assertEqual(printable["scalars"]["first_step"], 20)
+ self.assertEqual(
+ printable["scalars"]["outoforder_steps"], [(20, 15), (15, 3)]
+ )
+
+ def testInspectTag(self):
+ data = [
+ {"tag": "c", "histo": 2, "step": 10},
+ {"tag": "c", "histo": 2, "step": 11},
+ {"tag": "c", "histo": 2, "step": 9},
+ {"tag": "b", "histo": 2, "step": 20},
+ {"tag": "b", "simple_value": 2, "step": 15},
+ {"tag": "a", "simple_value": 2, "step": 3},
+ ]
+ self._WriteScalarSummaries(data)
+ units = efi.get_inspection_units(self.logdir, tag="c")
+ printable = efi.get_dict_to_print(units[0].field_to_obs)
+ self.assertEqual(printable["histograms"]["max_step"], 11)
+ self.assertEqual(printable["histograms"]["min_step"], 9)
+ self.assertEqual(printable["histograms"]["num_steps"], 3)
+ self.assertEqual(printable["histograms"]["last_step"], 9)
+ self.assertEqual(printable["histograms"]["first_step"], 10)
+ self.assertEqual(printable["histograms"]["outoforder_steps"], [(11, 9)])
+ self.assertEqual(printable["scalars"], None)
+
+ def testSessionLogSummaries(self):
+ data = [
+ {
+ "session_log": event_pb2.SessionLog(
+ status=event_pb2.SessionLog.START
+ ),
+ "step": 0,
+ },
+ {
+ "session_log": event_pb2.SessionLog(
+ status=event_pb2.SessionLog.CHECKPOINT
+ ),
+ "step": 1,
+ },
+ {
+ "session_log": event_pb2.SessionLog(
+ status=event_pb2.SessionLog.CHECKPOINT
+ ),
+ "step": 2,
+ },
+ {
+ "session_log": event_pb2.SessionLog(
+ status=event_pb2.SessionLog.CHECKPOINT
+ ),
+ "step": 3,
+ },
+ {
+ "session_log": event_pb2.SessionLog(
+ status=event_pb2.SessionLog.STOP
+ ),
+ "step": 4,
+ },
+ {
+ "session_log": event_pb2.SessionLog(
+ status=event_pb2.SessionLog.START
+ ),
+ "step": 5,
+ },
+ {
+ "session_log": event_pb2.SessionLog(
+ status=event_pb2.SessionLog.STOP
+ ),
+ "step": 6,
+ },
+ ]
+
+ self._WriteScalarSummaries(data)
+ units = efi.get_inspection_units(self.logdir)
+ self.assertEqual(1, len(units))
+ printable = efi.get_dict_to_print(units[0].field_to_obs)
+ self.assertEqual(printable["sessionlog:start"]["steps"], [0, 5])
+ self.assertEqual(printable["sessionlog:stop"]["steps"], [4, 6])
+ self.assertEqual(printable["sessionlog:checkpoint"]["num_steps"], 3)
+
+ def testInspectAllWithNestedLogdirs(self):
+ data = [
+ {"tag": "c", "simple_value": 2, "step": 10},
+ {"tag": "c", "simple_value": 2, "step": 11},
+ {"tag": "c", "simple_value": 2, "step": 9},
+ {"tag": "b", "simple_value": 2, "step": 20},
+ {"tag": "b", "simple_value": 2, "step": 15},
+ {"tag": "a", "simple_value": 2, "step": 3},
+ ]
+
+ subdirs = ["eval", "train"]
+ self._WriteScalarSummaries(data, subdirs=subdirs)
+ units = efi.get_inspection_units(self.logdir)
+ self.assertEqual(2, len(units))
+ directory_names = [os.path.join(self.logdir, name) for name in subdirs]
+ self.assertEqual(directory_names, sorted([unit.name for unit in units]))
+
+ for unit in units:
+ printable = efi.get_dict_to_print(unit.field_to_obs)["scalars"]
+ self.assertEqual(printable["max_step"], 20)
+ self.assertEqual(printable["min_step"], 3)
+ self.assertEqual(printable["num_steps"], 6)
+ self.assertEqual(printable["last_step"], 3)
+ self.assertEqual(printable["first_step"], 10)
+ self.assertEqual(
+ printable["outoforder_steps"], [(11, 9), (20, 15), (15, 3)]
+ )
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorboard/backend/event_processing/event_file_loader.py b/tensorboard/backend/event_processing/event_file_loader.py
index 50615e257c..3a1d52bdf0 100644
--- a/tensorboard/backend/event_processing/event_file_loader.py
+++ b/tensorboard/backend/event_processing/event_file_loader.py
@@ -31,82 +31,87 @@
class RawEventFileLoader(object):
- """An iterator that yields Event protos as serialized bytestrings."""
-
- def __init__(self, file_path):
- if file_path is None:
- raise ValueError('A file path is required')
- file_path = platform_util.readahead_file_path(file_path)
- logger.debug('Opening a record reader pointing at %s', file_path)
- with tf.compat.v1.errors.raise_exception_on_not_ok_status() as status:
- self._reader = _pywrap_tensorflow.PyRecordReader_New(
- tf.compat.as_bytes(file_path), 0, tf.compat.as_bytes(''), status)
- # Store it for logging purposes.
- self._file_path = file_path
- if not self._reader:
- raise IOError('Failed to open a record reader pointing to %s' % file_path)
-
- def Load(self):
- """Loads all new events from disk as raw serialized proto bytestrings.
-
- Calling Load multiple times in a row will not 'drop' events as long as the
- return value is not iterated over.
-
- Yields:
- All event proto bytestrings in the file that have not been yielded yet.
- """
- logger.debug('Loading events from %s', self._file_path)
-
- # GetNext() expects a status argument on TF <= 1.7.
- get_next_args = inspect.getargspec(self._reader.GetNext).args # pylint: disable=deprecated-method
- # First argument is self
- legacy_get_next = (len(get_next_args) > 1)
-
- while True:
- try:
- if legacy_get_next:
- with tf.compat.v1.errors.raise_exception_on_not_ok_status() as status:
- self._reader.GetNext(status)
- else:
- self._reader.GetNext()
- except (tf.errors.DataLossError, tf.errors.OutOfRangeError) as e:
- logger.debug('Cannot read more events: %s', e)
- # We ignore partial read exceptions, because a record may be truncated.
- # PyRecordReader holds the offset prior to the failed read, so retrying
- # will succeed.
- break
- yield self._reader.record()
- logger.debug('No more events in %s', self._file_path)
+ """An iterator that yields Event protos as serialized bytestrings."""
+
+ def __init__(self, file_path):
+ if file_path is None:
+ raise ValueError("A file path is required")
+ file_path = platform_util.readahead_file_path(file_path)
+ logger.debug("Opening a record reader pointing at %s", file_path)
+ with tf.compat.v1.errors.raise_exception_on_not_ok_status() as status:
+ self._reader = _pywrap_tensorflow.PyRecordReader_New(
+ tf.compat.as_bytes(file_path), 0, tf.compat.as_bytes(""), status
+ )
+ # Store it for logging purposes.
+ self._file_path = file_path
+ if not self._reader:
+ raise IOError(
+ "Failed to open a record reader pointing to %s" % file_path
+ )
+
+ def Load(self):
+ """Loads all new events from disk as raw serialized proto bytestrings.
+
+ Calling Load multiple times in a row will not 'drop' events as long as the
+ return value is not iterated over.
+
+ Yields:
+ All event proto bytestrings in the file that have not been yielded yet.
+ """
+ logger.debug("Loading events from %s", self._file_path)
+
+ # GetNext() expects a status argument on TF <= 1.7.
+ get_next_args = inspect.getargspec(
+ self._reader.GetNext
+ ).args # pylint: disable=deprecated-method
+ # First argument is self
+ legacy_get_next = len(get_next_args) > 1
+
+ while True:
+ try:
+ if legacy_get_next:
+ with tf.compat.v1.errors.raise_exception_on_not_ok_status() as status:
+ self._reader.GetNext(status)
+ else:
+ self._reader.GetNext()
+ except (tf.errors.DataLossError, tf.errors.OutOfRangeError) as e:
+ logger.debug("Cannot read more events: %s", e)
+ # We ignore partial read exceptions, because a record may be truncated.
+ # PyRecordReader holds the offset prior to the failed read, so retrying
+ # will succeed.
+ break
+ yield self._reader.record()
+ logger.debug("No more events in %s", self._file_path)
class EventFileLoader(RawEventFileLoader):
- """An iterator that yields parsed Event protos."""
+ """An iterator that yields parsed Event protos."""
- def Load(self):
- """Loads all new events from disk.
+ def Load(self):
+ """Loads all new events from disk.
- Calling Load multiple times in a row will not 'drop' events as long as the
- return value is not iterated over.
+ Calling Load multiple times in a row will not 'drop' events as long as the
+ return value is not iterated over.
- Yields:
- All events in the file that have not been yielded yet.
- """
- for record in super(EventFileLoader, self).Load():
- yield event_pb2.Event.FromString(record)
+ Yields:
+ All events in the file that have not been yielded yet.
+ """
+ for record in super(EventFileLoader, self).Load():
+ yield event_pb2.Event.FromString(record)
class TimestampedEventFileLoader(EventFileLoader):
- """An iterator that yields (UNIX timestamp float, Event proto) pairs."""
+ """An iterator that yields (UNIX timestamp float, Event proto) pairs."""
- def Load(self):
- """Loads all new events and their wall time values from disk.
+ def Load(self):
+ """Loads all new events and their wall time values from disk.
- Calling Load multiple times in a row will not 'drop' events as long as the
- return value is not iterated over.
+ Calling Load multiple times in a row will not 'drop' events as long as the
+ return value is not iterated over.
- Yields:
- Pairs of (UNIX timestamp float, Event proto) for all events in the file
- that have not been yielded yet.
- """
- for event in super(TimestampedEventFileLoader, self).Load():
- yield (event.wall_time, event)
+ Yields:
+ Pairs of (UNIX timestamp float, Event proto) for all events in the file
+ that have not been yielded yet.
+ """
+ for event in super(TimestampedEventFileLoader, self).Load():
+ yield (event.wall_time, event)
diff --git a/tensorboard/backend/event_processing/event_file_loader_test.py b/tensorboard/backend/event_processing/event_file_loader_test.py
index 99ada11ec5..2f67b4b9bf 100644
--- a/tensorboard/backend/event_processing/event_file_loader_test.py
+++ b/tensorboard/backend/event_processing/event_file_loader_test.py
@@ -29,82 +29,85 @@
class EventFileLoaderTest(tf.test.TestCase):
- # A record containing a simple event.
- RECORD = (b'\x18\x00\x00\x00\x00\x00\x00\x00\xa3\x7fK"\t\x00\x00\xc0%\xddu'
- b'\xd5A\x1a\rbrain.Event:1\xec\xf32\x8d')
-
- def _WriteToFile(self, filename, data):
- with open(filename, 'ab') as f:
- f.write(data)
-
- def _LoaderForTestFile(self, filename):
- return event_file_loader.EventFileLoader(
- os.path.join(self.get_temp_dir(), filename))
-
- def testEmptyEventFile(self):
- filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name
- self._WriteToFile(filename, b'')
- loader = self._LoaderForTestFile(filename)
- self.assertEqual(len(list(loader.Load())), 0)
-
- def testSingleWrite(self):
- filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name
- self._WriteToFile(filename, EventFileLoaderTest.RECORD)
- loader = self._LoaderForTestFile(filename)
- events = list(loader.Load())
- self.assertEqual(len(events), 1)
- self.assertEqual(events[0].wall_time, 1440183447.0)
- self.assertEqual(len(list(loader.Load())), 0)
-
- def testMultipleWrites(self):
- filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name
- self._WriteToFile(filename, EventFileLoaderTest.RECORD)
- loader = self._LoaderForTestFile(filename)
- self.assertEqual(len(list(loader.Load())), 1)
- self._WriteToFile(filename, EventFileLoaderTest.RECORD)
- self.assertEqual(len(list(loader.Load())), 1)
-
- def testMultipleLoads(self):
- filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name
- self._WriteToFile(filename, EventFileLoaderTest.RECORD)
- loader = self._LoaderForTestFile(filename)
- loader.Load()
- loader.Load()
- self.assertEqual(len(list(loader.Load())), 1)
-
- def testMultipleWritesAtOnce(self):
- filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name
- self._WriteToFile(filename, EventFileLoaderTest.RECORD)
- self._WriteToFile(filename, EventFileLoaderTest.RECORD)
- loader = self._LoaderForTestFile(filename)
- self.assertEqual(len(list(loader.Load())), 2)
-
- def testMultipleWritesWithBadWrite(self):
- filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name
- self._WriteToFile(filename, EventFileLoaderTest.RECORD)
- self._WriteToFile(filename, EventFileLoaderTest.RECORD)
- # Test that we ignore partial record writes at the end of the file.
- self._WriteToFile(filename, b'123')
- loader = self._LoaderForTestFile(filename)
- self.assertEqual(len(list(loader.Load())), 2)
+ # A record containing a simple event.
+ RECORD = (
+ b'\x18\x00\x00\x00\x00\x00\x00\x00\xa3\x7fK"\t\x00\x00\xc0%\xddu'
+ b"\xd5A\x1a\rbrain.Event:1\xec\xf32\x8d"
+ )
+
+ def _WriteToFile(self, filename, data):
+ with open(filename, "ab") as f:
+ f.write(data)
+
+ def _LoaderForTestFile(self, filename):
+ return event_file_loader.EventFileLoader(
+ os.path.join(self.get_temp_dir(), filename)
+ )
+
+ def testEmptyEventFile(self):
+ filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name
+ self._WriteToFile(filename, b"")
+ loader = self._LoaderForTestFile(filename)
+ self.assertEqual(len(list(loader.Load())), 0)
+
+ def testSingleWrite(self):
+ filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name
+ self._WriteToFile(filename, EventFileLoaderTest.RECORD)
+ loader = self._LoaderForTestFile(filename)
+ events = list(loader.Load())
+ self.assertEqual(len(events), 1)
+ self.assertEqual(events[0].wall_time, 1440183447.0)
+ self.assertEqual(len(list(loader.Load())), 0)
+
+ def testMultipleWrites(self):
+ filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name
+ self._WriteToFile(filename, EventFileLoaderTest.RECORD)
+ loader = self._LoaderForTestFile(filename)
+ self.assertEqual(len(list(loader.Load())), 1)
+ self._WriteToFile(filename, EventFileLoaderTest.RECORD)
+ self.assertEqual(len(list(loader.Load())), 1)
+
+ def testMultipleLoads(self):
+ filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name
+ self._WriteToFile(filename, EventFileLoaderTest.RECORD)
+ loader = self._LoaderForTestFile(filename)
+ loader.Load()
+ loader.Load()
+ self.assertEqual(len(list(loader.Load())), 1)
+
+ def testMultipleWritesAtOnce(self):
+ filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name
+ self._WriteToFile(filename, EventFileLoaderTest.RECORD)
+ self._WriteToFile(filename, EventFileLoaderTest.RECORD)
+ loader = self._LoaderForTestFile(filename)
+ self.assertEqual(len(list(loader.Load())), 2)
+
+ def testMultipleWritesWithBadWrite(self):
+ filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name
+ self._WriteToFile(filename, EventFileLoaderTest.RECORD)
+ self._WriteToFile(filename, EventFileLoaderTest.RECORD)
+ # Test that we ignore partial record writes at the end of the file.
+ self._WriteToFile(filename, b"123")
+ loader = self._LoaderForTestFile(filename)
+ self.assertEqual(len(list(loader.Load())), 2)
class RawEventFileLoaderTest(EventFileLoaderTest):
-
- def _LoaderForTestFile(self, filename):
- return event_file_loader.RawEventFileLoader(
- os.path.join(self.get_temp_dir(), filename))
-
- def testSingleWrite(self):
- filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name
- self._WriteToFile(filename, EventFileLoaderTest.RECORD)
- loader = self._LoaderForTestFile(filename)
- event_protos = list(loader.Load())
- self.assertEqual(len(event_protos), 1)
- # Record format has a 12 byte header and a 4 byte trailer.
- expected_event_proto = EventFileLoaderTest.RECORD[12:-4]
- self.assertEqual(event_protos[0], expected_event_proto)
-
-
-if __name__ == '__main__':
- tf.test.main()
+ def _LoaderForTestFile(self, filename):
+ return event_file_loader.RawEventFileLoader(
+ os.path.join(self.get_temp_dir(), filename)
+ )
+
+ def testSingleWrite(self):
+ filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name
+ self._WriteToFile(filename, EventFileLoaderTest.RECORD)
+ loader = self._LoaderForTestFile(filename)
+ event_protos = list(loader.Load())
+ self.assertEqual(len(event_protos), 1)
+ # Record format has a 12 byte header and a 4 byte trailer.
+ expected_event_proto = EventFileLoaderTest.RECORD[12:-4]
+ self.assertEqual(event_protos[0], expected_event_proto)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorboard/backend/event_processing/event_multiplexer.py b/tensorboard/backend/event_processing/event_multiplexer.py
index 5823a6412d..b575018abe 100644
--- a/tensorboard/backend/event_processing/event_multiplexer.py
+++ b/tensorboard/backend/event_processing/event_multiplexer.py
@@ -31,470 +31,481 @@
logger = tb_logging.get_logger()
-class EventMultiplexer(object):
- """An `EventMultiplexer` manages access to multiple `EventAccumulator`s.
-
- Each `EventAccumulator` is associated with a `run`, which is a self-contained
- TensorFlow execution. The `EventMultiplexer` provides methods for extracting
- information about events from multiple `run`s.
-
- Example usage for loading specific runs from files:
-
- ```python
- x = EventMultiplexer({'run1': 'path/to/run1', 'run2': 'path/to/run2'})
- x.Reload()
- ```
-
- Example usage for loading a directory where each subdirectory is a run
-
- ```python
- (eg:) /parent/directory/path/
- /parent/directory/path/run1/
- /parent/directory/path/run1/events.out.tfevents.1001
- /parent/directory/path/run1/events.out.tfevents.1002
-
- /parent/directory/path/run2/
- /parent/directory/path/run2/events.out.tfevents.9232
-
- /parent/directory/path/run3/
- /parent/directory/path/run3/events.out.tfevents.9232
- x = EventMultiplexer().AddRunsFromDirectory('/parent/directory/path')
- (which is equivalent to:)
- x = EventMultiplexer({'run1': '/parent/directory/path/run1', 'run2':...}
- ```
-
- If you would like to watch `/parent/directory/path`, wait for it to be created
- (if necessary) and then periodically pick up new runs, use
- `AutoloadingMultiplexer`
- @@Tensors
- """
-
- def __init__(self,
- run_path_map=None,
- size_guidance=None,
- purge_orphaned_data=True):
- """Constructor for the `EventMultiplexer`.
-
- Args:
- run_path_map: Dict `{run: path}` which specifies the
- name of a run, and the path to find the associated events. If it is
- None, then the EventMultiplexer initializes without any runs.
- size_guidance: A dictionary mapping from `tagType` to the number of items
- to store for each tag of that type. See
- `event_accumulator.EventAccumulator` for details.
- purge_orphaned_data: Whether to discard any events that were "orphaned" by
- a TensorFlow restart.
- """
- logger.info('Event Multiplexer initializing.')
- self._accumulators_mutex = threading.Lock()
- self._accumulators = {}
- self._paths = {}
- self._reload_called = False
- self._size_guidance = (size_guidance or
- event_accumulator.DEFAULT_SIZE_GUIDANCE)
- self.purge_orphaned_data = purge_orphaned_data
- if run_path_map is not None:
- logger.info('Event Multplexer doing initialization load for %s',
- run_path_map)
- for (run, path) in six.iteritems(run_path_map):
- self.AddRun(path, run)
- logger.info('Event Multiplexer done initializing')
-
- def AddRun(self, path, name=None):
- """Add a run to the multiplexer.
-
- If the name is not specified, it is the same as the path.
-
- If a run by that name exists, and we are already watching the right path,
- do nothing. If we are watching a different path, replace the event
- accumulator.
-
- If `Reload` has been called, it will `Reload` the newly created
- accumulators.
-
- Args:
- path: Path to the event files (or event directory) for given run.
- name: Name of the run to add. If not provided, is set to path.
-
- Returns:
- The `EventMultiplexer`.
- """
- name = name or path
- accumulator = None
- with self._accumulators_mutex:
- if name not in self._accumulators or self._paths[name] != path:
- if name in self._paths and self._paths[name] != path:
- # TODO(@decentralion) - Make it impossible to overwrite an old path
- # with a new path (just give the new path a distinct name)
- logger.warn('Conflict for name %s: old path %s, new path %s',
- name, self._paths[name], path)
- logger.info('Constructing EventAccumulator for %s', path)
- accumulator = event_accumulator.EventAccumulator(
- path,
- size_guidance=self._size_guidance,
- purge_orphaned_data=self.purge_orphaned_data)
- self._accumulators[name] = accumulator
- self._paths[name] = path
- if accumulator:
- if self._reload_called:
- accumulator.Reload()
- return self
-
- def AddRunsFromDirectory(self, path, name=None):
- """Load runs from a directory; recursively walks subdirectories.
-
- If path doesn't exist, no-op. This ensures that it is safe to call
- `AddRunsFromDirectory` multiple times, even before the directory is made.
-
- If path is a directory, load event files in the directory (if any exist) and
- recursively call AddRunsFromDirectory on any subdirectories. This mean you
- can call AddRunsFromDirectory at the root of a tree of event logs and
- TensorBoard will load them all.
-
- If the `EventMultiplexer` is already loaded this will cause
- the newly created accumulators to `Reload()`.
- Args:
- path: A string path to a directory to load runs from.
- name: Optionally, what name to apply to the runs. If name is provided
- and the directory contains run subdirectories, the name of each subrun
- is the concatenation of the parent name and the subdirectory name. If
- name is provided and the directory contains event files, then a run
- is added called "name" and with the events from the path.
-
- Raises:
- ValueError: If the path exists and isn't a directory.
-
- Returns:
- The `EventMultiplexer`.
- """
- logger.info('Starting AddRunsFromDirectory: %s', path)
- for subdir in io_wrapper.GetLogdirSubdirectories(path):
- logger.info('Adding events from directory %s', subdir)
- rpath = os.path.relpath(subdir, path)
- subname = os.path.join(name, rpath) if name else rpath
- self.AddRun(subdir, name=subname)
- logger.info('Done with AddRunsFromDirectory: %s', path)
- return self
-
- def Reload(self):
- """Call `Reload` on every `EventAccumulator`."""
- logger.info('Beginning EventMultiplexer.Reload()')
- self._reload_called = True
- # Build a list so we're safe even if the list of accumulators is modified
- # even while we're reloading.
- with self._accumulators_mutex:
- items = list(self._accumulators.items())
-
- names_to_delete = set()
- for name, accumulator in items:
- try:
- accumulator.Reload()
- except (OSError, IOError) as e:
- logger.error("Unable to reload accumulator '%s': %s", name, e)
- except directory_watcher.DirectoryDeletedError:
- names_to_delete.add(name)
-
- with self._accumulators_mutex:
- for name in names_to_delete:
- logger.warn("Deleting accumulator '%s'", name)
- del self._accumulators[name]
- logger.info('Finished with EventMultiplexer.Reload()')
- return self
-
- def PluginAssets(self, plugin_name):
- """Get index of runs and assets for a given plugin.
-
- Args:
- plugin_name: Name of the plugin we are checking for.
-
- Returns:
- A dictionary that maps from run_name to a list of plugin
- assets for that run.
- """
- with self._accumulators_mutex:
- # To avoid nested locks, we construct a copy of the run-accumulator map
- items = list(six.iteritems(self._accumulators))
-
- return {run: accum.PluginAssets(plugin_name) for run, accum in items}
-
- def RetrievePluginAsset(self, run, plugin_name, asset_name):
- """Return the contents for a specific plugin asset from a run.
-
- Args:
- run: The string name of the run.
- plugin_name: The string name of a plugin.
- asset_name: The string name of an asset.
-
- Returns:
- The string contents of the plugin asset.
-
- Raises:
- KeyError: If the asset is not available.
- """
- accumulator = self.GetAccumulator(run)
- return accumulator.RetrievePluginAsset(plugin_name, asset_name)
-
- def FirstEventTimestamp(self, run):
- """Return the timestamp of the first event of the given run.
-
- This may perform I/O if no events have been loaded yet for the run.
-
- Args:
- run: A string name of the run for which the timestamp is retrieved.
-
- Returns:
- The wall_time of the first event of the run, which will typically be
- seconds since the epoch.
-
- Raises:
- KeyError: If the run is not found.
- ValueError: If the run has no events loaded and there are no events on
- disk to load.
- """
- accumulator = self.GetAccumulator(run)
- return accumulator.FirstEventTimestamp()
-
- def Scalars(self, run, tag):
- """Retrieve the scalar events associated with a run and tag.
-
- Args:
- run: A string name of the run for which values are retrieved.
- tag: A string name of the tag for which values are retrieved.
-
- Raises:
- KeyError: If the run is not found, or the tag is not available for
- the given run.
-
- Returns:
- An array of `event_accumulator.ScalarEvents`.
- """
- accumulator = self.GetAccumulator(run)
- return accumulator.Scalars(tag)
-
- def Graph(self, run):
- """Retrieve the graph associated with the provided run.
-
- Args:
- run: A string name of a run to load the graph for.
-
- Raises:
- KeyError: If the run is not found.
- ValueError: If the run does not have an associated graph.
-
- Returns:
- The `GraphDef` protobuf data structure.
- """
- accumulator = self.GetAccumulator(run)
- return accumulator.Graph()
-
- def SerializedGraph(self, run):
- """Retrieve the serialized graph associated with the provided run.
-
- Args:
- run: A string name of a run to load the graph for.
-
- Raises:
- KeyError: If the run is not found.
- ValueError: If the run does not have an associated graph.
-
- Returns:
- The serialized form of the `GraphDef` protobuf data structure.
- """
- accumulator = self.GetAccumulator(run)
- return accumulator.SerializedGraph()
-
- def MetaGraph(self, run):
- """Retrieve the metagraph associated with the provided run.
-
- Args:
- run: A string name of a run to load the graph for.
-
- Raises:
- KeyError: If the run is not found.
- ValueError: If the run does not have an associated graph.
-
- Returns:
- The `MetaGraphDef` protobuf data structure.
- """
- accumulator = self.GetAccumulator(run)
- return accumulator.MetaGraph()
-
- def RunMetadata(self, run, tag):
- """Get the session.run() metadata associated with a TensorFlow run and tag.
- Args:
- run: A string name of a TensorFlow run.
- tag: A string name of the tag associated with a particular session.run().
-
- Raises:
- KeyError: If the run is not found, or the tag is not available for the
- given run.
-
- Returns:
- The metadata in the form of `RunMetadata` protobuf data structure.
- """
- accumulator = self.GetAccumulator(run)
- return accumulator.RunMetadata(tag)
-
- def Histograms(self, run, tag):
- """Retrieve the histogram events associated with a run and tag.
-
- Args:
- run: A string name of the run for which values are retrieved.
- tag: A string name of the tag for which values are retrieved.
-
- Raises:
- KeyError: If the run is not found, or the tag is not available for
- the given run.
-
- Returns:
- An array of `event_accumulator.HistogramEvents`.
- """
- accumulator = self.GetAccumulator(run)
- return accumulator.Histograms(tag)
-
- def CompressedHistograms(self, run, tag):
- """Retrieve the compressed histogram events associated with a run and tag.
-
- Args:
- run: A string name of the run for which values are retrieved.
- tag: A string name of the tag for which values are retrieved.
-
- Raises:
- KeyError: If the run is not found, or the tag is not available for
- the given run.
-
- Returns:
- An array of `event_accumulator.CompressedHistogramEvents`.
- """
- accumulator = self.GetAccumulator(run)
- return accumulator.CompressedHistograms(tag)
-
- def Images(self, run, tag):
- """Retrieve the image events associated with a run and tag.
-
- Args:
- run: A string name of the run for which values are retrieved.
- tag: A string name of the tag for which values are retrieved.
-
- Raises:
- KeyError: If the run is not found, or the tag is not available for
- the given run.
-
- Returns:
- An array of `event_accumulator.ImageEvents`.
- """
- accumulator = self.GetAccumulator(run)
- return accumulator.Images(tag)
-
- def Audio(self, run, tag):
- """Retrieve the audio events associated with a run and tag.
-
- Args:
- run: A string name of the run for which values are retrieved.
- tag: A string name of the tag for which values are retrieved.
-
- Raises:
- KeyError: If the run is not found, or the tag is not available for
- the given run.
-
- Returns:
- An array of `event_accumulator.AudioEvents`.
- """
- accumulator = self.GetAccumulator(run)
- return accumulator.Audio(tag)
-
- def Tensors(self, run, tag):
- """Retrieve the tensor events associated with a run and tag.
-
- Args:
- run: A string name of the run for which values are retrieved.
- tag: A string name of the tag for which values are retrieved.
-
- Raises:
- KeyError: If the run is not found, or the tag is not available for
- the given run.
-
- Returns:
- An array of `event_accumulator.TensorEvent`s.
- """
- accumulator = self.GetAccumulator(run)
- return accumulator.Tensors(tag)
-
- def PluginRunToTagToContent(self, plugin_name):
- """Returns a 2-layer dictionary of the form {run: {tag: content}}.
-
- The `content` referred above is the content field of the PluginData proto
- for the specified plugin within a Summary.Value proto.
-
- Args:
- plugin_name: The name of the plugin for which to fetch content.
+class EventMultiplexer(object):
+ """An `EventMultiplexer` manages access to multiple `EventAccumulator`s.
- Returns:
- A dictionary of the form {run: {tag: content}}.
- """
- mapping = {}
- for run in self.Runs():
- try:
- tag_to_content = self.GetAccumulator(run).PluginTagToContent(
- plugin_name)
- except KeyError:
- # This run lacks content for the plugin. Try the next run.
- continue
- mapping[run] = tag_to_content
- return mapping
-
- def SummaryMetadata(self, run, tag):
- """Return the summary metadata for the given tag on the given run.
-
- Args:
- run: A string name of the run for which summary metadata is to be
- retrieved.
- tag: A string name of the tag whose summary metadata is to be
- retrieved.
-
- Raises:
- KeyError: If the run is not found, or the tag is not available for
- the given run.
-
- Returns:
- A `SummaryMetadata` protobuf.
- """
- accumulator = self.GetAccumulator(run)
- return accumulator.SummaryMetadata(tag)
+ Each `EventAccumulator` is associated with a `run`, which is a self-contained
+ TensorFlow execution. The `EventMultiplexer` provides methods for extracting
+ information about events from multiple `run`s.
- def Runs(self):
- """Return all the run names in the `EventMultiplexer`.
+ Example usage for loading specific runs from files:
- Returns:
+ ```python
+ x = EventMultiplexer({'run1': 'path/to/run1', 'run2': 'path/to/run2'})
+ x.Reload()
```
- {runName: { images: [tag1, tag2, tag3],
- scalarValues: [tagA, tagB, tagC],
- histograms: [tagX, tagY, tagZ],
- compressedHistograms: [tagX, tagY, tagZ],
- graph: true, meta_graph: true}}
- ```
- """
- with self._accumulators_mutex:
- # To avoid nested locks, we construct a copy of the run-accumulator map
- items = list(six.iteritems(self._accumulators))
- return {run_name: accumulator.Tags() for run_name, accumulator in items}
- def RunPaths(self):
- """Returns a dict mapping run names to event file paths."""
- return self._paths
+ Example usage for loading a directory where each subdirectory is a run
- def GetAccumulator(self, run):
- """Returns EventAccumulator for a given run.
+ ```python
+ (eg:) /parent/directory/path/
+ /parent/directory/path/run1/
+ /parent/directory/path/run1/events.out.tfevents.1001
+ /parent/directory/path/run1/events.out.tfevents.1002
- Args:
- run: String name of run.
+ /parent/directory/path/run2/
+ /parent/directory/path/run2/events.out.tfevents.9232
- Returns:
- An EventAccumulator object.
+ /parent/directory/path/run3/
+ /parent/directory/path/run3/events.out.tfevents.9232
+ x = EventMultiplexer().AddRunsFromDirectory('/parent/directory/path')
+ (which is equivalent to:)
+ x = EventMultiplexer({'run1': '/parent/directory/path/run1', 'run2':...}
+ ```
- Raises:
- KeyError: If run does not exist.
+ If you would like to watch `/parent/directory/path`, wait for it to be created
+ (if necessary) and then periodically pick up new runs, use
+ `AutoloadingMultiplexer`
+ @@Tensors
"""
- with self._accumulators_mutex:
- return self._accumulators[run]
+
+ def __init__(
+ self, run_path_map=None, size_guidance=None, purge_orphaned_data=True
+ ):
+ """Constructor for the `EventMultiplexer`.
+
+ Args:
+ run_path_map: Dict `{run: path}` which specifies the
+ name of a run, and the path to find the associated events. If it is
+ None, then the EventMultiplexer initializes without any runs.
+ size_guidance: A dictionary mapping from `tagType` to the number of items
+ to store for each tag of that type. See
+ `event_accumulator.EventAccumulator` for details.
+ purge_orphaned_data: Whether to discard any events that were "orphaned" by
+ a TensorFlow restart.
+ """
+ logger.info("Event Multiplexer initializing.")
+ self._accumulators_mutex = threading.Lock()
+ self._accumulators = {}
+ self._paths = {}
+ self._reload_called = False
+ self._size_guidance = (
+ size_guidance or event_accumulator.DEFAULT_SIZE_GUIDANCE
+ )
+ self.purge_orphaned_data = purge_orphaned_data
+ if run_path_map is not None:
+ logger.info(
+ "Event Multplexer doing initialization load for %s",
+ run_path_map,
+ )
+ for (run, path) in six.iteritems(run_path_map):
+ self.AddRun(path, run)
+ logger.info("Event Multiplexer done initializing")
+
+ def AddRun(self, path, name=None):
+ """Add a run to the multiplexer.
+
+ If the name is not specified, it is the same as the path.
+
+ If a run by that name exists, and we are already watching the right path,
+ do nothing. If we are watching a different path, replace the event
+ accumulator.
+
+ If `Reload` has been called, it will `Reload` the newly created
+ accumulators.
+
+ Args:
+ path: Path to the event files (or event directory) for given run.
+ name: Name of the run to add. If not provided, is set to path.
+
+ Returns:
+ The `EventMultiplexer`.
+ """
+ name = name or path
+ accumulator = None
+ with self._accumulators_mutex:
+ if name not in self._accumulators or self._paths[name] != path:
+ if name in self._paths and self._paths[name] != path:
+ # TODO(@decentralion) - Make it impossible to overwrite an old path
+ # with a new path (just give the new path a distinct name)
+ logger.warn(
+ "Conflict for name %s: old path %s, new path %s",
+ name,
+ self._paths[name],
+ path,
+ )
+ logger.info("Constructing EventAccumulator for %s", path)
+ accumulator = event_accumulator.EventAccumulator(
+ path,
+ size_guidance=self._size_guidance,
+ purge_orphaned_data=self.purge_orphaned_data,
+ )
+ self._accumulators[name] = accumulator
+ self._paths[name] = path
+ if accumulator:
+ if self._reload_called:
+ accumulator.Reload()
+ return self
+
+ def AddRunsFromDirectory(self, path, name=None):
+ """Load runs from a directory; recursively walks subdirectories.
+
+ If path doesn't exist, no-op. This ensures that it is safe to call
+ `AddRunsFromDirectory` multiple times, even before the directory is made.
+
+ If path is a directory, load event files in the directory (if any exist) and
+ recursively call AddRunsFromDirectory on any subdirectories. This mean you
+ can call AddRunsFromDirectory at the root of a tree of event logs and
+ TensorBoard will load them all.
+
+ If the `EventMultiplexer` is already loaded this will cause
+ the newly created accumulators to `Reload()`.
+ Args:
+ path: A string path to a directory to load runs from.
+ name: Optionally, what name to apply to the runs. If name is provided
+ and the directory contains run subdirectories, the name of each subrun
+ is the concatenation of the parent name and the subdirectory name. If
+ name is provided and the directory contains event files, then a run
+ is added called "name" and with the events from the path.
+
+ Raises:
+ ValueError: If the path exists and isn't a directory.
+
+ Returns:
+ The `EventMultiplexer`.
+ """
+ logger.info("Starting AddRunsFromDirectory: %s", path)
+ for subdir in io_wrapper.GetLogdirSubdirectories(path):
+ logger.info("Adding events from directory %s", subdir)
+ rpath = os.path.relpath(subdir, path)
+ subname = os.path.join(name, rpath) if name else rpath
+ self.AddRun(subdir, name=subname)
+ logger.info("Done with AddRunsFromDirectory: %s", path)
+ return self
+
+ def Reload(self):
+ """Call `Reload` on every `EventAccumulator`."""
+ logger.info("Beginning EventMultiplexer.Reload()")
+ self._reload_called = True
+ # Build a list so we're safe even if the list of accumulators is modified
+ # even while we're reloading.
+ with self._accumulators_mutex:
+ items = list(self._accumulators.items())
+
+ names_to_delete = set()
+ for name, accumulator in items:
+ try:
+ accumulator.Reload()
+ except (OSError, IOError) as e:
+ logger.error("Unable to reload accumulator '%s': %s", name, e)
+ except directory_watcher.DirectoryDeletedError:
+ names_to_delete.add(name)
+
+ with self._accumulators_mutex:
+ for name in names_to_delete:
+ logger.warn("Deleting accumulator '%s'", name)
+ del self._accumulators[name]
+ logger.info("Finished with EventMultiplexer.Reload()")
+ return self
+
+ def PluginAssets(self, plugin_name):
+ """Get index of runs and assets for a given plugin.
+
+ Args:
+ plugin_name: Name of the plugin we are checking for.
+
+ Returns:
+ A dictionary that maps from run_name to a list of plugin
+ assets for that run.
+ """
+ with self._accumulators_mutex:
+ # To avoid nested locks, we construct a copy of the run-accumulator map
+ items = list(six.iteritems(self._accumulators))
+
+ return {run: accum.PluginAssets(plugin_name) for run, accum in items}
+
+ def RetrievePluginAsset(self, run, plugin_name, asset_name):
+ """Return the contents for a specific plugin asset from a run.
+
+ Args:
+ run: The string name of the run.
+ plugin_name: The string name of a plugin.
+ asset_name: The string name of an asset.
+
+ Returns:
+ The string contents of the plugin asset.
+
+ Raises:
+ KeyError: If the asset is not available.
+ """
+ accumulator = self.GetAccumulator(run)
+ return accumulator.RetrievePluginAsset(plugin_name, asset_name)
+
+ def FirstEventTimestamp(self, run):
+ """Return the timestamp of the first event of the given run.
+
+ This may perform I/O if no events have been loaded yet for the run.
+
+ Args:
+ run: A string name of the run for which the timestamp is retrieved.
+
+ Returns:
+ The wall_time of the first event of the run, which will typically be
+ seconds since the epoch.
+
+ Raises:
+ KeyError: If the run is not found.
+ ValueError: If the run has no events loaded and there are no events on
+ disk to load.
+ """
+ accumulator = self.GetAccumulator(run)
+ return accumulator.FirstEventTimestamp()
+
+ def Scalars(self, run, tag):
+ """Retrieve the scalar events associated with a run and tag.
+
+ Args:
+ run: A string name of the run for which values are retrieved.
+ tag: A string name of the tag for which values are retrieved.
+
+ Raises:
+ KeyError: If the run is not found, or the tag is not available for
+ the given run.
+
+ Returns:
+ An array of `event_accumulator.ScalarEvents`.
+ """
+ accumulator = self.GetAccumulator(run)
+ return accumulator.Scalars(tag)
+
+ def Graph(self, run):
+ """Retrieve the graph associated with the provided run.
+
+ Args:
+ run: A string name of a run to load the graph for.
+
+ Raises:
+ KeyError: If the run is not found.
+ ValueError: If the run does not have an associated graph.
+
+ Returns:
+ The `GraphDef` protobuf data structure.
+ """
+ accumulator = self.GetAccumulator(run)
+ return accumulator.Graph()
+
+ def SerializedGraph(self, run):
+ """Retrieve the serialized graph associated with the provided run.
+
+ Args:
+ run: A string name of a run to load the graph for.
+
+ Raises:
+ KeyError: If the run is not found.
+ ValueError: If the run does not have an associated graph.
+
+ Returns:
+ The serialized form of the `GraphDef` protobuf data structure.
+ """
+ accumulator = self.GetAccumulator(run)
+ return accumulator.SerializedGraph()
+
+ def MetaGraph(self, run):
+ """Retrieve the metagraph associated with the provided run.
+
+ Args:
+ run: A string name of a run to load the graph for.
+
+ Raises:
+ KeyError: If the run is not found.
+ ValueError: If the run does not have an associated graph.
+
+ Returns:
+ The `MetaGraphDef` protobuf data structure.
+ """
+ accumulator = self.GetAccumulator(run)
+ return accumulator.MetaGraph()
+
+ def RunMetadata(self, run, tag):
+ """Get the session.run() metadata associated with a TensorFlow run and
+ tag.
+
+ Args:
+ run: A string name of a TensorFlow run.
+ tag: A string name of the tag associated with a particular session.run().
+
+ Raises:
+ KeyError: If the run is not found, or the tag is not available for the
+ given run.
+
+ Returns:
+ The metadata in the form of `RunMetadata` protobuf data structure.
+ """
+ accumulator = self.GetAccumulator(run)
+ return accumulator.RunMetadata(tag)
+
+ def Histograms(self, run, tag):
+ """Retrieve the histogram events associated with a run and tag.
+
+ Args:
+ run: A string name of the run for which values are retrieved.
+ tag: A string name of the tag for which values are retrieved.
+
+ Raises:
+ KeyError: If the run is not found, or the tag is not available for
+ the given run.
+
+ Returns:
+ An array of `event_accumulator.HistogramEvents`.
+ """
+ accumulator = self.GetAccumulator(run)
+ return accumulator.Histograms(tag)
+
+ def CompressedHistograms(self, run, tag):
+ """Retrieve the compressed histogram events associated with a run and
+ tag.
+
+ Args:
+ run: A string name of the run for which values are retrieved.
+ tag: A string name of the tag for which values are retrieved.
+
+ Raises:
+ KeyError: If the run is not found, or the tag is not available for
+ the given run.
+
+ Returns:
+ An array of `event_accumulator.CompressedHistogramEvents`.
+ """
+ accumulator = self.GetAccumulator(run)
+ return accumulator.CompressedHistograms(tag)
+
+ def Images(self, run, tag):
+ """Retrieve the image events associated with a run and tag.
+
+ Args:
+ run: A string name of the run for which values are retrieved.
+ tag: A string name of the tag for which values are retrieved.
+
+ Raises:
+ KeyError: If the run is not found, or the tag is not available for
+ the given run.
+
+ Returns:
+ An array of `event_accumulator.ImageEvents`.
+ """
+ accumulator = self.GetAccumulator(run)
+ return accumulator.Images(tag)
+
+ def Audio(self, run, tag):
+ """Retrieve the audio events associated with a run and tag.
+
+ Args:
+ run: A string name of the run for which values are retrieved.
+ tag: A string name of the tag for which values are retrieved.
+
+ Raises:
+ KeyError: If the run is not found, or the tag is not available for
+ the given run.
+
+ Returns:
+ An array of `event_accumulator.AudioEvents`.
+ """
+ accumulator = self.GetAccumulator(run)
+ return accumulator.Audio(tag)
+
+ def Tensors(self, run, tag):
+ """Retrieve the tensor events associated with a run and tag.
+
+ Args:
+ run: A string name of the run for which values are retrieved.
+ tag: A string name of the tag for which values are retrieved.
+
+ Raises:
+ KeyError: If the run is not found, or the tag is not available for
+ the given run.
+
+ Returns:
+ An array of `event_accumulator.TensorEvent`s.
+ """
+ accumulator = self.GetAccumulator(run)
+ return accumulator.Tensors(tag)
+
+ def PluginRunToTagToContent(self, plugin_name):
+ """Returns a 2-layer dictionary of the form {run: {tag: content}}.
+
+ The `content` referred above is the content field of the PluginData proto
+ for the specified plugin within a Summary.Value proto.
+
+ Args:
+ plugin_name: The name of the plugin for which to fetch content.
+
+ Returns:
+ A dictionary of the form {run: {tag: content}}.
+ """
+ mapping = {}
+ for run in self.Runs():
+ try:
+ tag_to_content = self.GetAccumulator(run).PluginTagToContent(
+ plugin_name
+ )
+ except KeyError:
+ # This run lacks content for the plugin. Try the next run.
+ continue
+ mapping[run] = tag_to_content
+ return mapping
+
+ def SummaryMetadata(self, run, tag):
+ """Return the summary metadata for the given tag on the given run.
+
+ Args:
+ run: A string name of the run for which summary metadata is to be
+ retrieved.
+ tag: A string name of the tag whose summary metadata is to be
+ retrieved.
+
+ Raises:
+ KeyError: If the run is not found, or the tag is not available for
+ the given run.
+
+ Returns:
+ A `SummaryMetadata` protobuf.
+ """
+ accumulator = self.GetAccumulator(run)
+ return accumulator.SummaryMetadata(tag)
+
+ def Runs(self):
+ """Return all the run names in the `EventMultiplexer`.
+
+ Returns:
+ ```
+ {runName: { images: [tag1, tag2, tag3],
+ scalarValues: [tagA, tagB, tagC],
+ histograms: [tagX, tagY, tagZ],
+ compressedHistograms: [tagX, tagY, tagZ],
+ graph: true, meta_graph: true}}
+ ```
+ """
+ with self._accumulators_mutex:
+ # To avoid nested locks, we construct a copy of the run-accumulator map
+ items = list(six.iteritems(self._accumulators))
+ return {run_name: accumulator.Tags() for run_name, accumulator in items}
+
+ def RunPaths(self):
+ """Returns a dict mapping run names to event file paths."""
+ return self._paths
+
+ def GetAccumulator(self, run):
+ """Returns EventAccumulator for a given run.
+
+ Args:
+ run: String name of run.
+
+ Returns:
+ An EventAccumulator object.
+
+ Raises:
+ KeyError: If run does not exist.
+ """
+ with self._accumulators_mutex:
+ return self._accumulators[run]
diff --git a/tensorboard/backend/event_processing/event_multiplexer_test.py b/tensorboard/backend/event_processing/event_multiplexer_test.py
index 6263a7df06..7c24c7eb52 100644
--- a/tensorboard/backend/event_processing/event_multiplexer_test.py
+++ b/tensorboard/backend/event_processing/event_multiplexer_test.py
@@ -28,326 +28,350 @@
def _AddEvents(path):
- if not tf.io.gfile.isdir(path):
- tf.io.gfile.makedirs(path)
- fpath = os.path.join(path, 'hypothetical.tfevents.out')
- with tf.io.gfile.GFile(fpath, 'w') as f:
- f.write('')
- return fpath
+ if not tf.io.gfile.isdir(path):
+ tf.io.gfile.makedirs(path)
+ fpath = os.path.join(path, "hypothetical.tfevents.out")
+ with tf.io.gfile.GFile(fpath, "w") as f:
+ f.write("")
+ return fpath
def _CreateCleanDirectory(path):
- if tf.io.gfile.isdir(path):
- tf.io.gfile.rmtree(path)
- tf.io.gfile.mkdir(path)
+ if tf.io.gfile.isdir(path):
+ tf.io.gfile.rmtree(path)
+ tf.io.gfile.mkdir(path)
class _FakeAccumulator(object):
-
- def __init__(self, path):
- """Constructs a fake accumulator with some fake events.
-
- Args:
- path: The path for the run that this accumulator is for.
- """
- self._path = path
- self.reload_called = False
- self._plugin_to_tag_to_content = {
- 'baz_plugin': {
- 'foo': 'foo_content',
- 'bar': 'bar_content',
+ def __init__(self, path):
+ """Constructs a fake accumulator with some fake events.
+
+ Args:
+ path: The path for the run that this accumulator is for.
+ """
+ self._path = path
+ self.reload_called = False
+ self._plugin_to_tag_to_content = {
+ "baz_plugin": {"foo": "foo_content", "bar": "bar_content",}
}
- }
- def Tags(self):
- return {event_accumulator.IMAGES: ['im1', 'im2'],
- event_accumulator.AUDIO: ['snd1', 'snd2'],
- event_accumulator.HISTOGRAMS: ['hst1', 'hst2'],
- event_accumulator.COMPRESSED_HISTOGRAMS: ['cmphst1', 'cmphst2'],
- event_accumulator.SCALARS: ['sv1', 'sv2']}
+ def Tags(self):
+ return {
+ event_accumulator.IMAGES: ["im1", "im2"],
+ event_accumulator.AUDIO: ["snd1", "snd2"],
+ event_accumulator.HISTOGRAMS: ["hst1", "hst2"],
+ event_accumulator.COMPRESSED_HISTOGRAMS: ["cmphst1", "cmphst2"],
+ event_accumulator.SCALARS: ["sv1", "sv2"],
+ }
- def FirstEventTimestamp(self):
- return 0
+ def FirstEventTimestamp(self):
+ return 0
- def _TagHelper(self, tag_name, enum):
- if tag_name not in self.Tags()[enum]:
- raise KeyError
- return ['%s/%s' % (self._path, tag_name)]
+ def _TagHelper(self, tag_name, enum):
+ if tag_name not in self.Tags()[enum]:
+ raise KeyError
+ return ["%s/%s" % (self._path, tag_name)]
- def Scalars(self, tag_name):
- return self._TagHelper(tag_name, event_accumulator.SCALARS)
+ def Scalars(self, tag_name):
+ return self._TagHelper(tag_name, event_accumulator.SCALARS)
- def Histograms(self, tag_name):
- return self._TagHelper(tag_name, event_accumulator.HISTOGRAMS)
+ def Histograms(self, tag_name):
+ return self._TagHelper(tag_name, event_accumulator.HISTOGRAMS)
- def CompressedHistograms(self, tag_name):
- return self._TagHelper(tag_name, event_accumulator.COMPRESSED_HISTOGRAMS)
+ def CompressedHistograms(self, tag_name):
+ return self._TagHelper(
+ tag_name, event_accumulator.COMPRESSED_HISTOGRAMS
+ )
- def Images(self, tag_name):
- return self._TagHelper(tag_name, event_accumulator.IMAGES)
+ def Images(self, tag_name):
+ return self._TagHelper(tag_name, event_accumulator.IMAGES)
- def Audio(self, tag_name):
- return self._TagHelper(tag_name, event_accumulator.AUDIO)
+ def Audio(self, tag_name):
+ return self._TagHelper(tag_name, event_accumulator.AUDIO)
- def Tensors(self, tag_name):
- return self._TagHelper(tag_name, event_accumulator.TENSORS)
+ def Tensors(self, tag_name):
+ return self._TagHelper(tag_name, event_accumulator.TENSORS)
- def PluginTagToContent(self, plugin_name):
- # We pre-pend the runs with the path and '_' so that we can verify that the
- # tags are associated with the correct runs.
- return {
- self._path + '_' + run: content_mapping
- for (run, content_mapping
- ) in self._plugin_to_tag_to_content[plugin_name].items()
- }
+ def PluginTagToContent(self, plugin_name):
+ # We pre-pend the runs with the path and '_' so that we can verify that the
+ # tags are associated with the correct runs.
+ return {
+ self._path + "_" + run: content_mapping
+ for (run, content_mapping) in self._plugin_to_tag_to_content[
+ plugin_name
+ ].items()
+ }
- def Reload(self):
- self.reload_called = True
+ def Reload(self):
+ self.reload_called = True
-def _GetFakeAccumulator(path,
- size_guidance=None,
- compression_bps=None,
- purge_orphaned_data=None):
- del size_guidance, compression_bps, purge_orphaned_data # Unused.
- return _FakeAccumulator(path)
+def _GetFakeAccumulator(
+ path, size_guidance=None, compression_bps=None, purge_orphaned_data=None
+):
+ del size_guidance, compression_bps, purge_orphaned_data # Unused.
+ return _FakeAccumulator(path)
class EventMultiplexerTest(tf.test.TestCase):
-
- def setUp(self):
- super(EventMultiplexerTest, self).setUp()
- self.stubs = tf.compat.v1.test.StubOutForTesting()
-
- self.stubs.Set(event_accumulator, 'EventAccumulator', _GetFakeAccumulator)
-
- def tearDown(self):
- self.stubs.CleanUp()
-
- def testEmptyLoader(self):
- """Tests empty EventMultiplexer creation."""
- x = event_multiplexer.EventMultiplexer()
- self.assertEqual(x.Runs(), {})
-
- def testRunNamesRespected(self):
- """Tests two EventAccumulators inserted/accessed in EventMultiplexer."""
- x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
- self.assertItemsEqual(sorted(x.Runs().keys()), ['run1', 'run2'])
- self.assertEqual(x.GetAccumulator('run1')._path, 'path1')
- self.assertEqual(x.GetAccumulator('run2')._path, 'path2')
-
- def testReload(self):
- """EventAccumulators should Reload after EventMultiplexer call it."""
- x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
- self.assertFalse(x.GetAccumulator('run1').reload_called)
- self.assertFalse(x.GetAccumulator('run2').reload_called)
- x.Reload()
- self.assertTrue(x.GetAccumulator('run1').reload_called)
- self.assertTrue(x.GetAccumulator('run2').reload_called)
-
- def testScalars(self):
- """Tests Scalars function returns suitable values."""
- x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
-
- run1_actual = x.Scalars('run1', 'sv1')
- run1_expected = ['path1/sv1']
-
- self.assertEqual(run1_expected, run1_actual)
-
- def testPluginRunToTagToContent(self):
- """Tests the method that produces the run to tag to content mapping."""
- x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
- self.assertDictEqual({
- 'run1': {
- 'path1_foo': 'foo_content',
- 'path1_bar': 'bar_content',
- },
- 'run2': {
- 'path2_foo': 'foo_content',
- 'path2_bar': 'bar_content',
- }
- }, x.PluginRunToTagToContent('baz_plugin'))
-
- def testExceptions(self):
- """KeyError should be raised when accessing non-existing keys."""
- x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
- with self.assertRaises(KeyError):
- x.Scalars('sv1', 'xxx')
-
- def testInitialization(self):
- """Tests EventMultiplexer is created properly with its params."""
- x = event_multiplexer.EventMultiplexer()
- self.assertEqual(x.Runs(), {})
- x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
- self.assertItemsEqual(x.Runs(), ['run1', 'run2'])
- self.assertEqual(x.GetAccumulator('run1')._path, 'path1')
- self.assertEqual(x.GetAccumulator('run2')._path, 'path2')
-
- def testAddRunsFromDirectory(self):
- """Tests AddRunsFromDirectory function.
-
- Tests the following scenarios:
- - When the directory does not exist.
- - When the directory is empty.
- - When the directory has empty subdirectory.
- - Contains proper EventAccumulators after adding events.
- """
- x = event_multiplexer.EventMultiplexer()
- tmpdir = self.get_temp_dir()
- join = os.path.join
- fakedir = join(tmpdir, 'fake_accumulator_directory')
- realdir = join(tmpdir, 'real_accumulator_directory')
- self.assertEqual(x.Runs(), {})
- x.AddRunsFromDirectory(fakedir)
- self.assertEqual(x.Runs(), {}, 'loading fakedir had no effect')
-
- _CreateCleanDirectory(realdir)
- x.AddRunsFromDirectory(realdir)
- self.assertEqual(x.Runs(), {}, 'loading empty directory had no effect')
-
- path1 = join(realdir, 'path1')
- tf.io.gfile.mkdir(path1)
- x.AddRunsFromDirectory(realdir)
- self.assertEqual(x.Runs(), {}, 'creating empty subdirectory had no effect')
-
- _AddEvents(path1)
- x.AddRunsFromDirectory(realdir)
- self.assertItemsEqual(x.Runs(), ['path1'], 'loaded run: path1')
- loader1 = x.GetAccumulator('path1')
- self.assertEqual(loader1._path, path1, 'has the correct path')
-
- path2 = join(realdir, 'path2')
- _AddEvents(path2)
- x.AddRunsFromDirectory(realdir)
- self.assertItemsEqual(x.Runs(), ['path1', 'path2'])
- self.assertEqual(
- x.GetAccumulator('path1'), loader1, 'loader1 not regenerated')
-
- path2_2 = join(path2, 'path2')
- _AddEvents(path2_2)
- x.AddRunsFromDirectory(realdir)
- self.assertItemsEqual(x.Runs(), ['path1', 'path2', 'path2/path2'])
- self.assertEqual(
- x.GetAccumulator('path2/path2')._path, path2_2, 'loader2 path correct')
-
- def testAddRunsFromDirectoryThatContainsEvents(self):
- x = event_multiplexer.EventMultiplexer()
- tmpdir = self.get_temp_dir()
- join = os.path.join
- realdir = join(tmpdir, 'event_containing_directory')
-
- _CreateCleanDirectory(realdir)
-
- self.assertEqual(x.Runs(), {})
-
- _AddEvents(realdir)
- x.AddRunsFromDirectory(realdir)
- self.assertItemsEqual(x.Runs(), ['.'])
-
- subdir = join(realdir, 'subdir')
- _AddEvents(subdir)
- x.AddRunsFromDirectory(realdir)
- self.assertItemsEqual(x.Runs(), ['.', 'subdir'])
-
- def testAddRunsFromDirectoryWithRunNames(self):
- x = event_multiplexer.EventMultiplexer()
- tmpdir = self.get_temp_dir()
- join = os.path.join
- realdir = join(tmpdir, 'event_containing_directory')
-
- _CreateCleanDirectory(realdir)
-
- self.assertEqual(x.Runs(), {})
-
- _AddEvents(realdir)
- x.AddRunsFromDirectory(realdir, 'foo')
- self.assertItemsEqual(x.Runs(), ['foo/.'])
-
- subdir = join(realdir, 'subdir')
- _AddEvents(subdir)
- x.AddRunsFromDirectory(realdir, 'foo')
- self.assertItemsEqual(x.Runs(), ['foo/.', 'foo/subdir'])
-
- def testAddRunsFromDirectoryWalksTree(self):
- x = event_multiplexer.EventMultiplexer()
- tmpdir = self.get_temp_dir()
- join = os.path.join
- realdir = join(tmpdir, 'event_containing_directory')
-
- _CreateCleanDirectory(realdir)
- _AddEvents(realdir)
- sub = join(realdir, 'subdirectory')
- sub1 = join(sub, '1')
- sub2 = join(sub, '2')
- sub1_1 = join(sub1, '1')
- _AddEvents(sub1)
- _AddEvents(sub2)
- _AddEvents(sub1_1)
- x.AddRunsFromDirectory(realdir)
-
- self.assertItemsEqual(x.Runs(), ['.', 'subdirectory/1', 'subdirectory/2',
- 'subdirectory/1/1'])
-
- def testAddRunsFromDirectoryThrowsException(self):
- x = event_multiplexer.EventMultiplexer()
- tmpdir = self.get_temp_dir()
-
- filepath = _AddEvents(tmpdir)
- with self.assertRaises(ValueError):
- x.AddRunsFromDirectory(filepath)
-
- def testAddRun(self):
- x = event_multiplexer.EventMultiplexer()
- x.AddRun('run1_path', 'run1')
- run1 = x.GetAccumulator('run1')
- self.assertEqual(sorted(x.Runs().keys()), ['run1'])
- self.assertEqual(run1._path, 'run1_path')
-
- x.AddRun('run1_path', 'run1')
- self.assertEqual(run1, x.GetAccumulator('run1'), 'loader not recreated')
-
- x.AddRun('run2_path', 'run1')
- new_run1 = x.GetAccumulator('run1')
- self.assertEqual(new_run1._path, 'run2_path')
- self.assertNotEqual(run1, new_run1)
-
- x.AddRun('runName3')
- self.assertItemsEqual(sorted(x.Runs().keys()), ['run1', 'runName3'])
- self.assertEqual(x.GetAccumulator('runName3')._path, 'runName3')
-
- def testAddRunMaintainsLoading(self):
- x = event_multiplexer.EventMultiplexer()
- x.Reload()
- x.AddRun('run1')
- x.AddRun('run2')
- self.assertTrue(x.GetAccumulator('run1').reload_called)
- self.assertTrue(x.GetAccumulator('run2').reload_called)
+ def setUp(self):
+ super(EventMultiplexerTest, self).setUp()
+ self.stubs = tf.compat.v1.test.StubOutForTesting()
+
+ self.stubs.Set(
+ event_accumulator, "EventAccumulator", _GetFakeAccumulator
+ )
+
+ def tearDown(self):
+ self.stubs.CleanUp()
+
+ def testEmptyLoader(self):
+ """Tests empty EventMultiplexer creation."""
+ x = event_multiplexer.EventMultiplexer()
+ self.assertEqual(x.Runs(), {})
+
+ def testRunNamesRespected(self):
+ """Tests two EventAccumulators inserted/accessed in
+ EventMultiplexer."""
+ x = event_multiplexer.EventMultiplexer(
+ {"run1": "path1", "run2": "path2"}
+ )
+ self.assertItemsEqual(sorted(x.Runs().keys()), ["run1", "run2"])
+ self.assertEqual(x.GetAccumulator("run1")._path, "path1")
+ self.assertEqual(x.GetAccumulator("run2")._path, "path2")
+
+ def testReload(self):
+ """EventAccumulators should Reload after EventMultiplexer call it."""
+ x = event_multiplexer.EventMultiplexer(
+ {"run1": "path1", "run2": "path2"}
+ )
+ self.assertFalse(x.GetAccumulator("run1").reload_called)
+ self.assertFalse(x.GetAccumulator("run2").reload_called)
+ x.Reload()
+ self.assertTrue(x.GetAccumulator("run1").reload_called)
+ self.assertTrue(x.GetAccumulator("run2").reload_called)
+
+ def testScalars(self):
+ """Tests Scalars function returns suitable values."""
+ x = event_multiplexer.EventMultiplexer(
+ {"run1": "path1", "run2": "path2"}
+ )
+
+ run1_actual = x.Scalars("run1", "sv1")
+ run1_expected = ["path1/sv1"]
+
+ self.assertEqual(run1_expected, run1_actual)
+
+ def testPluginRunToTagToContent(self):
+ """Tests the method that produces the run to tag to content mapping."""
+ x = event_multiplexer.EventMultiplexer(
+ {"run1": "path1", "run2": "path2"}
+ )
+ self.assertDictEqual(
+ {
+ "run1": {
+ "path1_foo": "foo_content",
+ "path1_bar": "bar_content",
+ },
+ "run2": {
+ "path2_foo": "foo_content",
+ "path2_bar": "bar_content",
+ },
+ },
+ x.PluginRunToTagToContent("baz_plugin"),
+ )
+
+ def testExceptions(self):
+ """KeyError should be raised when accessing non-existing keys."""
+ x = event_multiplexer.EventMultiplexer(
+ {"run1": "path1", "run2": "path2"}
+ )
+ with self.assertRaises(KeyError):
+ x.Scalars("sv1", "xxx")
+
+ def testInitialization(self):
+ """Tests EventMultiplexer is created properly with its params."""
+ x = event_multiplexer.EventMultiplexer()
+ self.assertEqual(x.Runs(), {})
+ x = event_multiplexer.EventMultiplexer(
+ {"run1": "path1", "run2": "path2"}
+ )
+ self.assertItemsEqual(x.Runs(), ["run1", "run2"])
+ self.assertEqual(x.GetAccumulator("run1")._path, "path1")
+ self.assertEqual(x.GetAccumulator("run2")._path, "path2")
+
+ def testAddRunsFromDirectory(self):
+ """Tests AddRunsFromDirectory function.
+
+ Tests the following scenarios:
+ - When the directory does not exist.
+ - When the directory is empty.
+ - When the directory has empty subdirectory.
+ - Contains proper EventAccumulators after adding events.
+ """
+ x = event_multiplexer.EventMultiplexer()
+ tmpdir = self.get_temp_dir()
+ join = os.path.join
+ fakedir = join(tmpdir, "fake_accumulator_directory")
+ realdir = join(tmpdir, "real_accumulator_directory")
+ self.assertEqual(x.Runs(), {})
+ x.AddRunsFromDirectory(fakedir)
+ self.assertEqual(x.Runs(), {}, "loading fakedir had no effect")
+
+ _CreateCleanDirectory(realdir)
+ x.AddRunsFromDirectory(realdir)
+ self.assertEqual(x.Runs(), {}, "loading empty directory had no effect")
+
+ path1 = join(realdir, "path1")
+ tf.io.gfile.mkdir(path1)
+ x.AddRunsFromDirectory(realdir)
+ self.assertEqual(
+ x.Runs(), {}, "creating empty subdirectory had no effect"
+ )
+
+ _AddEvents(path1)
+ x.AddRunsFromDirectory(realdir)
+ self.assertItemsEqual(x.Runs(), ["path1"], "loaded run: path1")
+ loader1 = x.GetAccumulator("path1")
+ self.assertEqual(loader1._path, path1, "has the correct path")
+
+ path2 = join(realdir, "path2")
+ _AddEvents(path2)
+ x.AddRunsFromDirectory(realdir)
+ self.assertItemsEqual(x.Runs(), ["path1", "path2"])
+ self.assertEqual(
+ x.GetAccumulator("path1"), loader1, "loader1 not regenerated"
+ )
+
+ path2_2 = join(path2, "path2")
+ _AddEvents(path2_2)
+ x.AddRunsFromDirectory(realdir)
+ self.assertItemsEqual(x.Runs(), ["path1", "path2", "path2/path2"])
+ self.assertEqual(
+ x.GetAccumulator("path2/path2")._path,
+ path2_2,
+ "loader2 path correct",
+ )
+
+ def testAddRunsFromDirectoryThatContainsEvents(self):
+ x = event_multiplexer.EventMultiplexer()
+ tmpdir = self.get_temp_dir()
+ join = os.path.join
+ realdir = join(tmpdir, "event_containing_directory")
+
+ _CreateCleanDirectory(realdir)
+
+ self.assertEqual(x.Runs(), {})
+
+ _AddEvents(realdir)
+ x.AddRunsFromDirectory(realdir)
+ self.assertItemsEqual(x.Runs(), ["."])
+
+ subdir = join(realdir, "subdir")
+ _AddEvents(subdir)
+ x.AddRunsFromDirectory(realdir)
+ self.assertItemsEqual(x.Runs(), [".", "subdir"])
+
+ def testAddRunsFromDirectoryWithRunNames(self):
+ x = event_multiplexer.EventMultiplexer()
+ tmpdir = self.get_temp_dir()
+ join = os.path.join
+ realdir = join(tmpdir, "event_containing_directory")
+
+ _CreateCleanDirectory(realdir)
+
+ self.assertEqual(x.Runs(), {})
+
+ _AddEvents(realdir)
+ x.AddRunsFromDirectory(realdir, "foo")
+ self.assertItemsEqual(x.Runs(), ["foo/."])
+
+ subdir = join(realdir, "subdir")
+ _AddEvents(subdir)
+ x.AddRunsFromDirectory(realdir, "foo")
+ self.assertItemsEqual(x.Runs(), ["foo/.", "foo/subdir"])
+
+ def testAddRunsFromDirectoryWalksTree(self):
+ x = event_multiplexer.EventMultiplexer()
+ tmpdir = self.get_temp_dir()
+ join = os.path.join
+ realdir = join(tmpdir, "event_containing_directory")
+
+ _CreateCleanDirectory(realdir)
+ _AddEvents(realdir)
+ sub = join(realdir, "subdirectory")
+ sub1 = join(sub, "1")
+ sub2 = join(sub, "2")
+ sub1_1 = join(sub1, "1")
+ _AddEvents(sub1)
+ _AddEvents(sub2)
+ _AddEvents(sub1_1)
+ x.AddRunsFromDirectory(realdir)
+
+ self.assertItemsEqual(
+ x.Runs(),
+ [".", "subdirectory/1", "subdirectory/2", "subdirectory/1/1"],
+ )
+
+ def testAddRunsFromDirectoryThrowsException(self):
+ x = event_multiplexer.EventMultiplexer()
+ tmpdir = self.get_temp_dir()
+
+ filepath = _AddEvents(tmpdir)
+ with self.assertRaises(ValueError):
+ x.AddRunsFromDirectory(filepath)
+
+ def testAddRun(self):
+ x = event_multiplexer.EventMultiplexer()
+ x.AddRun("run1_path", "run1")
+ run1 = x.GetAccumulator("run1")
+ self.assertEqual(sorted(x.Runs().keys()), ["run1"])
+ self.assertEqual(run1._path, "run1_path")
+
+ x.AddRun("run1_path", "run1")
+ self.assertEqual(run1, x.GetAccumulator("run1"), "loader not recreated")
+
+ x.AddRun("run2_path", "run1")
+ new_run1 = x.GetAccumulator("run1")
+ self.assertEqual(new_run1._path, "run2_path")
+ self.assertNotEqual(run1, new_run1)
+
+ x.AddRun("runName3")
+ self.assertItemsEqual(sorted(x.Runs().keys()), ["run1", "runName3"])
+ self.assertEqual(x.GetAccumulator("runName3")._path, "runName3")
+
+ def testAddRunMaintainsLoading(self):
+ x = event_multiplexer.EventMultiplexer()
+ x.Reload()
+ x.AddRun("run1")
+ x.AddRun("run2")
+ self.assertTrue(x.GetAccumulator("run1").reload_called)
+ self.assertTrue(x.GetAccumulator("run2").reload_called)
class EventMultiplexerWithRealAccumulatorTest(tf.test.TestCase):
+ def testDeletingDirectoryRemovesRun(self):
+ x = event_multiplexer.EventMultiplexer()
+ tmpdir = self.get_temp_dir()
+ join = os.path.join
+ run1_dir = join(tmpdir, "run1")
+ run2_dir = join(tmpdir, "run2")
+ run3_dir = join(tmpdir, "run3")
- def testDeletingDirectoryRemovesRun(self):
- x = event_multiplexer.EventMultiplexer()
- tmpdir = self.get_temp_dir()
- join = os.path.join
- run1_dir = join(tmpdir, 'run1')
- run2_dir = join(tmpdir, 'run2')
- run3_dir = join(tmpdir, 'run3')
-
- for dirname in [run1_dir, run2_dir, run3_dir]:
- _AddEvents(dirname)
+ for dirname in [run1_dir, run2_dir, run3_dir]:
+ _AddEvents(dirname)
- x.AddRun(run1_dir, 'run1')
- x.AddRun(run2_dir, 'run2')
- x.AddRun(run3_dir, 'run3')
+ x.AddRun(run1_dir, "run1")
+ x.AddRun(run2_dir, "run2")
+ x.AddRun(run3_dir, "run3")
- x.Reload()
+ x.Reload()
- # Delete the directory, then reload.
- shutil.rmtree(run2_dir)
- x.Reload()
- self.assertNotIn('run2', x.Runs().keys())
+ # Delete the directory, then reload.
+ shutil.rmtree(run2_dir)
+ x.Reload()
+ self.assertNotIn("run2", x.Runs().keys())
-if __name__ == '__main__':
- tf.test.main()
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorboard/backend/event_processing/io_wrapper.py b/tensorboard/backend/event_processing/io_wrapper.py
index 7c78ff2bdc..eebd8c79db 100644
--- a/tensorboard/backend/event_processing/io_wrapper.py
+++ b/tensorboard/backend/event_processing/io_wrapper.py
@@ -29,175 +29,188 @@
logger = tb_logging.get_logger()
-_ESCAPE_GLOB_CHARACTERS_REGEX = re.compile('([*?[])')
+_ESCAPE_GLOB_CHARACTERS_REGEX = re.compile("([*?[])")
def IsCloudPath(path):
- return (
- path.startswith("gs://") or
- path.startswith("s3://") or
- path.startswith("/cns/")
- )
+ return (
+ path.startswith("gs://")
+ or path.startswith("s3://")
+ or path.startswith("/cns/")
+ )
+
def PathSeparator(path):
- return '/' if IsCloudPath(path) else os.sep
+ return "/" if IsCloudPath(path) else os.sep
+
def IsTensorFlowEventsFile(path):
- """Check the path name to see if it is probably a TF Events file.
+ """Check the path name to see if it is probably a TF Events file.
- Args:
- path: A file path to check if it is an event file.
+ Args:
+ path: A file path to check if it is an event file.
- Raises:
- ValueError: If the path is an empty string.
+ Raises:
+ ValueError: If the path is an empty string.
- Returns:
- If path is formatted like a TensorFlowEventsFile.
- """
- if not path:
- raise ValueError('Path must be a nonempty string')
- return 'tfevents' in tf.compat.as_str_any(os.path.basename(path))
+ Returns:
+ If path is formatted like a TensorFlowEventsFile.
+ """
+ if not path:
+ raise ValueError("Path must be a nonempty string")
+ return "tfevents" in tf.compat.as_str_any(os.path.basename(path))
def ListDirectoryAbsolute(directory):
- """Yields all files in the given directory. The paths are absolute."""
- return (os.path.join(directory, path)
- for path in tf.io.gfile.listdir(directory))
+ """Yields all files in the given directory.
+
+ The paths are absolute.
+ """
+ return (
+ os.path.join(directory, path) for path in tf.io.gfile.listdir(directory)
+ )
def _EscapeGlobCharacters(path):
- """Escapes the glob characters in a path.
+ """Escapes the glob characters in a path.
- Python 3 has a glob.escape method, but python 2 lacks it, so we manually
- implement this method.
+ Python 3 has a glob.escape method, but python 2 lacks it, so we manually
+ implement this method.
- Args:
- path: The absolute path to escape.
+ Args:
+ path: The absolute path to escape.
- Returns:
- The escaped path string.
- """
- drive, path = os.path.splitdrive(path)
- return '%s%s' % (drive, _ESCAPE_GLOB_CHARACTERS_REGEX.sub(r'[\1]', path))
+ Returns:
+ The escaped path string.
+ """
+ drive, path = os.path.splitdrive(path)
+ return "%s%s" % (drive, _ESCAPE_GLOB_CHARACTERS_REGEX.sub(r"[\1]", path))
def ListRecursivelyViaGlobbing(top):
- """Recursively lists all files within the directory.
-
- This method does not list subdirectories (in addition to regular files), and
- the file paths are all absolute. If the directory does not exist, this yields
- nothing.
-
- This method does so by glob-ing deeper and deeper directories, ie
- foo/*, foo/*/*, foo/*/*/* and so on until all files are listed. All file
- paths are absolute, and this method lists subdirectories too.
-
- For certain file systems, globbing via this method may prove significantly
- faster than recursively walking a directory. Specifically, TF file systems
- that implement TensorFlow's FileSystem.GetMatchingPaths method could save
- costly disk reads by using this method. However, for other file systems, this
- method might prove slower because the file system performs a walk per call to
- glob (in which case it might as well just perform 1 walk).
-
- Args:
- top: A path to a directory.
-
- Yields:
- A (dir_path, file_paths) tuple for each directory/subdirectory.
- """
- current_glob_string = os.path.join(_EscapeGlobCharacters(top), '*')
- level = 0
-
- while True:
- logger.info('GlobAndListFiles: Starting to glob level %d', level)
- glob = tf.io.gfile.glob(current_glob_string)
- logger.info(
- 'GlobAndListFiles: %d files glob-ed at level %d', len(glob), level)
-
- if not glob:
- # This subdirectory level lacks files. Terminate.
- return
-
- # Map subdirectory to a list of files.
- pairs = collections.defaultdict(list)
- for file_path in glob:
- pairs[os.path.dirname(file_path)].append(file_path)
- for dir_name, file_paths in six.iteritems(pairs):
- yield (dir_name, tuple(file_paths))
-
- if len(pairs) == 1:
- # If at any point the glob returns files that are all in a single
- # directory, replace the current globbing path with that directory as the
- # literal prefix. This should improve efficiency in cases where a single
- # subdir is significantly deeper than the rest of the sudirs.
- current_glob_string = os.path.join(list(pairs.keys())[0], '*')
-
- # Iterate to the next level of subdirectories.
- current_glob_string = os.path.join(current_glob_string, '*')
- level += 1
+ """Recursively lists all files within the directory.
+
+ This method does not list subdirectories (in addition to regular files), and
+ the file paths are all absolute. If the directory does not exist, this yields
+ nothing.
+
+ This method does so by glob-ing deeper and deeper directories, ie
+ foo/*, foo/*/*, foo/*/*/* and so on until all files are listed. All file
+ paths are absolute, and this method lists subdirectories too.
+
+ For certain file systems, globbing via this method may prove significantly
+ faster than recursively walking a directory. Specifically, TF file systems
+ that implement TensorFlow's FileSystem.GetMatchingPaths method could save
+ costly disk reads by using this method. However, for other file systems, this
+ method might prove slower because the file system performs a walk per call to
+ glob (in which case it might as well just perform 1 walk).
+
+ Args:
+ top: A path to a directory.
+
+ Yields:
+ A (dir_path, file_paths) tuple for each directory/subdirectory.
+ """
+ current_glob_string = os.path.join(_EscapeGlobCharacters(top), "*")
+ level = 0
+
+ while True:
+ logger.info("GlobAndListFiles: Starting to glob level %d", level)
+ glob = tf.io.gfile.glob(current_glob_string)
+ logger.info(
+ "GlobAndListFiles: %d files glob-ed at level %d", len(glob), level
+ )
+
+ if not glob:
+ # This subdirectory level lacks files. Terminate.
+ return
+
+ # Map subdirectory to a list of files.
+ pairs = collections.defaultdict(list)
+ for file_path in glob:
+ pairs[os.path.dirname(file_path)].append(file_path)
+ for dir_name, file_paths in six.iteritems(pairs):
+ yield (dir_name, tuple(file_paths))
+
+ if len(pairs) == 1:
+ # If at any point the glob returns files that are all in a single
+ # directory, replace the current globbing path with that directory as the
+ # literal prefix. This should improve efficiency in cases where a single
+ # subdir is significantly deeper than the rest of the sudirs.
+ current_glob_string = os.path.join(list(pairs.keys())[0], "*")
+
+ # Iterate to the next level of subdirectories.
+ current_glob_string = os.path.join(current_glob_string, "*")
+ level += 1
def ListRecursivelyViaWalking(top):
- """Walks a directory tree, yielding (dir_path, file_paths) tuples.
+ """Walks a directory tree, yielding (dir_path, file_paths) tuples.
- For each of `top` and its subdirectories, yields a tuple containing the path
- to the directory and the path to each of the contained files. Note that
- unlike os.Walk()/tf.io.gfile.walk()/ListRecursivelyViaGlobbing, this does not
- list subdirectories. The file paths are all absolute. If the directory does
- not exist, this yields nothing.
+ For each of `top` and its subdirectories, yields a tuple containing the path
+ to the directory and the path to each of the contained files. Note that
+ unlike os.Walk()/tf.io.gfile.walk()/ListRecursivelyViaGlobbing, this does not
+ list subdirectories. The file paths are all absolute. If the directory does
+ not exist, this yields nothing.
- Walking may be incredibly slow on certain file systems.
+ Walking may be incredibly slow on certain file systems.
- Args:
- top: A path to a directory.
+ Args:
+ top: A path to a directory.
- Yields:
- A (dir_path, file_paths) tuple for each directory/subdirectory.
- """
- for dir_path, _, filenames in tf.io.gfile.walk(top, topdown=True):
- yield (dir_path, (os.path.join(dir_path, filename)
- for filename in filenames))
+ Yields:
+ A (dir_path, file_paths) tuple for each directory/subdirectory.
+ """
+ for dir_path, _, filenames in tf.io.gfile.walk(top, topdown=True):
+ yield (
+ dir_path,
+ (os.path.join(dir_path, filename) for filename in filenames),
+ )
def GetLogdirSubdirectories(path):
- """Obtains all subdirectories with events files.
-
- The order of the subdirectories returned is unspecified. The internal logic
- that determines order varies by scenario.
-
- Args:
- path: The path to a directory under which to find subdirectories.
-
- Returns:
- A tuple of absolute paths of all subdirectories each with at least 1 events
- file directly within the subdirectory.
-
- Raises:
- ValueError: If the path passed to the method exists and is not a directory.
- """
- if not tf.io.gfile.exists(path):
- # No directory to traverse.
- return ()
-
- if not tf.io.gfile.isdir(path):
- raise ValueError('GetLogdirSubdirectories: path exists and is not a '
- 'directory, %s' % path)
-
- if IsCloudPath(path):
- # Glob-ing for files can be significantly faster than recursively
- # walking through directories for some file systems.
- logger.info(
- 'GetLogdirSubdirectories: Starting to list directories via glob-ing.')
- traversal_method = ListRecursivelyViaGlobbing
- else:
- # For other file systems, the glob-ing based method might be slower because
- # each call to glob could involve performing a recursive walk.
- logger.info(
- 'GetLogdirSubdirectories: Starting to list directories via walking.')
- traversal_method = ListRecursivelyViaWalking
-
- return (
- subdir
- for (subdir, files) in traversal_method(path)
- if any(IsTensorFlowEventsFile(f) for f in files)
- )
+ """Obtains all subdirectories with events files.
+
+ The order of the subdirectories returned is unspecified. The internal logic
+ that determines order varies by scenario.
+
+ Args:
+ path: The path to a directory under which to find subdirectories.
+
+ Returns:
+ A tuple of absolute paths of all subdirectories each with at least 1 events
+ file directly within the subdirectory.
+
+ Raises:
+ ValueError: If the path passed to the method exists and is not a directory.
+ """
+ if not tf.io.gfile.exists(path):
+ # No directory to traverse.
+ return ()
+
+ if not tf.io.gfile.isdir(path):
+ raise ValueError(
+ "GetLogdirSubdirectories: path exists and is not a "
+ "directory, %s" % path
+ )
+
+ if IsCloudPath(path):
+ # Glob-ing for files can be significantly faster than recursively
+ # walking through directories for some file systems.
+ logger.info(
+ "GetLogdirSubdirectories: Starting to list directories via glob-ing."
+ )
+ traversal_method = ListRecursivelyViaGlobbing
+ else:
+ # For other file systems, the glob-ing based method might be slower because
+ # each call to glob could involve performing a recursive walk.
+ logger.info(
+ "GetLogdirSubdirectories: Starting to list directories via walking."
+ )
+ traversal_method = ListRecursivelyViaWalking
+
+ return (
+ subdir
+ for (subdir, files) in traversal_method(path)
+ if any(IsTensorFlowEventsFile(f) for f in files)
+ )
diff --git a/tensorboard/backend/event_processing/io_wrapper_test.py b/tensorboard/backend/event_processing/io_wrapper_test.py
index 3642275a7f..ab2aed8097 100644
--- a/tensorboard/backend/event_processing/io_wrapper_test.py
+++ b/tensorboard/backend/event_processing/io_wrapper_test.py
@@ -27,303 +27,268 @@
class IoWrapperTest(tf.test.TestCase):
- def setUp(self):
- self.stubs = tf.compat.v1.test.StubOutForTesting()
-
- def tearDown(self):
- self.stubs.CleanUp()
-
- def testIsCloudPathGcsIsTrue(self):
- self.assertTrue(io_wrapper.IsCloudPath('gs://bucket/foo'))
-
- def testIsCloudPathS3IsTrue(self):
- self.assertTrue(io_wrapper.IsCloudPath('s3://bucket/foo'))
-
- def testIsCloudPathCnsIsTrue(self):
- self.assertTrue(io_wrapper.IsCloudPath('/cns/foo/bar'))
-
- def testIsCloudPathFileIsFalse(self):
- self.assertFalse(io_wrapper.IsCloudPath('file:///tmp/foo'))
-
- def testIsCloudPathLocalIsFalse(self):
- self.assertFalse(io_wrapper.IsCloudPath('/tmp/foo'))
-
- def testPathSeparator(self):
- # In nix systems, path separator would be the same as that of CNS/GCS
- # making it hard to tell if something went wrong.
- self.stubs.Set(os, 'sep', '#')
-
- self.assertEqual(io_wrapper.PathSeparator('/tmp/foo'), '#')
- self.assertEqual(io_wrapper.PathSeparator('tmp/foo'), '#')
- self.assertEqual(io_wrapper.PathSeparator('/cns/tmp/foo'), '/')
- self.assertEqual(io_wrapper.PathSeparator('gs://foo'), '/')
-
- def testIsIsTensorFlowEventsFileTrue(self):
- self.assertTrue(
- io_wrapper.IsTensorFlowEventsFile(
- '/logdir/events.out.tfevents.1473720042.com'))
-
- def testIsIsTensorFlowEventsFileFalse(self):
- self.assertFalse(
- io_wrapper.IsTensorFlowEventsFile('/logdir/model.ckpt'))
-
- def testIsIsTensorFlowEventsFileWithEmptyInput(self):
- with six.assertRaisesRegex(self,
- ValueError,
- r'Path must be a nonempty string'):
- io_wrapper.IsTensorFlowEventsFile('')
-
- def testListDirectoryAbsolute(self):
- temp_dir = tempfile.mkdtemp(prefix=self.get_temp_dir())
- self._CreateDeepDirectoryStructure(temp_dir)
-
- expected_files = (
- 'foo',
- 'bar',
- 'quuz',
- 'a.tfevents.1',
- 'model.ckpt',
- 'waldo',
- )
- self.assertItemsEqual(
- (os.path.join(temp_dir, f) for f in expected_files),
- io_wrapper.ListDirectoryAbsolute(temp_dir))
-
- def testListRecursivelyViaGlobbing(self):
- temp_dir = tempfile.mkdtemp(prefix=self.get_temp_dir())
- self._CreateDeepDirectoryStructure(temp_dir)
- expected = [
- ['', [
- 'foo',
- 'bar',
- 'a.tfevents.1',
- 'model.ckpt',
- 'quuz',
- 'waldo',
- ]],
- ['bar', [
- 'b.tfevents.1',
- 'red_herring.txt',
- 'baz',
- 'quux',
- ]],
- ['bar/baz', [
- 'c.tfevents.1',
- 'd.tfevents.1',
- ]],
- ['bar/quux', [
- 'some_flume_output.txt',
- 'some_more_flume_output.txt',
- ]],
- ['quuz', [
- 'e.tfevents.1',
- 'garply',
- ]],
- ['quuz/garply', [
- 'f.tfevents.1',
- 'corge',
- 'grault',
- ]],
- ['quuz/garply/corge', [
- 'g.tfevents.1'
- ]],
- ['quuz/garply/grault', [
- 'h.tfevents.1',
- ]],
- ['waldo', [
- 'fred',
- ]],
- ['waldo/fred', [
- 'i.tfevents.1',
- ]],
- ]
- for pair in expected:
- # If this is not the top-level directory, prepend the high-level
- # directory.
- pair[0] = os.path.join(temp_dir, pair[0]) if pair[0] else temp_dir
- pair[1] = [os.path.join(pair[0], f) for f in pair[1]]
- self._CompareFilesPerSubdirectory(
- expected, io_wrapper.ListRecursivelyViaGlobbing(temp_dir))
-
- def testListRecursivelyViaGlobbingForPathWithGlobCharacters(self):
- temp_dir = tempfile.mkdtemp(prefix=self.get_temp_dir())
- directory_names = (
- 'ba*',
- 'ba*/subdirectory',
- 'bar',
- )
- for directory_name in directory_names:
- os.makedirs(os.path.join(temp_dir, directory_name))
-
- file_names = (
- 'ba*/a.tfevents.1',
- 'ba*/subdirectory/b.tfevents.1',
- 'bar/c.tfevents.1',
- )
- for file_name in file_names:
- open(os.path.join(temp_dir, file_name), 'w').close()
-
- expected = [
- ['', [
- 'a.tfevents.1',
- 'subdirectory',
- ]],
- ['subdirectory', [
- 'b.tfevents.1',
- ]],
- # The contents of the bar subdirectory should be excluded from
- # this listing because the * character should have been escaped.
- ]
- top = os.path.join(temp_dir, 'ba*')
- for pair in expected:
- # If this is not the top-level directory, prepend the high-level
- # directory.
- pair[0] = os.path.join(top, pair[0]) if pair[0] else top
- pair[1] = [os.path.join(pair[0], f) for f in pair[1]]
- self._CompareFilesPerSubdirectory(
- expected, io_wrapper.ListRecursivelyViaGlobbing(top))
-
- def testListRecursivelyViaWalking(self):
- temp_dir = tempfile.mkdtemp(prefix=self.get_temp_dir())
- self._CreateDeepDirectoryStructure(temp_dir)
- expected = [
- ['', [
- 'a.tfevents.1',
- 'model.ckpt',
- ]],
- ['foo', []],
- ['bar', [
- 'b.tfevents.1',
- 'red_herring.txt',
- ]],
- ['bar/baz', [
- 'c.tfevents.1',
- 'd.tfevents.1',
- ]],
- ['bar/quux', [
- 'some_flume_output.txt',
- 'some_more_flume_output.txt',
- ]],
- ['quuz', [
- 'e.tfevents.1',
- ]],
- ['quuz/garply', [
- 'f.tfevents.1',
- ]],
- ['quuz/garply/corge', [
- 'g.tfevents.1',
- ]],
- ['quuz/garply/grault', [
- 'h.tfevents.1',
- ]],
- ['waldo', []],
- ['waldo/fred', [
- 'i.tfevents.1',
- ]],
- ]
- for pair in expected:
- # If this is not the top-level directory, prepend the high-level
- # directory.
- pair[0] = os.path.join(temp_dir, pair[0]) if pair[0] else temp_dir
- pair[1] = [os.path.join(pair[0], f) for f in pair[1]]
- self._CompareFilesPerSubdirectory(
- expected, io_wrapper.ListRecursivelyViaWalking(temp_dir))
-
- def testGetLogdirSubdirectories(self):
- temp_dir = tempfile.mkdtemp(prefix=self.get_temp_dir())
- self._CreateDeepDirectoryStructure(temp_dir)
- # Only subdirectories that immediately contains at least 1 events
- # file should be listed.
- expected = [
- '',
- 'bar',
- 'bar/baz',
- 'quuz',
- 'quuz/garply',
- 'quuz/garply/corge',
- 'quuz/garply/grault',
- 'waldo/fred',
- ]
- self.assertItemsEqual(
- [(os.path.join(temp_dir, subdir) if subdir else temp_dir)
- for subdir in expected],
- io_wrapper.GetLogdirSubdirectories(temp_dir))
-
- def _CreateDeepDirectoryStructure(self, top_directory):
- """Creates a reasonable deep structure of subdirectories with files.
-
- Args:
- top_directory: The absolute path of the top level directory in
- which to create the directory structure.
- """
- # Add a few subdirectories.
- directory_names = (
- # An empty directory.
- 'foo',
- # A directory with an events file (and a text file).
- 'bar',
- # A deeper directory with events files.
- 'bar/baz',
- # A non-empty subdirectory that lacks event files (should be ignored).
- 'bar/quux',
- # This 3-level deep set of subdirectories tests logic that replaces the
- # full glob string with an absolute path prefix if there is only 1
- # subdirectory in the final mapping.
- 'quuz/garply',
- 'quuz/garply/corge',
- 'quuz/garply/grault',
- # A directory that lacks events files, but contains a subdirectory
- # with events files (first level should be ignored, second level should
- # be included).
- 'waldo',
- 'waldo/fred',
- )
- for directory_name in directory_names:
- os.makedirs(os.path.join(top_directory, directory_name))
-
- # Add a few files to the directory.
- file_names = (
- 'a.tfevents.1',
- 'model.ckpt',
- 'bar/b.tfevents.1',
- 'bar/red_herring.txt',
- 'bar/baz/c.tfevents.1',
- 'bar/baz/d.tfevents.1',
- 'bar/quux/some_flume_output.txt',
- 'bar/quux/some_more_flume_output.txt',
- 'quuz/e.tfevents.1',
- 'quuz/garply/f.tfevents.1',
- 'quuz/garply/corge/g.tfevents.1',
- 'quuz/garply/grault/h.tfevents.1',
- 'waldo/fred/i.tfevents.1',
- )
- for file_name in file_names:
- open(os.path.join(top_directory, file_name), 'w').close()
-
- def _CompareFilesPerSubdirectory(self, expected, gotten):
- """Compares iterables of (subdirectory path, list of absolute paths)
-
- Args:
- expected: The expected iterable of 2-tuples.
- gotten: The gotten iterable of 2-tuples.
- """
- expected_directory_to_listing = {
- result[0]: list(result[1]) for result in expected}
- gotten_directory_to_listing = {
- result[0]: list(result[1]) for result in gotten}
- self.assertItemsEqual(
- expected_directory_to_listing.keys(),
- gotten_directory_to_listing.keys())
-
- for subdirectory, expected_listing in expected_directory_to_listing.items():
- gotten_listing = gotten_directory_to_listing[subdirectory]
- self.assertItemsEqual(
- expected_listing,
- gotten_listing,
- 'Files for subdirectory %r must match. Expected %r. Got %r.' % (
- subdirectory, expected_listing, gotten_listing))
-
-
-
-if __name__ == '__main__':
- tf.test.main()
+ def setUp(self):
+ self.stubs = tf.compat.v1.test.StubOutForTesting()
+
+ def tearDown(self):
+ self.stubs.CleanUp()
+
+ def testIsCloudPathGcsIsTrue(self):
+ self.assertTrue(io_wrapper.IsCloudPath("gs://bucket/foo"))
+
+ def testIsCloudPathS3IsTrue(self):
+ self.assertTrue(io_wrapper.IsCloudPath("s3://bucket/foo"))
+
+ def testIsCloudPathCnsIsTrue(self):
+ self.assertTrue(io_wrapper.IsCloudPath("/cns/foo/bar"))
+
+ def testIsCloudPathFileIsFalse(self):
+ self.assertFalse(io_wrapper.IsCloudPath("file:///tmp/foo"))
+
+ def testIsCloudPathLocalIsFalse(self):
+ self.assertFalse(io_wrapper.IsCloudPath("/tmp/foo"))
+
+ def testPathSeparator(self):
+ # In nix systems, path separator would be the same as that of CNS/GCS
+ # making it hard to tell if something went wrong.
+ self.stubs.Set(os, "sep", "#")
+
+ self.assertEqual(io_wrapper.PathSeparator("/tmp/foo"), "#")
+ self.assertEqual(io_wrapper.PathSeparator("tmp/foo"), "#")
+ self.assertEqual(io_wrapper.PathSeparator("/cns/tmp/foo"), "/")
+ self.assertEqual(io_wrapper.PathSeparator("gs://foo"), "/")
+
+ def testIsIsTensorFlowEventsFileTrue(self):
+ self.assertTrue(
+ io_wrapper.IsTensorFlowEventsFile(
+ "/logdir/events.out.tfevents.1473720042.com"
+ )
+ )
+
+ def testIsIsTensorFlowEventsFileFalse(self):
+ self.assertFalse(
+ io_wrapper.IsTensorFlowEventsFile("/logdir/model.ckpt")
+ )
+
+ def testIsIsTensorFlowEventsFileWithEmptyInput(self):
+ with six.assertRaisesRegex(
+ self, ValueError, r"Path must be a nonempty string"
+ ):
+ io_wrapper.IsTensorFlowEventsFile("")
+
+ def testListDirectoryAbsolute(self):
+ temp_dir = tempfile.mkdtemp(prefix=self.get_temp_dir())
+ self._CreateDeepDirectoryStructure(temp_dir)
+
+ expected_files = (
+ "foo",
+ "bar",
+ "quuz",
+ "a.tfevents.1",
+ "model.ckpt",
+ "waldo",
+ )
+ self.assertItemsEqual(
+ (os.path.join(temp_dir, f) for f in expected_files),
+ io_wrapper.ListDirectoryAbsolute(temp_dir),
+ )
+
+ def testListRecursivelyViaGlobbing(self):
+ temp_dir = tempfile.mkdtemp(prefix=self.get_temp_dir())
+ self._CreateDeepDirectoryStructure(temp_dir)
+ expected = [
+ [
+ "",
+ ["foo", "bar", "a.tfevents.1", "model.ckpt", "quuz", "waldo",],
+ ],
+ ["bar", ["b.tfevents.1", "red_herring.txt", "baz", "quux",]],
+ ["bar/baz", ["c.tfevents.1", "d.tfevents.1",]],
+ [
+ "bar/quux",
+ ["some_flume_output.txt", "some_more_flume_output.txt",],
+ ],
+ ["quuz", ["e.tfevents.1", "garply",]],
+ ["quuz/garply", ["f.tfevents.1", "corge", "grault",]],
+ ["quuz/garply/corge", ["g.tfevents.1"]],
+ ["quuz/garply/grault", ["h.tfevents.1",]],
+ ["waldo", ["fred",]],
+ ["waldo/fred", ["i.tfevents.1",]],
+ ]
+ for pair in expected:
+ # If this is not the top-level directory, prepend the high-level
+ # directory.
+ pair[0] = os.path.join(temp_dir, pair[0]) if pair[0] else temp_dir
+ pair[1] = [os.path.join(pair[0], f) for f in pair[1]]
+ self._CompareFilesPerSubdirectory(
+ expected, io_wrapper.ListRecursivelyViaGlobbing(temp_dir)
+ )
+
+ def testListRecursivelyViaGlobbingForPathWithGlobCharacters(self):
+ temp_dir = tempfile.mkdtemp(prefix=self.get_temp_dir())
+ directory_names = (
+ "ba*",
+ "ba*/subdirectory",
+ "bar",
+ )
+ for directory_name in directory_names:
+ os.makedirs(os.path.join(temp_dir, directory_name))
+
+ file_names = (
+ "ba*/a.tfevents.1",
+ "ba*/subdirectory/b.tfevents.1",
+ "bar/c.tfevents.1",
+ )
+ for file_name in file_names:
+ open(os.path.join(temp_dir, file_name), "w").close()
+
+ expected = [
+ ["", ["a.tfevents.1", "subdirectory",]],
+ ["subdirectory", ["b.tfevents.1",]],
+ # The contents of the bar subdirectory should be excluded from
+ # this listing because the * character should have been escaped.
+ ]
+ top = os.path.join(temp_dir, "ba*")
+ for pair in expected:
+ # If this is not the top-level directory, prepend the high-level
+ # directory.
+ pair[0] = os.path.join(top, pair[0]) if pair[0] else top
+ pair[1] = [os.path.join(pair[0], f) for f in pair[1]]
+ self._CompareFilesPerSubdirectory(
+ expected, io_wrapper.ListRecursivelyViaGlobbing(top)
+ )
+
+ def testListRecursivelyViaWalking(self):
+ temp_dir = tempfile.mkdtemp(prefix=self.get_temp_dir())
+ self._CreateDeepDirectoryStructure(temp_dir)
+ expected = [
+ ["", ["a.tfevents.1", "model.ckpt",]],
+ ["foo", []],
+ ["bar", ["b.tfevents.1", "red_herring.txt",]],
+ ["bar/baz", ["c.tfevents.1", "d.tfevents.1",]],
+ [
+ "bar/quux",
+ ["some_flume_output.txt", "some_more_flume_output.txt",],
+ ],
+ ["quuz", ["e.tfevents.1",]],
+ ["quuz/garply", ["f.tfevents.1",]],
+ ["quuz/garply/corge", ["g.tfevents.1",]],
+ ["quuz/garply/grault", ["h.tfevents.1",]],
+ ["waldo", []],
+ ["waldo/fred", ["i.tfevents.1",]],
+ ]
+ for pair in expected:
+ # If this is not the top-level directory, prepend the high-level
+ # directory.
+ pair[0] = os.path.join(temp_dir, pair[0]) if pair[0] else temp_dir
+ pair[1] = [os.path.join(pair[0], f) for f in pair[1]]
+ self._CompareFilesPerSubdirectory(
+ expected, io_wrapper.ListRecursivelyViaWalking(temp_dir)
+ )
+
+ def testGetLogdirSubdirectories(self):
+ temp_dir = tempfile.mkdtemp(prefix=self.get_temp_dir())
+ self._CreateDeepDirectoryStructure(temp_dir)
+ # Only subdirectories that immediately contains at least 1 events
+ # file should be listed.
+ expected = [
+ "",
+ "bar",
+ "bar/baz",
+ "quuz",
+ "quuz/garply",
+ "quuz/garply/corge",
+ "quuz/garply/grault",
+ "waldo/fred",
+ ]
+ self.assertItemsEqual(
+ [
+ (os.path.join(temp_dir, subdir) if subdir else temp_dir)
+ for subdir in expected
+ ],
+ io_wrapper.GetLogdirSubdirectories(temp_dir),
+ )
+
+ def _CreateDeepDirectoryStructure(self, top_directory):
+ """Creates a reasonable deep structure of subdirectories with files.
+
+ Args:
+ top_directory: The absolute path of the top level directory in
+ which to create the directory structure.
+ """
+ # Add a few subdirectories.
+ directory_names = (
+ # An empty directory.
+ "foo",
+ # A directory with an events file (and a text file).
+ "bar",
+ # A deeper directory with events files.
+ "bar/baz",
+ # A non-empty subdirectory that lacks event files (should be ignored).
+ "bar/quux",
+ # This 3-level deep set of subdirectories tests logic that replaces the
+ # full glob string with an absolute path prefix if there is only 1
+ # subdirectory in the final mapping.
+ "quuz/garply",
+ "quuz/garply/corge",
+ "quuz/garply/grault",
+ # A directory that lacks events files, but contains a subdirectory
+ # with events files (first level should be ignored, second level should
+ # be included).
+ "waldo",
+ "waldo/fred",
+ )
+ for directory_name in directory_names:
+ os.makedirs(os.path.join(top_directory, directory_name))
+
+ # Add a few files to the directory.
+ file_names = (
+ "a.tfevents.1",
+ "model.ckpt",
+ "bar/b.tfevents.1",
+ "bar/red_herring.txt",
+ "bar/baz/c.tfevents.1",
+ "bar/baz/d.tfevents.1",
+ "bar/quux/some_flume_output.txt",
+ "bar/quux/some_more_flume_output.txt",
+ "quuz/e.tfevents.1",
+ "quuz/garply/f.tfevents.1",
+ "quuz/garply/corge/g.tfevents.1",
+ "quuz/garply/grault/h.tfevents.1",
+ "waldo/fred/i.tfevents.1",
+ )
+ for file_name in file_names:
+ open(os.path.join(top_directory, file_name), "w").close()
+
+ def _CompareFilesPerSubdirectory(self, expected, gotten):
+ """Compares iterables of (subdirectory path, list of absolute paths)
+
+ Args:
+ expected: The expected iterable of 2-tuples.
+ gotten: The gotten iterable of 2-tuples.
+ """
+ expected_directory_to_listing = {
+ result[0]: list(result[1]) for result in expected
+ }
+ gotten_directory_to_listing = {
+ result[0]: list(result[1]) for result in gotten
+ }
+ self.assertItemsEqual(
+ expected_directory_to_listing.keys(),
+ gotten_directory_to_listing.keys(),
+ )
+
+ for (
+ subdirectory,
+ expected_listing,
+ ) in expected_directory_to_listing.items():
+ gotten_listing = gotten_directory_to_listing[subdirectory]
+ self.assertItemsEqual(
+ expected_listing,
+ gotten_listing,
+ "Files for subdirectory %r must match. Expected %r. Got %r."
+ % (subdirectory, expected_listing, gotten_listing),
+ )
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorboard/backend/event_processing/plugin_asset_util.py b/tensorboard/backend/event_processing/plugin_asset_util.py
index 4adf4a8f80..b13385c62c 100644
--- a/tensorboard/backend/event_processing/plugin_asset_util.py
+++ b/tensorboard/backend/event_processing/plugin_asset_util.py
@@ -26,78 +26,83 @@
def _IsDirectory(parent, item):
- """Helper that returns if parent/item is a directory."""
- return tf.io.gfile.isdir(os.path.join(parent, item))
+ """Helper that returns if parent/item is a directory."""
+ return tf.io.gfile.isdir(os.path.join(parent, item))
def PluginDirectory(logdir, plugin_name):
- """Returns the plugin directory for plugin_name."""
- return os.path.join(logdir, _PLUGINS_DIR, plugin_name)
+ """Returns the plugin directory for plugin_name."""
+ return os.path.join(logdir, _PLUGINS_DIR, plugin_name)
def ListPlugins(logdir):
- """List all the plugins that have registered assets in logdir.
-
- If the plugins_dir does not exist, it returns an empty list. This maintains
- compatibility with old directories that have no plugins written.
-
- Args:
- logdir: A directory that was created by a TensorFlow events writer.
-
- Returns:
- a list of plugin names, as strings
- """
- plugins_dir = os.path.join(logdir, _PLUGINS_DIR)
- try:
- entries = tf.io.gfile.listdir(plugins_dir)
- except tf.errors.NotFoundError:
- return []
- # Strip trailing slashes, which listdir() includes for some filesystems
- # for subdirectories, after using them to bypass IsDirectory().
- return [x.rstrip('/') for x in entries
- if x.endswith('/') or _IsDirectory(plugins_dir, x)]
+ """List all the plugins that have registered assets in logdir.
+
+ If the plugins_dir does not exist, it returns an empty list. This maintains
+ compatibility with old directories that have no plugins written.
+
+ Args:
+ logdir: A directory that was created by a TensorFlow events writer.
+
+ Returns:
+ a list of plugin names, as strings
+ """
+ plugins_dir = os.path.join(logdir, _PLUGINS_DIR)
+ try:
+ entries = tf.io.gfile.listdir(plugins_dir)
+ except tf.errors.NotFoundError:
+ return []
+ # Strip trailing slashes, which listdir() includes for some filesystems
+ # for subdirectories, after using them to bypass IsDirectory().
+ return [
+ x.rstrip("/")
+ for x in entries
+ if x.endswith("/") or _IsDirectory(plugins_dir, x)
+ ]
def ListAssets(logdir, plugin_name):
- """List all the assets that are available for given plugin in a logdir.
+ """List all the assets that are available for given plugin in a logdir.
- Args:
- logdir: A directory that was created by a TensorFlow summary.FileWriter.
- plugin_name: A string name of a plugin to list assets for.
+ Args:
+ logdir: A directory that was created by a TensorFlow summary.FileWriter.
+ plugin_name: A string name of a plugin to list assets for.
- Returns:
- A string list of available plugin assets. If the plugin subdirectory does
- not exist (either because the logdir doesn't exist, or because the plugin
- didn't register) an empty list is returned.
- """
- plugin_dir = PluginDirectory(logdir, plugin_name)
- try:
- # Strip trailing slashes, which listdir() includes for some filesystems.
- return [x.rstrip('/') for x in tf.io.gfile.listdir(plugin_dir)]
- except tf.errors.NotFoundError:
- return []
+ Returns:
+ A string list of available plugin assets. If the plugin subdirectory does
+ not exist (either because the logdir doesn't exist, or because the plugin
+ didn't register) an empty list is returned.
+ """
+ plugin_dir = PluginDirectory(logdir, plugin_name)
+ try:
+ # Strip trailing slashes, which listdir() includes for some filesystems.
+ return [x.rstrip("/") for x in tf.io.gfile.listdir(plugin_dir)]
+ except tf.errors.NotFoundError:
+ return []
def RetrieveAsset(logdir, plugin_name, asset_name):
- """Retrieve a particular plugin asset from a logdir.
-
- Args:
- logdir: A directory that was created by a TensorFlow summary.FileWriter.
- plugin_name: The plugin we want an asset from.
- asset_name: The name of the requested asset.
-
- Returns:
- string contents of the plugin asset.
-
- Raises:
- KeyError: if the asset does not exist.
- """
-
- asset_path = os.path.join(PluginDirectory(logdir, plugin_name), asset_name)
- try:
- with tf.io.gfile.GFile(asset_path, "r") as f:
- return f.read()
- except tf.errors.NotFoundError:
- raise KeyError("Asset path %s not found" % asset_path)
- except tf.errors.OpError as e:
- raise KeyError("Couldn't read asset path: %s, OpError %s" % (asset_path, e))
+ """Retrieve a particular plugin asset from a logdir.
+
+ Args:
+ logdir: A directory that was created by a TensorFlow summary.FileWriter.
+ plugin_name: The plugin we want an asset from.
+ asset_name: The name of the requested asset.
+
+ Returns:
+ string contents of the plugin asset.
+
+ Raises:
+ KeyError: if the asset does not exist.
+ """
+
+ asset_path = os.path.join(PluginDirectory(logdir, plugin_name), asset_name)
+ try:
+ with tf.io.gfile.GFile(asset_path, "r") as f:
+ return f.read()
+ except tf.errors.NotFoundError:
+ raise KeyError("Asset path %s not found" % asset_path)
+ except tf.errors.OpError as e:
+ raise KeyError(
+ "Couldn't read asset path: %s, OpError %s" % (asset_path, e)
+ )
diff --git a/tensorboard/backend/event_processing/plugin_event_accumulator.py b/tensorboard/backend/event_processing/plugin_event_accumulator.py
index 8fbea6f53a..8a976fd8df 100644
--- a/tensorboard/backend/event_processing/plugin_event_accumulator.py
+++ b/tensorboard/backend/event_processing/plugin_event_accumulator.py
@@ -41,14 +41,14 @@
namedtuple = collections.namedtuple
-TensorEvent = namedtuple('TensorEvent', ['wall_time', 'step', 'tensor_proto'])
+TensorEvent = namedtuple("TensorEvent", ["wall_time", "step", "tensor_proto"])
## The tagTypes below are just arbitrary strings chosen to pass the type
## information of the tag from the backend to the frontend
-TENSORS = 'tensors'
-GRAPH = 'graph'
-META_GRAPH = 'meta_graph'
-RUN_METADATA = 'run_metadata'
+TENSORS = "tensors"
+GRAPH = "graph"
+META_GRAPH = "meta_graph"
+RUN_METADATA = "run_metadata"
DEFAULT_SIZE_GUIDANCE = {
TENSORS: 500,
@@ -62,550 +62,604 @@
class EventAccumulator(object):
- """An `EventAccumulator` takes an event generator, and accumulates the values.
-
- The `EventAccumulator` is intended to provide a convenient Python
- interface for loading Event data written during a TensorFlow run.
- TensorFlow writes out `Event` protobuf objects, which have a timestamp
- and step number, and often contain a `Summary`. Summaries can have
- different kinds of data stored as arbitrary tensors. The Summaries
- also have a tag, which we use to organize logically related data. The
- `EventAccumulator` supports retrieving the `Event` and `Summary` data
- by its tag.
-
- Calling `Tags()` gets a map from `tagType` (i.e., `tensors`) to the
- associated tags for those data types. Then, the functional endpoint
- (i.g., `Accumulator.Tensors(tag)`) allows for the retrieval of all
- data associated with that tag.
-
- The `Reload()` method synchronously loads all of the data written so far.
-
- Fields:
- most_recent_step: Step of last Event proto added. This should only
- be accessed from the thread that calls Reload. This is -1 if
- nothing has been loaded yet.
- most_recent_wall_time: Timestamp of last Event proto added. This is
- a float containing seconds from the UNIX epoch, or -1 if
- nothing has been loaded yet. This should only be accessed from
- the thread that calls Reload.
- path: A file path to a directory containing tf events files, or a single
- tf events file. The accumulator will load events from this path.
- tensors_by_tag: A dictionary mapping each tag name to a
- reservoir.Reservoir of tensor summaries. Each such reservoir will
- only use a single key, given by `_TENSOR_RESERVOIR_KEY`.
-
- @@Tensors
- """
-
- def __init__(self,
- path,
- size_guidance=None,
- tensor_size_guidance=None,
- purge_orphaned_data=True,
- event_file_active_filter=None):
- """Construct the `EventAccumulator`.
-
- Args:
+ """An `EventAccumulator` takes an event generator, and accumulates the
+ values.
+
+ The `EventAccumulator` is intended to provide a convenient Python
+ interface for loading Event data written during a TensorFlow run.
+ TensorFlow writes out `Event` protobuf objects, which have a timestamp
+ and step number, and often contain a `Summary`. Summaries can have
+ different kinds of data stored as arbitrary tensors. The Summaries
+ also have a tag, which we use to organize logically related data. The
+ `EventAccumulator` supports retrieving the `Event` and `Summary` data
+ by its tag.
+
+ Calling `Tags()` gets a map from `tagType` (i.e., `tensors`) to the
+ associated tags for those data types. Then, the functional endpoint
+ (i.g., `Accumulator.Tensors(tag)`) allows for the retrieval of all
+ data associated with that tag.
+
+ The `Reload()` method synchronously loads all of the data written so far.
+
+ Fields:
+ most_recent_step: Step of last Event proto added. This should only
+ be accessed from the thread that calls Reload. This is -1 if
+ nothing has been loaded yet.
+ most_recent_wall_time: Timestamp of last Event proto added. This is
+ a float containing seconds from the UNIX epoch, or -1 if
+ nothing has been loaded yet. This should only be accessed from
+ the thread that calls Reload.
path: A file path to a directory containing tf events files, or a single
- tf events file. The accumulator will load events from this path.
- size_guidance: Information on how much data the EventAccumulator should
- store in memory. The DEFAULT_SIZE_GUIDANCE tries not to store too much
- so as to avoid OOMing the client. The size_guidance should be a map
- from a `tagType` string to an integer representing the number of
- items to keep per tag for items of that `tagType`. If the size is 0,
- all events are stored.
- tensor_size_guidance: Like `size_guidance`, but allowing finer
- granularity for tensor summaries. Should be a map from the
- `plugin_name` field on the `PluginData` proto to an integer
- representing the number of items to keep per tag. Plugins for
- which there is no entry in this map will default to the value of
- `size_guidance[event_accumulator.TENSORS]`. Defaults to `{}`.
- purge_orphaned_data: Whether to discard any events that were "orphaned" by
- a TensorFlow restart.
- event_file_active_filter: Optional predicate for determining whether an
- event file latest load timestamp should be considered active. If passed,
- this will enable multifile directory loading.
- """
- size_guidance = dict(size_guidance or DEFAULT_SIZE_GUIDANCE)
- sizes = {}
- for key in DEFAULT_SIZE_GUIDANCE:
- if key in size_guidance:
- sizes[key] = size_guidance[key]
- else:
- sizes[key] = DEFAULT_SIZE_GUIDANCE[key]
- self._size_guidance = size_guidance
- self._tensor_size_guidance = dict(tensor_size_guidance or {})
-
- self._first_event_timestamp = None
-
- self._graph = None
- self._graph_from_metagraph = False
- self._meta_graph = None
- self._tagged_metadata = {}
- self.summary_metadata = {}
- self.tensors_by_tag = {}
- self._tensors_by_tag_lock = threading.Lock()
-
- # Keep a mapping from plugin name to a dict mapping from tag to plugin data
- # content obtained from the SummaryMetadata (metadata field of Value) for
- # that plugin (This is not the entire SummaryMetadata proto - only the
- # content for that plugin). The SummaryWriter only keeps the content on the
- # first event encountered per tag, so we must store that first instance of
- # content for each tag.
- self._plugin_to_tag_to_content = collections.defaultdict(dict)
- self._plugin_tag_locks = collections.defaultdict(threading.Lock)
-
- self.path = path
- self._generator = _GeneratorFromPath(path, event_file_active_filter)
- self._generator_mutex = threading.Lock()
-
- self.purge_orphaned_data = purge_orphaned_data
-
- self.most_recent_step = -1
- self.most_recent_wall_time = -1
- self.file_version = None
-
- def Reload(self):
- """Loads all events added since the last call to `Reload`.
-
- If `Reload` was never called, loads all events in the file.
-
- Returns:
- The `EventAccumulator`.
- """
- with self._generator_mutex:
- for event in self._generator.Load():
- self._ProcessEvent(event)
- return self
-
- def PluginAssets(self, plugin_name):
- """Return a list of all plugin assets for the given plugin.
-
- Args:
- plugin_name: The string name of a plugin to retrieve assets for.
-
- Returns:
- A list of string plugin asset names, or empty list if none are available.
- If the plugin was not registered, an empty list is returned.
- """
- return plugin_asset_util.ListAssets(self.path, plugin_name)
-
- def RetrievePluginAsset(self, plugin_name, asset_name):
- """Return the contents of a given plugin asset.
-
- Args:
- plugin_name: The string name of a plugin.
- asset_name: The string name of an asset.
-
- Returns:
- The string contents of the plugin asset.
-
- Raises:
- KeyError: If the asset is not available.
- """
- return plugin_asset_util.RetrieveAsset(self.path, plugin_name, asset_name)
-
- def FirstEventTimestamp(self):
- """Returns the timestamp in seconds of the first event.
-
- If the first event has been loaded (either by this method or by `Reload`,
- this returns immediately. Otherwise, it will load in the first event. Note
- that this means that calling `Reload` will cause this to block until
- `Reload` has finished.
-
- Returns:
- The timestamp in seconds of the first event that was loaded.
+ tf events file. The accumulator will load events from this path.
+ tensors_by_tag: A dictionary mapping each tag name to a
+ reservoir.Reservoir of tensor summaries. Each such reservoir will
+ only use a single key, given by `_TENSOR_RESERVOIR_KEY`.
- Raises:
- ValueError: If no events have been loaded and there were no events found
- on disk.
+ @@Tensors
"""
- if self._first_event_timestamp is not None:
- return self._first_event_timestamp
- with self._generator_mutex:
- try:
- event = next(self._generator.Load())
- self._ProcessEvent(event)
- return self._first_event_timestamp
- except StopIteration:
- raise ValueError('No event timestamp could be found')
-
- def PluginTagToContent(self, plugin_name):
- """Returns a dict mapping tags to content specific to that plugin.
-
- Args:
- plugin_name: The name of the plugin for which to fetch plugin-specific
- content.
-
- Raises:
- KeyError: if the plugin name is not found.
-
- Returns:
- A dict mapping tag names to bytestrings of plugin-specific content-- by
- convention, in the form of binary serialized protos.
- """
- if plugin_name not in self._plugin_to_tag_to_content:
- raise KeyError('Plugin %r could not be found.' % plugin_name)
- with self._plugin_tag_locks[plugin_name]:
- # Return a snapshot to avoid concurrent mutation and iteration issues.
- return dict(self._plugin_to_tag_to_content[plugin_name])
-
- def SummaryMetadata(self, tag):
- """Given a summary tag name, return the associated metadata object.
-
- Args:
- tag: The name of a tag, as a string.
-
- Raises:
- KeyError: If the tag is not found.
-
- Returns:
- A `SummaryMetadata` protobuf.
- """
- return self.summary_metadata[tag]
-
- def _ProcessEvent(self, event):
- """Called whenever an event is loaded."""
- if self._first_event_timestamp is None:
- self._first_event_timestamp = event.wall_time
-
- if event.HasField('file_version'):
- new_file_version = _ParseFileVersion(event.file_version)
- if self.file_version and self.file_version != new_file_version:
- ## This should not happen.
- logger.warn(('Found new file_version for event.proto. This will '
- 'affect purging logic for TensorFlow restarts. '
- 'Old: {0} New: {1}').format(self.file_version,
- new_file_version))
- self.file_version = new_file_version
-
- self._MaybePurgeOrphanedData(event)
-
- ## Process the event.
- # GraphDef and MetaGraphDef are handled in a special way:
- # If no graph_def Event is available, but a meta_graph_def is, and it
- # contains a graph_def, then use the meta_graph_def.graph_def as our graph.
- # If a graph_def Event is available, always prefer it to the graph_def
- # inside the meta_graph_def.
- if event.HasField('graph_def'):
- if self._graph is not None:
- logger.warn(
- ('Found more than one graph event per run, or there was '
- 'a metagraph containing a graph_def, as well as one or '
- 'more graph events. Overwriting the graph with the '
- 'newest event.'))
- self._graph = event.graph_def
- self._graph_from_metagraph = False
- elif event.HasField('meta_graph_def'):
- if self._meta_graph is not None:
- logger.warn(('Found more than one metagraph event per run. '
- 'Overwriting the metagraph with the newest event.'))
- self._meta_graph = event.meta_graph_def
- if self._graph is None or self._graph_from_metagraph:
- # We may have a graph_def in the metagraph. If so, and no
- # graph_def is directly available, use this one instead.
+ def __init__(
+ self,
+ path,
+ size_guidance=None,
+ tensor_size_guidance=None,
+ purge_orphaned_data=True,
+ event_file_active_filter=None,
+ ):
+ """Construct the `EventAccumulator`.
+
+ Args:
+ path: A file path to a directory containing tf events files, or a single
+ tf events file. The accumulator will load events from this path.
+ size_guidance: Information on how much data the EventAccumulator should
+ store in memory. The DEFAULT_SIZE_GUIDANCE tries not to store too much
+ so as to avoid OOMing the client. The size_guidance should be a map
+ from a `tagType` string to an integer representing the number of
+ items to keep per tag for items of that `tagType`. If the size is 0,
+ all events are stored.
+ tensor_size_guidance: Like `size_guidance`, but allowing finer
+ granularity for tensor summaries. Should be a map from the
+ `plugin_name` field on the `PluginData` proto to an integer
+ representing the number of items to keep per tag. Plugins for
+ which there is no entry in this map will default to the value of
+ `size_guidance[event_accumulator.TENSORS]`. Defaults to `{}`.
+ purge_orphaned_data: Whether to discard any events that were "orphaned" by
+ a TensorFlow restart.
+ event_file_active_filter: Optional predicate for determining whether an
+ event file latest load timestamp should be considered active. If passed,
+ this will enable multifile directory loading.
+ """
+ size_guidance = dict(size_guidance or DEFAULT_SIZE_GUIDANCE)
+ sizes = {}
+ for key in DEFAULT_SIZE_GUIDANCE:
+ if key in size_guidance:
+ sizes[key] = size_guidance[key]
+ else:
+ sizes[key] = DEFAULT_SIZE_GUIDANCE[key]
+ self._size_guidance = size_guidance
+ self._tensor_size_guidance = dict(tensor_size_guidance or {})
+
+ self._first_event_timestamp = None
+
+ self._graph = None
+ self._graph_from_metagraph = False
+ self._meta_graph = None
+ self._tagged_metadata = {}
+ self.summary_metadata = {}
+ self.tensors_by_tag = {}
+ self._tensors_by_tag_lock = threading.Lock()
+
+ # Keep a mapping from plugin name to a dict mapping from tag to plugin data
+ # content obtained from the SummaryMetadata (metadata field of Value) for
+ # that plugin (This is not the entire SummaryMetadata proto - only the
+ # content for that plugin). The SummaryWriter only keeps the content on the
+ # first event encountered per tag, so we must store that first instance of
+ # content for each tag.
+ self._plugin_to_tag_to_content = collections.defaultdict(dict)
+ self._plugin_tag_locks = collections.defaultdict(threading.Lock)
+
+ self.path = path
+ self._generator = _GeneratorFromPath(path, event_file_active_filter)
+ self._generator_mutex = threading.Lock()
+
+ self.purge_orphaned_data = purge_orphaned_data
+
+ self.most_recent_step = -1
+ self.most_recent_wall_time = -1
+ self.file_version = None
+
+ def Reload(self):
+ """Loads all events added since the last call to `Reload`.
+
+ If `Reload` was never called, loads all events in the file.
+
+ Returns:
+ The `EventAccumulator`.
+ """
+ with self._generator_mutex:
+ for event in self._generator.Load():
+ self._ProcessEvent(event)
+ return self
+
+ def PluginAssets(self, plugin_name):
+ """Return a list of all plugin assets for the given plugin.
+
+ Args:
+ plugin_name: The string name of a plugin to retrieve assets for.
+
+ Returns:
+ A list of string plugin asset names, or empty list if none are available.
+ If the plugin was not registered, an empty list is returned.
+ """
+ return plugin_asset_util.ListAssets(self.path, plugin_name)
+
+ def RetrievePluginAsset(self, plugin_name, asset_name):
+ """Return the contents of a given plugin asset.
+
+ Args:
+ plugin_name: The string name of a plugin.
+ asset_name: The string name of an asset.
+
+ Returns:
+ The string contents of the plugin asset.
+
+ Raises:
+ KeyError: If the asset is not available.
+ """
+ return plugin_asset_util.RetrieveAsset(
+ self.path, plugin_name, asset_name
+ )
+
+ def FirstEventTimestamp(self):
+ """Returns the timestamp in seconds of the first event.
+
+ If the first event has been loaded (either by this method or by `Reload`,
+ this returns immediately. Otherwise, it will load in the first event. Note
+ that this means that calling `Reload` will cause this to block until
+ `Reload` has finished.
+
+ Returns:
+ The timestamp in seconds of the first event that was loaded.
+
+ Raises:
+ ValueError: If no events have been loaded and there were no events found
+ on disk.
+ """
+ if self._first_event_timestamp is not None:
+ return self._first_event_timestamp
+ with self._generator_mutex:
+ try:
+ event = next(self._generator.Load())
+ self._ProcessEvent(event)
+ return self._first_event_timestamp
+
+ except StopIteration:
+ raise ValueError("No event timestamp could be found")
+
+ def PluginTagToContent(self, plugin_name):
+ """Returns a dict mapping tags to content specific to that plugin.
+
+ Args:
+ plugin_name: The name of the plugin for which to fetch plugin-specific
+ content.
+
+ Raises:
+ KeyError: if the plugin name is not found.
+
+ Returns:
+ A dict mapping tag names to bytestrings of plugin-specific content-- by
+ convention, in the form of binary serialized protos.
+ """
+ if plugin_name not in self._plugin_to_tag_to_content:
+ raise KeyError("Plugin %r could not be found." % plugin_name)
+ with self._plugin_tag_locks[plugin_name]:
+ # Return a snapshot to avoid concurrent mutation and iteration issues.
+ return dict(self._plugin_to_tag_to_content[plugin_name])
+
+ def SummaryMetadata(self, tag):
+ """Given a summary tag name, return the associated metadata object.
+
+ Args:
+ tag: The name of a tag, as a string.
+
+ Raises:
+ KeyError: If the tag is not found.
+
+ Returns:
+ A `SummaryMetadata` protobuf.
+ """
+ return self.summary_metadata[tag]
+
+ def _ProcessEvent(self, event):
+ """Called whenever an event is loaded."""
+ if self._first_event_timestamp is None:
+ self._first_event_timestamp = event.wall_time
+
+ if event.HasField("file_version"):
+ new_file_version = _ParseFileVersion(event.file_version)
+ if self.file_version and self.file_version != new_file_version:
+ ## This should not happen.
+ logger.warn(
+ (
+ "Found new file_version for event.proto. This will "
+ "affect purging logic for TensorFlow restarts. "
+ "Old: {0} New: {1}"
+ ).format(self.file_version, new_file_version)
+ )
+ self.file_version = new_file_version
+
+ self._MaybePurgeOrphanedData(event)
+
+ ## Process the event.
+ # GraphDef and MetaGraphDef are handled in a special way:
+ # If no graph_def Event is available, but a meta_graph_def is, and it
+ # contains a graph_def, then use the meta_graph_def.graph_def as our graph.
+ # If a graph_def Event is available, always prefer it to the graph_def
+ # inside the meta_graph_def.
+ if event.HasField("graph_def"):
+ if self._graph is not None:
+ logger.warn(
+ (
+ "Found more than one graph event per run, or there was "
+ "a metagraph containing a graph_def, as well as one or "
+ "more graph events. Overwriting the graph with the "
+ "newest event."
+ )
+ )
+ self._graph = event.graph_def
+ self._graph_from_metagraph = False
+ elif event.HasField("meta_graph_def"):
+ if self._meta_graph is not None:
+ logger.warn(
+ (
+ "Found more than one metagraph event per run. "
+ "Overwriting the metagraph with the newest event."
+ )
+ )
+ self._meta_graph = event.meta_graph_def
+ if self._graph is None or self._graph_from_metagraph:
+ # We may have a graph_def in the metagraph. If so, and no
+ # graph_def is directly available, use this one instead.
+ meta_graph = meta_graph_pb2.MetaGraphDef()
+ meta_graph.ParseFromString(self._meta_graph)
+ if meta_graph.graph_def:
+ if self._graph is not None:
+ logger.warn(
+ (
+ "Found multiple metagraphs containing graph_defs,"
+ "but did not find any graph events. Overwriting the "
+ "graph with the newest metagraph version."
+ )
+ )
+ self._graph_from_metagraph = True
+ self._graph = meta_graph.graph_def.SerializeToString()
+ elif event.HasField("tagged_run_metadata"):
+ tag = event.tagged_run_metadata.tag
+ if tag in self._tagged_metadata:
+ logger.warn(
+ 'Found more than one "run metadata" event with tag '
+ + tag
+ + ". Overwriting it with the newest event."
+ )
+ self._tagged_metadata[tag] = event.tagged_run_metadata.run_metadata
+ elif event.HasField("summary"):
+ for value in event.summary.value:
+ value = data_compat.migrate_value(value)
+
+ if value.HasField("metadata"):
+ tag = value.tag
+ # We only store the first instance of the metadata. This check
+ # is important: the `FileWriter` does strip metadata from all
+ # values except the first one per each tag, but a new
+ # `FileWriter` is created every time a training job stops and
+ # restarts. Hence, we must also ignore non-initial metadata in
+ # this logic.
+ if tag not in self.summary_metadata:
+ self.summary_metadata[tag] = value.metadata
+ plugin_data = value.metadata.plugin_data
+ if plugin_data.plugin_name:
+ with self._plugin_tag_locks[
+ plugin_data.plugin_name
+ ]:
+ self._plugin_to_tag_to_content[
+ plugin_data.plugin_name
+ ][tag] = plugin_data.content
+ else:
+ logger.warn(
+ (
+ "This summary with tag %r is oddly not associated with a "
+ "plugin."
+ ),
+ tag,
+ )
+
+ if value.HasField("tensor"):
+ datum = value.tensor
+ tag = value.tag
+ if not tag:
+ # This tensor summary was created using the old method that used
+ # plugin assets. We must still continue to support it.
+ tag = value.node_name
+ self._ProcessTensor(tag, event.wall_time, event.step, datum)
+
+ def Tags(self):
+ """Return all tags found in the value stream.
+
+ Returns:
+ A `{tagType: ['list', 'of', 'tags']}` dictionary.
+ """
+ return {
+ TENSORS: list(self.tensors_by_tag.keys()),
+ # Use a heuristic: if the metagraph is available, but
+ # graph is not, then we assume the metagraph contains the graph.
+ GRAPH: self._graph is not None,
+ META_GRAPH: self._meta_graph is not None,
+ RUN_METADATA: list(self._tagged_metadata.keys()),
+ }
+
+ def Graph(self):
+ """Return the graph definition, if there is one.
+
+ If the graph is stored directly, return that. If no graph is stored
+ directly but a metagraph is stored containing a graph, return that.
+
+ Raises:
+ ValueError: If there is no graph for this run.
+
+ Returns:
+ The `graph_def` proto.
+ """
+ graph = graph_pb2.GraphDef()
+ if self._graph is not None:
+ graph.ParseFromString(self._graph)
+ return graph
+ raise ValueError("There is no graph in this EventAccumulator")
+
+ def SerializedGraph(self):
+ """Return the graph definition in serialized form, if there is one."""
+ return self._graph
+
+ def MetaGraph(self):
+ """Return the metagraph definition, if there is one.
+
+ Raises:
+ ValueError: If there is no metagraph for this run.
+
+ Returns:
+ The `meta_graph_def` proto.
+ """
+ if self._meta_graph is None:
+ raise ValueError("There is no metagraph in this EventAccumulator")
meta_graph = meta_graph_pb2.MetaGraphDef()
meta_graph.ParseFromString(self._meta_graph)
- if meta_graph.graph_def:
- if self._graph is not None:
- logger.warn(
- ('Found multiple metagraphs containing graph_defs,'
- 'but did not find any graph events. Overwriting the '
- 'graph with the newest metagraph version.'))
- self._graph_from_metagraph = True
- self._graph = meta_graph.graph_def.SerializeToString()
- elif event.HasField('tagged_run_metadata'):
- tag = event.tagged_run_metadata.tag
- if tag in self._tagged_metadata:
- logger.warn('Found more than one "run metadata" event with tag ' +
- tag + '. Overwriting it with the newest event.')
- self._tagged_metadata[tag] = event.tagged_run_metadata.run_metadata
- elif event.HasField('summary'):
- for value in event.summary.value:
- value = data_compat.migrate_value(value)
-
- if value.HasField('metadata'):
- tag = value.tag
- # We only store the first instance of the metadata. This check
- # is important: the `FileWriter` does strip metadata from all
- # values except the first one per each tag, but a new
- # `FileWriter` is created every time a training job stops and
- # restarts. Hence, we must also ignore non-initial metadata in
- # this logic.
- if tag not in self.summary_metadata:
- self.summary_metadata[tag] = value.metadata
- plugin_data = value.metadata.plugin_data
- if plugin_data.plugin_name:
- with self._plugin_tag_locks[plugin_data.plugin_name]:
- self._plugin_to_tag_to_content[plugin_data.plugin_name][tag] = (
- plugin_data.content)
- else:
- logger.warn(
- ('This summary with tag %r is oddly not associated with a '
- 'plugin.'), tag)
-
- if value.HasField('tensor'):
- datum = value.tensor
- tag = value.tag
- if not tag:
- # This tensor summary was created using the old method that used
- # plugin assets. We must still continue to support it.
- tag = value.node_name
- self._ProcessTensor(tag, event.wall_time, event.step, datum)
-
- def Tags(self):
- """Return all tags found in the value stream.
-
- Returns:
- A `{tagType: ['list', 'of', 'tags']}` dictionary.
- """
- return {
- TENSORS: list(self.tensors_by_tag.keys()),
- # Use a heuristic: if the metagraph is available, but
- # graph is not, then we assume the metagraph contains the graph.
- GRAPH: self._graph is not None,
- META_GRAPH: self._meta_graph is not None,
- RUN_METADATA: list(self._tagged_metadata.keys())
- }
-
- def Graph(self):
- """Return the graph definition, if there is one.
-
- If the graph is stored directly, return that. If no graph is stored
- directly but a metagraph is stored containing a graph, return that.
+ return meta_graph
+
+ def RunMetadata(self, tag):
+ """Given a tag, return the associated session.run() metadata.
+
+ Args:
+ tag: A string tag associated with the event.
+
+ Raises:
+ ValueError: If the tag is not found.
+
+ Returns:
+ The metadata in form of `RunMetadata` proto.
+ """
+ if tag not in self._tagged_metadata:
+ raise ValueError("There is no run metadata with this tag name")
+
+ run_metadata = config_pb2.RunMetadata()
+ run_metadata.ParseFromString(self._tagged_metadata[tag])
+ return run_metadata
+
+ def Tensors(self, tag):
+ """Given a summary tag, return all associated tensors.
+
+ Args:
+ tag: A string tag associated with the events.
+
+ Raises:
+ KeyError: If the tag is not found.
+
+ Returns:
+ An array of `TensorEvent`s.
+ """
+ return self.tensors_by_tag[tag].Items(_TENSOR_RESERVOIR_KEY)
+
+ def _MaybePurgeOrphanedData(self, event):
+ """Maybe purge orphaned data due to a TensorFlow crash.
+
+ When TensorFlow crashes at step T+O and restarts at step T, any events
+ written after step T are now "orphaned" and will be at best misleading if
+ they are included in TensorBoard.
+
+ This logic attempts to determine if there is orphaned data, and purge it
+ if it is found.
+
+ Args:
+ event: The event to use as a reference, to determine if a purge is needed.
+ """
+ if not self.purge_orphaned_data:
+ return
+ ## Check if the event happened after a crash, and purge expired tags.
+ if self.file_version and self.file_version >= 2:
+ ## If the file_version is recent enough, use the SessionLog enum
+ ## to check for restarts.
+ self._CheckForRestartAndMaybePurge(event)
+ else:
+ ## If there is no file version, default to old logic of checking for
+ ## out of order steps.
+ self._CheckForOutOfOrderStepAndMaybePurge(event)
+ # After checking, update the most recent summary step and wall time.
+ if event.HasField("summary"):
+ self.most_recent_step = event.step
+ self.most_recent_wall_time = event.wall_time
+
+ def _CheckForRestartAndMaybePurge(self, event):
+ """Check and discard expired events using SessionLog.START.
+
+ Check for a SessionLog.START event and purge all previously seen events
+ with larger steps, because they are out of date. Because of supervisor
+ threading, it is possible that this logic will cause the first few event
+ messages to be discarded since supervisor threading does not guarantee
+ that the START message is deterministically written first.
+
+ This method is preferred over _CheckForOutOfOrderStepAndMaybePurge which
+ can inadvertently discard events due to supervisor threading.
+
+ Args:
+ event: The event to use as reference. If the event is a START event, all
+ previously seen events with a greater event.step will be purged.
+ """
+ if (
+ event.HasField("session_log")
+ and event.session_log.status == event_pb2.SessionLog.START
+ ):
+ self._Purge(event, by_tags=False)
+
+ def _CheckForOutOfOrderStepAndMaybePurge(self, event):
+ """Check for out-of-order event.step and discard expired events for
+ tags.
+
+ Check if the event is out of order relative to the global most recent step.
+ If it is, purge outdated summaries for tags that the event contains.
+
+ Args:
+ event: The event to use as reference. If the event is out-of-order, all
+ events with the same tags, but with a greater event.step will be purged.
+ """
+ if event.step < self.most_recent_step and event.HasField("summary"):
+ self._Purge(event, by_tags=True)
+
+ def _ProcessTensor(self, tag, wall_time, step, tensor):
+ tv = TensorEvent(wall_time=wall_time, step=step, tensor_proto=tensor)
+ with self._tensors_by_tag_lock:
+ if tag not in self.tensors_by_tag:
+ reservoir_size = self._GetTensorReservoirSize(tag)
+ self.tensors_by_tag[tag] = reservoir.Reservoir(reservoir_size)
+ self.tensors_by_tag[tag].AddItem(_TENSOR_RESERVOIR_KEY, tv)
+
+ def _GetTensorReservoirSize(self, tag):
+ default = self._size_guidance[TENSORS]
+ summary_metadata = self.summary_metadata.get(tag)
+ if summary_metadata is None:
+ return default
+ return self._tensor_size_guidance.get(
+ summary_metadata.plugin_data.plugin_name, default
+ )
+
+ def _Purge(self, event, by_tags):
+ """Purge all events that have occurred after the given event.step.
+
+ If by_tags is True, purge all events that occurred after the given
+ event.step, but only for the tags that the event has. Non-sequential
+ event.steps suggest that a TensorFlow restart occurred, and we discard
+ the out-of-order events to display a consistent view in TensorBoard.
+
+ Discarding by tags is the safer method, when we are unsure whether a restart
+ has occurred, given that threading in supervisor can cause events of
+ different tags to arrive with unsynchronized step values.
+
+ If by_tags is False, then purge all events with event.step greater than the
+ given event.step. This can be used when we are certain that a TensorFlow
+ restart has occurred and these events can be discarded.
+
+ Args:
+ event: The event to use as reference for the purge. All events with
+ the same tags, but with a greater event.step will be purged.
+ by_tags: Bool to dictate whether to discard all out-of-order events or
+ only those that are associated with the given reference event.
+ """
+ ## Keep data in reservoirs that has a step less than event.step
+ _NotExpired = lambda x: x.step < event.step
+
+ num_expired = 0
+ if by_tags:
+ for value in event.summary.value:
+ if value.tag in self.tensors_by_tag:
+ tag_reservoir = self.tensors_by_tag[value.tag]
+ num_expired += tag_reservoir.FilterItems(
+ _NotExpired, _TENSOR_RESERVOIR_KEY
+ )
+ else:
+ for tag_reservoir in six.itervalues(self.tensors_by_tag):
+ num_expired += tag_reservoir.FilterItems(
+ _NotExpired, _TENSOR_RESERVOIR_KEY
+ )
+ if num_expired > 0:
+ purge_msg = _GetPurgeMessage(
+ self.most_recent_step,
+ self.most_recent_wall_time,
+ event.step,
+ event.wall_time,
+ num_expired,
+ )
+ logger.warn(purge_msg)
+
+
+def _GetPurgeMessage(
+ most_recent_step,
+ most_recent_wall_time,
+ event_step,
+ event_wall_time,
+ num_expired,
+):
+ """Return the string message associated with TensorBoard purges."""
+ return (
+ "Detected out of order event.step likely caused by a TensorFlow "
+ "restart. Purging {} expired tensor events from Tensorboard display "
+ "between the previous step: {} (timestamp: {}) and current step: {} "
+ "(timestamp: {})."
+ ).format(
+ num_expired,
+ most_recent_step,
+ most_recent_wall_time,
+ event_step,
+ event_wall_time,
+ )
- Raises:
- ValueError: If there is no graph for this run.
- Returns:
- The `graph_def` proto.
- """
- graph = graph_pb2.GraphDef()
- if self._graph is not None:
- graph.ParseFromString(self._graph)
- return graph
- raise ValueError('There is no graph in this EventAccumulator')
-
- def SerializedGraph(self):
- """Return the graph definition in serialized form, if there is one."""
- return self._graph
-
- def MetaGraph(self):
- """Return the metagraph definition, if there is one.
-
- Raises:
- ValueError: If there is no metagraph for this run.
-
- Returns:
- The `meta_graph_def` proto.
- """
- if self._meta_graph is None:
- raise ValueError('There is no metagraph in this EventAccumulator')
- meta_graph = meta_graph_pb2.MetaGraphDef()
- meta_graph.ParseFromString(self._meta_graph)
- return meta_graph
-
- def RunMetadata(self, tag):
- """Given a tag, return the associated session.run() metadata.
-
- Args:
- tag: A string tag associated with the event.
-
- Raises:
- ValueError: If the tag is not found.
-
- Returns:
- The metadata in form of `RunMetadata` proto.
- """
- if tag not in self._tagged_metadata:
- raise ValueError('There is no run metadata with this tag name')
-
- run_metadata = config_pb2.RunMetadata()
- run_metadata.ParseFromString(self._tagged_metadata[tag])
- return run_metadata
-
- def Tensors(self, tag):
- """Given a summary tag, return all associated tensors.
-
- Args:
- tag: A string tag associated with the events.
-
- Raises:
- KeyError: If the tag is not found.
-
- Returns:
- An array of `TensorEvent`s.
- """
- return self.tensors_by_tag[tag].Items(_TENSOR_RESERVOIR_KEY)
-
- def _MaybePurgeOrphanedData(self, event):
- """Maybe purge orphaned data due to a TensorFlow crash.
-
- When TensorFlow crashes at step T+O and restarts at step T, any events
- written after step T are now "orphaned" and will be at best misleading if
- they are included in TensorBoard.
-
- This logic attempts to determine if there is orphaned data, and purge it
- if it is found.
-
- Args:
- event: The event to use as a reference, to determine if a purge is needed.
- """
- if not self.purge_orphaned_data:
- return
- ## Check if the event happened after a crash, and purge expired tags.
- if self.file_version and self.file_version >= 2:
- ## If the file_version is recent enough, use the SessionLog enum
- ## to check for restarts.
- self._CheckForRestartAndMaybePurge(event)
+def _GeneratorFromPath(path, event_file_active_filter=None):
+ """Create an event generator for file or directory at given path string."""
+ if not path:
+ raise ValueError("path must be a valid string")
+ if io_wrapper.IsTensorFlowEventsFile(path):
+ return event_file_loader.EventFileLoader(path)
+ elif event_file_active_filter:
+ return directory_loader.DirectoryLoader(
+ path,
+ event_file_loader.TimestampedEventFileLoader,
+ path_filter=io_wrapper.IsTensorFlowEventsFile,
+ active_filter=event_file_active_filter,
+ )
else:
- ## If there is no file version, default to old logic of checking for
- ## out of order steps.
- self._CheckForOutOfOrderStepAndMaybePurge(event)
- # After checking, update the most recent summary step and wall time.
- if event.HasField('summary'):
- self.most_recent_step = event.step
- self.most_recent_wall_time = event.wall_time
-
- def _CheckForRestartAndMaybePurge(self, event):
- """Check and discard expired events using SessionLog.START.
-
- Check for a SessionLog.START event and purge all previously seen events
- with larger steps, because they are out of date. Because of supervisor
- threading, it is possible that this logic will cause the first few event
- messages to be discarded since supervisor threading does not guarantee
- that the START message is deterministically written first.
-
- This method is preferred over _CheckForOutOfOrderStepAndMaybePurge which
- can inadvertently discard events due to supervisor threading.
-
- Args:
- event: The event to use as reference. If the event is a START event, all
- previously seen events with a greater event.step will be purged.
- """
- if event.HasField(
- 'session_log') and event.session_log.status == event_pb2.SessionLog.START:
- self._Purge(event, by_tags=False)
+ return directory_watcher.DirectoryWatcher(
+ path,
+ event_file_loader.EventFileLoader,
+ io_wrapper.IsTensorFlowEventsFile,
+ )
- def _CheckForOutOfOrderStepAndMaybePurge(self, event):
- """Check for out-of-order event.step and discard expired events for tags.
- Check if the event is out of order relative to the global most recent step.
- If it is, purge outdated summaries for tags that the event contains.
+def _ParseFileVersion(file_version):
+ """Convert the string file_version in event.proto into a float.
Args:
- event: The event to use as reference. If the event is out-of-order, all
- events with the same tags, but with a greater event.step will be purged.
- """
- if event.step < self.most_recent_step and event.HasField('summary'):
- self._Purge(event, by_tags=True)
-
- def _ProcessTensor(self, tag, wall_time, step, tensor):
- tv = TensorEvent(wall_time=wall_time, step=step, tensor_proto=tensor)
- with self._tensors_by_tag_lock:
- if tag not in self.tensors_by_tag:
- reservoir_size = self._GetTensorReservoirSize(tag)
- self.tensors_by_tag[tag] = reservoir.Reservoir(reservoir_size)
- self.tensors_by_tag[tag].AddItem(_TENSOR_RESERVOIR_KEY, tv)
-
- def _GetTensorReservoirSize(self, tag):
- default = self._size_guidance[TENSORS]
- summary_metadata = self.summary_metadata.get(tag)
- if summary_metadata is None:
- return default
- return self._tensor_size_guidance.get(
- summary_metadata.plugin_data.plugin_name, default)
-
- def _Purge(self, event, by_tags):
- """Purge all events that have occurred after the given event.step.
-
- If by_tags is True, purge all events that occurred after the given
- event.step, but only for the tags that the event has. Non-sequential
- event.steps suggest that a TensorFlow restart occurred, and we discard
- the out-of-order events to display a consistent view in TensorBoard.
-
- Discarding by tags is the safer method, when we are unsure whether a restart
- has occurred, given that threading in supervisor can cause events of
- different tags to arrive with unsynchronized step values.
-
- If by_tags is False, then purge all events with event.step greater than the
- given event.step. This can be used when we are certain that a TensorFlow
- restart has occurred and these events can be discarded.
+ file_version: String file_version from event.proto
- Args:
- event: The event to use as reference for the purge. All events with
- the same tags, but with a greater event.step will be purged.
- by_tags: Bool to dictate whether to discard all out-of-order events or
- only those that are associated with the given reference event.
+ Returns:
+ Version number as a float.
"""
- ## Keep data in reservoirs that has a step less than event.step
- _NotExpired = lambda x: x.step < event.step
-
- num_expired = 0
- if by_tags:
- for value in event.summary.value:
- if value.tag in self.tensors_by_tag:
- tag_reservoir = self.tensors_by_tag[value.tag]
- num_expired += tag_reservoir.FilterItems(
- _NotExpired, _TENSOR_RESERVOIR_KEY)
- else:
- for tag_reservoir in six.itervalues(self.tensors_by_tag):
- num_expired += tag_reservoir.FilterItems(
- _NotExpired, _TENSOR_RESERVOIR_KEY)
- if num_expired > 0:
- purge_msg = _GetPurgeMessage(self.most_recent_step,
- self.most_recent_wall_time, event.step,
- event.wall_time, num_expired)
- logger.warn(purge_msg)
-
-
-def _GetPurgeMessage(most_recent_step, most_recent_wall_time, event_step,
- event_wall_time, num_expired):
- """Return the string message associated with TensorBoard purges."""
- return ('Detected out of order event.step likely caused by a TensorFlow '
- 'restart. Purging {} expired tensor events from Tensorboard display '
- 'between the previous step: {} (timestamp: {}) and current step: {} '
- '(timestamp: {}).'
- ).format(num_expired, most_recent_step, most_recent_wall_time,
- event_step, event_wall_time)
-
-
-def _GeneratorFromPath(path, event_file_active_filter=None):
- """Create an event generator for file or directory at given path string."""
- if not path:
- raise ValueError('path must be a valid string')
- if io_wrapper.IsTensorFlowEventsFile(path):
- return event_file_loader.EventFileLoader(path)
- elif event_file_active_filter:
- return directory_loader.DirectoryLoader(
- path,
- event_file_loader.TimestampedEventFileLoader,
- path_filter=io_wrapper.IsTensorFlowEventsFile,
- active_filter=event_file_active_filter)
- else:
- return directory_watcher.DirectoryWatcher(
- path,
- event_file_loader.EventFileLoader,
- io_wrapper.IsTensorFlowEventsFile)
-
-
-def _ParseFileVersion(file_version):
- """Convert the string file_version in event.proto into a float.
-
- Args:
- file_version: String file_version from event.proto
-
- Returns:
- Version number as a float.
- """
- tokens = file_version.split('brain.Event:')
- try:
- return float(tokens[-1])
- except ValueError:
- ## This should never happen according to the definition of file_version
- ## specified in event.proto.
- logger.warn(
- ('Invalid event.proto file_version. Defaulting to use of '
- 'out-of-order event.step logic for purging expired events.'))
- return -1
+ tokens = file_version.split("brain.Event:")
+ try:
+ return float(tokens[-1])
+ except ValueError:
+ ## This should never happen according to the definition of file_version
+ ## specified in event.proto.
+ logger.warn(
+ (
+ "Invalid event.proto file_version. Defaulting to use of "
+ "out-of-order event.step logic for purging expired events."
+ )
+ )
+ return -1
diff --git a/tensorboard/backend/event_processing/plugin_event_accumulator_test.py b/tensorboard/backend/event_processing/plugin_event_accumulator_test.py
index c6861a9838..17782c5f3c 100644
--- a/tensorboard/backend/event_processing/plugin_event_accumulator_test.py
+++ b/tensorboard/backend/event_processing/plugin_event_accumulator_test.py
@@ -41,720 +41,788 @@
class _EventGenerator(object):
- """Class that can add_events and then yield them back.
+ """Class that can add_events and then yield them back.
- Satisfies the EventGenerator API required for the EventAccumulator.
- Satisfies the EventWriter API required to create a tf.summary.FileWriter.
+ Satisfies the EventGenerator API required for the EventAccumulator.
+ Satisfies the EventWriter API required to create a tf.summary.FileWriter.
- Has additional convenience methods for adding test events.
- """
-
- def __init__(self, testcase, zero_out_timestamps=False):
- self._testcase = testcase
- self.items = []
- self.zero_out_timestamps = zero_out_timestamps
-
- def Load(self):
- while self.items:
- yield self.items.pop(0)
-
- def AddScalarTensor(self, tag, wall_time=0, step=0, value=0):
- """Add a rank-0 tensor event.
-
- Note: This is not related to the scalar plugin; it's just a
- convenience function to add an event whose contents aren't
- important.
+ Has additional convenience methods for adding test events.
"""
- tensor = tensor_util.make_tensor_proto(float(value))
- event = event_pb2.Event(
- wall_time=wall_time,
- step=step,
- summary=summary_pb2.Summary(
- value=[summary_pb2.Summary.Value(tag=tag, tensor=tensor)]))
- self.AddEvent(event)
- def AddEvent(self, event):
- if self.zero_out_timestamps:
- event.wall_time = 0.
- self.items.append(event)
-
- def add_event(self, event): # pylint: disable=invalid-name
- """Match the EventWriter API."""
- self.AddEvent(event)
-
- def get_logdir(self): # pylint: disable=invalid-name
- """Return a temp directory for asset writing."""
- return self._testcase.get_temp_dir()
+ def __init__(self, testcase, zero_out_timestamps=False):
+ self._testcase = testcase
+ self.items = []
+ self.zero_out_timestamps = zero_out_timestamps
+
+ def Load(self):
+ while self.items:
+ yield self.items.pop(0)
+
+ def AddScalarTensor(self, tag, wall_time=0, step=0, value=0):
+ """Add a rank-0 tensor event.
+
+ Note: This is not related to the scalar plugin; it's just a
+ convenience function to add an event whose contents aren't
+ important.
+ """
+ tensor = tensor_util.make_tensor_proto(float(value))
+ event = event_pb2.Event(
+ wall_time=wall_time,
+ step=step,
+ summary=summary_pb2.Summary(
+ value=[summary_pb2.Summary.Value(tag=tag, tensor=tensor)]
+ ),
+ )
+ self.AddEvent(event)
+
+ def AddEvent(self, event):
+ if self.zero_out_timestamps:
+ event.wall_time = 0.0
+ self.items.append(event)
+
+ def add_event(self, event): # pylint: disable=invalid-name
+ """Match the EventWriter API."""
+ self.AddEvent(event)
+
+ def get_logdir(self): # pylint: disable=invalid-name
+ """Return a temp directory for asset writing."""
+ return self._testcase.get_temp_dir()
class EventAccumulatorTest(tf.test.TestCase):
-
- def assertTagsEqual(self, actual, expected):
- """Utility method for checking the return value of the Tags() call.
-
- It fills out the `expected` arg with the default (empty) values for every
- tag type, so that the author needs only specify the non-empty values they
- are interested in testing.
-
- Args:
- actual: The actual Accumulator tags response.
- expected: The expected tags response (empty fields may be omitted)
- """
-
- empty_tags = {
- ea.GRAPH: False,
- ea.META_GRAPH: False,
- ea.RUN_METADATA: [],
- ea.TENSORS: [],
- }
-
- # Verifies that there are no unexpected keys in the actual response.
- # If this line fails, likely you added a new tag type, and need to update
- # the empty_tags dictionary above.
- self.assertItemsEqual(actual.keys(), empty_tags.keys())
-
- for key in actual:
- expected_value = expected.get(key, empty_tags[key])
- if isinstance(expected_value, list):
- self.assertItemsEqual(actual[key], expected_value)
- else:
- self.assertEqual(actual[key], expected_value)
+ def assertTagsEqual(self, actual, expected):
+ """Utility method for checking the return value of the Tags() call.
+
+ It fills out the `expected` arg with the default (empty) values for every
+ tag type, so that the author needs only specify the non-empty values they
+ are interested in testing.
+
+ Args:
+ actual: The actual Accumulator tags response.
+ expected: The expected tags response (empty fields may be omitted)
+ """
+
+ empty_tags = {
+ ea.GRAPH: False,
+ ea.META_GRAPH: False,
+ ea.RUN_METADATA: [],
+ ea.TENSORS: [],
+ }
+
+ # Verifies that there are no unexpected keys in the actual response.
+ # If this line fails, likely you added a new tag type, and need to update
+ # the empty_tags dictionary above.
+ self.assertItemsEqual(actual.keys(), empty_tags.keys())
+
+ for key in actual:
+ expected_value = expected.get(key, empty_tags[key])
+ if isinstance(expected_value, list):
+ self.assertItemsEqual(actual[key], expected_value)
+ else:
+ self.assertEqual(actual[key], expected_value)
class MockingEventAccumulatorTest(EventAccumulatorTest):
+ def setUp(self):
+ super(MockingEventAccumulatorTest, self).setUp()
+ self.stubs = tf.compat.v1.test.StubOutForTesting()
+ self._real_constructor = ea.EventAccumulator
+ self._real_generator = ea._GeneratorFromPath
+
+ def _FakeAccumulatorConstructor(generator, *args, **kwargs):
+ def _FakeGeneratorFromPath(path, event_file_active_filter=None):
+ return generator
+
+ ea._GeneratorFromPath = _FakeGeneratorFromPath
+ return self._real_constructor(generator, *args, **kwargs)
+
+ ea.EventAccumulator = _FakeAccumulatorConstructor
+
+ def tearDown(self):
+ self.stubs.CleanUp()
+ ea.EventAccumulator = self._real_constructor
+ ea._GeneratorFromPath = self._real_generator
+
+ def testEmptyAccumulator(self):
+ gen = _EventGenerator(self)
+ x = ea.EventAccumulator(gen)
+ x.Reload()
+ self.assertTagsEqual(x.Tags(), {})
+
+ def testReload(self):
+ """EventAccumulator contains suitable tags after calling Reload."""
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+ acc.Reload()
+ self.assertTagsEqual(acc.Tags(), {})
+ gen.AddScalarTensor("s1", wall_time=1, step=10, value=50)
+ gen.AddScalarTensor("s2", wall_time=1, step=10, value=80)
+ acc.Reload()
+ self.assertTagsEqual(acc.Tags(), {ea.TENSORS: ["s1", "s2"],})
+
+ def testKeyError(self):
+ """KeyError should be raised when accessing non-existing keys."""
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+ acc.Reload()
+ with self.assertRaises(KeyError):
+ acc.Tensors("s1")
+
+ def testNonValueEvents(self):
+ """Non-value events in the generator don't cause early exits."""
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+ gen.AddScalarTensor("s1", wall_time=1, step=10, value=20)
+ gen.AddEvent(
+ event_pb2.Event(wall_time=2, step=20, file_version="nots2")
+ )
+ gen.AddScalarTensor("s3", wall_time=3, step=100, value=1)
+
+ acc.Reload()
+ self.assertTagsEqual(acc.Tags(), {ea.TENSORS: ["s1", "s3"],})
+
+ def testExpiredDataDiscardedAfterRestartForFileVersionLessThan2(self):
+ """Tests that events are discarded after a restart is detected.
+
+ If a step value is observed to be lower than what was previously seen,
+ this should force a discard of all previous items with the same tag
+ that are outdated.
+
+ Only file versions < 2 use this out-of-order discard logic. Later versions
+ discard events based on the step value of SessionLog.START.
+ """
+ warnings = []
+ self.stubs.Set(logger, "warn", warnings.append)
+
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+
+ gen.AddEvent(
+ event_pb2.Event(wall_time=0, step=0, file_version="brain.Event:1")
+ )
+ gen.AddScalarTensor("s1", wall_time=1, step=100, value=20)
+ gen.AddScalarTensor("s1", wall_time=1, step=200, value=20)
+ gen.AddScalarTensor("s1", wall_time=1, step=300, value=20)
+ acc.Reload()
+ ## Check that number of items are what they should be
+ self.assertEqual([x.step for x in acc.Tensors("s1")], [100, 200, 300])
+
+ gen.AddScalarTensor("s1", wall_time=1, step=101, value=20)
+ gen.AddScalarTensor("s1", wall_time=1, step=201, value=20)
+ gen.AddScalarTensor("s1", wall_time=1, step=301, value=20)
+ acc.Reload()
+ ## Check that we have discarded 200 and 300 from s1
+ self.assertEqual(
+ [x.step for x in acc.Tensors("s1")], [100, 101, 201, 301]
+ )
+
+ def testOrphanedDataNotDiscardedIfFlagUnset(self):
+ """Tests that events are not discarded if purge_orphaned_data is
+ false."""
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen, purge_orphaned_data=False)
+
+ gen.AddEvent(
+ event_pb2.Event(wall_time=0, step=0, file_version="brain.Event:1")
+ )
+ gen.AddScalarTensor("s1", wall_time=1, step=100, value=20)
+ gen.AddScalarTensor("s1", wall_time=1, step=200, value=20)
+ gen.AddScalarTensor("s1", wall_time=1, step=300, value=20)
+ acc.Reload()
+ ## Check that number of items are what they should be
+ self.assertEqual([x.step for x in acc.Tensors("s1")], [100, 200, 300])
+
+ gen.AddScalarTensor("s1", wall_time=1, step=101, value=20)
+ gen.AddScalarTensor("s1", wall_time=1, step=201, value=20)
+ gen.AddScalarTensor("s1", wall_time=1, step=301, value=20)
+ acc.Reload()
+ ## Check that we have NOT discarded 200 and 300 from s1
+ self.assertEqual(
+ [x.step for x in acc.Tensors("s1")], [100, 200, 300, 101, 201, 301]
+ )
+
+ def testEventsDiscardedPerTagAfterRestartForFileVersionLessThan2(self):
+ """Tests that event discards after restart, only affect the misordered
+ tag.
+
+ If a step value is observed to be lower than what was previously seen,
+ this should force a discard of all previous items that are outdated, but
+ only for the out of order tag. Other tags should remain unaffected.
+
+ Only file versions < 2 use this out-of-order discard logic. Later versions
+ discard events based on the step value of SessionLog.START.
+ """
+ warnings = []
+ self.stubs.Set(logger, "warn", warnings.append)
+
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+
+ gen.AddEvent(
+ event_pb2.Event(wall_time=0, step=0, file_version="brain.Event:1")
+ )
+ gen.AddScalarTensor("s1", wall_time=1, step=100, value=20)
+ gen.AddScalarTensor("s2", wall_time=1, step=101, value=20)
+ gen.AddScalarTensor("s1", wall_time=1, step=200, value=20)
+ gen.AddScalarTensor("s2", wall_time=1, step=201, value=20)
+ gen.AddScalarTensor("s1", wall_time=1, step=300, value=20)
+ gen.AddScalarTensor("s2", wall_time=1, step=301, value=20)
+ gen.AddScalarTensor("s1", wall_time=1, step=101, value=20)
+ gen.AddScalarTensor("s3", wall_time=1, step=101, value=20)
+ gen.AddScalarTensor("s1", wall_time=1, step=201, value=20)
+ gen.AddScalarTensor("s1", wall_time=1, step=301, value=20)
+
+ acc.Reload()
+ ## Check that we have discarded 200 and 300 for s1
+ self.assertEqual(
+ [x.step for x in acc.Tensors("s1")], [100, 101, 201, 301]
+ )
+
+ ## Check that s1 discards do not affect s2 (written before out-of-order)
+ ## or s3 (written after out-of-order).
+ ## i.e. check that only events from the out of order tag are discarded
+ self.assertEqual([x.step for x in acc.Tensors("s2")], [101, 201, 301])
+ self.assertEqual([x.step for x in acc.Tensors("s3")], [101])
+
+ def testOnlySummaryEventsTriggerDiscards(self):
+ """Test that file version event does not trigger data purge."""
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+ gen.AddScalarTensor("s1", wall_time=1, step=100, value=20)
+ ev1 = event_pb2.Event(wall_time=2, step=0, file_version="brain.Event:1")
+ graph_bytes = tf.compat.v1.GraphDef().SerializeToString()
+ ev2 = event_pb2.Event(wall_time=3, step=0, graph_def=graph_bytes)
+ gen.AddEvent(ev1)
+ gen.AddEvent(ev2)
+ acc.Reload()
+ self.assertEqual([x.step for x in acc.Tensors("s1")], [100])
+
+ def testSessionLogStartMessageDiscardsExpiredEvents(self):
+ """Test that SessionLog.START message discards expired events.
+
+ This discard logic is preferred over the out-of-order step
+ discard logic, but this logic can only be used for event protos
+ which have the SessionLog enum, which was introduced to
+ event.proto for file_version >= brain.Event:2.
+ """
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+ gen.AddEvent(
+ event_pb2.Event(wall_time=0, step=1, file_version="brain.Event:2")
+ )
+
+ gen.AddScalarTensor("s1", wall_time=1, step=100, value=20)
+ gen.AddScalarTensor("s1", wall_time=1, step=200, value=20)
+ gen.AddScalarTensor("s1", wall_time=1, step=300, value=20)
+ gen.AddScalarTensor("s1", wall_time=1, step=400, value=20)
+
+ gen.AddScalarTensor("s2", wall_time=1, step=202, value=20)
+ gen.AddScalarTensor("s2", wall_time=1, step=203, value=20)
+
+ slog = event_pb2.SessionLog(status=event_pb2.SessionLog.START)
+ gen.AddEvent(event_pb2.Event(wall_time=2, step=201, session_log=slog))
+ acc.Reload()
+ self.assertEqual([x.step for x in acc.Tensors("s1")], [100, 200])
+ self.assertEqual([x.step for x in acc.Tensors("s2")], [])
+
+ def testFirstEventTimestamp(self):
+ """Test that FirstEventTimestamp() returns wall_time of the first
+ event."""
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+ gen.AddEvent(
+ event_pb2.Event(wall_time=10, step=20, file_version="brain.Event:2")
+ )
+ gen.AddScalarTensor("s1", wall_time=30, step=40, value=20)
+ self.assertEqual(acc.FirstEventTimestamp(), 10)
+
+ def testReloadPopulatesFirstEventTimestamp(self):
+ """Test that Reload() means FirstEventTimestamp() won't load events."""
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+ gen.AddEvent(
+ event_pb2.Event(wall_time=1, step=2, file_version="brain.Event:2")
+ )
+
+ acc.Reload()
+
+ def _Die(*args, **kwargs): # pylint: disable=unused-argument
+ raise RuntimeError("Load() should not be called")
+
+ self.stubs.Set(gen, "Load", _Die)
+ self.assertEqual(acc.FirstEventTimestamp(), 1)
+
+ def testFirstEventTimestampLoadsEvent(self):
+ """Test that FirstEventTimestamp() doesn't discard the loaded event."""
+ gen = _EventGenerator(self)
+ acc = ea.EventAccumulator(gen)
+ gen.AddEvent(
+ event_pb2.Event(wall_time=1, step=2, file_version="brain.Event:2")
+ )
+
+ self.assertEqual(acc.FirstEventTimestamp(), 1)
+ acc.Reload()
+ self.assertEqual(acc.file_version, 2.0)
+
+ def testNewStyleScalarSummary(self):
+ """Verify processing of tensorboard.plugins.scalar.summary."""
+ event_sink = _EventGenerator(self, zero_out_timestamps=True)
+ writer = test_util.FileWriter(self.get_temp_dir())
+ writer.event_writer = event_sink
+ with tf.compat.v1.Graph().as_default():
+ with self.test_session() as sess:
+ step = tf.compat.v1.placeholder(tf.float32, shape=[])
+ scalar_summary.op(
+ "accuracy", 1.0 - 1.0 / (step + tf.constant(1.0))
+ )
+ scalar_summary.op("xent", 1.0 / (step + tf.constant(1.0)))
+ merged = tf.compat.v1.summary.merge_all()
+ writer.add_graph(sess.graph)
+ for i in xrange(10):
+ summ = sess.run(merged, feed_dict={step: float(i)})
+ writer.add_summary(summ, global_step=i)
+
+ accumulator = ea.EventAccumulator(event_sink)
+ accumulator.Reload()
+
+ tags = [
+ u"accuracy/scalar_summary",
+ u"xent/scalar_summary",
+ ]
+
+ self.assertTagsEqual(
+ accumulator.Tags(),
+ {ea.TENSORS: tags, ea.GRAPH: True, ea.META_GRAPH: False,},
+ )
+
+ def testNewStyleAudioSummary(self):
+ """Verify processing of tensorboard.plugins.audio.summary."""
+ event_sink = _EventGenerator(self, zero_out_timestamps=True)
+ writer = test_util.FileWriter(self.get_temp_dir())
+ writer.event_writer = event_sink
+ with tf.compat.v1.Graph().as_default():
+ with self.test_session() as sess:
+ ipt = tf.random.normal(shape=[5, 441, 2])
+ with tf.name_scope("1"):
+ audio_summary.op(
+ "one", ipt, sample_rate=44100, max_outputs=1
+ )
+ with tf.name_scope("2"):
+ audio_summary.op(
+ "two", ipt, sample_rate=44100, max_outputs=2
+ )
+ with tf.name_scope("3"):
+ audio_summary.op(
+ "three", ipt, sample_rate=44100, max_outputs=3
+ )
+ merged = tf.compat.v1.summary.merge_all()
+ writer.add_graph(sess.graph)
+ for i in xrange(10):
+ summ = sess.run(merged)
+ writer.add_summary(summ, global_step=i)
+
+ accumulator = ea.EventAccumulator(event_sink)
+ accumulator.Reload()
+
+ tags = [
+ u"1/one/audio_summary",
+ u"2/two/audio_summary",
+ u"3/three/audio_summary",
+ ]
+
+ self.assertTagsEqual(
+ accumulator.Tags(),
+ {ea.TENSORS: tags, ea.GRAPH: True, ea.META_GRAPH: False,},
+ )
+
+ def testNewStyleImageSummary(self):
+ """Verify processing of tensorboard.plugins.image.summary."""
+ event_sink = _EventGenerator(self, zero_out_timestamps=True)
+ writer = test_util.FileWriter(self.get_temp_dir())
+ writer.event_writer = event_sink
+ with tf.compat.v1.Graph().as_default():
+ with self.test_session() as sess:
+ ipt = tf.ones([10, 4, 4, 3], tf.uint8)
+ # This is an interesting example, because the old tf.image_summary op
+ # would throw an error here, because it would be tag reuse.
+ # Using the tf node name instead allows argument re-use to the image
+ # summary.
+ with tf.name_scope("1"):
+ image_summary.op("images", ipt, max_outputs=1)
+ with tf.name_scope("2"):
+ image_summary.op("images", ipt, max_outputs=2)
+ with tf.name_scope("3"):
+ image_summary.op("images", ipt, max_outputs=3)
+ merged = tf.compat.v1.summary.merge_all()
+ writer.add_graph(sess.graph)
+ for i in xrange(10):
+ summ = sess.run(merged)
+ writer.add_summary(summ, global_step=i)
+
+ accumulator = ea.EventAccumulator(event_sink)
+ accumulator.Reload()
+
+ tags = [
+ u"1/images/image_summary",
+ u"2/images/image_summary",
+ u"3/images/image_summary",
+ ]
+
+ self.assertTagsEqual(
+ accumulator.Tags(),
+ {ea.TENSORS: tags, ea.GRAPH: True, ea.META_GRAPH: False,},
+ )
+
+ def testTFSummaryTensor(self):
+ """Verify processing of tf.summary.tensor."""
+ event_sink = _EventGenerator(self, zero_out_timestamps=True)
+ writer = test_util.FileWriter(self.get_temp_dir())
+ writer.event_writer = event_sink
+ with tf.compat.v1.Graph().as_default():
+ with self.test_session() as sess:
+ tensor_summary = tf.compat.v1.summary.tensor_summary
+ tensor_summary("scalar", tf.constant(1.0))
+ tensor_summary("vector", tf.constant([1.0, 2.0, 3.0]))
+ tensor_summary("string", tf.constant(six.b("foobar")))
+ merged = tf.compat.v1.summary.merge_all()
+ summ = sess.run(merged)
+ writer.add_summary(summ, 0)
+
+ accumulator = ea.EventAccumulator(event_sink)
+ accumulator.Reload()
+
+ self.assertTagsEqual(
+ accumulator.Tags(), {ea.TENSORS: ["scalar", "vector", "string"],}
+ )
+
+ scalar_proto = accumulator.Tensors("scalar")[0].tensor_proto
+ scalar = tensor_util.make_ndarray(scalar_proto)
+ vector_proto = accumulator.Tensors("vector")[0].tensor_proto
+ vector = tensor_util.make_ndarray(vector_proto)
+ string_proto = accumulator.Tensors("string")[0].tensor_proto
+ string = tensor_util.make_ndarray(string_proto)
+
+ self.assertTrue(np.array_equal(scalar, 1.0))
+ self.assertTrue(np.array_equal(vector, [1.0, 2.0, 3.0]))
+ self.assertTrue(np.array_equal(string, six.b("foobar")))
+
+ def _testTFSummaryTensor_SizeGuidance(
+ self, plugin_name, tensor_size_guidance, steps, expected_count
+ ):
+ event_sink = _EventGenerator(self, zero_out_timestamps=True)
+ writer = test_util.FileWriter(self.get_temp_dir())
+ writer.event_writer = event_sink
+ with tf.compat.v1.Graph().as_default():
+ with self.test_session() as sess:
+ summary_metadata = summary_pb2.SummaryMetadata(
+ plugin_data=summary_pb2.SummaryMetadata.PluginData(
+ plugin_name=plugin_name, content=b"{}"
+ )
+ )
+ tf.compat.v1.summary.tensor_summary(
+ "scalar",
+ tf.constant(1.0),
+ summary_metadata=summary_metadata,
+ )
+ merged = tf.compat.v1.summary.merge_all()
+ for step in xrange(steps):
+ writer.add_summary(sess.run(merged), global_step=step)
+
+ accumulator = ea.EventAccumulator(
+ event_sink, tensor_size_guidance=tensor_size_guidance
+ )
+ accumulator.Reload()
+
+ tensors = accumulator.Tensors("scalar")
+ self.assertEqual(len(tensors), expected_count)
+
+ def testTFSummaryTensor_SizeGuidance_DefaultToTensorGuidance(self):
+ self._testTFSummaryTensor_SizeGuidance(
+ plugin_name="jabberwocky",
+ tensor_size_guidance={},
+ steps=ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] + 1,
+ expected_count=ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS],
+ )
+
+ def testTFSummaryTensor_SizeGuidance_UseSmallSingularPluginGuidance(self):
+ size = int(ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] / 2)
+ assert size < ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS], size
+ self._testTFSummaryTensor_SizeGuidance(
+ plugin_name="jabberwocky",
+ tensor_size_guidance={"jabberwocky": size},
+ steps=ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] + 1,
+ expected_count=size,
+ )
+
+ def testTFSummaryTensor_SizeGuidance_UseLargeSingularPluginGuidance(self):
+ size = ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] + 5
+ self._testTFSummaryTensor_SizeGuidance(
+ plugin_name="jabberwocky",
+ tensor_size_guidance={"jabberwocky": size},
+ steps=ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] + 10,
+ expected_count=size,
+ )
+
+ def testTFSummaryTensor_SizeGuidance_IgnoreIrrelevantGuidances(self):
+ size_small = int(ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] / 3)
+ size_large = int(ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] / 2)
+ assert size_small < size_large < ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS], (
+ size_small,
+ size_large,
+ )
+ self._testTFSummaryTensor_SizeGuidance(
+ plugin_name="jabberwocky",
+ tensor_size_guidance={
+ "jabberwocky": size_small,
+ "wnoorejbpxl": size_large,
+ },
+ steps=ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] + 1,
+ expected_count=size_small,
+ )
- def setUp(self):
- super(MockingEventAccumulatorTest, self).setUp()
- self.stubs = tf.compat.v1.test.StubOutForTesting()
- self._real_constructor = ea.EventAccumulator
- self._real_generator = ea._GeneratorFromPath
-
- def _FakeAccumulatorConstructor(generator, *args, **kwargs):
- def _FakeGeneratorFromPath(path, event_file_active_filter=None):
- return generator
- ea._GeneratorFromPath = _FakeGeneratorFromPath
- return self._real_constructor(generator, *args, **kwargs)
-
- ea.EventAccumulator = _FakeAccumulatorConstructor
-
- def tearDown(self):
- self.stubs.CleanUp()
- ea.EventAccumulator = self._real_constructor
- ea._GeneratorFromPath = self._real_generator
-
- def testEmptyAccumulator(self):
- gen = _EventGenerator(self)
- x = ea.EventAccumulator(gen)
- x.Reload()
- self.assertTagsEqual(x.Tags(), {})
-
- def testReload(self):
- """EventAccumulator contains suitable tags after calling Reload."""
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
- acc.Reload()
- self.assertTagsEqual(acc.Tags(), {})
- gen.AddScalarTensor('s1', wall_time=1, step=10, value=50)
- gen.AddScalarTensor('s2', wall_time=1, step=10, value=80)
- acc.Reload()
- self.assertTagsEqual(acc.Tags(), {
- ea.TENSORS: ['s1', 's2'],
- })
-
- def testKeyError(self):
- """KeyError should be raised when accessing non-existing keys."""
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
- acc.Reload()
- with self.assertRaises(KeyError):
- acc.Tensors('s1')
-
- def testNonValueEvents(self):
- """Non-value events in the generator don't cause early exits."""
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
- gen.AddScalarTensor('s1', wall_time=1, step=10, value=20)
- gen.AddEvent(
- event_pb2.Event(wall_time=2, step=20, file_version='nots2'))
- gen.AddScalarTensor('s3', wall_time=3, step=100, value=1)
-
- acc.Reload()
- self.assertTagsEqual(acc.Tags(), {
- ea.TENSORS: ['s1', 's3'],
- })
-
- def testExpiredDataDiscardedAfterRestartForFileVersionLessThan2(self):
- """Tests that events are discarded after a restart is detected.
-
- If a step value is observed to be lower than what was previously seen,
- this should force a discard of all previous items with the same tag
- that are outdated.
-
- Only file versions < 2 use this out-of-order discard logic. Later versions
- discard events based on the step value of SessionLog.START.
- """
- warnings = []
- self.stubs.Set(logger, 'warn', warnings.append)
-
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
-
- gen.AddEvent(
- event_pb2.Event(wall_time=0, step=0, file_version='brain.Event:1'))
- gen.AddScalarTensor('s1', wall_time=1, step=100, value=20)
- gen.AddScalarTensor('s1', wall_time=1, step=200, value=20)
- gen.AddScalarTensor('s1', wall_time=1, step=300, value=20)
- acc.Reload()
- ## Check that number of items are what they should be
- self.assertEqual([x.step for x in acc.Tensors('s1')], [100, 200, 300])
-
- gen.AddScalarTensor('s1', wall_time=1, step=101, value=20)
- gen.AddScalarTensor('s1', wall_time=1, step=201, value=20)
- gen.AddScalarTensor('s1', wall_time=1, step=301, value=20)
- acc.Reload()
- ## Check that we have discarded 200 and 300 from s1
- self.assertEqual([x.step for x in acc.Tensors('s1')], [100, 101, 201, 301])
-
- def testOrphanedDataNotDiscardedIfFlagUnset(self):
- """Tests that events are not discarded if purge_orphaned_data is false.
- """
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen, purge_orphaned_data=False)
-
- gen.AddEvent(
- event_pb2.Event(wall_time=0, step=0, file_version='brain.Event:1'))
- gen.AddScalarTensor('s1', wall_time=1, step=100, value=20)
- gen.AddScalarTensor('s1', wall_time=1, step=200, value=20)
- gen.AddScalarTensor('s1', wall_time=1, step=300, value=20)
- acc.Reload()
- ## Check that number of items are what they should be
- self.assertEqual([x.step for x in acc.Tensors('s1')], [100, 200, 300])
-
- gen.AddScalarTensor('s1', wall_time=1, step=101, value=20)
- gen.AddScalarTensor('s1', wall_time=1, step=201, value=20)
- gen.AddScalarTensor('s1', wall_time=1, step=301, value=20)
- acc.Reload()
- ## Check that we have NOT discarded 200 and 300 from s1
- self.assertEqual([x.step for x in acc.Tensors('s1')],
- [100, 200, 300, 101, 201, 301])
-
- def testEventsDiscardedPerTagAfterRestartForFileVersionLessThan2(self):
- """Tests that event discards after restart, only affect the misordered tag.
-
- If a step value is observed to be lower than what was previously seen,
- this should force a discard of all previous items that are outdated, but
- only for the out of order tag. Other tags should remain unaffected.
-
- Only file versions < 2 use this out-of-order discard logic. Later versions
- discard events based on the step value of SessionLog.START.
- """
- warnings = []
- self.stubs.Set(logger, 'warn', warnings.append)
-
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
-
- gen.AddEvent(
- event_pb2.Event(wall_time=0, step=0, file_version='brain.Event:1'))
- gen.AddScalarTensor('s1', wall_time=1, step=100, value=20)
- gen.AddScalarTensor('s2', wall_time=1, step=101, value=20)
- gen.AddScalarTensor('s1', wall_time=1, step=200, value=20)
- gen.AddScalarTensor('s2', wall_time=1, step=201, value=20)
- gen.AddScalarTensor('s1', wall_time=1, step=300, value=20)
- gen.AddScalarTensor('s2', wall_time=1, step=301, value=20)
- gen.AddScalarTensor('s1', wall_time=1, step=101, value=20)
- gen.AddScalarTensor('s3', wall_time=1, step=101, value=20)
- gen.AddScalarTensor('s1', wall_time=1, step=201, value=20)
- gen.AddScalarTensor('s1', wall_time=1, step=301, value=20)
-
- acc.Reload()
- ## Check that we have discarded 200 and 300 for s1
- self.assertEqual([x.step for x in acc.Tensors('s1')], [100, 101, 201, 301])
-
- ## Check that s1 discards do not affect s2 (written before out-of-order)
- ## or s3 (written after out-of-order).
- ## i.e. check that only events from the out of order tag are discarded
- self.assertEqual([x.step for x in acc.Tensors('s2')], [101, 201, 301])
- self.assertEqual([x.step for x in acc.Tensors('s3')], [101])
-
- def testOnlySummaryEventsTriggerDiscards(self):
- """Test that file version event does not trigger data purge."""
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
- gen.AddScalarTensor('s1', wall_time=1, step=100, value=20)
- ev1 = event_pb2.Event(wall_time=2, step=0, file_version='brain.Event:1')
- graph_bytes = tf.compat.v1.GraphDef().SerializeToString()
- ev2 = event_pb2.Event(wall_time=3, step=0, graph_def=graph_bytes)
- gen.AddEvent(ev1)
- gen.AddEvent(ev2)
- acc.Reload()
- self.assertEqual([x.step for x in acc.Tensors('s1')], [100])
-
- def testSessionLogStartMessageDiscardsExpiredEvents(self):
- """Test that SessionLog.START message discards expired events.
-
- This discard logic is preferred over the out-of-order step discard logic,
- but this logic can only be used for event protos which have the SessionLog
- enum, which was introduced to event.proto for file_version >= brain.Event:2.
- """
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
- gen.AddEvent(
- event_pb2.Event(wall_time=0, step=1, file_version='brain.Event:2'))
-
- gen.AddScalarTensor('s1', wall_time=1, step=100, value=20)
- gen.AddScalarTensor('s1', wall_time=1, step=200, value=20)
- gen.AddScalarTensor('s1', wall_time=1, step=300, value=20)
- gen.AddScalarTensor('s1', wall_time=1, step=400, value=20)
-
- gen.AddScalarTensor('s2', wall_time=1, step=202, value=20)
- gen.AddScalarTensor('s2', wall_time=1, step=203, value=20)
-
- slog = event_pb2.SessionLog(status=event_pb2.SessionLog.START)
- gen.AddEvent(
- event_pb2.Event(wall_time=2, step=201, session_log=slog))
- acc.Reload()
- self.assertEqual([x.step for x in acc.Tensors('s1')], [100, 200])
- self.assertEqual([x.step for x in acc.Tensors('s2')], [])
-
- def testFirstEventTimestamp(self):
- """Test that FirstEventTimestamp() returns wall_time of the first event."""
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
- gen.AddEvent(
- event_pb2.Event(wall_time=10, step=20, file_version='brain.Event:2'))
- gen.AddScalarTensor('s1', wall_time=30, step=40, value=20)
- self.assertEqual(acc.FirstEventTimestamp(), 10)
-
- def testReloadPopulatesFirstEventTimestamp(self):
- """Test that Reload() means FirstEventTimestamp() won't load events."""
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
- gen.AddEvent(
- event_pb2.Event(wall_time=1, step=2, file_version='brain.Event:2'))
-
- acc.Reload()
-
- def _Die(*args, **kwargs): # pylint: disable=unused-argument
- raise RuntimeError('Load() should not be called')
-
- self.stubs.Set(gen, 'Load', _Die)
- self.assertEqual(acc.FirstEventTimestamp(), 1)
-
- def testFirstEventTimestampLoadsEvent(self):
- """Test that FirstEventTimestamp() doesn't discard the loaded event."""
- gen = _EventGenerator(self)
- acc = ea.EventAccumulator(gen)
- gen.AddEvent(
- event_pb2.Event(wall_time=1, step=2, file_version='brain.Event:2'))
-
- self.assertEqual(acc.FirstEventTimestamp(), 1)
- acc.Reload()
- self.assertEqual(acc.file_version, 2.0)
-
- def testNewStyleScalarSummary(self):
- """Verify processing of tensorboard.plugins.scalar.summary."""
- event_sink = _EventGenerator(self, zero_out_timestamps=True)
- writer = test_util.FileWriter(self.get_temp_dir())
- writer.event_writer = event_sink
- with tf.compat.v1.Graph().as_default():
- with self.test_session() as sess:
- step = tf.compat.v1.placeholder(tf.float32, shape=[])
- scalar_summary.op('accuracy', 1.0 - 1.0 / (step + tf.constant(1.0)))
- scalar_summary.op('xent', 1.0 / (step + tf.constant(1.0)))
- merged = tf.compat.v1.summary.merge_all()
- writer.add_graph(sess.graph)
- for i in xrange(10):
- summ = sess.run(merged, feed_dict={step: float(i)})
- writer.add_summary(summ, global_step=i)
-
- accumulator = ea.EventAccumulator(event_sink)
- accumulator.Reload()
-
- tags = [
- u'accuracy/scalar_summary',
- u'xent/scalar_summary',
- ]
-
- self.assertTagsEqual(accumulator.Tags(), {
- ea.TENSORS: tags,
- ea.GRAPH: True,
- ea.META_GRAPH: False,
- })
-
- def testNewStyleAudioSummary(self):
- """Verify processing of tensorboard.plugins.audio.summary."""
- event_sink = _EventGenerator(self, zero_out_timestamps=True)
- writer = test_util.FileWriter(self.get_temp_dir())
- writer.event_writer = event_sink
- with tf.compat.v1.Graph().as_default():
- with self.test_session() as sess:
- ipt = tf.random.normal(shape=[5, 441, 2])
- with tf.name_scope('1'):
- audio_summary.op('one', ipt, sample_rate=44100, max_outputs=1)
- with tf.name_scope('2'):
- audio_summary.op('two', ipt, sample_rate=44100, max_outputs=2)
- with tf.name_scope('3'):
- audio_summary.op('three', ipt, sample_rate=44100, max_outputs=3)
- merged = tf.compat.v1.summary.merge_all()
- writer.add_graph(sess.graph)
- for i in xrange(10):
- summ = sess.run(merged)
- writer.add_summary(summ, global_step=i)
-
- accumulator = ea.EventAccumulator(event_sink)
- accumulator.Reload()
-
- tags = [
- u'1/one/audio_summary',
- u'2/two/audio_summary',
- u'3/three/audio_summary',
- ]
-
- self.assertTagsEqual(accumulator.Tags(), {
- ea.TENSORS: tags,
- ea.GRAPH: True,
- ea.META_GRAPH: False,
- })
-
- def testNewStyleImageSummary(self):
- """Verify processing of tensorboard.plugins.image.summary."""
- event_sink = _EventGenerator(self, zero_out_timestamps=True)
- writer = test_util.FileWriter(self.get_temp_dir())
- writer.event_writer = event_sink
- with tf.compat.v1.Graph().as_default():
- with self.test_session() as sess:
- ipt = tf.ones([10, 4, 4, 3], tf.uint8)
- # This is an interesting example, because the old tf.image_summary op
- # would throw an error here, because it would be tag reuse.
- # Using the tf node name instead allows argument re-use to the image
- # summary.
- with tf.name_scope('1'):
- image_summary.op('images', ipt, max_outputs=1)
- with tf.name_scope('2'):
- image_summary.op('images', ipt, max_outputs=2)
- with tf.name_scope('3'):
- image_summary.op('images', ipt, max_outputs=3)
- merged = tf.compat.v1.summary.merge_all()
- writer.add_graph(sess.graph)
- for i in xrange(10):
- summ = sess.run(merged)
- writer.add_summary(summ, global_step=i)
-
- accumulator = ea.EventAccumulator(event_sink)
- accumulator.Reload()
-
- tags = [
- u'1/images/image_summary',
- u'2/images/image_summary',
- u'3/images/image_summary',
- ]
-
- self.assertTagsEqual(accumulator.Tags(), {
- ea.TENSORS: tags,
- ea.GRAPH: True,
- ea.META_GRAPH: False,
- })
-
- def testTFSummaryTensor(self):
- """Verify processing of tf.summary.tensor."""
- event_sink = _EventGenerator(self, zero_out_timestamps=True)
- writer = test_util.FileWriter(self.get_temp_dir())
- writer.event_writer = event_sink
- with tf.compat.v1.Graph().as_default():
- with self.test_session() as sess:
- tensor_summary = tf.compat.v1.summary.tensor_summary
- tensor_summary('scalar', tf.constant(1.0))
- tensor_summary('vector', tf.constant([1.0, 2.0, 3.0]))
- tensor_summary('string', tf.constant(six.b('foobar')))
- merged = tf.compat.v1.summary.merge_all()
- summ = sess.run(merged)
- writer.add_summary(summ, 0)
-
- accumulator = ea.EventAccumulator(event_sink)
- accumulator.Reload()
-
- self.assertTagsEqual(accumulator.Tags(), {
- ea.TENSORS: ['scalar', 'vector', 'string'],
- })
-
- scalar_proto = accumulator.Tensors('scalar')[0].tensor_proto
- scalar = tensor_util.make_ndarray(scalar_proto)
- vector_proto = accumulator.Tensors('vector')[0].tensor_proto
- vector = tensor_util.make_ndarray(vector_proto)
- string_proto = accumulator.Tensors('string')[0].tensor_proto
- string = tensor_util.make_ndarray(string_proto)
-
- self.assertTrue(np.array_equal(scalar, 1.0))
- self.assertTrue(np.array_equal(vector, [1.0, 2.0, 3.0]))
- self.assertTrue(np.array_equal(string, six.b('foobar')))
-
- def _testTFSummaryTensor_SizeGuidance(self,
- plugin_name,
- tensor_size_guidance,
- steps,
- expected_count):
- event_sink = _EventGenerator(self, zero_out_timestamps=True)
- writer = test_util.FileWriter(self.get_temp_dir())
- writer.event_writer = event_sink
- with tf.compat.v1.Graph().as_default():
- with self.test_session() as sess:
+
+class RealisticEventAccumulatorTest(EventAccumulatorTest):
+ def testTensorsRealistically(self):
+ """Test accumulator by writing values and then reading them."""
+
+ def FakeScalarSummary(tag, value):
+ value = summary_pb2.Summary.Value(tag=tag, simple_value=value)
+ summary = summary_pb2.Summary(value=[value])
+ return summary
+
+ directory = os.path.join(self.get_temp_dir(), "values_dir")
+ if tf.io.gfile.isdir(directory):
+ tf.io.gfile.rmtree(directory)
+ tf.io.gfile.mkdir(directory)
+
+ writer = test_util.FileWriter(directory, max_queue=100)
+
+ with tf.Graph().as_default() as graph:
+ _ = tf.constant([2.0, 1.0])
+ # Add a graph to the summary writer.
+ writer.add_graph(graph)
+ graph_def = graph.as_graph_def(add_shapes=True)
+ meta_graph_def = tf.compat.v1.train.export_meta_graph(
+ graph_def=graph_def
+ )
+ writer.add_meta_graph(meta_graph_def)
+
+ run_metadata = config_pb2.RunMetadata()
+ device_stats = run_metadata.step_stats.dev_stats.add()
+ device_stats.device = "test device"
+ writer.add_run_metadata(run_metadata, "test run")
+
+ # Write a bunch of events using the writer.
+ for i in xrange(30):
+ summ_id = FakeScalarSummary("id", i)
+ summ_sq = FakeScalarSummary("sq", i * i)
+ writer.add_summary(summ_id, i * 5)
+ writer.add_summary(summ_sq, i * 5)
+ writer.flush()
+
+ # Verify that we can load those events properly
+ acc = ea.EventAccumulator(directory)
+ acc.Reload()
+ self.assertTagsEqual(
+ acc.Tags(),
+ {
+ ea.TENSORS: ["id", "sq"],
+ ea.GRAPH: True,
+ ea.META_GRAPH: True,
+ ea.RUN_METADATA: ["test run"],
+ },
+ )
+ id_events = acc.Tensors("id")
+ sq_events = acc.Tensors("sq")
+ self.assertEqual(30, len(id_events))
+ self.assertEqual(30, len(sq_events))
+ for i in xrange(30):
+ self.assertEqual(i * 5, id_events[i].step)
+ self.assertEqual(i * 5, sq_events[i].step)
+ self.assertEqual(
+ i, tensor_util.make_ndarray(id_events[i].tensor_proto).item()
+ )
+ self.assertEqual(
+ i * i,
+ tensor_util.make_ndarray(sq_events[i].tensor_proto).item(),
+ )
+
+ # Write a few more events to test incremental reloading
+ for i in xrange(30, 40):
+ summ_id = FakeScalarSummary("id", i)
+ summ_sq = FakeScalarSummary("sq", i * i)
+ writer.add_summary(summ_id, i * 5)
+ writer.add_summary(summ_sq, i * 5)
+ writer.flush()
+
+ # Verify we can now see all of the data
+ acc.Reload()
+ id_events = acc.Tensors("id")
+ sq_events = acc.Tensors("sq")
+ self.assertEqual(40, len(id_events))
+ self.assertEqual(40, len(sq_events))
+ for i in xrange(40):
+ self.assertEqual(i * 5, id_events[i].step)
+ self.assertEqual(i * 5, sq_events[i].step)
+ self.assertEqual(
+ i, tensor_util.make_ndarray(id_events[i].tensor_proto).item()
+ )
+ self.assertEqual(
+ i * i,
+ tensor_util.make_ndarray(sq_events[i].tensor_proto).item(),
+ )
+
+ expected_graph_def = graph_pb2.GraphDef.FromString(
+ graph.as_graph_def(add_shapes=True).SerializeToString()
+ )
+ self.assertProtoEquals(expected_graph_def, acc.Graph())
+ self.assertProtoEquals(
+ expected_graph_def,
+ graph_pb2.GraphDef.FromString(acc.SerializedGraph()),
+ )
+
+ expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString(
+ meta_graph_def.SerializeToString()
+ )
+ self.assertProtoEquals(expected_meta_graph, acc.MetaGraph())
+
+ def testGraphFromMetaGraphBecomesAvailable(self):
+ """Test accumulator by writing values and then reading them."""
+
+ directory = os.path.join(
+ self.get_temp_dir(), "metagraph_test_values_dir"
+ )
+ if tf.io.gfile.isdir(directory):
+ tf.io.gfile.rmtree(directory)
+ tf.io.gfile.mkdir(directory)
+
+ writer = test_util.FileWriter(directory, max_queue=100)
+
+ with tf.Graph().as_default() as graph:
+ _ = tf.constant([2.0, 1.0])
+ # Add a graph to the summary writer.
+ graph_def = graph.as_graph_def(add_shapes=True)
+ meta_graph_def = tf.compat.v1.train.export_meta_graph(
+ graph_def=graph_def
+ )
+ writer.add_meta_graph(meta_graph_def)
+ writer.flush()
+
+ # Verify that we can load those events properly
+ acc = ea.EventAccumulator(directory)
+ acc.Reload()
+ self.assertTagsEqual(acc.Tags(), {ea.GRAPH: True, ea.META_GRAPH: True,})
+
+ expected_graph_def = graph_pb2.GraphDef.FromString(
+ graph.as_graph_def(add_shapes=True).SerializeToString()
+ )
+ self.assertProtoEquals(expected_graph_def, acc.Graph())
+ self.assertProtoEquals(
+ expected_graph_def,
+ graph_pb2.GraphDef.FromString(acc.SerializedGraph()),
+ )
+
+ expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString(
+ meta_graph_def.SerializeToString()
+ )
+ self.assertProtoEquals(expected_meta_graph, acc.MetaGraph())
+
+ def _writeMetadata(self, logdir, summary_metadata, nonce=""):
+ """Write to disk a summary with the given metadata.
+
+ Arguments:
+ logdir: a string
+ summary_metadata: a `SummaryMetadata` protobuf object
+ nonce: optional; will be added to the end of the event file name
+ to guarantee that multiple calls to this function do not stomp the
+ same file
+ """
+
+ summary = summary_pb2.Summary()
+ summary.value.add(
+ tensor=tensor_util.make_tensor_proto(
+ ["po", "ta", "to"], dtype=tf.string
+ ),
+ tag="you_are_it",
+ metadata=summary_metadata,
+ )
+ writer = test_util.FileWriter(logdir, filename_suffix=nonce)
+ writer.add_summary(summary.SerializeToString())
+ writer.close()
+
+ def testSummaryMetadata(self):
+ logdir = self.get_temp_dir()
summary_metadata = summary_pb2.SummaryMetadata(
+ display_name="current tagee",
+ summary_description="no",
plugin_data=summary_pb2.SummaryMetadata.PluginData(
- plugin_name=plugin_name, content=b'{}'))
- tf.compat.v1.summary.tensor_summary('scalar', tf.constant(1.0),
- summary_metadata=summary_metadata)
- merged = tf.compat.v1.summary.merge_all()
- for step in xrange(steps):
- writer.add_summary(sess.run(merged), global_step=step)
-
-
- accumulator = ea.EventAccumulator(
- event_sink, tensor_size_guidance=tensor_size_guidance)
- accumulator.Reload()
-
- tensors = accumulator.Tensors('scalar')
- self.assertEqual(len(tensors), expected_count)
-
- def testTFSummaryTensor_SizeGuidance_DefaultToTensorGuidance(self):
- self._testTFSummaryTensor_SizeGuidance(
- plugin_name='jabberwocky',
- tensor_size_guidance={},
- steps=ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] + 1,
- expected_count=ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS])
-
- def testTFSummaryTensor_SizeGuidance_UseSmallSingularPluginGuidance(self):
- size = int(ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] / 2)
- assert size < ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS], size
- self._testTFSummaryTensor_SizeGuidance(
- plugin_name='jabberwocky',
- tensor_size_guidance={'jabberwocky': size},
- steps=ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] + 1,
- expected_count=size)
-
- def testTFSummaryTensor_SizeGuidance_UseLargeSingularPluginGuidance(self):
- size = ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] + 5
- self._testTFSummaryTensor_SizeGuidance(
- plugin_name='jabberwocky',
- tensor_size_guidance={'jabberwocky': size},
- steps=ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] + 10,
- expected_count=size)
-
- def testTFSummaryTensor_SizeGuidance_IgnoreIrrelevantGuidances(self):
- size_small = int(ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] / 3)
- size_large = int(ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] / 2)
- assert size_small < size_large < ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS], (
- size_small, size_large)
- self._testTFSummaryTensor_SizeGuidance(
- plugin_name='jabberwocky',
- tensor_size_guidance={'jabberwocky': size_small,
- 'wnoorejbpxl': size_large},
- steps=ea.DEFAULT_SIZE_GUIDANCE[ea.TENSORS] + 1,
- expected_count=size_small)
-
+ plugin_name="outlet"
+ ),
+ )
+ self._writeMetadata(logdir, summary_metadata)
+ acc = ea.EventAccumulator(logdir)
+ acc.Reload()
+ self.assertProtoEquals(
+ summary_metadata, acc.SummaryMetadata("you_are_it")
+ )
+
+ def testSummaryMetadata_FirstMetadataWins(self):
+ logdir = self.get_temp_dir()
+ summary_metadata_1 = summary_pb2.SummaryMetadata(
+ display_name="current tagee",
+ summary_description="no",
+ plugin_data=summary_pb2.SummaryMetadata.PluginData(
+ plugin_name="outlet", content=b"120v"
+ ),
+ )
+ self._writeMetadata(logdir, summary_metadata_1, nonce="1")
+ acc = ea.EventAccumulator(logdir)
+ acc.Reload()
+ summary_metadata_2 = summary_pb2.SummaryMetadata(
+ display_name="tagee of the future",
+ summary_description="definitely not",
+ plugin_data=summary_pb2.SummaryMetadata.PluginData(
+ plugin_name="plug", content=b"110v"
+ ),
+ )
+ self._writeMetadata(logdir, summary_metadata_2, nonce="2")
+ acc.Reload()
+
+ self.assertProtoEquals(
+ summary_metadata_1, acc.SummaryMetadata("you_are_it")
+ )
+
+ def testPluginTagToContent_PluginsCannotJumpOnTheBandwagon(self):
+ # If there are multiple `SummaryMetadata` for a given tag, and the
+ # set of plugins in the `plugin_data` of second is different from
+ # that of the first, then the second set should be ignored.
+ logdir = self.get_temp_dir()
+ summary_metadata_1 = summary_pb2.SummaryMetadata(
+ display_name="current tagee",
+ summary_description="no",
+ plugin_data=summary_pb2.SummaryMetadata.PluginData(
+ plugin_name="outlet", content=b"120v"
+ ),
+ )
+ self._writeMetadata(logdir, summary_metadata_1, nonce="1")
+ acc = ea.EventAccumulator(logdir)
+ acc.Reload()
+ summary_metadata_2 = summary_pb2.SummaryMetadata(
+ display_name="tagee of the future",
+ summary_description="definitely not",
+ plugin_data=summary_pb2.SummaryMetadata.PluginData(
+ plugin_name="plug", content=b"110v"
+ ),
+ )
+ self._writeMetadata(logdir, summary_metadata_2, nonce="2")
+ acc.Reload()
-class RealisticEventAccumulatorTest(EventAccumulatorTest):
+ self.assertEqual(
+ acc.PluginTagToContent("outlet"), {"you_are_it": b"120v"}
+ )
+ with six.assertRaisesRegex(self, KeyError, "plug"):
+ acc.PluginTagToContent("plug")
- def testTensorsRealistically(self):
- """Test accumulator by writing values and then reading them."""
-
- def FakeScalarSummary(tag, value):
- value = summary_pb2.Summary.Value(tag=tag, simple_value=value)
- summary = summary_pb2.Summary(value=[value])
- return summary
-
- directory = os.path.join(self.get_temp_dir(), 'values_dir')
- if tf.io.gfile.isdir(directory):
- tf.io.gfile.rmtree(directory)
- tf.io.gfile.mkdir(directory)
-
- writer = test_util.FileWriter(directory, max_queue=100)
-
- with tf.Graph().as_default() as graph:
- _ = tf.constant([2.0, 1.0])
- # Add a graph to the summary writer.
- writer.add_graph(graph)
- graph_def = graph.as_graph_def(add_shapes=True)
- meta_graph_def = tf.compat.v1.train.export_meta_graph(graph_def=graph_def)
- writer.add_meta_graph(meta_graph_def)
-
- run_metadata = config_pb2.RunMetadata()
- device_stats = run_metadata.step_stats.dev_stats.add()
- device_stats.device = 'test device'
- writer.add_run_metadata(run_metadata, 'test run')
-
- # Write a bunch of events using the writer.
- for i in xrange(30):
- summ_id = FakeScalarSummary('id', i)
- summ_sq = FakeScalarSummary('sq', i * i)
- writer.add_summary(summ_id, i * 5)
- writer.add_summary(summ_sq, i * 5)
- writer.flush()
-
- # Verify that we can load those events properly
- acc = ea.EventAccumulator(directory)
- acc.Reload()
- self.assertTagsEqual(acc.Tags(), {
- ea.TENSORS: ['id', 'sq'],
- ea.GRAPH: True,
- ea.META_GRAPH: True,
- ea.RUN_METADATA: ['test run'],
- })
- id_events = acc.Tensors('id')
- sq_events = acc.Tensors('sq')
- self.assertEqual(30, len(id_events))
- self.assertEqual(30, len(sq_events))
- for i in xrange(30):
- self.assertEqual(i * 5, id_events[i].step)
- self.assertEqual(i * 5, sq_events[i].step)
- self.assertEqual(i, tensor_util.make_ndarray(id_events[i].tensor_proto).item())
- self.assertEqual(i * i, tensor_util.make_ndarray(sq_events[i].tensor_proto).item())
-
- # Write a few more events to test incremental reloading
- for i in xrange(30, 40):
- summ_id = FakeScalarSummary('id', i)
- summ_sq = FakeScalarSummary('sq', i * i)
- writer.add_summary(summ_id, i * 5)
- writer.add_summary(summ_sq, i * 5)
- writer.flush()
-
- # Verify we can now see all of the data
- acc.Reload()
- id_events = acc.Tensors('id')
- sq_events = acc.Tensors('sq')
- self.assertEqual(40, len(id_events))
- self.assertEqual(40, len(sq_events))
- for i in xrange(40):
- self.assertEqual(i * 5, id_events[i].step)
- self.assertEqual(i * 5, sq_events[i].step)
- self.assertEqual(i, tensor_util.make_ndarray(id_events[i].tensor_proto).item())
- self.assertEqual(i * i, tensor_util.make_ndarray(sq_events[i].tensor_proto).item())
-
- expected_graph_def = graph_pb2.GraphDef.FromString(
- graph.as_graph_def(add_shapes=True).SerializeToString())
- self.assertProtoEquals(expected_graph_def, acc.Graph())
- self.assertProtoEquals(expected_graph_def,
- graph_pb2.GraphDef.FromString(acc.SerializedGraph()))
-
- expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString(
- meta_graph_def.SerializeToString())
- self.assertProtoEquals(expected_meta_graph, acc.MetaGraph())
-
- def testGraphFromMetaGraphBecomesAvailable(self):
- """Test accumulator by writing values and then reading them."""
-
- directory = os.path.join(self.get_temp_dir(), 'metagraph_test_values_dir')
- if tf.io.gfile.isdir(directory):
- tf.io.gfile.rmtree(directory)
- tf.io.gfile.mkdir(directory)
-
- writer = test_util.FileWriter(directory, max_queue=100)
-
- with tf.Graph().as_default() as graph:
- _ = tf.constant([2.0, 1.0])
- # Add a graph to the summary writer.
- graph_def = graph.as_graph_def(add_shapes=True)
- meta_graph_def = tf.compat.v1.train.export_meta_graph(graph_def=graph_def)
- writer.add_meta_graph(meta_graph_def)
- writer.flush()
-
- # Verify that we can load those events properly
- acc = ea.EventAccumulator(directory)
- acc.Reload()
- self.assertTagsEqual(acc.Tags(), {
- ea.GRAPH: True,
- ea.META_GRAPH: True,
- })
-
- expected_graph_def = graph_pb2.GraphDef.FromString(
- graph.as_graph_def(add_shapes=True).SerializeToString())
- self.assertProtoEquals(expected_graph_def, acc.Graph())
- self.assertProtoEquals(expected_graph_def,
- graph_pb2.GraphDef.FromString(acc.SerializedGraph()))
-
- expected_meta_graph = meta_graph_pb2.MetaGraphDef.FromString(
- meta_graph_def.SerializeToString())
- self.assertProtoEquals(expected_meta_graph, acc.MetaGraph())
-
- def _writeMetadata(self, logdir, summary_metadata, nonce=''):
- """Write to disk a summary with the given metadata.
-
- Arguments:
- logdir: a string
- summary_metadata: a `SummaryMetadata` protobuf object
- nonce: optional; will be added to the end of the event file name
- to guarantee that multiple calls to this function do not stomp the
- same file
- """
- summary = summary_pb2.Summary()
- summary.value.add(
- tensor=tensor_util.make_tensor_proto(['po', 'ta', 'to'], dtype=tf.string),
- tag='you_are_it',
- metadata=summary_metadata)
- writer = test_util.FileWriter(logdir, filename_suffix=nonce)
- writer.add_summary(summary.SerializeToString())
- writer.close()
-
- def testSummaryMetadata(self):
- logdir = self.get_temp_dir()
- summary_metadata = summary_pb2.SummaryMetadata(
- display_name='current tagee',
- summary_description='no',
- plugin_data=summary_pb2.SummaryMetadata.PluginData(
- plugin_name='outlet'))
- self._writeMetadata(logdir, summary_metadata)
- acc = ea.EventAccumulator(logdir)
- acc.Reload()
- self.assertProtoEquals(summary_metadata,
- acc.SummaryMetadata('you_are_it'))
-
- def testSummaryMetadata_FirstMetadataWins(self):
- logdir = self.get_temp_dir()
- summary_metadata_1 = summary_pb2.SummaryMetadata(
- display_name='current tagee',
- summary_description='no',
- plugin_data=summary_pb2.SummaryMetadata.PluginData(
- plugin_name='outlet', content=b'120v'))
- self._writeMetadata(logdir, summary_metadata_1, nonce='1')
- acc = ea.EventAccumulator(logdir)
- acc.Reload()
- summary_metadata_2 = summary_pb2.SummaryMetadata(
- display_name='tagee of the future',
- summary_description='definitely not',
- plugin_data=summary_pb2.SummaryMetadata.PluginData(
- plugin_name='plug', content=b'110v'))
- self._writeMetadata(logdir, summary_metadata_2, nonce='2')
- acc.Reload()
-
- self.assertProtoEquals(summary_metadata_1,
- acc.SummaryMetadata('you_are_it'))
-
- def testPluginTagToContent_PluginsCannotJumpOnTheBandwagon(self):
- # If there are multiple `SummaryMetadata` for a given tag, and the
- # set of plugins in the `plugin_data` of second is different from
- # that of the first, then the second set should be ignored.
- logdir = self.get_temp_dir()
- summary_metadata_1 = summary_pb2.SummaryMetadata(
- display_name='current tagee',
- summary_description='no',
- plugin_data=summary_pb2.SummaryMetadata.PluginData(
- plugin_name='outlet', content=b'120v'))
- self._writeMetadata(logdir, summary_metadata_1, nonce='1')
- acc = ea.EventAccumulator(logdir)
- acc.Reload()
- summary_metadata_2 = summary_pb2.SummaryMetadata(
- display_name='tagee of the future',
- summary_description='definitely not',
- plugin_data=summary_pb2.SummaryMetadata.PluginData(
- plugin_name='plug', content=b'110v'))
- self._writeMetadata(logdir, summary_metadata_2, nonce='2')
- acc.Reload()
-
- self.assertEqual(acc.PluginTagToContent('outlet'),
- {'you_are_it': b'120v'})
- with six.assertRaisesRegex(self, KeyError, 'plug'):
- acc.PluginTagToContent('plug')
-
-if __name__ == '__main__':
- tf.test.main()
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorboard/backend/event_processing/plugin_event_multiplexer.py b/tensorboard/backend/event_processing/plugin_event_multiplexer.py
index 566ab6f9bb..e5304f3e3e 100644
--- a/tensorboard/backend/event_processing/plugin_event_multiplexer.py
+++ b/tensorboard/backend/event_processing/plugin_event_multiplexer.py
@@ -25,439 +25,454 @@
from six.moves import queue, xrange # pylint: disable=redefined-builtin
from tensorboard.backend.event_processing import directory_watcher
-from tensorboard.backend.event_processing import plugin_event_accumulator as event_accumulator
+from tensorboard.backend.event_processing import (
+ plugin_event_accumulator as event_accumulator,
+)
from tensorboard.backend.event_processing import io_wrapper
from tensorboard.util import tb_logging
logger = tb_logging.get_logger()
-class EventMultiplexer(object):
- """An `EventMultiplexer` manages access to multiple `EventAccumulator`s.
-
- Each `EventAccumulator` is associated with a `run`, which is a self-contained
- TensorFlow execution. The `EventMultiplexer` provides methods for extracting
- information about events from multiple `run`s.
-
- Example usage for loading specific runs from files:
-
- ```python
- x = EventMultiplexer({'run1': 'path/to/run1', 'run2': 'path/to/run2'})
- x.Reload()
- ```
-
- Example usage for loading a directory where each subdirectory is a run
-
- ```python
- (eg:) /parent/directory/path/
- /parent/directory/path/run1/
- /parent/directory/path/run1/events.out.tfevents.1001
- /parent/directory/path/run1/events.out.tfevents.1002
-
- /parent/directory/path/run2/
- /parent/directory/path/run2/events.out.tfevents.9232
-
- /parent/directory/path/run3/
- /parent/directory/path/run3/events.out.tfevents.9232
- x = EventMultiplexer().AddRunsFromDirectory('/parent/directory/path')
- (which is equivalent to:)
- x = EventMultiplexer({'run1': '/parent/directory/path/run1', 'run2':...}
- ```
-
- If you would like to watch `/parent/directory/path`, wait for it to be created
- (if necessary) and then periodically pick up new runs, use
- `AutoloadingMultiplexer`
- @@Tensors
- """
-
- def __init__(self,
- run_path_map=None,
- size_guidance=None,
- tensor_size_guidance=None,
- purge_orphaned_data=True,
- max_reload_threads=None,
- event_file_active_filter=None):
- """Constructor for the `EventMultiplexer`.
-
- Args:
- run_path_map: Dict `{run: path}` which specifies the
- name of a run, and the path to find the associated events. If it is
- None, then the EventMultiplexer initializes without any runs.
- size_guidance: A dictionary mapping from `tagType` to the number of items
- to store for each tag of that type. See
- `event_accumulator.EventAccumulator` for details.
- tensor_size_guidance: A dictionary mapping from `plugin_name` to
- the number of items to store for each tag of that type. See
- `event_accumulator.EventAccumulator` for details.
- purge_orphaned_data: Whether to discard any events that were "orphaned" by
- a TensorFlow restart.
- max_reload_threads: The max number of threads that TensorBoard can use
- to reload runs. Each thread reloads one run at a time. If not provided,
- reloads runs serially (one after another).
- event_file_active_filter: Optional predicate for determining whether an
- event file latest load timestamp should be considered active. If passed,
- this will enable multifile directory loading.
- """
- logger.info('Event Multiplexer initializing.')
- self._accumulators_mutex = threading.Lock()
- self._accumulators = {}
- self._paths = {}
- self._reload_called = False
- self._size_guidance = (size_guidance or
- event_accumulator.DEFAULT_SIZE_GUIDANCE)
- self._tensor_size_guidance = tensor_size_guidance
- self.purge_orphaned_data = purge_orphaned_data
- self._max_reload_threads = max_reload_threads or 1
- self._event_file_active_filter = event_file_active_filter
- if run_path_map is not None:
- logger.info('Event Multplexer doing initialization load for %s',
- run_path_map)
- for (run, path) in six.iteritems(run_path_map):
- self.AddRun(path, run)
- logger.info('Event Multiplexer done initializing')
-
- def AddRun(self, path, name=None):
- """Add a run to the multiplexer.
-
- If the name is not specified, it is the same as the path.
-
- If a run by that name exists, and we are already watching the right path,
- do nothing. If we are watching a different path, replace the event
- accumulator.
-
- If `Reload` has been called, it will `Reload` the newly created
- accumulators.
-
- Args:
- path: Path to the event files (or event directory) for given run.
- name: Name of the run to add. If not provided, is set to path.
-
- Returns:
- The `EventMultiplexer`.
- """
- name = name or path
- accumulator = None
- with self._accumulators_mutex:
- if name not in self._accumulators or self._paths[name] != path:
- if name in self._paths and self._paths[name] != path:
- # TODO(@decentralion) - Make it impossible to overwrite an old path
- # with a new path (just give the new path a distinct name)
- logger.warn('Conflict for name %s: old path %s, new path %s',
- name, self._paths[name], path)
- logger.info('Constructing EventAccumulator for %s', path)
- accumulator = event_accumulator.EventAccumulator(
- path,
- size_guidance=self._size_guidance,
- tensor_size_guidance=self._tensor_size_guidance,
- purge_orphaned_data=self.purge_orphaned_data,
- event_file_active_filter=self._event_file_active_filter)
- self._accumulators[name] = accumulator
- self._paths[name] = path
- if accumulator:
- if self._reload_called:
- accumulator.Reload()
- return self
-
- def AddRunsFromDirectory(self, path, name=None):
- """Load runs from a directory; recursively walks subdirectories.
-
- If path doesn't exist, no-op. This ensures that it is safe to call
- `AddRunsFromDirectory` multiple times, even before the directory is made.
-
- If path is a directory, load event files in the directory (if any exist) and
- recursively call AddRunsFromDirectory on any subdirectories. This mean you
- can call AddRunsFromDirectory at the root of a tree of event logs and
- TensorBoard will load them all.
-
- If the `EventMultiplexer` is already loaded this will cause
- the newly created accumulators to `Reload()`.
- Args:
- path: A string path to a directory to load runs from.
- name: Optionally, what name to apply to the runs. If name is provided
- and the directory contains run subdirectories, the name of each subrun
- is the concatenation of the parent name and the subdirectory name. If
- name is provided and the directory contains event files, then a run
- is added called "name" and with the events from the path.
-
- Raises:
- ValueError: If the path exists and isn't a directory.
-
- Returns:
- The `EventMultiplexer`.
- """
- logger.info('Starting AddRunsFromDirectory: %s', path)
- for subdir in io_wrapper.GetLogdirSubdirectories(path):
- logger.info('Adding run from directory %s', subdir)
- rpath = os.path.relpath(subdir, path)
- subname = os.path.join(name, rpath) if name else rpath
- self.AddRun(subdir, name=subname)
- logger.info('Done with AddRunsFromDirectory: %s', path)
- return self
-
- def Reload(self):
- """Call `Reload` on every `EventAccumulator`."""
- logger.info('Beginning EventMultiplexer.Reload()')
- self._reload_called = True
- # Build a list so we're safe even if the list of accumulators is modified
- # even while we're reloading.
- with self._accumulators_mutex:
- items = list(self._accumulators.items())
- items_queue = queue.Queue()
- for item in items:
- items_queue.put(item)
-
- # Methods of built-in python containers are thread-safe so long as the GIL
- # for the thread exists, but we might as well be careful.
- names_to_delete = set()
- names_to_delete_mutex = threading.Lock()
-
- def Worker():
- """Keeps reloading accumulators til none are left."""
- while True:
- try:
- name, accumulator = items_queue.get(block=False)
- except queue.Empty:
- # No more runs to reload.
- break
-
- try:
- accumulator.Reload()
- except (OSError, IOError) as e:
- logger.error('Unable to reload accumulator %r: %s', name, e)
- except directory_watcher.DirectoryDeletedError:
- with names_to_delete_mutex:
- names_to_delete.add(name)
- finally:
- items_queue.task_done()
-
- if self._max_reload_threads > 1:
- num_threads = min(
- self._max_reload_threads, len(items))
- logger.info('Starting %d threads to reload runs', num_threads)
- for i in xrange(num_threads):
- thread = threading.Thread(target=Worker, name='Reloader %d' % i)
- thread.daemon = True
- thread.start()
- items_queue.join()
- else:
- logger.info(
- 'Reloading runs serially (one after another) on the main '
- 'thread.')
- Worker()
-
- with self._accumulators_mutex:
- for name in names_to_delete:
- logger.warn('Deleting accumulator %r', name)
- del self._accumulators[name]
- logger.info('Finished with EventMultiplexer.Reload()')
- return self
-
- def PluginAssets(self, plugin_name):
- """Get index of runs and assets for a given plugin.
-
- Args:
- plugin_name: Name of the plugin we are checking for.
-
- Returns:
- A dictionary that maps from run_name to a list of plugin
- assets for that run.
- """
- with self._accumulators_mutex:
- # To avoid nested locks, we construct a copy of the run-accumulator map
- items = list(six.iteritems(self._accumulators))
-
- return {run: accum.PluginAssets(plugin_name) for run, accum in items}
-
- def RetrievePluginAsset(self, run, plugin_name, asset_name):
- """Return the contents for a specific plugin asset from a run.
-
- Args:
- run: The string name of the run.
- plugin_name: The string name of a plugin.
- asset_name: The string name of an asset.
-
- Returns:
- The string contents of the plugin asset.
-
- Raises:
- KeyError: If the asset is not available.
- """
- accumulator = self.GetAccumulator(run)
- return accumulator.RetrievePluginAsset(plugin_name, asset_name)
-
- def FirstEventTimestamp(self, run):
- """Return the timestamp of the first event of the given run.
-
- This may perform I/O if no events have been loaded yet for the run.
- Args:
- run: A string name of the run for which the timestamp is retrieved.
-
- Returns:
- The wall_time of the first event of the run, which will typically be
- seconds since the epoch.
-
- Raises:
- KeyError: If the run is not found.
- ValueError: If the run has no events loaded and there are no events on
- disk to load.
- """
- accumulator = self.GetAccumulator(run)
- return accumulator.FirstEventTimestamp()
-
- def Graph(self, run):
- """Retrieve the graph associated with the provided run.
-
- Args:
- run: A string name of a run to load the graph for.
-
- Raises:
- KeyError: If the run is not found.
- ValueError: If the run does not have an associated graph.
-
- Returns:
- The `GraphDef` protobuf data structure.
- """
- accumulator = self.GetAccumulator(run)
- return accumulator.Graph()
-
- def SerializedGraph(self, run):
- """Retrieve the serialized graph associated with the provided run.
-
- Args:
- run: A string name of a run to load the graph for.
-
- Raises:
- KeyError: If the run is not found.
- ValueError: If the run does not have an associated graph.
-
- Returns:
- The serialized form of the `GraphDef` protobuf data structure.
- """
- accumulator = self.GetAccumulator(run)
- return accumulator.SerializedGraph()
-
- def MetaGraph(self, run):
- """Retrieve the metagraph associated with the provided run.
-
- Args:
- run: A string name of a run to load the graph for.
-
- Raises:
- KeyError: If the run is not found.
- ValueError: If the run does not have an associated graph.
-
- Returns:
- The `MetaGraphDef` protobuf data structure.
- """
- accumulator = self.GetAccumulator(run)
- return accumulator.MetaGraph()
-
- def RunMetadata(self, run, tag):
- """Get the session.run() metadata associated with a TensorFlow run and tag.
-
- Args:
- run: A string name of a TensorFlow run.
- tag: A string name of the tag associated with a particular session.run().
-
- Raises:
- KeyError: If the run is not found, or the tag is not available for the
- given run.
-
- Returns:
- The metadata in the form of `RunMetadata` protobuf data structure.
- """
- accumulator = self.GetAccumulator(run)
- return accumulator.RunMetadata(tag)
-
- def Tensors(self, run, tag):
- """Retrieve the tensor events associated with a run and tag.
-
- Args:
- run: A string name of the run for which values are retrieved.
- tag: A string name of the tag for which values are retrieved.
-
- Raises:
- KeyError: If the run is not found, or the tag is not available for
- the given run.
-
- Returns:
- An array of `event_accumulator.TensorEvent`s.
- """
- accumulator = self.GetAccumulator(run)
- return accumulator.Tensors(tag)
-
- def PluginRunToTagToContent(self, plugin_name):
- """Returns a 2-layer dictionary of the form {run: {tag: content}}.
-
- The `content` referred above is the content field of the PluginData proto
- for the specified plugin within a Summary.Value proto.
-
- Args:
- plugin_name: The name of the plugin for which to fetch content.
+class EventMultiplexer(object):
+ """An `EventMultiplexer` manages access to multiple `EventAccumulator`s.
- Returns:
- A dictionary of the form {run: {tag: content}}.
- """
- mapping = {}
- for run in self.Runs():
- try:
- tag_to_content = self.GetAccumulator(run).PluginTagToContent(
- plugin_name)
- except KeyError:
- # This run lacks content for the plugin. Try the next run.
- continue
- mapping[run] = tag_to_content
- return mapping
-
- def SummaryMetadata(self, run, tag):
- """Return the summary metadata for the given tag on the given run.
-
- Args:
- run: A string name of the run for which summary metadata is to be
- retrieved.
- tag: A string name of the tag whose summary metadata is to be
- retrieved.
-
- Raises:
- KeyError: If the run is not found, or the tag is not available for
- the given run.
-
- Returns:
- A `SummaryMetadata` protobuf.
- """
- accumulator = self.GetAccumulator(run)
- return accumulator.SummaryMetadata(tag)
+ Each `EventAccumulator` is associated with a `run`, which is a self-contained
+ TensorFlow execution. The `EventMultiplexer` provides methods for extracting
+ information about events from multiple `run`s.
- def Runs(self):
- """Return all the run names in the `EventMultiplexer`.
+ Example usage for loading specific runs from files:
- Returns:
- ```
- {runName: { scalarValues: [tagA, tagB, tagC],
- graph: true, meta_graph: true}}
+ ```python
+ x = EventMultiplexer({'run1': 'path/to/run1', 'run2': 'path/to/run2'})
+ x.Reload()
```
- """
- with self._accumulators_mutex:
- # To avoid nested locks, we construct a copy of the run-accumulator map
- items = list(six.iteritems(self._accumulators))
- return {run_name: accumulator.Tags() for run_name, accumulator in items}
- def RunPaths(self):
- """Returns a dict mapping run names to event file paths."""
- return self._paths
+ Example usage for loading a directory where each subdirectory is a run
- def GetAccumulator(self, run):
- """Returns EventAccumulator for a given run.
+ ```python
+ (eg:) /parent/directory/path/
+ /parent/directory/path/run1/
+ /parent/directory/path/run1/events.out.tfevents.1001
+ /parent/directory/path/run1/events.out.tfevents.1002
- Args:
- run: String name of run.
+ /parent/directory/path/run2/
+ /parent/directory/path/run2/events.out.tfevents.9232
- Returns:
- An EventAccumulator object.
+ /parent/directory/path/run3/
+ /parent/directory/path/run3/events.out.tfevents.9232
+ x = EventMultiplexer().AddRunsFromDirectory('/parent/directory/path')
+ (which is equivalent to:)
+ x = EventMultiplexer({'run1': '/parent/directory/path/run1', 'run2':...}
+ ```
- Raises:
- KeyError: If run does not exist.
+ If you would like to watch `/parent/directory/path`, wait for it to be created
+ (if necessary) and then periodically pick up new runs, use
+ `AutoloadingMultiplexer`
+ @@Tensors
"""
- with self._accumulators_mutex:
- return self._accumulators[run]
+
+ def __init__(
+ self,
+ run_path_map=None,
+ size_guidance=None,
+ tensor_size_guidance=None,
+ purge_orphaned_data=True,
+ max_reload_threads=None,
+ event_file_active_filter=None,
+ ):
+ """Constructor for the `EventMultiplexer`.
+
+ Args:
+ run_path_map: Dict `{run: path}` which specifies the
+ name of a run, and the path to find the associated events. If it is
+ None, then the EventMultiplexer initializes without any runs.
+ size_guidance: A dictionary mapping from `tagType` to the number of items
+ to store for each tag of that type. See
+ `event_accumulator.EventAccumulator` for details.
+ tensor_size_guidance: A dictionary mapping from `plugin_name` to
+ the number of items to store for each tag of that type. See
+ `event_accumulator.EventAccumulator` for details.
+ purge_orphaned_data: Whether to discard any events that were "orphaned" by
+ a TensorFlow restart.
+ max_reload_threads: The max number of threads that TensorBoard can use
+ to reload runs. Each thread reloads one run at a time. If not provided,
+ reloads runs serially (one after another).
+ event_file_active_filter: Optional predicate for determining whether an
+ event file latest load timestamp should be considered active. If passed,
+ this will enable multifile directory loading.
+ """
+ logger.info("Event Multiplexer initializing.")
+ self._accumulators_mutex = threading.Lock()
+ self._accumulators = {}
+ self._paths = {}
+ self._reload_called = False
+ self._size_guidance = (
+ size_guidance or event_accumulator.DEFAULT_SIZE_GUIDANCE
+ )
+ self._tensor_size_guidance = tensor_size_guidance
+ self.purge_orphaned_data = purge_orphaned_data
+ self._max_reload_threads = max_reload_threads or 1
+ self._event_file_active_filter = event_file_active_filter
+ if run_path_map is not None:
+ logger.info(
+ "Event Multplexer doing initialization load for %s",
+ run_path_map,
+ )
+ for (run, path) in six.iteritems(run_path_map):
+ self.AddRun(path, run)
+ logger.info("Event Multiplexer done initializing")
+
+ def AddRun(self, path, name=None):
+ """Add a run to the multiplexer.
+
+ If the name is not specified, it is the same as the path.
+
+ If a run by that name exists, and we are already watching the right path,
+ do nothing. If we are watching a different path, replace the event
+ accumulator.
+
+ If `Reload` has been called, it will `Reload` the newly created
+ accumulators.
+
+ Args:
+ path: Path to the event files (or event directory) for given run.
+ name: Name of the run to add. If not provided, is set to path.
+
+ Returns:
+ The `EventMultiplexer`.
+ """
+ name = name or path
+ accumulator = None
+ with self._accumulators_mutex:
+ if name not in self._accumulators or self._paths[name] != path:
+ if name in self._paths and self._paths[name] != path:
+ # TODO(@decentralion) - Make it impossible to overwrite an old path
+ # with a new path (just give the new path a distinct name)
+ logger.warn(
+ "Conflict for name %s: old path %s, new path %s",
+ name,
+ self._paths[name],
+ path,
+ )
+ logger.info("Constructing EventAccumulator for %s", path)
+ accumulator = event_accumulator.EventAccumulator(
+ path,
+ size_guidance=self._size_guidance,
+ tensor_size_guidance=self._tensor_size_guidance,
+ purge_orphaned_data=self.purge_orphaned_data,
+ event_file_active_filter=self._event_file_active_filter,
+ )
+ self._accumulators[name] = accumulator
+ self._paths[name] = path
+ if accumulator:
+ if self._reload_called:
+ accumulator.Reload()
+ return self
+
+ def AddRunsFromDirectory(self, path, name=None):
+ """Load runs from a directory; recursively walks subdirectories.
+
+ If path doesn't exist, no-op. This ensures that it is safe to call
+ `AddRunsFromDirectory` multiple times, even before the directory is made.
+
+ If path is a directory, load event files in the directory (if any exist) and
+ recursively call AddRunsFromDirectory on any subdirectories. This mean you
+ can call AddRunsFromDirectory at the root of a tree of event logs and
+ TensorBoard will load them all.
+
+ If the `EventMultiplexer` is already loaded this will cause
+ the newly created accumulators to `Reload()`.
+ Args:
+ path: A string path to a directory to load runs from.
+ name: Optionally, what name to apply to the runs. If name is provided
+ and the directory contains run subdirectories, the name of each subrun
+ is the concatenation of the parent name and the subdirectory name. If
+ name is provided and the directory contains event files, then a run
+ is added called "name" and with the events from the path.
+
+ Raises:
+ ValueError: If the path exists and isn't a directory.
+
+ Returns:
+ The `EventMultiplexer`.
+ """
+ logger.info("Starting AddRunsFromDirectory: %s", path)
+ for subdir in io_wrapper.GetLogdirSubdirectories(path):
+ logger.info("Adding run from directory %s", subdir)
+ rpath = os.path.relpath(subdir, path)
+ subname = os.path.join(name, rpath) if name else rpath
+ self.AddRun(subdir, name=subname)
+ logger.info("Done with AddRunsFromDirectory: %s", path)
+ return self
+
+ def Reload(self):
+ """Call `Reload` on every `EventAccumulator`."""
+ logger.info("Beginning EventMultiplexer.Reload()")
+ self._reload_called = True
+ # Build a list so we're safe even if the list of accumulators is modified
+ # even while we're reloading.
+ with self._accumulators_mutex:
+ items = list(self._accumulators.items())
+ items_queue = queue.Queue()
+ for item in items:
+ items_queue.put(item)
+
+ # Methods of built-in python containers are thread-safe so long as the GIL
+ # for the thread exists, but we might as well be careful.
+ names_to_delete = set()
+ names_to_delete_mutex = threading.Lock()
+
+ def Worker():
+ """Keeps reloading accumulators til none are left."""
+ while True:
+ try:
+ name, accumulator = items_queue.get(block=False)
+ except queue.Empty:
+ # No more runs to reload.
+ break
+
+ try:
+ accumulator.Reload()
+ except (OSError, IOError) as e:
+ logger.error("Unable to reload accumulator %r: %s", name, e)
+ except directory_watcher.DirectoryDeletedError:
+ with names_to_delete_mutex:
+ names_to_delete.add(name)
+ finally:
+ items_queue.task_done()
+
+ if self._max_reload_threads > 1:
+ num_threads = min(self._max_reload_threads, len(items))
+ logger.info("Starting %d threads to reload runs", num_threads)
+ for i in xrange(num_threads):
+ thread = threading.Thread(target=Worker, name="Reloader %d" % i)
+ thread.daemon = True
+ thread.start()
+ items_queue.join()
+ else:
+ logger.info(
+ "Reloading runs serially (one after another) on the main "
+ "thread."
+ )
+ Worker()
+
+ with self._accumulators_mutex:
+ for name in names_to_delete:
+ logger.warn("Deleting accumulator %r", name)
+ del self._accumulators[name]
+ logger.info("Finished with EventMultiplexer.Reload()")
+ return self
+
+ def PluginAssets(self, plugin_name):
+ """Get index of runs and assets for a given plugin.
+
+ Args:
+ plugin_name: Name of the plugin we are checking for.
+
+ Returns:
+ A dictionary that maps from run_name to a list of plugin
+ assets for that run.
+ """
+ with self._accumulators_mutex:
+ # To avoid nested locks, we construct a copy of the run-accumulator map
+ items = list(six.iteritems(self._accumulators))
+
+ return {run: accum.PluginAssets(plugin_name) for run, accum in items}
+
+ def RetrievePluginAsset(self, run, plugin_name, asset_name):
+ """Return the contents for a specific plugin asset from a run.
+
+ Args:
+ run: The string name of the run.
+ plugin_name: The string name of a plugin.
+ asset_name: The string name of an asset.
+
+ Returns:
+ The string contents of the plugin asset.
+
+ Raises:
+ KeyError: If the asset is not available.
+ """
+ accumulator = self.GetAccumulator(run)
+ return accumulator.RetrievePluginAsset(plugin_name, asset_name)
+
+ def FirstEventTimestamp(self, run):
+ """Return the timestamp of the first event of the given run.
+
+ This may perform I/O if no events have been loaded yet for the run.
+
+ Args:
+ run: A string name of the run for which the timestamp is retrieved.
+
+ Returns:
+ The wall_time of the first event of the run, which will typically be
+ seconds since the epoch.
+
+ Raises:
+ KeyError: If the run is not found.
+ ValueError: If the run has no events loaded and there are no events on
+ disk to load.
+ """
+ accumulator = self.GetAccumulator(run)
+ return accumulator.FirstEventTimestamp()
+
+ def Graph(self, run):
+ """Retrieve the graph associated with the provided run.
+
+ Args:
+ run: A string name of a run to load the graph for.
+
+ Raises:
+ KeyError: If the run is not found.
+ ValueError: If the run does not have an associated graph.
+
+ Returns:
+ The `GraphDef` protobuf data structure.
+ """
+ accumulator = self.GetAccumulator(run)
+ return accumulator.Graph()
+
+ def SerializedGraph(self, run):
+ """Retrieve the serialized graph associated with the provided run.
+
+ Args:
+ run: A string name of a run to load the graph for.
+
+ Raises:
+ KeyError: If the run is not found.
+ ValueError: If the run does not have an associated graph.
+
+ Returns:
+ The serialized form of the `GraphDef` protobuf data structure.
+ """
+ accumulator = self.GetAccumulator(run)
+ return accumulator.SerializedGraph()
+
+ def MetaGraph(self, run):
+ """Retrieve the metagraph associated with the provided run.
+
+ Args:
+ run: A string name of a run to load the graph for.
+
+ Raises:
+ KeyError: If the run is not found.
+ ValueError: If the run does not have an associated graph.
+
+ Returns:
+ The `MetaGraphDef` protobuf data structure.
+ """
+ accumulator = self.GetAccumulator(run)
+ return accumulator.MetaGraph()
+
+ def RunMetadata(self, run, tag):
+ """Get the session.run() metadata associated with a TensorFlow run and
+ tag.
+
+ Args:
+ run: A string name of a TensorFlow run.
+ tag: A string name of the tag associated with a particular session.run().
+
+ Raises:
+ KeyError: If the run is not found, or the tag is not available for the
+ given run.
+
+ Returns:
+ The metadata in the form of `RunMetadata` protobuf data structure.
+ """
+ accumulator = self.GetAccumulator(run)
+ return accumulator.RunMetadata(tag)
+
+ def Tensors(self, run, tag):
+ """Retrieve the tensor events associated with a run and tag.
+
+ Args:
+ run: A string name of the run for which values are retrieved.
+ tag: A string name of the tag for which values are retrieved.
+
+ Raises:
+ KeyError: If the run is not found, or the tag is not available for
+ the given run.
+
+ Returns:
+ An array of `event_accumulator.TensorEvent`s.
+ """
+ accumulator = self.GetAccumulator(run)
+ return accumulator.Tensors(tag)
+
+ def PluginRunToTagToContent(self, plugin_name):
+ """Returns a 2-layer dictionary of the form {run: {tag: content}}.
+
+ The `content` referred above is the content field of the PluginData proto
+ for the specified plugin within a Summary.Value proto.
+
+ Args:
+ plugin_name: The name of the plugin for which to fetch content.
+
+ Returns:
+ A dictionary of the form {run: {tag: content}}.
+ """
+ mapping = {}
+ for run in self.Runs():
+ try:
+ tag_to_content = self.GetAccumulator(run).PluginTagToContent(
+ plugin_name
+ )
+ except KeyError:
+ # This run lacks content for the plugin. Try the next run.
+ continue
+ mapping[run] = tag_to_content
+ return mapping
+
+ def SummaryMetadata(self, run, tag):
+ """Return the summary metadata for the given tag on the given run.
+
+ Args:
+ run: A string name of the run for which summary metadata is to be
+ retrieved.
+ tag: A string name of the tag whose summary metadata is to be
+ retrieved.
+
+ Raises:
+ KeyError: If the run is not found, or the tag is not available for
+ the given run.
+
+ Returns:
+ A `SummaryMetadata` protobuf.
+ """
+ accumulator = self.GetAccumulator(run)
+ return accumulator.SummaryMetadata(tag)
+
+ def Runs(self):
+ """Return all the run names in the `EventMultiplexer`.
+
+ Returns:
+ ```
+ {runName: { scalarValues: [tagA, tagB, tagC],
+ graph: true, meta_graph: true}}
+ ```
+ """
+ with self._accumulators_mutex:
+ # To avoid nested locks, we construct a copy of the run-accumulator map
+ items = list(six.iteritems(self._accumulators))
+ return {run_name: accumulator.Tags() for run_name, accumulator in items}
+
+ def RunPaths(self):
+ """Returns a dict mapping run names to event file paths."""
+ return self._paths
+
+ def GetAccumulator(self, run):
+ """Returns EventAccumulator for a given run.
+
+ Args:
+ run: String name of run.
+
+ Returns:
+ An EventAccumulator object.
+
+ Raises:
+ KeyError: If run does not exist.
+ """
+ with self._accumulators_mutex:
+ return self._accumulators[run]
diff --git a/tensorboard/backend/event_processing/plugin_event_multiplexer_test.py b/tensorboard/backend/event_processing/plugin_event_multiplexer_test.py
index 0614ed1895..62207e631d 100644
--- a/tensorboard/backend/event_processing/plugin_event_multiplexer_test.py
+++ b/tensorboard/backend/event_processing/plugin_event_multiplexer_test.py
@@ -23,402 +23,439 @@
import tensorflow as tf
-from tensorboard.backend.event_processing import plugin_event_accumulator as event_accumulator
-from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer
+from tensorboard.backend.event_processing import (
+ plugin_event_accumulator as event_accumulator,
+)
+from tensorboard.backend.event_processing import (
+ plugin_event_multiplexer as event_multiplexer,
+)
from tensorboard.util import test_util
def _AddEvents(path):
- if not tf.io.gfile.isdir(path):
- tf.io.gfile.makedirs(path)
- fpath = os.path.join(path, 'hypothetical.tfevents.out')
- with tf.io.gfile.GFile(fpath, 'w') as f:
- f.write('')
- return fpath
+ if not tf.io.gfile.isdir(path):
+ tf.io.gfile.makedirs(path)
+ fpath = os.path.join(path, "hypothetical.tfevents.out")
+ with tf.io.gfile.GFile(fpath, "w") as f:
+ f.write("")
+ return fpath
def _CreateCleanDirectory(path):
- if tf.io.gfile.isdir(path):
- tf.io.gfile.rmtree(path)
- tf.io.gfile.mkdir(path)
+ if tf.io.gfile.isdir(path):
+ tf.io.gfile.rmtree(path)
+ tf.io.gfile.mkdir(path)
class _FakeAccumulator(object):
-
- def __init__(self, path):
- """Constructs a fake accumulator with some fake events.
-
- Args:
- path: The path for the run that this accumulator is for.
- """
- self._path = path
- self.reload_called = False
- self._plugin_to_tag_to_content = {
- 'baz_plugin': {
- 'foo': 'foo_content',
- 'bar': 'bar_content',
+ def __init__(self, path):
+ """Constructs a fake accumulator with some fake events.
+
+ Args:
+ path: The path for the run that this accumulator is for.
+ """
+ self._path = path
+ self.reload_called = False
+ self._plugin_to_tag_to_content = {
+ "baz_plugin": {"foo": "foo_content", "bar": "bar_content",}
}
- }
- def Tags(self):
- return {}
+ def Tags(self):
+ return {}
- def FirstEventTimestamp(self):
- return 0
+ def FirstEventTimestamp(self):
+ return 0
- def _TagHelper(self, tag_name, enum):
- if tag_name not in self.Tags()[enum]:
- raise KeyError
- return ['%s/%s' % (self._path, tag_name)]
+ def _TagHelper(self, tag_name, enum):
+ if tag_name not in self.Tags()[enum]:
+ raise KeyError
+ return ["%s/%s" % (self._path, tag_name)]
- def Tensors(self, tag_name):
- return self._TagHelper(tag_name, event_accumulator.TENSORS)
+ def Tensors(self, tag_name):
+ return self._TagHelper(tag_name, event_accumulator.TENSORS)
- def PluginTagToContent(self, plugin_name):
- # We pre-pend the runs with the path and '_' so that we can verify that the
- # tags are associated with the correct runs.
- return {
- self._path + '_' + run: content_mapping
- for (run, content_mapping
- ) in self._plugin_to_tag_to_content[plugin_name].items()
- }
+ def PluginTagToContent(self, plugin_name):
+ # We pre-pend the runs with the path and '_' so that we can verify that the
+ # tags are associated with the correct runs.
+ return {
+ self._path + "_" + run: content_mapping
+ for (run, content_mapping) in self._plugin_to_tag_to_content[
+ plugin_name
+ ].items()
+ }
- def Reload(self):
- self.reload_called = True
+ def Reload(self):
+ self.reload_called = True
-def _GetFakeAccumulator(path,
- size_guidance=None,
- tensor_size_guidance=None,
- purge_orphaned_data=None,
- event_file_active_filter=None):
- del size_guidance, tensor_size_guidance, purge_orphaned_data # Unused.
- del event_file_active_filter # unused
- return _FakeAccumulator(path)
+def _GetFakeAccumulator(
+ path,
+ size_guidance=None,
+ tensor_size_guidance=None,
+ purge_orphaned_data=None,
+ event_file_active_filter=None,
+):
+ del size_guidance, tensor_size_guidance, purge_orphaned_data # Unused.
+ del event_file_active_filter # unused
+ return _FakeAccumulator(path)
class EventMultiplexerTest(tf.test.TestCase):
-
- def setUp(self):
- super(EventMultiplexerTest, self).setUp()
- self.stubs = tf.compat.v1.test.StubOutForTesting()
-
- self.stubs.Set(event_accumulator, 'EventAccumulator', _GetFakeAccumulator)
-
- def tearDown(self):
- self.stubs.CleanUp()
-
- def testEmptyLoader(self):
- """Tests empty EventMultiplexer creation."""
- x = event_multiplexer.EventMultiplexer()
- self.assertEqual(x.Runs(), {})
-
- def testRunNamesRespected(self):
- """Tests two EventAccumulators inserted/accessed in EventMultiplexer."""
- x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
- self.assertItemsEqual(sorted(x.Runs().keys()), ['run1', 'run2'])
- self.assertEqual(x.GetAccumulator('run1')._path, 'path1')
- self.assertEqual(x.GetAccumulator('run2')._path, 'path2')
-
- def testReload(self):
- """EventAccumulators should Reload after EventMultiplexer call it."""
- x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
- self.assertFalse(x.GetAccumulator('run1').reload_called)
- self.assertFalse(x.GetAccumulator('run2').reload_called)
- x.Reload()
- self.assertTrue(x.GetAccumulator('run1').reload_called)
- self.assertTrue(x.GetAccumulator('run2').reload_called)
-
- def testPluginRunToTagToContent(self):
- """Tests the method that produces the run to tag to content mapping."""
- x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
- self.assertDictEqual({
- 'run1': {
- 'path1_foo': 'foo_content',
- 'path1_bar': 'bar_content',
- },
- 'run2': {
- 'path2_foo': 'foo_content',
- 'path2_bar': 'bar_content',
- }
- }, x.PluginRunToTagToContent('baz_plugin'))
-
- def testExceptions(self):
- """KeyError should be raised when accessing non-existing keys."""
- x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
- with self.assertRaises(KeyError):
- x.Tensors('sv1', 'xxx')
-
- def testInitialization(self):
- """Tests EventMultiplexer is created properly with its params."""
- x = event_multiplexer.EventMultiplexer()
- self.assertEqual(x.Runs(), {})
- x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
- self.assertItemsEqual(x.Runs(), ['run1', 'run2'])
- self.assertEqual(x.GetAccumulator('run1')._path, 'path1')
- self.assertEqual(x.GetAccumulator('run2')._path, 'path2')
-
- def testAddRunsFromDirectory(self):
- """Tests AddRunsFromDirectory function.
-
- Tests the following scenarios:
- - When the directory does not exist.
- - When the directory is empty.
- - When the directory has empty subdirectory.
- - Contains proper EventAccumulators after adding events.
- """
- x = event_multiplexer.EventMultiplexer()
- tmpdir = self.get_temp_dir()
- join = os.path.join
- fakedir = join(tmpdir, 'fake_accumulator_directory')
- realdir = join(tmpdir, 'real_accumulator_directory')
- self.assertEqual(x.Runs(), {})
- x.AddRunsFromDirectory(fakedir)
- self.assertEqual(x.Runs(), {}, 'loading fakedir had no effect')
-
- _CreateCleanDirectory(realdir)
- x.AddRunsFromDirectory(realdir)
- self.assertEqual(x.Runs(), {}, 'loading empty directory had no effect')
-
- path1 = join(realdir, 'path1')
- tf.io.gfile.mkdir(path1)
- x.AddRunsFromDirectory(realdir)
- self.assertEqual(x.Runs(), {}, 'creating empty subdirectory had no effect')
-
- _AddEvents(path1)
- x.AddRunsFromDirectory(realdir)
- self.assertItemsEqual(x.Runs(), ['path1'], 'loaded run: path1')
- loader1 = x.GetAccumulator('path1')
- self.assertEqual(loader1._path, path1, 'has the correct path')
-
- path2 = join(realdir, 'path2')
- _AddEvents(path2)
- x.AddRunsFromDirectory(realdir)
- self.assertItemsEqual(x.Runs(), ['path1', 'path2'])
- self.assertEqual(
- x.GetAccumulator('path1'), loader1, 'loader1 not regenerated')
-
- path2_2 = join(path2, 'path2')
- _AddEvents(path2_2)
- x.AddRunsFromDirectory(realdir)
- self.assertItemsEqual(x.Runs(), ['path1', 'path2', 'path2/path2'])
- self.assertEqual(
- x.GetAccumulator('path2/path2')._path, path2_2, 'loader2 path correct')
-
- def testAddRunsFromDirectoryThatContainsEvents(self):
- x = event_multiplexer.EventMultiplexer()
- tmpdir = self.get_temp_dir()
- join = os.path.join
- realdir = join(tmpdir, 'event_containing_directory')
-
- _CreateCleanDirectory(realdir)
-
- self.assertEqual(x.Runs(), {})
-
- _AddEvents(realdir)
- x.AddRunsFromDirectory(realdir)
- self.assertItemsEqual(x.Runs(), ['.'])
-
- subdir = join(realdir, 'subdir')
- _AddEvents(subdir)
- x.AddRunsFromDirectory(realdir)
- self.assertItemsEqual(x.Runs(), ['.', 'subdir'])
-
- def testAddRunsFromDirectoryWithRunNames(self):
- x = event_multiplexer.EventMultiplexer()
- tmpdir = self.get_temp_dir()
- join = os.path.join
- realdir = join(tmpdir, 'event_containing_directory')
-
- _CreateCleanDirectory(realdir)
-
- self.assertEqual(x.Runs(), {})
-
- _AddEvents(realdir)
- x.AddRunsFromDirectory(realdir, 'foo')
- self.assertItemsEqual(x.Runs(), ['foo/.'])
-
- subdir = join(realdir, 'subdir')
- _AddEvents(subdir)
- x.AddRunsFromDirectory(realdir, 'foo')
- self.assertItemsEqual(x.Runs(), ['foo/.', 'foo/subdir'])
-
- def testAddRunsFromDirectoryWalksTree(self):
- x = event_multiplexer.EventMultiplexer()
- tmpdir = self.get_temp_dir()
- join = os.path.join
- realdir = join(tmpdir, 'event_containing_directory')
-
- _CreateCleanDirectory(realdir)
- _AddEvents(realdir)
- sub = join(realdir, 'subdirectory')
- sub1 = join(sub, '1')
- sub2 = join(sub, '2')
- sub1_1 = join(sub1, '1')
- _AddEvents(sub1)
- _AddEvents(sub2)
- _AddEvents(sub1_1)
- x.AddRunsFromDirectory(realdir)
-
- self.assertItemsEqual(x.Runs(), ['.', 'subdirectory/1', 'subdirectory/2',
- 'subdirectory/1/1'])
-
- def testAddRunsFromDirectoryThrowsException(self):
- x = event_multiplexer.EventMultiplexer()
- tmpdir = self.get_temp_dir()
-
- filepath = _AddEvents(tmpdir)
- with self.assertRaises(ValueError):
- x.AddRunsFromDirectory(filepath)
-
- def testAddRun(self):
- x = event_multiplexer.EventMultiplexer()
- x.AddRun('run1_path', 'run1')
- run1 = x.GetAccumulator('run1')
- self.assertEqual(sorted(x.Runs().keys()), ['run1'])
- self.assertEqual(run1._path, 'run1_path')
-
- x.AddRun('run1_path', 'run1')
- self.assertEqual(run1, x.GetAccumulator('run1'), 'loader not recreated')
-
- x.AddRun('run2_path', 'run1')
- new_run1 = x.GetAccumulator('run1')
- self.assertEqual(new_run1._path, 'run2_path')
- self.assertNotEqual(run1, new_run1)
-
- x.AddRun('runName3')
- self.assertItemsEqual(sorted(x.Runs().keys()), ['run1', 'runName3'])
- self.assertEqual(x.GetAccumulator('runName3')._path, 'runName3')
-
- def testAddRunMaintainsLoading(self):
- x = event_multiplexer.EventMultiplexer()
- x.Reload()
- x.AddRun('run1')
- x.AddRun('run2')
- self.assertTrue(x.GetAccumulator('run1').reload_called)
- self.assertTrue(x.GetAccumulator('run2').reload_called)
-
- def testAddReloadWithMultipleThreads(self):
- x = event_multiplexer.EventMultiplexer(max_reload_threads=2)
- x.Reload()
- x.AddRun('run1')
- x.AddRun('run2')
- x.AddRun('run3')
- self.assertTrue(x.GetAccumulator('run1').reload_called)
- self.assertTrue(x.GetAccumulator('run2').reload_called)
- self.assertTrue(x.GetAccumulator('run3').reload_called)
-
- def testReloadWithMoreRunsThanThreads(self):
- patcher = tf.compat.v1.test.mock.patch('threading.Thread.start', autospec=True)
- start_mock = patcher.start()
- self.addCleanup(patcher.stop)
- patcher = tf.compat.v1.test.mock.patch(
- 'six.moves.queue.Queue.join', autospec=True)
- join_mock = patcher.start()
- self.addCleanup(patcher.stop)
-
- x = event_multiplexer.EventMultiplexer(max_reload_threads=2)
- x.AddRun('run1')
- x.AddRun('run2')
- x.AddRun('run3')
- x.Reload()
-
- # 2 threads should have been started despite how there are 3 runs.
- self.assertEqual(2, start_mock.call_count)
- self.assertEqual(1, join_mock.call_count)
-
- def testReloadWithMoreThreadsThanRuns(self):
- patcher = tf.compat.v1.test.mock.patch('threading.Thread.start', autospec=True)
- start_mock = patcher.start()
- self.addCleanup(patcher.stop)
- patcher = tf.compat.v1.test.mock.patch(
- 'six.moves.queue.Queue.join', autospec=True)
- join_mock = patcher.start()
- self.addCleanup(patcher.stop)
-
- x = event_multiplexer.EventMultiplexer(max_reload_threads=42)
- x.AddRun('run1')
- x.AddRun('run2')
- x.AddRun('run3')
- x.Reload()
-
- # 3 threads should have been started despite how the multiplexer
- # could have started up to 42 threads.
- self.assertEqual(3, start_mock.call_count)
- self.assertEqual(1, join_mock.call_count)
-
- def testReloadWith1Thread(self):
- patcher = tf.compat.v1.test.mock.patch('threading.Thread.start', autospec=True)
- start_mock = patcher.start()
- self.addCleanup(patcher.stop)
- patcher = tf.compat.v1.test.mock.patch(
- 'six.moves.queue.Queue.join', autospec=True)
- join_mock = patcher.start()
- self.addCleanup(patcher.stop)
-
- x = event_multiplexer.EventMultiplexer(max_reload_threads=1)
- x.AddRun('run1')
- x.AddRun('run2')
- x.AddRun('run3')
- x.Reload()
-
- # The multiplexer should have started no new threads.
- self.assertEqual(0, start_mock.call_count)
- self.assertEqual(0, join_mock.call_count)
+ def setUp(self):
+ super(EventMultiplexerTest, self).setUp()
+ self.stubs = tf.compat.v1.test.StubOutForTesting()
+
+ self.stubs.Set(
+ event_accumulator, "EventAccumulator", _GetFakeAccumulator
+ )
+
+ def tearDown(self):
+ self.stubs.CleanUp()
+
+ def testEmptyLoader(self):
+ """Tests empty EventMultiplexer creation."""
+ x = event_multiplexer.EventMultiplexer()
+ self.assertEqual(x.Runs(), {})
+
+ def testRunNamesRespected(self):
+ """Tests two EventAccumulators inserted/accessed in
+ EventMultiplexer."""
+ x = event_multiplexer.EventMultiplexer(
+ {"run1": "path1", "run2": "path2"}
+ )
+ self.assertItemsEqual(sorted(x.Runs().keys()), ["run1", "run2"])
+ self.assertEqual(x.GetAccumulator("run1")._path, "path1")
+ self.assertEqual(x.GetAccumulator("run2")._path, "path2")
+
+ def testReload(self):
+ """EventAccumulators should Reload after EventMultiplexer call it."""
+ x = event_multiplexer.EventMultiplexer(
+ {"run1": "path1", "run2": "path2"}
+ )
+ self.assertFalse(x.GetAccumulator("run1").reload_called)
+ self.assertFalse(x.GetAccumulator("run2").reload_called)
+ x.Reload()
+ self.assertTrue(x.GetAccumulator("run1").reload_called)
+ self.assertTrue(x.GetAccumulator("run2").reload_called)
+
+ def testPluginRunToTagToContent(self):
+ """Tests the method that produces the run to tag to content mapping."""
+ x = event_multiplexer.EventMultiplexer(
+ {"run1": "path1", "run2": "path2"}
+ )
+ self.assertDictEqual(
+ {
+ "run1": {
+ "path1_foo": "foo_content",
+ "path1_bar": "bar_content",
+ },
+ "run2": {
+ "path2_foo": "foo_content",
+ "path2_bar": "bar_content",
+ },
+ },
+ x.PluginRunToTagToContent("baz_plugin"),
+ )
+
+ def testExceptions(self):
+ """KeyError should be raised when accessing non-existing keys."""
+ x = event_multiplexer.EventMultiplexer(
+ {"run1": "path1", "run2": "path2"}
+ )
+ with self.assertRaises(KeyError):
+ x.Tensors("sv1", "xxx")
+
+ def testInitialization(self):
+ """Tests EventMultiplexer is created properly with its params."""
+ x = event_multiplexer.EventMultiplexer()
+ self.assertEqual(x.Runs(), {})
+ x = event_multiplexer.EventMultiplexer(
+ {"run1": "path1", "run2": "path2"}
+ )
+ self.assertItemsEqual(x.Runs(), ["run1", "run2"])
+ self.assertEqual(x.GetAccumulator("run1")._path, "path1")
+ self.assertEqual(x.GetAccumulator("run2")._path, "path2")
+
+ def testAddRunsFromDirectory(self):
+ """Tests AddRunsFromDirectory function.
+
+ Tests the following scenarios:
+ - When the directory does not exist.
+ - When the directory is empty.
+ - When the directory has empty subdirectory.
+ - Contains proper EventAccumulators after adding events.
+ """
+ x = event_multiplexer.EventMultiplexer()
+ tmpdir = self.get_temp_dir()
+ join = os.path.join
+ fakedir = join(tmpdir, "fake_accumulator_directory")
+ realdir = join(tmpdir, "real_accumulator_directory")
+ self.assertEqual(x.Runs(), {})
+ x.AddRunsFromDirectory(fakedir)
+ self.assertEqual(x.Runs(), {}, "loading fakedir had no effect")
+
+ _CreateCleanDirectory(realdir)
+ x.AddRunsFromDirectory(realdir)
+ self.assertEqual(x.Runs(), {}, "loading empty directory had no effect")
+
+ path1 = join(realdir, "path1")
+ tf.io.gfile.mkdir(path1)
+ x.AddRunsFromDirectory(realdir)
+ self.assertEqual(
+ x.Runs(), {}, "creating empty subdirectory had no effect"
+ )
+
+ _AddEvents(path1)
+ x.AddRunsFromDirectory(realdir)
+ self.assertItemsEqual(x.Runs(), ["path1"], "loaded run: path1")
+ loader1 = x.GetAccumulator("path1")
+ self.assertEqual(loader1._path, path1, "has the correct path")
+
+ path2 = join(realdir, "path2")
+ _AddEvents(path2)
+ x.AddRunsFromDirectory(realdir)
+ self.assertItemsEqual(x.Runs(), ["path1", "path2"])
+ self.assertEqual(
+ x.GetAccumulator("path1"), loader1, "loader1 not regenerated"
+ )
+
+ path2_2 = join(path2, "path2")
+ _AddEvents(path2_2)
+ x.AddRunsFromDirectory(realdir)
+ self.assertItemsEqual(x.Runs(), ["path1", "path2", "path2/path2"])
+ self.assertEqual(
+ x.GetAccumulator("path2/path2")._path,
+ path2_2,
+ "loader2 path correct",
+ )
+
+ def testAddRunsFromDirectoryThatContainsEvents(self):
+ x = event_multiplexer.EventMultiplexer()
+ tmpdir = self.get_temp_dir()
+ join = os.path.join
+ realdir = join(tmpdir, "event_containing_directory")
+
+ _CreateCleanDirectory(realdir)
+
+ self.assertEqual(x.Runs(), {})
+
+ _AddEvents(realdir)
+ x.AddRunsFromDirectory(realdir)
+ self.assertItemsEqual(x.Runs(), ["."])
+
+ subdir = join(realdir, "subdir")
+ _AddEvents(subdir)
+ x.AddRunsFromDirectory(realdir)
+ self.assertItemsEqual(x.Runs(), [".", "subdir"])
+
+ def testAddRunsFromDirectoryWithRunNames(self):
+ x = event_multiplexer.EventMultiplexer()
+ tmpdir = self.get_temp_dir()
+ join = os.path.join
+ realdir = join(tmpdir, "event_containing_directory")
+
+ _CreateCleanDirectory(realdir)
+
+ self.assertEqual(x.Runs(), {})
+
+ _AddEvents(realdir)
+ x.AddRunsFromDirectory(realdir, "foo")
+ self.assertItemsEqual(x.Runs(), ["foo/."])
+
+ subdir = join(realdir, "subdir")
+ _AddEvents(subdir)
+ x.AddRunsFromDirectory(realdir, "foo")
+ self.assertItemsEqual(x.Runs(), ["foo/.", "foo/subdir"])
+
+ def testAddRunsFromDirectoryWalksTree(self):
+ x = event_multiplexer.EventMultiplexer()
+ tmpdir = self.get_temp_dir()
+ join = os.path.join
+ realdir = join(tmpdir, "event_containing_directory")
+
+ _CreateCleanDirectory(realdir)
+ _AddEvents(realdir)
+ sub = join(realdir, "subdirectory")
+ sub1 = join(sub, "1")
+ sub2 = join(sub, "2")
+ sub1_1 = join(sub1, "1")
+ _AddEvents(sub1)
+ _AddEvents(sub2)
+ _AddEvents(sub1_1)
+ x.AddRunsFromDirectory(realdir)
+
+ self.assertItemsEqual(
+ x.Runs(),
+ [".", "subdirectory/1", "subdirectory/2", "subdirectory/1/1"],
+ )
+
+ def testAddRunsFromDirectoryThrowsException(self):
+ x = event_multiplexer.EventMultiplexer()
+ tmpdir = self.get_temp_dir()
+
+ filepath = _AddEvents(tmpdir)
+ with self.assertRaises(ValueError):
+ x.AddRunsFromDirectory(filepath)
+
+ def testAddRun(self):
+ x = event_multiplexer.EventMultiplexer()
+ x.AddRun("run1_path", "run1")
+ run1 = x.GetAccumulator("run1")
+ self.assertEqual(sorted(x.Runs().keys()), ["run1"])
+ self.assertEqual(run1._path, "run1_path")
+
+ x.AddRun("run1_path", "run1")
+ self.assertEqual(run1, x.GetAccumulator("run1"), "loader not recreated")
+
+ x.AddRun("run2_path", "run1")
+ new_run1 = x.GetAccumulator("run1")
+ self.assertEqual(new_run1._path, "run2_path")
+ self.assertNotEqual(run1, new_run1)
+
+ x.AddRun("runName3")
+ self.assertItemsEqual(sorted(x.Runs().keys()), ["run1", "runName3"])
+ self.assertEqual(x.GetAccumulator("runName3")._path, "runName3")
+
+ def testAddRunMaintainsLoading(self):
+ x = event_multiplexer.EventMultiplexer()
+ x.Reload()
+ x.AddRun("run1")
+ x.AddRun("run2")
+ self.assertTrue(x.GetAccumulator("run1").reload_called)
+ self.assertTrue(x.GetAccumulator("run2").reload_called)
+
+ def testAddReloadWithMultipleThreads(self):
+ x = event_multiplexer.EventMultiplexer(max_reload_threads=2)
+ x.Reload()
+ x.AddRun("run1")
+ x.AddRun("run2")
+ x.AddRun("run3")
+ self.assertTrue(x.GetAccumulator("run1").reload_called)
+ self.assertTrue(x.GetAccumulator("run2").reload_called)
+ self.assertTrue(x.GetAccumulator("run3").reload_called)
+
+ def testReloadWithMoreRunsThanThreads(self):
+ patcher = tf.compat.v1.test.mock.patch(
+ "threading.Thread.start", autospec=True
+ )
+ start_mock = patcher.start()
+ self.addCleanup(patcher.stop)
+ patcher = tf.compat.v1.test.mock.patch(
+ "six.moves.queue.Queue.join", autospec=True
+ )
+ join_mock = patcher.start()
+ self.addCleanup(patcher.stop)
+
+ x = event_multiplexer.EventMultiplexer(max_reload_threads=2)
+ x.AddRun("run1")
+ x.AddRun("run2")
+ x.AddRun("run3")
+ x.Reload()
+
+ # 2 threads should have been started despite how there are 3 runs.
+ self.assertEqual(2, start_mock.call_count)
+ self.assertEqual(1, join_mock.call_count)
+
+ def testReloadWithMoreThreadsThanRuns(self):
+ patcher = tf.compat.v1.test.mock.patch(
+ "threading.Thread.start", autospec=True
+ )
+ start_mock = patcher.start()
+ self.addCleanup(patcher.stop)
+ patcher = tf.compat.v1.test.mock.patch(
+ "six.moves.queue.Queue.join", autospec=True
+ )
+ join_mock = patcher.start()
+ self.addCleanup(patcher.stop)
+
+ x = event_multiplexer.EventMultiplexer(max_reload_threads=42)
+ x.AddRun("run1")
+ x.AddRun("run2")
+ x.AddRun("run3")
+ x.Reload()
+
+ # 3 threads should have been started despite how the multiplexer
+ # could have started up to 42 threads.
+ self.assertEqual(3, start_mock.call_count)
+ self.assertEqual(1, join_mock.call_count)
+
+ def testReloadWith1Thread(self):
+ patcher = tf.compat.v1.test.mock.patch(
+ "threading.Thread.start", autospec=True
+ )
+ start_mock = patcher.start()
+ self.addCleanup(patcher.stop)
+ patcher = tf.compat.v1.test.mock.patch(
+ "six.moves.queue.Queue.join", autospec=True
+ )
+ join_mock = patcher.start()
+ self.addCleanup(patcher.stop)
+
+ x = event_multiplexer.EventMultiplexer(max_reload_threads=1)
+ x.AddRun("run1")
+ x.AddRun("run2")
+ x.AddRun("run3")
+ x.Reload()
+
+ # The multiplexer should have started no new threads.
+ self.assertEqual(0, start_mock.call_count)
+ self.assertEqual(0, join_mock.call_count)
class EventMultiplexerWithRealAccumulatorTest(tf.test.TestCase):
-
- def testMultifileReload(self):
- multiplexer = event_multiplexer.EventMultiplexer(
- event_file_active_filter=lambda timestamp: True)
- logdir = self.get_temp_dir()
- run_name = 'run1'
- run_path = os.path.join(logdir, run_name)
- # Create two separate event files, using filename suffix to ensure a
- # deterministic sort order, and then simulate a write to file A, then
- # to file B, then another write to file A (with reloads after each).
- with test_util.FileWriter(run_path, filename_suffix='.a') as writer_a:
- writer_a.add_test_summary('a1', step=1)
- writer_a.flush()
- multiplexer.AddRunsFromDirectory(logdir)
- multiplexer.Reload()
- with test_util.FileWriter(run_path, filename_suffix='.b') as writer_b:
- writer_b.add_test_summary('b', step=1)
- multiplexer.Reload()
- writer_a.add_test_summary('a2', step=2)
- writer_a.flush()
- multiplexer.Reload()
- # Both event files should be treated as active, so we should load the newly
- # written data to the first file even though it's no longer the latest one.
- self.assertEqual(1, len(multiplexer.Tensors(run_name, 'a1')))
- self.assertEqual(1, len(multiplexer.Tensors(run_name, 'b')))
- self.assertEqual(1, len(multiplexer.Tensors(run_name, 'a2')))
-
- def testDeletingDirectoryRemovesRun(self):
- x = event_multiplexer.EventMultiplexer()
- tmpdir = self.get_temp_dir()
- self._add3RunsToMultiplexer(tmpdir, x)
- x.Reload()
-
- # Delete the directory, then reload.
- shutil.rmtree(os.path.join(tmpdir, 'run2'))
- x.Reload()
- self.assertNotIn('run2', x.Runs().keys())
-
- def _add3RunsToMultiplexer(self, logdir, multiplexer):
- """Creates and adds 3 runs to the multiplexer."""
- run1_dir = os.path.join(logdir, 'run1')
- run2_dir = os.path.join(logdir, 'run2')
- run3_dir = os.path.join(logdir, 'run3')
-
- for dirname in [run1_dir, run2_dir, run3_dir]:
- _AddEvents(dirname)
-
- multiplexer.AddRun(run1_dir, 'run1')
- multiplexer.AddRun(run2_dir, 'run2')
- multiplexer.AddRun(run3_dir, 'run3')
-
-
-if __name__ == '__main__':
- tf.test.main()
+ def testMultifileReload(self):
+ multiplexer = event_multiplexer.EventMultiplexer(
+ event_file_active_filter=lambda timestamp: True
+ )
+ logdir = self.get_temp_dir()
+ run_name = "run1"
+ run_path = os.path.join(logdir, run_name)
+ # Create two separate event files, using filename suffix to ensure a
+ # deterministic sort order, and then simulate a write to file A, then
+ # to file B, then another write to file A (with reloads after each).
+ with test_util.FileWriter(run_path, filename_suffix=".a") as writer_a:
+ writer_a.add_test_summary("a1", step=1)
+ writer_a.flush()
+ multiplexer.AddRunsFromDirectory(logdir)
+ multiplexer.Reload()
+ with test_util.FileWriter(
+ run_path, filename_suffix=".b"
+ ) as writer_b:
+ writer_b.add_test_summary("b", step=1)
+ multiplexer.Reload()
+ writer_a.add_test_summary("a2", step=2)
+ writer_a.flush()
+ multiplexer.Reload()
+ # Both event files should be treated as active, so we should load the newly
+ # written data to the first file even though it's no longer the latest one.
+ self.assertEqual(1, len(multiplexer.Tensors(run_name, "a1")))
+ self.assertEqual(1, len(multiplexer.Tensors(run_name, "b")))
+ self.assertEqual(1, len(multiplexer.Tensors(run_name, "a2")))
+
+ def testDeletingDirectoryRemovesRun(self):
+ x = event_multiplexer.EventMultiplexer()
+ tmpdir = self.get_temp_dir()
+ self._add3RunsToMultiplexer(tmpdir, x)
+ x.Reload()
+
+ # Delete the directory, then reload.
+ shutil.rmtree(os.path.join(tmpdir, "run2"))
+ x.Reload()
+ self.assertNotIn("run2", x.Runs().keys())
+
+ def _add3RunsToMultiplexer(self, logdir, multiplexer):
+ """Creates and adds 3 runs to the multiplexer."""
+ run1_dir = os.path.join(logdir, "run1")
+ run2_dir = os.path.join(logdir, "run2")
+ run3_dir = os.path.join(logdir, "run3")
+
+ for dirname in [run1_dir, run2_dir, run3_dir]:
+ _AddEvents(dirname)
+
+ multiplexer.AddRun(run1_dir, "run1")
+ multiplexer.AddRun(run2_dir, "run2")
+ multiplexer.AddRun(run3_dir, "run3")
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorboard/backend/event_processing/reservoir.py b/tensorboard/backend/event_processing/reservoir.py
index f0ec95689f..74648e2021 100644
--- a/tensorboard/backend/event_processing/reservoir.py
+++ b/tensorboard/backend/event_processing/reservoir.py
@@ -25,7 +25,7 @@
class Reservoir(object):
- """A map-to-arrays container, with deterministic Reservoir Sampling.
+ """A map-to-arrays container, with deterministic Reservoir Sampling.
Items are added with an associated key. Items may be retrieved by key, and
a list of keys can also be retrieved. If size is not zero, then it dictates
@@ -57,203 +57,214 @@ class Reservoir(object):
always_keep_last: Whether the latest seen sample is always at the
end of the reservoir. Defaults to True.
size: An integer of the maximum number of samples.
- """
-
- def __init__(self, size, seed=0, always_keep_last=True):
- """Creates a new reservoir.
-
- Args:
- size: The number of values to keep in the reservoir for each tag. If 0,
- all values will be kept.
- seed: The seed of the random number generator to use when sampling.
- Different values for |seed| will produce different samples from the same
- input items.
- always_keep_last: Whether to always keep the latest seen item in the
- end of the reservoir. Defaults to True.
-
- Raises:
- ValueError: If size is negative or not an integer.
"""
- if size < 0 or size != round(size):
- raise ValueError('size must be nonnegative integer, was %s' % size)
- self._buckets = collections.defaultdict(
- lambda: _ReservoirBucket(size, random.Random(seed), always_keep_last))
- # _mutex guards the keys - creating new keys, retrieving by key, etc
- # the internal items are guarded by the ReservoirBuckets' internal mutexes
- self._mutex = threading.Lock()
- self.size = size
- self.always_keep_last = always_keep_last
-
- def Keys(self):
- """Return all the keys in the reservoir.
-
- Returns:
- ['list', 'of', 'keys'] in the Reservoir.
- """
- with self._mutex:
- return list(self._buckets.keys())
-
- def Items(self, key):
- """Return items associated with given key.
-
- Args:
- key: The key for which we are finding associated items.
-
- Raises:
- KeyError: If the key is not found in the reservoir.
-
- Returns:
- [list, of, items] associated with that key.
- """
- with self._mutex:
- if key not in self._buckets:
- raise KeyError('Key %s was not found in Reservoir' % key)
- bucket = self._buckets[key]
- return bucket.Items()
- def AddItem(self, key, item, f=lambda x: x):
- """Add a new item to the Reservoir with the given tag.
+ def __init__(self, size, seed=0, always_keep_last=True):
+ """Creates a new reservoir.
+
+ Args:
+ size: The number of values to keep in the reservoir for each tag. If 0,
+ all values will be kept.
+ seed: The seed of the random number generator to use when sampling.
+ Different values for |seed| will produce different samples from the same
+ input items.
+ always_keep_last: Whether to always keep the latest seen item in the
+ end of the reservoir. Defaults to True.
+
+ Raises:
+ ValueError: If size is negative or not an integer.
+ """
+ if size < 0 or size != round(size):
+ raise ValueError("size must be nonnegative integer, was %s" % size)
+ self._buckets = collections.defaultdict(
+ lambda: _ReservoirBucket(
+ size, random.Random(seed), always_keep_last
+ )
+ )
+ # _mutex guards the keys - creating new keys, retrieving by key, etc
+ # the internal items are guarded by the ReservoirBuckets' internal mutexes
+ self._mutex = threading.Lock()
+ self.size = size
+ self.always_keep_last = always_keep_last
+
+ def Keys(self):
+ """Return all the keys in the reservoir.
+
+ Returns:
+ ['list', 'of', 'keys'] in the Reservoir.
+ """
+ with self._mutex:
+ return list(self._buckets.keys())
+
+ def Items(self, key):
+ """Return items associated with given key.
+
+ Args:
+ key: The key for which we are finding associated items.
+
+ Raises:
+ KeyError: If the key is not found in the reservoir.
+
+ Returns:
+ [list, of, items] associated with that key.
+ """
+ with self._mutex:
+ if key not in self._buckets:
+ raise KeyError("Key %s was not found in Reservoir" % key)
+ bucket = self._buckets[key]
+ return bucket.Items()
+
+ def AddItem(self, key, item, f=lambda x: x):
+ """Add a new item to the Reservoir with the given tag.
+
+ If the reservoir has not yet reached full size, the new item is guaranteed
+ to be added. If the reservoir is full, then behavior depends on the
+ always_keep_last boolean.
+
+ If always_keep_last was set to true, the new item is guaranteed to be added
+ to the reservoir, and either the previous last item will be replaced, or
+ (with low probability) an older item will be replaced.
+
+ If always_keep_last was set to false, then the new item will replace an
+ old item with low probability.
+
+ If f is provided, it will be applied to transform item (lazily, iff item is
+ going to be included in the reservoir).
+
+ Args:
+ key: The key to store the item under.
+ item: The item to add to the reservoir.
+ f: An optional function to transform the item prior to addition.
+ """
+ with self._mutex:
+ bucket = self._buckets[key]
+ bucket.AddItem(item, f)
+
+ def FilterItems(self, filterFn, key=None):
+ """Filter items within a Reservoir, using a filtering function.
+
+ Args:
+ filterFn: A function that returns True for the items to be kept.
+ key: An optional bucket key to filter. If not specified, will filter all
+ all buckets.
+
+ Returns:
+ The number of items removed.
+ """
+ with self._mutex:
+ if key:
+ if key in self._buckets:
+ return self._buckets[key].FilterItems(filterFn)
+ else:
+ return 0
+ else:
+ return sum(
+ bucket.FilterItems(filterFn)
+ for bucket in self._buckets.values()
+ )
- If the reservoir has not yet reached full size, the new item is guaranteed
- to be added. If the reservoir is full, then behavior depends on the
- always_keep_last boolean.
- If always_keep_last was set to true, the new item is guaranteed to be added
- to the reservoir, and either the previous last item will be replaced, or
- (with low probability) an older item will be replaced.
-
- If always_keep_last was set to false, then the new item will replace an
- old item with low probability.
-
- If f is provided, it will be applied to transform item (lazily, iff item is
- going to be included in the reservoir).
+class _ReservoirBucket(object):
+ """A container for items from a stream, that implements reservoir sampling.
- Args:
- key: The key to store the item under.
- item: The item to add to the reservoir.
- f: An optional function to transform the item prior to addition.
+ It always stores the most recent item as its final item.
"""
- with self._mutex:
- bucket = self._buckets[key]
- bucket.AddItem(item, f)
-
- def FilterItems(self, filterFn, key=None):
- """Filter items within a Reservoir, using a filtering function.
-
- Args:
- filterFn: A function that returns True for the items to be kept.
- key: An optional bucket key to filter. If not specified, will filter all
- all buckets.
- Returns:
- The number of items removed.
- """
- with self._mutex:
- if key:
- if key in self._buckets:
- return self._buckets[key].FilterItems(filterFn)
+ def __init__(self, _max_size, _random=None, always_keep_last=True):
+ """Create the _ReservoirBucket.
+
+ Args:
+ _max_size: The maximum size the reservoir bucket may grow to. If size is
+ zero, the bucket has unbounded size.
+ _random: The random number generator to use. If not specified, defaults to
+ random.Random(0).
+ always_keep_last: Whether the latest seen item should always be included
+ in the end of the bucket.
+
+ Raises:
+ ValueError: if the size is not a nonnegative integer.
+ """
+ if _max_size < 0 or _max_size != round(_max_size):
+ raise ValueError(
+ "_max_size must be nonnegative int, was %s" % _max_size
+ )
+ self.items = []
+ # This mutex protects the internal items, ensuring that calls to Items and
+ # AddItem are thread-safe
+ self._mutex = threading.Lock()
+ self._max_size = _max_size
+ self._num_items_seen = 0
+ if _random is not None:
+ self._random = _random
else:
- return 0
- else:
- return sum(bucket.FilterItems(filterFn)
- for bucket in self._buckets.values())
-
-
-class _ReservoirBucket(object):
- """A container for items from a stream, that implements reservoir sampling.
-
- It always stores the most recent item as its final item.
- """
-
- def __init__(self, _max_size, _random=None, always_keep_last=True):
- """Create the _ReservoirBucket.
-
- Args:
- _max_size: The maximum size the reservoir bucket may grow to. If size is
- zero, the bucket has unbounded size.
- _random: The random number generator to use. If not specified, defaults to
- random.Random(0).
- always_keep_last: Whether the latest seen item should always be included
- in the end of the bucket.
-
- Raises:
- ValueError: if the size is not a nonnegative integer.
- """
- if _max_size < 0 or _max_size != round(_max_size):
- raise ValueError('_max_size must be nonnegative int, was %s' % _max_size)
- self.items = []
- # This mutex protects the internal items, ensuring that calls to Items and
- # AddItem are thread-safe
- self._mutex = threading.Lock()
- self._max_size = _max_size
- self._num_items_seen = 0
- if _random is not None:
- self._random = _random
- else:
- self._random = random.Random(0)
- self.always_keep_last = always_keep_last
-
- def AddItem(self, item, f=lambda x: x):
- """Add an item to the ReservoirBucket, replacing an old item if necessary.
-
- The new item is guaranteed to be added to the bucket, and to be the last
- element in the bucket. If the bucket has reached capacity, then an old item
- will be replaced. With probability (_max_size/_num_items_seen) a random item
- in the bucket will be popped out and the new item will be appended
- to the end. With probability (1 - _max_size/_num_items_seen)
- the last item in the bucket will be replaced.
-
- Since the O(n) replacements occur with O(1/_num_items_seen) likelihood,
- the amortized runtime is O(1).
-
- Args:
- item: The item to add to the bucket.
- f: A function to transform item before addition, if it will be kept in
- the reservoir.
- """
- with self._mutex:
- if len(self.items) < self._max_size or self._max_size == 0:
- self.items.append(f(item))
- else:
- r = self._random.randint(0, self._num_items_seen)
- if r < self._max_size:
- self.items.pop(r)
- self.items.append(f(item))
- elif self.always_keep_last:
- self.items[-1] = f(item)
- self._num_items_seen += 1
-
- def FilterItems(self, filterFn):
- """Filter items in a ReservoirBucket, using a filtering function.
-
- Filtering items from the reservoir bucket must update the
- internal state variable self._num_items_seen, which is used for determining
- the rate of replacement in reservoir sampling. Ideally, self._num_items_seen
- would contain the exact number of items that have ever seen by the
- ReservoirBucket and satisfy filterFn. However, the ReservoirBucket does not
- have access to all items seen -- it only has access to the subset of items
- that have survived sampling (self.items). Therefore, we estimate
- self._num_items_seen by scaling it by the same ratio as the ratio of items
- not removed from self.items.
-
- Args:
- filterFn: A function that returns True for items to be kept.
-
- Returns:
- The number of items removed from the bucket.
- """
- with self._mutex:
- size_before = len(self.items)
- self.items = list(filter(filterFn, self.items))
- size_diff = size_before - len(self.items)
-
- # Estimate a correction the number of items seen
- prop_remaining = len(self.items) / float(
- size_before) if size_before > 0 else 0
- self._num_items_seen = int(round(self._num_items_seen * prop_remaining))
- return size_diff
-
- def Items(self):
- """Get all the items in the bucket."""
- with self._mutex:
- return list(self.items)
+ self._random = random.Random(0)
+ self.always_keep_last = always_keep_last
+
+ def AddItem(self, item, f=lambda x: x):
+ """Add an item to the ReservoirBucket, replacing an old item if
+ necessary.
+
+ The new item is guaranteed to be added to the bucket, and to be the last
+ element in the bucket. If the bucket has reached capacity, then an old item
+ will be replaced. With probability (_max_size/_num_items_seen) a random item
+ in the bucket will be popped out and the new item will be appended
+ to the end. With probability (1 - _max_size/_num_items_seen)
+ the last item in the bucket will be replaced.
+
+ Since the O(n) replacements occur with O(1/_num_items_seen) likelihood,
+ the amortized runtime is O(1).
+
+ Args:
+ item: The item to add to the bucket.
+ f: A function to transform item before addition, if it will be kept in
+ the reservoir.
+ """
+ with self._mutex:
+ if len(self.items) < self._max_size or self._max_size == 0:
+ self.items.append(f(item))
+ else:
+ r = self._random.randint(0, self._num_items_seen)
+ if r < self._max_size:
+ self.items.pop(r)
+ self.items.append(f(item))
+ elif self.always_keep_last:
+ self.items[-1] = f(item)
+ self._num_items_seen += 1
+
+ def FilterItems(self, filterFn):
+ """Filter items in a ReservoirBucket, using a filtering function.
+
+ Filtering items from the reservoir bucket must update the
+ internal state variable self._num_items_seen, which is used for determining
+ the rate of replacement in reservoir sampling. Ideally, self._num_items_seen
+ would contain the exact number of items that have ever seen by the
+ ReservoirBucket and satisfy filterFn. However, the ReservoirBucket does not
+ have access to all items seen -- it only has access to the subset of items
+ that have survived sampling (self.items). Therefore, we estimate
+ self._num_items_seen by scaling it by the same ratio as the ratio of items
+ not removed from self.items.
+
+ Args:
+ filterFn: A function that returns True for items to be kept.
+
+ Returns:
+ The number of items removed from the bucket.
+ """
+ with self._mutex:
+ size_before = len(self.items)
+ self.items = list(filter(filterFn, self.items))
+ size_diff = size_before - len(self.items)
+
+ # Estimate a correction the number of items seen
+ prop_remaining = (
+ len(self.items) / float(size_before) if size_before > 0 else 0
+ )
+ self._num_items_seen = int(
+ round(self._num_items_seen * prop_remaining)
+ )
+ return size_diff
+
+ def Items(self):
+ """Get all the items in the bucket."""
+ with self._mutex:
+ return list(self.items)
diff --git a/tensorboard/backend/event_processing/reservoir_test.py b/tensorboard/backend/event_processing/reservoir_test.py
index 50b830903e..4096918482 100644
--- a/tensorboard/backend/event_processing/reservoir_test.py
+++ b/tensorboard/backend/event_processing/reservoir_test.py
@@ -24,256 +24,264 @@
class ReservoirTest(tf.test.TestCase):
-
- def testEmptyReservoir(self):
- r = reservoir.Reservoir(1)
- self.assertFalse(r.Keys())
-
- def testRespectsSize(self):
- r = reservoir.Reservoir(42)
- self.assertEqual(r._buckets['meaning of life']._max_size, 42)
-
- def testItemsAndKeys(self):
- r = reservoir.Reservoir(42)
- r.AddItem('foo', 4)
- r.AddItem('bar', 9)
- r.AddItem('foo', 19)
- self.assertItemsEqual(r.Keys(), ['foo', 'bar'])
- self.assertEqual(r.Items('foo'), [4, 19])
- self.assertEqual(r.Items('bar'), [9])
-
- def testExceptions(self):
- with self.assertRaises(ValueError):
- reservoir.Reservoir(-1)
- with self.assertRaises(ValueError):
- reservoir.Reservoir(13.3)
-
- r = reservoir.Reservoir(12)
- with self.assertRaises(KeyError):
- r.Items('missing key')
-
- def testDeterminism(self):
- """Tests that the reservoir is deterministic."""
- key = 'key'
- r1 = reservoir.Reservoir(10)
- r2 = reservoir.Reservoir(10)
- for i in xrange(100):
- r1.AddItem('key', i)
- r2.AddItem('key', i)
-
- self.assertEqual(r1.Items(key), r2.Items(key))
-
- def testBucketDeterminism(self):
- """Tests that reservoirs are deterministic at a bucket level.
-
- This means that only the order elements are added within a bucket matters.
- """
- separate_reservoir = reservoir.Reservoir(10)
- interleaved_reservoir = reservoir.Reservoir(10)
- for i in xrange(100):
- separate_reservoir.AddItem('key1', i)
- for i in xrange(100):
- separate_reservoir.AddItem('key2', i)
- for i in xrange(100):
- interleaved_reservoir.AddItem('key1', i)
- interleaved_reservoir.AddItem('key2', i)
-
- for key in ['key1', 'key2']:
- self.assertEqual(
- separate_reservoir.Items(key), interleaved_reservoir.Items(key))
-
- def testUsesSeed(self):
- """Tests that reservoirs with different seeds keep different samples."""
- key = 'key'
- r1 = reservoir.Reservoir(10, seed=0)
- r2 = reservoir.Reservoir(10, seed=1)
- for i in xrange(100):
- r1.AddItem('key', i)
- r2.AddItem('key', i)
- self.assertNotEqual(r1.Items(key), r2.Items(key))
-
- def testFilterItemsByKey(self):
- r = reservoir.Reservoir(100, seed=0)
- for i in xrange(10):
- r.AddItem('key1', i)
- r.AddItem('key2', i)
-
- self.assertEqual(len(r.Items('key1')), 10)
- self.assertEqual(len(r.Items('key2')), 10)
-
- self.assertEqual(r.FilterItems(lambda x: x <= 7, 'key2'), 2)
- self.assertEqual(len(r.Items('key2')), 8)
- self.assertEqual(len(r.Items('key1')), 10)
-
- self.assertEqual(r.FilterItems(lambda x: x <= 3, 'key1'), 6)
- self.assertEqual(len(r.Items('key1')), 4)
- self.assertEqual(len(r.Items('key2')), 8)
+ def testEmptyReservoir(self):
+ r = reservoir.Reservoir(1)
+ self.assertFalse(r.Keys())
+
+ def testRespectsSize(self):
+ r = reservoir.Reservoir(42)
+ self.assertEqual(r._buckets["meaning of life"]._max_size, 42)
+
+ def testItemsAndKeys(self):
+ r = reservoir.Reservoir(42)
+ r.AddItem("foo", 4)
+ r.AddItem("bar", 9)
+ r.AddItem("foo", 19)
+ self.assertItemsEqual(r.Keys(), ["foo", "bar"])
+ self.assertEqual(r.Items("foo"), [4, 19])
+ self.assertEqual(r.Items("bar"), [9])
+
+ def testExceptions(self):
+ with self.assertRaises(ValueError):
+ reservoir.Reservoir(-1)
+ with self.assertRaises(ValueError):
+ reservoir.Reservoir(13.3)
+
+ r = reservoir.Reservoir(12)
+ with self.assertRaises(KeyError):
+ r.Items("missing key")
+
+ def testDeterminism(self):
+ """Tests that the reservoir is deterministic."""
+ key = "key"
+ r1 = reservoir.Reservoir(10)
+ r2 = reservoir.Reservoir(10)
+ for i in xrange(100):
+ r1.AddItem("key", i)
+ r2.AddItem("key", i)
+
+ self.assertEqual(r1.Items(key), r2.Items(key))
+
+ def testBucketDeterminism(self):
+ """Tests that reservoirs are deterministic at a bucket level.
+
+ This means that only the order elements are added within a
+ bucket matters.
+ """
+ separate_reservoir = reservoir.Reservoir(10)
+ interleaved_reservoir = reservoir.Reservoir(10)
+ for i in xrange(100):
+ separate_reservoir.AddItem("key1", i)
+ for i in xrange(100):
+ separate_reservoir.AddItem("key2", i)
+ for i in xrange(100):
+ interleaved_reservoir.AddItem("key1", i)
+ interleaved_reservoir.AddItem("key2", i)
+
+ for key in ["key1", "key2"]:
+ self.assertEqual(
+ separate_reservoir.Items(key), interleaved_reservoir.Items(key)
+ )
+
+ def testUsesSeed(self):
+ """Tests that reservoirs with different seeds keep different
+ samples."""
+ key = "key"
+ r1 = reservoir.Reservoir(10, seed=0)
+ r2 = reservoir.Reservoir(10, seed=1)
+ for i in xrange(100):
+ r1.AddItem("key", i)
+ r2.AddItem("key", i)
+ self.assertNotEqual(r1.Items(key), r2.Items(key))
+
+ def testFilterItemsByKey(self):
+ r = reservoir.Reservoir(100, seed=0)
+ for i in xrange(10):
+ r.AddItem("key1", i)
+ r.AddItem("key2", i)
+
+ self.assertEqual(len(r.Items("key1")), 10)
+ self.assertEqual(len(r.Items("key2")), 10)
+
+ self.assertEqual(r.FilterItems(lambda x: x <= 7, "key2"), 2)
+ self.assertEqual(len(r.Items("key2")), 8)
+ self.assertEqual(len(r.Items("key1")), 10)
+
+ self.assertEqual(r.FilterItems(lambda x: x <= 3, "key1"), 6)
+ self.assertEqual(len(r.Items("key1")), 4)
+ self.assertEqual(len(r.Items("key2")), 8)
class ReservoirBucketTest(tf.test.TestCase):
-
- def testEmptyBucket(self):
- b = reservoir._ReservoirBucket(1)
- self.assertFalse(b.Items())
-
- def testFillToSize(self):
- b = reservoir._ReservoirBucket(100)
- for i in xrange(100):
- b.AddItem(i)
- self.assertEqual(b.Items(), list(xrange(100)))
- self.assertEqual(b._num_items_seen, 100)
-
- def testDoesntOverfill(self):
- b = reservoir._ReservoirBucket(10)
- for i in xrange(1000):
- b.AddItem(i)
- self.assertEqual(len(b.Items()), 10)
- self.assertEqual(b._num_items_seen, 1000)
-
- def testMaintainsOrder(self):
- b = reservoir._ReservoirBucket(100)
- for i in xrange(10000):
- b.AddItem(i)
- items = b.Items()
- prev = -1
- for item in items:
- self.assertTrue(item > prev)
- prev = item
-
- def testKeepsLatestItem(self):
- b = reservoir._ReservoirBucket(5)
- for i in xrange(100):
- b.AddItem(i)
- last = b.Items()[-1]
- self.assertEqual(last, i)
-
- def testSizeOneBucket(self):
- b = reservoir._ReservoirBucket(1)
- for i in xrange(20):
- b.AddItem(i)
- self.assertEqual(b.Items(), [i])
- self.assertEqual(b._num_items_seen, 20)
-
- def testSizeZeroBucket(self):
- b = reservoir._ReservoirBucket(0)
- for i in xrange(20):
- b.AddItem(i)
- self.assertEqual(b.Items(), list(range(i + 1)))
- self.assertEqual(b._num_items_seen, 20)
-
- def testSizeRequirement(self):
- with self.assertRaises(ValueError):
- reservoir._ReservoirBucket(-1)
- with self.assertRaises(ValueError):
- reservoir._ReservoirBucket(10.3)
-
- def testRemovesItems(self):
- b = reservoir._ReservoirBucket(100)
- for i in xrange(10):
- b.AddItem(i)
- self.assertEqual(len(b.Items()), 10)
- self.assertEqual(b._num_items_seen, 10)
- self.assertEqual(b.FilterItems(lambda x: x <= 7), 2)
- self.assertEqual(len(b.Items()), 8)
- self.assertEqual(b._num_items_seen, 8)
-
- def testRemovesItemsWhenItemsAreReplaced(self):
- b = reservoir._ReservoirBucket(100)
- for i in xrange(10000):
- b.AddItem(i)
- self.assertEqual(b._num_items_seen, 10000)
-
- # Remove items
- num_removed = b.FilterItems(lambda x: x <= 7)
- self.assertGreater(num_removed, 92)
- self.assertEqual([], [item for item in b.Items() if item > 7])
- self.assertEqual(b._num_items_seen,
- int(round(10000 * (1 - float(num_removed) / 100))))
-
- def testLazyFunctionEvaluationAndAlwaysKeepLast(self):
-
- class FakeRandom(object):
-
- def randint(self, a, b): # pylint:disable=unused-argument
- return 999
-
- class Incrementer(object):
-
- def __init__(self):
- self.n = 0
-
- def increment_and_double(self, x):
- self.n += 1
- return x * 2
-
- # We've mocked the randomness generator, so that once it is full, the last
- # item will never get durable reservoir inclusion. Since always_keep_last is
- # false, the function should only get invoked 100 times while filling up
- # the reservoir. This laziness property is an essential performance
- # optimization.
- b = reservoir._ReservoirBucket(100, FakeRandom(), always_keep_last=False)
- incrementer = Incrementer()
- for i in xrange(1000):
- b.AddItem(i, incrementer.increment_and_double)
- self.assertEqual(incrementer.n, 100)
- self.assertEqual(b.Items(), [x * 2 for x in xrange(100)])
-
- # This time, we will always keep the last item, meaning that the function
- # should get invoked once for every item we add.
- b = reservoir._ReservoirBucket(100, FakeRandom(), always_keep_last=True)
- incrementer = Incrementer()
-
- for i in xrange(1000):
- b.AddItem(i, incrementer.increment_and_double)
- self.assertEqual(incrementer.n, 1000)
- self.assertEqual(b.Items(), [x * 2 for x in xrange(99)] + [999 * 2])
+ def testEmptyBucket(self):
+ b = reservoir._ReservoirBucket(1)
+ self.assertFalse(b.Items())
+
+ def testFillToSize(self):
+ b = reservoir._ReservoirBucket(100)
+ for i in xrange(100):
+ b.AddItem(i)
+ self.assertEqual(b.Items(), list(xrange(100)))
+ self.assertEqual(b._num_items_seen, 100)
+
+ def testDoesntOverfill(self):
+ b = reservoir._ReservoirBucket(10)
+ for i in xrange(1000):
+ b.AddItem(i)
+ self.assertEqual(len(b.Items()), 10)
+ self.assertEqual(b._num_items_seen, 1000)
+
+ def testMaintainsOrder(self):
+ b = reservoir._ReservoirBucket(100)
+ for i in xrange(10000):
+ b.AddItem(i)
+ items = b.Items()
+ prev = -1
+ for item in items:
+ self.assertTrue(item > prev)
+ prev = item
+
+ def testKeepsLatestItem(self):
+ b = reservoir._ReservoirBucket(5)
+ for i in xrange(100):
+ b.AddItem(i)
+ last = b.Items()[-1]
+ self.assertEqual(last, i)
+
+ def testSizeOneBucket(self):
+ b = reservoir._ReservoirBucket(1)
+ for i in xrange(20):
+ b.AddItem(i)
+ self.assertEqual(b.Items(), [i])
+ self.assertEqual(b._num_items_seen, 20)
+
+ def testSizeZeroBucket(self):
+ b = reservoir._ReservoirBucket(0)
+ for i in xrange(20):
+ b.AddItem(i)
+ self.assertEqual(b.Items(), list(range(i + 1)))
+ self.assertEqual(b._num_items_seen, 20)
+
+ def testSizeRequirement(self):
+ with self.assertRaises(ValueError):
+ reservoir._ReservoirBucket(-1)
+ with self.assertRaises(ValueError):
+ reservoir._ReservoirBucket(10.3)
+
+ def testRemovesItems(self):
+ b = reservoir._ReservoirBucket(100)
+ for i in xrange(10):
+ b.AddItem(i)
+ self.assertEqual(len(b.Items()), 10)
+ self.assertEqual(b._num_items_seen, 10)
+ self.assertEqual(b.FilterItems(lambda x: x <= 7), 2)
+ self.assertEqual(len(b.Items()), 8)
+ self.assertEqual(b._num_items_seen, 8)
+
+ def testRemovesItemsWhenItemsAreReplaced(self):
+ b = reservoir._ReservoirBucket(100)
+ for i in xrange(10000):
+ b.AddItem(i)
+ self.assertEqual(b._num_items_seen, 10000)
+
+ # Remove items
+ num_removed = b.FilterItems(lambda x: x <= 7)
+ self.assertGreater(num_removed, 92)
+ self.assertEqual([], [item for item in b.Items() if item > 7])
+ self.assertEqual(
+ b._num_items_seen,
+ int(round(10000 * (1 - float(num_removed) / 100))),
+ )
+
+ def testLazyFunctionEvaluationAndAlwaysKeepLast(self):
+ class FakeRandom(object):
+ def randint(self, a, b): # pylint:disable=unused-argument
+ return 999
+
+ class Incrementer(object):
+ def __init__(self):
+ self.n = 0
+
+ def increment_and_double(self, x):
+ self.n += 1
+ return x * 2
+
+ # We've mocked the randomness generator, so that once it is full, the last
+ # item will never get durable reservoir inclusion. Since always_keep_last is
+ # false, the function should only get invoked 100 times while filling up
+ # the reservoir. This laziness property is an essential performance
+ # optimization.
+ b = reservoir._ReservoirBucket(
+ 100, FakeRandom(), always_keep_last=False
+ )
+ incrementer = Incrementer()
+ for i in xrange(1000):
+ b.AddItem(i, incrementer.increment_and_double)
+ self.assertEqual(incrementer.n, 100)
+ self.assertEqual(b.Items(), [x * 2 for x in xrange(100)])
+
+ # This time, we will always keep the last item, meaning that the function
+ # should get invoked once for every item we add.
+ b = reservoir._ReservoirBucket(100, FakeRandom(), always_keep_last=True)
+ incrementer = Incrementer()
+
+ for i in xrange(1000):
+ b.AddItem(i, incrementer.increment_and_double)
+ self.assertEqual(incrementer.n, 1000)
+ self.assertEqual(b.Items(), [x * 2 for x in xrange(99)] + [999 * 2])
class ReservoirBucketStatisticalDistributionTest(tf.test.TestCase):
-
- def setUp(self):
- self.total = 1000000
- self.samples = 10000
- self.n_buckets = 100
- self.total_per_bucket = self.total // self.n_buckets
- self.assertEqual(self.total % self.n_buckets, 0, 'total must be evenly '
- 'divisible by the number of buckets')
- self.assertTrue(self.total > self.samples, 'need to have more items '
- 'than samples')
-
- def AssertBinomialQuantity(self, measured):
- p = 1.0 * self.n_buckets / self.samples
- mean = p * self.samples
- variance = p * (1 - p) * self.samples
- error = measured - mean
- # Given that the buckets were actually binomially distributed, this
- # fails with probability ~2E-9
- passed = error * error <= 36.0 * variance
- self.assertTrue(passed, 'found a bucket with measured %d '
- 'too far from expected %d' % (measured, mean))
-
- def testBucketReservoirSamplingViaStatisticalProperties(self):
- # Not related to a 'ReservoirBucket', but instead number of buckets we put
- # samples into for testing the shape of the distribution
- b = reservoir._ReservoirBucket(_max_size=self.samples)
- # add one extra item because we always keep the most recent item, which
- # would skew the distribution; we can just slice it off the end instead.
- for i in xrange(self.total + 1):
- b.AddItem(i)
-
- divbins = [0] * self.n_buckets
- modbins = [0] * self.n_buckets
- # Slice off the last item when we iterate.
- for item in b.Items()[0:-1]:
- divbins[item // self.total_per_bucket] += 1
- modbins[item % self.n_buckets] += 1
-
- for bucket_index in xrange(self.n_buckets):
- divbin = divbins[bucket_index]
- modbin = modbins[bucket_index]
- self.AssertBinomialQuantity(divbin)
- self.AssertBinomialQuantity(modbin)
-
-
-if __name__ == '__main__':
- tf.test.main()
+ def setUp(self):
+ self.total = 1000000
+ self.samples = 10000
+ self.n_buckets = 100
+ self.total_per_bucket = self.total // self.n_buckets
+ self.assertEqual(
+ self.total % self.n_buckets,
+ 0,
+ "total must be evenly " "divisible by the number of buckets",
+ )
+ self.assertTrue(
+ self.total > self.samples, "need to have more items " "than samples"
+ )
+
+ def AssertBinomialQuantity(self, measured):
+ p = 1.0 * self.n_buckets / self.samples
+ mean = p * self.samples
+ variance = p * (1 - p) * self.samples
+ error = measured - mean
+ # Given that the buckets were actually binomially distributed, this
+ # fails with probability ~2E-9
+ passed = error * error <= 36.0 * variance
+ self.assertTrue(
+ passed,
+ "found a bucket with measured %d "
+ "too far from expected %d" % (measured, mean),
+ )
+
+ def testBucketReservoirSamplingViaStatisticalProperties(self):
+ # Not related to a 'ReservoirBucket', but instead number of buckets we put
+ # samples into for testing the shape of the distribution
+ b = reservoir._ReservoirBucket(_max_size=self.samples)
+ # add one extra item because we always keep the most recent item, which
+ # would skew the distribution; we can just slice it off the end instead.
+ for i in xrange(self.total + 1):
+ b.AddItem(i)
+
+ divbins = [0] * self.n_buckets
+ modbins = [0] * self.n_buckets
+ # Slice off the last item when we iterate.
+ for item in b.Items()[0:-1]:
+ divbins[item // self.total_per_bucket] += 1
+ modbins[item % self.n_buckets] += 1
+
+ for bucket_index in xrange(self.n_buckets):
+ divbin = divbins[bucket_index]
+ modbin = modbins[bucket_index]
+ self.AssertBinomialQuantity(divbin)
+ self.AssertBinomialQuantity(modbin)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorboard/backend/event_processing/sqlite_writer.py b/tensorboard/backend/event_processing/sqlite_writer.py
index bed30f570a..7ac2f05c2f 100644
--- a/tensorboard/backend/event_processing/sqlite_writer.py
+++ b/tensorboard/backend/event_processing/sqlite_writer.py
@@ -34,184 +34,215 @@
# Struct bundling a tag with its SummaryMetadata and a list of values, each of
# which are a tuple of step, wall time (as a float), and a TensorProto.
-TagData = collections.namedtuple('TagData', ['tag', 'metadata', 'values'])
+TagData = collections.namedtuple("TagData", ["tag", "metadata", "values"])
class SqliteWriter(object):
- """Sends summary data to SQLite using python's sqlite3 module."""
+ """Sends summary data to SQLite using python's sqlite3 module."""
- def __init__(self, db_connection_provider):
- """Constructs a SqliteWriterEventSink.
+ def __init__(self, db_connection_provider):
+ """Constructs a SqliteWriterEventSink.
- Args:
- db_connection_provider: Provider function for creating a DB connection.
- """
- self._db = db_connection_provider()
+ Args:
+ db_connection_provider: Provider function for creating a DB connection.
+ """
+ self._db = db_connection_provider()
- def _make_blob(self, bytestring):
- """Helper to ensure SQLite treats the given data as a BLOB."""
- # Special-case python 2 pysqlite which uses buffers for BLOB.
- if sys.version_info[0] == 2:
- return buffer(bytestring) # noqa: F821 (undefined name)
- return bytestring
+ def _make_blob(self, bytestring):
+ """Helper to ensure SQLite treats the given data as a BLOB."""
+ # Special-case python 2 pysqlite which uses buffers for BLOB.
+ if sys.version_info[0] == 2:
+ return buffer(bytestring) # noqa: F821 (undefined name)
+ return bytestring
- def _create_id(self):
- """Returns a freshly created DB-wide unique ID."""
- cursor = self._db.cursor()
- cursor.execute('INSERT INTO Ids DEFAULT VALUES')
- return cursor.lastrowid
+ def _create_id(self):
+ """Returns a freshly created DB-wide unique ID."""
+ cursor = self._db.cursor()
+ cursor.execute("INSERT INTO Ids DEFAULT VALUES")
+ return cursor.lastrowid
- def _maybe_init_user(self):
- """Returns the ID for the current user, creating the row if needed."""
- user_name = os.environ.get('USER', '') or os.environ.get('USERNAME', '')
- cursor = self._db.cursor()
- cursor.execute('SELECT user_id FROM Users WHERE user_name = ?',
- (user_name,))
- row = cursor.fetchone()
- if row:
- return row[0]
- user_id = self._create_id()
- cursor.execute(
- """
+ def _maybe_init_user(self):
+ """Returns the ID for the current user, creating the row if needed."""
+ user_name = os.environ.get("USER", "") or os.environ.get("USERNAME", "")
+ cursor = self._db.cursor()
+ cursor.execute(
+ "SELECT user_id FROM Users WHERE user_name = ?", (user_name,)
+ )
+ row = cursor.fetchone()
+ if row:
+ return row[0]
+ user_id = self._create_id()
+ cursor.execute(
+ """
INSERT INTO USERS (user_id, user_name, inserted_time)
VALUES (?, ?, ?)
""",
- (user_id, user_name, time.time()))
- return user_id
+ (user_id, user_name, time.time()),
+ )
+ return user_id
- def _maybe_init_experiment(self, experiment_name):
- """Returns the ID for the given experiment, creating the row if needed.
+ def _maybe_init_experiment(self, experiment_name):
+ """Returns the ID for the given experiment, creating the row if needed.
- Args:
- experiment_name: name of experiment.
- """
- user_id = self._maybe_init_user()
- cursor = self._db.cursor()
- cursor.execute(
+ Args:
+ experiment_name: name of experiment.
"""
+ user_id = self._maybe_init_user()
+ cursor = self._db.cursor()
+ cursor.execute(
+ """
SELECT experiment_id FROM Experiments
WHERE user_id = ? AND experiment_name = ?
""",
- (user_id, experiment_name))
- row = cursor.fetchone()
- if row:
- return row[0]
- experiment_id = self._create_id()
- # TODO: track computed time from run start times
- computed_time = 0
- cursor.execute(
- """
+ (user_id, experiment_name),
+ )
+ row = cursor.fetchone()
+ if row:
+ return row[0]
+ experiment_id = self._create_id()
+ # TODO: track computed time from run start times
+ computed_time = 0
+ cursor.execute(
+ """
INSERT INTO Experiments (
user_id, experiment_id, experiment_name,
inserted_time, started_time, is_watching
) VALUES (?, ?, ?, ?, ?, ?)
""",
- (user_id, experiment_id, experiment_name, time.time(), computed_time,
- False))
- return experiment_id
+ (
+ user_id,
+ experiment_id,
+ experiment_name,
+ time.time(),
+ computed_time,
+ False,
+ ),
+ )
+ return experiment_id
- def _maybe_init_run(self, experiment_name, run_name):
- """Returns the ID for the given run, creating the row if needed.
+ def _maybe_init_run(self, experiment_name, run_name):
+ """Returns the ID for the given run, creating the row if needed.
- Args:
- experiment_name: name of experiment containing this run.
- run_name: name of run.
- """
- experiment_id = self._maybe_init_experiment(experiment_name)
- cursor = self._db.cursor()
- cursor.execute(
+ Args:
+ experiment_name: name of experiment containing this run.
+ run_name: name of run.
"""
+ experiment_id = self._maybe_init_experiment(experiment_name)
+ cursor = self._db.cursor()
+ cursor.execute(
+ """
SELECT run_id FROM Runs
WHERE experiment_id = ? AND run_name = ?
""",
- (experiment_id, run_name))
- row = cursor.fetchone()
- if row:
- return row[0]
- run_id = self._create_id()
- # TODO: track actual run start times
- started_time = 0
- cursor.execute(
- """
+ (experiment_id, run_name),
+ )
+ row = cursor.fetchone()
+ if row:
+ return row[0]
+ run_id = self._create_id()
+ # TODO: track actual run start times
+ started_time = 0
+ cursor.execute(
+ """
INSERT INTO Runs (
experiment_id, run_id, run_name, inserted_time, started_time
) VALUES (?, ?, ?, ?, ?)
""",
- (experiment_id, run_id, run_name, time.time(), started_time))
- return run_id
+ (experiment_id, run_id, run_name, time.time(), started_time),
+ )
+ return run_id
- def _maybe_init_tags(self, run_id, tag_to_metadata):
- """Returns a tag-to-ID map for the given tags, creating rows if needed.
+ def _maybe_init_tags(self, run_id, tag_to_metadata):
+ """Returns a tag-to-ID map for the given tags, creating rows if needed.
- Args:
- run_id: the ID of the run to which these tags belong.
- tag_to_metadata: map of tag name to SummaryMetadata for the tag.
- """
- cursor = self._db.cursor()
- # TODO: for huge numbers of tags (e.g. 1000+), this is slower than just
- # querying for the known tag names explicitly; find a better tradeoff.
- cursor.execute('SELECT tag_name, tag_id FROM Tags WHERE run_id = ?',
- (run_id,))
- tag_to_id = {row[0]: row[1] for row in cursor.fetchall()
- if row[0] in tag_to_metadata}
- new_tag_data = []
- for tag, metadata in six.iteritems(tag_to_metadata):
- if tag not in tag_to_id:
- tag_id = self._create_id()
- tag_to_id[tag] = tag_id
- new_tag_data.append((run_id, tag_id, tag, time.time(),
- metadata.display_name,
- metadata.plugin_data.plugin_name,
- self._make_blob(metadata.plugin_data.content)))
- cursor.executemany(
+ Args:
+ run_id: the ID of the run to which these tags belong.
+ tag_to_metadata: map of tag name to SummaryMetadata for the tag.
"""
+ cursor = self._db.cursor()
+ # TODO: for huge numbers of tags (e.g. 1000+), this is slower than just
+ # querying for the known tag names explicitly; find a better tradeoff.
+ cursor.execute(
+ "SELECT tag_name, tag_id FROM Tags WHERE run_id = ?", (run_id,)
+ )
+ tag_to_id = {
+ row[0]: row[1]
+ for row in cursor.fetchall()
+ if row[0] in tag_to_metadata
+ }
+ new_tag_data = []
+ for tag, metadata in six.iteritems(tag_to_metadata):
+ if tag not in tag_to_id:
+ tag_id = self._create_id()
+ tag_to_id[tag] = tag_id
+ new_tag_data.append(
+ (
+ run_id,
+ tag_id,
+ tag,
+ time.time(),
+ metadata.display_name,
+ metadata.plugin_data.plugin_name,
+ self._make_blob(metadata.plugin_data.content),
+ )
+ )
+ cursor.executemany(
+ """
INSERT INTO Tags (
run_id, tag_id, tag_name, inserted_time, display_name, plugin_name,
plugin_data
) VALUES (?, ?, ?, ?, ?, ?, ?)
""",
- new_tag_data)
- return tag_to_id
+ new_tag_data,
+ )
+ return tag_to_id
- def write_summaries(self, tagged_data, experiment_name, run_name):
- """Transactionally writes the given tagged summary data to the DB.
+ def write_summaries(self, tagged_data, experiment_name, run_name):
+ """Transactionally writes the given tagged summary data to the DB.
- Args:
- tagged_data: map from tag to TagData instances.
- experiment_name: name of experiment.
- run_name: name of run.
- """
- logger.debug('Writing summaries for %s tags', len(tagged_data))
- # Connection used as context manager for auto commit/rollback on exit.
- # We still need an explicit BEGIN, because it doesn't do one on enter,
- # it waits until the first DML command - which is totally broken.
- # See: https://stackoverflow.com/a/44448465/1179226
- with self._db:
- self._db.execute('BEGIN TRANSACTION')
- run_id = self._maybe_init_run(experiment_name, run_name)
- tag_to_metadata = {
- tag: tagdata.metadata for tag, tagdata in six.iteritems(tagged_data)
- }
- tag_to_id = self._maybe_init_tags(run_id, tag_to_metadata)
- tensor_values = []
- for tag, tagdata in six.iteritems(tagged_data):
- tag_id = tag_to_id[tag]
- for step, wall_time, tensor_proto in tagdata.values:
- dtype = tensor_proto.dtype
- shape = ','.join(str(d.size) for d in tensor_proto.tensor_shape.dim)
- # Use tensor_proto.tensor_content if it's set, to skip relatively
- # expensive extraction into intermediate ndarray.
- data = self._make_blob(
- tensor_proto.tensor_content or
- tensor_util.make_ndarray(tensor_proto).tobytes())
- tensor_values.append((tag_id, step, wall_time, dtype, shape, data))
- self._db.executemany(
- """
+ Args:
+ tagged_data: map from tag to TagData instances.
+ experiment_name: name of experiment.
+ run_name: name of run.
+ """
+ logger.debug("Writing summaries for %s tags", len(tagged_data))
+ # Connection used as context manager for auto commit/rollback on exit.
+ # We still need an explicit BEGIN, because it doesn't do one on enter,
+ # it waits until the first DML command - which is totally broken.
+ # See: https://stackoverflow.com/a/44448465/1179226
+ with self._db:
+ self._db.execute("BEGIN TRANSACTION")
+ run_id = self._maybe_init_run(experiment_name, run_name)
+ tag_to_metadata = {
+ tag: tagdata.metadata
+ for tag, tagdata in six.iteritems(tagged_data)
+ }
+ tag_to_id = self._maybe_init_tags(run_id, tag_to_metadata)
+ tensor_values = []
+ for tag, tagdata in six.iteritems(tagged_data):
+ tag_id = tag_to_id[tag]
+ for step, wall_time, tensor_proto in tagdata.values:
+ dtype = tensor_proto.dtype
+ shape = ",".join(
+ str(d.size) for d in tensor_proto.tensor_shape.dim
+ )
+ # Use tensor_proto.tensor_content if it's set, to skip relatively
+ # expensive extraction into intermediate ndarray.
+ data = self._make_blob(
+ tensor_proto.tensor_content
+ or tensor_util.make_ndarray(tensor_proto).tobytes()
+ )
+ tensor_values.append(
+ (tag_id, step, wall_time, dtype, shape, data)
+ )
+ self._db.executemany(
+ """
INSERT OR REPLACE INTO Tensors (
series, step, computed_time, dtype, shape, data
) VALUES (?, ?, ?, ?, ?, ?)
""",
- tensor_values)
+ tensor_values,
+ )
# See tensorflow/contrib/tensorboard/db/schema.cc for documentation.
@@ -414,17 +445,19 @@ def write_summaries(self, tagged_data, experiment_name, run_name):
def initialize_schema(connection):
- """Initializes the TensorBoard sqlite schema using the given connection.
+ """Initializes the TensorBoard sqlite schema using the given connection.
- Args:
- connection: A sqlite DB connection.
- """
- cursor = connection.cursor()
- cursor.execute("PRAGMA application_id={}".format(_TENSORBOARD_APPLICATION_ID))
- cursor.execute("PRAGMA user_version={}".format(_TENSORBOARD_USER_VERSION))
- with connection:
- for statement in _SCHEMA_STATEMENTS:
- lines = statement.strip('\n').split('\n')
- message = lines[0] + ('...' if len(lines) > 1 else '')
- logger.debug('Running DB init statement: %s', message)
- cursor.execute(statement)
+ Args:
+ connection: A sqlite DB connection.
+ """
+ cursor = connection.cursor()
+ cursor.execute(
+ "PRAGMA application_id={}".format(_TENSORBOARD_APPLICATION_ID)
+ )
+ cursor.execute("PRAGMA user_version={}".format(_TENSORBOARD_USER_VERSION))
+ with connection:
+ for statement in _SCHEMA_STATEMENTS:
+ lines = statement.strip("\n").split("\n")
+ message = lines[0] + ("..." if len(lines) > 1 else "")
+ logger.debug("Running DB init statement: %s", message)
+ cursor.execute(statement)
diff --git a/tensorboard/backend/experiment_id.py b/tensorboard/backend/experiment_id.py
index 7c427abe32..e695f26638 100644
--- a/tensorboard/backend/experiment_id.py
+++ b/tensorboard/backend/experiment_id.py
@@ -30,42 +30,42 @@
class ExperimentIdMiddleware(object):
- """WSGI middleware extracting experiment IDs from URL to environment.
+ """WSGI middleware extracting experiment IDs from URL to environment.
- Any request whose path matches `/experiment/SOME_EID[/...]` will have
- its first two path components stripped, and its experiment ID stored
- onto the WSGI environment with key taken from the `WSGI_ENVIRON_KEY`
- constant. All other requests will have paths unchanged and the
- experiment ID set to the empty string.
+ Any request whose path matches `/experiment/SOME_EID[/...]` will have
+ its first two path components stripped, and its experiment ID stored
+ onto the WSGI environment with key taken from the `WSGI_ENVIRON_KEY`
+ constant. All other requests will have paths unchanged and the
+ experiment ID set to the empty string.
- Instances of this class are WSGI applications (see PEP 3333).
- """
+ Instances of this class are WSGI applications (see PEP 3333).
+ """
- def __init__(self, application):
- """Initializes an `ExperimentIdMiddleware`.
+ def __init__(self, application):
+ """Initializes an `ExperimentIdMiddleware`.
- Args:
- application: The WSGI application to wrap (see PEP 3333).
- """
- self._application = application
- # Regular expression that matches the whole `/experiment/EID` prefix
- # (without any trailing slash) and captures the experiment ID.
- self._pat = re.compile(
- r"/%s/([^/]*)" % re.escape(_EXPERIMENT_PATH_COMPONENT)
- )
+ Args:
+ application: The WSGI application to wrap (see PEP 3333).
+ """
+ self._application = application
+ # Regular expression that matches the whole `/experiment/EID` prefix
+ # (without any trailing slash) and captures the experiment ID.
+ self._pat = re.compile(
+ r"/%s/([^/]*)" % re.escape(_EXPERIMENT_PATH_COMPONENT)
+ )
- def __call__(self, environ, start_response):
- path = environ.get("PATH_INFO", "")
- m = self._pat.match(path)
- if m:
- eid = m.group(1)
- new_path = path[m.end(0):]
- root = m.group(0)
- else:
- eid = ""
- new_path = path
- root = ""
- environ[WSGI_ENVIRON_KEY] = eid
- environ["PATH_INFO"] = new_path
- environ["SCRIPT_NAME"] = environ.get("SCRIPT_NAME", "") + root
- return self._application(environ, start_response)
+ def __call__(self, environ, start_response):
+ path = environ.get("PATH_INFO", "")
+ m = self._pat.match(path)
+ if m:
+ eid = m.group(1)
+ new_path = path[m.end(0) :]
+ root = m.group(0)
+ else:
+ eid = ""
+ new_path = path
+ root = ""
+ environ[WSGI_ENVIRON_KEY] = eid
+ environ["PATH_INFO"] = new_path
+ environ["SCRIPT_NAME"] = environ.get("SCRIPT_NAME", "") + root
+ return self._application(environ, start_response)
diff --git a/tensorboard/backend/experiment_id_test.py b/tensorboard/backend/experiment_id_test.py
index 698d6d4559..5ebc7c9091 100644
--- a/tensorboard/backend/experiment_id_test.py
+++ b/tensorboard/backend/experiment_id_test.py
@@ -28,66 +28,68 @@
class ExperimentIdMiddlewareTest(tb_test.TestCase):
- """Tests for `ExperimentIdMiddleware`."""
-
- def setUp(self):
- super(ExperimentIdMiddlewareTest, self).setUp()
- self.app = experiment_id.ExperimentIdMiddleware(self._echo_app)
- self.server = werkzeug_test.Client(self.app, werkzeug.BaseResponse)
-
- def _echo_app(self, environ, start_response):
- # https://www.python.org/dev/peps/pep-0333/#environ-variables
- data = {
- "eid": environ[experiment_id.WSGI_ENVIRON_KEY],
- "path": environ.get("PATH_INFO", ""),
- "script": environ.get("SCRIPT_NAME", ""),
- }
- body = json.dumps(data, sort_keys=True)
- start_response("200 OK", [("Content-Type", "application/json")])
- return [body]
-
- def _assert_ok(self, response, eid, path, script):
- self.assertEqual(response.status_code, 200)
- actual = json.loads(response.get_data())
- expected = dict(eid=eid, path=path, script=script)
- self.assertEqual(actual, expected)
-
- def test_no_experiment_empty_path(self):
- response = self.server.get("")
- self._assert_ok(response, eid="", path="", script="")
-
- def test_no_experiment_root_path(self):
- response = self.server.get("/")
- self._assert_ok(response, eid="", path="/", script="")
-
- def test_no_experiment_sub_path(self):
- response = self.server.get("/x/y")
- self._assert_ok(response, eid="", path="/x/y", script="")
-
- def test_with_experiment_empty_path(self):
- response = self.server.get("/experiment/123")
- self._assert_ok(response, eid="123", path="", script="/experiment/123")
-
- def test_with_experiment_root_path(self):
- response = self.server.get("/experiment/123/")
- self._assert_ok(response, eid="123", path="/", script="/experiment/123")
-
- def test_with_experiment_sub_path(self):
- response = self.server.get("/experiment/123/x/y")
- self._assert_ok(response, eid="123", path="/x/y", script="/experiment/123")
-
- def test_with_empty_experiment_empty_path(self):
- response = self.server.get("/experiment/")
- self._assert_ok(response, eid="", path="", script="/experiment/")
-
- def test_with_empty_experiment_root_path(self):
- response = self.server.get("/experiment//")
- self._assert_ok(response, eid="", path="/", script="/experiment/")
-
- def test_with_empty_experiment_sub_path(self):
- response = self.server.get("/experiment//x/y")
- self._assert_ok(response, eid="", path="/x/y", script="/experiment/")
+ """Tests for `ExperimentIdMiddleware`."""
+
+ def setUp(self):
+ super(ExperimentIdMiddlewareTest, self).setUp()
+ self.app = experiment_id.ExperimentIdMiddleware(self._echo_app)
+ self.server = werkzeug_test.Client(self.app, werkzeug.BaseResponse)
+
+ def _echo_app(self, environ, start_response):
+ # https://www.python.org/dev/peps/pep-0333/#environ-variables
+ data = {
+ "eid": environ[experiment_id.WSGI_ENVIRON_KEY],
+ "path": environ.get("PATH_INFO", ""),
+ "script": environ.get("SCRIPT_NAME", ""),
+ }
+ body = json.dumps(data, sort_keys=True)
+ start_response("200 OK", [("Content-Type", "application/json")])
+ return [body]
+
+ def _assert_ok(self, response, eid, path, script):
+ self.assertEqual(response.status_code, 200)
+ actual = json.loads(response.get_data())
+ expected = dict(eid=eid, path=path, script=script)
+ self.assertEqual(actual, expected)
+
+ def test_no_experiment_empty_path(self):
+ response = self.server.get("")
+ self._assert_ok(response, eid="", path="", script="")
+
+ def test_no_experiment_root_path(self):
+ response = self.server.get("/")
+ self._assert_ok(response, eid="", path="/", script="")
+
+ def test_no_experiment_sub_path(self):
+ response = self.server.get("/x/y")
+ self._assert_ok(response, eid="", path="/x/y", script="")
+
+ def test_with_experiment_empty_path(self):
+ response = self.server.get("/experiment/123")
+ self._assert_ok(response, eid="123", path="", script="/experiment/123")
+
+ def test_with_experiment_root_path(self):
+ response = self.server.get("/experiment/123/")
+ self._assert_ok(response, eid="123", path="/", script="/experiment/123")
+
+ def test_with_experiment_sub_path(self):
+ response = self.server.get("/experiment/123/x/y")
+ self._assert_ok(
+ response, eid="123", path="/x/y", script="/experiment/123"
+ )
+
+ def test_with_empty_experiment_empty_path(self):
+ response = self.server.get("/experiment/")
+ self._assert_ok(response, eid="", path="", script="/experiment/")
+
+ def test_with_empty_experiment_root_path(self):
+ response = self.server.get("/experiment//")
+ self._assert_ok(response, eid="", path="/", script="/experiment/")
+
+ def test_with_empty_experiment_sub_path(self):
+ response = self.server.get("/experiment//x/y")
+ self._assert_ok(response, eid="", path="/x/y", script="/experiment/")
if __name__ == "__main__":
- tb_test.main()
+ tb_test.main()
diff --git a/tensorboard/backend/http_util.py b/tensorboard/backend/http_util.py
index ce131b3f07..1fb9d2c675 100644
--- a/tensorboard/backend/http_util.py
+++ b/tensorboard/backend/http_util.py
@@ -27,14 +27,16 @@
import wsgiref.handlers
import six
-from six.moves.urllib import parse as urlparse # pylint: disable=wrong-import-order
+from six.moves.urllib import (
+ parse as urlparse,
+) # pylint: disable=wrong-import-order
import werkzeug
from tensorboard.backend import json_util
from tensorboard.compat import tf
-_DISALLOWED_CHAR_IN_DOMAIN = re.compile(r'\s')
+_DISALLOWED_CHAR_IN_DOMAIN = re.compile(r"\s")
# TODO(stephanwlee): Refactor this to not use the module variable but
# instead use a configurable via some kind of assets provider which would
@@ -48,204 +50,229 @@
_CSP_SCRIPT_UNSAFE_EVAL = True
_CSP_STYLE_DOMAINS_WHITELIST = []
-_EXTRACT_MIMETYPE_PATTERN = re.compile(r'^[^;\s]*')
-_EXTRACT_CHARSET_PATTERN = re.compile(r'charset=([-_0-9A-Za-z]+)')
+_EXTRACT_MIMETYPE_PATTERN = re.compile(r"^[^;\s]*")
+_EXTRACT_CHARSET_PATTERN = re.compile(r"charset=([-_0-9A-Za-z]+)")
# Allows *, gzip or x-gzip, but forbid gzip;q=0
# https://tools.ietf.org/html/rfc7231#section-5.3.4
_ALLOWS_GZIP_PATTERN = re.compile(
- r'(?:^|,|\s)(?:(?:x-)?gzip|\*)(?!;q=0)(?:\s|,|$)')
-
-_TEXTUAL_MIMETYPES = set([
- 'application/javascript',
- 'application/json',
- 'application/json+protobuf',
- 'image/svg+xml',
- 'text/css',
- 'text/csv',
- 'text/html',
- 'text/plain',
- 'text/tab-separated-values',
- 'text/x-protobuf',
-])
-
-_JSON_MIMETYPES = set([
- 'application/json',
- 'application/json+protobuf',
-])
+ r"(?:^|,|\s)(?:(?:x-)?gzip|\*)(?!;q=0)(?:\s|,|$)"
+)
-# Do not support xhtml for now.
-_HTML_MIMETYPE = 'text/html'
-
-def Respond(request,
- content,
- content_type,
- code=200,
- expires=0,
- content_encoding=None,
- encoding='utf-8',
- csp_scripts_sha256s=None):
- """Construct a werkzeug Response.
-
- Responses are transmitted to the browser with compression if: a) the browser
- supports it; b) it's sane to compress the content_type in question; and c)
- the content isn't already compressed, as indicated by the content_encoding
- parameter.
-
- Browser and proxy caching is completely disabled by default. If the expires
- parameter is greater than zero then the response will be able to be cached by
- the browser for that many seconds; however, proxies are still forbidden from
- caching so that developers can bypass the cache with Ctrl+Shift+R.
-
- For textual content that isn't JSON, the encoding parameter is used as the
- transmission charset which is automatically appended to the Content-Type
- header. That is unless of course the content_type parameter contains a
- charset parameter. If the two disagree, the characters in content will be
- transcoded to the latter.
-
- If content_type declares a JSON media type, then content MAY be a dict, list,
- tuple, or set, in which case this function has an implicit composition with
- json_util.Cleanse and json.dumps. The encoding parameter is used to decode
- byte strings within the JSON object; therefore transmitting binary data
- within JSON is not permitted. JSON is transmitted as ASCII unless the
- content_type parameter explicitly defines a charset parameter, in which case
- the serialized JSON bytes will use that instead of escape sequences.
-
- Args:
- request: A werkzeug Request object. Used mostly to check the
- Accept-Encoding header.
- content: Payload data as byte string, unicode string, or maybe JSON.
- content_type: Media type and optionally an output charset.
- code: Numeric HTTP status code to use.
- expires: Second duration for browser caching.
- content_encoding: Encoding if content is already encoded, e.g. 'gzip'.
- encoding: Input charset if content parameter has byte strings.
- csp_scripts_sha256s: List of base64 serialized sha256 of whitelisted script
- elements for script-src of the Content-Security-Policy. If it is None, the
- HTML will disallow any script to execute. It is only be used when the
- content_type is text/html.
-
- Returns:
- A werkzeug Response object (a WSGI application).
- """
-
- mimetype = _EXTRACT_MIMETYPE_PATTERN.search(content_type).group(0)
- charset_match = _EXTRACT_CHARSET_PATTERN.search(content_type)
- charset = charset_match.group(1) if charset_match else encoding
- textual = charset_match or mimetype in _TEXTUAL_MIMETYPES
- if (mimetype in _JSON_MIMETYPES and
- isinstance(content, (dict, list, set, tuple))):
- content = json.dumps(json_util.Cleanse(content, encoding),
- ensure_ascii=not charset_match)
- if charset != encoding:
- content = tf.compat.as_text(content, encoding)
- content = tf.compat.as_bytes(content, charset)
- if textual and not charset_match and mimetype not in _JSON_MIMETYPES:
- content_type += '; charset=' + charset
- gzip_accepted = _ALLOWS_GZIP_PATTERN.search(
- request.headers.get('Accept-Encoding', ''))
- # Automatically gzip uncompressed text data if accepted.
- if textual and not content_encoding and gzip_accepted:
- out = six.BytesIO()
- # Set mtime to zero to make payload for a given input deterministic.
- with gzip.GzipFile(fileobj=out, mode='wb', compresslevel=3, mtime=0) as f:
- f.write(content)
- content = out.getvalue()
- content_encoding = 'gzip'
-
- content_length = len(content)
- direct_passthrough = False
- # Automatically streamwise-gunzip precompressed data if not accepted.
- if content_encoding == 'gzip' and not gzip_accepted:
- gzip_file = gzip.GzipFile(fileobj=six.BytesIO(content), mode='rb')
- # Last 4 bytes of gzip formatted data (little-endian) store the original
- # content length mod 2^32; we just assume it's the content length. That
- # means we can't streamwise-gunzip >4 GB precompressed file; this is ok.
- content_length = struct.unpack(' 0:
- e = wsgiref.handlers.format_date_time(time.time() + float(expires))
- headers.append(('Expires', e))
- headers.append(('Cache-Control', 'private, max-age=%d' % expires))
- else:
- headers.append(('Expires', '0'))
- headers.append(('Cache-Control', 'no-cache, must-revalidate'))
- if mimetype == _HTML_MIMETYPE:
- _validate_global_whitelist(_CSP_IMG_DOMAINS_WHITELIST)
- _validate_global_whitelist(_CSP_STYLE_DOMAINS_WHITELIST)
- _validate_global_whitelist(_CSP_FONT_DOMAINS_WHITELIST)
- _validate_global_whitelist(_CSP_FRAME_DOMAINS_WHITELIST)
- _validate_global_whitelist(_CSP_SCRIPT_DOMAINS_WHITELIST)
-
- frags = _CSP_SCRIPT_DOMAINS_WHITELIST + [
- "'self'" if _CSP_SCRIPT_SELF else '',
- "'unsafe-eval'" if _CSP_SCRIPT_UNSAFE_EVAL else '',
- ] + [
- "'sha256-{}'".format(sha256) for sha256 in (csp_scripts_sha256s or [])
+_TEXTUAL_MIMETYPES = set(
+ [
+ "application/javascript",
+ "application/json",
+ "application/json+protobuf",
+ "image/svg+xml",
+ "text/css",
+ "text/csv",
+ "text/html",
+ "text/plain",
+ "text/tab-separated-values",
+ "text/x-protobuf",
]
- script_srcs = _create_csp_string(*frags)
-
- csp_string = ';'.join([
- "default-src 'self'",
- 'font-src %s' % _create_csp_string(
- "'self'",
- *_CSP_FONT_DOMAINS_WHITELIST
- ),
- 'frame-ancestors *',
- # Dynamic plugins are rendered inside an iframe.
- 'frame-src %s' % _create_csp_string(
- "'self'",
- *_CSP_FRAME_DOMAINS_WHITELIST
- ),
- 'img-src %s' % _create_csp_string(
- "'self'",
- # used by favicon
- 'data:',
- # used by What-If tool for image sprites.
- 'blob:',
- *_CSP_IMG_DOMAINS_WHITELIST
- ),
- "object-src 'none'",
- 'style-src %s' % _create_csp_string(
- "'self'",
- # used by google-chart
- 'https://www.gstatic.com',
- 'data:',
- # inline styles: Polymer templates + d3 uses inline styles.
- "'unsafe-inline'",
- *_CSP_STYLE_DOMAINS_WHITELIST
- ),
- "script-src %s" % script_srcs,
- ])
-
- headers.append(('Content-Security-Policy', csp_string))
-
- if request.method == 'HEAD':
- content = None
-
- return werkzeug.wrappers.Response(
- response=content, status=code, headers=headers, content_type=content_type,
- direct_passthrough=direct_passthrough)
+)
+
+_JSON_MIMETYPES = set(["application/json", "application/json+protobuf",])
+
+# Do not support xhtml for now.
+_HTML_MIMETYPE = "text/html"
+
+
+def Respond(
+ request,
+ content,
+ content_type,
+ code=200,
+ expires=0,
+ content_encoding=None,
+ encoding="utf-8",
+ csp_scripts_sha256s=None,
+):
+ """Construct a werkzeug Response.
+
+ Responses are transmitted to the browser with compression if: a) the browser
+ supports it; b) it's sane to compress the content_type in question; and c)
+ the content isn't already compressed, as indicated by the content_encoding
+ parameter.
+
+ Browser and proxy caching is completely disabled by default. If the expires
+ parameter is greater than zero then the response will be able to be cached by
+ the browser for that many seconds; however, proxies are still forbidden from
+ caching so that developers can bypass the cache with Ctrl+Shift+R.
+
+ For textual content that isn't JSON, the encoding parameter is used as the
+ transmission charset which is automatically appended to the Content-Type
+ header. That is unless of course the content_type parameter contains a
+ charset parameter. If the two disagree, the characters in content will be
+ transcoded to the latter.
+
+ If content_type declares a JSON media type, then content MAY be a dict, list,
+ tuple, or set, in which case this function has an implicit composition with
+ json_util.Cleanse and json.dumps. The encoding parameter is used to decode
+ byte strings within the JSON object; therefore transmitting binary data
+ within JSON is not permitted. JSON is transmitted as ASCII unless the
+ content_type parameter explicitly defines a charset parameter, in which case
+ the serialized JSON bytes will use that instead of escape sequences.
+
+ Args:
+ request: A werkzeug Request object. Used mostly to check the
+ Accept-Encoding header.
+ content: Payload data as byte string, unicode string, or maybe JSON.
+ content_type: Media type and optionally an output charset.
+ code: Numeric HTTP status code to use.
+ expires: Second duration for browser caching.
+ content_encoding: Encoding if content is already encoded, e.g. 'gzip'.
+ encoding: Input charset if content parameter has byte strings.
+ csp_scripts_sha256s: List of base64 serialized sha256 of whitelisted script
+ elements for script-src of the Content-Security-Policy. If it is None, the
+ HTML will disallow any script to execute. It is only be used when the
+ content_type is text/html.
+
+ Returns:
+ A werkzeug Response object (a WSGI application).
+ """
+
+ mimetype = _EXTRACT_MIMETYPE_PATTERN.search(content_type).group(0)
+ charset_match = _EXTRACT_CHARSET_PATTERN.search(content_type)
+ charset = charset_match.group(1) if charset_match else encoding
+ textual = charset_match or mimetype in _TEXTUAL_MIMETYPES
+ if mimetype in _JSON_MIMETYPES and isinstance(
+ content, (dict, list, set, tuple)
+ ):
+ content = json.dumps(
+ json_util.Cleanse(content, encoding), ensure_ascii=not charset_match
+ )
+ if charset != encoding:
+ content = tf.compat.as_text(content, encoding)
+ content = tf.compat.as_bytes(content, charset)
+ if textual and not charset_match and mimetype not in _JSON_MIMETYPES:
+ content_type += "; charset=" + charset
+ gzip_accepted = _ALLOWS_GZIP_PATTERN.search(
+ request.headers.get("Accept-Encoding", "")
+ )
+ # Automatically gzip uncompressed text data if accepted.
+ if textual and not content_encoding and gzip_accepted:
+ out = six.BytesIO()
+ # Set mtime to zero to make payload for a given input deterministic.
+ with gzip.GzipFile(
+ fileobj=out, mode="wb", compresslevel=3, mtime=0
+ ) as f:
+ f.write(content)
+ content = out.getvalue()
+ content_encoding = "gzip"
+
+ content_length = len(content)
+ direct_passthrough = False
+ # Automatically streamwise-gunzip precompressed data if not accepted.
+ if content_encoding == "gzip" and not gzip_accepted:
+ gzip_file = gzip.GzipFile(fileobj=six.BytesIO(content), mode="rb")
+ # Last 4 bytes of gzip formatted data (little-endian) store the original
+ # content length mod 2^32; we just assume it's the content length. That
+ # means we can't streamwise-gunzip >4 GB precompressed file; this is ok.
+ content_length = struct.unpack(" 0:
+ e = wsgiref.handlers.format_date_time(time.time() + float(expires))
+ headers.append(("Expires", e))
+ headers.append(("Cache-Control", "private, max-age=%d" % expires))
+ else:
+ headers.append(("Expires", "0"))
+ headers.append(("Cache-Control", "no-cache, must-revalidate"))
+ if mimetype == _HTML_MIMETYPE:
+ _validate_global_whitelist(_CSP_IMG_DOMAINS_WHITELIST)
+ _validate_global_whitelist(_CSP_STYLE_DOMAINS_WHITELIST)
+ _validate_global_whitelist(_CSP_FONT_DOMAINS_WHITELIST)
+ _validate_global_whitelist(_CSP_FRAME_DOMAINS_WHITELIST)
+ _validate_global_whitelist(_CSP_SCRIPT_DOMAINS_WHITELIST)
+
+ frags = (
+ _CSP_SCRIPT_DOMAINS_WHITELIST
+ + [
+ "'self'" if _CSP_SCRIPT_SELF else "",
+ "'unsafe-eval'" if _CSP_SCRIPT_UNSAFE_EVAL else "",
+ ]
+ + [
+ "'sha256-{}'".format(sha256)
+ for sha256 in (csp_scripts_sha256s or [])
+ ]
+ )
+ script_srcs = _create_csp_string(*frags)
+
+ csp_string = ";".join(
+ [
+ "default-src 'self'",
+ "font-src %s"
+ % _create_csp_string("'self'", *_CSP_FONT_DOMAINS_WHITELIST),
+ "frame-ancestors *",
+ # Dynamic plugins are rendered inside an iframe.
+ "frame-src %s"
+ % _create_csp_string("'self'", *_CSP_FRAME_DOMAINS_WHITELIST),
+ "img-src %s"
+ % _create_csp_string(
+ "'self'",
+ # used by favicon
+ "data:",
+ # used by What-If tool for image sprites.
+ "blob:",
+ *_CSP_IMG_DOMAINS_WHITELIST
+ ),
+ "object-src 'none'",
+ "style-src %s"
+ % _create_csp_string(
+ "'self'",
+ # used by google-chart
+ "https://www.gstatic.com",
+ "data:",
+ # inline styles: Polymer templates + d3 uses inline styles.
+ "'unsafe-inline'",
+ *_CSP_STYLE_DOMAINS_WHITELIST
+ ),
+ "script-src %s" % script_srcs,
+ ]
+ )
+
+ headers.append(("Content-Security-Policy", csp_string))
+
+ if request.method == "HEAD":
+ content = None
+
+ return werkzeug.wrappers.Response(
+ response=content,
+ status=code,
+ headers=headers,
+ content_type=content_type,
+ direct_passthrough=direct_passthrough,
+ )
+
def _validate_global_whitelist(whitelists):
- for domain in whitelists:
- url = urlparse.urlparse(domain)
- if not url.scheme == 'https' or not url.netloc:
- raise ValueError('Expected all whitelist to be a https URL: %r' % domain)
- if ';' in domain:
- raise ValueError('Expected whitelist domain to not contain ";": %r' % domain)
- if _DISALLOWED_CHAR_IN_DOMAIN.search(domain):
- raise ValueError(
- 'Expected whitelist domain to not contain a whitespace: %r' % domain)
+ for domain in whitelists:
+ url = urlparse.urlparse(domain)
+ if not url.scheme == "https" or not url.netloc:
+ raise ValueError(
+ "Expected all whitelist to be a https URL: %r" % domain
+ )
+ if ";" in domain:
+ raise ValueError(
+ 'Expected whitelist domain to not contain ";": %r' % domain
+ )
+ if _DISALLOWED_CHAR_IN_DOMAIN.search(domain):
+ raise ValueError(
+ "Expected whitelist domain to not contain a whitespace: %r"
+ % domain
+ )
+
def _create_csp_string(*csp_fragments):
- csp_string = ' '.join([frag for frag in csp_fragments if frag])
- return csp_string if csp_string else "'none'"
+ csp_string = " ".join([frag for frag in csp_fragments if frag])
+ return csp_string if csp_string else "'none'"
diff --git a/tensorboard/backend/http_util_test.py b/tensorboard/backend/http_util_test.py
index d165450196..de4e31b611 100644
--- a/tensorboard/backend/http_util_test.py
+++ b/tensorboard/backend/http_util_test.py
@@ -26,10 +26,10 @@
import six
try:
- # python version >= 3.3
- from unittest import mock
+ # python version >= 3.3
+ from unittest import mock
except ImportError:
- import mock # pylint: disable=unused-import
+ import mock # pylint: disable=unused-import
from werkzeug import test as wtest
from werkzeug import wrappers
@@ -39,286 +39,375 @@
class RespondTest(tb_test.TestCase):
-
- def testHelloWorld(self):
- q = wrappers.Request(wtest.EnvironBuilder().get_environ())
- r = http_util.Respond(q, 'hello world', 'text/html')
- self.assertEqual(r.status_code, 200)
- self.assertEqual(r.response, [six.b('hello world')])
- self.assertEqual(r.headers.get('Content-Length'), '18')
-
- def testHeadRequest_doesNotWrite(self):
- builder = wtest.EnvironBuilder(method='HEAD')
- env = builder.get_environ()
- request = wrappers.Request(env)
- r = http_util.Respond(request, 'hello world', 'text/html')
- self.assertEqual(r.status_code, 200)
- self.assertEqual(r.response, [])
- self.assertEqual(r.headers.get('Content-Length'), '18')
-
- def testPlainText_appendsUtf8ToContentType(self):
- q = wrappers.Request(wtest.EnvironBuilder().get_environ())
- r = http_util.Respond(q, 'hello', 'text/plain')
- h = r.headers
- self.assertEqual(h.get('Content-Type'), 'text/plain; charset=utf-8')
-
- def testContentLength_isInBytes(self):
- q = wrappers.Request(wtest.EnvironBuilder().get_environ())
- r = http_util.Respond(q, '爱', 'text/plain')
- self.assertEqual(r.headers.get('Content-Length'), '3')
- q = wrappers.Request(wtest.EnvironBuilder().get_environ())
- r = http_util.Respond(q, '爱'.encode('utf-8'), 'text/plain')
- self.assertEqual(r.headers.get('Content-Length'), '3')
-
- def testResponseCharsetTranscoding(self):
- bean = '要依法治国是赞美那些谁是公义的和惩罚恶人。 - 韩非'
-
- # input is unicode string, output is gbk string
- q = wrappers.Request(wtest.EnvironBuilder().get_environ())
- r = http_util.Respond(q, bean, 'text/plain; charset=gbk')
- self.assertEqual(r.response, [bean.encode('gbk')])
-
- # input is utf-8 string, output is gbk string
- q = wrappers.Request(wtest.EnvironBuilder().get_environ())
- r = http_util.Respond(q, bean.encode('utf-8'), 'text/plain; charset=gbk')
- self.assertEqual(r.response, [bean.encode('gbk')])
-
- # input is object with unicode strings, output is gbk json
- q = wrappers.Request(wtest.EnvironBuilder().get_environ())
- r = http_util.Respond(q, {'red': bean}, 'application/json; charset=gbk')
- self.assertEqual(r.response, [b'{"red": "' + bean.encode('gbk') + b'"}'])
-
- # input is object with utf-8 strings, output is gbk json
- q = wrappers.Request(wtest.EnvironBuilder().get_environ())
- r = http_util.Respond(
- q, {'red': bean.encode('utf-8')}, 'application/json; charset=gbk')
- self.assertEqual(r.response, [b'{"red": "' + bean.encode('gbk') + b'"}'])
-
- # input is object with gbk strings, output is gbk json
- q = wrappers.Request(wtest.EnvironBuilder().get_environ())
- r = http_util.Respond(
- q, {'red': bean.encode('gbk')},
- 'application/json; charset=gbk',
- encoding='gbk')
- self.assertEqual(r.response, [b'{"red": "' + bean.encode('gbk') + b'"}'])
-
- def testAcceptGzip_compressesResponse(self):
- fall_of_hyperion_canto1_stanza1 = '\n'.join([
- 'Fanatics have their dreams, wherewith they weave',
- 'A paradise for a sect; the savage too',
- 'From forth the loftiest fashion of his sleep',
- 'Guesses at Heaven; pity these have not',
- 'Trac\'d upon vellum or wild Indian leaf',
- 'The shadows of melodious utterance.',
- 'But bare of laurel they live, dream, and die;',
- 'For Poesy alone can tell her dreams,',
- 'With the fine spell of words alone can save',
- 'Imagination from the sable charm',
- 'And dumb enchantment. Who alive can say,',
- '\'Thou art no Poet may\'st not tell thy dreams?\'',
- 'Since every man whose soul is not a clod',
- 'Hath visions, and would speak, if he had loved',
- 'And been well nurtured in his mother tongue.',
- 'Whether the dream now purpos\'d to rehearse',
- 'Be poet\'s or fanatic\'s will be known',
- 'When this warm scribe my hand is in the grave.',
- ])
-
- e1 = wtest.EnvironBuilder(headers={'Accept-Encoding': '*'}).get_environ()
- any_encoding = wrappers.Request(e1)
-
- r = http_util.Respond(
- any_encoding, fall_of_hyperion_canto1_stanza1, 'text/plain')
- self.assertEqual(r.headers.get('Content-Encoding'), 'gzip')
- self.assertEqual(_gunzip(r.response[0]), # pylint: disable=unsubscriptable-object
- fall_of_hyperion_canto1_stanza1.encode('utf-8'))
-
- e2 = wtest.EnvironBuilder(headers={'Accept-Encoding': 'gzip'}).get_environ()
- gzip_encoding = wrappers.Request(e2)
-
- r = http_util.Respond(
- gzip_encoding, fall_of_hyperion_canto1_stanza1, 'text/plain')
- self.assertEqual(r.headers.get('Content-Encoding'), 'gzip')
- self.assertEqual(_gunzip(r.response[0]), # pylint: disable=unsubscriptable-object
- fall_of_hyperion_canto1_stanza1.encode('utf-8'))
-
- r = http_util.Respond(
- any_encoding, fall_of_hyperion_canto1_stanza1, 'image/png')
- self.assertEqual(
- r.response, [fall_of_hyperion_canto1_stanza1.encode('utf-8')])
-
- def testAcceptGzip_alreadyCompressed_sendsPrecompressedResponse(self):
- gzip_text = _gzip(b'hello hello hello world')
- e = wtest.EnvironBuilder(headers={'Accept-Encoding': 'gzip'}).get_environ()
- q = wrappers.Request(e)
- r = http_util.Respond(q, gzip_text, 'text/plain', content_encoding='gzip')
- self.assertEqual(r.response, [gzip_text]) # Still singly zipped
-
- def testPrecompressedResponse_noAcceptGzip_decompressesResponse(self):
- orig_text = b'hello hello hello world'
- gzip_text = _gzip(orig_text)
- q = wrappers.Request(wtest.EnvironBuilder().get_environ())
- r = http_util.Respond(q, gzip_text, 'text/plain', content_encoding='gzip')
- # Streaming gunzip produces file-wrapper application iterator as response,
- # so rejoin it into the full response before comparison.
- full_response = b''.join(r.response)
- self.assertEqual(full_response, orig_text)
-
- def testPrecompressedResponse_streamingDecompression_catchesBadSize(self):
- orig_text = b'hello hello hello world'
- gzip_text = _gzip(orig_text)
- # Corrupt the gzipped data's stored content size (last 4 bytes).
- bad_text = gzip_text[:-4] + _bitflip(gzip_text[-4:])
- q = wrappers.Request(wtest.EnvironBuilder().get_environ())
- r = http_util.Respond(q, bad_text, 'text/plain', content_encoding='gzip')
- # Streaming gunzip defers actual unzipping until response is used; once
- # we iterate over the whole file-wrapper application iterator, the
- # underlying GzipFile should be closed, and throw the size check error.
- with six.assertRaisesRegex(self, IOError, 'Incorrect length'):
- _ = list(r.response)
-
- def testJson_getsAutoSerialized(self):
- q = wrappers.Request(wtest.EnvironBuilder().get_environ())
- r = http_util.Respond(q, [1, 2, 3], 'application/json')
- self.assertEqual(r.response, [b'[1, 2, 3]'])
-
- def testExpires_setsCruiseControl(self):
- q = wrappers.Request(wtest.EnvironBuilder().get_environ())
- r = http_util.Respond(q, 'hello world', 'text/html', expires=60)
- self.assertEqual(r.headers.get('Cache-Control'), 'private, max-age=60')
-
- def testCsp(self):
- q = wrappers.Request(wtest.EnvironBuilder().get_environ())
- r = http_util.Respond(
- q, 'hello', 'text/html', csp_scripts_sha256s=['abcdefghi'])
- expected_csp = (
- "default-src 'self';font-src 'self';frame-ancestors *;"
- "frame-src 'self';img-src 'self' data: blob:;object-src 'none';"
- "style-src 'self' https://www.gstatic.com data: 'unsafe-inline';"
- "script-src 'self' 'unsafe-eval' 'sha256-abcdefghi'"
- )
- self.assertEqual(r.headers.get('Content-Security-Policy'), expected_csp)
-
- @mock.patch.object(http_util, '_CSP_SCRIPT_SELF', False)
- def testCsp_noHash(self):
- q = wrappers.Request(wtest.EnvironBuilder().get_environ())
- r = http_util.Respond(q, 'hello', 'text/html', csp_scripts_sha256s=None)
- expected_csp = (
- "default-src 'self';font-src 'self';frame-ancestors *;"
- "frame-src 'self';img-src 'self' data: blob:;object-src 'none';"
- "style-src 'self' https://www.gstatic.com data: 'unsafe-inline';"
- "script-src 'unsafe-eval'"
- )
- self.assertEqual(r.headers.get('Content-Security-Policy'), expected_csp)
-
- @mock.patch.object(http_util, '_CSP_SCRIPT_SELF', False)
- @mock.patch.object(http_util, '_CSP_SCRIPT_UNSAFE_EVAL', False)
- def testCsp_noHash_noUnsafeEval(self):
- q = wrappers.Request(wtest.EnvironBuilder().get_environ())
- r = http_util.Respond(q, 'hello', 'text/html', csp_scripts_sha256s=None)
- expected_csp = (
- "default-src 'self';font-src 'self';frame-ancestors *;"
- "frame-src 'self';img-src 'self' data: blob:;object-src 'none';"
- "style-src 'self' https://www.gstatic.com data: 'unsafe-inline';"
- "script-src 'none'"
+ def testHelloWorld(self):
+ q = wrappers.Request(wtest.EnvironBuilder().get_environ())
+ r = http_util.Respond(q, "hello world", "text/html")
+ self.assertEqual(r.status_code, 200)
+ self.assertEqual(r.response, [six.b("hello world")])
+ self.assertEqual(r.headers.get("Content-Length"), "18")
+
+ def testHeadRequest_doesNotWrite(self):
+ builder = wtest.EnvironBuilder(method="HEAD")
+ env = builder.get_environ()
+ request = wrappers.Request(env)
+ r = http_util.Respond(request, "hello world", "text/html")
+ self.assertEqual(r.status_code, 200)
+ self.assertEqual(r.response, [])
+ self.assertEqual(r.headers.get("Content-Length"), "18")
+
+ def testPlainText_appendsUtf8ToContentType(self):
+ q = wrappers.Request(wtest.EnvironBuilder().get_environ())
+ r = http_util.Respond(q, "hello", "text/plain")
+ h = r.headers
+ self.assertEqual(h.get("Content-Type"), "text/plain; charset=utf-8")
+
+ def testContentLength_isInBytes(self):
+ q = wrappers.Request(wtest.EnvironBuilder().get_environ())
+ r = http_util.Respond(q, "爱", "text/plain")
+ self.assertEqual(r.headers.get("Content-Length"), "3")
+ q = wrappers.Request(wtest.EnvironBuilder().get_environ())
+ r = http_util.Respond(q, "爱".encode("utf-8"), "text/plain")
+ self.assertEqual(r.headers.get("Content-Length"), "3")
+
+ def testResponseCharsetTranscoding(self):
+ bean = "要依法治国是赞美那些谁是公义的和惩罚恶人。 - 韩非"
+
+ # input is unicode string, output is gbk string
+ q = wrappers.Request(wtest.EnvironBuilder().get_environ())
+ r = http_util.Respond(q, bean, "text/plain; charset=gbk")
+ self.assertEqual(r.response, [bean.encode("gbk")])
+
+ # input is utf-8 string, output is gbk string
+ q = wrappers.Request(wtest.EnvironBuilder().get_environ())
+ r = http_util.Respond(
+ q, bean.encode("utf-8"), "text/plain; charset=gbk"
+ )
+ self.assertEqual(r.response, [bean.encode("gbk")])
+
+ # input is object with unicode strings, output is gbk json
+ q = wrappers.Request(wtest.EnvironBuilder().get_environ())
+ r = http_util.Respond(q, {"red": bean}, "application/json; charset=gbk")
+ self.assertEqual(
+ r.response, [b'{"red": "' + bean.encode("gbk") + b'"}']
+ )
+
+ # input is object with utf-8 strings, output is gbk json
+ q = wrappers.Request(wtest.EnvironBuilder().get_environ())
+ r = http_util.Respond(
+ q, {"red": bean.encode("utf-8")}, "application/json; charset=gbk"
+ )
+ self.assertEqual(
+ r.response, [b'{"red": "' + bean.encode("gbk") + b'"}']
+ )
+
+ # input is object with gbk strings, output is gbk json
+ q = wrappers.Request(wtest.EnvironBuilder().get_environ())
+ r = http_util.Respond(
+ q,
+ {"red": bean.encode("gbk")},
+ "application/json; charset=gbk",
+ encoding="gbk",
+ )
+ self.assertEqual(
+ r.response, [b'{"red": "' + bean.encode("gbk") + b'"}']
+ )
+
+ def testAcceptGzip_compressesResponse(self):
+ fall_of_hyperion_canto1_stanza1 = "\n".join(
+ [
+ "Fanatics have their dreams, wherewith they weave",
+ "A paradise for a sect; the savage too",
+ "From forth the loftiest fashion of his sleep",
+ "Guesses at Heaven; pity these have not",
+ "Trac'd upon vellum or wild Indian leaf",
+ "The shadows of melodious utterance.",
+ "But bare of laurel they live, dream, and die;",
+ "For Poesy alone can tell her dreams,",
+ "With the fine spell of words alone can save",
+ "Imagination from the sable charm",
+ "And dumb enchantment. Who alive can say,",
+ "'Thou art no Poet may'st not tell thy dreams?'",
+ "Since every man whose soul is not a clod",
+ "Hath visions, and would speak, if he had loved",
+ "And been well nurtured in his mother tongue.",
+ "Whether the dream now purpos'd to rehearse",
+ "Be poet's or fanatic's will be known",
+ "When this warm scribe my hand is in the grave.",
+ ]
+ )
+
+ e1 = wtest.EnvironBuilder(
+ headers={"Accept-Encoding": "*"}
+ ).get_environ()
+ any_encoding = wrappers.Request(e1)
+
+ r = http_util.Respond(
+ any_encoding, fall_of_hyperion_canto1_stanza1, "text/plain"
+ )
+ self.assertEqual(r.headers.get("Content-Encoding"), "gzip")
+ self.assertEqual(
+ _gunzip(r.response[0]), # pylint: disable=unsubscriptable-object
+ fall_of_hyperion_canto1_stanza1.encode("utf-8"),
+ )
+
+ e2 = wtest.EnvironBuilder(
+ headers={"Accept-Encoding": "gzip"}
+ ).get_environ()
+ gzip_encoding = wrappers.Request(e2)
+
+ r = http_util.Respond(
+ gzip_encoding, fall_of_hyperion_canto1_stanza1, "text/plain"
+ )
+ self.assertEqual(r.headers.get("Content-Encoding"), "gzip")
+ self.assertEqual(
+ _gunzip(r.response[0]), # pylint: disable=unsubscriptable-object
+ fall_of_hyperion_canto1_stanza1.encode("utf-8"),
+ )
+
+ r = http_util.Respond(
+ any_encoding, fall_of_hyperion_canto1_stanza1, "image/png"
+ )
+ self.assertEqual(
+ r.response, [fall_of_hyperion_canto1_stanza1.encode("utf-8")]
+ )
+
+ def testAcceptGzip_alreadyCompressed_sendsPrecompressedResponse(self):
+ gzip_text = _gzip(b"hello hello hello world")
+ e = wtest.EnvironBuilder(
+ headers={"Accept-Encoding": "gzip"}
+ ).get_environ()
+ q = wrappers.Request(e)
+ r = http_util.Respond(
+ q, gzip_text, "text/plain", content_encoding="gzip"
+ )
+ self.assertEqual(r.response, [gzip_text]) # Still singly zipped
+
+ def testPrecompressedResponse_noAcceptGzip_decompressesResponse(self):
+ orig_text = b"hello hello hello world"
+ gzip_text = _gzip(orig_text)
+ q = wrappers.Request(wtest.EnvironBuilder().get_environ())
+ r = http_util.Respond(
+ q, gzip_text, "text/plain", content_encoding="gzip"
+ )
+ # Streaming gunzip produces file-wrapper application iterator as response,
+ # so rejoin it into the full response before comparison.
+ full_response = b"".join(r.response)
+ self.assertEqual(full_response, orig_text)
+
+ def testPrecompressedResponse_streamingDecompression_catchesBadSize(self):
+ orig_text = b"hello hello hello world"
+ gzip_text = _gzip(orig_text)
+ # Corrupt the gzipped data's stored content size (last 4 bytes).
+ bad_text = gzip_text[:-4] + _bitflip(gzip_text[-4:])
+ q = wrappers.Request(wtest.EnvironBuilder().get_environ())
+ r = http_util.Respond(
+ q, bad_text, "text/plain", content_encoding="gzip"
+ )
+ # Streaming gunzip defers actual unzipping until response is used; once
+ # we iterate over the whole file-wrapper application iterator, the
+ # underlying GzipFile should be closed, and throw the size check error.
+ with six.assertRaisesRegex(self, IOError, "Incorrect length"):
+ _ = list(r.response)
+
+ def testJson_getsAutoSerialized(self):
+ q = wrappers.Request(wtest.EnvironBuilder().get_environ())
+ r = http_util.Respond(q, [1, 2, 3], "application/json")
+ self.assertEqual(r.response, [b"[1, 2, 3]"])
+
+ def testExpires_setsCruiseControl(self):
+ q = wrappers.Request(wtest.EnvironBuilder().get_environ())
+ r = http_util.Respond(q, "hello world", "text/html", expires=60)
+ self.assertEqual(r.headers.get("Cache-Control"), "private, max-age=60")
+
+ def testCsp(self):
+ q = wrappers.Request(wtest.EnvironBuilder().get_environ())
+ r = http_util.Respond(
+ q, "hello", "text/html", csp_scripts_sha256s=["abcdefghi"]
+ )
+ expected_csp = (
+ "default-src 'self';font-src 'self';frame-ancestors *;"
+ "frame-src 'self';img-src 'self' data: blob:;object-src 'none';"
+ "style-src 'self' https://www.gstatic.com data: 'unsafe-inline';"
+ "script-src 'self' 'unsafe-eval' 'sha256-abcdefghi'"
+ )
+ self.assertEqual(r.headers.get("Content-Security-Policy"), expected_csp)
+
+ @mock.patch.object(http_util, "_CSP_SCRIPT_SELF", False)
+ def testCsp_noHash(self):
+ q = wrappers.Request(wtest.EnvironBuilder().get_environ())
+ r = http_util.Respond(
+ q, "hello", "text/html", csp_scripts_sha256s=None
+ )
+ expected_csp = (
+ "default-src 'self';font-src 'self';frame-ancestors *;"
+ "frame-src 'self';img-src 'self' data: blob:;object-src 'none';"
+ "style-src 'self' https://www.gstatic.com data: 'unsafe-inline';"
+ "script-src 'unsafe-eval'"
+ )
+ self.assertEqual(r.headers.get("Content-Security-Policy"), expected_csp)
+
+ @mock.patch.object(http_util, "_CSP_SCRIPT_SELF", False)
+ @mock.patch.object(http_util, "_CSP_SCRIPT_UNSAFE_EVAL", False)
+ def testCsp_noHash_noUnsafeEval(self):
+ q = wrappers.Request(wtest.EnvironBuilder().get_environ())
+ r = http_util.Respond(
+ q, "hello", "text/html", csp_scripts_sha256s=None
+ )
+ expected_csp = (
+ "default-src 'self';font-src 'self';frame-ancestors *;"
+ "frame-src 'self';img-src 'self' data: blob:;object-src 'none';"
+ "style-src 'self' https://www.gstatic.com data: 'unsafe-inline';"
+ "script-src 'none'"
+ )
+ self.assertEqual(r.headers.get("Content-Security-Policy"), expected_csp)
+
+ @mock.patch.object(http_util, "_CSP_SCRIPT_SELF", True)
+ @mock.patch.object(http_util, "_CSP_SCRIPT_UNSAFE_EVAL", False)
+ def testCsp_onlySelf(self):
+ q = wrappers.Request(wtest.EnvironBuilder().get_environ())
+ r = http_util.Respond(
+ q, "hello", "text/html", csp_scripts_sha256s=None
+ )
+ expected_csp = (
+ "default-src 'self';font-src 'self';frame-ancestors *;"
+ "frame-src 'self';img-src 'self' data: blob:;object-src 'none';"
+ "style-src 'self' https://www.gstatic.com data: 'unsafe-inline';"
+ "script-src 'self'"
+ )
+ self.assertEqual(r.headers.get("Content-Security-Policy"), expected_csp)
+
+ @mock.patch.object(http_util, "_CSP_SCRIPT_UNSAFE_EVAL", False)
+ def testCsp_disableUnsafeEval(self):
+ q = wrappers.Request(wtest.EnvironBuilder().get_environ())
+ r = http_util.Respond(
+ q, "hello", "text/html", csp_scripts_sha256s=["abcdefghi"]
+ )
+ expected_csp = (
+ "default-src 'self';font-src 'self';frame-ancestors *;"
+ "frame-src 'self';img-src 'self' data: blob:;object-src 'none';"
+ "style-src 'self' https://www.gstatic.com data: 'unsafe-inline';"
+ "script-src 'self' 'sha256-abcdefghi'"
+ )
+ self.assertEqual(r.headers.get("Content-Security-Policy"), expected_csp)
+
+ @mock.patch.object(
+ http_util, "_CSP_IMG_DOMAINS_WHITELIST", ["https://example.com"]
)
- self.assertEqual(r.headers.get('Content-Security-Policy'), expected_csp)
-
- @mock.patch.object(http_util, '_CSP_SCRIPT_SELF', True)
- @mock.patch.object(http_util, '_CSP_SCRIPT_UNSAFE_EVAL', False)
- def testCsp_onlySelf(self):
- q = wrappers.Request(wtest.EnvironBuilder().get_environ())
- r = http_util.Respond(q, 'hello', 'text/html', csp_scripts_sha256s=None)
- expected_csp = (
- "default-src 'self';font-src 'self';frame-ancestors *;"
- "frame-src 'self';img-src 'self' data: blob:;object-src 'none';"
- "style-src 'self' https://www.gstatic.com data: 'unsafe-inline';"
- "script-src 'self'"
+ @mock.patch.object(
+ http_util,
+ "_CSP_SCRIPT_DOMAINS_WHITELIST",
+ ["https://tensorflow.org/tensorboard"],
)
- self.assertEqual(r.headers.get('Content-Security-Policy'), expected_csp)
-
- @mock.patch.object(http_util, '_CSP_SCRIPT_UNSAFE_EVAL', False)
- def testCsp_disableUnsafeEval(self):
- q = wrappers.Request(wtest.EnvironBuilder().get_environ())
- r = http_util.Respond(
- q, 'hello', 'text/html', csp_scripts_sha256s=['abcdefghi'])
- expected_csp = (
- "default-src 'self';font-src 'self';frame-ancestors *;"
- "frame-src 'self';img-src 'self' data: blob:;object-src 'none';"
- "style-src 'self' https://www.gstatic.com data: 'unsafe-inline';"
- "script-src 'self' 'sha256-abcdefghi'"
+ @mock.patch.object(
+ http_util, "_CSP_STYLE_DOMAINS_WHITELIST", ["https://googol.com"]
)
- self.assertEqual(r.headers.get('Content-Security-Policy'), expected_csp)
-
- @mock.patch.object(http_util, '_CSP_IMG_DOMAINS_WHITELIST', ['https://example.com'])
- @mock.patch.object(http_util, '_CSP_SCRIPT_DOMAINS_WHITELIST',
- ['https://tensorflow.org/tensorboard'])
- @mock.patch.object(http_util, '_CSP_STYLE_DOMAINS_WHITELIST', ['https://googol.com'])
- @mock.patch.object(http_util, '_CSP_FRAME_DOMAINS_WHITELIST', ['https://myframe.com'])
- def testCsp_globalDomainWhiteList(self):
- q = wrappers.Request(wtest.EnvironBuilder().get_environ())
- r = http_util.Respond(q, 'hello', 'text/html', csp_scripts_sha256s=['abcd'])
- expected_csp = (
- "default-src 'self';font-src 'self';frame-ancestors *;"
- "frame-src 'self' https://myframe.com;"
- "img-src 'self' data: blob: https://example.com;"
- "object-src 'none';style-src 'self' https://www.gstatic.com data: "
- "'unsafe-inline' https://googol.com;script-src "
- "https://tensorflow.org/tensorboard 'self' 'unsafe-eval' 'sha256-abcd'"
+ @mock.patch.object(
+ http_util, "_CSP_FRAME_DOMAINS_WHITELIST", ["https://myframe.com"]
)
- self.assertEqual(r.headers.get('Content-Security-Policy'), expected_csp)
-
- def testCsp_badGlobalDomainWhiteList(self):
- q = wrappers.Request(wtest.EnvironBuilder().get_environ())
- configs = [
- '_CSP_SCRIPT_DOMAINS_WHITELIST',
- '_CSP_IMG_DOMAINS_WHITELIST',
- '_CSP_STYLE_DOMAINS_WHITELIST',
- '_CSP_FONT_DOMAINS_WHITELIST',
- '_CSP_FRAME_DOMAINS_WHITELIST',
- ]
-
- for config in configs:
- with mock.patch.object(http_util, config, ['http://tensorflow.org']):
- with self.assertRaisesRegex(
- ValueError, '^Expected all whitelist to be a https URL'):
- http_util.Respond(q, 'hello', 'text/html', csp_scripts_sha256s=['abcd'])
-
- # Cannot grant more trust to a script from a remote source.
- with mock.patch.object(http_util, config,
- ["'strict-dynamic' 'unsafe-eval' https://tensorflow.org/"]):
- with self.assertRaisesRegex(
- ValueError, '^Expected all whitelist to be a https URL'):
- http_util.Respond(q, 'hello', 'text/html', csp_scripts_sha256s=['abcd'])
-
- # Attempt to terminate the script-src to specify a new one that allows ALL!
- with mock.patch.object(http_util, config, ['https://tensorflow.org;script-src *']):
- with self.assertRaisesRegex(
- ValueError, '^Expected whitelist domain to not contain ";"'):
- http_util.Respond(q, 'hello', 'text/html', csp_scripts_sha256s=['abcd'])
-
- # Attempt to use whitespace, delimit character, to specify a new one.
- with mock.patch.object(http_util, config, ['https://tensorflow.org *']):
- with self.assertRaisesRegex(
- ValueError, '^Expected whitelist domain to not contain a whitespace'):
- http_util.Respond(q, 'hello', 'text/html', csp_scripts_sha256s=['abcd'])
+ def testCsp_globalDomainWhiteList(self):
+ q = wrappers.Request(wtest.EnvironBuilder().get_environ())
+ r = http_util.Respond(
+ q, "hello", "text/html", csp_scripts_sha256s=["abcd"]
+ )
+ expected_csp = (
+ "default-src 'self';font-src 'self';frame-ancestors *;"
+ "frame-src 'self' https://myframe.com;"
+ "img-src 'self' data: blob: https://example.com;"
+ "object-src 'none';style-src 'self' https://www.gstatic.com data: "
+ "'unsafe-inline' https://googol.com;script-src "
+ "https://tensorflow.org/tensorboard 'self' 'unsafe-eval' 'sha256-abcd'"
+ )
+ self.assertEqual(r.headers.get("Content-Security-Policy"), expected_csp)
+
+ def testCsp_badGlobalDomainWhiteList(self):
+ q = wrappers.Request(wtest.EnvironBuilder().get_environ())
+ configs = [
+ "_CSP_SCRIPT_DOMAINS_WHITELIST",
+ "_CSP_IMG_DOMAINS_WHITELIST",
+ "_CSP_STYLE_DOMAINS_WHITELIST",
+ "_CSP_FONT_DOMAINS_WHITELIST",
+ "_CSP_FRAME_DOMAINS_WHITELIST",
+ ]
+
+ for config in configs:
+ with mock.patch.object(
+ http_util, config, ["http://tensorflow.org"]
+ ):
+ with self.assertRaisesRegex(
+ ValueError, "^Expected all whitelist to be a https URL"
+ ):
+ http_util.Respond(
+ q,
+ "hello",
+ "text/html",
+ csp_scripts_sha256s=["abcd"],
+ )
+
+ # Cannot grant more trust to a script from a remote source.
+ with mock.patch.object(
+ http_util,
+ config,
+ ["'strict-dynamic' 'unsafe-eval' https://tensorflow.org/"],
+ ):
+ with self.assertRaisesRegex(
+ ValueError, "^Expected all whitelist to be a https URL"
+ ):
+ http_util.Respond(
+ q,
+ "hello",
+ "text/html",
+ csp_scripts_sha256s=["abcd"],
+ )
+
+ # Attempt to terminate the script-src to specify a new one that allows ALL!
+ with mock.patch.object(
+ http_util, config, ["https://tensorflow.org;script-src *"]
+ ):
+ with self.assertRaisesRegex(
+ ValueError, '^Expected whitelist domain to not contain ";"'
+ ):
+ http_util.Respond(
+ q,
+ "hello",
+ "text/html",
+ csp_scripts_sha256s=["abcd"],
+ )
+
+ # Attempt to use whitespace, delimit character, to specify a new one.
+ with mock.patch.object(
+ http_util, config, ["https://tensorflow.org *"]
+ ):
+ with self.assertRaisesRegex(
+ ValueError,
+ "^Expected whitelist domain to not contain a whitespace",
+ ):
+ http_util.Respond(
+ q,
+ "hello",
+ "text/html",
+ csp_scripts_sha256s=["abcd"],
+ )
def _gzip(bs):
- out = six.BytesIO()
- with gzip.GzipFile(fileobj=out, mode='wb') as f:
- f.write(bs)
- return out.getvalue()
+ out = six.BytesIO()
+ with gzip.GzipFile(fileobj=out, mode="wb") as f:
+ f.write(bs)
+ return out.getvalue()
def _gunzip(bs):
- with gzip.GzipFile(fileobj=six.BytesIO(bs), mode='rb') as f:
- return f.read()
+ with gzip.GzipFile(fileobj=six.BytesIO(bs), mode="rb") as f:
+ return f.read()
+
def _bitflip(bs):
- # Return bytestring with all its bits flipped.
- return b''.join(struct.pack('B', 0xFF ^ struct.unpack_from('B', bs, i)[0])
- for i in range(len(bs)))
+ # Return bytestring with all its bits flipped.
+ return b"".join(
+ struct.pack("B", 0xFF ^ struct.unpack_from("B", bs, i)[0])
+ for i in range(len(bs))
+ )
+
-if __name__ == '__main__':
- tb_test.main()
+if __name__ == "__main__":
+ tb_test.main()
diff --git a/tensorboard/backend/json_util.py b/tensorboard/backend/json_util.py
index 65b2c01c7c..5ca23afc48 100644
--- a/tensorboard/backend/json_util.py
+++ b/tensorboard/backend/json_util.py
@@ -17,10 +17,10 @@
Python provides no way to override how json.dumps serializes
Infinity/-Infinity/NaN; if allow_nan is true, it encodes them as
-Infinity/-Infinity/NaN, in violation of the JSON spec and in violation of what
-JSON.parse accepts. If it's false, it throws a ValueError, Neither subclassing
-JSONEncoder nor passing a function in the |default| keyword argument overrides
-this.
+Infinity/-Infinity/NaN, in violation of the JSON spec and in violation
+of what JSON.parse accepts. If it's false, it throws a ValueError,
+Neither subclassing JSONEncoder nor passing a function in the |default|
+keyword argument overrides this.
"""
from __future__ import absolute_import
@@ -33,46 +33,45 @@
from tensorboard.compat import tf
-_INFINITY = float('inf')
-_NEGATIVE_INFINITY = float('-inf')
+_INFINITY = float("inf")
+_NEGATIVE_INFINITY = float("-inf")
-def Cleanse(obj, encoding='utf-8'):
- """Makes Python object appropriate for JSON serialization.
+def Cleanse(obj, encoding="utf-8"):
+ """Makes Python object appropriate for JSON serialization.
- - Replaces instances of Infinity/-Infinity/NaN with strings.
- - Turns byte strings into unicode strings.
- - Turns sets into sorted lists.
- - Turns tuples into lists.
+ - Replaces instances of Infinity/-Infinity/NaN with strings.
+ - Turns byte strings into unicode strings.
+ - Turns sets into sorted lists.
+ - Turns tuples into lists.
- Args:
- obj: Python data structure.
- encoding: Charset used to decode byte strings.
+ Args:
+ obj: Python data structure.
+ encoding: Charset used to decode byte strings.
- Returns:
- Unicode JSON data structure.
- """
- if isinstance(obj, int):
- return obj
- elif isinstance(obj, float):
- if obj == _INFINITY:
- return 'Infinity'
- elif obj == _NEGATIVE_INFINITY:
- return '-Infinity'
- elif math.isnan(obj):
- return 'NaN'
+ Returns:
+ Unicode JSON data structure.
+ """
+ if isinstance(obj, int):
+ return obj
+ elif isinstance(obj, float):
+ if obj == _INFINITY:
+ return "Infinity"
+ elif obj == _NEGATIVE_INFINITY:
+ return "-Infinity"
+ elif math.isnan(obj):
+ return "NaN"
+ else:
+ return obj
+ elif isinstance(obj, bytes):
+ return tf.compat.as_text(obj, encoding)
+ elif isinstance(obj, (list, tuple)):
+ return [Cleanse(i, encoding) for i in obj]
+ elif isinstance(obj, set):
+ return [Cleanse(i, encoding) for i in sorted(obj)]
+ elif isinstance(obj, dict):
+ return collections.OrderedDict(
+ (Cleanse(k, encoding), Cleanse(v, encoding)) for k, v in obj.items()
+ )
else:
- return obj
- elif isinstance(obj, bytes):
- return tf.compat.as_text(obj, encoding)
- elif isinstance(obj, (list, tuple)):
- return [Cleanse(i, encoding) for i in obj]
- elif isinstance(obj, set):
- return [Cleanse(i, encoding) for i in sorted(obj)]
- elif isinstance(obj, dict):
- return collections.OrderedDict(
- (Cleanse(k, encoding), Cleanse(v, encoding))
- for k, v in obj.items()
- )
- else:
- return obj
+ return obj
diff --git a/tensorboard/backend/json_util_test.py b/tensorboard/backend/json_util_test.py
index a9b3424dbc..e0d030be52 100644
--- a/tensorboard/backend/json_util_test.py
+++ b/tensorboard/backend/json_util_test.py
@@ -23,54 +23,55 @@
from tensorboard import test as tb_test
from tensorboard.backend import json_util
-_INFINITY = float('inf')
+_INFINITY = float("inf")
class CleanseTest(tb_test.TestCase):
-
- def _assertWrapsAs(self, to_wrap, expected):
- """Asserts that |to_wrap| becomes |expected| when wrapped."""
- actual = json_util.Cleanse(to_wrap)
- for a, e in zip(actual, expected):
- self.assertEqual(e, a)
-
- def testWrapsPrimitives(self):
- self._assertWrapsAs(_INFINITY, 'Infinity')
- self._assertWrapsAs(-_INFINITY, '-Infinity')
- self._assertWrapsAs(float('nan'), 'NaN')
-
- def testWrapsObjectValues(self):
- self._assertWrapsAs({'x': _INFINITY}, {'x': 'Infinity'})
-
- def testWrapsObjectKeys(self):
- self._assertWrapsAs({_INFINITY: 'foo'}, {'Infinity': 'foo'})
-
- def testWrapsInListsAndTuples(self):
- self._assertWrapsAs([_INFINITY], ['Infinity'])
- # map() returns a list even if the argument is a tuple.
- self._assertWrapsAs((_INFINITY,), ['Infinity',])
-
- def testWrapsRecursively(self):
- self._assertWrapsAs({'x': [_INFINITY]}, {'x': ['Infinity']})
-
- def testOrderedDict_preservesOrder(self):
- # dict iteration order is not specified prior to Python 3.7, and is
- # observably different from insertion order in CPython 2.7.
- od = collections.OrderedDict()
- for c in string.ascii_lowercase:
- od[c] = c
- self.assertEqual(len(od), 26, od)
- self.assertEqual(list(od), list(json_util.Cleanse(od)))
-
- def testTuple_turnsIntoList(self):
- self.assertEqual(json_util.Cleanse(('a', 'b')), ['a', 'b'])
-
- def testSet_turnsIntoSortedList(self):
- self.assertEqual(json_util.Cleanse(set(['b', 'a'])), ['a', 'b'])
-
- def testByteString_turnsIntoUnicodeString(self):
- self.assertEqual(json_util.Cleanse(b'\xc2\xa3'), u'\u00a3') # is # sterling
-
-
-if __name__ == '__main__':
- tb_test.main()
+ def _assertWrapsAs(self, to_wrap, expected):
+ """Asserts that |to_wrap| becomes |expected| when wrapped."""
+ actual = json_util.Cleanse(to_wrap)
+ for a, e in zip(actual, expected):
+ self.assertEqual(e, a)
+
+ def testWrapsPrimitives(self):
+ self._assertWrapsAs(_INFINITY, "Infinity")
+ self._assertWrapsAs(-_INFINITY, "-Infinity")
+ self._assertWrapsAs(float("nan"), "NaN")
+
+ def testWrapsObjectValues(self):
+ self._assertWrapsAs({"x": _INFINITY}, {"x": "Infinity"})
+
+ def testWrapsObjectKeys(self):
+ self._assertWrapsAs({_INFINITY: "foo"}, {"Infinity": "foo"})
+
+ def testWrapsInListsAndTuples(self):
+ self._assertWrapsAs([_INFINITY], ["Infinity"])
+ # map() returns a list even if the argument is a tuple.
+ self._assertWrapsAs((_INFINITY,), ["Infinity",])
+
+ def testWrapsRecursively(self):
+ self._assertWrapsAs({"x": [_INFINITY]}, {"x": ["Infinity"]})
+
+ def testOrderedDict_preservesOrder(self):
+ # dict iteration order is not specified prior to Python 3.7, and is
+ # observably different from insertion order in CPython 2.7.
+ od = collections.OrderedDict()
+ for c in string.ascii_lowercase:
+ od[c] = c
+ self.assertEqual(len(od), 26, od)
+ self.assertEqual(list(od), list(json_util.Cleanse(od)))
+
+ def testTuple_turnsIntoList(self):
+ self.assertEqual(json_util.Cleanse(("a", "b")), ["a", "b"])
+
+ def testSet_turnsIntoSortedList(self):
+ self.assertEqual(json_util.Cleanse(set(["b", "a"])), ["a", "b"])
+
+ def testByteString_turnsIntoUnicodeString(self):
+ self.assertEqual(
+ json_util.Cleanse(b"\xc2\xa3"), u"\u00a3"
+ ) # is # sterling
+
+
+if __name__ == "__main__":
+ tb_test.main()
diff --git a/tensorboard/backend/path_prefix.py b/tensorboard/backend/path_prefix.py
index 2192deb9f7..13ffeb61d0 100644
--- a/tensorboard/backend/path_prefix.py
+++ b/tensorboard/backend/path_prefix.py
@@ -27,39 +27,45 @@
class PathPrefixMiddleware(object):
- """WSGI middleware for path prefixes.
+ """WSGI middleware for path prefixes.
- All requests to this middleware must begin with the specified path
- prefix (otherwise, a 404 will be returned immediately). Requests will
- be forwarded to the underlying application with the path prefix
- stripped and appended to `SCRIPT_NAME` (see the WSGI spec, PEP 3333,
- for details).
- """
+ All requests to this middleware must begin with the specified path
+ prefix (otherwise, a 404 will be returned immediately). Requests
+ will be forwarded to the underlying application with the path prefix
+ stripped and appended to `SCRIPT_NAME` (see the WSGI spec, PEP 3333,
+ for details).
+ """
- def __init__(self, application, path_prefix):
- """Initializes this middleware.
+ def __init__(self, application, path_prefix):
+ """Initializes this middleware.
- Args:
- application: The WSGI application to wrap (see PEP 3333).
- path_prefix: A string path prefix to be stripped from incoming
- requests. If empty, this middleware is a no-op. If non-empty,
- the path prefix must start with a slash and not end with one
- (e.g., "/tensorboard").
- """
- if path_prefix.endswith("/"):
- raise ValueError("Path prefix must not end with slash: %r" % path_prefix)
- if path_prefix and not path_prefix.startswith("/"):
- raise ValueError(
- "Non-empty path prefix must start with slash: %r" % path_prefix
- )
- self._application = application
- self._path_prefix = path_prefix
- self._strict_prefix = self._path_prefix + "/"
+ Args:
+ application: The WSGI application to wrap (see PEP 3333).
+ path_prefix: A string path prefix to be stripped from incoming
+ requests. If empty, this middleware is a no-op. If non-empty,
+ the path prefix must start with a slash and not end with one
+ (e.g., "/tensorboard").
+ """
+ if path_prefix.endswith("/"):
+ raise ValueError(
+ "Path prefix must not end with slash: %r" % path_prefix
+ )
+ if path_prefix and not path_prefix.startswith("/"):
+ raise ValueError(
+ "Non-empty path prefix must start with slash: %r" % path_prefix
+ )
+ self._application = application
+ self._path_prefix = path_prefix
+ self._strict_prefix = self._path_prefix + "/"
- def __call__(self, environ, start_response):
- path = environ.get("PATH_INFO", "")
- if path != self._path_prefix and not path.startswith(self._strict_prefix):
- raise errors.NotFoundError()
- environ["PATH_INFO"] = path[len(self._path_prefix):]
- environ["SCRIPT_NAME"] = environ.get("SCRIPT_NAME", "") + self._path_prefix
- return self._application(environ, start_response)
+ def __call__(self, environ, start_response):
+ path = environ.get("PATH_INFO", "")
+ if path != self._path_prefix and not path.startswith(
+ self._strict_prefix
+ ):
+ raise errors.NotFoundError()
+ environ["PATH_INFO"] = path[len(self._path_prefix) :]
+ environ["SCRIPT_NAME"] = (
+ environ.get("SCRIPT_NAME", "") + self._path_prefix
+ )
+ return self._application(environ, start_response)
diff --git a/tensorboard/backend/path_prefix_test.py b/tensorboard/backend/path_prefix_test.py
index 80dda313c1..6b62fbf76f 100644
--- a/tensorboard/backend/path_prefix_test.py
+++ b/tensorboard/backend/path_prefix_test.py
@@ -29,85 +29,85 @@
class PathPrefixMiddlewareTest(tb_test.TestCase):
- """Tests for `PathPrefixMiddleware`."""
-
- def _echo_app(self, environ, start_response):
- # https://www.python.org/dev/peps/pep-0333/#environ-variables
- data = {
- "path": environ.get("PATH_INFO", ""),
- "script": environ.get("SCRIPT_NAME", ""),
- }
- body = json.dumps(data, sort_keys=True)
- start_response("200 OK", [("Content-Type", "application/json")])
- return [body]
-
- def _assert_ok(self, response, path, script):
- self.assertEqual(response.status_code, 200)
- actual = json.loads(response.get_data())
- expected = dict(path=path, script=script)
- self.assertEqual(actual, expected)
-
- def test_bad_path_prefix_without_leading_slash(self):
- with self.assertRaises(ValueError) as cm:
- path_prefix.PathPrefixMiddleware(self._echo_app, "hmm")
- msg = str(cm.exception)
- self.assertIn("must start with slash", msg)
- self.assertIn(repr("hmm"), msg)
-
- def test_bad_path_prefix_with_trailing_slash(self):
- with self.assertRaises(ValueError) as cm:
- path_prefix.PathPrefixMiddleware(self._echo_app, "/hmm/")
- msg = str(cm.exception)
- self.assertIn("must not end with slash", msg)
- self.assertIn(repr("/hmm/"), msg)
-
- def test_empty_path_prefix(self):
- app = path_prefix.PathPrefixMiddleware(self._echo_app, "")
- server = werkzeug_test.Client(app, werkzeug.BaseResponse)
-
- with self.subTest("at empty"):
- self._assert_ok(server.get(""), path="", script="")
-
- with self.subTest("at root"):
- self._assert_ok(server.get("/"), path="/", script="")
-
- with self.subTest("at subpath"):
- response = server.get("/foo/bar")
- self._assert_ok(server.get("/foo/bar"), path="/foo/bar", script="")
-
- def test_nonempty_path_prefix(self):
- app = path_prefix.PathPrefixMiddleware(self._echo_app, "/pfx")
- server = werkzeug_test.Client(app, werkzeug.BaseResponse)
-
- with self.subTest("at root"):
- response = server.get("/pfx")
- self._assert_ok(response, path="", script="/pfx")
-
- with self.subTest("at root with slash"):
- response = server.get("/pfx/")
- self._assert_ok(response, path="/", script="/pfx")
-
- with self.subTest("at subpath"):
- response = server.get("/pfx/foo/bar")
- self._assert_ok(response, path="/foo/bar", script="/pfx")
-
- with self.subTest("at non-path-component extension"):
- with self.assertRaises(errors.NotFoundError):
- server.get("/pfxz")
-
- with self.subTest("above path prefix"):
- with self.assertRaises(errors.NotFoundError):
- server.get("/hmm")
-
- def test_composition(self):
- app = self._echo_app
- app = path_prefix.PathPrefixMiddleware(app, "/bar")
- app = path_prefix.PathPrefixMiddleware(app, "/foo")
- server = werkzeug_test.Client(app, werkzeug.BaseResponse)
-
- response = server.get("/foo/bar/baz/quux")
- self._assert_ok(response, path="/baz/quux", script="/foo/bar")
+ """Tests for `PathPrefixMiddleware`."""
+
+ def _echo_app(self, environ, start_response):
+ # https://www.python.org/dev/peps/pep-0333/#environ-variables
+ data = {
+ "path": environ.get("PATH_INFO", ""),
+ "script": environ.get("SCRIPT_NAME", ""),
+ }
+ body = json.dumps(data, sort_keys=True)
+ start_response("200 OK", [("Content-Type", "application/json")])
+ return [body]
+
+ def _assert_ok(self, response, path, script):
+ self.assertEqual(response.status_code, 200)
+ actual = json.loads(response.get_data())
+ expected = dict(path=path, script=script)
+ self.assertEqual(actual, expected)
+
+ def test_bad_path_prefix_without_leading_slash(self):
+ with self.assertRaises(ValueError) as cm:
+ path_prefix.PathPrefixMiddleware(self._echo_app, "hmm")
+ msg = str(cm.exception)
+ self.assertIn("must start with slash", msg)
+ self.assertIn(repr("hmm"), msg)
+
+ def test_bad_path_prefix_with_trailing_slash(self):
+ with self.assertRaises(ValueError) as cm:
+ path_prefix.PathPrefixMiddleware(self._echo_app, "/hmm/")
+ msg = str(cm.exception)
+ self.assertIn("must not end with slash", msg)
+ self.assertIn(repr("/hmm/"), msg)
+
+ def test_empty_path_prefix(self):
+ app = path_prefix.PathPrefixMiddleware(self._echo_app, "")
+ server = werkzeug_test.Client(app, werkzeug.BaseResponse)
+
+ with self.subTest("at empty"):
+ self._assert_ok(server.get(""), path="", script="")
+
+ with self.subTest("at root"):
+ self._assert_ok(server.get("/"), path="/", script="")
+
+ with self.subTest("at subpath"):
+ response = server.get("/foo/bar")
+ self._assert_ok(server.get("/foo/bar"), path="/foo/bar", script="")
+
+ def test_nonempty_path_prefix(self):
+ app = path_prefix.PathPrefixMiddleware(self._echo_app, "/pfx")
+ server = werkzeug_test.Client(app, werkzeug.BaseResponse)
+
+ with self.subTest("at root"):
+ response = server.get("/pfx")
+ self._assert_ok(response, path="", script="/pfx")
+
+ with self.subTest("at root with slash"):
+ response = server.get("/pfx/")
+ self._assert_ok(response, path="/", script="/pfx")
+
+ with self.subTest("at subpath"):
+ response = server.get("/pfx/foo/bar")
+ self._assert_ok(response, path="/foo/bar", script="/pfx")
+
+ with self.subTest("at non-path-component extension"):
+ with self.assertRaises(errors.NotFoundError):
+ server.get("/pfxz")
+
+ with self.subTest("above path prefix"):
+ with self.assertRaises(errors.NotFoundError):
+ server.get("/hmm")
+
+ def test_composition(self):
+ app = self._echo_app
+ app = path_prefix.PathPrefixMiddleware(app, "/bar")
+ app = path_prefix.PathPrefixMiddleware(app, "/foo")
+ server = werkzeug_test.Client(app, werkzeug.BaseResponse)
+
+ response = server.get("/foo/bar/baz/quux")
+ self._assert_ok(response, path="/baz/quux", script="/foo/bar")
if __name__ == "__main__":
- tb_test.main()
+ tb_test.main()
diff --git a/tensorboard/backend/process_graph.py b/tensorboard/backend/process_graph.py
index adf7f0e1b3..71a6b1c8e7 100644
--- a/tensorboard/backend/process_graph.py
+++ b/tensorboard/backend/process_graph.py
@@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Graph post-processing logic. Used by both TensorBoard and mldash."""
+"""Graph post-processing logic.
+
+Used by both TensorBoard and mldash.
+"""
from __future__ import absolute_import
@@ -22,47 +25,53 @@
from tensorboard.compat import tf
-def prepare_graph_for_ui(graph, limit_attr_size=1024,
- large_attrs_key='_too_large_attrs'):
- """Prepares (modifies in-place) the graph to be served to the front-end.
+def prepare_graph_for_ui(
+ graph, limit_attr_size=1024, large_attrs_key="_too_large_attrs"
+):
+ """Prepares (modifies in-place) the graph to be served to the front-end.
- For now, it supports filtering out attributes that are
- too large to be shown in the graph UI.
+ For now, it supports filtering out attributes that are
+ too large to be shown in the graph UI.
- Args:
- graph: The GraphDef proto message.
- limit_attr_size: Maximum allowed size in bytes, before the attribute
- is considered large. Default is 1024 (1KB). Must be > 0 or None.
- If None, there will be no filtering.
- large_attrs_key: The attribute key that will be used for storing attributes
- that are too large. Default is '_too_large_attrs'. Must be != None if
- `limit_attr_size` is != None.
+ Args:
+ graph: The GraphDef proto message.
+ limit_attr_size: Maximum allowed size in bytes, before the attribute
+ is considered large. Default is 1024 (1KB). Must be > 0 or None.
+ If None, there will be no filtering.
+ large_attrs_key: The attribute key that will be used for storing attributes
+ that are too large. Default is '_too_large_attrs'. Must be != None if
+ `limit_attr_size` is != None.
- Raises:
- ValueError: If `large_attrs_key is None` while `limit_attr_size != None`.
- ValueError: If `limit_attr_size` is defined, but <= 0.
- """
- # Check input for validity.
- if limit_attr_size is not None:
- if large_attrs_key is None:
- raise ValueError('large_attrs_key must be != None when limit_attr_size'
- '!= None.')
+ Raises:
+ ValueError: If `large_attrs_key is None` while `limit_attr_size != None`.
+ ValueError: If `limit_attr_size` is defined, but <= 0.
+ """
+ # Check input for validity.
+ if limit_attr_size is not None:
+ if large_attrs_key is None:
+ raise ValueError(
+ "large_attrs_key must be != None when limit_attr_size"
+ "!= None."
+ )
- if limit_attr_size <= 0:
- raise ValueError('limit_attr_size must be > 0, but is %d' %
- limit_attr_size)
+ if limit_attr_size <= 0:
+ raise ValueError(
+ "limit_attr_size must be > 0, but is %d" % limit_attr_size
+ )
- # Filter only if a limit size is defined.
- if limit_attr_size is not None:
- for node in graph.node:
- # Go through all the attributes and filter out ones bigger than the
- # limit.
- keys = list(node.attr.keys())
- for key in keys:
- size = node.attr[key].ByteSize()
- if size > limit_attr_size or size < 0:
- del node.attr[key]
- # Add the attribute key to the list of "too large" attributes.
- # This is used in the info card in the graph UI to show the user
- # that some attributes are too large to be shown.
- node.attr[large_attrs_key].list.s.append(tf.compat.as_bytes(key))
+ # Filter only if a limit size is defined.
+ if limit_attr_size is not None:
+ for node in graph.node:
+ # Go through all the attributes and filter out ones bigger than the
+ # limit.
+ keys = list(node.attr.keys())
+ for key in keys:
+ size = node.attr[key].ByteSize()
+ if size > limit_attr_size or size < 0:
+ del node.attr[key]
+ # Add the attribute key to the list of "too large" attributes.
+ # This is used in the info card in the graph UI to show the user
+ # that some attributes are too large to be shown.
+ node.attr[large_attrs_key].list.s.append(
+ tf.compat.as_bytes(key)
+ )
diff --git a/tensorboard/backend/security_validator.py b/tensorboard/backend/security_validator.py
index 388a74d9c4..ec0353a601 100644
--- a/tensorboard/backend/security_validator.py
+++ b/tensorboard/backend/security_validator.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Validates responses and their security features"""
+"""Validates responses and their security features."""
from __future__ import absolute_import
from __future__ import division
@@ -48,145 +48,147 @@
def _maybe_raise_value_error(error_msg):
- logger.warn("In 3.0, this warning will become an error:\n%s" % error_msg)
- # TODO(3.x): raise a value error.
+ logger.warn("In 3.0, this warning will become an error:\n%s" % error_msg)
+ # TODO(3.x): raise a value error.
class SecurityValidatorMiddleware(object):
- """WSGI middleware validating security on response.
+ """WSGI middleware validating security on response.
- It validates:
- - responses have Content-Type
- - responses have X-Content-Type-Options: nosniff
- - text/html responses have CSP header. It also validates whether the CSP
- headers pass basic requirement. e.g., default-src should be present, cannot
- use "*" directive, and others. For more complete list, please refer to
- _validate_csp_policies.
+ It validates:
+ - responses have Content-Type
+ - responses have X-Content-Type-Options: nosniff
+ - text/html responses have CSP header. It also validates whether the CSP
+ headers pass basic requirement. e.g., default-src should be present, cannot
+ use "*" directive, and others. For more complete list, please refer to
+ _validate_csp_policies.
- Instances of this class are WSGI applications (see PEP 3333).
- """
-
- def __init__(self, application):
- """Initializes an `SecurityValidatorMiddleware`.
-
- Args:
- application: The WSGI application to wrap (see PEP 3333).
+ Instances of this class are WSGI applications (see PEP 3333).
"""
- self._application = application
-
- def __call__(self, environ, start_response):
-
- def start_response_proxy(status, headers, exc_info=None):
- self._validate_headers(headers)
- return start_response(status, headers, exc_info)
-
- return self._application(environ, start_response_proxy)
-
- def _validate_headers(self, headers_list):
- headers = Headers(headers_list)
- self._validate_content_type(headers)
- self._validate_x_content_type_options(headers)
- self._validate_csp_headers(headers)
-
- def _validate_content_type(self, headers):
- if headers.get("Content-Type"):
- return
-
- _maybe_raise_value_error("Content-Type is required on a Response")
-
- def _validate_x_content_type_options(self, headers):
- option = headers.get("X-Content-Type-Options")
- if option == "nosniff":
- return
-
- _maybe_raise_value_error(
- 'X-Content-Type-Options is required to be "nosniff"'
- )
-
- def _validate_csp_headers(self, headers):
- mime_type, _ = http.parse_options_header(headers.get("Content-Type"))
- if mime_type != _HTML_MIME_TYPE:
- return
-
- csp_texts = headers.get_all("Content-Security-Policy")
- policies = []
-
- for csp_text in csp_texts:
- policies += self._parse_serialized_csp(csp_text)
-
- self._validate_csp_policies(policies)
-
- def _validate_csp_policies(self, policies):
- has_default_src = False
- violations = []
-
- for directive in policies:
- name = directive.name
- for value in directive.value:
- has_default_src = has_default_src or name == _CSP_DEFAULT_SRC
-
- if value in _CSP_IGNORE.get(name, []):
- # There are cases where certain directives are legitimate.
- continue
-
- # TensorBoard follows principle of least privilege. However, to make it
- # easier to conform to the security policy for plugin authors,
- # TensorBoard trusts request and resources originating its server. Also,
- # it can selectively trust domains as long as they use https protocol.
- # Lastly, it can allow 'none' directive.
- # TODO(stephanwlee): allow configuration for whitelist of domains for
- # stricter enforcement.
- # TODO(stephanwlee): deprecate the sha-based whitelisting.
- if (
- value == "'self'" or value == "'none'"
- or value.startswith("https:")
- or value.startswith("'sha256-")
- ):
- continue
-
- msg = "Illegal Content-Security-Policy for {name}: {value}".format(
- name=name, value=value
- )
- violations.append(msg)
- if not has_default_src:
- violations.append("Requires default-src for Content-Security-Policy")
+ def __init__(self, application):
+ """Initializes an `SecurityValidatorMiddleware`.
- if violations:
- _maybe_raise_value_error("\n".join(violations))
+ Args:
+ application: The WSGI application to wrap (see PEP 3333).
+ """
+ self._application = application
- def _parse_serialized_csp(self, csp_text):
- # See https://www.w3.org/TR/CSP/#parse-serialized-policy.
- # Below Steps are based on the algorithm stated in above spec.
- # Deviations:
- # - it does not warn and ignore duplicative directive (Step 2.5)
+ def __call__(self, environ, start_response):
+ def start_response_proxy(status, headers, exc_info=None):
+ self._validate_headers(headers)
+ return start_response(status, headers, exc_info)
- # Step 2
- csp_srcs = csp_text.split(";")
+ return self._application(environ, start_response_proxy)
- policy = []
- for token in csp_srcs:
- # Step 2.1
- token = token.strip()
+ def _validate_headers(self, headers_list):
+ headers = Headers(headers_list)
+ self._validate_content_type(headers)
+ self._validate_x_content_type_options(headers)
+ self._validate_csp_headers(headers)
- if not token:
- # Step 2.2
- continue
+ def _validate_content_type(self, headers):
+ if headers.get("Content-Type"):
+ return
- # Step 2.3
- token_frag = token.split(None, 1)
- name = token_frag[0]
+ _maybe_raise_value_error("Content-Type is required on a Response")
- values = token_frag[1] if len(token_frag) == 2 else ""
+ def _validate_x_content_type_options(self, headers):
+ option = headers.get("X-Content-Type-Options")
+ if option == "nosniff":
+ return
- # Step 2.4
- name = name.lower()
-
- # Step 2.6
- value = values.split()
- # Step 2.7
- directive = Directive(name=name, value=value)
- # Step 2.8
- policy.append(directive)
+ _maybe_raise_value_error(
+ 'X-Content-Type-Options is required to be "nosniff"'
+ )
- return policy
+ def _validate_csp_headers(self, headers):
+ mime_type, _ = http.parse_options_header(headers.get("Content-Type"))
+ if mime_type != _HTML_MIME_TYPE:
+ return
+
+ csp_texts = headers.get_all("Content-Security-Policy")
+ policies = []
+
+ for csp_text in csp_texts:
+ policies += self._parse_serialized_csp(csp_text)
+
+ self._validate_csp_policies(policies)
+
+ def _validate_csp_policies(self, policies):
+ has_default_src = False
+ violations = []
+
+ for directive in policies:
+ name = directive.name
+ for value in directive.value:
+ has_default_src = has_default_src or name == _CSP_DEFAULT_SRC
+
+ if value in _CSP_IGNORE.get(name, []):
+ # There are cases where certain directives are legitimate.
+ continue
+
+ # TensorBoard follows principle of least privilege. However, to make it
+ # easier to conform to the security policy for plugin authors,
+ # TensorBoard trusts request and resources originating its server. Also,
+ # it can selectively trust domains as long as they use https protocol.
+ # Lastly, it can allow 'none' directive.
+ # TODO(stephanwlee): allow configuration for whitelist of domains for
+ # stricter enforcement.
+ # TODO(stephanwlee): deprecate the sha-based whitelisting.
+ if (
+ value == "'self'"
+ or value == "'none'"
+ or value.startswith("https:")
+ or value.startswith("'sha256-")
+ ):
+ continue
+
+ msg = "Illegal Content-Security-Policy for {name}: {value}".format(
+ name=name, value=value
+ )
+ violations.append(msg)
+
+ if not has_default_src:
+ violations.append(
+ "Requires default-src for Content-Security-Policy"
+ )
+
+ if violations:
+ _maybe_raise_value_error("\n".join(violations))
+
+ def _parse_serialized_csp(self, csp_text):
+ # See https://www.w3.org/TR/CSP/#parse-serialized-policy.
+ # Below Steps are based on the algorithm stated in above spec.
+ # Deviations:
+ # - it does not warn and ignore duplicative directive (Step 2.5)
+
+ # Step 2
+ csp_srcs = csp_text.split(";")
+
+ policy = []
+ for token in csp_srcs:
+ # Step 2.1
+ token = token.strip()
+
+ if not token:
+ # Step 2.2
+ continue
+
+ # Step 2.3
+ token_frag = token.split(None, 1)
+ name = token_frag[0]
+
+ values = token_frag[1] if len(token_frag) == 2 else ""
+
+ # Step 2.4
+ name = name.lower()
+
+ # Step 2.6
+ value = values.split()
+ # Step 2.7
+ directive = Directive(name=name, value=value)
+ # Step 2.8
+ policy.append(directive)
+
+ return policy
diff --git a/tensorboard/backend/security_validator_test.py b/tensorboard/backend/security_validator_test.py
index 9a972c9685..a3cd343746 100644
--- a/tensorboard/backend/security_validator_test.py
+++ b/tensorboard/backend/security_validator_test.py
@@ -19,10 +19,10 @@
from __future__ import print_function
try:
- # python version >= 3.3
- from unittest import mock
+ # python version >= 3.3
+ from unittest import mock
except ImportError:
- import mock # pylint: disable=unused-import
+ import mock # pylint: disable=unused-import
import werkzeug
from werkzeug import test as werkzeug_test
@@ -43,159 +43,154 @@ def create_headers(
x_content_type_options="nosniff",
content_security_policy="",
):
- return Headers(
- {
- "Content-Type": content_type,
- "X-Content-Type-Options": x_content_type_options,
- "Content-Security-Policy": content_security_policy,
- }
- )
-
-
-class SecurityValidatorMiddlewareTest(tb_test.TestCase):
- """Tests for `SecurityValidatorMiddleware`."""
-
- def make_request_and_maybe_assert_warn(
- self,
- headers,
- expected_warn_substr,
- ):
-
- @werkzeug.Request.application
- def _simple_app(req):
- return werkzeug.Response("OK", headers=headers)
-
- app = security_validator.SecurityValidatorMiddleware(_simple_app)
- server = werkzeug_test.Client(app, BaseResponse)
-
- with mock.patch.object(logger, "warn") as mock_warn:
- server.get("")
-
- if expected_warn_substr is None:
- mock_warn.assert_not_called()
- else:
- mock_warn.assert_called_with(_WARN_PREFIX + expected_warn_substr)
-
- def make_request_and_assert_no_warn(
- self,
- headers,
- ):
- self.make_request_and_maybe_assert_warn(headers, None)
-
- def test_validate_content_type(self):
- self.make_request_and_assert_no_warn(
- create_headers(content_type="application/json"),
+ return Headers(
+ {
+ "Content-Type": content_type,
+ "X-Content-Type-Options": x_content_type_options,
+ "Content-Security-Policy": content_security_policy,
+ }
)
- self.make_request_and_maybe_assert_warn(
- create_headers(content_type=""),
- "Content-Type is required on a Response"
- )
- def test_validate_x_content_type_options(self):
- self.make_request_and_assert_no_warn(
- create_headers(x_content_type_options="nosniff")
- )
-
- self.make_request_and_maybe_assert_warn(
- create_headers(x_content_type_options=""),
- 'X-Content-Type-Options is required to be "nosniff"',
- )
-
- def test_validate_csp_text_html(self):
- self.make_request_and_assert_no_warn(
- create_headers(
- content_type="text/html; charset=UTF-8",
- content_security_policy=(
- "DEFAult-src 'self';script-src https://google.com;"
- "style-src 'self' https://example; object-src "
+class SecurityValidatorMiddlewareTest(tb_test.TestCase):
+ """Tests for `SecurityValidatorMiddleware`."""
+
+ def make_request_and_maybe_assert_warn(
+ self, headers, expected_warn_substr,
+ ):
+ @werkzeug.Request.application
+ def _simple_app(req):
+ return werkzeug.Response("OK", headers=headers)
+
+ app = security_validator.SecurityValidatorMiddleware(_simple_app)
+ server = werkzeug_test.Client(app, BaseResponse)
+
+ with mock.patch.object(logger, "warn") as mock_warn:
+ server.get("")
+
+ if expected_warn_substr is None:
+ mock_warn.assert_not_called()
+ else:
+ mock_warn.assert_called_with(_WARN_PREFIX + expected_warn_substr)
+
+ def make_request_and_assert_no_warn(
+ self, headers,
+ ):
+ self.make_request_and_maybe_assert_warn(headers, None)
+
+ def test_validate_content_type(self):
+ self.make_request_and_assert_no_warn(
+ create_headers(content_type="application/json"),
+ )
+
+ self.make_request_and_maybe_assert_warn(
+ create_headers(content_type=""),
+ "Content-Type is required on a Response",
+ )
+
+ def test_validate_x_content_type_options(self):
+ self.make_request_and_assert_no_warn(
+ create_headers(x_content_type_options="nosniff")
+ )
+
+ self.make_request_and_maybe_assert_warn(
+ create_headers(x_content_type_options=""),
+ 'X-Content-Type-Options is required to be "nosniff"',
+ )
+
+ def test_validate_csp_text_html(self):
+ self.make_request_and_assert_no_warn(
+ create_headers(
+ content_type="text/html; charset=UTF-8",
+ content_security_policy=(
+ "DEFAult-src 'self';script-src https://google.com;"
+ "style-src 'self' https://example; object-src "
+ ),
),
- ),
- )
+ )
- self.make_request_and_maybe_assert_warn(
- create_headers(
- content_type="text/html; charset=UTF-8",
- content_security_policy="",
- ),
- "Requires default-src for Content-Security-Policy",
- )
+ self.make_request_and_maybe_assert_warn(
+ create_headers(
+ content_type="text/html; charset=UTF-8",
+ content_security_policy="",
+ ),
+ "Requires default-src for Content-Security-Policy",
+ )
- self.make_request_and_maybe_assert_warn(
- create_headers(
- content_type="text/html; charset=UTF-8",
- content_security_policy="default-src *",
- ),
- "Illegal Content-Security-Policy for default-src: *",
- )
+ self.make_request_and_maybe_assert_warn(
+ create_headers(
+ content_type="text/html; charset=UTF-8",
+ content_security_policy="default-src *",
+ ),
+ "Illegal Content-Security-Policy for default-src: *",
+ )
- self.make_request_and_maybe_assert_warn(
- create_headers(
- content_type="text/html; charset=UTF-8",
- content_security_policy="default-src 'self';script-src *",
- ),
- "Illegal Content-Security-Policy for script-src: *",
- )
+ self.make_request_and_maybe_assert_warn(
+ create_headers(
+ content_type="text/html; charset=UTF-8",
+ content_security_policy="default-src 'self';script-src *",
+ ),
+ "Illegal Content-Security-Policy for script-src: *",
+ )
+
+ self.make_request_and_maybe_assert_warn(
+ create_headers(
+ content_type="text/html; charset=UTF-8",
+ content_security_policy=(
+ "script-src * 'sha256-foo' 'nonce-bar';"
+ "style-src http://google.com;object-src *;"
+ "img-src 'unsafe-inline';default-src 'self';"
+ "script-src * 'strict-dynamic'"
+ ),
+ ),
+ "\n".join(
+ [
+ "Illegal Content-Security-Policy for script-src: *",
+ "Illegal Content-Security-Policy for script-src: 'nonce-bar'",
+ "Illegal Content-Security-Policy for style-src: http://google.com",
+ "Illegal Content-Security-Policy for object-src: *",
+ "Illegal Content-Security-Policy for img-src: 'unsafe-inline'",
+ "Illegal Content-Security-Policy for script-src: *",
+ "Illegal Content-Security-Policy for script-src: 'strict-dynamic'",
+ ]
+ ),
+ )
- self.make_request_and_maybe_assert_warn(
- create_headers(
+ def test_validate_csp_multiple_csp_headers(self):
+ base_headers = create_headers(
content_type="text/html; charset=UTF-8",
content_security_policy=(
- "script-src * 'sha256-foo' 'nonce-bar';"
- "style-src http://google.com;object-src *;"
- "img-src 'unsafe-inline';default-src 'self';"
- "script-src * 'strict-dynamic'"
+ "script-src * 'sha256-foo';" "style-src http://google.com"
),
- ),
- "\n".join(
- [
- "Illegal Content-Security-Policy for script-src: *",
- "Illegal Content-Security-Policy for script-src: 'nonce-bar'",
- "Illegal Content-Security-Policy for style-src: http://google.com",
- "Illegal Content-Security-Policy for object-src: *",
- "Illegal Content-Security-Policy for img-src: 'unsafe-inline'",
- "Illegal Content-Security-Policy for script-src: *",
- "Illegal Content-Security-Policy for script-src: 'strict-dynamic'",
- ]
- ),
- )
-
- def test_validate_csp_multiple_csp_headers(self):
- base_headers = create_headers(
- content_type="text/html; charset=UTF-8",
- content_security_policy=(
- "script-src * 'sha256-foo';"
- "style-src http://google.com"
- ),
- )
- base_headers.add(
- "Content-Security-Policy",
- "default-src 'self';script-src 'nonce-bar';object-src *",
- )
-
- self.make_request_and_maybe_assert_warn(
- base_headers,
- "\n".join(
- [
- "Illegal Content-Security-Policy for script-src: *",
- "Illegal Content-Security-Policy for style-src: http://google.com",
- "Illegal Content-Security-Policy for script-src: 'nonce-bar'",
- "Illegal Content-Security-Policy for object-src: *",
- ]
- ),
- )
-
- def test_validate_csp_non_text_html(self):
- self.make_request_and_assert_no_warn(
- create_headers(
- content_type="application/xhtml",
- content_security_policy=(
- "script-src * 'sha256-foo' 'nonce-bar';"
- "style-src http://google.com;object-src *;"
+ )
+ base_headers.add(
+ "Content-Security-Policy",
+ "default-src 'self';script-src 'nonce-bar';object-src *",
+ )
+
+ self.make_request_and_maybe_assert_warn(
+ base_headers,
+ "\n".join(
+ [
+ "Illegal Content-Security-Policy for script-src: *",
+ "Illegal Content-Security-Policy for style-src: http://google.com",
+ "Illegal Content-Security-Policy for script-src: 'nonce-bar'",
+ "Illegal Content-Security-Policy for object-src: *",
+ ]
),
- ),
- )
+ )
+
+ def test_validate_csp_non_text_html(self):
+ self.make_request_and_assert_no_warn(
+ create_headers(
+ content_type="application/xhtml",
+ content_security_policy=(
+ "script-src * 'sha256-foo' 'nonce-bar';"
+ "style-src http://google.com;object-src *;"
+ ),
+ ),
+ )
if __name__ == "__main__":
- tb_test.main()
+ tb_test.main()
diff --git a/tensorboard/compat/__init__.py b/tensorboard/compat/__init__.py
index 07ee19b17c..78cd610e9c 100644
--- a/tensorboard/compat/__init__.py
+++ b/tensorboard/compat/__init__.py
@@ -14,9 +14,9 @@
"""Compatibility interfaces for TensorBoard.
-This module provides logic for importing variations on the TensorFlow APIs, as
-lazily loaded imports to help avoid circular dependency issues and defer the
-search and loading of the module until necessary.
+This module provides logic for importing variations on the TensorFlow
+APIs, as lazily loaded imports to help avoid circular dependency issues
+and defer the search and loading of the module until necessary.
"""
from __future__ import absolute_import
@@ -28,78 +28,82 @@
import tensorboard.lazy as _lazy
-@_lazy.lazy_load('tensorboard.compat.tf')
+@_lazy.lazy_load("tensorboard.compat.tf")
def tf():
- """Provide the root module of a TF-like API for use within TensorBoard.
-
- By default this is equivalent to `import tensorflow as tf`, but it can be used
- in combination with //tensorboard/compat:tensorflow (to fall back to a stub TF
- API implementation if the real one is not available) or with
- //tensorboard/compat:no_tensorflow (to force unconditional use of the stub).
-
- Returns:
- The root module of a TF-like API, if available.
-
- Raises:
- ImportError: if a TF-like API is not available.
- """
- try:
- from tensorboard.compat import notf
- except ImportError:
+ """Provide the root module of a TF-like API for use within TensorBoard.
+
+ By default this is equivalent to `import tensorflow as tf`, but it can be used
+ in combination with //tensorboard/compat:tensorflow (to fall back to a stub TF
+ API implementation if the real one is not available) or with
+ //tensorboard/compat:no_tensorflow (to force unconditional use of the stub).
+
+ Returns:
+ The root module of a TF-like API, if available.
+
+ Raises:
+ ImportError: if a TF-like API is not available.
+ """
try:
- import tensorflow
- return tensorflow
+ from tensorboard.compat import notf
except ImportError:
- pass
- from tensorboard.compat import tensorflow_stub
- return tensorflow_stub
+ try:
+ import tensorflow
+
+ return tensorflow
+ except ImportError:
+ pass
+ from tensorboard.compat import tensorflow_stub
+ return tensorflow_stub
-@_lazy.lazy_load('tensorboard.compat.tf2')
+
+@_lazy.lazy_load("tensorboard.compat.tf2")
def tf2():
- """Provide the root module of a TF-2.0 API for use within TensorBoard.
+ """Provide the root module of a TF-2.0 API for use within TensorBoard.
- Returns:
- The root module of a TF-2.0 API, if available.
+ Returns:
+ The root module of a TF-2.0 API, if available.
- Raises:
- ImportError: if a TF-2.0 API is not available.
- """
- # Import the `tf` compat API from this file and check if it's already TF 2.0.
- if tf.__version__.startswith('2.'):
- return tf
- elif hasattr(tf, 'compat') and hasattr(tf.compat, 'v2'):
- # As a fallback, try `tensorflow.compat.v2` if it's defined.
- return tf.compat.v2
- raise ImportError('cannot import tensorflow 2.0 API')
+ Raises:
+ ImportError: if a TF-2.0 API is not available.
+ """
+ # Import the `tf` compat API from this file and check if it's already TF 2.0.
+ if tf.__version__.startswith("2."):
+ return tf
+ elif hasattr(tf, "compat") and hasattr(tf.compat, "v2"):
+ # As a fallback, try `tensorflow.compat.v2` if it's defined.
+ return tf.compat.v2
+ raise ImportError("cannot import tensorflow 2.0 API")
# TODO(https://github.com/tensorflow/tensorboard/issues/1711): remove this
-@_lazy.lazy_load('tensorboard.compat._pywrap_tensorflow')
+@_lazy.lazy_load("tensorboard.compat._pywrap_tensorflow")
def _pywrap_tensorflow():
- """Provide pywrap_tensorflow access in TensorBoard.
+ """Provide pywrap_tensorflow access in TensorBoard.
- pywrap_tensorflow cannot be accessed from tf.python.pywrap_tensorflow
- and needs to be imported using
- `from tensorflow.python import pywrap_tensorflow`. Therefore, we provide
- a separate accessor function for it here.
+ pywrap_tensorflow cannot be accessed from tf.python.pywrap_tensorflow
+ and needs to be imported using
+ `from tensorflow.python import pywrap_tensorflow`. Therefore, we provide
+ a separate accessor function for it here.
- NOTE: pywrap_tensorflow is not part of TensorFlow API and this
- dependency will go away soon.
+ NOTE: pywrap_tensorflow is not part of TensorFlow API and this
+ dependency will go away soon.
- Returns:
- pywrap_tensorflow import, if available.
+ Returns:
+ pywrap_tensorflow import, if available.
- Raises:
- ImportError: if we couldn't import pywrap_tensorflow.
- """
- try:
- from tensorboard.compat import notf
- except ImportError:
+ Raises:
+ ImportError: if we couldn't import pywrap_tensorflow.
+ """
try:
- from tensorflow.python import pywrap_tensorflow
- return pywrap_tensorflow
+ from tensorboard.compat import notf
except ImportError:
- pass
- from tensorboard.compat.tensorflow_stub import pywrap_tensorflow
- return pywrap_tensorflow
+ try:
+ from tensorflow.python import pywrap_tensorflow
+
+ return pywrap_tensorflow
+ except ImportError:
+ pass
+ from tensorboard.compat.tensorflow_stub import pywrap_tensorflow
+
+ return pywrap_tensorflow
diff --git a/tensorboard/compat/proto/proto_test.py b/tensorboard/compat/proto/proto_test.py
index 49885f2b0f..d05559b6df 100644
--- a/tensorboard/compat/proto/proto_test.py
+++ b/tensorboard/compat/proto/proto_test.py
@@ -15,8 +15,8 @@
"""Proto match tests between `tensorboard.compat.proto` and TensorFlow.
These tests verify that the local copy of TensorFlow protos are the same
-as those available directly from TensorFlow. Local protos are used to build
-`tensorboard-notf` without a TensorFlow dependency.
+as those available directly from TensorFlow. Local protos are used to
+build `tensorboard-notf` without a TensorFlow dependency.
"""
from __future__ import absolute_import
@@ -32,104 +32,162 @@
# Keep this list synced with BUILD in current directory
PROTO_IMPORTS = [
- ('tensorflow.core.framework.allocation_description_pb2',
- 'tensorboard.compat.proto.allocation_description_pb2'),
- ('tensorflow.core.framework.api_def_pb2',
- 'tensorboard.compat.proto.api_def_pb2'),
- ('tensorflow.core.framework.attr_value_pb2',
- 'tensorboard.compat.proto.attr_value_pb2'),
- ('tensorflow.core.protobuf.cluster_pb2',
- 'tensorboard.compat.proto.cluster_pb2'),
- ('tensorflow.core.protobuf.config_pb2',
- 'tensorboard.compat.proto.config_pb2'),
- ('tensorflow.core.framework.cost_graph_pb2',
- 'tensorboard.compat.proto.cost_graph_pb2'),
- ('tensorflow.python.framework.cpp_shape_inference_pb2',
- 'tensorboard.compat.proto.cpp_shape_inference_pb2'),
- ('tensorflow.core.protobuf.debug_pb2',
- 'tensorboard.compat.proto.debug_pb2'),
- ('tensorflow.core.util.event_pb2',
- 'tensorboard.compat.proto.event_pb2'),
- ('tensorflow.core.framework.function_pb2',
- 'tensorboard.compat.proto.function_pb2'),
- ('tensorflow.core.framework.graph_pb2',
- 'tensorboard.compat.proto.graph_pb2'),
- ('tensorflow.core.protobuf.meta_graph_pb2',
- 'tensorboard.compat.proto.meta_graph_pb2'),
- ('tensorflow.core.framework.node_def_pb2',
- 'tensorboard.compat.proto.node_def_pb2'),
- ('tensorflow.core.framework.op_def_pb2',
- 'tensorboard.compat.proto.op_def_pb2'),
- ('tensorflow.core.framework.resource_handle_pb2',
- 'tensorboard.compat.proto.resource_handle_pb2'),
- ('tensorflow.core.protobuf.rewriter_config_pb2',
- 'tensorboard.compat.proto.rewriter_config_pb2'),
- ('tensorflow.core.protobuf.saved_object_graph_pb2',
- 'tensorboard.compat.proto.saved_object_graph_pb2'),
- ('tensorflow.core.protobuf.saver_pb2',
- 'tensorboard.compat.proto.saver_pb2'),
- ('tensorflow.core.framework.step_stats_pb2',
- 'tensorboard.compat.proto.step_stats_pb2'),
- ('tensorflow.core.protobuf.struct_pb2',
- 'tensorboard.compat.proto.struct_pb2'),
- ('tensorflow.core.framework.summary_pb2',
- 'tensorboard.compat.proto.summary_pb2'),
- ('tensorflow.core.framework.tensor_pb2',
- 'tensorboard.compat.proto.tensor_pb2'),
- ('tensorflow.core.framework.tensor_description_pb2',
- 'tensorboard.compat.proto.tensor_description_pb2'),
- ('tensorflow.core.framework.tensor_shape_pb2',
- 'tensorboard.compat.proto.tensor_shape_pb2'),
- ('tensorflow.core.profiler.tfprof_log_pb2',
- 'tensorboard.compat.proto.tfprof_log_pb2'),
- ('tensorflow.core.protobuf.trackable_object_graph_pb2',
- 'tensorboard.compat.proto.trackable_object_graph_pb2'),
- ('tensorflow.core.framework.types_pb2',
- 'tensorboard.compat.proto.types_pb2'),
- ('tensorflow.core.framework.variable_pb2',
- 'tensorboard.compat.proto.variable_pb2'),
- ('tensorflow.core.framework.versions_pb2',
- 'tensorboard.compat.proto.versions_pb2'),
+ (
+ "tensorflow.core.framework.allocation_description_pb2",
+ "tensorboard.compat.proto.allocation_description_pb2",
+ ),
+ (
+ "tensorflow.core.framework.api_def_pb2",
+ "tensorboard.compat.proto.api_def_pb2",
+ ),
+ (
+ "tensorflow.core.framework.attr_value_pb2",
+ "tensorboard.compat.proto.attr_value_pb2",
+ ),
+ (
+ "tensorflow.core.protobuf.cluster_pb2",
+ "tensorboard.compat.proto.cluster_pb2",
+ ),
+ (
+ "tensorflow.core.protobuf.config_pb2",
+ "tensorboard.compat.proto.config_pb2",
+ ),
+ (
+ "tensorflow.core.framework.cost_graph_pb2",
+ "tensorboard.compat.proto.cost_graph_pb2",
+ ),
+ (
+ "tensorflow.python.framework.cpp_shape_inference_pb2",
+ "tensorboard.compat.proto.cpp_shape_inference_pb2",
+ ),
+ (
+ "tensorflow.core.protobuf.debug_pb2",
+ "tensorboard.compat.proto.debug_pb2",
+ ),
+ ("tensorflow.core.util.event_pb2", "tensorboard.compat.proto.event_pb2"),
+ (
+ "tensorflow.core.framework.function_pb2",
+ "tensorboard.compat.proto.function_pb2",
+ ),
+ (
+ "tensorflow.core.framework.graph_pb2",
+ "tensorboard.compat.proto.graph_pb2",
+ ),
+ (
+ "tensorflow.core.protobuf.meta_graph_pb2",
+ "tensorboard.compat.proto.meta_graph_pb2",
+ ),
+ (
+ "tensorflow.core.framework.node_def_pb2",
+ "tensorboard.compat.proto.node_def_pb2",
+ ),
+ (
+ "tensorflow.core.framework.op_def_pb2",
+ "tensorboard.compat.proto.op_def_pb2",
+ ),
+ (
+ "tensorflow.core.framework.resource_handle_pb2",
+ "tensorboard.compat.proto.resource_handle_pb2",
+ ),
+ (
+ "tensorflow.core.protobuf.rewriter_config_pb2",
+ "tensorboard.compat.proto.rewriter_config_pb2",
+ ),
+ (
+ "tensorflow.core.protobuf.saved_object_graph_pb2",
+ "tensorboard.compat.proto.saved_object_graph_pb2",
+ ),
+ (
+ "tensorflow.core.protobuf.saver_pb2",
+ "tensorboard.compat.proto.saver_pb2",
+ ),
+ (
+ "tensorflow.core.framework.step_stats_pb2",
+ "tensorboard.compat.proto.step_stats_pb2",
+ ),
+ (
+ "tensorflow.core.protobuf.struct_pb2",
+ "tensorboard.compat.proto.struct_pb2",
+ ),
+ (
+ "tensorflow.core.framework.summary_pb2",
+ "tensorboard.compat.proto.summary_pb2",
+ ),
+ (
+ "tensorflow.core.framework.tensor_pb2",
+ "tensorboard.compat.proto.tensor_pb2",
+ ),
+ (
+ "tensorflow.core.framework.tensor_description_pb2",
+ "tensorboard.compat.proto.tensor_description_pb2",
+ ),
+ (
+ "tensorflow.core.framework.tensor_shape_pb2",
+ "tensorboard.compat.proto.tensor_shape_pb2",
+ ),
+ (
+ "tensorflow.core.profiler.tfprof_log_pb2",
+ "tensorboard.compat.proto.tfprof_log_pb2",
+ ),
+ (
+ "tensorflow.core.protobuf.trackable_object_graph_pb2",
+ "tensorboard.compat.proto.trackable_object_graph_pb2",
+ ),
+ (
+ "tensorflow.core.framework.types_pb2",
+ "tensorboard.compat.proto.types_pb2",
+ ),
+ (
+ "tensorflow.core.framework.variable_pb2",
+ "tensorboard.compat.proto.variable_pb2",
+ ),
+ (
+ "tensorflow.core.framework.versions_pb2",
+ "tensorboard.compat.proto.versions_pb2",
+ ),
]
PROTO_REPLACEMENTS = [
- ('tensorflow/core/framework/', 'tensorboard/compat/proto/'),
- ('tensorflow/core/protobuf/', 'tensorboard/compat/proto/'),
- ('tensorflow/core/profiler/', 'tensorboard/compat/proto/'),
- ('tensorflow/python/framework/', 'tensorboard/compat/proto/'),
- ('tensorflow/core/util/', 'tensorboard/compat/proto/'),
- ('package: "tensorflow.tfprof"', 'package: "tensorboard"'),
- ('package: "tensorflow"', 'package: "tensorboard"'),
- ('type_name: ".tensorflow.tfprof', 'type_name: ".tensorboard'),
- ('type_name: ".tensorflow', 'type_name: ".tensorboard'),
+ ("tensorflow/core/framework/", "tensorboard/compat/proto/"),
+ ("tensorflow/core/protobuf/", "tensorboard/compat/proto/"),
+ ("tensorflow/core/profiler/", "tensorboard/compat/proto/"),
+ ("tensorflow/python/framework/", "tensorboard/compat/proto/"),
+ ("tensorflow/core/util/", "tensorboard/compat/proto/"),
+ ('package: "tensorflow.tfprof"', 'package: "tensorboard"'),
+ ('package: "tensorflow"', 'package: "tensorboard"'),
+ ('type_name: ".tensorflow.tfprof', 'type_name: ".tensorboard'),
+ ('type_name: ".tensorflow', 'type_name: ".tensorboard'),
]
class ProtoMatchTest(tf.test.TestCase):
+ def test_each_proto_matches_tensorflow(self):
+ for tf_path, tb_path in PROTO_IMPORTS:
+ tf_pb2 = importlib.import_module(tf_path)
+ tb_pb2 = importlib.import_module(tb_path)
+ expected = descriptor_pb2.FileDescriptorProto()
+ actual = descriptor_pb2.FileDescriptorProto()
+ tf_pb2.DESCRIPTOR.CopyToProto(expected)
+ tb_pb2.DESCRIPTOR.CopyToProto(actual)
- def test_each_proto_matches_tensorflow(self):
- for tf_path, tb_path in PROTO_IMPORTS:
- tf_pb2 = importlib.import_module(tf_path)
- tb_pb2 = importlib.import_module(tb_path)
- expected = descriptor_pb2.FileDescriptorProto()
- actual = descriptor_pb2.FileDescriptorProto()
- tf_pb2.DESCRIPTOR.CopyToProto(expected)
- tb_pb2.DESCRIPTOR.CopyToProto(actual)
+ # Convert expected to be actual since this matches the
+ # replacements done in proto/update.sh
+ actual = str(actual)
+ expected = str(expected)
+ for orig, repl in PROTO_REPLACEMENTS:
+ expected = expected.replace(orig, repl)
- # Convert expected to be actual since this matches the
- # replacements done in proto/update.sh
- actual = str(actual)
- expected = str(expected)
- for orig, repl in PROTO_REPLACEMENTS:
- expected = expected.replace(orig, repl)
+ diff = difflib.unified_diff(
+ actual.splitlines(1), expected.splitlines(1)
+ )
+ diff = "".join(diff)
- diff = difflib.unified_diff(actual.splitlines(1),
- expected.splitlines(1))
- diff = ''.join(diff)
+ self.assertEquals(
+ diff,
+ "",
+ "{} and {} did not match:\n{}".format(tf_path, tb_path, diff),
+ )
- self.assertEquals(diff, '',
- '{} and {} did not match:\n{}'.format(tf_path, tb_path, diff))
-
-if __name__ == '__main__':
- tf.test.main()
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorboard/compat/tensorflow_stub/__init__.py b/tensorboard/compat/tensorflow_stub/__init__.py
index b94f654a28..8b3cf3ffa2 100644
--- a/tensorboard/compat/tensorflow_stub/__init__.py
+++ b/tensorboard/compat/tensorflow_stub/__init__.py
@@ -38,4 +38,4 @@
compat.v1.errors = errors
# Set a fake __version__ to help distinguish this as our own stub API.
-__version__ = 'stub'
+__version__ = "stub"
diff --git a/tensorboard/compat/tensorflow_stub/app.py b/tensorboard/compat/tensorflow_stub/app.py
index b49e6e47b4..6c3f265a41 100644
--- a/tensorboard/compat/tensorflow_stub/app.py
+++ b/tensorboard/compat/tensorflow_stub/app.py
@@ -31,13 +31,13 @@ def _usage(shorthelp):
shorthelp: bool, if True, prints only flags from the main module,
rather than all flags.
"""
- doc = _sys.modules['__main__'].__doc__
+ doc = _sys.modules["__main__"].__doc__
if not doc:
- doc = '\nUSAGE: %s [flags]\n' % _sys.argv[0]
- doc = flags.text_wrap(doc, indent=' ', firstline_indent='')
+ doc = "\nUSAGE: %s [flags]\n" % _sys.argv[0]
+ doc = flags.text_wrap(doc, indent=" ", firstline_indent="")
else:
# Replace all '%s' with sys.argv[0], and all '%%' with '%'.
- num_specifiers = doc.count('%') - 2 * doc.count('%%')
+ num_specifiers = doc.count("%") - 2 * doc.count("%%")
try:
doc %= (_sys.argv[0],) * num_specifiers
except (OverflowError, TypeError, ValueError):
@@ -50,9 +50,9 @@ def _usage(shorthelp):
try:
_sys.stdout.write(doc)
if flag_str:
- _sys.stdout.write('\nflags:\n')
+ _sys.stdout.write("\nflags:\n")
_sys.stdout.write(flag_str)
- _sys.stdout.write('\n')
+ _sys.stdout.write("\n")
except IOError as e:
# We avoid printing a huge backtrace if we get EPIPE, because
# "foo.par --help | less" is a frequent use case.
@@ -62,24 +62,27 @@ def _usage(shorthelp):
class _HelpFlag(flags.BooleanFlag):
"""Special boolean flag that displays usage and raises SystemExit."""
- NAME = 'help'
- SHORT_NAME = 'h'
+
+ NAME = "help"
+ SHORT_NAME = "h"
def __init__(self):
super(_HelpFlag, self).__init__(
- self.NAME, False, 'show this help', short_name=self.SHORT_NAME)
+ self.NAME, False, "show this help", short_name=self.SHORT_NAME
+ )
def parse(self, arg):
if arg:
_usage(shorthelp=True)
print()
- print('Try --helpfull to get a list of all flags.')
+ print("Try --helpfull to get a list of all flags.")
_sys.exit(1)
class _HelpshortFlag(_HelpFlag):
"""--helpshort is an alias for --help."""
- NAME = 'helpshort'
+
+ NAME = "helpshort"
SHORT_NAME = None
@@ -87,12 +90,7 @@ class _HelpfullFlag(flags.BooleanFlag):
"""Display help for flags in main module and all dependent modules."""
def __init__(self):
- super(
- _HelpfullFlag,
- self).__init__(
- 'helpfull',
- False,
- 'show full help')
+ super(_HelpfullFlag, self).__init__("helpfull", False, "show full help")
def parse(self, arg):
if arg:
@@ -122,7 +120,7 @@ def run(main=None, argv=None):
# Parse known flags.
argv = flags.FLAGS(_sys.argv if argv is None else argv, known_only=True)
- main = main or _sys.modules['__main__'].main
+ main = main or _sys.modules["__main__"].main
# Call the main function, passing through any arguments
# to the final program.
diff --git a/tensorboard/compat/tensorflow_stub/compat/__init__.py b/tensorboard/compat/tensorflow_stub/compat/__init__.py
index a7dd0d582e..1938e8d65c 100644
--- a/tensorboard/compat/tensorflow_stub/compat/__init__.py
+++ b/tensorboard/compat/tensorflow_stub/compat/__init__.py
@@ -39,7 +39,8 @@
def as_bytes(bytes_or_text, encoding="utf-8"):
- """Converts either bytes or unicode to `bytes`, using utf-8 encoding for text.
+ """Converts either bytes or unicode to `bytes`, using utf-8 encoding for
+ text.
Args:
bytes_or_text: A `bytes`, `str`, or `unicode` object.
@@ -56,7 +57,9 @@ def as_bytes(bytes_or_text, encoding="utf-8"):
elif isinstance(bytes_or_text, bytes):
return bytes_or_text
else:
- raise TypeError("Expected binary or unicode string, got %r" % (bytes_or_text,))
+ raise TypeError(
+ "Expected binary or unicode string, got %r" % (bytes_or_text,)
+ )
def as_text(bytes_or_text, encoding="utf-8"):
@@ -77,7 +80,9 @@ def as_text(bytes_or_text, encoding="utf-8"):
elif isinstance(bytes_or_text, bytes):
return bytes_or_text.decode(encoding)
else:
- raise TypeError("Expected binary or unicode string, got %r" % bytes_or_text)
+ raise TypeError(
+ "Expected binary or unicode string, got %r" % bytes_or_text
+ )
# Convert an object to a `str` in both Python 2 and 3.
@@ -109,7 +114,8 @@ def as_str_any(value):
# @tf_export('compat.path_to_str')
def path_to_str(path):
- """Returns the file system path representation of a `PathLike` object, else as it is.
+ """Returns the file system path representation of a `PathLike` object, else
+ as it is.
Args:
path: An object that can be converted to path representation.
diff --git a/tensorboard/compat/tensorflow_stub/dtypes.py b/tensorboard/compat/tensorflow_stub/dtypes.py
index 9115a5b8a0..6a1f864123 100644
--- a/tensorboard/compat/tensorflow_stub/dtypes.py
+++ b/tensorboard/compat/tensorflow_stub/dtypes.py
@@ -74,7 +74,6 @@ def __init__(self, type_enum):
Raises:
TypeError: If `type_enum` is not a value `types_pb2.DataType`.
-
"""
# TODO(mrry): Make the necessary changes (using __new__) to ensure
# that calling this returns one of the interned values.
@@ -136,7 +135,7 @@ def as_datatype_enum(self):
@property
def is_bool(self):
- """Returns whether this is a boolean data type"""
+ """Returns whether this is a boolean data type."""
return self.base_dtype == bool
@property
@@ -150,9 +149,11 @@ def is_integer(self):
@property
def is_floating(self):
- """Returns whether this is a (non-quantized, real) floating point type."""
+ """Returns whether this is a (non-quantized, real) floating point
+ type."""
return (
- self.is_numpy_compatible and np.issubdtype(self.as_numpy_dtype, np.floating)
+ self.is_numpy_compatible
+ and np.issubdtype(self.as_numpy_dtype, np.floating)
) or self.base_dtype == bfloat16
@property
@@ -186,7 +187,6 @@ def min(self):
Raises:
TypeError: if this is a non-numeric, unordered, or quantized type.
-
"""
if self.is_quantized or self.base_dtype in (
bool,
@@ -214,7 +214,6 @@ def max(self):
Raises:
TypeError: if this is a non-numeric, unordered, or quantized type.
-
"""
if self.is_quantized or self.base_dtype in (
bool,
@@ -239,6 +238,7 @@ def max(self):
@property
def limits(self, clip_negative=True):
"""Return intensity limits, i.e. (min, max) tuple, of the dtype.
+
Args:
clip_negative : bool, optional
If True, clip the negative range (i.e. return 0 for min intensity)
@@ -247,7 +247,9 @@ def limits(self, clip_negative=True):
min, max : tuple
Lower and upper intensity limits.
"""
- min, max = dtype_range[self.as_numpy_dtype] # pylint: disable=redefined-builtin
+ min, max = dtype_range[
+ self.as_numpy_dtype
+ ] # pylint: disable=redefined-builtin
if clip_negative:
min = 0 # pylint: disable=redefined-builtin
return min, max
@@ -329,9 +331,9 @@ def size(self):
np.uint16: (0, 65535),
np.int8: (-128, 127),
np.int16: (-32768, 32767),
- np.int64: (-2 ** 63, 2 ** 63 - 1),
+ np.int64: (-(2 ** 63), 2 ** 63 - 1),
np.uint64: (0, 2 ** 64 - 1),
- np.int32: (-2 ** 31, 2 ** 31 - 1),
+ np.int32: (-(2 ** 31), 2 ** 31 - 1),
np.uint32: (0, 2 ** 32 - 1),
np.float32: (-1, 1),
np.float64: (-1, 1),
@@ -523,7 +525,9 @@ def size(self):
types_pb2.DT_RESOURCE_REF: "resource_ref",
types_pb2.DT_VARIANT_REF: "variant_ref",
}
-_STRING_TO_TF = {value: _INTERN_TABLE[key] for key, value in _TYPE_TO_STRING.items()}
+_STRING_TO_TF = {
+ value: _INTERN_TABLE[key] for key, value in _TYPE_TO_STRING.items()
+}
# Add non-canonical aliases.
_STRING_TO_TF["half"] = float16
_STRING_TO_TF["half_ref"] = float16_ref
@@ -687,4 +691,6 @@ def as_dtype(type_value):
"Cannot convert {} to a dtype. {}".format(type_value, e)
)
- raise TypeError("Cannot convert value %r to a TensorFlow DType." % type_value)
+ raise TypeError(
+ "Cannot convert value %r to a TensorFlow DType." % type_value
+ )
diff --git a/tensorboard/compat/tensorflow_stub/error_codes.py b/tensorboard/compat/tensorflow_stub/error_codes.py
index d599ba24db..6d8a12e746 100644
--- a/tensorboard/compat/tensorflow_stub/error_codes.py
+++ b/tensorboard/compat/tensorflow_stub/error_codes.py
@@ -21,62 +21,62 @@
# The operation was cancelled (typically by the caller).
CANCELLED = 1
-'''
+"""
Unknown error. An example of where this error may be returned is
if a Status value received from another address space belongs to
an error-space that is not known in this address space. Also
errors raised by APIs that do not return enough error information
may be converted to this error.
-'''
+"""
UNKNOWN = 2
-'''
+"""
Client specified an invalid argument. Note that this differs
from FAILED_PRECONDITION. INVALID_ARGUMENT indicates arguments
that are problematic regardless of the state of the system
(e.g., a malformed file name).
-'''
+"""
INVALID_ARGUMENT = 3
-'''
+"""
Deadline expired before operation could complete. For operations
that change the state of the system, this error may be returned
even if the operation has completed successfully. For example, a
successful response from a server could have been delayed long
enough for the deadline to expire.
-'''
+"""
DEADLINE_EXCEEDED = 4
-'''
+"""
Some requested entity (e.g., file or directory) was not found.
For privacy reasons, this code *may* be returned when the client
does not have the access right to the entity.
-'''
+"""
NOT_FOUND = 5
-'''
+"""
Some entity that we attempted to create (e.g., file or directory)
already exists.
-'''
+"""
ALREADY_EXISTS = 6
-'''
+"""
The caller does not have permission to execute the specified
operation. PERMISSION_DENIED must not be used for rejections
caused by exhausting some resource (use RESOURCE_EXHAUSTED
instead for those errors). PERMISSION_DENIED must not be
used if the caller can not be identified (use UNAUTHENTICATED
instead for those errors).
-'''
+"""
PERMISSION_DENIED = 7
-'''
+"""
Some resource has been exhausted, perhaps a per-user quota, or
perhaps the entire file system is out of space.
-'''
+"""
RESOURCE_EXHAUSTED = 8
-'''
+"""
Operation was rejected because the system is not in a state
required for the operation's execution. For example, directory
to be deleted may be non-empty, an rmdir operation is applied to
@@ -95,19 +95,19 @@
REST Get/Update/Delete on a resource and the resource on the
server does not match the condition. E.g., conflicting
read-modify-write on the same resource.
-'''
+"""
FAILED_PRECONDITION = 9
-'''
+"""
The operation was aborted, typically due to a concurrency issue
like sequencer check failures, transaction aborts, etc.
See litmus test above for deciding between FAILED_PRECONDITION,
ABORTED, and UNAVAILABLE.
-'''
+"""
ABORTED = 10
-'''
+"""
Operation tried to iterate past the valid input range. E.g., seeking or
reading past end of file.
@@ -123,39 +123,39 @@
error) when it applies so that callers who are iterating through
a space can easily look for an OUT_OF_RANGE error to detect when
they are done.
-'''
+"""
OUT_OF_RANGE = 11
# Operation is not implemented or not supported/enabled in this service.
UNIMPLEMENTED = 12
-'''
+"""
Internal errors. Means some invariant expected by the underlying
system has been broken. If you see one of these errors,
something is very broken.
-'''
+"""
INTERNAL = 13
-'''
+"""
The service is currently unavailable. This is a most likely a
transient condition and may be corrected by retrying with
a backoff.
See litmus test above for deciding between FAILED_PRECONDITION,
ABORTED, and UNAVAILABLE.
-'''
+"""
UNAVAILABLE = 14
# Unrecoverable data loss or corruption.
DATA_LOSS = 15
-'''
+"""
The request does not have valid authentication credentials for the
operation.
-'''
+"""
UNAUTHENTICATED = 16
-'''
+"""
An extra enum entry to prevent people from writing code that
fails to compile when a new code is added.
@@ -165,5 +165,5 @@
Nobody should rely on the value (currently 20) listed here. It
may change in the future.
-'''
+"""
DO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_USE_DEFAULT_IN_SWITCH_INSTEAD_ = 20
diff --git a/tensorboard/compat/tensorflow_stub/errors.py b/tensorboard/compat/tensorflow_stub/errors.py
index 4c8ba89613..3f9f6cfc8d 100644
--- a/tensorboard/compat/tensorflow_stub/errors.py
+++ b/tensorboard/compat/tensorflow_stub/errors.py
@@ -82,7 +82,8 @@ def node_def(self):
def __str__(self):
if self._op is not None:
output = [
- "%s\n\nCaused by op %r, defined at:\n" % (self.message, self._op.name)
+ "%s\n\nCaused by op %r, defined at:\n"
+ % (self.message, self._op.name)
]
curr_traceback_list = traceback.format_list(self._op.traceback)
output.extend(curr_traceback_list)
@@ -95,7 +96,9 @@ def __str__(self):
% (original_op.name,)
)
prev_traceback_list = curr_traceback_list
- curr_traceback_list = traceback.format_list(original_op.traceback)
+ curr_traceback_list = traceback.format_list(
+ original_op.traceback
+ )
# Attempt to elide large common subsequences of the subsequent
# stack traces.
@@ -104,7 +107,9 @@ def __str__(self):
is_eliding = False
elide_count = 0
last_elided_line = None
- for line, line_in_prev in zip(curr_traceback_list, prev_traceback_list):
+ for line, line_in_prev in zip(
+ curr_traceback_list, prev_traceback_list
+ ):
if line == line_in_prev:
if is_eliding:
elide_count += 1
@@ -199,8 +204,6 @@ def __init__(self, node_def, op, message):
super(CancelledError, self).__init__(node_def, op, message, CANCELLED)
-
-
# @tf_export("errors.UnknownError")
class UnknownError(OpError):
"""Unknown error.
@@ -259,7 +262,8 @@ def __init__(self, node_def, op, message):
# @tf_export("errors.NotFoundError")
class NotFoundError(OpError):
- """Raised when a requested entity (e.g., a file or directory) was not found.
+ """Raised when a requested entity (e.g., a file or directory) was not
+ found.
For example, running the
@{tf.WholeFileReader.read}
@@ -288,7 +292,9 @@ class AlreadyExistsError(OpError):
def __init__(self, node_def, op, message):
"""Creates an `AlreadyExistsError`."""
- super(AlreadyExistsError, self).__init__(node_def, op, message, ALREADY_EXISTS)
+ super(AlreadyExistsError, self).__init__(
+ node_def, op, message, ALREADY_EXISTS
+ )
# @tf_export("errors.PermissionDeniedError")
@@ -345,7 +351,8 @@ def __init__(self, node_def, op, message):
# @tf_export("errors.FailedPreconditionError")
class FailedPreconditionError(OpError):
- """Operation was rejected because the system is not in a state to execute it.
+ """Operation was rejected because the system is not in a state to execute
+ it.
This exception is most commonly raised when running an operation
that reads a @{tf.Variable}
@@ -394,7 +401,9 @@ class OutOfRangeError(OpError):
def __init__(self, node_def, op, message):
"""Creates an `OutOfRangeError`."""
- super(OutOfRangeError, self).__init__(node_def, op, message, OUT_OF_RANGE)
+ super(OutOfRangeError, self).__init__(
+ node_def, op, message, OUT_OF_RANGE
+ )
# @tf_export("errors.UnimplementedError")
@@ -412,7 +421,9 @@ class UnimplementedError(OpError):
def __init__(self, node_def, op, message):
"""Creates an `UnimplementedError`."""
- super(UnimplementedError, self).__init__(node_def, op, message, UNIMPLEMENTED)
+ super(UnimplementedError, self).__init__(
+ node_def, op, message, UNIMPLEMENTED
+ )
# @tf_export("errors.InternalError")
@@ -441,7 +452,9 @@ class UnavailableError(OpError):
def __init__(self, node_def, op, message):
"""Creates an `UnavailableError`."""
- super(UnavailableError, self).__init__(node_def, op, message, UNAVAILABLE)
+ super(UnavailableError, self).__init__(
+ node_def, op, message, UNAVAILABLE
+ )
# @tf_export("errors.DataLossError")
diff --git a/tensorboard/compat/tensorflow_stub/flags.py b/tensorboard/compat/tensorflow_stub/flags.py
index dfbcdd5cb5..98f67aba21 100644
--- a/tensorboard/compat/tensorflow_stub/flags.py
+++ b/tensorboard/compat/tensorflow_stub/flags.py
@@ -13,7 +13,10 @@
# limitations under the License.
# ==============================================================================
-"""Import router for absl.flags. See https://github.com/abseil/abseil-py."""
+"""Import router for absl.flags.
+
+See https://github.com/abseil/abseil-py.
+"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -62,9 +65,9 @@ def wrapper(*args, **kwargs):
class _FlagValuesWrapper(object):
"""Wrapper class for absl.flags.FLAGS.
- The difference is that tf.compat.v1.flags.FLAGS implicitly parses flags with sys.argv
- when accessing the FLAGS values before it's explicitly parsed,
- while absl.flags.FLAGS raises an exception.
+ The difference is that tf.compat.v1.flags.FLAGS implicitly parses
+ flags with sys.argv when accessing the FLAGS values before it's
+ explicitly parsed, while absl.flags.FLAGS raises an exception.
"""
def __init__(self, flags_object):
diff --git a/tensorboard/compat/tensorflow_stub/io/gfile.py b/tensorboard/compat/tensorflow_stub/io/gfile.py
index d0dd2f588a..e56bb96917 100644
--- a/tensorboard/compat/tensorflow_stub/io/gfile.py
+++ b/tensorboard/compat/tensorflow_stub/io/gfile.py
@@ -14,10 +14,10 @@
# ==============================================================================
"""A limited reimplementation of the TensorFlow FileIO API.
-The TensorFlow version wraps the C++ FileSystem API. Here we provide a pure
-Python implementation, limited to the features required for TensorBoard. This
-allows running TensorBoard without depending on TensorFlow for file operations.
-
+The TensorFlow version wraps the C++ FileSystem API. Here we provide a
+pure Python implementation, limited to the features required for
+TensorBoard. This allows running TensorBoard without depending on
+TensorFlow for file operations.
"""
from __future__ import absolute_import
from __future__ import division
@@ -34,9 +34,11 @@
import sys
import tempfile
import uuid
+
try:
import botocore.exceptions
import boto3
+
S3_ENABLED = True
except ImportError:
S3_ENABLED = False
@@ -62,7 +64,7 @@
def register_filesystem(prefix, filesystem):
- if ':' in prefix:
+ if ":" in prefix:
raise ValueError("Filesystem prefix cannot contain a :")
_REGISTERED_FILESYSTEMS[prefix] = filesystem
@@ -119,7 +121,8 @@ def read(self, filename, binary_mode=False, size=None, continue_from=None):
encoding = None if binary_mode else "utf8"
if not exists(filename):
raise errors.NotFoundError(
- None, None, 'Not Found: ' + compat.as_text(filename))
+ None, None, "Not Found: " + compat.as_text(filename)
+ )
offset = None
if continue_from is not None:
offset = continue_from.get("opaque_offset", None)
@@ -134,8 +137,8 @@ def read(self, filename, binary_mode=False, size=None, continue_from=None):
return (data, continuation_token)
def write(self, filename, file_content, binary_mode=False):
- """Writes string file contents to a file, overwriting any
- existing contents.
+ """Writes string file contents to a file, overwriting any existing
+ contents.
Args:
filename: string, a path
@@ -166,8 +169,7 @@ def glob(self, filename):
return [
# Convert the filenames to string from bytes.
compat.as_str_any(matching_filename)
- for matching_filename in py_glob.glob(
- compat.as_bytes(filename))
+ for matching_filename in py_glob.glob(compat.as_bytes(filename))
]
else:
return [
@@ -175,7 +177,8 @@ def glob(self, filename):
compat.as_str_any(matching_filename)
for single_filename in filename
for matching_filename in py_glob.glob(
- compat.as_bytes(single_filename))
+ compat.as_bytes(single_filename)
+ )
]
def isdir(self, dirname):
@@ -197,7 +200,8 @@ def makedirs(self, path):
os.makedirs(path)
except FileExistsError:
raise errors.AlreadyExistsError(
- None, None, "Directory already exists")
+ None, None, "Directory already exists"
+ )
def stat(self, filename):
"""Returns file statistics for a given path."""
@@ -221,10 +225,10 @@ def bucket_and_path(self, url):
"""Split an S3-prefixed URL into bucket and path."""
url = compat.as_str_any(url)
if url.startswith("s3://"):
- url = url[len("s3://"):]
+ url = url[len("s3://") :]
idx = url.index("/")
bucket = url[:idx]
- path = url[(idx + 1):]
+ path = url[(idx + 1) :]
return bucket, path
def exists(self, filename):
@@ -270,44 +274,44 @@ def read(self, filename, binary_mode=False, size=None, continue_from=None):
if continue_from is not None:
offset = continue_from.get("byte_offset", 0)
- endpoint = ''
+ endpoint = ""
if size is not None:
# TODO(orionr): This endpoint risks splitting a multi-byte
# character or splitting \r and \n in the case of CRLFs,
# producing decoding errors below.
endpoint = offset + size
- if offset != 0 or endpoint != '':
+ if offset != 0 or endpoint != "":
# Asked for a range, so modify the request
- args['Range'] = 'bytes={}-{}'.format(offset, endpoint)
+ args["Range"] = "bytes={}-{}".format(offset, endpoint)
try:
- stream = s3.Object(bucket, path).get(**args)['Body'].read()
+ stream = s3.Object(bucket, path).get(**args)["Body"].read()
except botocore.exceptions.ClientError as exc:
- if exc.response['Error']['Code'] == '416':
+ if exc.response["Error"]["Code"] == "416":
if size is not None:
# Asked for too much, so request just to the end. Do this
# in a second request so we don't check length in all cases.
client = boto3.client("s3")
obj = client.head_object(Bucket=bucket, Key=path)
- content_length = obj['ContentLength']
+ content_length = obj["ContentLength"]
endpoint = min(content_length, offset + size)
if offset == endpoint:
# Asked for no bytes, so just return empty
- stream = b''
+ stream = b""
else:
- args['Range'] = 'bytes={}-{}'.format(offset, endpoint)
- stream = s3.Object(bucket, path).get(**args)['Body'].read()
+ args["Range"] = "bytes={}-{}".format(offset, endpoint)
+ stream = s3.Object(bucket, path).get(**args)["Body"].read()
else:
raise
# `stream` should contain raw bytes here (i.e., there has been neither
# decoding nor newline translation), so the byte offset increases by
# the expected amount.
- continuation_token = {'byte_offset': (offset + len(stream))}
+ continuation_token = {"byte_offset": (offset + len(stream))}
if binary_mode:
return (bytes(stream), continuation_token)
else:
- return (stream.decode('utf-8'), continuation_token)
+ return (stream.decode("utf-8"), continuation_token)
def write(self, filename, file_content, binary_mode=False):
"""Writes string file contents to a file.
@@ -322,7 +326,7 @@ def write(self, filename, file_content, binary_mode=False):
# Always convert to bytes for writing
if binary_mode:
if not isinstance(file_content, six.binary_type):
- raise TypeError('File content type must be bytes')
+ raise TypeError("File content type must be bytes")
else:
file_content = compat.as_bytes(file_content)
client.put_object(Body=file_content, Bucket=bucket, Key=path)
@@ -330,11 +334,12 @@ def write(self, filename, file_content, binary_mode=False):
def glob(self, filename):
"""Returns a list of files that match the given pattern(s)."""
# Only support prefix with * at the end and no ? in the string
- star_i = filename.find('*')
- quest_i = filename.find('?')
+ star_i = filename.find("*")
+ quest_i = filename.find("?")
if quest_i >= 0:
raise NotImplementedError(
- "{} not supported by compat glob".format(filename))
+ "{} not supported by compat glob".format(filename)
+ )
if star_i != len(filename) - 1:
# Just return empty so we can use glob from directory watcher
#
@@ -349,7 +354,7 @@ def glob(self, filename):
keys = []
for r in p.paginate(Bucket=bucket, Prefix=path):
for o in r.get("Contents", []):
- key = o["Key"][len(path):]
+ key = o["Key"][len(path) :]
if key: # Skip the base dir, which would add an empty string
keys.append(filename + key)
return keys
@@ -375,9 +380,10 @@ def listdir(self, dirname):
keys = []
for r in p.paginate(Bucket=bucket, Prefix=path, Delimiter="/"):
keys.extend(
- o["Prefix"][len(path):-1] for o in r.get("CommonPrefixes", []))
+ o["Prefix"][len(path) : -1] for o in r.get("CommonPrefixes", [])
+ )
for o in r.get("Contents", []):
- key = o["Key"][len(path):]
+ key = o["Key"][len(path) :]
if key: # Skip the base dir, which would add an empty string
keys.append(key)
return keys
@@ -386,12 +392,13 @@ def makedirs(self, dirname):
"""Creates a directory and all parent/intermediate directories."""
if self.exists(dirname):
raise errors.AlreadyExistsError(
- None, None, "Directory already exists")
+ None, None, "Directory already exists"
+ )
client = boto3.client("s3")
bucket, path = self.bucket_and_path(dirname)
if not path.endswith("/"):
path += "/" # This will make sure we don't override a file
- client.put_object(Body='', Bucket=bucket, Key=path)
+ client.put_object(Body="", Bucket=bucket, Key=path)
def stat(self, filename):
"""Returns file statistics for a given path."""
@@ -401,9 +408,9 @@ def stat(self, filename):
bucket, path = self.bucket_and_path(filename)
try:
obj = client.head_object(Bucket=bucket, Key=path)
- return StatData(obj['ContentLength'])
+ return StatData(obj["ContentLength"])
except botocore.exceptions.ClientError as exc:
- if exc.response['Error']['Code'] == '404':
+ if exc.response["Error"]["Code"] == "404":
raise errors.NotFoundError(None, None, "Could not find file")
else:
raise
@@ -418,12 +425,13 @@ class GFile(object):
# Only methods needed for TensorBoard are implemented.
def __init__(self, filename, mode):
- if mode not in ('r', 'rb', 'br', 'w', 'wb', 'bw'):
+ if mode not in ("r", "rb", "br", "w", "wb", "bw"):
raise NotImplementedError(
- "mode {} not supported by compat GFile".format(mode))
+ "mode {} not supported by compat GFile".format(mode)
+ )
self.filename = compat.as_bytes(filename)
self.fs = get_filesystem(self.filename)
- self.fs_supports_append = hasattr(self.fs, 'append')
+ self.fs_supports_append = hasattr(self.fs, "append")
self.buff = None
# The buffer offset and the buffer chunk size are measured in the
# natural units of the underlying stream, i.e. bytes for binary mode,
@@ -433,8 +441,8 @@ def __init__(self, filename, mode):
self.continuation_token = None
self.write_temp = None
self.write_started = False
- self.binary_mode = 'b' in mode
- self.write_mode = 'w' in mode
+ self.binary_mode = "b" in mode
+ self.write_mode = "w" in mode
self.closed = False
def __enter__(self):
@@ -453,7 +461,7 @@ def _read_buffer_to_offset(self, new_buff_offset):
old_buff_offset = self.buff_offset
read_size = min(len(self.buff), new_buff_offset) - old_buff_offset
self.buff_offset += read_size
- return self.buff[old_buff_offset:old_buff_offset + read_size]
+ return self.buff[old_buff_offset : old_buff_offset + read_size]
def read(self, n=None):
"""Reads contents of file to a string.
@@ -467,7 +475,8 @@ def read(self, n=None):
"""
if self.write_mode:
raise errors.PermissionDeniedError(
- None, None, "File not opened in read mode")
+ None, None, "File not opened in read mode"
+ )
result = None
if self.buff and len(self.buff) > self.buff_offset:
@@ -485,7 +494,8 @@ def read(self, n=None):
# read from filesystem
read_size = max(self.buff_chunk_size, n) if n is not None else None
(self.buff, self.continuation_token) = self.fs.read(
- self.filename, self.binary_mode, read_size, self.continuation_token)
+ self.filename, self.binary_mode, read_size, self.continuation_token
+ )
self.buff_offset = 0
# add from filesystem
@@ -499,18 +509,20 @@ def read(self, n=None):
return result
def write(self, file_content):
- """Writes string file contents to file, clearing contents of the
- file on first write and then appending on subsequent calls.
+ """Writes string file contents to file, clearing contents of the file
+ on first write and then appending on subsequent calls.
Args:
file_content: string, the contents
"""
if not self.write_mode:
raise errors.PermissionDeniedError(
- None, None, "File not opened in write mode")
+ None, None, "File not opened in write mode"
+ )
if self.closed:
raise errors.FailedPreconditionError(
- None, None, "File already closed")
+ None, None, "File already closed"
+ )
if self.fs_supports_append:
if not self.write_started:
@@ -535,12 +547,12 @@ def __next__(self):
if not self.buff:
# read one unit into the buffer
line = self.read(1)
- if line and (line[-1] == '\n' or not self.buff):
+ if line and (line[-1] == "\n" or not self.buff):
return line
if not self.buff:
raise StopIteration()
else:
- index = self.buff.find('\n', self.buff_offset)
+ index = self.buff.find("\n", self.buff_offset)
if index != -1:
# include line until now plus newline
chunk = self.read(index + 1 - self.buff_offset)
@@ -550,7 +562,7 @@ def __next__(self):
# read one unit past end of buffer
chunk = self.read(len(self.buff) + 1 - self.buff_offset)
line = line + chunk if line else chunk
- if line and (line[-1] == '\n' or not self.buff):
+ if line and (line[-1] == "\n" or not self.buff):
return line
if not self.buff:
raise StopIteration()
@@ -561,7 +573,8 @@ def next(self):
def flush(self):
if self.closed:
raise errors.FailedPreconditionError(
- None, None, "File already closed")
+ None, None, "File already closed"
+ )
if not self.fs_supports_append:
if self.write_temp is not None:
@@ -576,7 +589,7 @@ def flush(self):
def close(self):
self.flush()
- if self.write_temp is not None:
+ if self.write_temp is not None:
self.write_temp.close()
self.write_temp = None
self.write_started = False
@@ -723,39 +736,40 @@ def stat(filename):
"""
return get_filesystem(filename).stat(filename)
+
# Used for tests only
def _write_string_to_file(filename, file_content):
- """Writes a string to a given file.
+ """Writes a string to a given file.
- Args:
- filename: string, path to a file
- file_content: string, contents that need to be written to the file
+ Args:
+ filename: string, path to a file
+ file_content: string, contents that need to be written to the file
- Raises:
- errors.OpError: If there are errors during the operation.
- """
- with GFile(filename, mode="w") as f:
- f.write(compat.as_text(file_content))
+ Raises:
+ errors.OpError: If there are errors during the operation.
+ """
+ with GFile(filename, mode="w") as f:
+ f.write(compat.as_text(file_content))
# Used for tests only
def _read_file_to_string(filename, binary_mode=False):
- """Reads the entire contents of a file to a string.
-
- Args:
- filename: string, path to a file
- binary_mode: whether to open the file in binary mode or not. This changes
- the type of the object returned.
-
- Returns:
- contents of the file as a string or bytes.
-
- Raises:
- errors.OpError: Raises variety of errors that are subtypes e.g.
- `NotFoundError` etc.
- """
- if binary_mode:
- f = GFile(filename, mode="rb")
- else:
- f = GFile(filename, mode="r")
- return f.read()
+ """Reads the entire contents of a file to a string.
+
+ Args:
+ filename: string, path to a file
+ binary_mode: whether to open the file in binary mode or not. This changes
+ the type of the object returned.
+
+ Returns:
+ contents of the file as a string or bytes.
+
+ Raises:
+ errors.OpError: Raises variety of errors that are subtypes e.g.
+ `NotFoundError` etc.
+ """
+ if binary_mode:
+ f = GFile(filename, mode="rb")
+ else:
+ f = GFile(filename, mode="r")
+ return f.read()
diff --git a/tensorboard/compat/tensorflow_stub/io/gfile_s3_test.py b/tensorboard/compat/tensorflow_stub/io/gfile_s3_test.py
index 85c3c5a71b..67a9dbf18c 100644
--- a/tensorboard/compat/tensorflow_stub/io/gfile_s3_test.py
+++ b/tensorboard/compat/tensorflow_stub/io/gfile_s3_test.py
@@ -33,11 +33,10 @@
class GFileTest(unittest.TestCase):
-
@mock_s3
def testExists(self):
temp_dir = self._CreateDeepS3Structure()
- ckpt_path = self._PathJoin(temp_dir, 'model.ckpt')
+ ckpt_path = self._PathJoin(temp_dir, "model.ckpt")
self.assertTrue(gfile.exists(temp_dir))
self.assertTrue(gfile.exists(ckpt_path))
@@ -47,19 +46,19 @@ def testGlob(self):
# S3 glob includes subdirectory content, which standard
# filesystem does not. However, this is good for perf.
expected = [
- 'a.tfevents.1',
- 'bar/b.tfevents.1',
- 'bar/baz/c.tfevents.1',
- 'bar/baz/d.tfevents.1',
- 'bar/quux/some_flume_output.txt',
- 'bar/quux/some_more_flume_output.txt',
- 'bar/red_herring.txt',
- 'model.ckpt',
- 'quuz/e.tfevents.1',
- 'quuz/garply/corge/g.tfevents.1',
- 'quuz/garply/f.tfevents.1',
- 'quuz/garply/grault/h.tfevents.1',
- 'waldo/fred/i.tfevents.1'
+ "a.tfevents.1",
+ "bar/b.tfevents.1",
+ "bar/baz/c.tfevents.1",
+ "bar/baz/d.tfevents.1",
+ "bar/quux/some_flume_output.txt",
+ "bar/quux/some_more_flume_output.txt",
+ "bar/red_herring.txt",
+ "model.ckpt",
+ "quuz/e.tfevents.1",
+ "quuz/garply/corge/g.tfevents.1",
+ "quuz/garply/f.tfevents.1",
+ "quuz/garply/grault/h.tfevents.1",
+ "waldo/fred/i.tfevents.1",
]
expected_listing = [self._PathJoin(temp_dir, f) for f in expected]
gotten_listing = gfile.glob(self._PathJoin(temp_dir, "*"))
@@ -67,8 +66,9 @@ def testGlob(self):
self,
expected_listing,
gotten_listing,
- 'Files must match. Expected %r. Got %r.' % (
- expected_listing, gotten_listing))
+ "Files must match. Expected %r. Got %r."
+ % (expected_listing, gotten_listing),
+ )
@mock_s3
def testIsdir(self):
@@ -82,11 +82,11 @@ def testListdir(self):
expected_files = [
# Empty directory not returned
# 'foo',
- 'bar',
- 'quuz',
- 'a.tfevents.1',
- 'model.ckpt',
- 'waldo',
+ "bar",
+ "quuz",
+ "a.tfevents.1",
+ "model.ckpt",
+ "waldo",
]
gotten_files = gfile.listdir(temp_dir)
six.assertCountEqual(self, expected_files, gotten_files)
@@ -94,14 +94,14 @@ def testListdir(self):
@mock_s3
def testMakeDirs(self):
temp_dir = self._CreateDeepS3Structure()
- new_dir = self._PathJoin(temp_dir, 'newdir', 'subdir', 'subsubdir')
+ new_dir = self._PathJoin(temp_dir, "newdir", "subdir", "subsubdir")
gfile.makedirs(new_dir)
self.assertTrue(gfile.isdir(new_dir))
@mock_s3
def testMakeDirsAlreadyExists(self):
temp_dir = self._CreateDeepS3Structure()
- new_dir = self._PathJoin(temp_dir, 'bar', 'baz')
+ new_dir = self._PathJoin(temp_dir, "bar", "baz")
with self.assertRaises(errors.AlreadyExistsError):
gfile.makedirs(new_dir)
@@ -110,40 +110,21 @@ def testWalk(self):
temp_dir = self._CreateDeepS3Structure()
self._CreateDeepS3Structure(temp_dir)
expected = [
- ['', [
- 'a.tfevents.1',
- 'model.ckpt',
- ]],
+ ["", ["a.tfevents.1", "model.ckpt",]],
# Empty directory not returned
# ['foo', []],
- ['bar', [
- 'b.tfevents.1',
- 'red_herring.txt',
- ]],
- ['bar/baz', [
- 'c.tfevents.1',
- 'd.tfevents.1',
- ]],
- ['bar/quux', [
- 'some_flume_output.txt',
- 'some_more_flume_output.txt',
- ]],
- ['quuz', [
- 'e.tfevents.1',
- ]],
- ['quuz/garply', [
- 'f.tfevents.1',
- ]],
- ['quuz/garply/corge', [
- 'g.tfevents.1',
- ]],
- ['quuz/garply/grault', [
- 'h.tfevents.1',
- ]],
- ['waldo', []],
- ['waldo/fred', [
- 'i.tfevents.1',
- ]],
+ ["bar", ["b.tfevents.1", "red_herring.txt",]],
+ ["bar/baz", ["c.tfevents.1", "d.tfevents.1",]],
+ [
+ "bar/quux",
+ ["some_flume_output.txt", "some_more_flume_output.txt",],
+ ],
+ ["quuz", ["e.tfevents.1",]],
+ ["quuz/garply", ["f.tfevents.1",]],
+ ["quuz/garply/corge", ["g.tfevents.1",]],
+ ["quuz/garply/grault", ["h.tfevents.1",]],
+ ["waldo", []],
+ ["waldo/fred", ["i.tfevents.1",]],
]
for pair in expected:
# If this is not the top-level directory, prepend the high-level
@@ -154,21 +135,21 @@ def testWalk(self):
@mock_s3
def testStat(self):
- ckpt_content = 'asdfasdfasdffoobarbuzz'
+ ckpt_content = "asdfasdfasdffoobarbuzz"
temp_dir = self._CreateDeepS3Structure(ckpt_content=ckpt_content)
- ckpt_path = self._PathJoin(temp_dir, 'model.ckpt')
+ ckpt_path = self._PathJoin(temp_dir, "model.ckpt")
ckpt_stat = gfile.stat(ckpt_path)
self.assertEqual(ckpt_stat.length, len(ckpt_content))
- bad_ckpt_path = self._PathJoin(temp_dir, 'bad_model.ckpt')
+ bad_ckpt_path = self._PathJoin(temp_dir, "bad_model.ckpt")
with self.assertRaises(errors.NotFoundError):
gfile.stat(bad_ckpt_path)
@mock_s3
def testRead(self):
- ckpt_content = 'asdfasdfasdffoobarbuzz'
+ ckpt_content = "asdfasdfasdffoobarbuzz"
temp_dir = self._CreateDeepS3Structure(ckpt_content=ckpt_content)
- ckpt_path = self._PathJoin(temp_dir, 'model.ckpt')
- with gfile.GFile(ckpt_path, 'r') as f:
+ ckpt_path = self._PathJoin(temp_dir, "model.ckpt")
+ with gfile.GFile(ckpt_path, "r") as f:
f.buff_chunk_size = 4 # Test buffering by reducing chunk size
ckpt_read = f.read()
self.assertEqual(ckpt_content, ckpt_read)
@@ -176,120 +157,125 @@ def testRead(self):
@mock_s3
def testReadLines(self):
ckpt_lines = (
- [u'\n'] + [u'line {}\n'.format(i) for i in range(10)] + [u' ']
+ [u"\n"] + [u"line {}\n".format(i) for i in range(10)] + [u" "]
)
- ckpt_content = u''.join(ckpt_lines)
+ ckpt_content = u"".join(ckpt_lines)
temp_dir = self._CreateDeepS3Structure(ckpt_content=ckpt_content)
- ckpt_path = self._PathJoin(temp_dir, 'model.ckpt')
- with gfile.GFile(ckpt_path, 'r') as f:
+ ckpt_path = self._PathJoin(temp_dir, "model.ckpt")
+ with gfile.GFile(ckpt_path, "r") as f:
f.buff_chunk_size = 4 # Test buffering by reducing chunk size
ckpt_read_lines = list(f)
self.assertEqual(ckpt_lines, ckpt_read_lines)
@mock_s3
def testReadWithOffset(self):
- ckpt_content = 'asdfasdfasdffoobarbuzz'
- ckpt_b_content = b'asdfasdfasdffoobarbuzz'
+ ckpt_content = "asdfasdfasdffoobarbuzz"
+ ckpt_b_content = b"asdfasdfasdffoobarbuzz"
temp_dir = self._CreateDeepS3Structure(ckpt_content=ckpt_content)
- ckpt_path = self._PathJoin(temp_dir, 'model.ckpt')
- with gfile.GFile(ckpt_path, 'r') as f:
+ ckpt_path = self._PathJoin(temp_dir, "model.ckpt")
+ with gfile.GFile(ckpt_path, "r") as f:
f.buff_chunk_size = 4 # Test buffering by reducing chunk size
ckpt_read = f.read(12)
- self.assertEqual('asdfasdfasdf', ckpt_read)
+ self.assertEqual("asdfasdfasdf", ckpt_read)
ckpt_read = f.read(6)
- self.assertEqual('foobar', ckpt_read)
+ self.assertEqual("foobar", ckpt_read)
ckpt_read = f.read(1)
- self.assertEqual('b', ckpt_read)
+ self.assertEqual("b", ckpt_read)
ckpt_read = f.read()
- self.assertEqual('uzz', ckpt_read)
+ self.assertEqual("uzz", ckpt_read)
ckpt_read = f.read(1000)
- self.assertEqual('', ckpt_read)
- with gfile.GFile(ckpt_path, 'rb') as f:
+ self.assertEqual("", ckpt_read)
+ with gfile.GFile(ckpt_path, "rb") as f:
ckpt_read = f.read()
self.assertEqual(ckpt_b_content, ckpt_read)
@mock_s3
def testWrite(self):
temp_dir = self._CreateDeepS3Structure()
- ckpt_path = os.path.join(temp_dir, 'model2.ckpt')
- ckpt_content = u'asdfasdfasdffoobarbuzz'
- with gfile.GFile(ckpt_path, 'w') as f:
+ ckpt_path = os.path.join(temp_dir, "model2.ckpt")
+ ckpt_content = u"asdfasdfasdffoobarbuzz"
+ with gfile.GFile(ckpt_path, "w") as f:
f.write(ckpt_content)
- with gfile.GFile(ckpt_path, 'r') as f:
+ with gfile.GFile(ckpt_path, "r") as f:
ckpt_read = f.read()
self.assertEqual(ckpt_content, ckpt_read)
@mock_s3
def testOverwrite(self):
temp_dir = self._CreateDeepS3Structure()
- ckpt_path = os.path.join(temp_dir, 'model2.ckpt')
- ckpt_content = u'asdfasdfasdffoobarbuzz'
- with gfile.GFile(ckpt_path, 'w') as f:
- f.write(u'original')
- with gfile.GFile(ckpt_path, 'w') as f:
+ ckpt_path = os.path.join(temp_dir, "model2.ckpt")
+ ckpt_content = u"asdfasdfasdffoobarbuzz"
+ with gfile.GFile(ckpt_path, "w") as f:
+ f.write(u"original")
+ with gfile.GFile(ckpt_path, "w") as f:
f.write(ckpt_content)
- with gfile.GFile(ckpt_path, 'r') as f:
+ with gfile.GFile(ckpt_path, "r") as f:
ckpt_read = f.read()
self.assertEqual(ckpt_content, ckpt_read)
@mock_s3
def testWriteMultiple(self):
temp_dir = self._CreateDeepS3Structure()
- ckpt_path = os.path.join(temp_dir, 'model2.ckpt')
- ckpt_content = u'asdfasdfasdffoobarbuzz' * 5
- with gfile.GFile(ckpt_path, 'w') as f:
+ ckpt_path = os.path.join(temp_dir, "model2.ckpt")
+ ckpt_content = u"asdfasdfasdffoobarbuzz" * 5
+ with gfile.GFile(ckpt_path, "w") as f:
for i in range(0, len(ckpt_content), 3):
- f.write(ckpt_content[i:i + 3])
+ f.write(ckpt_content[i : i + 3])
# Test periodic flushing of the file
if i % 9 == 0:
f.flush()
- with gfile.GFile(ckpt_path, 'r') as f:
+ with gfile.GFile(ckpt_path, "r") as f:
ckpt_read = f.read()
self.assertEqual(ckpt_content, ckpt_read)
@mock_s3
def testWriteEmpty(self):
temp_dir = self._CreateDeepS3Structure()
- ckpt_path = os.path.join(temp_dir, 'model2.ckpt')
- ckpt_content = u''
- with gfile.GFile(ckpt_path, 'w') as f:
+ ckpt_path = os.path.join(temp_dir, "model2.ckpt")
+ ckpt_content = u""
+ with gfile.GFile(ckpt_path, "w") as f:
f.write(ckpt_content)
- with gfile.GFile(ckpt_path, 'r') as f:
+ with gfile.GFile(ckpt_path, "r") as f:
ckpt_read = f.read()
self.assertEqual(ckpt_content, ckpt_read)
@mock_s3
def testWriteBinary(self):
temp_dir = self._CreateDeepS3Structure()
- ckpt_path = os.path.join(temp_dir, 'model.ckpt')
- ckpt_content = b'asdfasdfasdffoobarbuzz'
- with gfile.GFile(ckpt_path, 'wb') as f:
+ ckpt_path = os.path.join(temp_dir, "model.ckpt")
+ ckpt_content = b"asdfasdfasdffoobarbuzz"
+ with gfile.GFile(ckpt_path, "wb") as f:
f.write(ckpt_content)
- with gfile.GFile(ckpt_path, 'rb') as f:
+ with gfile.GFile(ckpt_path, "rb") as f:
ckpt_read = f.read()
self.assertEqual(ckpt_content, ckpt_read)
@mock_s3
def testWriteMultipleBinary(self):
temp_dir = self._CreateDeepS3Structure()
- ckpt_path = os.path.join(temp_dir, 'model2.ckpt')
- ckpt_content = b'asdfasdfasdffoobarbuzz' * 5
- with gfile.GFile(ckpt_path, 'wb') as f:
+ ckpt_path = os.path.join(temp_dir, "model2.ckpt")
+ ckpt_content = b"asdfasdfasdffoobarbuzz" * 5
+ with gfile.GFile(ckpt_path, "wb") as f:
for i in range(0, len(ckpt_content), 3):
- f.write(ckpt_content[i:i + 3])
+ f.write(ckpt_content[i : i + 3])
# Test periodic flushing of the file
if i % 9 == 0:
f.flush()
- with gfile.GFile(ckpt_path, 'rb') as f:
+ with gfile.GFile(ckpt_path, "rb") as f:
ckpt_read = f.read()
self.assertEqual(ckpt_content, ckpt_read)
def _PathJoin(self, *args):
- """Join directory and path with slash and not local separator"""
+ """Join directory and path with slash and not local separator."""
return "/".join(args)
- def _CreateDeepS3Structure(self, top_directory='top_dir', ckpt_content='',
- region_name='us-east-1', bucket_name='test'):
+ def _CreateDeepS3Structure(
+ self,
+ top_directory="top_dir",
+ ckpt_content="",
+ region_name="us-east-1",
+ bucket_name="test",
+ ):
"""Creates a reasonable deep structure of S3 subdirectories with files.
Args:
@@ -302,62 +288,62 @@ def _CreateDeepS3Structure(self, top_directory='top_dir', ckpt_content='',
Returns:
S3 URL of the top directory in the form 's3://bucket/path'
"""
- s3_top_url = 's3://{}/{}'.format(bucket_name, top_directory)
+ s3_top_url = "s3://{}/{}".format(bucket_name, top_directory)
# Add a few subdirectories.
directory_names = (
# An empty directory.
- 'foo',
+ "foo",
# A directory with an events file (and a text file).
- 'bar',
+ "bar",
# A deeper directory with events files.
- 'bar/baz',
+ "bar/baz",
# A non-empty subdir that lacks event files (should be ignored).
- 'bar/quux',
+ "bar/quux",
# This 3-level deep set of subdirectories tests logic that replaces
# the full glob string with an absolute path prefix if there is
# only 1 subdirectory in the final mapping.
- 'quuz/garply',
- 'quuz/garply/corge',
- 'quuz/garply/grault',
+ "quuz/garply",
+ "quuz/garply/corge",
+ "quuz/garply/grault",
# A directory that lacks events files, but contains a subdirectory
# with events files (first level should be ignored, second level
# should be included).
- 'waldo',
- 'waldo/fred',
+ "waldo",
+ "waldo/fred",
)
- client = boto3.client('s3', region_name=region_name)
+ client = boto3.client("s3", region_name=region_name)
client.create_bucket(Bucket=bucket_name)
- client.put_object(Body='', Bucket=bucket_name, Key=top_directory)
+ client.put_object(Body="", Bucket=bucket_name, Key=top_directory)
for directory_name in directory_names:
# Add an end slash
- path = top_directory + '/' + directory_name + '/'
+ path = top_directory + "/" + directory_name + "/"
# Create an empty object so the location exists
- client.put_object(Body='', Bucket=bucket_name, Key=directory_name)
+ client.put_object(Body="", Bucket=bucket_name, Key=directory_name)
# Add a few files to the directory.
file_names = (
- 'a.tfevents.1',
- 'model.ckpt',
- 'bar/b.tfevents.1',
- 'bar/red_herring.txt',
- 'bar/baz/c.tfevents.1',
- 'bar/baz/d.tfevents.1',
- 'bar/quux/some_flume_output.txt',
- 'bar/quux/some_more_flume_output.txt',
- 'quuz/e.tfevents.1',
- 'quuz/garply/f.tfevents.1',
- 'quuz/garply/corge/g.tfevents.1',
- 'quuz/garply/grault/h.tfevents.1',
- 'waldo/fred/i.tfevents.1',
+ "a.tfevents.1",
+ "model.ckpt",
+ "bar/b.tfevents.1",
+ "bar/red_herring.txt",
+ "bar/baz/c.tfevents.1",
+ "bar/baz/d.tfevents.1",
+ "bar/quux/some_flume_output.txt",
+ "bar/quux/some_more_flume_output.txt",
+ "quuz/e.tfevents.1",
+ "quuz/garply/f.tfevents.1",
+ "quuz/garply/corge/g.tfevents.1",
+ "quuz/garply/grault/h.tfevents.1",
+ "waldo/fred/i.tfevents.1",
)
for file_name in file_names:
# Add an end slash
- path = top_directory + '/' + file_name
- if file_name == 'model.ckpt':
+ path = top_directory + "/" + file_name
+ if file_name == "model.ckpt":
content = ckpt_content
else:
- content = ''
+ content = ""
client.put_object(Body=content, Bucket=bucket_name, Key=path)
return s3_top_url
@@ -369,14 +355,18 @@ def _CompareFilesPerSubdirectory(self, expected, gotten):
gotten: The gotten iterable of 2-tuples.
"""
expected_directory_to_files = {
- result[0]: list(result[1]) for result in expected}
+ result[0]: list(result[1]) for result in expected
+ }
gotten_directory_to_files = {
# Note we ignore subdirectories and just compare files
- result[0]: list(result[2]) for result in gotten}
+ result[0]: list(result[2])
+ for result in gotten
+ }
six.assertCountEqual(
self,
expected_directory_to_files.keys(),
- gotten_directory_to_files.keys())
+ gotten_directory_to_files.keys(),
+ )
for subdir, expected_listing in expected_directory_to_files.items():
gotten_listing = gotten_directory_to_files[subdir]
@@ -384,9 +374,10 @@ def _CompareFilesPerSubdirectory(self, expected, gotten):
self,
expected_listing,
gotten_listing,
- 'Files for subdir %r must match. Expected %r. Got %r.' % (
- subdir, expected_listing, gotten_listing))
+ "Files for subdir %r must match. Expected %r. Got %r."
+ % (subdir, expected_listing, gotten_listing),
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tensorboard/compat/tensorflow_stub/io/gfile_test.py b/tensorboard/compat/tensorflow_stub/io/gfile_test.py
index 3151a9f6f2..7a3ae7e8de 100644
--- a/tensorboard/compat/tensorflow_stub/io/gfile_test.py
+++ b/tensorboard/compat/tensorflow_stub/io/gfile_test.py
@@ -32,7 +32,7 @@ class GFileTest(tb_test.TestCase):
def testExists(self):
temp_dir = self.get_temp_dir()
self._CreateDeepDirectoryStructure(temp_dir)
- ckpt_path = os.path.join(temp_dir, 'model.ckpt')
+ ckpt_path = os.path.join(temp_dir, "model.ckpt")
self.assertTrue(gfile.exists(temp_dir))
self.assertTrue(gfile.exists(ckpt_path))
@@ -40,12 +40,12 @@ def testGlob(self):
temp_dir = self.get_temp_dir()
self._CreateDeepDirectoryStructure(temp_dir)
expected = [
- 'foo',
- 'bar',
- 'a.tfevents.1',
- 'model.ckpt',
- 'quuz',
- 'waldo',
+ "foo",
+ "bar",
+ "a.tfevents.1",
+ "model.ckpt",
+ "quuz",
+ "waldo",
]
expected_listing = [os.path.join(temp_dir, f) for f in expected]
gotten_listing = gfile.glob(os.path.join(temp_dir, "*"))
@@ -53,8 +53,9 @@ def testGlob(self):
self,
expected_listing,
gotten_listing,
- 'Files must match. Expected %r. Got %r.' % (
- expected_listing, gotten_listing))
+ "Files must match. Expected %r. Got %r."
+ % (expected_listing, gotten_listing),
+ )
def testIsdir(self):
temp_dir = self.get_temp_dir()
@@ -64,29 +65,26 @@ def testListdir(self):
temp_dir = self.get_temp_dir()
self._CreateDeepDirectoryStructure(temp_dir)
expected_files = (
- 'foo',
- 'bar',
- 'quuz',
- 'a.tfevents.1',
- 'model.ckpt',
- 'waldo',
+ "foo",
+ "bar",
+ "quuz",
+ "a.tfevents.1",
+ "model.ckpt",
+ "waldo",
)
- six.assertCountEqual(
- self,
- expected_files,
- gfile.listdir(temp_dir))
+ six.assertCountEqual(self, expected_files, gfile.listdir(temp_dir))
def testMakeDirs(self):
temp_dir = self.get_temp_dir()
self._CreateDeepDirectoryStructure(temp_dir)
- new_dir = os.path.join(temp_dir, 'newdir', 'subdir', 'subsubdir')
+ new_dir = os.path.join(temp_dir, "newdir", "subdir", "subsubdir")
gfile.makedirs(new_dir)
self.assertTrue(gfile.isdir(new_dir))
def testMakeDirsAlreadyExists(self):
temp_dir = self.get_temp_dir()
self._CreateDeepDirectoryStructure(temp_dir)
- new_dir = os.path.join(temp_dir, 'bar', 'baz')
+ new_dir = os.path.join(temp_dir, "bar", "baz")
with self.assertRaises(errors.AlreadyExistsError):
gfile.makedirs(new_dir)
@@ -94,69 +92,53 @@ def testWalk(self):
temp_dir = self.get_temp_dir()
self._CreateDeepDirectoryStructure(temp_dir)
expected = [
- ['', [
- 'a.tfevents.1',
- 'model.ckpt',
- ]],
- ['foo', []],
- ['bar', [
- 'b.tfevents.1',
- 'red_herring.txt',
- ]],
- ['bar/baz', [
- 'c.tfevents.1',
- 'd.tfevents.1',
- ]],
- ['bar/quux', [
- 'some_flume_output.txt',
- 'some_more_flume_output.txt',
- ]],
- ['quuz', [
- 'e.tfevents.1',
- ]],
- ['quuz/garply', [
- 'f.tfevents.1',
- ]],
- ['quuz/garply/corge', [
- 'g.tfevents.1',
- ]],
- ['quuz/garply/grault', [
- 'h.tfevents.1',
- ]],
- ['waldo', []],
- ['waldo/fred', [
- 'i.tfevents.1',
- ]],
+ ["", ["a.tfevents.1", "model.ckpt",]],
+ ["foo", []],
+ ["bar", ["b.tfevents.1", "red_herring.txt",]],
+ ["bar/baz", ["c.tfevents.1", "d.tfevents.1",]],
+ [
+ "bar/quux",
+ ["some_flume_output.txt", "some_more_flume_output.txt",],
+ ],
+ ["quuz", ["e.tfevents.1",]],
+ ["quuz/garply", ["f.tfevents.1",]],
+ ["quuz/garply/corge", ["g.tfevents.1",]],
+ ["quuz/garply/grault", ["h.tfevents.1",]],
+ ["waldo", []],
+ ["waldo/fred", ["i.tfevents.1",]],
]
for pair in expected:
# If this is not the top-level directory, prepend the high-level
# directory.
- pair[0] = os.path.join(temp_dir,
- pair[0].replace('/', os.path.sep)) if pair[0] else temp_dir
+ pair[0] = (
+ os.path.join(temp_dir, pair[0].replace("/", os.path.sep))
+ if pair[0]
+ else temp_dir
+ )
gotten = gfile.walk(temp_dir)
self._CompareFilesPerSubdirectory(expected, gotten)
def testStat(self):
temp_dir = self.get_temp_dir()
self._CreateDeepDirectoryStructure(temp_dir)
- ckpt_path = os.path.join(temp_dir, 'model.ckpt')
- ckpt_content = 'asdfasdfasdffoobarbuzz'
- with open(ckpt_path, 'w') as f:
+ ckpt_path = os.path.join(temp_dir, "model.ckpt")
+ ckpt_content = "asdfasdfasdffoobarbuzz"
+ with open(ckpt_path, "w") as f:
f.write(ckpt_content)
ckpt_stat = gfile.stat(ckpt_path)
self.assertEqual(ckpt_stat.length, len(ckpt_content))
- bad_ckpt_path = os.path.join(temp_dir, 'bad_model.ckpt')
+ bad_ckpt_path = os.path.join(temp_dir, "bad_model.ckpt")
with self.assertRaises(errors.NotFoundError):
gfile.stat(bad_ckpt_path)
def testRead(self):
temp_dir = self.get_temp_dir()
self._CreateDeepDirectoryStructure(temp_dir)
- ckpt_path = os.path.join(temp_dir, 'model.ckpt')
- ckpt_content = 'asdfasdfasdffoobarbuzz'
- with open(ckpt_path, 'w') as f:
+ ckpt_path = os.path.join(temp_dir, "model.ckpt")
+ ckpt_content = "asdfasdfasdffoobarbuzz"
+ with open(ckpt_path, "w") as f:
f.write(ckpt_content)
- with gfile.GFile(ckpt_path, 'r') as f:
+ with gfile.GFile(ckpt_path, "r") as f:
f.buff_chunk_size = 4 # Test buffering by reducing chunk size
ckpt_read = f.read()
self.assertEqual(ckpt_content, ckpt_read)
@@ -164,24 +146,24 @@ def testRead(self):
def testReadLines(self):
temp_dir = self.get_temp_dir()
self._CreateDeepDirectoryStructure(temp_dir)
- ckpt_path = os.path.join(temp_dir, 'model.ckpt')
+ ckpt_path = os.path.join(temp_dir, "model.ckpt")
# Note \r\n, which io.read will automatically replace with \n.
# That substitution desynchronizes character offsets (omitting \r) from
# the underlying byte offsets (counting \r). Multibyte characters would
# similarly cause desynchronization.
raw_ckpt_lines = (
- [u'\r\n'] + [u'line {}\r\n'.format(i) for i in range(10)] + [u' ']
+ [u"\r\n"] + [u"line {}\r\n".format(i) for i in range(10)] + [u" "]
)
- expected_ckpt_lines = ( # without \r
- [u'\n'] + [u'line {}\n'.format(i) for i in range(10)] + [u' ']
+ expected_ckpt_lines = ( # without \r
+ [u"\n"] + [u"line {}\n".format(i) for i in range(10)] + [u" "]
)
# Write out newlines as given (i.e., \r\n) regardless of OS, so as to
# test translation on read.
- with io.open(ckpt_path, 'w', newline='') as f:
- data = u''.join(raw_ckpt_lines)
+ with io.open(ckpt_path, "w", newline="") as f:
+ data = u"".join(raw_ckpt_lines)
f.write(data)
- with gfile.GFile(ckpt_path, 'r') as f:
+ with gfile.GFile(ckpt_path, "r") as f:
f.buff_chunk_size = 4 # Test buffering by reducing chunk size
read_ckpt_lines = list(f)
self.assertEqual(expected_ckpt_lines, read_ckpt_lines)
@@ -189,100 +171,100 @@ def testReadLines(self):
def testReadWithOffset(self):
temp_dir = self.get_temp_dir()
self._CreateDeepDirectoryStructure(temp_dir)
- ckpt_path = os.path.join(temp_dir, 'model.ckpt')
- ckpt_content = 'asdfasdfasdffoobarbuzz'
- ckpt_b_content = b'asdfasdfasdffoobarbuzz'
- with open(ckpt_path, 'w') as f:
+ ckpt_path = os.path.join(temp_dir, "model.ckpt")
+ ckpt_content = "asdfasdfasdffoobarbuzz"
+ ckpt_b_content = b"asdfasdfasdffoobarbuzz"
+ with open(ckpt_path, "w") as f:
f.write(ckpt_content)
- with gfile.GFile(ckpt_path, 'r') as f:
+ with gfile.GFile(ckpt_path, "r") as f:
f.buff_chunk_size = 4 # Test buffering by reducing chunk size
ckpt_read = f.read(12)
- self.assertEqual('asdfasdfasdf', ckpt_read)
+ self.assertEqual("asdfasdfasdf", ckpt_read)
ckpt_read = f.read(6)
- self.assertEqual('foobar', ckpt_read)
+ self.assertEqual("foobar", ckpt_read)
ckpt_read = f.read(1)
- self.assertEqual('b', ckpt_read)
+ self.assertEqual("b", ckpt_read)
ckpt_read = f.read()
- self.assertEqual('uzz', ckpt_read)
+ self.assertEqual("uzz", ckpt_read)
ckpt_read = f.read(1000)
- self.assertEqual('', ckpt_read)
- with gfile.GFile(ckpt_path, 'rb') as f:
+ self.assertEqual("", ckpt_read)
+ with gfile.GFile(ckpt_path, "rb") as f:
ckpt_read = f.read()
self.assertEqual(ckpt_b_content, ckpt_read)
def testWrite(self):
temp_dir = self.get_temp_dir()
self._CreateDeepDirectoryStructure(temp_dir)
- ckpt_path = os.path.join(temp_dir, 'model2.ckpt')
- ckpt_content = u'asdfasdfasdffoobarbuzz'
- with gfile.GFile(ckpt_path, 'w') as f:
+ ckpt_path = os.path.join(temp_dir, "model2.ckpt")
+ ckpt_content = u"asdfasdfasdffoobarbuzz"
+ with gfile.GFile(ckpt_path, "w") as f:
f.write(ckpt_content)
- with open(ckpt_path, 'r') as f:
+ with open(ckpt_path, "r") as f:
ckpt_read = f.read()
self.assertEqual(ckpt_content, ckpt_read)
def testOverwrite(self):
temp_dir = self.get_temp_dir()
self._CreateDeepDirectoryStructure(temp_dir)
- ckpt_path = os.path.join(temp_dir, 'model2.ckpt')
- ckpt_content = u'asdfasdfasdffoobarbuzz'
- with gfile.GFile(ckpt_path, 'w') as f:
- f.write(u'original')
- with gfile.GFile(ckpt_path, 'w') as f:
+ ckpt_path = os.path.join(temp_dir, "model2.ckpt")
+ ckpt_content = u"asdfasdfasdffoobarbuzz"
+ with gfile.GFile(ckpt_path, "w") as f:
+ f.write(u"original")
+ with gfile.GFile(ckpt_path, "w") as f:
f.write(ckpt_content)
- with open(ckpt_path, 'r') as f:
+ with open(ckpt_path, "r") as f:
ckpt_read = f.read()
self.assertEqual(ckpt_content, ckpt_read)
def testWriteMultiple(self):
temp_dir = self.get_temp_dir()
self._CreateDeepDirectoryStructure(temp_dir)
- ckpt_path = os.path.join(temp_dir, 'model2.ckpt')
- ckpt_content = u'asdfasdfasdffoobarbuzz' * 5
- with gfile.GFile(ckpt_path, 'w') as f:
+ ckpt_path = os.path.join(temp_dir, "model2.ckpt")
+ ckpt_content = u"asdfasdfasdffoobarbuzz" * 5
+ with gfile.GFile(ckpt_path, "w") as f:
for i in range(0, len(ckpt_content), 3):
- f.write(ckpt_content[i:i + 3])
+ f.write(ckpt_content[i : i + 3])
# Test periodic flushing of the file
if i % 9 == 0:
f.flush()
- with open(ckpt_path, 'r') as f:
+ with open(ckpt_path, "r") as f:
ckpt_read = f.read()
self.assertEqual(ckpt_content, ckpt_read)
def testWriteEmpty(self):
temp_dir = self.get_temp_dir()
self._CreateDeepDirectoryStructure(temp_dir)
- ckpt_path = os.path.join(temp_dir, 'model2.ckpt')
- ckpt_content = u''
- with gfile.GFile(ckpt_path, 'w') as f:
+ ckpt_path = os.path.join(temp_dir, "model2.ckpt")
+ ckpt_content = u""
+ with gfile.GFile(ckpt_path, "w") as f:
f.write(ckpt_content)
- with open(ckpt_path, 'r') as f:
+ with open(ckpt_path, "r") as f:
ckpt_read = f.read()
self.assertEqual(ckpt_content, ckpt_read)
def testWriteBinary(self):
temp_dir = self.get_temp_dir()
self._CreateDeepDirectoryStructure(temp_dir)
- ckpt_path = os.path.join(temp_dir, 'model2.ckpt')
- ckpt_content = b'asdfasdfasdffoobarbuzz'
- with gfile.GFile(ckpt_path, 'wb') as f:
+ ckpt_path = os.path.join(temp_dir, "model2.ckpt")
+ ckpt_content = b"asdfasdfasdffoobarbuzz"
+ with gfile.GFile(ckpt_path, "wb") as f:
f.write(ckpt_content)
- with open(ckpt_path, 'rb') as f:
+ with open(ckpt_path, "rb") as f:
ckpt_read = f.read()
self.assertEqual(ckpt_content, ckpt_read)
def testWriteMultipleBinary(self):
temp_dir = self.get_temp_dir()
self._CreateDeepDirectoryStructure(temp_dir)
- ckpt_path = os.path.join(temp_dir, 'model2.ckpt')
- ckpt_content = b'asdfasdfasdffoobarbuzz' * 5
- with gfile.GFile(ckpt_path, 'wb') as f:
+ ckpt_path = os.path.join(temp_dir, "model2.ckpt")
+ ckpt_content = b"asdfasdfasdffoobarbuzz" * 5
+ with gfile.GFile(ckpt_path, "wb") as f:
for i in range(0, len(ckpt_content), 3):
- f.write(ckpt_content[i:i + 3])
+ f.write(ckpt_content[i : i + 3])
# Test periodic flushing of the file
if i % 9 == 0:
f.flush()
- with open(ckpt_path, 'rb') as f:
+ with open(ckpt_path, "rb") as f:
ckpt_read = f.read()
self.assertEqual(ckpt_content, ckpt_read)
@@ -296,46 +278,46 @@ def _CreateDeepDirectoryStructure(self, top_directory):
# Add a few subdirectories.
directory_names = (
# An empty directory.
- 'foo',
+ "foo",
# A directory with an events file (and a text file).
- 'bar',
+ "bar",
# A deeper directory with events files.
- 'bar/baz',
+ "bar/baz",
# A non-empty subdir that lacks event files (should be ignored).
- 'bar/quux',
+ "bar/quux",
# This 3-level deep set of subdirectories tests logic that replaces
# the full glob string with an absolute path prefix if there is
# only 1 subdirectory in the final mapping.
- 'quuz/garply',
- 'quuz/garply/corge',
- 'quuz/garply/grault',
+ "quuz/garply",
+ "quuz/garply/corge",
+ "quuz/garply/grault",
# A directory that lacks events files, but contains a subdirectory
# with events files (first level should be ignored, second level
# should be included).
- 'waldo',
- 'waldo/fred',
+ "waldo",
+ "waldo/fred",
)
for directory_name in directory_names:
os.makedirs(os.path.join(top_directory, directory_name))
# Add a few files to the directory.
file_names = (
- 'a.tfevents.1',
- 'model.ckpt',
- 'bar/b.tfevents.1',
- 'bar/red_herring.txt',
- 'bar/baz/c.tfevents.1',
- 'bar/baz/d.tfevents.1',
- 'bar/quux/some_flume_output.txt',
- 'bar/quux/some_more_flume_output.txt',
- 'quuz/e.tfevents.1',
- 'quuz/garply/f.tfevents.1',
- 'quuz/garply/corge/g.tfevents.1',
- 'quuz/garply/grault/h.tfevents.1',
- 'waldo/fred/i.tfevents.1',
+ "a.tfevents.1",
+ "model.ckpt",
+ "bar/b.tfevents.1",
+ "bar/red_herring.txt",
+ "bar/baz/c.tfevents.1",
+ "bar/baz/d.tfevents.1",
+ "bar/quux/some_flume_output.txt",
+ "bar/quux/some_more_flume_output.txt",
+ "quuz/e.tfevents.1",
+ "quuz/garply/f.tfevents.1",
+ "quuz/garply/corge/g.tfevents.1",
+ "quuz/garply/grault/h.tfevents.1",
+ "waldo/fred/i.tfevents.1",
)
for file_name in file_names:
- open(os.path.join(top_directory, file_name), 'w').close()
+ open(os.path.join(top_directory, file_name), "w").close()
def _CompareFilesPerSubdirectory(self, expected, gotten):
"""Compares iterables of (subdirectory path, list of absolute paths)
@@ -345,14 +327,18 @@ def _CompareFilesPerSubdirectory(self, expected, gotten):
gotten: The gotten iterable of 2-tuples.
"""
expected_directory_to_files = {
- result[0]: list(result[1]) for result in expected}
+ result[0]: list(result[1]) for result in expected
+ }
gotten_directory_to_files = {
# Note we ignore subdirectories and just compare files
- result[0]: list(result[2]) for result in gotten}
+ result[0]: list(result[2])
+ for result in gotten
+ }
six.assertCountEqual(
self,
expected_directory_to_files.keys(),
- gotten_directory_to_files.keys())
+ gotten_directory_to_files.keys(),
+ )
for subdir, expected_listing in expected_directory_to_files.items():
gotten_listing = gotten_directory_to_files[subdir]
@@ -360,9 +346,10 @@ def _CompareFilesPerSubdirectory(self, expected, gotten):
self,
expected_listing,
gotten_listing,
- 'Files for subdir %r must match. Expected %r. Got %r.' % (
- subdir, expected_listing, gotten_listing))
+ "Files for subdir %r must match. Expected %r. Got %r."
+ % (subdir, expected_listing, gotten_listing),
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
tb_test.main()
diff --git a/tensorboard/compat/tensorflow_stub/io/gfile_tf_test.py b/tensorboard/compat/tensorflow_stub/io/gfile_tf_test.py
index b6ce254280..38b0b61d75 100644
--- a/tensorboard/compat/tensorflow_stub/io/gfile_tf_test.py
+++ b/tensorboard/compat/tensorflow_stub/io/gfile_tf_test.py
@@ -35,247 +35,274 @@
# the TensorFlow FileIO API (to the extent that it's implemented at all).
# Many of the TF tests are removed because they do not apply here.
-class FileIoTest(tb_test.TestCase):
- def setUp(self):
- self._base_dir = os.path.join(self.get_temp_dir(), "base_dir")
- gfile.makedirs(self._base_dir)
-
- # It was a temp_dir anyway
- # def tearDown(self):
- # gfile.delete_recursively(self._base_dir)
-
- def testEmptyFilename(self):
- f = gfile.GFile("", mode="r")
- with self.assertRaises(errors.NotFoundError):
- _ = f.read()
-
- def testFileDoesntExist(self):
- file_path = os.path.join(self._base_dir, "temp_file")
- self.assertFalse(gfile.exists(file_path))
- with self.assertRaises(errors.NotFoundError):
- _ = gfile._read_file_to_string(file_path)
-
- def testWriteToString(self):
- file_path = os.path.join(self._base_dir, "temp_file")
- gfile._write_string_to_file(file_path, "testing")
- self.assertTrue(gfile.exists(file_path))
- file_contents = gfile._read_file_to_string(file_path)
- self.assertEqual("testing", file_contents)
-
- def testReadBinaryMode(self):
- file_path = os.path.join(self._base_dir, "temp_file")
- gfile._write_string_to_file(file_path, "testing")
- with gfile.GFile(file_path, mode="rb") as f:
- self.assertEqual(b"testing", f.read())
-
- def testWriteBinaryMode(self):
- file_path = os.path.join(self._base_dir, "temp_file")
- gfile.GFile(file_path, "wb").write(compat.as_bytes("testing"))
- with gfile.GFile(file_path, mode="r") as f:
- self.assertEqual("testing", f.read())
-
- def testMultipleFiles(self):
- file_prefix = os.path.join(self._base_dir, "temp_file")
- for i in range(5000):
-
- with gfile.GFile(file_prefix + str(i), mode="w") as f:
- f.write("testing")
- f.flush()
-
- with gfile.GFile(file_prefix + str(i), mode="r") as f:
- self.assertEqual("testing", f.read())
-
- def testMultipleWrites(self):
- file_path = os.path.join(self._base_dir, "temp_file")
- with gfile.GFile(file_path, mode="w") as f:
- f.write("line1\n")
- f.write("line2")
- file_contents = gfile._read_file_to_string(file_path)
- self.assertEqual("line1\nline2", file_contents)
-
- def testFileWriteBadMode(self):
- file_path = os.path.join(self._base_dir, "temp_file")
- with self.assertRaises(errors.PermissionDeniedError):
- gfile.GFile(file_path, mode="r").write("testing")
-
- def testFileReadBadMode(self):
- file_path = os.path.join(self._base_dir, "temp_file")
- gfile.GFile(file_path, mode="w").write("testing")
- self.assertTrue(gfile.exists(file_path))
- with self.assertRaises(errors.PermissionDeniedError):
- gfile.GFile(file_path, mode="w").read()
-
- def testIsDirectory(self):
- dir_path = os.path.join(self._base_dir, "test_dir")
- # Failure for a non-existing dir.
- self.assertFalse(gfile.isdir(dir_path))
- gfile.makedirs(dir_path)
- self.assertTrue(gfile.isdir(dir_path))
- file_path = os.path.join(dir_path, "test_file")
- gfile.GFile(file_path, mode="w").write("test")
- # False for a file.
- self.assertFalse(gfile.isdir(file_path))
- # Test that the value returned from `stat()` has `is_directory` set.
- # file_statistics = gfile.stat(dir_path)
- # self.assertTrue(file_statistics.is_directory)
-
- def testListDirectory(self):
- dir_path = os.path.join(self._base_dir, "test_dir")
- gfile.makedirs(dir_path)
- files = ["file1.txt", "file2.txt", "file3.txt"]
- for name in files:
- file_path = os.path.join(dir_path, name)
- gfile.GFile(file_path, mode="w").write("testing")
- subdir_path = os.path.join(dir_path, "sub_dir")
- gfile.makedirs(subdir_path)
- subdir_file_path = os.path.join(subdir_path, "file4.txt")
- gfile.GFile(subdir_file_path, mode="w").write("testing")
- dir_list = gfile.listdir(dir_path)
- self.assertItemsEqual(files + ["sub_dir"], dir_list)
-
- def testListDirectoryFailure(self):
- dir_path = os.path.join(self._base_dir, "test_dir")
- with self.assertRaises(errors.NotFoundError):
- gfile.listdir(dir_path)
-
- def _setupWalkDirectories(self, dir_path):
- # Creating a file structure as follows
- # test_dir -> file: file1.txt; dirs: subdir1_1, subdir1_2, subdir1_3
- # subdir1_1 -> file: file3.txt
- # subdir1_2 -> dir: subdir2
- gfile.makedirs(dir_path)
- gfile.GFile(
- os.path.join(dir_path, "file1.txt"), mode="w").write("testing")
- sub_dirs1 = ["subdir1_1", "subdir1_2", "subdir1_3"]
- for name in sub_dirs1:
- gfile.makedirs(os.path.join(dir_path, name))
- gfile.GFile(
- os.path.join(dir_path, "subdir1_1/file2.txt"),
- mode="w").write("testing")
- gfile.makedirs(os.path.join(dir_path, "subdir1_2/subdir2"))
-
- def testWalkInOrder(self):
- dir_path = os.path.join(self._base_dir, "test_dir")
- self._setupWalkDirectories(dir_path)
- # Now test the walk (topdown = True)
- all_dirs = []
- all_subdirs = []
- all_files = []
- for (w_dir, w_subdirs, w_files) in gfile.walk(dir_path, topdown=True):
- all_dirs.append(w_dir)
- all_subdirs.append(w_subdirs)
- all_files.append(w_files)
- self.assertItemsEqual(all_dirs, [dir_path] + [
- os.path.join(dir_path, item)
- for item in
- ["subdir1_1", "subdir1_2", "subdir1_2/subdir2", "subdir1_3"]
- ])
- self.assertEqual(dir_path, all_dirs[0])
- self.assertLess(
- all_dirs.index(os.path.join(dir_path, "subdir1_2")),
- all_dirs.index(os.path.join(dir_path, "subdir1_2/subdir2")))
- self.assertItemsEqual(all_subdirs[1:5], [[], ["subdir2"], [], []])
- self.assertItemsEqual(all_subdirs[0],
- ["subdir1_1", "subdir1_2", "subdir1_3"])
- self.assertItemsEqual(all_files, [["file1.txt"], ["file2.txt"], [], [], []])
- self.assertLess(
- all_files.index(["file1.txt"]), all_files.index(["file2.txt"]))
-
- def testWalkPostOrder(self):
- dir_path = os.path.join(self._base_dir, "test_dir")
- self._setupWalkDirectories(dir_path)
- # Now test the walk (topdown = False)
- all_dirs = []
- all_subdirs = []
- all_files = []
- for (w_dir, w_subdirs, w_files) in gfile.walk(dir_path, topdown=False):
- all_dirs.append(w_dir)
- all_subdirs.append(w_subdirs)
- all_files.append(w_files)
- self.assertItemsEqual(all_dirs, [
- os.path.join(dir_path, item)
- for item in
- ["subdir1_1", "subdir1_2/subdir2", "subdir1_2", "subdir1_3"]
- ] + [dir_path])
- self.assertEqual(dir_path, all_dirs[4])
- self.assertLess(
- all_dirs.index(os.path.join(dir_path, "subdir1_2/subdir2")),
- all_dirs.index(os.path.join(dir_path, "subdir1_2")))
- self.assertItemsEqual(all_subdirs[0:4], [[], [], ["subdir2"], []])
- self.assertItemsEqual(all_subdirs[4],
- ["subdir1_1", "subdir1_2", "subdir1_3"])
- self.assertItemsEqual(all_files, [["file2.txt"], [], [], [], ["file1.txt"]])
- self.assertLess(
- all_files.index(["file2.txt"]), all_files.index(["file1.txt"]))
-
- def testWalkFailure(self):
- dir_path = os.path.join(self._base_dir, "test_dir")
- # Try walking a directory that wasn't created.
- all_dirs = []
- all_subdirs = []
- all_files = []
- for (w_dir, w_subdirs, w_files) in gfile.walk(dir_path, topdown=False):
- all_dirs.append(w_dir)
- all_subdirs.append(w_subdirs)
- all_files.append(w_files)
- self.assertItemsEqual(all_dirs, [])
- self.assertItemsEqual(all_subdirs, [])
- self.assertItemsEqual(all_files, [])
-
- def testStat(self):
- file_path = os.path.join(self._base_dir, "temp_file")
- gfile.GFile(file_path, mode="w").write("testing")
- file_statistics = gfile.stat(file_path)
- os_statistics = os.stat(file_path)
- self.assertEqual(7, file_statistics.length)
-
- def testRead(self):
- file_path = os.path.join(self._base_dir, "temp_file")
- with gfile.GFile(file_path, mode="w") as f:
- f.write("testing1\ntesting2\ntesting3\n\ntesting5")
- with gfile.GFile(file_path, mode="r") as f:
- self.assertEqual(36, gfile.stat(file_path).length)
- self.assertEqual("testing1\n", f.read(9))
- self.assertEqual("testing2\n", f.read(9))
- self.assertEqual("t", f.read(1))
- self.assertEqual("esting3\n\ntesting5", f.read())
-
- def testReadingIterator(self):
- file_path = os.path.join(self._base_dir, "temp_file")
- data = ["testing1\n", "testing2\n", "testing3\n", "\n", "testing5"]
- with gfile.GFile(file_path, mode="w") as f:
- f.write("".join(data))
- with gfile.GFile(file_path, mode="r") as f:
- actual_data = []
- for line in f:
- actual_data.append(line)
- self.assertSequenceEqual(actual_data, data)
-
- def testUTF8StringPath(self):
- file_path = os.path.join(self._base_dir, "UTF8测试_file")
- gfile._write_string_to_file(file_path, "testing")
- with gfile.GFile(file_path, mode="rb") as f:
- self.assertEqual(b"testing", f.read())
-
- def testEof(self):
- """Test that reading past EOF does not raise an exception."""
-
- file_path = os.path.join(self._base_dir, "temp_file")
-
- with gfile.GFile(file_path, mode="w") as f:
- content = "testing"
- f.write(content)
- f.flush()
- with gfile.GFile(file_path, mode="r") as f:
- self.assertEqual(content, f.read(len(content) + 1))
-
- def testUTF8StringPathExists(self):
- file_path = os.path.join(self._base_dir, "UTF8测试_file_exist")
- gfile._write_string_to_file(file_path, "testing")
- v = gfile.exists(file_path)
- self.assertEqual(v, True)
+class FileIoTest(tb_test.TestCase):
+ def setUp(self):
+ self._base_dir = os.path.join(self.get_temp_dir(), "base_dir")
+ gfile.makedirs(self._base_dir)
+
+ # It was a temp_dir anyway
+ # def tearDown(self):
+ # gfile.delete_recursively(self._base_dir)
+
+ def testEmptyFilename(self):
+ f = gfile.GFile("", mode="r")
+ with self.assertRaises(errors.NotFoundError):
+ _ = f.read()
+
+ def testFileDoesntExist(self):
+ file_path = os.path.join(self._base_dir, "temp_file")
+ self.assertFalse(gfile.exists(file_path))
+ with self.assertRaises(errors.NotFoundError):
+ _ = gfile._read_file_to_string(file_path)
+
+ def testWriteToString(self):
+ file_path = os.path.join(self._base_dir, "temp_file")
+ gfile._write_string_to_file(file_path, "testing")
+ self.assertTrue(gfile.exists(file_path))
+ file_contents = gfile._read_file_to_string(file_path)
+ self.assertEqual("testing", file_contents)
+
+ def testReadBinaryMode(self):
+ file_path = os.path.join(self._base_dir, "temp_file")
+ gfile._write_string_to_file(file_path, "testing")
+ with gfile.GFile(file_path, mode="rb") as f:
+ self.assertEqual(b"testing", f.read())
+
+ def testWriteBinaryMode(self):
+ file_path = os.path.join(self._base_dir, "temp_file")
+ gfile.GFile(file_path, "wb").write(compat.as_bytes("testing"))
+ with gfile.GFile(file_path, mode="r") as f:
+ self.assertEqual("testing", f.read())
+
+ def testMultipleFiles(self):
+ file_prefix = os.path.join(self._base_dir, "temp_file")
+ for i in range(5000):
+
+ with gfile.GFile(file_prefix + str(i), mode="w") as f:
+ f.write("testing")
+ f.flush()
+
+ with gfile.GFile(file_prefix + str(i), mode="r") as f:
+ self.assertEqual("testing", f.read())
+
+ def testMultipleWrites(self):
+ file_path = os.path.join(self._base_dir, "temp_file")
+ with gfile.GFile(file_path, mode="w") as f:
+ f.write("line1\n")
+ f.write("line2")
+ file_contents = gfile._read_file_to_string(file_path)
+ self.assertEqual("line1\nline2", file_contents)
+
+ def testFileWriteBadMode(self):
+ file_path = os.path.join(self._base_dir, "temp_file")
+ with self.assertRaises(errors.PermissionDeniedError):
+ gfile.GFile(file_path, mode="r").write("testing")
+
+ def testFileReadBadMode(self):
+ file_path = os.path.join(self._base_dir, "temp_file")
+ gfile.GFile(file_path, mode="w").write("testing")
+ self.assertTrue(gfile.exists(file_path))
+ with self.assertRaises(errors.PermissionDeniedError):
+ gfile.GFile(file_path, mode="w").read()
+
+ def testIsDirectory(self):
+ dir_path = os.path.join(self._base_dir, "test_dir")
+ # Failure for a non-existing dir.
+ self.assertFalse(gfile.isdir(dir_path))
+ gfile.makedirs(dir_path)
+ self.assertTrue(gfile.isdir(dir_path))
+ file_path = os.path.join(dir_path, "test_file")
+ gfile.GFile(file_path, mode="w").write("test")
+ # False for a file.
+ self.assertFalse(gfile.isdir(file_path))
+ # Test that the value returned from `stat()` has `is_directory` set.
+ # file_statistics = gfile.stat(dir_path)
+ # self.assertTrue(file_statistics.is_directory)
+
+ def testListDirectory(self):
+ dir_path = os.path.join(self._base_dir, "test_dir")
+ gfile.makedirs(dir_path)
+ files = ["file1.txt", "file2.txt", "file3.txt"]
+ for name in files:
+ file_path = os.path.join(dir_path, name)
+ gfile.GFile(file_path, mode="w").write("testing")
+ subdir_path = os.path.join(dir_path, "sub_dir")
+ gfile.makedirs(subdir_path)
+ subdir_file_path = os.path.join(subdir_path, "file4.txt")
+ gfile.GFile(subdir_file_path, mode="w").write("testing")
+ dir_list = gfile.listdir(dir_path)
+ self.assertItemsEqual(files + ["sub_dir"], dir_list)
+
+ def testListDirectoryFailure(self):
+ dir_path = os.path.join(self._base_dir, "test_dir")
+ with self.assertRaises(errors.NotFoundError):
+ gfile.listdir(dir_path)
+
+ def _setupWalkDirectories(self, dir_path):
+ # Creating a file structure as follows
+ # test_dir -> file: file1.txt; dirs: subdir1_1, subdir1_2, subdir1_3
+ # subdir1_1 -> file: file3.txt
+ # subdir1_2 -> dir: subdir2
+ gfile.makedirs(dir_path)
+ gfile.GFile(os.path.join(dir_path, "file1.txt"), mode="w").write(
+ "testing"
+ )
+ sub_dirs1 = ["subdir1_1", "subdir1_2", "subdir1_3"]
+ for name in sub_dirs1:
+ gfile.makedirs(os.path.join(dir_path, name))
+ gfile.GFile(
+ os.path.join(dir_path, "subdir1_1/file2.txt"), mode="w"
+ ).write("testing")
+ gfile.makedirs(os.path.join(dir_path, "subdir1_2/subdir2"))
+
+ def testWalkInOrder(self):
+ dir_path = os.path.join(self._base_dir, "test_dir")
+ self._setupWalkDirectories(dir_path)
+ # Now test the walk (topdown = True)
+ all_dirs = []
+ all_subdirs = []
+ all_files = []
+ for (w_dir, w_subdirs, w_files) in gfile.walk(dir_path, topdown=True):
+ all_dirs.append(w_dir)
+ all_subdirs.append(w_subdirs)
+ all_files.append(w_files)
+ self.assertItemsEqual(
+ all_dirs,
+ [dir_path]
+ + [
+ os.path.join(dir_path, item)
+ for item in [
+ "subdir1_1",
+ "subdir1_2",
+ "subdir1_2/subdir2",
+ "subdir1_3",
+ ]
+ ],
+ )
+ self.assertEqual(dir_path, all_dirs[0])
+ self.assertLess(
+ all_dirs.index(os.path.join(dir_path, "subdir1_2")),
+ all_dirs.index(os.path.join(dir_path, "subdir1_2/subdir2")),
+ )
+ self.assertItemsEqual(all_subdirs[1:5], [[], ["subdir2"], [], []])
+ self.assertItemsEqual(
+ all_subdirs[0], ["subdir1_1", "subdir1_2", "subdir1_3"]
+ )
+ self.assertItemsEqual(
+ all_files, [["file1.txt"], ["file2.txt"], [], [], []]
+ )
+ self.assertLess(
+ all_files.index(["file1.txt"]), all_files.index(["file2.txt"])
+ )
+
+ def testWalkPostOrder(self):
+ dir_path = os.path.join(self._base_dir, "test_dir")
+ self._setupWalkDirectories(dir_path)
+ # Now test the walk (topdown = False)
+ all_dirs = []
+ all_subdirs = []
+ all_files = []
+ for (w_dir, w_subdirs, w_files) in gfile.walk(dir_path, topdown=False):
+ all_dirs.append(w_dir)
+ all_subdirs.append(w_subdirs)
+ all_files.append(w_files)
+ self.assertItemsEqual(
+ all_dirs,
+ [
+ os.path.join(dir_path, item)
+ for item in [
+ "subdir1_1",
+ "subdir1_2/subdir2",
+ "subdir1_2",
+ "subdir1_3",
+ ]
+ ]
+ + [dir_path],
+ )
+ self.assertEqual(dir_path, all_dirs[4])
+ self.assertLess(
+ all_dirs.index(os.path.join(dir_path, "subdir1_2/subdir2")),
+ all_dirs.index(os.path.join(dir_path, "subdir1_2")),
+ )
+ self.assertItemsEqual(all_subdirs[0:4], [[], [], ["subdir2"], []])
+ self.assertItemsEqual(
+ all_subdirs[4], ["subdir1_1", "subdir1_2", "subdir1_3"]
+ )
+ self.assertItemsEqual(
+ all_files, [["file2.txt"], [], [], [], ["file1.txt"]]
+ )
+ self.assertLess(
+ all_files.index(["file2.txt"]), all_files.index(["file1.txt"])
+ )
+
+ def testWalkFailure(self):
+ dir_path = os.path.join(self._base_dir, "test_dir")
+ # Try walking a directory that wasn't created.
+ all_dirs = []
+ all_subdirs = []
+ all_files = []
+ for (w_dir, w_subdirs, w_files) in gfile.walk(dir_path, topdown=False):
+ all_dirs.append(w_dir)
+ all_subdirs.append(w_subdirs)
+ all_files.append(w_files)
+ self.assertItemsEqual(all_dirs, [])
+ self.assertItemsEqual(all_subdirs, [])
+ self.assertItemsEqual(all_files, [])
+
+ def testStat(self):
+ file_path = os.path.join(self._base_dir, "temp_file")
+ gfile.GFile(file_path, mode="w").write("testing")
+ file_statistics = gfile.stat(file_path)
+ os_statistics = os.stat(file_path)
+ self.assertEqual(7, file_statistics.length)
+
+ def testRead(self):
+ file_path = os.path.join(self._base_dir, "temp_file")
+ with gfile.GFile(file_path, mode="w") as f:
+ f.write("testing1\ntesting2\ntesting3\n\ntesting5")
+ with gfile.GFile(file_path, mode="r") as f:
+ self.assertEqual(36, gfile.stat(file_path).length)
+ self.assertEqual("testing1\n", f.read(9))
+ self.assertEqual("testing2\n", f.read(9))
+ self.assertEqual("t", f.read(1))
+ self.assertEqual("esting3\n\ntesting5", f.read())
+
+ def testReadingIterator(self):
+ file_path = os.path.join(self._base_dir, "temp_file")
+ data = ["testing1\n", "testing2\n", "testing3\n", "\n", "testing5"]
+ with gfile.GFile(file_path, mode="w") as f:
+ f.write("".join(data))
+ with gfile.GFile(file_path, mode="r") as f:
+ actual_data = []
+ for line in f:
+ actual_data.append(line)
+ self.assertSequenceEqual(actual_data, data)
+
+ def testUTF8StringPath(self):
+ file_path = os.path.join(self._base_dir, "UTF8测试_file")
+ gfile._write_string_to_file(file_path, "testing")
+ with gfile.GFile(file_path, mode="rb") as f:
+ self.assertEqual(b"testing", f.read())
+
+ def testEof(self):
+ """Test that reading past EOF does not raise an exception."""
+
+ file_path = os.path.join(self._base_dir, "temp_file")
+
+ with gfile.GFile(file_path, mode="w") as f:
+ content = "testing"
+ f.write(content)
+ f.flush()
+ with gfile.GFile(file_path, mode="r") as f:
+ self.assertEqual(content, f.read(len(content) + 1))
+
+ def testUTF8StringPathExists(self):
+ file_path = os.path.join(self._base_dir, "UTF8测试_file_exist")
+ gfile._write_string_to_file(file_path, "testing")
+ v = gfile.exists(file_path)
+ self.assertEqual(v, True)
if __name__ == "__main__":
- tb_test.main()
+ tb_test.main()
diff --git a/tensorboard/compat/tensorflow_stub/pywrap_tensorflow.py b/tensorboard/compat/tensorflow_stub/pywrap_tensorflow.py
index a41563e698..649ac598eb 100644
--- a/tensorboard/compat/tensorflow_stub/pywrap_tensorflow.py
+++ b/tensorboard/compat/tensorflow_stub/pywrap_tensorflow.py
@@ -41,11 +41,11 @@ def TF_bfloat16_type():
def masked_crc32c(data):
x = u32(crc32c(data))
- return u32(((x >> 15) | u32(x << 17)) + 0xa282ead8)
+ return u32(((x >> 15) | u32(x << 17)) + 0xA282EAD8)
def u32(x):
- return x & 0xffffffff
+ return x & 0xFFFFFFFF
# fmt: off
@@ -125,6 +125,7 @@ def u32(x):
def crc_update(crc, data):
"""Update CRC-32C checksum with data.
+
Args:
crc: 32-bit checksum to update as long.
data: byte array, string or iterable over bytes.
@@ -139,13 +140,14 @@ def crc_update(crc, data):
crc ^= _MASK
for b in buf:
- table_index = (crc ^ b) & 0xff
+ table_index = (crc ^ b) & 0xFF
crc = (CRC_TABLE[table_index] ^ (crc >> 8)) & _MASK
return crc ^ _MASK
def crc_finalize(crc):
"""Finalize CRC-32C checksum.
+
This function should be called as last step of crc calculation.
Args:
crc: 32-bit checksum as long.
@@ -157,6 +159,7 @@ def crc_finalize(crc):
def crc32c(data):
"""Compute CRC-32C checksum of the data.
+
Args:
data: byte array, string or iterable over bytes.
Returns:
@@ -166,28 +169,34 @@ def crc32c(data):
class PyRecordReader_New:
- def __init__(self, filename=None, start_offset=0, compression_type=None,
- status=None):
+ def __init__(
+ self, filename=None, start_offset=0, compression_type=None, status=None
+ ):
if filename is None:
raise errors.NotFoundError(
- None, None, 'No filename provided, cannot read Events')
+ None, None, "No filename provided, cannot read Events"
+ )
if not gfile.exists(filename):
raise errors.NotFoundError(
- None, None,
- '{} does not point to valid Events file'.format(filename))
+ None,
+ None,
+ "{} does not point to valid Events file".format(filename),
+ )
if start_offset:
raise errors.UnimplementedError(
- None, None, 'start offset not supported by compat reader')
+ None, None, "start offset not supported by compat reader"
+ )
if compression_type:
# TODO: Handle gzip and zlib compressed files
raise errors.UnimplementedError(
- None, None, 'compression not supported by compat reader')
+ None, None, "compression not supported by compat reader"
+ )
self.filename = filename
self.start_offset = start_offset
self.compression_type = compression_type
self.status = status
self.curr_event = None
- self.file_handle = gfile.GFile(self.filename, 'rb')
+ self.file_handle = gfile.GFile(self.filename, "rb")
def GetNext(self):
# Read the header
@@ -195,18 +204,17 @@ def GetNext(self):
header_str = self.file_handle.read(8)
if len(header_str) != 8:
# Hit EOF so raise and exit
- raise errors.OutOfRangeError(None, None, 'No more events to read')
- header = struct.unpack('Q', header_str)
+ raise errors.OutOfRangeError(None, None, "No more events to read")
+ header = struct.unpack("Q", header_str)
# Read the crc32, which is 4 bytes, and check it against
# the crc32 of the header
crc_header_str = self.file_handle.read(4)
- crc_header = struct.unpack('I', crc_header_str)
+ crc_header = struct.unpack("I", crc_header_str)
header_crc_calc = masked_crc32c(header_str)
if header_crc_calc != crc_header[0]:
raise errors.DataLossError(
- None, None,
- '{} failed header crc32 check'.format(self.filename)
+ None, None, "{} failed header crc32 check".format(self.filename)
)
# The length of the header tells us how many bytes the Event
@@ -221,11 +229,12 @@ def GetNext(self):
# has no crc32, in which case we skip.
crc_event_str = self.file_handle.read(4)
if crc_event_str:
- crc_event = struct.unpack('I', crc_event_str)
+ crc_event = struct.unpack("I", crc_event_str)
if event_crc_calc != crc_event[0]:
raise errors.DataLossError(
- None, None,
- '{} failed event crc32 check'.format(self.filename)
+ None,
+ None,
+ "{} failed event crc32 check".format(self.filename),
)
# Set the current event to be read later by record() call
diff --git a/tensorboard/compat/tensorflow_stub/tensor_shape.py b/tensorboard/compat/tensorflow_stub/tensor_shape.py
index 1133b0d087..bdb3f7e075 100644
--- a/tensorboard/compat/tensorflow_stub/tensor_shape.py
+++ b/tensorboard/compat/tensorflow_stub/tensor_shape.py
@@ -49,7 +49,8 @@ def __str__(self):
return "?" if value is None else str(value)
def __eq__(self, other):
- """Returns true if `other` has the same known value as this Dimension."""
+ """Returns true if `other` has the same known value as this
+ Dimension."""
try:
other = as_dimension(other)
except (TypeError, ValueError):
@@ -98,10 +99,15 @@ def is_convertible_with(self, other):
True if this Dimension and `other` are convertible.
"""
other = as_dimension(other)
- return self._value is None or other.value is None or self._value == other.value
+ return (
+ self._value is None
+ or other.value is None
+ or self._value == other.value
+ )
def assert_is_convertible_with(self, other):
- """Raises an exception if `other` is not convertible with this Dimension.
+ """Raises an exception if `other` is not convertible with this
+ Dimension.
Args:
other: Another Dimension.
@@ -111,10 +117,13 @@ def assert_is_convertible_with(self, other):
is_convertible_with).
"""
if not self.is_convertible_with(other):
- raise ValueError("Dimensions %s and %s are not convertible" % (self, other))
+ raise ValueError(
+ "Dimensions %s and %s are not convertible" % (self, other)
+ )
def merge_with(self, other):
- """Returns a Dimension that combines the information in `self` and `other`.
+ """Returns a Dimension that combines the information in `self` and
+ `other`.
Dimensions are combined as follows:
@@ -433,7 +442,8 @@ def __gt__(self, other):
return self._value > other.value
def __ge__(self, other):
- """Returns True if `self` is known to be greater than or equal to `other`.
+ """Returns True if `self` is known to be greater than or equal to
+ `other`.
Dimensions are compared as follows:
@@ -554,7 +564,8 @@ def __str__(self):
@property
def dims(self):
- """Returns a list of Dimensions, or None if the shape is unspecified."""
+ """Returns a list of Dimensions, or None if the shape is
+ unspecified."""
return self._dims
@dims.setter
@@ -573,9 +584,12 @@ def ndims(self):
return self._ndims
def __len__(self):
- """Returns the rank of this shape, or raises ValueError if unspecified."""
+ """Returns the rank of this shape, or raises ValueError if
+ unspecified."""
if self._dims is None:
- raise ValueError("Cannot take the length of Shape with unknown rank.")
+ raise ValueError(
+ "Cannot take the length of Shape with unknown rank."
+ )
return self.ndims
def __bool__(self):
@@ -586,7 +600,8 @@ def __bool__(self):
__nonzero__ = __bool__
def __iter__(self):
- """Returns `self.dims` if the rank is known, otherwise raises ValueError."""
+ """Returns `self.dims` if the rank is known, otherwise raises
+ ValueError."""
if self._dims is None:
raise ValueError("Cannot iterate over a shape with unknown rank.")
else:
@@ -637,7 +652,8 @@ def __getitem__(self, key):
return Dimension(None)
def num_elements(self):
- """Returns the total number of elements, or none for incomplete shapes."""
+ """Returns the total number of elements, or none for incomplete
+ shapes."""
if self.is_fully_defined():
size = 1
for dim in self._dims:
@@ -647,7 +663,8 @@ def num_elements(self):
return None
def merge_with(self, other):
- """Returns a `TensorShape` combining the information in `self` and `other`.
+ """Returns a `TensorShape` combining the information in `self` and
+ `other`.
The dimensions in `self` and `other` are merged elementwise,
according to the rules defined for `Dimension.merge_with()`.
@@ -673,7 +690,9 @@ def merge_with(self, other):
new_dims.append(dim.merge_with(other[i]))
return TensorShape(new_dims)
except ValueError:
- raise ValueError("Shapes %s and %s are not convertible" % (self, other))
+ raise ValueError(
+ "Shapes %s and %s are not convertible" % (self, other)
+ )
def concatenate(self, other):
"""Returns the concatenation of the dimension in `self` and `other`.
@@ -699,7 +718,8 @@ def concatenate(self, other):
return TensorShape(self._dims + other.dims)
def assert_same_rank(self, other):
- """Raises an exception if `self` and `other` do not have convertible ranks.
+ """Raises an exception if `self` and `other` do not have convertible
+ ranks.
Args:
other: Another `TensorShape`.
@@ -716,7 +736,8 @@ def assert_same_rank(self, other):
)
def assert_has_rank(self, rank):
- """Raises an exception if `self` is not convertible with the given `rank`.
+ """Raises an exception if `self` is not convertible with the given
+ `rank`.
Args:
rank: An integer.
@@ -762,7 +783,9 @@ def with_rank_at_least(self, rank):
`rank`.
"""
if self.ndims is not None and self.ndims < rank:
- raise ValueError("Shape %s must have rank at least %d" % (self, rank))
+ raise ValueError(
+ "Shape %s must have rank at least %d" % (self, rank)
+ )
else:
return self
@@ -781,7 +804,9 @@ def with_rank_at_most(self, rank):
`rank`.
"""
if self.ndims is not None and self.ndims > rank:
- raise ValueError("Shape %s must have rank at most %d" % (self, rank))
+ raise ValueError(
+ "Shape %s must have rank at most %d" % (self, rank)
+ )
else:
return self
@@ -821,7 +846,6 @@ def is_convertible_with(self, other):
Returns:
True iff `self` is convertible with `other`.
-
"""
other = as_shape(other)
if self._dims is not None and other.dims is not None:
@@ -833,7 +857,8 @@ def is_convertible_with(self, other):
return True
def assert_is_convertible_with(self, other):
- """Raises exception if `self` and `other` do not represent the same shape.
+ """Raises exception if `self` and `other` do not represent the same
+ shape.
This method can be used to assert that there exists a shape that both
`self` and `other` represent.
@@ -845,10 +870,13 @@ def assert_is_convertible_with(self, other):
ValueError: If `self` and `other` do not represent the same shape.
"""
if not self.is_convertible_with(other):
- raise ValueError("Shapes %s and %s are inconvertible" % (self, other))
+ raise ValueError(
+ "Shapes %s and %s are inconvertible" % (self, other)
+ )
def most_specific_convertible_shape(self, other):
- """Returns the most specific TensorShape convertible with `self` and `other`.
+ """Returns the most specific TensorShape convertible with `self` and
+ `other`.
* TensorShape([None, 1]) is the most specific TensorShape convertible with
both TensorShape([2, 1]) and TensorShape([5, 1]). Note that
@@ -868,7 +896,11 @@ def most_specific_convertible_shape(self, other):
"""
other = as_shape(other)
- if self._dims is None or other.dims is None or self.ndims != other.ndims:
+ if (
+ self._dims is None
+ or other.dims is None
+ or self.ndims != other.ndims
+ ):
return unknown_shape()
dims = [(Dimension(None))] * self.ndims
@@ -884,7 +916,8 @@ def is_fully_defined(self):
)
def assert_is_fully_defined(self):
- """Raises an exception if `self` is not fully defined in every dimension.
+ """Raises an exception if `self` is not fully defined in every
+ dimension.
Raises:
ValueError: If `self` does not have a known value for every dimension.
@@ -902,7 +935,9 @@ def as_list(self):
ValueError: If `self` is an unknown shape with an unknown rank.
"""
if self._dims is None:
- raise ValueError("as_list() is not defined on an unknown TensorShape.")
+ raise ValueError(
+ "as_list() is not defined on an unknown TensorShape."
+ )
return [dim.value for dim in self._dims]
def as_proto(self):
@@ -934,7 +969,9 @@ def __ne__(self, other):
except TypeError:
return NotImplemented
if self.ndims is None or other.ndims is None:
- raise ValueError("The inequality of unknown TensorShapes is undefined.")
+ raise ValueError(
+ "The inequality of unknown TensorShapes is undefined."
+ )
if self.ndims != other.ndims:
return True
return self._dims != other.dims
diff --git a/tensorboard/data/provider.py b/tensorboard/data/provider.py
index ad2dee8d49..047c574f28 100644
--- a/tensorboard/data/provider.py
+++ b/tensorboard/data/provider.py
@@ -27,750 +27,769 @@
@six.add_metaclass(abc.ABCMeta)
class DataProvider(object):
- """Interface for reading TensorBoard scalar, tensor, and blob data.
+ """Interface for reading TensorBoard scalar, tensor, and blob data.
- These APIs are under development and subject to change. For instance,
- providers may be asked to implement more filtering mechanisms, such as
- downsampling strategies or domain restriction by step or wall time.
+ These APIs are under development and subject to change. For instance,
+ providers may be asked to implement more filtering mechanisms, such as
+ downsampling strategies or domain restriction by step or wall time.
- Unless otherwise noted, any methods on this class may raise errors
- defined in `tensorboard.errors`, like `tensorboard.errors.NotFoundError`.
- """
-
- def data_location(self, experiment_id):
- """Render a human-readable description of the data source.
-
- For instance, this might return a path to a directory on disk.
-
- The default implementation always returns the empty string.
-
- Args:
- experiment_id: ID of enclosing experiment.
-
- Returns:
- A string, which may be empty.
+ Unless otherwise noted, any methods on this class may raise errors
+ defined in `tensorboard.errors`, like `tensorboard.errors.NotFoundError`.
"""
- return ""
-
- @abc.abstractmethod
- def list_runs(self, experiment_id):
- """List all runs within an experiment.
- Args:
- experiment_id: ID of enclosing experiment.
+ def data_location(self, experiment_id):
+ """Render a human-readable description of the data source.
+
+ For instance, this might return a path to a directory on disk.
+
+ The default implementation always returns the empty string.
+
+ Args:
+ experiment_id: ID of enclosing experiment.
+
+ Returns:
+ A string, which may be empty.
+ """
+ return ""
+
+ @abc.abstractmethod
+ def list_runs(self, experiment_id):
+ """List all runs within an experiment.
+
+ Args:
+ experiment_id: ID of enclosing experiment.
+
+ Returns:
+ A collection of `Run` values.
+
+ Raises:
+ tensorboard.errors.PublicError: See `DataProvider` class docstring.
+ """
+ pass
+
+ @abc.abstractmethod
+ def list_scalars(self, experiment_id, plugin_name, run_tag_filter=None):
+ """List metadata about scalar time series.
+
+ Args:
+ experiment_id: ID of enclosing experiment.
+ plugin_name: String name of the TensorBoard plugin that created
+ the data to be queried. Required.
+ run_tag_filter: Optional `RunTagFilter` value. If omitted, all
+ runs and tags will be included.
+
+ The result will only contain keys for run-tag combinations that
+ actually exist, which may not include all entries in the
+ `run_tag_filter`.
+
+ Returns:
+ A nested map `d` such that `d[run][tag]` is a `ScalarTimeSeries`
+ value.
+
+ Raises:
+ tensorboard.errors.PublicError: See `DataProvider` class docstring.
+ """
+ pass
+
+ @abc.abstractmethod
+ def read_scalars(
+ self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
+ ):
+ """Read values from scalar time series.
+
+ Args:
+ experiment_id: ID of enclosing experiment.
+ plugin_name: String name of the TensorBoard plugin that created
+ the data to be queried. Required.
+ downsample: Integer number of steps to which to downsample the
+ results (e.g., `1000`). Required.
+ run_tag_filter: Optional `RunTagFilter` value. If provided, a time
+ series will only be included in the result if its run and tag
+ both pass this filter. If `None`, all time series will be
+ included.
+
+ The result will only contain keys for run-tag combinations that
+ actually exist, which may not include all entries in the
+ `run_tag_filter`.
+
+ Returns:
+ A nested map `d` such that `d[run][tag]` is a list of
+ `ScalarDatum` values sorted by step.
+
+ Raises:
+ tensorboard.errors.PublicError: See `DataProvider` class docstring.
+ """
+ pass
+
+ def list_tensors(self, experiment_id, plugin_name, run_tag_filter=None):
+ """List metadata about tensor time series.
+
+ Args:
+ experiment_id: ID of enclosing experiment.
+ plugin_name: String name of the TensorBoard plugin that created
+ the data to be queried. Required.
+ run_tag_filter: Optional `RunTagFilter` value. If omitted, all
+ runs and tags will be included.
+
+ The result will only contain keys for run-tag combinations that
+ actually exist, which may not include all entries in the
+ `run_tag_filter`.
+
+ Returns:
+ A nested map `d` such that `d[run][tag]` is a `TensorTimeSeries`
+ value.
+
+ Raises:
+ tensorboard.errors.PublicError: See `DataProvider` class docstring.
+ """
+ pass
+
+ def read_tensors(
+ self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
+ ):
+ """Read values from tensor time series.
+
+ Args:
+ experiment_id: ID of enclosing experiment.
+ plugin_name: String name of the TensorBoard plugin that created
+ the data to be queried. Required.
+ downsample: Integer number of steps to which to downsample the
+ results (e.g., `1000`). Required.
+ run_tag_filter: Optional `RunTagFilter` value. If provided, a time
+ series will only be included in the result if its run and tag
+ both pass this filter. If `None`, all time series will be
+ included.
+
+ The result will only contain keys for run-tag combinations that
+ actually exist, which may not include all entries in the
+ `run_tag_filter`.
+
+ Returns:
+ A nested map `d` such that `d[run][tag]` is a list of
+ `TensorDatum` values sorted by step.
+
+ Raises:
+ tensorboard.errors.PublicError: See `DataProvider` class docstring.
+ """
+ pass
+
+ def list_blob_sequences(
+ self, experiment_id, plugin_name, run_tag_filter=None
+ ):
+ """List metadata about blob sequence time series.
+
+ Args:
+ experiment_id: ID of enclosing experiment.
+ plugin_name: String name of the TensorBoard plugin that created the data
+ to be queried. Required.
+ run_tag_filter: Optional `RunTagFilter` value. If omitted, all runs and
+ tags will be included. The result will only contain keys for run-tag
+ combinations that actually exist, which may not include all entries in
+ the `run_tag_filter`.
+
+ Returns:
+ A nested map `d` such that `d[run][tag]` is a `BlobSequenceTimeSeries`
+ value.
+
+ Raises:
+ tensorboard.errors.PublicError: See `DataProvider` class docstring.
+ """
+ pass
+
+ def read_blob_sequences(
+ self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
+ ):
+ """Read values from blob sequence time series.
+
+ Args:
+ experiment_id: ID of enclosing experiment.
+ plugin_name: String name of the TensorBoard plugin that created the data
+ to be queried. Required.
+ downsample: Integer number of steps to which to downsample the results
+ (e.g., `1000`). Required.
+ run_tag_filter: Optional `RunTagFilter` value. If provided, a time series
+ will only be included in the result if its run and tag both pass this
+ filter. If `None`, all time series will be included. The result will
+ only contain keys for run-tag combinations that actually exist, which
+ may not include all entries in the `run_tag_filter`.
+
+ Returns:
+ A nested map `d` such that `d[run][tag]` is a list of
+ `BlobSequenceDatum` values sorted by step.
+
+ Raises:
+ tensorboard.errors.PublicError: See `DataProvider` class docstring.
+ """
+ pass
+
+ def read_blob(self, blob_key):
+ """Read data for a single blob.
+
+ Args:
+ blob_key: A key identifying the desired blob, as provided by
+ `read_blob_sequences(...)`.
+
+ Returns:
+ Raw binary data as `bytes`.
+
+ Raises:
+ tensorboard.errors.PublicError: See `DataProvider` class docstring.
+ """
+ pass
- Returns:
- A collection of `Run` values.
- Raises:
- tensorboard.errors.PublicError: See `DataProvider` class docstring.
+class Run(object):
+ """Metadata about a run.
+
+ Attributes:
+ run_id: A unique opaque string identifier for this run.
+ run_name: A user-facing name for this run (as a `str`).
+ start_time: The wall time of the earliest recorded event in this
+ run, as `float` seconds since epoch, or `None` if this run has no
+ recorded events.
"""
- pass
-
- @abc.abstractmethod
- def list_scalars(self, experiment_id, plugin_name, run_tag_filter=None):
- """List metadata about scalar time series.
- Args:
- experiment_id: ID of enclosing experiment.
- plugin_name: String name of the TensorBoard plugin that created
- the data to be queried. Required.
- run_tag_filter: Optional `RunTagFilter` value. If omitted, all
- runs and tags will be included.
+ __slots__ = ("_run_id", "_run_name", "_start_time")
+
+ def __init__(self, run_id, run_name, start_time):
+ self._run_id = run_id
+ self._run_name = run_name
+ self._start_time = start_time
+
+ @property
+ def run_id(self):
+ return self._run_id
+
+ @property
+ def run_name(self):
+ return self._run_name
+
+ @property
+ def start_time(self):
+ return self._start_time
+
+ def __eq__(self, other):
+ if not isinstance(other, Run):
+ return False
+ if self._run_id != other._run_id:
+ return False
+ if self._run_name != other._run_name:
+ return False
+ if self._start_time != other._start_time:
+ return False
+ return True
+
+ def __hash__(self):
+ return hash((self._run_id, self._run_name, self._start_time))
+
+ def __repr__(self):
+ return "Run(%s)" % ", ".join(
+ (
+ "run_id=%r" % (self._run_id,),
+ "run_name=%r" % (self._run_name,),
+ "start_time=%r" % (self._start_time,),
+ )
+ )
- The result will only contain keys for run-tag combinations that
- actually exist, which may not include all entries in the
- `run_tag_filter`.
- Returns:
- A nested map `d` such that `d[run][tag]` is a `ScalarTimeSeries`
- value.
+class _TimeSeries(object):
+ """Metadata about time series data for a particular run and tag.
- Raises:
- tensorboard.errors.PublicError: See `DataProvider` class docstring.
+ Superclass of `ScalarTimeSeries` and `BlobSequenceTimeSeries`.
"""
- pass
-
- @abc.abstractmethod
- def read_scalars(
- self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
- ):
- """Read values from scalar time series.
-
- Args:
- experiment_id: ID of enclosing experiment.
- plugin_name: String name of the TensorBoard plugin that created
- the data to be queried. Required.
- downsample: Integer number of steps to which to downsample the
- results (e.g., `1000`). Required.
- run_tag_filter: Optional `RunTagFilter` value. If provided, a time
- series will only be included in the result if its run and tag
- both pass this filter. If `None`, all time series will be
- included.
-
- The result will only contain keys for run-tag combinations that
- actually exist, which may not include all entries in the
- `run_tag_filter`.
-
- Returns:
- A nested map `d` such that `d[run][tag]` is a list of
- `ScalarDatum` values sorted by step.
-
- Raises:
- tensorboard.errors.PublicError: See `DataProvider` class docstring.
- """
- pass
- def list_tensors(self, experiment_id, plugin_name, run_tag_filter=None):
- """List metadata about tensor time series.
+ __slots__ = (
+ "_max_step",
+ "_max_wall_time",
+ "_plugin_content",
+ "_description",
+ "_display_name",
+ )
- Args:
- experiment_id: ID of enclosing experiment.
- plugin_name: String name of the TensorBoard plugin that created
- the data to be queried. Required.
- run_tag_filter: Optional `RunTagFilter` value. If omitted, all
- runs and tags will be included.
+ def __init__(
+ self, max_step, max_wall_time, plugin_content, description, display_name
+ ):
+ self._max_step = max_step
+ self._max_wall_time = max_wall_time
+ self._plugin_content = plugin_content
+ self._description = description
+ self._display_name = display_name
- The result will only contain keys for run-tag combinations that
- actually exist, which may not include all entries in the
- `run_tag_filter`.
+ @property
+ def max_step(self):
+ return self._max_step
- Returns:
- A nested map `d` such that `d[run][tag]` is a `TensorTimeSeries`
- value.
+ @property
+ def max_wall_time(self):
+ return self._max_wall_time
- Raises:
- tensorboard.errors.PublicError: See `DataProvider` class docstring.
- """
- pass
-
- def read_tensors(
- self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
- ):
- """Read values from tensor time series.
-
- Args:
- experiment_id: ID of enclosing experiment.
- plugin_name: String name of the TensorBoard plugin that created
- the data to be queried. Required.
- downsample: Integer number of steps to which to downsample the
- results (e.g., `1000`). Required.
- run_tag_filter: Optional `RunTagFilter` value. If provided, a time
- series will only be included in the result if its run and tag
- both pass this filter. If `None`, all time series will be
- included.
-
- The result will only contain keys for run-tag combinations that
- actually exist, which may not include all entries in the
- `run_tag_filter`.
-
- Returns:
- A nested map `d` such that `d[run][tag]` is a list of
- `TensorDatum` values sorted by step.
-
- Raises:
- tensorboard.errors.PublicError: See `DataProvider` class docstring.
- """
- pass
-
- def list_blob_sequences(
- self, experiment_id, plugin_name, run_tag_filter=None
- ):
- """List metadata about blob sequence time series.
-
- Args:
- experiment_id: ID of enclosing experiment.
- plugin_name: String name of the TensorBoard plugin that created the data
- to be queried. Required.
- run_tag_filter: Optional `RunTagFilter` value. If omitted, all runs and
- tags will be included. The result will only contain keys for run-tag
- combinations that actually exist, which may not include all entries in
- the `run_tag_filter`.
-
- Returns:
- A nested map `d` such that `d[run][tag]` is a `BlobSequenceTimeSeries`
- value.
-
- Raises:
- tensorboard.errors.PublicError: See `DataProvider` class docstring.
- """
- pass
-
- def read_blob_sequences(
- self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
- ):
- """Read values from blob sequence time series.
-
- Args:
- experiment_id: ID of enclosing experiment.
- plugin_name: String name of the TensorBoard plugin that created the data
- to be queried. Required.
- downsample: Integer number of steps to which to downsample the results
- (e.g., `1000`). Required.
- run_tag_filter: Optional `RunTagFilter` value. If provided, a time series
- will only be included in the result if its run and tag both pass this
- filter. If `None`, all time series will be included. The result will
- only contain keys for run-tag combinations that actually exist, which
- may not include all entries in the `run_tag_filter`.
-
- Returns:
- A nested map `d` such that `d[run][tag]` is a list of
- `BlobSequenceDatum` values sorted by step.
-
- Raises:
- tensorboard.errors.PublicError: See `DataProvider` class docstring.
- """
- pass
+ @property
+ def plugin_content(self):
+ return self._plugin_content
- def read_blob(self, blob_key):
- """Read data for a single blob.
+ @property
+ def description(self):
+ return self._description
- Args:
- blob_key: A key identifying the desired blob, as provided by
- `read_blob_sequences(...)`.
+ @property
+ def display_name(self):
+ return self._display_name
- Returns:
- Raw binary data as `bytes`.
- Raises:
- tensorboard.errors.PublicError: See `DataProvider` class docstring.
+class ScalarTimeSeries(_TimeSeries):
+ """Metadata about a scalar time series for a particular run and tag.
+
+ Attributes:
+ max_step: The largest step value of any datum in this scalar time series; a
+ nonnegative integer.
+ max_wall_time: The largest wall time of any datum in this time series, as
+ `float` seconds since epoch.
+ plugin_content: A bytestring of arbitrary plugin-specific metadata for this
+ time series, as provided to `tf.summary.write` in the
+ `plugin_data.content` field of the `metadata` argument.
+ description: An optional long-form Markdown description, as a `str` that is
+ empty if no description was specified.
+ display_name: An optional long-form Markdown description, as a `str` that is
+ empty if no description was specified. Deprecated; may be removed soon.
"""
- pass
+ def __eq__(self, other):
+ if not isinstance(other, ScalarTimeSeries):
+ return False
+ if self._max_step != other._max_step:
+ return False
+ if self._max_wall_time != other._max_wall_time:
+ return False
+ if self._plugin_content != other._plugin_content:
+ return False
+ if self._description != other._description:
+ return False
+ if self._display_name != other._display_name:
+ return False
+ return True
+
+ def __hash__(self):
+ return hash(
+ (
+ self._max_step,
+ self._max_wall_time,
+ self._plugin_content,
+ self._description,
+ self._display_name,
+ )
+ )
+
+ def __repr__(self):
+ return "ScalarTimeSeries(%s)" % ", ".join(
+ (
+ "max_step=%r" % (self._max_step,),
+ "max_wall_time=%r" % (self._max_wall_time,),
+ "plugin_content=%r" % (self._plugin_content,),
+ "description=%r" % (self._description,),
+ "display_name=%r" % (self._display_name,),
+ )
+ )
-class Run(object):
- """Metadata about a run.
-
- Attributes:
- run_id: A unique opaque string identifier for this run.
- run_name: A user-facing name for this run (as a `str`).
- start_time: The wall time of the earliest recorded event in this
- run, as `float` seconds since epoch, or `None` if this run has no
- recorded events.
- """
-
- __slots__ = ("_run_id", "_run_name", "_start_time")
-
- def __init__(self, run_id, run_name, start_time):
- self._run_id = run_id
- self._run_name = run_name
- self._start_time = start_time
-
- @property
- def run_id(self):
- return self._run_id
-
- @property
- def run_name(self):
- return self._run_name
-
- @property
- def start_time(self):
- return self._start_time
-
- def __eq__(self, other):
- if not isinstance(other, Run):
- return False
- if self._run_id != other._run_id:
- return False
- if self._run_name != other._run_name:
- return False
- if self._start_time != other._start_time:
- return False
- return True
-
- def __hash__(self):
- return hash((self._run_id, self._run_name, self._start_time))
-
- def __repr__(self):
- return "Run(%s)" % ", ".join((
- "run_id=%r" % (self._run_id,),
- "run_name=%r" % (self._run_name,),
- "start_time=%r" % (self._start_time,),
- ))
+class ScalarDatum(object):
+ """A single datum in a scalar time series for a run and tag.
+
+ Attributes:
+ step: The global step at which this datum occurred; an integer. This
+ is a unique key among data of this time series.
+ wall_time: The real-world time at which this datum occurred, as
+ `float` seconds since epoch.
+ value: The scalar value for this datum; a `float`.
+ """
-class _TimeSeries(object):
- """Metadata about time series data for a particular run and tag.
+ __slots__ = ("_step", "_wall_time", "_value")
+
+ def __init__(self, step, wall_time, value):
+ self._step = step
+ self._wall_time = wall_time
+ self._value = value
+
+ @property
+ def step(self):
+ return self._step
+
+ @property
+ def wall_time(self):
+ return self._wall_time
+
+ @property
+ def value(self):
+ return self._value
+
+ def __eq__(self, other):
+ if not isinstance(other, ScalarDatum):
+ return False
+ if self._step != other._step:
+ return False
+ if self._wall_time != other._wall_time:
+ return False
+ if self._value != other._value:
+ return False
+ return True
+
+ def __hash__(self):
+ return hash((self._step, self._wall_time, self._value))
+
+ def __repr__(self):
+ return "ScalarDatum(%s)" % ", ".join(
+ (
+ "step=%r" % (self._step,),
+ "wall_time=%r" % (self._wall_time,),
+ "value=%r" % (self._value,),
+ )
+ )
- Superclass of `ScalarTimeSeries` and `BlobSequenceTimeSeries`.
- """
- __slots__ = (
- "_max_step",
- "_max_wall_time",
- "_plugin_content",
- "_description",
- "_display_name",
- )
+class TensorTimeSeries(_TimeSeries):
+ """Metadata about a tensor time series for a particular run and tag.
+
+ Attributes:
+ max_step: The largest step value of any datum in this tensor time series; a
+ nonnegative integer.
+ max_wall_time: The largest wall time of any datum in this time series, as
+ `float` seconds since epoch.
+ plugin_content: A bytestring of arbitrary plugin-specific metadata for this
+ time series, as provided to `tf.summary.write` in the
+ `plugin_data.content` field of the `metadata` argument.
+ description: An optional long-form Markdown description, as a `str` that is
+ empty if no description was specified.
+ display_name: An optional long-form Markdown description, as a `str` that is
+ empty if no description was specified. Deprecated; may be removed soon.
+ """
- def __init__(
- self, max_step, max_wall_time, plugin_content, description, display_name
- ):
- self._max_step = max_step
- self._max_wall_time = max_wall_time
- self._plugin_content = plugin_content
- self._description = description
- self._display_name = display_name
+ def __eq__(self, other):
+ if not isinstance(other, TensorTimeSeries):
+ return False
+ if self._max_step != other._max_step:
+ return False
+ if self._max_wall_time != other._max_wall_time:
+ return False
+ if self._plugin_content != other._plugin_content:
+ return False
+ if self._description != other._description:
+ return False
+ if self._display_name != other._display_name:
+ return False
+ return True
+
+ def __hash__(self):
+ return hash(
+ (
+ self._max_step,
+ self._max_wall_time,
+ self._plugin_content,
+ self._description,
+ self._display_name,
+ )
+ )
+
+ def __repr__(self):
+ return "TensorTimeSeries(%s)" % ", ".join(
+ (
+ "max_step=%r" % (self._max_step,),
+ "max_wall_time=%r" % (self._max_wall_time,),
+ "plugin_content=%r" % (self._plugin_content,),
+ "description=%r" % (self._description,),
+ "display_name=%r" % (self._display_name,),
+ )
+ )
- @property
- def max_step(self):
- return self._max_step
- @property
- def max_wall_time(self):
- return self._max_wall_time
+class TensorDatum(object):
+ """A single datum in a tensor time series for a run and tag.
+
+ Attributes:
+ step: The global step at which this datum occurred; an integer. This
+ is a unique key among data of this time series.
+ wall_time: The real-world time at which this datum occurred, as
+ `float` seconds since epoch.
+ numpy: The `numpy.ndarray` value with the tensor contents of this
+ datum.
+ """
- @property
- def plugin_content(self):
- return self._plugin_content
+ __slots__ = ("_step", "_wall_time", "_numpy")
+
+ def __init__(self, step, wall_time, numpy):
+ self._step = step
+ self._wall_time = wall_time
+ self._numpy = numpy
+
+ @property
+ def step(self):
+ return self._step
+
+ @property
+ def wall_time(self):
+ return self._wall_time
+
+ @property
+ def numpy(self):
+ return self._numpy
+
+ def __eq__(self, other):
+ if not isinstance(other, TensorDatum):
+ return False
+ if self._step != other._step:
+ return False
+ if self._wall_time != other._wall_time:
+ return False
+ if not np.array_equal(self._numpy, other._numpy):
+ return False
+ return True
+
+ # Unhashable type: numpy arrays are mutable.
+ __hash__ = None
+
+ def __repr__(self):
+ return "TensorDatum(%s)" % ", ".join(
+ (
+ "step=%r" % (self._step,),
+ "wall_time=%r" % (self._wall_time,),
+ "numpy=%r" % (self._numpy,),
+ )
+ )
- @property
- def description(self):
- return self._description
- @property
- def display_name(self):
- return self._display_name
+class BlobSequenceTimeSeries(_TimeSeries):
+ """Metadata about a blob sequence time series for a particular run and tag.
+
+ Attributes:
+ max_step: The largest step value of any datum in this scalar time series; a
+ nonnegative integer.
+ max_wall_time: The largest wall time of any datum in this time series, as
+ `float` seconds since epoch.
+ latest_max_index: The largest index in the sequence at max_step (0-based,
+ inclusive).
+ plugin_content: A bytestring of arbitrary plugin-specific metadata for this
+ time series, as provided to `tf.summary.write` in the
+ `plugin_data.content` field of the `metadata` argument.
+ description: An optional long-form Markdown description, as a `str` that is
+ empty if no description was specified.
+ display_name: An optional long-form Markdown description, as a `str` that is
+ empty if no description was specified. Deprecated; may be removed soon.
+ """
+ __slots__ = ("_latest_max_index",)
+
+ def __init__(
+ self,
+ max_step,
+ max_wall_time,
+ latest_max_index,
+ plugin_content,
+ description,
+ display_name,
+ ):
+ super(BlobSequenceTimeSeries, self).__init__(
+ max_step, max_wall_time, plugin_content, description, display_name
+ )
+ self._latest_max_index = latest_max_index
+
+ @property
+ def latest_max_index(self):
+ return self._latest_max_index
+
+ def __eq__(self, other):
+ if not isinstance(other, BlobSequenceTimeSeries):
+ return False
+ if self._max_step != other._max_step:
+ return False
+ if self._max_wall_time != other._max_wall_time:
+ return False
+ if self._latest_max_index != other._latest_max_index:
+ return False
+ if self._plugin_content != other._plugin_content:
+ return False
+ if self._description != other._description:
+ return False
+ if self._display_name != other._display_name:
+ return False
+ return True
+
+ def __hash__(self):
+ return hash(
+ (
+ self._max_step,
+ self._max_wall_time,
+ self._latest_max_index,
+ self._plugin_content,
+ self._description,
+ self._display_name,
+ )
+ )
+
+ def __repr__(self):
+ return "BlobSequenceTimeSeries(%s)" % ", ".join(
+ (
+ "max_step=%r" % (self._max_step,),
+ "max_wall_time=%r" % (self._max_wall_time,),
+ "latest_max_index=%r" % (self._latest_max_index,),
+ "plugin_content=%r" % (self._plugin_content,),
+ "description=%r" % (self._description,),
+ "display_name=%r" % (self._display_name,),
+ )
+ )
-class ScalarTimeSeries(_TimeSeries):
- """Metadata about a scalar time series for a particular run and tag.
-
- Attributes:
- max_step: The largest step value of any datum in this scalar time series; a
- nonnegative integer.
- max_wall_time: The largest wall time of any datum in this time series, as
- `float` seconds since epoch.
- plugin_content: A bytestring of arbitrary plugin-specific metadata for this
- time series, as provided to `tf.summary.write` in the
- `plugin_data.content` field of the `metadata` argument.
- description: An optional long-form Markdown description, as a `str` that is
- empty if no description was specified.
- display_name: An optional long-form Markdown description, as a `str` that is
- empty if no description was specified. Deprecated; may be removed soon.
- """
-
- def __eq__(self, other):
- if not isinstance(other, ScalarTimeSeries):
- return False
- if self._max_step != other._max_step:
- return False
- if self._max_wall_time != other._max_wall_time:
- return False
- if self._plugin_content != other._plugin_content:
- return False
- if self._description != other._description:
- return False
- if self._display_name != other._display_name:
- return False
- return True
-
- def __hash__(self):
- return hash((
- self._max_step,
- self._max_wall_time,
- self._plugin_content,
- self._description,
- self._display_name,
- ))
-
- def __repr__(self):
- return "ScalarTimeSeries(%s)" % ", ".join((
- "max_step=%r" % (self._max_step,),
- "max_wall_time=%r" % (self._max_wall_time,),
- "plugin_content=%r" % (self._plugin_content,),
- "description=%r" % (self._description,),
- "display_name=%r" % (self._display_name,),
- ))
+class BlobReference(object):
+ """A reference to a blob.
+
+ Attributes:
+ blob_key: A string containing a key uniquely identifying a blob, which
+ may be dereferenced via `provider.read_blob(blob_key)`.
+
+ These keys must be constructed such that they can be included directly in
+ a URL, with no further encoding. Concretely, this means that they consist
+ exclusively of "unreserved characters" per RFC 3986, namely
+ [a-zA-Z0-9._~-]. These keys are case-sensitive; it may be wise for
+ implementations to normalize case to reduce confusion. The empty string
+ is not a valid key.
+
+ Blob keys must not contain information that should be kept secret.
+ Privacy-sensitive applications should use random keys (e.g. UUIDs), or
+ encrypt keys containing secret fields.
+ url: (optional) A string containing a URL from which the blob data may be
+ fetched directly, bypassing the data provider. URLs may be a vector
+ for data leaks (e.g. via browser history, web proxies, etc.), so these
+ URLs should not expose secret information.
+ """
-class ScalarDatum(object):
- """A single datum in a scalar time series for a run and tag.
-
- Attributes:
- step: The global step at which this datum occurred; an integer. This
- is a unique key among data of this time series.
- wall_time: The real-world time at which this datum occurred, as
- `float` seconds since epoch.
- value: The scalar value for this datum; a `float`.
- """
-
- __slots__ = ("_step", "_wall_time", "_value")
-
- def __init__(self, step, wall_time, value):
- self._step = step
- self._wall_time = wall_time
- self._value = value
-
- @property
- def step(self):
- return self._step
-
- @property
- def wall_time(self):
- return self._wall_time
-
- @property
- def value(self):
- return self._value
-
- def __eq__(self, other):
- if not isinstance(other, ScalarDatum):
- return False
- if self._step != other._step:
- return False
- if self._wall_time != other._wall_time:
- return False
- if self._value != other._value:
- return False
- return True
-
- def __hash__(self):
- return hash((self._step, self._wall_time, self._value))
-
- def __repr__(self):
- return "ScalarDatum(%s)" % ", ".join((
- "step=%r" % (self._step,),
- "wall_time=%r" % (self._wall_time,),
- "value=%r" % (self._value,),
- ))
+ __slots__ = ("_url", "_blob_key")
+ def __init__(self, blob_key, url=None):
+ self._blob_key = blob_key
+ self._url = url
-class TensorTimeSeries(_TimeSeries):
- """Metadata about a tensor time series for a particular run and tag.
-
- Attributes:
- max_step: The largest step value of any datum in this tensor time series; a
- nonnegative integer.
- max_wall_time: The largest wall time of any datum in this time series, as
- `float` seconds since epoch.
- plugin_content: A bytestring of arbitrary plugin-specific metadata for this
- time series, as provided to `tf.summary.write` in the
- `plugin_data.content` field of the `metadata` argument.
- description: An optional long-form Markdown description, as a `str` that is
- empty if no description was specified.
- display_name: An optional long-form Markdown description, as a `str` that is
- empty if no description was specified. Deprecated; may be removed soon.
- """
-
- def __eq__(self, other):
- if not isinstance(other, TensorTimeSeries):
- return False
- if self._max_step != other._max_step:
- return False
- if self._max_wall_time != other._max_wall_time:
- return False
- if self._plugin_content != other._plugin_content:
- return False
- if self._description != other._description:
- return False
- if self._display_name != other._display_name:
- return False
- return True
-
- def __hash__(self):
- return hash((
- self._max_step,
- self._max_wall_time,
- self._plugin_content,
- self._description,
- self._display_name,
- ))
-
- def __repr__(self):
- return "TensorTimeSeries(%s)" % ", ".join((
- "max_step=%r" % (self._max_step,),
- "max_wall_time=%r" % (self._max_wall_time,),
- "plugin_content=%r" % (self._plugin_content,),
- "description=%r" % (self._description,),
- "display_name=%r" % (self._display_name,),
- ))
+ @property
+ def blob_key(self):
+ """Provide a key uniquely identifying a blob.
+ Callers should consider these keys to be opaque-- i.e., to have
+ no intrinsic meaning. Some data providers may use random IDs;
+ but others may encode information into the key, in which case
+ callers must make no attempt to decode it.
+ """
+ return self._blob_key
-class TensorDatum(object):
- """A single datum in a tensor time series for a run and tag.
-
- Attributes:
- step: The global step at which this datum occurred; an integer. This
- is a unique key among data of this time series.
- wall_time: The real-world time at which this datum occurred, as
- `float` seconds since epoch.
- numpy: The `numpy.ndarray` value with the tensor contents of this
- datum.
- """
-
- __slots__ = ("_step", "_wall_time", "_numpy")
-
- def __init__(self, step, wall_time, numpy):
- self._step = step
- self._wall_time = wall_time
- self._numpy = numpy
-
- @property
- def step(self):
- return self._step
-
- @property
- def wall_time(self):
- return self._wall_time
-
- @property
- def numpy(self):
- return self._numpy
-
- def __eq__(self, other):
- if not isinstance(other, TensorDatum):
- return False
- if self._step != other._step:
- return False
- if self._wall_time != other._wall_time:
- return False
- if not np.array_equal(self._numpy, other._numpy):
- return False
- return True
-
- # Unhashable type: numpy arrays are mutable.
- __hash__ = None
-
- def __repr__(self):
- return "TensorDatum(%s)" % ", ".join((
- "step=%r" % (self._step,),
- "wall_time=%r" % (self._wall_time,),
- "numpy=%r" % (self._numpy,),
- ))
+ @property
+ def url(self):
+ """Provide the direct-access URL for this blob, if available.
+ Note that this method is *not* expected to construct a URL to
+ the data-loading endpoint provided by TensorBoard. If this
+ method returns None, then the caller should proceed to use
+ `blob_key()` to build the URL, as needed.
+ """
+ return self._url
-class BlobSequenceTimeSeries(_TimeSeries):
- """Metadata about a blob sequence time series for a particular run and tag.
-
- Attributes:
- max_step: The largest step value of any datum in this scalar time series; a
- nonnegative integer.
- max_wall_time: The largest wall time of any datum in this time series, as
- `float` seconds since epoch.
- latest_max_index: The largest index in the sequence at max_step (0-based,
- inclusive).
- plugin_content: A bytestring of arbitrary plugin-specific metadata for this
- time series, as provided to `tf.summary.write` in the
- `plugin_data.content` field of the `metadata` argument.
- description: An optional long-form Markdown description, as a `str` that is
- empty if no description was specified.
- display_name: An optional long-form Markdown description, as a `str` that is
- empty if no description was specified. Deprecated; may be removed soon.
- """
-
- __slots__ = ("_latest_max_index",)
-
- def __init__(
- self,
- max_step,
- max_wall_time,
- latest_max_index,
- plugin_content,
- description,
- display_name,
- ):
- super(BlobSequenceTimeSeries, self).__init__(
- max_step, max_wall_time, plugin_content, description, display_name
- )
- self._latest_max_index = latest_max_index
-
- @property
- def latest_max_index(self):
- return self._latest_max_index
-
- def __eq__(self, other):
- if not isinstance(other, BlobSequenceTimeSeries):
- return False
- if self._max_step != other._max_step:
- return False
- if self._max_wall_time != other._max_wall_time:
- return False
- if self._latest_max_index != other._latest_max_index:
- return False
- if self._plugin_content != other._plugin_content:
- return False
- if self._description != other._description:
- return False
- if self._display_name != other._display_name:
- return False
- return True
-
- def __hash__(self):
- return hash((
- self._max_step,
- self._max_wall_time,
- self._latest_max_index,
- self._plugin_content,
- self._description,
- self._display_name,
- ))
-
- def __repr__(self):
- return "BlobSequenceTimeSeries(%s)" % ", ".join((
- "max_step=%r" % (self._max_step,),
- "max_wall_time=%r" % (self._max_wall_time,),
- "latest_max_index=%r" % (self._latest_max_index,),
- "plugin_content=%r" % (self._plugin_content,),
- "description=%r" % (self._description,),
- "display_name=%r" % (self._display_name,),
- ))
+ def __eq__(self, other):
+ if not isinstance(other, BlobReference):
+ return False
+ if self._blob_key != other._blob_key:
+ return False
+ if self._url != other._url:
+ return False
+ return True
+ def __hash__(self):
+ return hash((self._blob_key, self._url))
-class BlobReference(object):
- """A reference to a blob.
-
- Attributes:
- blob_key: A string containing a key uniquely identifying a blob, which
- may be dereferenced via `provider.read_blob(blob_key)`.
-
- These keys must be constructed such that they can be included directly in
- a URL, with no further encoding. Concretely, this means that they consist
- exclusively of "unreserved characters" per RFC 3986, namely
- [a-zA-Z0-9._~-]. These keys are case-sensitive; it may be wise for
- implementations to normalize case to reduce confusion. The empty string
- is not a valid key.
-
- Blob keys must not contain information that should be kept secret.
- Privacy-sensitive applications should use random keys (e.g. UUIDs), or
- encrypt keys containing secret fields.
- url: (optional) A string containing a URL from which the blob data may be
- fetched directly, bypassing the data provider. URLs may be a vector
- for data leaks (e.g. via browser history, web proxies, etc.), so these
- URLs should not expose secret information.
- """
-
- __slots__ = ("_url", "_blob_key")
-
- def __init__(self, blob_key, url=None):
- self._blob_key = blob_key
- self._url = url
-
- @property
- def blob_key(self):
- """Provide a key uniquely identifying a blob.
-
- Callers should consider these keys to be opaque-- i.e., to have no intrinsic
- meaning. Some data providers may use random IDs; but others may encode
- information into the key, in which case callers must make no attempt to
- decode it.
- """
- return self._blob_key
+ def __repr__(self):
+ return "BlobReference(%s)" % ", ".join(
+ ("blob_key=%r" % (self._blob_key,), "url=%r" % (self._url,))
+ )
- @property
- def url(self):
- """Provide the direct-access URL for this blob, if available.
- Note that this method is *not* expected to construct a URL to the
- data-loading endpoint provided by TensorBoard. If this method returns
- None, then the caller should proceed to use `blob_key()` to build the URL,
- as needed.
+class BlobSequenceDatum(object):
+ """A single datum in a blob sequence time series for a run and tag.
+
+ Attributes:
+ step: The global step at which this datum occurred; an integer. This is a
+ unique key among data of this time series.
+ wall_time: The real-world time at which this datum occurred, as `float`
+ seconds since epoch.
+ values: A tuple of `BlobReference` objects, providing access to elements of
+ this sequence.
"""
- return self._url
-
- def __eq__(self, other):
- if not isinstance(other, BlobReference):
- return False
- if self._blob_key != other._blob_key:
- return False
- if self._url != other._url:
- return False
- return True
-
- def __hash__(self):
- return hash((self._blob_key, self._url))
-
- def __repr__(self):
- return "BlobReference(%s)" % ", ".join(
- ("blob_key=%r" % (self._blob_key,), "url=%r" % (self._url,))
- )
-
-class BlobSequenceDatum(object):
- """A single datum in a blob sequence time series for a run and tag.
-
- Attributes:
- step: The global step at which this datum occurred; an integer. This is a
- unique key among data of this time series.
- wall_time: The real-world time at which this datum occurred, as `float`
- seconds since epoch.
- values: A tuple of `BlobReference` objects, providing access to elements of
- this sequence.
- """
-
- __slots__ = ("_step", "_wall_time", "_values")
-
- def __init__(self, step, wall_time, values):
- self._step = step
- self._wall_time = wall_time
- self._values = values
-
- @property
- def step(self):
- return self._step
-
- @property
- def wall_time(self):
- return self._wall_time
-
- @property
- def values(self):
- return self._values
-
- def __eq__(self, other):
- if not isinstance(other, BlobSequenceDatum):
- return False
- if self._step != other._step:
- return False
- if self._wall_time != other._wall_time:
- return False
- if self._values != other._values:
- return False
- return True
-
- def __hash__(self):
- return hash((self._step, self._wall_time, self._values))
-
- def __repr__(self):
- return "BlobSequenceDatum(%s)" % ", ".join((
- "step=%r" % (self._step,),
- "wall_time=%r" % (self._wall_time,),
- "values=%r" % (self._values,),
- ))
+ __slots__ = ("_step", "_wall_time", "_values")
+
+ def __init__(self, step, wall_time, values):
+ self._step = step
+ self._wall_time = wall_time
+ self._values = values
+
+ @property
+ def step(self):
+ return self._step
+
+ @property
+ def wall_time(self):
+ return self._wall_time
+
+ @property
+ def values(self):
+ return self._values
+
+ def __eq__(self, other):
+ if not isinstance(other, BlobSequenceDatum):
+ return False
+ if self._step != other._step:
+ return False
+ if self._wall_time != other._wall_time:
+ return False
+ if self._values != other._values:
+ return False
+ return True
+
+ def __hash__(self):
+ return hash((self._step, self._wall_time, self._values))
+
+ def __repr__(self):
+ return "BlobSequenceDatum(%s)" % ", ".join(
+ (
+ "step=%r" % (self._step,),
+ "wall_time=%r" % (self._wall_time,),
+ "values=%r" % (self._values,),
+ )
+ )
class RunTagFilter(object):
- """Filters data by run and tag names."""
-
- def __init__(self, runs=None, tags=None):
- """Construct a `RunTagFilter`.
-
- A time series passes this filter if both its run *and* its tag are
- included in the corresponding whitelists.
-
- Order and multiplicity are ignored; `runs` and `tags` are treated as
- sets.
-
- Args:
- runs: Collection of run names, as strings, or `None` to admit all
- runs.
- tags: Collection of tag names, as strings, or `None` to admit all
- tags.
- """
- self._runs = None if runs is None else frozenset(runs)
- self._tags = None if tags is None else frozenset(tags)
-
- @property
- def runs(self):
- return self._runs
-
- @property
- def tags(self):
- return self._tags
-
- def __repr__(self):
- return "RunTagFilter(%s)" % ", ".join((
- "runs=%r" % (self._runs,),
- "tags=%r" % (self._tags,),
- ))
+ """Filters data by run and tag names."""
+
+ def __init__(self, runs=None, tags=None):
+ """Construct a `RunTagFilter`.
+
+ A time series passes this filter if both its run *and* its tag are
+ included in the corresponding whitelists.
+
+ Order and multiplicity are ignored; `runs` and `tags` are treated as
+ sets.
+
+ Args:
+ runs: Collection of run names, as strings, or `None` to admit all
+ runs.
+ tags: Collection of tag names, as strings, or `None` to admit all
+ tags.
+ """
+ self._runs = None if runs is None else frozenset(runs)
+ self._tags = None if tags is None else frozenset(tags)
+
+ @property
+ def runs(self):
+ return self._runs
+
+ @property
+ def tags(self):
+ return self._tags
+
+ def __repr__(self):
+ return "RunTagFilter(%s)" % ", ".join(
+ ("runs=%r" % (self._runs,), "tags=%r" % (self._tags,),)
+ )
diff --git a/tensorboard/data/provider_test.py b/tensorboard/data/provider_test.py
index ae2e72f36f..8a8e8fc17c 100644
--- a/tensorboard/data/provider_test.py
+++ b/tensorboard/data/provider_test.py
@@ -26,270 +26,300 @@
class DataProviderTest(tb_test.TestCase):
- def test_abstract(self):
- with six.assertRaisesRegex(self, TypeError, "abstract class"):
- provider.DataProvider()
+ def test_abstract(self):
+ with six.assertRaisesRegex(self, TypeError, "abstract class"):
+ provider.DataProvider()
class RunTest(tb_test.TestCase):
- def test_eq(self):
- a1 = provider.Run(run_id="a", run_name="aa", start_time=1.25)
- a2 = provider.Run(run_id="a", run_name="aa", start_time=1.25)
- b = provider.Run(run_id="b", run_name="bb", start_time=-1.75)
- self.assertEqual(a1, a2)
- self.assertNotEqual(a1, b)
- self.assertNotEqual(b, object())
-
- def test_repr(self):
- x = provider.Run(run_id="alpha", run_name="bravo", start_time=1.25)
- repr_ = repr(x)
- self.assertIn(repr(x.run_id), repr_)
- self.assertIn(repr(x.run_name), repr_)
- self.assertIn(repr(x.start_time), repr_)
+ def test_eq(self):
+ a1 = provider.Run(run_id="a", run_name="aa", start_time=1.25)
+ a2 = provider.Run(run_id="a", run_name="aa", start_time=1.25)
+ b = provider.Run(run_id="b", run_name="bb", start_time=-1.75)
+ self.assertEqual(a1, a2)
+ self.assertNotEqual(a1, b)
+ self.assertNotEqual(b, object())
+
+ def test_repr(self):
+ x = provider.Run(run_id="alpha", run_name="bravo", start_time=1.25)
+ repr_ = repr(x)
+ self.assertIn(repr(x.run_id), repr_)
+ self.assertIn(repr(x.run_name), repr_)
+ self.assertIn(repr(x.start_time), repr_)
class ScalarTimeSeriesTest(tb_test.TestCase):
- def test_repr(self):
- x = provider.ScalarTimeSeries(
- max_step=77,
- max_wall_time=1234.5,
- plugin_content=b"AB\xCD\xEF!\x00",
- description="test test",
- display_name="one two",
- )
- repr_ = repr(x)
- self.assertIn(repr(x.max_step), repr_)
- self.assertIn(repr(x.max_wall_time), repr_)
- self.assertIn(repr(x.plugin_content), repr_)
- self.assertIn(repr(x.description), repr_)
- self.assertIn(repr(x.display_name), repr_)
-
- def test_eq(self):
- x1 = provider.ScalarTimeSeries(77, 1234.5, b"\x12", "one", "two")
- x2 = provider.ScalarTimeSeries(77, 1234.5, b"\x12", "one", "two")
- x3 = provider.ScalarTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum")
- self.assertEqual(x1, x2)
- self.assertNotEqual(x1, x3)
- self.assertNotEqual(x1, object())
-
- def test_hash(self):
- x1 = provider.ScalarTimeSeries(77, 1234.5, b"\x12", "one", "two")
- x2 = provider.ScalarTimeSeries(77, 1234.5, b"\x12", "one", "two")
- x3 = provider.ScalarTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum")
- self.assertEqual(hash(x1), hash(x2))
- # The next check is technically not required by the `__hash__`
- # contract, but _should_ pass; failure on this assertion would at
- # least warrant some scrutiny.
- self.assertNotEqual(hash(x1), hash(x3))
+ def test_repr(self):
+ x = provider.ScalarTimeSeries(
+ max_step=77,
+ max_wall_time=1234.5,
+ plugin_content=b"AB\xCD\xEF!\x00",
+ description="test test",
+ display_name="one two",
+ )
+ repr_ = repr(x)
+ self.assertIn(repr(x.max_step), repr_)
+ self.assertIn(repr(x.max_wall_time), repr_)
+ self.assertIn(repr(x.plugin_content), repr_)
+ self.assertIn(repr(x.description), repr_)
+ self.assertIn(repr(x.display_name), repr_)
+
+ def test_eq(self):
+ x1 = provider.ScalarTimeSeries(77, 1234.5, b"\x12", "one", "two")
+ x2 = provider.ScalarTimeSeries(77, 1234.5, b"\x12", "one", "two")
+ x3 = provider.ScalarTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum")
+ self.assertEqual(x1, x2)
+ self.assertNotEqual(x1, x3)
+ self.assertNotEqual(x1, object())
+
+ def test_hash(self):
+ x1 = provider.ScalarTimeSeries(77, 1234.5, b"\x12", "one", "two")
+ x2 = provider.ScalarTimeSeries(77, 1234.5, b"\x12", "one", "two")
+ x3 = provider.ScalarTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum")
+ self.assertEqual(hash(x1), hash(x2))
+ # The next check is technically not required by the `__hash__`
+ # contract, but _should_ pass; failure on this assertion would at
+ # least warrant some scrutiny.
+ self.assertNotEqual(hash(x1), hash(x3))
class ScalarDatumTest(tb_test.TestCase):
- def test_repr(self):
- x = provider.ScalarDatum(step=123, wall_time=234.5, value=-0.125)
- repr_ = repr(x)
- self.assertIn(repr(x.step), repr_)
- self.assertIn(repr(x.wall_time), repr_)
- self.assertIn(repr(x.value), repr_)
-
- def test_eq(self):
- x1 = provider.ScalarDatum(step=12, wall_time=0.25, value=1.25)
- x2 = provider.ScalarDatum(step=12, wall_time=0.25, value=1.25)
- x3 = provider.ScalarDatum(step=23, wall_time=3.25, value=-0.5)
- self.assertEqual(x1, x2)
- self.assertNotEqual(x1, x3)
- self.assertNotEqual(x1, object())
-
- def test_hash(self):
- x1 = provider.ScalarDatum(step=12, wall_time=0.25, value=1.25)
- x2 = provider.ScalarDatum(step=12, wall_time=0.25, value=1.25)
- x3 = provider.ScalarDatum(step=23, wall_time=3.25, value=-0.5)
- self.assertEqual(hash(x1), hash(x2))
- # The next check is technically not required by the `__hash__`
- # contract, but _should_ pass; failure on this assertion would at
- # least warrant some scrutiny.
- self.assertNotEqual(hash(x1), hash(x3))
+ def test_repr(self):
+ x = provider.ScalarDatum(step=123, wall_time=234.5, value=-0.125)
+ repr_ = repr(x)
+ self.assertIn(repr(x.step), repr_)
+ self.assertIn(repr(x.wall_time), repr_)
+ self.assertIn(repr(x.value), repr_)
+
+ def test_eq(self):
+ x1 = provider.ScalarDatum(step=12, wall_time=0.25, value=1.25)
+ x2 = provider.ScalarDatum(step=12, wall_time=0.25, value=1.25)
+ x3 = provider.ScalarDatum(step=23, wall_time=3.25, value=-0.5)
+ self.assertEqual(x1, x2)
+ self.assertNotEqual(x1, x3)
+ self.assertNotEqual(x1, object())
+
+ def test_hash(self):
+ x1 = provider.ScalarDatum(step=12, wall_time=0.25, value=1.25)
+ x2 = provider.ScalarDatum(step=12, wall_time=0.25, value=1.25)
+ x3 = provider.ScalarDatum(step=23, wall_time=3.25, value=-0.5)
+ self.assertEqual(hash(x1), hash(x2))
+ # The next check is technically not required by the `__hash__`
+ # contract, but _should_ pass; failure on this assertion would at
+ # least warrant some scrutiny.
+ self.assertNotEqual(hash(x1), hash(x3))
class TensorTimeSeriesTest(tb_test.TestCase):
- def test_repr(self):
- x = provider.TensorTimeSeries(
- max_step=77,
- max_wall_time=1234.5,
- plugin_content=b"AB\xCD\xEF!\x00",
- description="test test",
- display_name="one two",
- )
- repr_ = repr(x)
- self.assertIn(repr(x.max_step), repr_)
- self.assertIn(repr(x.max_wall_time), repr_)
- self.assertIn(repr(x.plugin_content), repr_)
- self.assertIn(repr(x.description), repr_)
- self.assertIn(repr(x.display_name), repr_)
-
- def test_eq(self):
- x1 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two")
- x2 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two")
- x3 = provider.TensorTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum")
- self.assertEqual(x1, x2)
- self.assertNotEqual(x1, x3)
- self.assertNotEqual(x1, object())
-
- def test_hash(self):
- x1 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two")
- x2 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two")
- x3 = provider.TensorTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum")
- self.assertEqual(hash(x1), hash(x2))
- # The next check is technically not required by the `__hash__`
- # contract, but _should_ pass; failure on this assertion would at
- # least warrant some scrutiny.
- self.assertNotEqual(hash(x1), hash(x3))
+ def test_repr(self):
+ x = provider.TensorTimeSeries(
+ max_step=77,
+ max_wall_time=1234.5,
+ plugin_content=b"AB\xCD\xEF!\x00",
+ description="test test",
+ display_name="one two",
+ )
+ repr_ = repr(x)
+ self.assertIn(repr(x.max_step), repr_)
+ self.assertIn(repr(x.max_wall_time), repr_)
+ self.assertIn(repr(x.plugin_content), repr_)
+ self.assertIn(repr(x.description), repr_)
+ self.assertIn(repr(x.display_name), repr_)
+
+ def test_eq(self):
+ x1 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two")
+ x2 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two")
+ x3 = provider.TensorTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum")
+ self.assertEqual(x1, x2)
+ self.assertNotEqual(x1, x3)
+ self.assertNotEqual(x1, object())
+
+ def test_hash(self):
+ x1 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two")
+ x2 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two")
+ x3 = provider.TensorTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum")
+ self.assertEqual(hash(x1), hash(x2))
+ # The next check is technically not required by the `__hash__`
+ # contract, but _should_ pass; failure on this assertion would at
+ # least warrant some scrutiny.
+ self.assertNotEqual(hash(x1), hash(x3))
class TensorDatumTest(tb_test.TestCase):
- def test_repr(self):
- x = provider.TensorDatum(step=123, wall_time=234.5, numpy=np.array(-0.25))
- repr_ = repr(x)
- self.assertIn(repr(x.step), repr_)
- self.assertIn(repr(x.wall_time), repr_)
- self.assertIn(repr(x.numpy), repr_)
-
- def test_eq(self):
- nd = np.array
- x1 = provider.TensorDatum(step=12, wall_time=0.25, numpy=nd([1.0, 2.0]))
- x2 = provider.TensorDatum(step=12, wall_time=0.25, numpy=nd([1.0, 2.0]))
- x3 = provider.TensorDatum(step=23, wall_time=3.25, numpy=nd([-0.5, -2.5]))
- self.assertEqual(x1, x2)
- self.assertNotEqual(x1, x3)
- self.assertNotEqual(x1, object())
-
- def test_eq_with_rank0_tensor(self):
- x1 = provider.TensorDatum(step=12, wall_time=0.25, numpy=np.array([1.25]))
- x2 = provider.TensorDatum(step=12, wall_time=0.25, numpy=np.array([1.25]))
- x3 = provider.TensorDatum(step=23, wall_time=3.25, numpy=np.array([1.25]))
- self.assertEqual(x1, x2)
- self.assertNotEqual(x1, x3)
- self.assertNotEqual(x1, object())
-
- def test_hash(self):
- x = provider.TensorDatum(step=12, wall_time=0.25, numpy=np.array([1.25]))
- with six.assertRaisesRegex(self, TypeError, "unhashable type"):
- hash(x)
+ def test_repr(self):
+ x = provider.TensorDatum(
+ step=123, wall_time=234.5, numpy=np.array(-0.25)
+ )
+ repr_ = repr(x)
+ self.assertIn(repr(x.step), repr_)
+ self.assertIn(repr(x.wall_time), repr_)
+ self.assertIn(repr(x.numpy), repr_)
+
+ def test_eq(self):
+ nd = np.array
+ x1 = provider.TensorDatum(step=12, wall_time=0.25, numpy=nd([1.0, 2.0]))
+ x2 = provider.TensorDatum(step=12, wall_time=0.25, numpy=nd([1.0, 2.0]))
+ x3 = provider.TensorDatum(
+ step=23, wall_time=3.25, numpy=nd([-0.5, -2.5])
+ )
+ self.assertEqual(x1, x2)
+ self.assertNotEqual(x1, x3)
+ self.assertNotEqual(x1, object())
+
+ def test_eq_with_rank0_tensor(self):
+ x1 = provider.TensorDatum(
+ step=12, wall_time=0.25, numpy=np.array([1.25])
+ )
+ x2 = provider.TensorDatum(
+ step=12, wall_time=0.25, numpy=np.array([1.25])
+ )
+ x3 = provider.TensorDatum(
+ step=23, wall_time=3.25, numpy=np.array([1.25])
+ )
+ self.assertEqual(x1, x2)
+ self.assertNotEqual(x1, x3)
+ self.assertNotEqual(x1, object())
+
+ def test_hash(self):
+ x = provider.TensorDatum(
+ step=12, wall_time=0.25, numpy=np.array([1.25])
+ )
+ with six.assertRaisesRegex(self, TypeError, "unhashable type"):
+ hash(x)
class BlobSequenceTimeSeriesTest(tb_test.TestCase):
-
- def test_repr(self):
- x = provider.BlobSequenceTimeSeries(
- max_step=77,
- max_wall_time=1234.5,
- latest_max_index=6,
- plugin_content=b"AB\xCD\xEF!\x00",
- description="test test",
- display_name="one two",
- )
- repr_ = repr(x)
- self.assertIn(repr(x.max_step), repr_)
- self.assertIn(repr(x.max_wall_time), repr_)
- self.assertIn(repr(x.latest_max_index), repr_)
- self.assertIn(repr(x.plugin_content), repr_)
- self.assertIn(repr(x.description), repr_)
- self.assertIn(repr(x.display_name), repr_)
-
- def test_eq(self):
- x1 = provider.BlobSequenceTimeSeries(77, 1234.5, 6, b"\x12", "one", "two")
- x2 = provider.BlobSequenceTimeSeries(77, 1234.5, 6, b"\x12", "one", "two")
- x3 = provider.BlobSequenceTimeSeries(66, 4321.0, 7, b"\x7F", "hmm", "hum")
- self.assertEqual(x1, x2)
- self.assertNotEqual(x1, x3)
- self.assertNotEqual(x1, object())
-
- def test_hash(self):
- x1 = provider.BlobSequenceTimeSeries(77, 1234.5, 6, b"\x12", "one", "two")
- x2 = provider.BlobSequenceTimeSeries(77, 1234.5, 6, b"\x12", "one", "two")
- x3 = provider.BlobSequenceTimeSeries(66, 4321.0, 7, b"\x7F", "hmm", "hum")
- self.assertEqual(hash(x1), hash(x2))
- # The next check is technically not required by the `__hash__`
- # contract, but _should_ pass; failure on this assertion would at
- # least warrant some scrutiny.
- self.assertNotEqual(hash(x1), hash(x3))
+ def test_repr(self):
+ x = provider.BlobSequenceTimeSeries(
+ max_step=77,
+ max_wall_time=1234.5,
+ latest_max_index=6,
+ plugin_content=b"AB\xCD\xEF!\x00",
+ description="test test",
+ display_name="one two",
+ )
+ repr_ = repr(x)
+ self.assertIn(repr(x.max_step), repr_)
+ self.assertIn(repr(x.max_wall_time), repr_)
+ self.assertIn(repr(x.latest_max_index), repr_)
+ self.assertIn(repr(x.plugin_content), repr_)
+ self.assertIn(repr(x.description), repr_)
+ self.assertIn(repr(x.display_name), repr_)
+
+ def test_eq(self):
+ x1 = provider.BlobSequenceTimeSeries(
+ 77, 1234.5, 6, b"\x12", "one", "two"
+ )
+ x2 = provider.BlobSequenceTimeSeries(
+ 77, 1234.5, 6, b"\x12", "one", "two"
+ )
+ x3 = provider.BlobSequenceTimeSeries(
+ 66, 4321.0, 7, b"\x7F", "hmm", "hum"
+ )
+ self.assertEqual(x1, x2)
+ self.assertNotEqual(x1, x3)
+ self.assertNotEqual(x1, object())
+
+ def test_hash(self):
+ x1 = provider.BlobSequenceTimeSeries(
+ 77, 1234.5, 6, b"\x12", "one", "two"
+ )
+ x2 = provider.BlobSequenceTimeSeries(
+ 77, 1234.5, 6, b"\x12", "one", "two"
+ )
+ x3 = provider.BlobSequenceTimeSeries(
+ 66, 4321.0, 7, b"\x7F", "hmm", "hum"
+ )
+ self.assertEqual(hash(x1), hash(x2))
+ # The next check is technically not required by the `__hash__`
+ # contract, but _should_ pass; failure on this assertion would at
+ # least warrant some scrutiny.
+ self.assertNotEqual(hash(x1), hash(x3))
class BlobReferenceTest(tb_test.TestCase):
-
- def test_repr(self):
- x = provider.BlobReference(url="foo", blob_key="baz")
- repr_ = repr(x)
- self.assertIn(repr(x.url), repr_)
- self.assertIn(repr(x.blob_key), repr_)
-
- def test_eq(self):
- x1 = provider.BlobReference(url="foo", blob_key="baz")
- x2 = provider.BlobReference(url="foo", blob_key="baz")
- x3 = provider.BlobReference(url="foo", blob_key="qux")
- self.assertEqual(x1, x2)
- self.assertNotEqual(x1, x3)
- self.assertNotEqual(x1, object())
-
- def test_hash(self):
- x1 = provider.BlobReference(url="foo", blob_key="baz")
- x2 = provider.BlobReference(url="foo", blob_key="baz")
- x3 = provider.BlobReference(url="foo", blob_key="qux")
- self.assertEqual(hash(x1), hash(x2))
- # The next check is technically not required by the `__hash__`
- # contract, but _should_ pass; failure on this assertion would at
- # least warrant some scrutiny.
- self.assertNotEqual(hash(x1), hash(x3))
+ def test_repr(self):
+ x = provider.BlobReference(url="foo", blob_key="baz")
+ repr_ = repr(x)
+ self.assertIn(repr(x.url), repr_)
+ self.assertIn(repr(x.blob_key), repr_)
+
+ def test_eq(self):
+ x1 = provider.BlobReference(url="foo", blob_key="baz")
+ x2 = provider.BlobReference(url="foo", blob_key="baz")
+ x3 = provider.BlobReference(url="foo", blob_key="qux")
+ self.assertEqual(x1, x2)
+ self.assertNotEqual(x1, x3)
+ self.assertNotEqual(x1, object())
+
+ def test_hash(self):
+ x1 = provider.BlobReference(url="foo", blob_key="baz")
+ x2 = provider.BlobReference(url="foo", blob_key="baz")
+ x3 = provider.BlobReference(url="foo", blob_key="qux")
+ self.assertEqual(hash(x1), hash(x2))
+ # The next check is technically not required by the `__hash__`
+ # contract, but _should_ pass; failure on this assertion would at
+ # least warrant some scrutiny.
+ self.assertNotEqual(hash(x1), hash(x3))
class BlobSequenceDatumTest(tb_test.TestCase):
-
- def test_repr(self):
- x = provider.BlobSequenceDatum(
- step=123, wall_time=234.5, values=("foo", "bar", "baz"))
- repr_ = repr(x)
- self.assertIn(repr(x.step), repr_)
- self.assertIn(repr(x.wall_time), repr_)
- self.assertIn(repr(x.values), repr_)
-
- def test_eq(self):
- x1 = provider.BlobSequenceDatum(
- step=12, wall_time=0.25, values=("foo", "bar", "baz"))
- x2 = provider.BlobSequenceDatum(
- step=12, wall_time=0.25, values=("foo", "bar", "baz"))
- x3 = provider.BlobSequenceDatum(step=23, wall_time=3.25, values=("qux",))
- self.assertEqual(x1, x2)
- self.assertNotEqual(x1, x3)
- self.assertNotEqual(x1, object())
-
- def test_hash(self):
- x1 = provider.BlobSequenceDatum(
- step=12, wall_time=0.25, values=("foo", "bar", "baz"))
- x2 = provider.BlobSequenceDatum(
- step=12, wall_time=0.25, values=("foo", "bar", "baz"))
- x3 = provider.BlobSequenceDatum(step=23, wall_time=3.25, values=("qux",))
- self.assertEqual(hash(x1), hash(x2))
- # The next check is technically not required by the `__hash__`
- # contract, but _should_ pass; failure on this assertion would at
- # least warrant some scrutiny.
- self.assertNotEqual(hash(x1), hash(x3))
+ def test_repr(self):
+ x = provider.BlobSequenceDatum(
+ step=123, wall_time=234.5, values=("foo", "bar", "baz")
+ )
+ repr_ = repr(x)
+ self.assertIn(repr(x.step), repr_)
+ self.assertIn(repr(x.wall_time), repr_)
+ self.assertIn(repr(x.values), repr_)
+
+ def test_eq(self):
+ x1 = provider.BlobSequenceDatum(
+ step=12, wall_time=0.25, values=("foo", "bar", "baz")
+ )
+ x2 = provider.BlobSequenceDatum(
+ step=12, wall_time=0.25, values=("foo", "bar", "baz")
+ )
+ x3 = provider.BlobSequenceDatum(
+ step=23, wall_time=3.25, values=("qux",)
+ )
+ self.assertEqual(x1, x2)
+ self.assertNotEqual(x1, x3)
+ self.assertNotEqual(x1, object())
+
+ def test_hash(self):
+ x1 = provider.BlobSequenceDatum(
+ step=12, wall_time=0.25, values=("foo", "bar", "baz")
+ )
+ x2 = provider.BlobSequenceDatum(
+ step=12, wall_time=0.25, values=("foo", "bar", "baz")
+ )
+ x3 = provider.BlobSequenceDatum(
+ step=23, wall_time=3.25, values=("qux",)
+ )
+ self.assertEqual(hash(x1), hash(x2))
+ # The next check is technically not required by the `__hash__`
+ # contract, but _should_ pass; failure on this assertion would at
+ # least warrant some scrutiny.
+ self.assertNotEqual(hash(x1), hash(x3))
class RunTagFilterTest(tb_test.TestCase):
- def test_defensive_copy(self):
- runs = ["r1"]
- tags = ["t1"]
- f = provider.RunTagFilter(runs, tags)
- runs.append("r2")
- tags.pop()
- self.assertEqual(frozenset(f.runs), frozenset(["r1"]))
- self.assertEqual(frozenset(f.tags), frozenset(["t1"]))
-
- def test_repr(self):
- x = provider.RunTagFilter(runs=["one", "two"], tags=["three", "four"])
- repr_ = repr(x)
- self.assertIn(repr(x.runs), repr_)
- self.assertIn(repr(x.tags), repr_)
+ def test_defensive_copy(self):
+ runs = ["r1"]
+ tags = ["t1"]
+ f = provider.RunTagFilter(runs, tags)
+ runs.append("r2")
+ tags.pop()
+ self.assertEqual(frozenset(f.runs), frozenset(["r1"]))
+ self.assertEqual(frozenset(f.tags), frozenset(["t1"]))
+
+ def test_repr(self):
+ x = provider.RunTagFilter(runs=["one", "two"], tags=["three", "four"])
+ repr_ = repr(x)
+ self.assertIn(repr(x.runs), repr_)
+ self.assertIn(repr(x.tags), repr_)
if __name__ == "__main__":
- tb_test.main()
+ tb_test.main()
diff --git a/tensorboard/data_compat.py b/tensorboard/data_compat.py
index fe50829d58..4232b62f7d 100644
--- a/tensorboard/data_compat.py
+++ b/tensorboard/data_compat.py
@@ -30,81 +30,89 @@
def migrate_value(value):
- """Convert `value` to a new-style value, if necessary and possible.
-
- An "old-style" value is a value that uses any `value` field other than
- the `tensor` field. A "new-style" value is a value that uses the
- `tensor` field. TensorBoard continues to support old-style values on
- disk; this method converts them to new-style values so that further
- code need only deal with one data format.
-
- Arguments:
- value: A `Summary.Value` object. This argument is not modified.
-
- Returns:
- If the `value` is an old-style value for which there is a new-style
- equivalent, the result is the new-style value. Otherwise---if the
- value is already new-style or does not yet have a new-style
- equivalent---the value will be returned unchanged.
-
- :type value: Summary.Value
- :rtype: Summary.Value
- """
- handler = {
- 'histo': _migrate_histogram_value,
- 'image': _migrate_image_value,
- 'audio': _migrate_audio_value,
- 'simple_value': _migrate_scalar_value,
- }.get(value.WhichOneof('value'))
- return handler(value) if handler else value
+ """Convert `value` to a new-style value, if necessary and possible.
+
+ An "old-style" value is a value that uses any `value` field other than
+ the `tensor` field. A "new-style" value is a value that uses the
+ `tensor` field. TensorBoard continues to support old-style values on
+ disk; this method converts them to new-style values so that further
+ code need only deal with one data format.
+
+ Arguments:
+ value: A `Summary.Value` object. This argument is not modified.
+
+ Returns:
+ If the `value` is an old-style value for which there is a new-style
+ equivalent, the result is the new-style value. Otherwise---if the
+ value is already new-style or does not yet have a new-style
+ equivalent---the value will be returned unchanged.
+
+ :type value: Summary.Value
+ :rtype: Summary.Value
+ """
+ handler = {
+ "histo": _migrate_histogram_value,
+ "image": _migrate_image_value,
+ "audio": _migrate_audio_value,
+ "simple_value": _migrate_scalar_value,
+ }.get(value.WhichOneof("value"))
+ return handler(value) if handler else value
def make_summary(tag, metadata, data):
tensor_proto = tensor_util.make_tensor_proto(data)
- return summary_pb2.Summary.Value(tag=tag,
- metadata=metadata,
- tensor=tensor_proto)
+ return summary_pb2.Summary.Value(
+ tag=tag, metadata=metadata, tensor=tensor_proto
+ )
def _migrate_histogram_value(value):
- histogram_value = value.histo
- bucket_lefts = [histogram_value.min] + histogram_value.bucket_limit[:-1]
- bucket_rights = histogram_value.bucket_limit[:-1] + [histogram_value.max]
- bucket_counts = histogram_value.bucket
- buckets = np.array([bucket_lefts, bucket_rights, bucket_counts], dtype=np.float32).transpose()
+ histogram_value = value.histo
+ bucket_lefts = [histogram_value.min] + histogram_value.bucket_limit[:-1]
+ bucket_rights = histogram_value.bucket_limit[:-1] + [histogram_value.max]
+ bucket_counts = histogram_value.bucket
+ buckets = np.array(
+ [bucket_lefts, bucket_rights, bucket_counts], dtype=np.float32
+ ).transpose()
- summary_metadata = histogram_metadata.create_summary_metadata(
- display_name=value.metadata.display_name or value.tag,
- description=value.metadata.summary_description)
+ summary_metadata = histogram_metadata.create_summary_metadata(
+ display_name=value.metadata.display_name or value.tag,
+ description=value.metadata.summary_description,
+ )
- return make_summary(value.tag, summary_metadata, buckets)
+ return make_summary(value.tag, summary_metadata, buckets)
def _migrate_image_value(value):
- image_value = value.image
- data = [tf.compat.as_bytes(str(image_value.width)),
- tf.compat.as_bytes(str(image_value.height)),
- tf.compat.as_bytes(image_value.encoded_image_string)]
+ image_value = value.image
+ data = [
+ tf.compat.as_bytes(str(image_value.width)),
+ tf.compat.as_bytes(str(image_value.height)),
+ tf.compat.as_bytes(image_value.encoded_image_string),
+ ]
- summary_metadata = image_metadata.create_summary_metadata(
- display_name=value.metadata.display_name or value.tag,
- description=value.metadata.summary_description)
- return make_summary(value.tag, summary_metadata, data)
+ summary_metadata = image_metadata.create_summary_metadata(
+ display_name=value.metadata.display_name or value.tag,
+ description=value.metadata.summary_description,
+ )
+ return make_summary(value.tag, summary_metadata, data)
def _migrate_audio_value(value):
- audio_value = value.audio
- data = [[audio_value.encoded_audio_string, b'']] # empty label
- summary_metadata = audio_metadata.create_summary_metadata(
- display_name=value.metadata.display_name or value.tag,
- description=value.metadata.summary_description,
- encoding=audio_metadata.Encoding.Value('WAV'))
- return make_summary(value.tag, summary_metadata, data)
+ audio_value = value.audio
+ data = [[audio_value.encoded_audio_string, b""]] # empty label
+ summary_metadata = audio_metadata.create_summary_metadata(
+ display_name=value.metadata.display_name or value.tag,
+ description=value.metadata.summary_description,
+ encoding=audio_metadata.Encoding.Value("WAV"),
+ )
+ return make_summary(value.tag, summary_metadata, data)
def _migrate_scalar_value(value):
- scalar_value = value.simple_value
- summary_metadata = scalar_metadata.create_summary_metadata(
- display_name=value.metadata.display_name or value.tag,
- description=value.metadata.summary_description)
- return make_summary(value.tag, summary_metadata, scalar_value)
+ scalar_value = value.simple_value
+ summary_metadata = scalar_metadata.create_summary_metadata(
+ display_name=value.metadata.display_name or value.tag,
+ description=value.metadata.summary_description,
+ )
+ return make_summary(value.tag, summary_metadata, scalar_value)
diff --git a/tensorboard/data_compat_test.py b/tensorboard/data_compat_test.py
index c5e395fe84..73ca9e86a4 100644
--- a/tensorboard/data_compat_test.py
+++ b/tensorboard/data_compat_test.py
@@ -32,179 +32,204 @@
from tensorboard.util import tensor_util
-
class MigrateValueTest(tf.test.TestCase):
- """Tests for `migrate_value`.
-
- These tests should ensure that all first-party new-style values are
- passed through unchanged, that all supported old-style values are
- converted to new-style values, and that other old-style values are
- passed through unchanged.
- """
-
- def _value_from_op(self, op):
- with tf.compat.v1.Session() as sess:
- summary_pbtxt = sess.run(op)
- summary = summary_pb2.Summary()
- summary.ParseFromString(summary_pbtxt)
- # There may be multiple values (e.g., for an image summary that emits
- # multiple images in one batch). That's fine; we'll choose any
- # representative value, assuming that they're homogeneous.
- assert summary.value
- return summary.value[0]
-
- def _assert_noop(self, value):
- original_pbtxt = value.SerializeToString()
- result = data_compat.migrate_value(value)
- self.assertEqual(value, result)
- self.assertEqual(original_pbtxt, value.SerializeToString())
-
- def test_scalar(self):
- with tf.compat.v1.Graph().as_default():
- old_op = tf.compat.v1.summary.scalar('important_constants', tf.constant(0x5f3759df))
- old_value = self._value_from_op(old_op)
- assert old_value.HasField('simple_value'), old_value
- new_value = data_compat.migrate_value(old_value)
-
- self.assertEqual('important_constants', new_value.tag)
- expected_metadata = scalar_metadata.create_summary_metadata(
- display_name='important_constants',
- description='')
- self.assertEqual(expected_metadata, new_value.metadata)
- self.assertTrue(new_value.HasField('tensor'))
- data = tensor_util.make_ndarray(new_value.tensor)
- self.assertEqual((), data.shape)
- low_precision_value = np.array(0x5f3759df).astype('float32').item()
- self.assertEqual(low_precision_value, data.item())
-
- def test_audio(self):
- with tf.compat.v1.Graph().as_default():
- audio = tf.reshape(tf.linspace(0.0, 100.0, 4 * 10 * 2), (4, 10, 2))
- old_op = tf.compat.v1.summary.audio('k488', audio, 44100)
- old_value = self._value_from_op(old_op)
- assert old_value.HasField('audio'), old_value
- new_value = data_compat.migrate_value(old_value)
-
- self.assertEqual('k488/audio/0', new_value.tag)
- expected_metadata = audio_metadata.create_summary_metadata(
- display_name='k488/audio/0',
- description='',
- encoding=audio_metadata.Encoding.Value('WAV'))
- self.assertEqual(expected_metadata, new_value.metadata)
- self.assertTrue(new_value.HasField('tensor'))
- data = tensor_util.make_ndarray(new_value.tensor)
- self.assertEqual((1, 2), data.shape)
- self.assertEqual(tf.compat.as_bytes(old_value.audio.encoded_audio_string),
- data[0][0])
- self.assertEqual(b'', data[0][1]) # empty label
-
- def test_text(self):
- with tf.compat.v1.Graph().as_default():
- op = tf.compat.v1.summary.text('lorem_ipsum', tf.constant('dolor sit amet'))
- value = self._value_from_op(op)
- assert value.HasField('tensor'), value
- self._assert_noop(value)
-
- def test_fully_populated_tensor(self):
- with tf.compat.v1.Graph().as_default():
- metadata = summary_pb2.SummaryMetadata(
- plugin_data=summary_pb2.SummaryMetadata.PluginData(
- plugin_name='font_of_wisdom',
- content=b'adobe_garamond'))
- op = tf.compat.v1.summary.tensor_summary(
- name='tensorpocalypse',
- tensor=tf.constant([[0.0, 2.0], [float('inf'), float('nan')]]),
- display_name='TENSORPOCALYPSE',
- summary_description='look on my works ye mighty and despair',
- summary_metadata=metadata)
- value = self._value_from_op(op)
- assert value.HasField('tensor'), value
- self._assert_noop(value)
-
- def test_image(self):
- with tf.compat.v1.Graph().as_default():
- old_op = tf.compat.v1.summary.image('mona_lisa',
- tf.image.convert_image_dtype(
- tf.random.normal(shape=[1, 400, 200, 3]),
- tf.uint8,
- saturate=True))
- old_value = self._value_from_op(old_op)
- assert old_value.HasField('image'), old_value
- new_value = data_compat.migrate_value(old_value)
-
- self.assertEqual('mona_lisa/image/0', new_value.tag)
- expected_metadata = image_metadata.create_summary_metadata(
- display_name='mona_lisa/image/0', description='')
- self.assertEqual(expected_metadata, new_value.metadata)
- self.assertTrue(new_value.HasField('tensor'))
- (width, height, data) = tensor_util.make_ndarray(new_value.tensor)
- self.assertEqual(b'200', width)
- self.assertEqual(b'400', height)
- self.assertEqual(
- tf.compat.as_bytes(old_value.image.encoded_image_string), data)
-
- def test_histogram(self):
- with tf.compat.v1.Graph().as_default():
- old_op = tf.compat.v1.summary.histogram('important_data',
- tf.random.normal(shape=[23, 45]))
- old_value = self._value_from_op(old_op)
- assert old_value.HasField('histo'), old_value
- new_value = data_compat.migrate_value(old_value)
-
- self.assertEqual('important_data', new_value.tag)
- expected_metadata = histogram_metadata.create_summary_metadata(
- display_name='important_data', description='')
- self.assertEqual(expected_metadata, new_value.metadata)
- self.assertTrue(new_value.HasField('tensor'))
- buckets = tensor_util.make_ndarray(new_value.tensor)
- self.assertEqual(old_value.histo.min, buckets[0][0])
- self.assertEqual(old_value.histo.max, buckets[-1][1])
- self.assertEqual(23 * 45, buckets[:, 2].astype(int).sum())
-
- def test_new_style_histogram(self):
- with tf.compat.v1.Graph().as_default():
- op = histogram_summary.op('important_data',
- tf.random.normal(shape=[10, 10]),
- bucket_count=100,
- display_name='Important data',
- description='secrets of the universe')
- value = self._value_from_op(op)
- assert value.HasField('tensor'), value
- self._assert_noop(value)
-
- def test_new_style_image(self):
- with tf.compat.v1.Graph().as_default():
- op = image_summary.op(
- 'mona_lisa',
- tf.image.convert_image_dtype(
- tf.random.normal(shape=[1, 400, 200, 3]), tf.uint8, saturate=True),
- display_name='The Mona Lisa',
- description='A renowned portrait by da Vinci.')
- value = self._value_from_op(op)
- assert value.HasField('tensor'), value
- self._assert_noop(value)
-
- def test_new_style_audio(self):
- with tf.compat.v1.Graph().as_default():
- audio = tf.reshape(tf.linspace(0.0, 100.0, 4 * 10 * 2), (4, 10, 2))
- op = audio_summary.op('k488',
- tf.cast(audio, tf.float32),
- sample_rate=44100,
- display_name='Piano Concerto No.23',
- description='In **A major**.')
- value = self._value_from_op(op)
- assert value.HasField('tensor'), value
- self._assert_noop(value)
-
- def test_new_style_scalar(self):
- with tf.compat.v1.Graph().as_default():
- op = scalar_summary.op('important_constants', tf.constant(0x5f3759df),
- display_name='Important constants',
- description='evil floating point bit magic')
- value = self._value_from_op(op)
- assert value.HasField('tensor'), value
- self._assert_noop(value)
-
-
-if __name__ == '__main__':
- tf.test.main()
+ """Tests for `migrate_value`.
+
+ These tests should ensure that all first-party new-style values are
+ passed through unchanged, that all supported old-style values are
+ converted to new-style values, and that other old-style values are
+ passed through unchanged.
+ """
+
+ def _value_from_op(self, op):
+ with tf.compat.v1.Session() as sess:
+ summary_pbtxt = sess.run(op)
+ summary = summary_pb2.Summary()
+ summary.ParseFromString(summary_pbtxt)
+ # There may be multiple values (e.g., for an image summary that emits
+ # multiple images in one batch). That's fine; we'll choose any
+ # representative value, assuming that they're homogeneous.
+ assert summary.value
+ return summary.value[0]
+
+ def _assert_noop(self, value):
+ original_pbtxt = value.SerializeToString()
+ result = data_compat.migrate_value(value)
+ self.assertEqual(value, result)
+ self.assertEqual(original_pbtxt, value.SerializeToString())
+
+ def test_scalar(self):
+ with tf.compat.v1.Graph().as_default():
+ old_op = tf.compat.v1.summary.scalar(
+ "important_constants", tf.constant(0x5F3759DF)
+ )
+ old_value = self._value_from_op(old_op)
+ assert old_value.HasField("simple_value"), old_value
+ new_value = data_compat.migrate_value(old_value)
+
+ self.assertEqual("important_constants", new_value.tag)
+ expected_metadata = scalar_metadata.create_summary_metadata(
+ display_name="important_constants", description=""
+ )
+ self.assertEqual(expected_metadata, new_value.metadata)
+ self.assertTrue(new_value.HasField("tensor"))
+ data = tensor_util.make_ndarray(new_value.tensor)
+ self.assertEqual((), data.shape)
+ low_precision_value = np.array(0x5F3759DF).astype("float32").item()
+ self.assertEqual(low_precision_value, data.item())
+
+ def test_audio(self):
+ with tf.compat.v1.Graph().as_default():
+ audio = tf.reshape(tf.linspace(0.0, 100.0, 4 * 10 * 2), (4, 10, 2))
+ old_op = tf.compat.v1.summary.audio("k488", audio, 44100)
+ old_value = self._value_from_op(old_op)
+ assert old_value.HasField("audio"), old_value
+ new_value = data_compat.migrate_value(old_value)
+
+ self.assertEqual("k488/audio/0", new_value.tag)
+ expected_metadata = audio_metadata.create_summary_metadata(
+ display_name="k488/audio/0",
+ description="",
+ encoding=audio_metadata.Encoding.Value("WAV"),
+ )
+ self.assertEqual(expected_metadata, new_value.metadata)
+ self.assertTrue(new_value.HasField("tensor"))
+ data = tensor_util.make_ndarray(new_value.tensor)
+ self.assertEqual((1, 2), data.shape)
+ self.assertEqual(
+ tf.compat.as_bytes(old_value.audio.encoded_audio_string), data[0][0]
+ )
+ self.assertEqual(b"", data[0][1]) # empty label
+
+ def test_text(self):
+ with tf.compat.v1.Graph().as_default():
+ op = tf.compat.v1.summary.text(
+ "lorem_ipsum", tf.constant("dolor sit amet")
+ )
+ value = self._value_from_op(op)
+ assert value.HasField("tensor"), value
+ self._assert_noop(value)
+
+ def test_fully_populated_tensor(self):
+ with tf.compat.v1.Graph().as_default():
+ metadata = summary_pb2.SummaryMetadata(
+ plugin_data=summary_pb2.SummaryMetadata.PluginData(
+ plugin_name="font_of_wisdom", content=b"adobe_garamond"
+ )
+ )
+ op = tf.compat.v1.summary.tensor_summary(
+ name="tensorpocalypse",
+ tensor=tf.constant([[0.0, 2.0], [float("inf"), float("nan")]]),
+ display_name="TENSORPOCALYPSE",
+ summary_description="look on my works ye mighty and despair",
+ summary_metadata=metadata,
+ )
+ value = self._value_from_op(op)
+ assert value.HasField("tensor"), value
+ self._assert_noop(value)
+
+ def test_image(self):
+ with tf.compat.v1.Graph().as_default():
+ old_op = tf.compat.v1.summary.image(
+ "mona_lisa",
+ tf.image.convert_image_dtype(
+ tf.random.normal(shape=[1, 400, 200, 3]),
+ tf.uint8,
+ saturate=True,
+ ),
+ )
+ old_value = self._value_from_op(old_op)
+ assert old_value.HasField("image"), old_value
+ new_value = data_compat.migrate_value(old_value)
+
+ self.assertEqual("mona_lisa/image/0", new_value.tag)
+ expected_metadata = image_metadata.create_summary_metadata(
+ display_name="mona_lisa/image/0", description=""
+ )
+ self.assertEqual(expected_metadata, new_value.metadata)
+ self.assertTrue(new_value.HasField("tensor"))
+ (width, height, data) = tensor_util.make_ndarray(new_value.tensor)
+ self.assertEqual(b"200", width)
+ self.assertEqual(b"400", height)
+ self.assertEqual(
+ tf.compat.as_bytes(old_value.image.encoded_image_string), data
+ )
+
+ def test_histogram(self):
+ with tf.compat.v1.Graph().as_default():
+ old_op = tf.compat.v1.summary.histogram(
+ "important_data", tf.random.normal(shape=[23, 45])
+ )
+ old_value = self._value_from_op(old_op)
+ assert old_value.HasField("histo"), old_value
+ new_value = data_compat.migrate_value(old_value)
+
+ self.assertEqual("important_data", new_value.tag)
+ expected_metadata = histogram_metadata.create_summary_metadata(
+ display_name="important_data", description=""
+ )
+ self.assertEqual(expected_metadata, new_value.metadata)
+ self.assertTrue(new_value.HasField("tensor"))
+ buckets = tensor_util.make_ndarray(new_value.tensor)
+ self.assertEqual(old_value.histo.min, buckets[0][0])
+ self.assertEqual(old_value.histo.max, buckets[-1][1])
+ self.assertEqual(23 * 45, buckets[:, 2].astype(int).sum())
+
+ def test_new_style_histogram(self):
+ with tf.compat.v1.Graph().as_default():
+ op = histogram_summary.op(
+ "important_data",
+ tf.random.normal(shape=[10, 10]),
+ bucket_count=100,
+ display_name="Important data",
+ description="secrets of the universe",
+ )
+ value = self._value_from_op(op)
+ assert value.HasField("tensor"), value
+ self._assert_noop(value)
+
+ def test_new_style_image(self):
+ with tf.compat.v1.Graph().as_default():
+ op = image_summary.op(
+ "mona_lisa",
+ tf.image.convert_image_dtype(
+ tf.random.normal(shape=[1, 400, 200, 3]),
+ tf.uint8,
+ saturate=True,
+ ),
+ display_name="The Mona Lisa",
+ description="A renowned portrait by da Vinci.",
+ )
+ value = self._value_from_op(op)
+ assert value.HasField("tensor"), value
+ self._assert_noop(value)
+
+ def test_new_style_audio(self):
+ with tf.compat.v1.Graph().as_default():
+ audio = tf.reshape(tf.linspace(0.0, 100.0, 4 * 10 * 2), (4, 10, 2))
+ op = audio_summary.op(
+ "k488",
+ tf.cast(audio, tf.float32),
+ sample_rate=44100,
+ display_name="Piano Concerto No.23",
+ description="In **A major**.",
+ )
+ value = self._value_from_op(op)
+ assert value.HasField("tensor"), value
+ self._assert_noop(value)
+
+ def test_new_style_scalar(self):
+ with tf.compat.v1.Graph().as_default():
+ op = scalar_summary.op(
+ "important_constants",
+ tf.constant(0x5F3759DF),
+ display_name="Important constants",
+ description="evil floating point bit magic",
+ )
+ value = self._value_from_op(op)
+ assert value.HasField("tensor"), value
+ self._assert_noop(value)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorboard/default.py b/tensorboard/default.py
index 9ef04d2adc..963a8fef14 100644
--- a/tensorboard/default.py
+++ b/tensorboard/default.py
@@ -47,7 +47,7 @@
from tensorboard.plugins.hparams import hparams_plugin
from tensorboard.plugins.image import images_plugin
from tensorboard.plugins.interactive_inference import (
- interactive_inference_plugin_loader
+ interactive_inference_plugin_loader,
)
from tensorboard.plugins.pr_curve import pr_curves_plugin
from tensorboard.plugins.profile import profile_plugin_loader
@@ -80,39 +80,42 @@
mesh_plugin.MeshPlugin,
]
+
def get_plugins():
- """Returns a list specifying TensorBoard's default first-party plugins.
+ """Returns a list specifying TensorBoard's default first-party plugins.
- Plugins are specified in this list either via a TBLoader instance to load the
- plugin, or the TBPlugin class itself which will be loaded using a BasicLoader.
+ Plugins are specified in this list either via a TBLoader instance to load the
+ plugin, or the TBPlugin class itself which will be loaded using a BasicLoader.
- This list can be passed to the `tensorboard.program.TensorBoard` API.
+ This list can be passed to the `tensorboard.program.TensorBoard` API.
- Returns:
- The list of default plugins.
+ Returns:
+ The list of default plugins.
- :rtype: list[Type[base_plugin.TBLoader] | Type[base_plugin.TBPlugin]]
- """
+ :rtype: list[Type[base_plugin.TBLoader] | Type[base_plugin.TBPlugin]]
+ """
- return _PLUGINS[:]
+ return _PLUGINS[:]
def get_dynamic_plugins():
- """Returns a list specifying TensorBoard's dynamically loaded plugins.
+ """Returns a list specifying TensorBoard's dynamically loaded plugins.
- A dynamic TensorBoard plugin is specified using entry_points [1] and it is
- the robust way to integrate plugins into TensorBoard.
+ A dynamic TensorBoard plugin is specified using entry_points [1] and it is
+ the robust way to integrate plugins into TensorBoard.
- This list can be passed to the `tensorboard.program.TensorBoard` API.
+ This list can be passed to the `tensorboard.program.TensorBoard` API.
- Returns:
- The list of dynamic plugins.
+ Returns:
+ The list of dynamic plugins.
- :rtype: list[Type[base_plugin.TBLoader] | Type[base_plugin.TBPlugin]]
+ :rtype: list[Type[base_plugin.TBLoader] | Type[base_plugin.TBPlugin]]
- [1]: https://packaging.python.org/specifications/entry-points/
- """
- return [
- entry_point.load()
- for entry_point in pkg_resources.iter_entry_points('tensorboard_plugins')
- ]
+ [1]: https://packaging.python.org/specifications/entry-points/
+ """
+ return [
+ entry_point.load()
+ for entry_point in pkg_resources.iter_entry_points(
+ "tensorboard_plugins"
+ )
+ ]
diff --git a/tensorboard/default_test.py b/tensorboard/default_test.py
index 39665af169..17bc6c802b 100644
--- a/tensorboard/default_test.py
+++ b/tensorboard/default_test.py
@@ -19,10 +19,10 @@
from __future__ import print_function
try:
- # python version >= 3.3
- from unittest import mock
+ # python version >= 3.3
+ from unittest import mock
except ImportError:
- import mock # pylint: disable=unused-import
+ import mock # pylint: disable=unused-import
import pkg_resources
@@ -32,43 +32,42 @@
class FakePlugin(base_plugin.TBPlugin):
- """FakePlugin for testing."""
+ """FakePlugin for testing."""
- plugin_name = 'fake'
+ plugin_name = "fake"
class FakeEntryPoint(pkg_resources.EntryPoint):
- """EntryPoint class that fake loads FakePlugin."""
+ """EntryPoint class that fake loads FakePlugin."""
- @classmethod
- def create(cls):
- """Creates an instance of FakeEntryPoint.
+ @classmethod
+ def create(cls):
+ """Creates an instance of FakeEntryPoint.
- Returns:
- instance of FakeEntryPoint
- """
- return cls('foo', 'bar')
+ Returns:
+ instance of FakeEntryPoint
+ """
+ return cls("foo", "bar")
- def load(self):
- """Returns FakePlugin instead of resolving module.
+ def load(self):
+ """Returns FakePlugin instead of resolving module.
- Returns:
- FakePlugin
- """
- return FakePlugin
+ Returns:
+ FakePlugin
+ """
+ return FakePlugin
class DefaultTest(test.TestCase):
+ @mock.patch.object(pkg_resources, "iter_entry_points")
+ def test_get_dynamic_plugin(self, mock_iter_entry_points):
+ mock_iter_entry_points.return_value = [FakeEntryPoint.create()]
- @mock.patch.object(pkg_resources, 'iter_entry_points')
- def test_get_dynamic_plugin(self, mock_iter_entry_points):
- mock_iter_entry_points.return_value = [FakeEntryPoint.create()]
+ actual_plugins = default.get_dynamic_plugins()
- actual_plugins = default.get_dynamic_plugins()
-
- mock_iter_entry_points.assert_called_with('tensorboard_plugins')
- self.assertEqual(actual_plugins, [FakePlugin])
+ mock_iter_entry_points.assert_called_with("tensorboard_plugins")
+ self.assertEqual(actual_plugins, [FakePlugin])
if __name__ == "__main__":
- test.main()
+ test.main()
diff --git a/tensorboard/defs/tb_proto_library_test.py b/tensorboard/defs/tb_proto_library_test.py
index 425e1d06f9..3fc3bb5763 100644
--- a/tensorboard/defs/tb_proto_library_test.py
+++ b/tensorboard/defs/tb_proto_library_test.py
@@ -26,19 +26,19 @@
class TbProtoLibraryTest(tb_test.TestCase):
- """Tests for `tb_proto_library`."""
+ """Tests for `tb_proto_library`."""
- def tests_with_deps(self):
- foo = test_base_pb2.Foo()
- foo.foo = 1
- bar = test_downstream_pb2.Bar()
- bar.foo.foo = 1
- self.assertEqual(foo, bar.foo)
+ def tests_with_deps(self):
+ foo = test_base_pb2.Foo()
+ foo.foo = 1
+ bar = test_downstream_pb2.Bar()
+ bar.foo.foo = 1
+ self.assertEqual(foo, bar.foo)
- def test_service_deps(self):
- self.assertIsInstance(test_base_pb2_grpc.FooServiceServicer, type)
- self.assertIsInstance(test_downstream_pb2_grpc.BarServiceServicer, type)
+ def test_service_deps(self):
+ self.assertIsInstance(test_base_pb2_grpc.FooServiceServicer, type)
+ self.assertIsInstance(test_downstream_pb2_grpc.BarServiceServicer, type)
if __name__ == "__main__":
- tb_test.main()
+ tb_test.main()
diff --git a/tensorboard/defs/web_test_python_stub.template.py b/tensorboard/defs/web_test_python_stub.template.py
index bce71ad88d..59b40f1515 100644
--- a/tensorboard/defs/web_test_python_stub.template.py
+++ b/tensorboard/defs/web_test_python_stub.template.py
@@ -20,10 +20,8 @@
from tensorboard.functionaltests import wct_test_driver
-Test = wct_test_driver.create_test_class(
- "{BINARY_PATH}",
- "{WEB_PATH}")
+Test = wct_test_driver.create_test_class("{BINARY_PATH}", "{WEB_PATH}")
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/tensorboard/encode_png_benchmark.py b/tensorboard/encode_png_benchmark.py
index 5550bd74cf..703482a7e3 100644
--- a/tensorboard/encode_png_benchmark.py
+++ b/tensorboard/encode_png_benchmark.py
@@ -64,80 +64,91 @@
def bench(image, thread_count):
- """Encode `image` to PNG on `thread_count` threads in parallel.
-
- Returns:
- A `float` representing number of seconds that it takes all threads
- to finish encoding `image`.
- """
- threads = [threading.Thread(target=lambda: encoder.encode_png(image))
- for _ in xrange(thread_count)]
- start_time = datetime.datetime.now()
- for thread in threads:
- thread.start()
- for thread in threads:
- thread.join()
- end_time = datetime.datetime.now()
- delta = (end_time - start_time).total_seconds()
- return delta
+ """Encode `image` to PNG on `thread_count` threads in parallel.
+
+ Returns:
+ A `float` representing number of seconds that it takes all threads
+ to finish encoding `image`.
+ """
+ threads = [
+ threading.Thread(target=lambda: encoder.encode_png(image))
+ for _ in xrange(thread_count)
+ ]
+ start_time = datetime.datetime.now()
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+ end_time = datetime.datetime.now()
+ delta = (end_time - start_time).total_seconds()
+ return delta
def _image_of_size(image_size):
- """Generate a square RGB test image of the given side length."""
- return np.random.uniform(0, 256, [image_size, image_size, 3]).astype(np.uint8)
+ """Generate a square RGB test image of the given side length."""
+ return np.random.uniform(0, 256, [image_size, image_size, 3]).astype(
+ np.uint8
+ )
def _format_line(headers, fields):
- """Format a line of a table.
-
- Arguments:
- headers: A list of strings that are used as the table headers.
- fields: A list of the same length as `headers` where `fields[i]` is
- the entry for `headers[i]` in this row. Elements can be of
- arbitrary types. Pass `headers` to print the header row.
-
- Returns:
- A pretty string.
- """
- assert len(fields) == len(headers), (fields, headers)
- fields = ["%2.4f" % field if isinstance(field, float) else str(field)
- for field in fields]
- return ' '.join(' ' * max(0, len(header) - len(field)) + field
- for (header, field) in zip(headers, fields))
+ """Format a line of a table.
+
+ Arguments:
+ headers: A list of strings that are used as the table headers.
+ fields: A list of the same length as `headers` where `fields[i]` is
+ the entry for `headers[i]` in this row. Elements can be of
+ arbitrary types. Pass `headers` to print the header row.
+
+ Returns:
+ A pretty string.
+ """
+ assert len(fields) == len(headers), (fields, headers)
+ fields = [
+ "%2.4f" % field if isinstance(field, float) else str(field)
+ for field in fields
+ ]
+ return " ".join(
+ " " * max(0, len(header) - len(field)) + field
+ for (header, field) in zip(headers, fields)
+ )
def main(unused_argv):
- logging.set_verbosity(logging.INFO)
- np.random.seed(0)
-
- thread_counts = [1, 2, 4, 6, 8, 10, 12, 14, 16, 32]
-
- logger.info("Warming up...")
- warmup_image = _image_of_size(256)
- for thread_count in thread_counts:
- bench(warmup_image, thread_count)
-
- logger.info("Running...")
- results = {}
- image = _image_of_size(4096)
- headers = ('THREADS', 'TOTAL_TIME', 'UNIT_TIME', 'SPEEDUP', 'PARALLELISM')
- logger.info(_format_line(headers, headers))
- for thread_count in thread_counts:
- time.sleep(1.0)
- total_time = min(bench(image, thread_count)
- for _ in xrange(3)) # best-of-three timing
- unit_time = total_time / thread_count
- if total_time < 2.0:
- logger.warn("This benchmark is running too quickly! This "
- "may cause misleading timing data. Consider "
- "increasing the image size until it takes at "
- "least 2.0s to encode one image.")
- results[thread_count] = unit_time
- speedup = results[1] / results[thread_count]
- parallelism = speedup / thread_count
- fields = (thread_count, total_time, unit_time, speedup, parallelism)
- logger.info(_format_line(headers, fields))
-
-
-if __name__ == '__main__':
- app.run(main)
+ logging.set_verbosity(logging.INFO)
+ np.random.seed(0)
+
+ thread_counts = [1, 2, 4, 6, 8, 10, 12, 14, 16, 32]
+
+ logger.info("Warming up...")
+ warmup_image = _image_of_size(256)
+ for thread_count in thread_counts:
+ bench(warmup_image, thread_count)
+
+ logger.info("Running...")
+ results = {}
+ image = _image_of_size(4096)
+ headers = ("THREADS", "TOTAL_TIME", "UNIT_TIME", "SPEEDUP", "PARALLELISM")
+ logger.info(_format_line(headers, headers))
+ for thread_count in thread_counts:
+ time.sleep(1.0)
+ total_time = min(
+ bench(image, thread_count) for _ in xrange(3)
+ ) # best-of-three timing
+ unit_time = total_time / thread_count
+ if total_time < 2.0:
+ logger.warn(
+ "This benchmark is running too quickly! This "
+ "may cause misleading timing data. Consider "
+ "increasing the image size until it takes at "
+ "least 2.0s to encode one image."
+ )
+ results[thread_count] = unit_time
+ speedup = results[1] / results[thread_count]
+ parallelism = speedup / thread_count
+ fields = (thread_count, total_time, unit_time, speedup, parallelism)
+ logger.info(_format_line(headers, fields))
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/tensorboard/errors.py b/tensorboard/errors.py
index f937eaa545..4ad906c1c3 100644
--- a/tensorboard/errors.py
+++ b/tensorboard/errors.py
@@ -30,67 +30,67 @@
class PublicError(RuntimeError):
- """An error whose text does not contain sensitive information."""
+ """An error whose text does not contain sensitive information."""
- http_code = 500 # default; subclasses should override
+ http_code = 500 # default; subclasses should override
- def __init__(self, details):
- super(PublicError, self).__init__(details)
+ def __init__(self, details):
+ super(PublicError, self).__init__(details)
class InvalidArgumentError(PublicError):
- """Client specified an invalid argument.
+ """Client specified an invalid argument.
- The text of this error is assumed not to contain sensitive data,
- and so may appear in (e.g.) the response body of a failed HTTP
- request.
+ The text of this error is assumed not to contain sensitive data,
+ and so may appear in (e.g.) the response body of a failed HTTP
+ request.
- Corresponds to HTTP 400 Bad Request or Google error code `INVALID_ARGUMENT`.
- """
+ Corresponds to HTTP 400 Bad Request or Google error code `INVALID_ARGUMENT`.
+ """
- http_code = 400
+ http_code = 400
- def __init__(self, details=None):
- msg = _format_message("Invalid argument", details)
- super(InvalidArgumentError, self).__init__(msg)
+ def __init__(self, details=None):
+ msg = _format_message("Invalid argument", details)
+ super(InvalidArgumentError, self).__init__(msg)
class NotFoundError(PublicError):
- """Some requested entity (e.g., file or directory) was not found.
+ """Some requested entity (e.g., file or directory) was not found.
- The text of this error is assumed not to contain sensitive data,
- and so may appear in (e.g.) the response body of a failed HTTP
- request.
+ The text of this error is assumed not to contain sensitive data,
+ and so may appear in (e.g.) the response body of a failed HTTP
+ request.
- Corresponds to HTTP 404 Not Found or Google error code `NOT_FOUND`.
- """
+ Corresponds to HTTP 404 Not Found or Google error code `NOT_FOUND`.
+ """
- http_code = 404
+ http_code = 404
- def __init__(self, details=None):
- msg = _format_message("Not found", details)
- super(NotFoundError, self).__init__(msg)
+ def __init__(self, details=None):
+ msg = _format_message("Not found", details)
+ super(NotFoundError, self).__init__(msg)
class PermissionDeniedError(PublicError):
- """The caller does not have permission to execute the specified operation.
+ """The caller does not have permission to execute the specified operation.
- The text of this error is assumed not to contain sensitive data,
- and so may appear in (e.g.) the response body of a failed HTTP
- request.
+ The text of this error is assumed not to contain sensitive data,
+ and so may appear in (e.g.) the response body of a failed HTTP
+ request.
- Corresponds to HTTP 403 Forbidden or Google error code `PERMISSION_DENIED`.
- """
+ Corresponds to HTTP 403 Forbidden or Google error code `PERMISSION_DENIED`.
+ """
- http_code = 403
+ http_code = 403
- def __init__(self, details=None):
- msg = _format_message("Permission denied", details)
- super(PermissionDeniedError, self).__init__(msg)
+ def __init__(self, details=None):
+ msg = _format_message("Permission denied", details)
+ super(PermissionDeniedError, self).__init__(msg)
def _format_message(code_name, details):
- if details is None:
- return code_name
- else:
- return "%s: %s" % (code_name, details)
+ if details is None:
+ return code_name
+ else:
+ return "%s: %s" % (code_name, details)
diff --git a/tensorboard/errors_test.py b/tensorboard/errors_test.py
index b0c6ac0156..050065a6d3 100644
--- a/tensorboard/errors_test.py
+++ b/tensorboard/errors_test.py
@@ -23,52 +23,49 @@
class InvalidArgumentErrorTest(tb_test.TestCase):
+ def test_no_details(self):
+ e = errors.InvalidArgumentError()
+ expected_msg = "Invalid argument"
+ self.assertEqual(str(e), expected_msg)
- def test_no_details(self):
- e = errors.InvalidArgumentError()
- expected_msg = "Invalid argument"
- self.assertEqual(str(e), expected_msg)
+ def test_with_details(self):
+ e = errors.InvalidArgumentError("expected absolute path; got './foo'")
+ expected_msg = "Invalid argument: expected absolute path; got './foo'"
+ self.assertEqual(str(e), expected_msg)
- def test_with_details(self):
- e = errors.InvalidArgumentError("expected absolute path; got './foo'")
- expected_msg = "Invalid argument: expected absolute path; got './foo'"
- self.assertEqual(str(e), expected_msg)
-
- def test_http_code(self):
- self.assertEqual(errors.InvalidArgumentError().http_code, 400)
+ def test_http_code(self):
+ self.assertEqual(errors.InvalidArgumentError().http_code, 400)
class NotFoundErrorTest(tb_test.TestCase):
+ def test_no_details(self):
+ e = errors.NotFoundError()
+ expected_msg = "Not found"
+ self.assertEqual(str(e), expected_msg)
- def test_no_details(self):
- e = errors.NotFoundError()
- expected_msg = "Not found"
- self.assertEqual(str(e), expected_msg)
-
- def test_with_details(self):
- e = errors.NotFoundError("no scalar data for run=foo, tag=bar")
- expected_msg = "Not found: no scalar data for run=foo, tag=bar"
- self.assertEqual(str(e), expected_msg)
+ def test_with_details(self):
+ e = errors.NotFoundError("no scalar data for run=foo, tag=bar")
+ expected_msg = "Not found: no scalar data for run=foo, tag=bar"
+ self.assertEqual(str(e), expected_msg)
- def test_http_code(self):
- self.assertEqual(errors.NotFoundError().http_code, 404)
+ def test_http_code(self):
+ self.assertEqual(errors.NotFoundError().http_code, 404)
class PermissionDeniedErrorTest(tb_test.TestCase):
+ def test_no_details(self):
+ e = errors.PermissionDeniedError()
+ expected_msg = "Permission denied"
+ self.assertEqual(str(e), expected_msg)
- def test_no_details(self):
- e = errors.PermissionDeniedError()
- expected_msg = "Permission denied"
- self.assertEqual(str(e), expected_msg)
-
- def test_with_details(self):
- e = errors.PermissionDeniedError("this data is top secret")
- expected_msg = "Permission denied: this data is top secret"
- self.assertEqual(str(e), expected_msg)
+ def test_with_details(self):
+ e = errors.PermissionDeniedError("this data is top secret")
+ expected_msg = "Permission denied: this data is top secret"
+ self.assertEqual(str(e), expected_msg)
- def test_http_code(self):
- self.assertEqual(errors.PermissionDeniedError().http_code, 403)
+ def test_http_code(self):
+ self.assertEqual(errors.PermissionDeniedError().http_code, 403)
if __name__ == "__main__":
- tb_test.main()
+ tb_test.main()
diff --git a/tensorboard/examples/plugins/example_basic/setup.py b/tensorboard/examples/plugins/example_basic/setup.py
index e774ecbeec..54da6e0c50 100644
--- a/tensorboard/examples/plugins/example_basic/setup.py
+++ b/tensorboard/examples/plugins/example_basic/setup.py
@@ -25,9 +25,7 @@
version="0.1.0",
description="Sample TensorBoard plugin.",
packages=["tensorboard_plugin_example"],
- package_data={
- "tensorboard_plugin_example": ["static/**"],
- },
+ package_data={"tensorboard_plugin_example": ["static/**"],},
entry_points={
"tensorboard_plugins": [
"example_basic = tensorboard_plugin_example.plugin:ExamplePlugin",
diff --git a/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/demo.py b/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/demo.py
index 3f7260481b..536e2c1b35 100644
--- a/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/demo.py
+++ b/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/demo.py
@@ -29,18 +29,17 @@
def main(unused_argv):
- writer = tf.summary.create_file_writer("demo_logs")
- with writer.as_default():
- summary_v2.greeting(
- "guestbook",
- "Alice",
- step=0,
- description="Sign your name!",
- )
- summary_v2.greeting("guestbook", "Bob", step=1) # no need for `description`
- summary_v2.greeting("guestbook", "Cheryl", step=2)
- summary_v2.greeting("more_names", "David", step=4)
+ writer = tf.summary.create_file_writer("demo_logs")
+ with writer.as_default():
+ summary_v2.greeting(
+ "guestbook", "Alice", step=0, description="Sign your name!",
+ )
+ summary_v2.greeting(
+ "guestbook", "Bob", step=1
+ ) # no need for `description`
+ summary_v2.greeting("guestbook", "Cheryl", step=2)
+ summary_v2.greeting("more_names", "David", step=4)
if __name__ == "__main__":
- app.run(main)
+ app.run(main)
diff --git a/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/plugin.py b/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/plugin.py
index acfccb3dbd..a8bacd0750 100644
--- a/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/plugin.py
+++ b/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/plugin.py
@@ -32,59 +32,66 @@
class ExamplePlugin(base_plugin.TBPlugin):
- plugin_name = metadata.PLUGIN_NAME
+ plugin_name = metadata.PLUGIN_NAME
- def __init__(self, context):
- self._multiplexer = context.multiplexer
+ def __init__(self, context):
+ self._multiplexer = context.multiplexer
- def is_active(self):
- return bool(self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME))
+ def is_active(self):
+ return bool(
+ self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME)
+ )
- def get_plugin_apps(self):
- return {
- "/index.js": self._serve_js,
- "/tags": self._serve_tags,
- "/greetings": self._serve_greetings,
- }
+ def get_plugin_apps(self):
+ return {
+ "/index.js": self._serve_js,
+ "/tags": self._serve_tags,
+ "/greetings": self._serve_greetings,
+ }
- def frontend_metadata(self):
- return base_plugin.FrontendMetadata(es_module_path="/index.js")
+ def frontend_metadata(self):
+ return base_plugin.FrontendMetadata(es_module_path="/index.js")
- @wrappers.Request.application
- def _serve_js(self, request):
- del request # unused
- filepath = os.path.join(os.path.dirname(__file__), "static", "index.js")
- with open(filepath) as infile:
- contents = infile.read()
- return werkzeug.Response(contents, content_type="application/javascript")
+ @wrappers.Request.application
+ def _serve_js(self, request):
+ del request # unused
+ filepath = os.path.join(os.path.dirname(__file__), "static", "index.js")
+ with open(filepath) as infile:
+ contents = infile.read()
+ return werkzeug.Response(
+ contents, content_type="application/javascript"
+ )
- @wrappers.Request.application
- def _serve_tags(self, request):
- del request # unused
- mapping = self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME)
- result = {run: {} for run in self._multiplexer.Runs()}
- for (run, tag_to_content) in six.iteritems(mapping):
- for tag in tag_to_content:
- summary_metadata = self._multiplexer.SummaryMetadata(run, tag)
- result[run][tag] = {
- u"description": summary_metadata.summary_description,
- }
- contents = json.dumps(result, sort_keys=True)
- return werkzeug.Response(contents, content_type="application/json")
+ @wrappers.Request.application
+ def _serve_tags(self, request):
+ del request # unused
+ mapping = self._multiplexer.PluginRunToTagToContent(
+ metadata.PLUGIN_NAME
+ )
+ result = {run: {} for run in self._multiplexer.Runs()}
+ for (run, tag_to_content) in six.iteritems(mapping):
+ for tag in tag_to_content:
+ summary_metadata = self._multiplexer.SummaryMetadata(run, tag)
+ result[run][tag] = {
+ u"description": summary_metadata.summary_description,
+ }
+ contents = json.dumps(result, sort_keys=True)
+ return werkzeug.Response(contents, content_type="application/json")
- @wrappers.Request.application
- def _serve_greetings(self, request):
- run = request.args.get("run")
- tag = request.args.get("tag")
- if run is None or tag is None:
- raise werkzeug.exceptions.BadRequest("Must specify run and tag")
- try:
- data = [
- np.asscalar(tensor_util.make_ndarray(event.tensor_proto))
- .decode("utf-8")
- for event in self._multiplexer.Tensors(run, tag)
- ]
- except KeyError:
- raise werkzeug.exceptions.BadRequest("Invalid run or tag")
- contents = json.dumps(data, sort_keys=True)
- return werkzeug.Response(contents, content_type="application/json")
+ @wrappers.Request.application
+ def _serve_greetings(self, request):
+ run = request.args.get("run")
+ tag = request.args.get("tag")
+ if run is None or tag is None:
+ raise werkzeug.exceptions.BadRequest("Must specify run and tag")
+ try:
+ data = [
+ np.asscalar(
+ tensor_util.make_ndarray(event.tensor_proto)
+ ).decode("utf-8")
+ for event in self._multiplexer.Tensors(run, tag)
+ ]
+ except KeyError:
+ raise werkzeug.exceptions.BadRequest("Invalid run or tag")
+ contents = json.dumps(data, sort_keys=True)
+ return werkzeug.Response(contents, content_type="application/json")
diff --git a/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/summary_v2.py b/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/summary_v2.py
index 04c80d0f58..fd8da23c0d 100644
--- a/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/summary_v2.py
+++ b/tensorboard/examples/plugins/example_basic/tensorboard_plugin_example/summary_v2.py
@@ -26,42 +26,42 @@
def greeting(name, guest, step=None, description=None):
- """Write a "greeting" summary.
+ """Write a "greeting" summary.
- Arguments:
- name: A name for this summary. The summary tag used for TensorBoard will
- be this name prefixed by any active name scopes.
- guest: A rank-0 string `Tensor`.
- step: Explicit `int64`-castable monotonic step value for this summary. If
- omitted, this defaults to `tf.summary.experimental.get_step()`, which must
- not be None.
- description: Optional long-form description for this summary, as a
- constant `str`. Markdown is supported. Defaults to empty.
+ Arguments:
+ name: A name for this summary. The summary tag used for TensorBoard will
+ be this name prefixed by any active name scopes.
+ guest: A rank-0 string `Tensor`.
+ step: Explicit `int64`-castable monotonic step value for this summary. If
+ omitted, this defaults to `tf.summary.experimental.get_step()`, which must
+ not be None.
+ description: Optional long-form description for this summary, as a
+ constant `str`. Markdown is supported. Defaults to empty.
- Returns:
- True on success, or false if no summary was written because no default
- summary writer was available.
+ Returns:
+ True on success, or false if no summary was written because no default
+ summary writer was available.
- Raises:
- ValueError: if a default writer exists, but no step was provided and
- `tf.summary.experimental.get_step()` is None.
- """
- with tf.summary.experimental.summary_scope(
- name, "greeting_summary", values=[guest, step],
- ) as (tag, _):
- return tf.summary.write(
- tag=tag,
- tensor=tf.strings.join(["Hello, ", guest, "!"]),
- step=step,
- metadata=_create_summary_metadata(description),
- )
+ Raises:
+ ValueError: if a default writer exists, but no step was provided and
+ `tf.summary.experimental.get_step()` is None.
+ """
+ with tf.summary.experimental.summary_scope(
+ name, "greeting_summary", values=[guest, step],
+ ) as (tag, _):
+ return tf.summary.write(
+ tag=tag,
+ tensor=tf.strings.join(["Hello, ", guest, "!"]),
+ step=step,
+ metadata=_create_summary_metadata(description),
+ )
def _create_summary_metadata(description):
- return summary_pb2.SummaryMetadata(
- summary_description=description,
- plugin_data=summary_pb2.SummaryMetadata.PluginData(
- plugin_name=metadata.PLUGIN_NAME,
- content=b"", # no need for summary-specific metadata
- ),
- )
+ return summary_pb2.SummaryMetadata(
+ summary_description=description,
+ plugin_data=summary_pb2.SummaryMetadata.PluginData(
+ plugin_name=metadata.PLUGIN_NAME,
+ content=b"", # no need for summary-specific metadata
+ ),
+ )
diff --git a/tensorboard/functionaltests/core_test.py b/tensorboard/functionaltests/core_test.py
index 0a3a6f13f0..3953c5b99f 100644
--- a/tensorboard/functionaltests/core_test.py
+++ b/tensorboard/functionaltests/core_test.py
@@ -37,171 +37,186 @@
class BasicTest(unittest.TestCase):
- """Tests that the basic chrome is displayed when there is no data."""
-
- @classmethod
- def setUpClass(cls):
- src_dir = os.environ["TEST_SRCDIR"]
- binary = os.path.join(src_dir,
- "org_tensorflow_tensorboard/tensorboard/tensorboard")
- cls.logdir = tempfile.mkdtemp(prefix='core_test_%s_logdir_' % cls.__name__)
- cls.setUpData()
- cls.port = _BASE_PORT + _PORT_OFFSETS[cls]
- cls.process = subprocess.Popen(
- [binary, "--port", str(cls.port), "--logdir", cls.logdir])
-
- @classmethod
- def setUpData(cls):
- # Overridden by DashboardsTest.
- pass
-
- @classmethod
- def tearDownClass(cls):
- cls.process.kill()
- cls.process.wait()
-
- def setUp(self):
- self.driver = webtest.new_webdriver_session()
- self.driver.get("http://localhost:%s" % self.port)
- self.wait = wait.WebDriverWait(self.driver, 10)
-
- def tearDown(self):
- try:
- self.driver.quit()
- finally:
- self.driver = None
- self.wait = None
-
- def testToolbarTitleDisplays(self):
- self.wait.until(
- expected_conditions.text_to_be_present_in_element((
- by.By.CLASS_NAME, "toolbar-title"), "TensorBoard"))
-
- def testLogdirDisplays(self):
- self.wait.until(
- expected_conditions.text_to_be_present_in_element((
- by.By.ID, "data_location"), self.logdir))
+ """Tests that the basic chrome is displayed when there is no data."""
+
+ @classmethod
+ def setUpClass(cls):
+ src_dir = os.environ["TEST_SRCDIR"]
+ binary = os.path.join(
+ src_dir, "org_tensorflow_tensorboard/tensorboard/tensorboard"
+ )
+ cls.logdir = tempfile.mkdtemp(
+ prefix="core_test_%s_logdir_" % cls.__name__
+ )
+ cls.setUpData()
+ cls.port = _BASE_PORT + _PORT_OFFSETS[cls]
+ cls.process = subprocess.Popen(
+ [binary, "--port", str(cls.port), "--logdir", cls.logdir]
+ )
+
+ @classmethod
+ def setUpData(cls):
+ # Overridden by DashboardsTest.
+ pass
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.process.kill()
+ cls.process.wait()
+
+ def setUp(self):
+ self.driver = webtest.new_webdriver_session()
+ self.driver.get("http://localhost:%s" % self.port)
+ self.wait = wait.WebDriverWait(self.driver, 10)
+
+ def tearDown(self):
+ try:
+ self.driver.quit()
+ finally:
+ self.driver = None
+ self.wait = None
+
+ def testToolbarTitleDisplays(self):
+ self.wait.until(
+ expected_conditions.text_to_be_present_in_element(
+ (by.By.CLASS_NAME, "toolbar-title"), "TensorBoard"
+ )
+ )
+
+ def testLogdirDisplays(self):
+ self.wait.until(
+ expected_conditions.text_to_be_present_in_element(
+ (by.By.ID, "data_location"), self.logdir
+ )
+ )
+
class DashboardsTest(BasicTest):
- """Tests basic behavior when there is some data in TensorBoard.
-
- This extends `BasicTest`, so it inherits its methods to test that the
- basic chrome is displayed. We also check that we can navigate around
- the various dashboards.
- """
-
- @classmethod
- def setUpData(cls):
- scalars_demo.run_all(cls.logdir, verbose=False)
- audio_demo.run_all(cls.logdir, verbose=False)
-
- def testLogdirDisplays(self):
- # TensorBoard doesn't have logdir display when there is data
- pass
-
- def testDashboardSelection(self):
- """Test that we can navigate among the different dashboards."""
- selectors = {
- "scalars_tab": "paper-tab[data-dashboard=scalars]",
- "audio_tab": "paper-tab[data-dashboard=audio]",
- "graphs_tab": "paper-tab[data-dashboard=graphs]",
- "inactive_dropdown": "paper-dropdown-menu[label*=Inactive]",
- "images_menu_item": "paper-item[data-dashboard=images]",
- "reload_button": "paper-icon-button#reload-button",
- }
- elements = {}
- for (name, selector) in selectors.items():
- locator = (by.By.CSS_SELECTOR, selector)
- self.wait.until(expected_conditions.presence_of_element_located(locator))
- elements[name] = self.driver.find_element_by_css_selector(selector)
-
- # The implementation of paper-* components doesn't seem to play nice
- # with Selenium's `element.is_selected()` and `element.is_enabled()`
- # methods. Instead, we check the appropriate WAI-ARIA attributes.
- # (Also, though the docs for `get_attribute` say that the string
- # `"false"` is returned as `False`, this appears to be the case
- # _sometimes_ but not _always_, so we should take special care to
- # handle that.)
- def is_selected(element):
- attribute = element.get_attribute("aria-selected")
- return attribute and attribute != "false"
-
- def is_enabled(element):
- attribute = element.get_attribute("aria-disabled")
- is_disabled = attribute and attribute != "false"
- return not is_disabled
-
- def assert_selected_dashboard(polymer_component_name):
- expected = {polymer_component_name}
- actual = {
- container.find_element_by_css_selector("*").tag_name # first child
- for container
- in self.driver.find_elements_by_css_selector(".dashboard-container")
- if container.is_displayed()
- }
- self.assertEqual(expected, actual)
-
- # The scalar and audio dashboards should be active, and the scalar
- # dashboard should be selected by default. The images menu item
- # should not be visible, as it's within the drop-down menu.
- self.assertTrue(elements["scalars_tab"].is_displayed())
- self.assertTrue(elements["audio_tab"].is_displayed())
- self.assertTrue(elements["graphs_tab"].is_displayed())
- self.assertTrue(is_selected(elements["scalars_tab"]))
- self.assertFalse(is_selected(elements["audio_tab"]))
- self.assertFalse(elements["images_menu_item"].is_displayed())
- self.assertFalse(is_selected(elements["images_menu_item"]))
- assert_selected_dashboard("tf-scalar-dashboard")
-
- # While we're on the scalar dashboard, we should be allowed to
- # reload the data.
- self.assertTrue(is_enabled(elements["reload_button"]))
-
- # We should be able to activate the audio dashboard.
- elements["audio_tab"].click()
- self.assertFalse(is_selected(elements["scalars_tab"]))
- self.assertTrue(is_selected(elements["audio_tab"]))
- self.assertFalse(is_selected(elements["graphs_tab"]))
- self.assertFalse(is_selected(elements["images_menu_item"]))
- assert_selected_dashboard("tf-audio-dashboard")
- self.assertTrue(is_enabled(elements["reload_button"]))
-
- # We should then be able to open the dropdown and navigate to the
- # image dashboard. (We have to wait until it's visible because of the
- # dropdown menu's animations.)
- elements["inactive_dropdown"].click()
- self.wait.until(
- expected_conditions.visibility_of(elements["images_menu_item"]))
- self.assertTrue(elements["images_menu_item"].is_displayed())
- elements["images_menu_item"].click()
- self.assertFalse(is_selected(elements["scalars_tab"]))
- self.assertFalse(is_selected(elements["audio_tab"]))
- self.assertFalse(is_selected(elements["graphs_tab"]))
- self.assertTrue(is_selected(elements["images_menu_item"]))
- assert_selected_dashboard("tf-image-dashboard")
- self.assertTrue(is_enabled(elements["reload_button"]))
-
- # Next, we should be able to navigate back to an active dashboard.
- # If we choose the graphs dashboard, the reload feature should be
- # disabled.
- elements["graphs_tab"].click()
- self.assertFalse(elements["images_menu_item"].is_displayed())
- self.assertFalse(is_selected(elements["scalars_tab"]))
- self.assertFalse(is_selected(elements["audio_tab"]))
- self.assertTrue(is_selected(elements["graphs_tab"]))
- self.assertFalse(is_selected(elements["images_menu_item"]))
- assert_selected_dashboard("tf-graph-dashboard")
- self.assertFalse(is_enabled(elements["reload_button"]))
-
- # Finally, we should be able to navigate back to the scalar dashboard.
- elements["scalars_tab"].click()
- self.assertTrue(is_selected(elements["scalars_tab"]))
- self.assertFalse(is_selected(elements["audio_tab"]))
- self.assertFalse(is_selected(elements["graphs_tab"]))
- self.assertFalse(is_selected(elements["images_menu_item"]))
- assert_selected_dashboard("tf-scalar-dashboard")
- self.assertTrue(is_enabled(elements["reload_button"]))
+ """Tests basic behavior when there is some data in TensorBoard.
+
+ This extends `BasicTest`, so it inherits its methods to test that
+ the basic chrome is displayed. We also check that we can navigate
+ around the various dashboards.
+ """
+
+ @classmethod
+ def setUpData(cls):
+ scalars_demo.run_all(cls.logdir, verbose=False)
+ audio_demo.run_all(cls.logdir, verbose=False)
+
+ def testLogdirDisplays(self):
+ # TensorBoard doesn't have logdir display when there is data
+ pass
+
+ def testDashboardSelection(self):
+ """Test that we can navigate among the different dashboards."""
+ selectors = {
+ "scalars_tab": "paper-tab[data-dashboard=scalars]",
+ "audio_tab": "paper-tab[data-dashboard=audio]",
+ "graphs_tab": "paper-tab[data-dashboard=graphs]",
+ "inactive_dropdown": "paper-dropdown-menu[label*=Inactive]",
+ "images_menu_item": "paper-item[data-dashboard=images]",
+ "reload_button": "paper-icon-button#reload-button",
+ }
+ elements = {}
+ for (name, selector) in selectors.items():
+ locator = (by.By.CSS_SELECTOR, selector)
+ self.wait.until(
+ expected_conditions.presence_of_element_located(locator)
+ )
+ elements[name] = self.driver.find_element_by_css_selector(selector)
+
+ # The implementation of paper-* components doesn't seem to play nice
+ # with Selenium's `element.is_selected()` and `element.is_enabled()`
+ # methods. Instead, we check the appropriate WAI-ARIA attributes.
+ # (Also, though the docs for `get_attribute` say that the string
+ # `"false"` is returned as `False`, this appears to be the case
+ # _sometimes_ but not _always_, so we should take special care to
+ # handle that.)
+ def is_selected(element):
+ attribute = element.get_attribute("aria-selected")
+ return attribute and attribute != "false"
+
+ def is_enabled(element):
+ attribute = element.get_attribute("aria-disabled")
+ is_disabled = attribute and attribute != "false"
+ return not is_disabled
+
+ def assert_selected_dashboard(polymer_component_name):
+ expected = {polymer_component_name}
+ actual = {
+ container.find_element_by_css_selector(
+ "*"
+ ).tag_name # first child
+ for container in self.driver.find_elements_by_css_selector(
+ ".dashboard-container"
+ )
+ if container.is_displayed()
+ }
+ self.assertEqual(expected, actual)
+
+ # The scalar and audio dashboards should be active, and the scalar
+ # dashboard should be selected by default. The images menu item
+ # should not be visible, as it's within the drop-down menu.
+ self.assertTrue(elements["scalars_tab"].is_displayed())
+ self.assertTrue(elements["audio_tab"].is_displayed())
+ self.assertTrue(elements["graphs_tab"].is_displayed())
+ self.assertTrue(is_selected(elements["scalars_tab"]))
+ self.assertFalse(is_selected(elements["audio_tab"]))
+ self.assertFalse(elements["images_menu_item"].is_displayed())
+ self.assertFalse(is_selected(elements["images_menu_item"]))
+ assert_selected_dashboard("tf-scalar-dashboard")
+
+ # While we're on the scalar dashboard, we should be allowed to
+ # reload the data.
+ self.assertTrue(is_enabled(elements["reload_button"]))
+
+ # We should be able to activate the audio dashboard.
+ elements["audio_tab"].click()
+ self.assertFalse(is_selected(elements["scalars_tab"]))
+ self.assertTrue(is_selected(elements["audio_tab"]))
+ self.assertFalse(is_selected(elements["graphs_tab"]))
+ self.assertFalse(is_selected(elements["images_menu_item"]))
+ assert_selected_dashboard("tf-audio-dashboard")
+ self.assertTrue(is_enabled(elements["reload_button"]))
+
+ # We should then be able to open the dropdown and navigate to the
+ # image dashboard. (We have to wait until it's visible because of the
+ # dropdown menu's animations.)
+ elements["inactive_dropdown"].click()
+ self.wait.until(
+ expected_conditions.visibility_of(elements["images_menu_item"])
+ )
+ self.assertTrue(elements["images_menu_item"].is_displayed())
+ elements["images_menu_item"].click()
+ self.assertFalse(is_selected(elements["scalars_tab"]))
+ self.assertFalse(is_selected(elements["audio_tab"]))
+ self.assertFalse(is_selected(elements["graphs_tab"]))
+ self.assertTrue(is_selected(elements["images_menu_item"]))
+ assert_selected_dashboard("tf-image-dashboard")
+ self.assertTrue(is_enabled(elements["reload_button"]))
+
+ # Next, we should be able to navigate back to an active dashboard.
+ # If we choose the graphs dashboard, the reload feature should be
+ # disabled.
+ elements["graphs_tab"].click()
+ self.assertFalse(elements["images_menu_item"].is_displayed())
+ self.assertFalse(is_selected(elements["scalars_tab"]))
+ self.assertFalse(is_selected(elements["audio_tab"]))
+ self.assertTrue(is_selected(elements["graphs_tab"]))
+ self.assertFalse(is_selected(elements["images_menu_item"]))
+ assert_selected_dashboard("tf-graph-dashboard")
+ self.assertFalse(is_enabled(elements["reload_button"]))
+
+ # Finally, we should be able to navigate back to the scalar dashboard.
+ elements["scalars_tab"].click()
+ self.assertTrue(is_selected(elements["scalars_tab"]))
+ self.assertFalse(is_selected(elements["audio_tab"]))
+ self.assertFalse(is_selected(elements["graphs_tab"]))
+ self.assertFalse(is_selected(elements["images_menu_item"]))
+ assert_selected_dashboard("tf-scalar-dashboard")
+ self.assertTrue(is_enabled(elements["reload_button"]))
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/tensorboard/functionaltests/wct_test_driver.py b/tensorboard/functionaltests/wct_test_driver.py
index d0e09eb9a7..9fbecd806b 100644
--- a/tensorboard/functionaltests/wct_test_driver.py
+++ b/tensorboard/functionaltests/wct_test_driver.py
@@ -33,114 +33,124 @@
# As emitted in the "Listening on:" line of the WebfilesServer output.
# We extract only the port because the hostname can reroute through corp
# DNS and force auth, which fails in tests.
-_URL_RE = re.compile(br'http://[^:]*:([0-9]+)/')
+_URL_RE = re.compile(br"http://[^:]*:([0-9]+)/")
_SUITE_PASSED_RE = re.compile(r'.*test suite passed"$')
-_SUITE_FAILED_RE = re.compile(r'.*failing test.*')
+_SUITE_FAILED_RE = re.compile(r".*failing test.*")
+
def create_test_class(binary_path, web_path):
- """Create a unittest.TestCase class to run WebComponentTester tests.
-
- Arguments:
- binary_path: relative path to a `tf_web_library` target;
- e.g.: "tensorboard/components/vz_foo/test/test_web_library"
- web_path: absolute web path to the tests page in the above web
- library; e.g.: "/vz-foo/test/tests.html"
-
- Result:
- A new subclass of `unittest.TestCase`. Bind this to a variable in
- the test file's main module.
- """
-
- class BrowserLogIndicatesResult(object):
- def __init__(self):
- self.passed = False
- self.log = []
-
- def __call__(self, driver):
- # Scan through the log entries and search for a line indicating whether
- # the test passed or failed. The method 'driver.get_log' also seems to
- # clear the log so we aggregate it in self.log for printing later on.
- new_log = driver.get_log("browser")
- new_messages = [entry["message"] for entry in new_log]
- self.log = self.log + new_log
- if self._log_matches(new_messages, _SUITE_FAILED_RE):
- self.passed = False
- return True
- if self._log_matches(new_messages, _SUITE_PASSED_RE):
- self.passed = True
- return True
- # Here, we still don't know.
- return False
-
- def _log_matches(self, messages, regexp):
- for message in messages:
- if regexp.match(message):
- return True
- return False
-
- class WebComponentTesterTest(unittest.TestCase):
- """Tests that a family of unit tests completes successfully."""
-
- def setUp(cls):
- src_dir = os.environ["TEST_SRCDIR"]
- binary = os.path.join(
- src_dir,
- "org_tensorflow_tensorboard/" + binary_path)
- cls.process = subprocess.Popen(
- [binary], stdin=None, stdout=None, stderr=subprocess.PIPE)
-
- lines = []
- hit_eof = False
- while True:
- line = cls.process.stderr.readline()
- if line == b"":
- # b"" means reached EOF; b"\n" means empty line.
- hit_eof = True
- break
- lines.append(line)
- if b"Listening on:" in line:
- match = _URL_RE.search(line)
- if match:
- cls.port = int(match.group(1))
- break
- else:
- raise ValueError("Failed to parse listening-on line: %r" % line)
- if len(lines) >= 1024:
- # Sanity check---something is wrong. Let us fail fast rather
- # than spending the 15-minute test timeout consuming
- # potentially empty logs.
- hit_eof = True
- break
- if hit_eof:
- full_output = "\n".join(repr(line) for line in lines)
- raise ValueError(
- "Did not find listening-on line in output:\n%s" % full_output)
-
- def tearDown(cls):
- cls.process.kill()
- cls.process.wait()
-
- def test(self):
- driver = webtest.new_webdriver_session(
- capabilities={"loggingPrefs": {"browser": "ALL"}}
- )
- url = "http://localhost:%s%s" % (self.port, web_path)
- driver.get(url)
- browser_log_indicates_result = BrowserLogIndicatesResult()
- try:
- wait.WebDriverWait(driver, 10).until(browser_log_indicates_result)
- if not browser_log_indicates_result.passed:
- self.fail()
- finally:
- # Print log as an aid for debugging.
- log = browser_log_indicates_result.log + driver.get_log("browser")
- self._print_log(log)
-
- def _print_log(self, entries):
- print("Browser log follows:")
- print("--------------------")
- print(" | ".join(entries[0].keys()))
- for entry in entries:
- print(" | ".join(str(v) for v in entry.values()))
-
- return WebComponentTesterTest
+ """Create a unittest.TestCase class to run WebComponentTester tests.
+
+ Arguments:
+ binary_path: relative path to a `tf_web_library` target;
+ e.g.: "tensorboard/components/vz_foo/test/test_web_library"
+ web_path: absolute web path to the tests page in the above web
+ library; e.g.: "/vz-foo/test/tests.html"
+
+ Result:
+ A new subclass of `unittest.TestCase`. Bind this to a variable in
+ the test file's main module.
+ """
+
+ class BrowserLogIndicatesResult(object):
+ def __init__(self):
+ self.passed = False
+ self.log = []
+
+ def __call__(self, driver):
+ # Scan through the log entries and search for a line indicating whether
+ # the test passed or failed. The method 'driver.get_log' also seems to
+ # clear the log so we aggregate it in self.log for printing later on.
+ new_log = driver.get_log("browser")
+ new_messages = [entry["message"] for entry in new_log]
+ self.log = self.log + new_log
+ if self._log_matches(new_messages, _SUITE_FAILED_RE):
+ self.passed = False
+ return True
+ if self._log_matches(new_messages, _SUITE_PASSED_RE):
+ self.passed = True
+ return True
+ # Here, we still don't know.
+ return False
+
+ def _log_matches(self, messages, regexp):
+ for message in messages:
+ if regexp.match(message):
+ return True
+ return False
+
+ class WebComponentTesterTest(unittest.TestCase):
+ """Tests that a family of unit tests completes successfully."""
+
+ def setUp(cls):
+ src_dir = os.environ["TEST_SRCDIR"]
+ binary = os.path.join(
+ src_dir, "org_tensorflow_tensorboard/" + binary_path
+ )
+ cls.process = subprocess.Popen(
+ [binary], stdin=None, stdout=None, stderr=subprocess.PIPE
+ )
+
+ lines = []
+ hit_eof = False
+ while True:
+ line = cls.process.stderr.readline()
+ if line == b"":
+ # b"" means reached EOF; b"\n" means empty line.
+ hit_eof = True
+ break
+ lines.append(line)
+ if b"Listening on:" in line:
+ match = _URL_RE.search(line)
+ if match:
+ cls.port = int(match.group(1))
+ break
+ else:
+ raise ValueError(
+ "Failed to parse listening-on line: %r" % line
+ )
+ if len(lines) >= 1024:
+ # Sanity check---something is wrong. Let us fail fast rather
+ # than spending the 15-minute test timeout consuming
+ # potentially empty logs.
+ hit_eof = True
+ break
+ if hit_eof:
+ full_output = "\n".join(repr(line) for line in lines)
+ raise ValueError(
+ "Did not find listening-on line in output:\n%s"
+ % full_output
+ )
+
+ def tearDown(cls):
+ cls.process.kill()
+ cls.process.wait()
+
+ def test(self):
+ driver = webtest.new_webdriver_session(
+ capabilities={"loggingPrefs": {"browser": "ALL"}}
+ )
+ url = "http://localhost:%s%s" % (self.port, web_path)
+ driver.get(url)
+ browser_log_indicates_result = BrowserLogIndicatesResult()
+ try:
+ wait.WebDriverWait(driver, 10).until(
+ browser_log_indicates_result
+ )
+ if not browser_log_indicates_result.passed:
+ self.fail()
+ finally:
+ # Print log as an aid for debugging.
+ log = browser_log_indicates_result.log + driver.get_log(
+ "browser"
+ )
+ self._print_log(log)
+
+ def _print_log(self, entries):
+ print("Browser log follows:")
+ print("--------------------")
+ print(" | ".join(entries[0].keys()))
+ for entry in entries:
+ print(" | ".join(str(v) for v in entry.values()))
+
+ return WebComponentTesterTest
diff --git a/tensorboard/lazy.py b/tensorboard/lazy.py
index a891903965..65d51e489b 100644
--- a/tensorboard/lazy.py
+++ b/tensorboard/lazy.py
@@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""TensorBoard is a webapp for understanding TensorFlow runs and graphs.
-"""
+"""TensorBoard is a webapp for understanding TensorFlow runs and graphs."""
from __future__ import absolute_import
from __future__ import division
@@ -25,69 +24,80 @@
def lazy_load(name):
- """Decorator to define a function that lazily loads the module 'name'.
-
- This can be used to defer importing troublesome dependencies - e.g. ones that
- are large and infrequently used, or that cause a dependency cycle -
- until they are actually used.
-
- Args:
- name: the fully-qualified name of the module; typically the last segment
- of 'name' matches the name of the decorated function
-
- Returns:
- Decorator function that produces a lazy-loading module 'name' backed by the
- underlying decorated function.
- """
- def wrapper(load_fn):
- # Wrap load_fn to call it exactly once and update __dict__ afterwards to
- # make future lookups efficient (only failed lookups call __getattr__).
- @_memoize
- def load_once(self):
- if load_once.loading:
- raise ImportError("Circular import when resolving LazyModule %r" % name)
- load_once.loading = True
- try:
- module = load_fn()
- finally:
+ """Decorator to define a function that lazily loads the module 'name'.
+
+ This can be used to defer importing troublesome dependencies - e.g. ones that
+ are large and infrequently used, or that cause a dependency cycle -
+ until they are actually used.
+
+ Args:
+ name: the fully-qualified name of the module; typically the last segment
+ of 'name' matches the name of the decorated function
+
+ Returns:
+ Decorator function that produces a lazy-loading module 'name' backed by the
+ underlying decorated function.
+ """
+
+ def wrapper(load_fn):
+ # Wrap load_fn to call it exactly once and update __dict__ afterwards to
+ # make future lookups efficient (only failed lookups call __getattr__).
+ @_memoize
+ def load_once(self):
+ if load_once.loading:
+ raise ImportError(
+ "Circular import when resolving LazyModule %r" % name
+ )
+ load_once.loading = True
+ try:
+ module = load_fn()
+ finally:
+ load_once.loading = False
+ self.__dict__.update(module.__dict__)
+ load_once.loaded = True
+ return module
+
load_once.loading = False
- self.__dict__.update(module.__dict__)
- load_once.loaded = True
- return module
- load_once.loading = False
- load_once.loaded = False
+ load_once.loaded = False
+
+ # Define a module that proxies getattr() and dir() to the result of calling
+ # load_once() the first time it's needed. The class is nested so we can close
+ # over load_once() and avoid polluting the module's attrs with our own state.
+ class LazyModule(types.ModuleType):
+ def __getattr__(self, attr_name):
+ return getattr(load_once(self), attr_name)
- # Define a module that proxies getattr() and dir() to the result of calling
- # load_once() the first time it's needed. The class is nested so we can close
- # over load_once() and avoid polluting the module's attrs with our own state.
- class LazyModule(types.ModuleType):
- def __getattr__(self, attr_name):
- return getattr(load_once(self), attr_name)
+ def __dir__(self):
+ return dir(load_once(self))
- def __dir__(self):
- return dir(load_once(self))
+ def __repr__(self):
+ if load_once.loaded:
+ return "<%r via LazyModule (loaded)>" % load_once(self)
+ return (
+ ""
+ % self.__name__
+ )
- def __repr__(self):
- if load_once.loaded:
- return '<%r via LazyModule (loaded)>' % load_once(self)
- return '' % self.__name__
+ return LazyModule(name)
- return LazyModule(name)
- return wrapper
+ return wrapper
def _memoize(f):
- """Memoizing decorator for f, which must have exactly 1 hashable argument."""
- nothing = object() # Unique "no value" sentinel object.
- cache = {}
- # Use a reentrant lock so that if f references the resulting wrapper we die
- # with recursion depth exceeded instead of deadlocking.
- lock = threading.RLock()
- @functools.wraps(f)
- def wrapper(arg):
- if cache.get(arg, nothing) is nothing:
- with lock:
+ """Memoizing decorator for f, which must have exactly 1 hashable
+ argument."""
+ nothing = object() # Unique "no value" sentinel object.
+ cache = {}
+ # Use a reentrant lock so that if f references the resulting wrapper we die
+ # with recursion depth exceeded instead of deadlocking.
+ lock = threading.RLock()
+
+ @functools.wraps(f)
+ def wrapper(arg):
if cache.get(arg, nothing) is nothing:
- cache[arg] = f(arg)
- return cache[arg]
- return wrapper
+ with lock:
+ if cache.get(arg, nothing) is nothing:
+ cache[arg] = f(arg)
+ return cache[arg]
+
+ return wrapper
diff --git a/tensorboard/lazy_test.py b/tensorboard/lazy_test.py
index 8ff3f7a204..9732b9df68 100644
--- a/tensorboard/lazy_test.py
+++ b/tensorboard/lazy_test.py
@@ -25,85 +25,95 @@
class LazyTest(unittest.TestCase):
-
- def test_self_composition(self):
- """A lazy module should be able to load another lazy module."""
- # This test can fail if the `LazyModule` implementation stores the
- # cached module as a field on the module itself rather than a
- # closure value. (See pull request review comments on #1781 for
- # details.)
-
- @lazy.lazy_load("inner")
- def inner():
- import collections
- return collections
-
- @lazy.lazy_load("outer")
- def outer():
- return inner
-
- x1 = outer.namedtuple
- x2 = inner.namedtuple
- self.assertEqual(x1, x2)
-
- def test_lazy_cycle(self):
- """A cycle among lazy modules should error, not deadlock or spin."""
- # This test can fail if `_memoize` uses a non-reentrant lock. (See
- # pull request review comments on #1781 for details.)
-
- @lazy.lazy_load("inner")
- def inner():
- return outer.foo
-
- @lazy.lazy_load("outer")
- def outer():
- return inner
-
- expected_message = "Circular import when resolving LazyModule 'inner'"
- with six.assertRaisesRegex(self, ImportError, expected_message):
- outer.bar
-
- def test_repr_before_load(self):
- @lazy.lazy_load("foo")
- def foo():
- self.fail("Should not need to resolve this module.")
- self.assertEquals(repr(foo), "")
-
- def test_repr_after_load(self):
- import collections
- @lazy.lazy_load("foo")
- def foo():
- return collections
- foo.namedtuple
- self.assertEquals(repr(foo), "<%r via LazyModule (loaded)>" % collections)
-
- def test_failed_load_idempotent(self):
- expected_message = "you will never stop me"
- @lazy.lazy_load("bad")
- def bad():
- raise ValueError(expected_message)
- with six.assertRaisesRegex(self, ValueError, expected_message):
- bad.day
- with six.assertRaisesRegex(self, ValueError, expected_message):
- bad.day
-
- def test_loads_only_once_even_when_result_equal_to_everything(self):
- # This would fail if the implementation of `_memoize` used `==`
- # rather than `is` to check for the sentinel value.
- class EqualToEverything(object):
- def __eq__(self, other):
- return True
-
- count_box = [0]
- @lazy.lazy_load("foo")
- def foo():
- count_box[0] += 1
- return EqualToEverything()
-
- dir(foo)
- dir(foo)
- self.assertEqual(count_box[0], 1)
-
-
-if __name__ == '__main__':
- unittest.main()
+ def test_self_composition(self):
+ """A lazy module should be able to load another lazy module."""
+ # This test can fail if the `LazyModule` implementation stores the
+ # cached module as a field on the module itself rather than a
+ # closure value. (See pull request review comments on #1781 for
+ # details.)
+
+ @lazy.lazy_load("inner")
+ def inner():
+ import collections
+
+ return collections
+
+ @lazy.lazy_load("outer")
+ def outer():
+ return inner
+
+ x1 = outer.namedtuple
+ x2 = inner.namedtuple
+ self.assertEqual(x1, x2)
+
+ def test_lazy_cycle(self):
+ """A cycle among lazy modules should error, not deadlock or spin."""
+ # This test can fail if `_memoize` uses a non-reentrant lock. (See
+ # pull request review comments on #1781 for details.)
+
+ @lazy.lazy_load("inner")
+ def inner():
+ return outer.foo
+
+ @lazy.lazy_load("outer")
+ def outer():
+ return inner
+
+ expected_message = "Circular import when resolving LazyModule 'inner'"
+ with six.assertRaisesRegex(self, ImportError, expected_message):
+ outer.bar
+
+ def test_repr_before_load(self):
+ @lazy.lazy_load("foo")
+ def foo():
+ self.fail("Should not need to resolve this module.")
+
+ self.assertEquals(
+ repr(foo), ""
+ )
+
+ def test_repr_after_load(self):
+ import collections
+
+ @lazy.lazy_load("foo")
+ def foo():
+ return collections
+
+ foo.namedtuple
+ self.assertEquals(
+ repr(foo), "<%r via LazyModule (loaded)>" % collections
+ )
+
+ def test_failed_load_idempotent(self):
+ expected_message = "you will never stop me"
+
+ @lazy.lazy_load("bad")
+ def bad():
+ raise ValueError(expected_message)
+
+ with six.assertRaisesRegex(self, ValueError, expected_message):
+ bad.day
+ with six.assertRaisesRegex(self, ValueError, expected_message):
+ bad.day
+
+ def test_loads_only_once_even_when_result_equal_to_everything(self):
+ # This would fail if the implementation of `_memoize` used `==`
+ # rather than `is` to check for the sentinel value.
+ class EqualToEverything(object):
+ def __eq__(self, other):
+ return True
+
+ count_box = [0]
+
+ @lazy.lazy_load("foo")
+ def foo():
+ count_box[0] += 1
+ return EqualToEverything()
+
+ dir(foo)
+ dir(foo)
+ self.assertEqual(count_box[0], 1)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tensorboard/lib_test.py b/tensorboard/lib_test.py
index aab06b023e..b7d890fc16 100644
--- a/tensorboard/lib_test.py
+++ b/tensorboard/lib_test.py
@@ -22,22 +22,22 @@
class ReloadTensorBoardTest(unittest.TestCase):
+ def test_functional_after_reload(self):
+ self.assertNotIn("tensorboard", sys.modules)
+ import tensorboard as tensorboard # it makes the Google sync happy
- def test_functional_after_reload(self):
- self.assertNotIn("tensorboard", sys.modules)
- import tensorboard as tensorboard # it makes the Google sync happy
- submodules = ["notebook", "program", "summary"]
- dirs_before = {
- module_name: dir(getattr(tensorboard, module_name))
- for module_name in submodules
- }
- tensorboard = moves.reload_module(tensorboard)
- dirs_after = {
- module_name: dir(getattr(tensorboard, module_name))
- for module_name in submodules
- }
- self.assertEqual(dirs_before, dirs_after)
+ submodules = ["notebook", "program", "summary"]
+ dirs_before = {
+ module_name: dir(getattr(tensorboard, module_name))
+ for module_name in submodules
+ }
+ tensorboard = moves.reload_module(tensorboard)
+ dirs_after = {
+ module_name: dir(getattr(tensorboard, module_name))
+ for module_name in submodules
+ }
+ self.assertEqual(dirs_before, dirs_after)
-if __name__ == '__main__':
- unittest.main()
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tensorboard/main.py b/tensorboard/main.py
index 2495f335dc..0976831220 100644
--- a/tensorboard/main.py
+++ b/tensorboard/main.py
@@ -32,7 +32,7 @@
# pattern of reads used by TensorBoard for logdirs. See for details:
# https://github.com/tensorflow/tensorboard/issues/1225
# This must be set before the first import of tensorflow.
-os.environ['GCS_READ_CACHE_DISABLED'] = '1'
+os.environ["GCS_READ_CACHE_DISABLED"] = "1"
import sys
@@ -47,33 +47,39 @@
logger = tb_logging.get_logger()
+
def run_main():
- """Initializes flags and calls main()."""
- program.setup_environment()
-
- if getattr(tf, '__version__', 'stub') == 'stub':
- print("TensorFlow installation not found - running with reduced feature set.",
- file=sys.stderr)
-
- tensorboard = program.TensorBoard(
- default.get_plugins() + default.get_dynamic_plugins(),
- program.get_default_assets_zip_provider(),
- subcommands=[uploader_main.UploaderSubcommand()])
- try:
- from absl import app
- # Import this to check that app.run() will accept the flags_parser argument.
- from absl.flags import argparse_flags
- app.run(tensorboard.main, flags_parser=tensorboard.configure)
- raise AssertionError("absl.app.run() shouldn't return")
- except ImportError:
- pass
- except base_plugin.FlagsError as e:
- print("Error: %s" % e, file=sys.stderr)
- sys.exit(1)
-
- tensorboard.configure(sys.argv)
- sys.exit(tensorboard.main())
-
-
-if __name__ == '__main__':
- run_main()
+ """Initializes flags and calls main()."""
+ program.setup_environment()
+
+ if getattr(tf, "__version__", "stub") == "stub":
+ print(
+ "TensorFlow installation not found - running with reduced feature set.",
+ file=sys.stderr,
+ )
+
+ tensorboard = program.TensorBoard(
+ default.get_plugins() + default.get_dynamic_plugins(),
+ program.get_default_assets_zip_provider(),
+ subcommands=[uploader_main.UploaderSubcommand()],
+ )
+ try:
+ from absl import app
+
+ # Import this to check that app.run() will accept the flags_parser argument.
+ from absl.flags import argparse_flags
+
+ app.run(tensorboard.main, flags_parser=tensorboard.configure)
+ raise AssertionError("absl.app.run() shouldn't return")
+ except ImportError:
+ pass
+ except base_plugin.FlagsError as e:
+ print("Error: %s" % e, file=sys.stderr)
+ sys.exit(1)
+
+ tensorboard.configure(sys.argv)
+ sys.exit(tensorboard.main())
+
+
+if __name__ == "__main__":
+ run_main()
diff --git a/tensorboard/manager.py b/tensorboard/manager.py
index f6495bdc00..589c227fb7 100644
--- a/tensorboard/manager.py
+++ b/tensorboard/manager.py
@@ -41,12 +41,7 @@
# https://github.com/tensorflow/tensorboard/issues/2017.
_FieldType = collections.namedtuple(
"_FieldType",
- (
- "serialized_type",
- "runtime_type",
- "serialize",
- "deserialize",
- ),
+ ("serialized_type", "runtime_type", "serialize", "deserialize",),
)
_type_int = _FieldType(
serialized_type=int,
@@ -62,270 +57,268 @@
)
# Information about a running TensorBoard instance.
-_TENSORBOARD_INFO_FIELDS = collections.OrderedDict((
- ("version", _type_str),
- ("start_time", _type_int), # seconds since epoch
- ("pid", _type_int),
- ("port", _type_int),
- ("path_prefix", _type_str), # may be empty
- ("logdir", _type_str), # may be empty
- ("db", _type_str), # may be empty
- ("cache_key", _type_str), # opaque, as given by `cache_key` below
-))
+_TENSORBOARD_INFO_FIELDS = collections.OrderedDict(
+ (
+ ("version", _type_str),
+ ("start_time", _type_int), # seconds since epoch
+ ("pid", _type_int),
+ ("port", _type_int),
+ ("path_prefix", _type_str), # may be empty
+ ("logdir", _type_str), # may be empty
+ ("db", _type_str), # may be empty
+ ("cache_key", _type_str), # opaque, as given by `cache_key` below
+ )
+)
TensorBoardInfo = collections.namedtuple(
- "TensorBoardInfo",
- _TENSORBOARD_INFO_FIELDS,
+ "TensorBoardInfo", _TENSORBOARD_INFO_FIELDS,
)
def data_source_from_info(info):
- """Format the data location for the given TensorBoardInfo.
+ """Format the data location for the given TensorBoardInfo.
- Args:
- info: A TensorBoardInfo value.
+ Args:
+ info: A TensorBoardInfo value.
- Returns:
- A human-readable string describing the logdir or database connection
- used by the server: e.g., "logdir /tmp/logs".
- """
- if info.db:
- return "db %s" % info.db
- else:
- return "logdir %s" % info.logdir
+ Returns:
+ A human-readable string describing the logdir or database connection
+ used by the server: e.g., "logdir /tmp/logs".
+ """
+ if info.db:
+ return "db %s" % info.db
+ else:
+ return "logdir %s" % info.logdir
def _info_to_string(info):
- """Convert a `TensorBoardInfo` to string form to be stored on disk.
-
- The format returned by this function is opaque and should only be
- interpreted by `_info_from_string`.
-
- Args:
- info: A valid `TensorBoardInfo` object.
-
- Raises:
- ValueError: If any field on `info` is not of the correct type.
-
- Returns:
- A string representation of the provided `TensorBoardInfo`.
- """
- for key in _TENSORBOARD_INFO_FIELDS:
- field_type = _TENSORBOARD_INFO_FIELDS[key]
- if not isinstance(getattr(info, key), field_type.runtime_type):
- raise ValueError(
- "expected %r of type %s, but found: %r" %
- (key, field_type.runtime_type, getattr(info, key))
- )
- if info.version != version.VERSION:
- raise ValueError(
- "expected 'version' to be %r, but found: %r" %
- (version.VERSION, info.version)
- )
- json_value = {
- k: _TENSORBOARD_INFO_FIELDS[k].serialize(getattr(info, k))
- for k in _TENSORBOARD_INFO_FIELDS
- }
- return json.dumps(json_value, sort_keys=True, indent=4)
+ """Convert a `TensorBoardInfo` to string form to be stored on disk.
+
+ The format returned by this function is opaque and should only be
+ interpreted by `_info_from_string`.
+
+ Args:
+ info: A valid `TensorBoardInfo` object.
+
+ Raises:
+ ValueError: If any field on `info` is not of the correct type.
+
+ Returns:
+ A string representation of the provided `TensorBoardInfo`.
+ """
+ for key in _TENSORBOARD_INFO_FIELDS:
+ field_type = _TENSORBOARD_INFO_FIELDS[key]
+ if not isinstance(getattr(info, key), field_type.runtime_type):
+ raise ValueError(
+ "expected %r of type %s, but found: %r"
+ % (key, field_type.runtime_type, getattr(info, key))
+ )
+ if info.version != version.VERSION:
+ raise ValueError(
+ "expected 'version' to be %r, but found: %r"
+ % (version.VERSION, info.version)
+ )
+ json_value = {
+ k: _TENSORBOARD_INFO_FIELDS[k].serialize(getattr(info, k))
+ for k in _TENSORBOARD_INFO_FIELDS
+ }
+ return json.dumps(json_value, sort_keys=True, indent=4)
def _info_from_string(info_string):
- """Parse a `TensorBoardInfo` object from its string representation.
-
- Args:
- info_string: A string representation of a `TensorBoardInfo`, as
- produced by a previous call to `_info_to_string`.
-
- Returns:
- A `TensorBoardInfo` value.
-
- Raises:
- ValueError: If the provided string is not valid JSON, or if it is
- missing any required fields, or if any field is of incorrect type.
- """
-
- try:
- json_value = json.loads(info_string)
- except ValueError:
- raise ValueError("invalid JSON: %r" % (info_string,))
- if not isinstance(json_value, dict):
- raise ValueError("not a JSON object: %r" % (json_value,))
- expected_keys = frozenset(_TENSORBOARD_INFO_FIELDS)
- actual_keys = frozenset(json_value)
- missing_keys = expected_keys - actual_keys
- if missing_keys:
- raise ValueError(
- "TensorBoardInfo missing keys: %r"
- % (sorted(missing_keys),)
- )
- # For forward compatibility, silently ignore unknown keys.
+ """Parse a `TensorBoardInfo` object from its string representation.
- # Validate and deserialize fields.
- fields = {}
- for key in _TENSORBOARD_INFO_FIELDS:
- field_type = _TENSORBOARD_INFO_FIELDS[key]
- if not isinstance(json_value[key], field_type.serialized_type):
- raise ValueError(
- "expected %r of type %s, but found: %r" %
- (key, field_type.serialized_type, json_value[key])
- )
- fields[key] = field_type.deserialize(json_value[key])
+ Args:
+ info_string: A string representation of a `TensorBoardInfo`, as
+ produced by a previous call to `_info_to_string`.
- return TensorBoardInfo(**fields)
+ Returns:
+ A `TensorBoardInfo` value.
+
+ Raises:
+ ValueError: If the provided string is not valid JSON, or if it is
+ missing any required fields, or if any field is of incorrect type.
+ """
+
+ try:
+ json_value = json.loads(info_string)
+ except ValueError:
+ raise ValueError("invalid JSON: %r" % (info_string,))
+ if not isinstance(json_value, dict):
+ raise ValueError("not a JSON object: %r" % (json_value,))
+ expected_keys = frozenset(_TENSORBOARD_INFO_FIELDS)
+ actual_keys = frozenset(json_value)
+ missing_keys = expected_keys - actual_keys
+ if missing_keys:
+ raise ValueError(
+ "TensorBoardInfo missing keys: %r" % (sorted(missing_keys),)
+ )
+ # For forward compatibility, silently ignore unknown keys.
+
+ # Validate and deserialize fields.
+ fields = {}
+ for key in _TENSORBOARD_INFO_FIELDS:
+ field_type = _TENSORBOARD_INFO_FIELDS[key]
+ if not isinstance(json_value[key], field_type.serialized_type):
+ raise ValueError(
+ "expected %r of type %s, but found: %r"
+ % (key, field_type.serialized_type, json_value[key])
+ )
+ fields[key] = field_type.deserialize(json_value[key])
+
+ return TensorBoardInfo(**fields)
def cache_key(working_directory, arguments, configure_kwargs):
- """Compute a `TensorBoardInfo.cache_key` field.
-
- The format returned by this function is opaque. Clients may only
- inspect it by comparing it for equality with other results from this
- function.
-
- Args:
- working_directory: The directory from which TensorBoard was launched
- and relative to which paths like `--logdir` and `--db` are
- resolved.
- arguments: The command-line args to TensorBoard, as `sys.argv[1:]`.
- Should be a list (or tuple), not an unparsed string. If you have a
- raw shell command, use `shlex.split` before passing it to this
- function.
- configure_kwargs: A dictionary of additional argument values to
- override the textual `arguments`, with the same semantics as in
- `tensorboard.program.TensorBoard.configure`. May be an empty
- dictionary.
-
- Returns:
- A string such that if two (prospective or actual) TensorBoard
- invocations have the same cache key then it is safe to use one in
- place of the other. The converse is not guaranteed: it is often safe
- to change the order of TensorBoard arguments, or to explicitly set
- them to their default values, or to move them between `arguments`
- and `configure_kwargs`, but such invocations may yield distinct
- cache keys.
- """
- if not isinstance(arguments, (list, tuple)):
- raise TypeError(
- "'arguments' should be a list of arguments, but found: %r "
- "(use `shlex.split` if given a string)"
- % (arguments,)
+ """Compute a `TensorBoardInfo.cache_key` field.
+
+ The format returned by this function is opaque. Clients may only
+ inspect it by comparing it for equality with other results from this
+ function.
+
+ Args:
+ working_directory: The directory from which TensorBoard was launched
+ and relative to which paths like `--logdir` and `--db` are
+ resolved.
+ arguments: The command-line args to TensorBoard, as `sys.argv[1:]`.
+ Should be a list (or tuple), not an unparsed string. If you have a
+ raw shell command, use `shlex.split` before passing it to this
+ function.
+ configure_kwargs: A dictionary of additional argument values to
+ override the textual `arguments`, with the same semantics as in
+ `tensorboard.program.TensorBoard.configure`. May be an empty
+ dictionary.
+
+ Returns:
+ A string such that if two (prospective or actual) TensorBoard
+ invocations have the same cache key then it is safe to use one in
+ place of the other. The converse is not guaranteed: it is often safe
+ to change the order of TensorBoard arguments, or to explicitly set
+ them to their default values, or to move them between `arguments`
+ and `configure_kwargs`, but such invocations may yield distinct
+ cache keys.
+ """
+ if not isinstance(arguments, (list, tuple)):
+ raise TypeError(
+ "'arguments' should be a list of arguments, but found: %r "
+ "(use `shlex.split` if given a string)" % (arguments,)
+ )
+ datum = {
+ "working_directory": working_directory,
+ "arguments": arguments,
+ "configure_kwargs": configure_kwargs,
+ }
+ raw = base64.b64encode(
+ json.dumps(datum, sort_keys=True, separators=(",", ":")).encode("utf-8")
)
- datum = {
- "working_directory": working_directory,
- "arguments": arguments,
- "configure_kwargs": configure_kwargs,
- }
- raw = base64.b64encode(
- json.dumps(datum, sort_keys=True, separators=(",", ":")).encode("utf-8")
- )
- # `raw` is of type `bytes`, even though it only contains ASCII
- # characters; we want it to be `str` in both Python 2 and 3.
- return str(raw.decode("ascii"))
+ # `raw` is of type `bytes`, even though it only contains ASCII
+ # characters; we want it to be `str` in both Python 2 and 3.
+ return str(raw.decode("ascii"))
def _get_info_dir():
- """Get path to directory in which to store info files.
-
- The directory returned by this function is "owned" by this module. If
- the contents of the directory are modified other than via the public
- functions of this module, subsequent behavior is undefined.
-
- The directory will be created if it does not exist.
- """
- path = os.path.join(tempfile.gettempdir(), ".tensorboard-info")
- try:
- os.makedirs(path)
- except OSError as e:
- if e.errno == errno.EEXIST and os.path.isdir(path):
- pass
+ """Get path to directory in which to store info files.
+
+ The directory returned by this function is "owned" by this module. If
+ the contents of the directory are modified other than via the public
+ functions of this module, subsequent behavior is undefined.
+
+ The directory will be created if it does not exist.
+ """
+ path = os.path.join(tempfile.gettempdir(), ".tensorboard-info")
+ try:
+ os.makedirs(path)
+ except OSError as e:
+ if e.errno == errno.EEXIST and os.path.isdir(path):
+ pass
+ else:
+ raise
else:
- raise
- else:
- os.chmod(path, 0o777)
- return path
+ os.chmod(path, 0o777)
+ return path
def _get_info_file_path():
- """Get path to info file for the current process.
+ """Get path to info file for the current process.
- As with `_get_info_dir`, the info directory will be created if it does
- not exist.
- """
- return os.path.join(_get_info_dir(), "pid-%d.info" % os.getpid())
+ As with `_get_info_dir`, the info directory will be created if it
+ does not exist.
+ """
+ return os.path.join(_get_info_dir(), "pid-%d.info" % os.getpid())
def write_info_file(tensorboard_info):
- """Write TensorBoardInfo to the current process's info file.
+ """Write TensorBoardInfo to the current process's info file.
- This should be called by `main` once the server is ready. When the
- server shuts down, `remove_info_file` should be called.
+ This should be called by `main` once the server is ready. When the
+ server shuts down, `remove_info_file` should be called.
- Args:
- tensorboard_info: A valid `TensorBoardInfo` object.
+ Args:
+ tensorboard_info: A valid `TensorBoardInfo` object.
- Raises:
- ValueError: If any field on `info` is not of the correct type.
- """
- payload = "%s\n" % _info_to_string(tensorboard_info)
- with open(_get_info_file_path(), "w") as outfile:
- outfile.write(payload)
+ Raises:
+ ValueError: If any field on `info` is not of the correct type.
+ """
+ payload = "%s\n" % _info_to_string(tensorboard_info)
+ with open(_get_info_file_path(), "w") as outfile:
+ outfile.write(payload)
def remove_info_file():
- """Remove the current process's TensorBoardInfo file, if it exists.
-
- If the file does not exist, no action is taken and no error is raised.
- """
- try:
- os.unlink(_get_info_file_path())
- except OSError as e:
- if e.errno == errno.ENOENT:
- # The user may have wiped their temporary directory or something.
- # Not a problem: we're already in the state that we want to be in.
- pass
- else:
- raise
+ """Remove the current process's TensorBoardInfo file, if it exists.
+
+ If the file does not exist, no action is taken and no error is
+ raised.
+ """
+ try:
+ os.unlink(_get_info_file_path())
+ except OSError as e:
+ if e.errno == errno.ENOENT:
+ # The user may have wiped their temporary directory or something.
+ # Not a problem: we're already in the state that we want to be in.
+ pass
+ else:
+ raise
def get_all():
- """Return TensorBoardInfo values for running TensorBoard processes.
-
- This function may not provide a perfect snapshot of the set of running
- processes. Its result set may be incomplete if the user has cleaned
- their /tmp/ directory while TensorBoard processes are running. It may
- contain extraneous entries if TensorBoard processes exited uncleanly
- (e.g., with SIGKILL or SIGQUIT).
-
- Entries in the info directory that do not represent valid
- `TensorBoardInfo` values will be silently ignored.
-
- Returns:
- A fresh list of `TensorBoardInfo` objects.
- """
- info_dir = _get_info_dir()
- results = []
- for filename in os.listdir(info_dir):
- filepath = os.path.join(info_dir, filename)
- try:
- with open(filepath) as infile:
- contents = infile.read()
- except IOError as e:
- if e.errno == errno.EACCES:
- # May have been written by this module in a process whose
- # `umask` includes some bits of 0o444.
- continue
- else:
- raise
- try:
- info = _info_from_string(contents)
- except ValueError:
- # Ignore unrecognized files, logging at debug only.
- tb_logging.get_logger().debug(
- "invalid info file: %r",
- filepath,
- exc_info=True,
- )
- else:
- results.append(info)
- return results
+ """Return TensorBoardInfo values for running TensorBoard processes.
+
+ This function may not provide a perfect snapshot of the set of running
+ processes. Its result set may be incomplete if the user has cleaned
+ their /tmp/ directory while TensorBoard processes are running. It may
+ contain extraneous entries if TensorBoard processes exited uncleanly
+ (e.g., with SIGKILL or SIGQUIT).
+
+ Entries in the info directory that do not represent valid
+ `TensorBoardInfo` values will be silently ignored.
+
+ Returns:
+ A fresh list of `TensorBoardInfo` objects.
+ """
+ info_dir = _get_info_dir()
+ results = []
+ for filename in os.listdir(info_dir):
+ filepath = os.path.join(info_dir, filename)
+ try:
+ with open(filepath) as infile:
+ contents = infile.read()
+ except IOError as e:
+ if e.errno == errno.EACCES:
+ # May have been written by this module in a process whose
+ # `umask` includes some bits of 0o444.
+ continue
+ else:
+ raise
+ try:
+ info = _info_from_string(contents)
+ except ValueError:
+ # Ignore unrecognized files, logging at debug only.
+ tb_logging.get_logger().debug(
+ "invalid info file: %r", filepath, exc_info=True,
+ )
+ else:
+ results.append(info)
+ return results
# The following five types enumerate the possible return values of the
@@ -375,102 +368,102 @@ def get_all():
def start(arguments, timeout=datetime.timedelta(seconds=60)):
- """Start a new TensorBoard instance, or reuse a compatible one.
-
- If the cache key determined by the provided arguments and the current
- working directory (see `cache_key`) matches the cache key of a running
- TensorBoard process (see `get_all`), that process will be reused.
-
- Otherwise, a new TensorBoard process will be spawned with the provided
- arguments, using the `tensorboard` binary from the system path.
-
- Args:
- arguments: List of strings to be passed as arguments to
- `tensorboard`. (If you have a raw command-line string, see
- `shlex.split`.)
- timeout: `datetime.timedelta` object describing how long to wait for
- the subprocess to initialize a TensorBoard server and write its
- `TensorBoardInfo` file. If the info file is not written within
- this time period, `start` will assume that the subprocess is stuck
- in a bad state, and will give up on waiting for it and return a
- `StartTimedOut` result. Note that in such a case the subprocess
- will not be killed. Default value is 60 seconds.
-
- Returns:
- A `StartReused`, `StartLaunched`, `StartFailed`, or `StartTimedOut`
- object.
- """
- match = _find_matching_instance(
- cache_key(
- working_directory=os.getcwd(),
- arguments=arguments,
- configure_kwargs={},
- ),
- )
- if match:
- return StartReused(info=match)
-
- (stdout_fd, stdout_path) = tempfile.mkstemp(prefix=".tensorboard-stdout-")
- (stderr_fd, stderr_path) = tempfile.mkstemp(prefix=".tensorboard-stderr-")
- start_time_seconds = time.time()
- explicit_tb = os.environ.get("TENSORBOARD_BINARY", None)
- try:
- p = subprocess.Popen(
- ["tensorboard" if explicit_tb is None else explicit_tb] + arguments,
- stdout=stdout_fd,
- stderr=stderr_fd,
+ """Start a new TensorBoard instance, or reuse a compatible one.
+
+ If the cache key determined by the provided arguments and the current
+ working directory (see `cache_key`) matches the cache key of a running
+ TensorBoard process (see `get_all`), that process will be reused.
+
+ Otherwise, a new TensorBoard process will be spawned with the provided
+ arguments, using the `tensorboard` binary from the system path.
+
+ Args:
+ arguments: List of strings to be passed as arguments to
+ `tensorboard`. (If you have a raw command-line string, see
+ `shlex.split`.)
+ timeout: `datetime.timedelta` object describing how long to wait for
+ the subprocess to initialize a TensorBoard server and write its
+ `TensorBoardInfo` file. If the info file is not written within
+ this time period, `start` will assume that the subprocess is stuck
+ in a bad state, and will give up on waiting for it and return a
+ `StartTimedOut` result. Note that in such a case the subprocess
+ will not be killed. Default value is 60 seconds.
+
+ Returns:
+ A `StartReused`, `StartLaunched`, `StartFailed`, or `StartTimedOut`
+ object.
+ """
+ match = _find_matching_instance(
+ cache_key(
+ working_directory=os.getcwd(),
+ arguments=arguments,
+ configure_kwargs={},
+ ),
)
- except OSError as e:
- return StartExecFailed(os_error=e, explicit_binary=explicit_tb)
- finally:
- os.close(stdout_fd)
- os.close(stderr_fd)
-
- poll_interval_seconds = 0.5
- end_time_seconds = start_time_seconds + timeout.total_seconds()
- while time.time() < end_time_seconds:
- time.sleep(poll_interval_seconds)
- subprocess_result = p.poll()
- if subprocess_result is not None:
- return StartFailed(
- exit_code=subprocess_result,
- stdout=_maybe_read_file(stdout_path),
- stderr=_maybe_read_file(stderr_path),
- )
- for info in get_all():
- if info.pid == p.pid and info.start_time >= start_time_seconds:
- return StartLaunched(info=info)
- else:
- return StartTimedOut(pid=p.pid)
+ if match:
+ return StartReused(info=match)
+
+ (stdout_fd, stdout_path) = tempfile.mkstemp(prefix=".tensorboard-stdout-")
+ (stderr_fd, stderr_path) = tempfile.mkstemp(prefix=".tensorboard-stderr-")
+ start_time_seconds = time.time()
+ explicit_tb = os.environ.get("TENSORBOARD_BINARY", None)
+ try:
+ p = subprocess.Popen(
+ ["tensorboard" if explicit_tb is None else explicit_tb] + arguments,
+ stdout=stdout_fd,
+ stderr=stderr_fd,
+ )
+ except OSError as e:
+ return StartExecFailed(os_error=e, explicit_binary=explicit_tb)
+ finally:
+ os.close(stdout_fd)
+ os.close(stderr_fd)
+
+ poll_interval_seconds = 0.5
+ end_time_seconds = start_time_seconds + timeout.total_seconds()
+ while time.time() < end_time_seconds:
+ time.sleep(poll_interval_seconds)
+ subprocess_result = p.poll()
+ if subprocess_result is not None:
+ return StartFailed(
+ exit_code=subprocess_result,
+ stdout=_maybe_read_file(stdout_path),
+ stderr=_maybe_read_file(stderr_path),
+ )
+ for info in get_all():
+ if info.pid == p.pid and info.start_time >= start_time_seconds:
+ return StartLaunched(info=info)
+ else:
+ return StartTimedOut(pid=p.pid)
def _find_matching_instance(cache_key):
- """Find a running TensorBoard instance compatible with the cache key.
+ """Find a running TensorBoard instance compatible with the cache key.
- Returns:
- A `TensorBoardInfo` object, or `None` if none matches the cache key.
- """
- infos = get_all()
- candidates = [info for info in infos if info.cache_key == cache_key]
- for candidate in sorted(candidates, key=lambda x: x.port):
- # TODO(@wchargin): Check here that the provided port is still live.
- return candidate
- return None
+ Returns:
+ A `TensorBoardInfo` object, or `None` if none matches the cache key.
+ """
+ infos = get_all()
+ candidates = [info for info in infos if info.cache_key == cache_key]
+ for candidate in sorted(candidates, key=lambda x: x.port):
+ # TODO(@wchargin): Check here that the provided port is still live.
+ return candidate
+ return None
def _maybe_read_file(filename):
- """Read the given file, if it exists.
-
- Args:
- filename: A path to a file.
-
- Returns:
- A string containing the file contents, or `None` if the file does
- not exist.
- """
- try:
- with open(filename) as infile:
- return infile.read()
- except IOError as e:
- if e.errno == errno.ENOENT:
- return None
+ """Read the given file, if it exists.
+
+ Args:
+ filename: A path to a file.
+
+ Returns:
+ A string containing the file contents, or `None` if the file does
+ not exist.
+ """
+ try:
+ with open(filename) as infile:
+ return infile.read()
+ except IOError as e:
+ if e.errno == errno.ENOENT:
+ return None
diff --git a/tensorboard/manager_e2e_test.py b/tensorboard/manager_e2e_test.py
index 463054c482..9b8843ca4d 100644
--- a/tensorboard/manager_e2e_test.py
+++ b/tensorboard/manager_e2e_test.py
@@ -35,338 +35,344 @@
import tensorflow as tf
try:
- # python version >= 3.3
- from unittest import mock
+ # python version >= 3.3
+ from unittest import mock
except ImportError:
- import mock # pylint: disable=unused-import
+ import mock # pylint: disable=unused-import
from tensorboard import manager
class ManagerEndToEndTest(tf.test.TestCase):
-
- def setUp(self):
- super(ManagerEndToEndTest, self).setUp()
-
- # Spy on subprocesses spawned so that we can kill them.
- self.popens = []
- class PopenSpy(subprocess.Popen):
- def __init__(p, *args, **kwargs):
- super(PopenSpy, p).__init__(*args, **kwargs)
- self.popens.append(p)
- popen_patcher = mock.patch.object(subprocess, "Popen", PopenSpy)
- popen_patcher.start()
-
- # Make sure that temporary files (including .tensorboard-info) are
- # created under our purview.
- self.tmproot = os.path.join(self.get_temp_dir(), "tmproot")
- os.mkdir(self.tmproot)
- self._patch_environ({"TMPDIR": self.tmproot})
- tempfile.tempdir = None # force `gettempdir` to reinitialize from env
- self.assertEqual(tempfile.gettempdir(), self.tmproot)
- self.info_dir = manager._get_info_dir() # ensure that directory exists
-
- # Add our Bazel-provided `tensorboard` to the system path. (The
- # //tensorboard:tensorboard target is made available in the same
- # directory as //tensorboard:manager_e2e_test.)
- tensorboard_binary_dir = os.path.dirname(os.environ["TEST_BINARY"])
- self._patch_environ({
- "PATH": os.pathsep.join((tensorboard_binary_dir, os.environ["PATH"])),
- })
- self._ensure_tensorboard_on_path(tensorboard_binary_dir)
-
- def tearDown(self):
- failed_kills = []
- for p in self.popens:
- try:
- p.kill()
- except Exception as e:
- if isinstance(e, OSError) and e.errno == errno.ESRCH:
- # ESRCH 3 No such process: e.g., it already exited.
- pass
- else:
- # We really want to make sure to try to kill all these
- # processes. Continue killing; fail the test later.
- failed_kills.append(e)
- for p in self.popens:
- p.wait()
- self.assertEqual(failed_kills, [])
-
- def _patch_environ(self, partial_environ):
- patcher = mock.patch.dict(os.environ, partial_environ)
- patcher.start()
- self.addCleanup(patcher.stop)
-
- def _ensure_tensorboard_on_path(self, expected_binary_dir):
- """Ensure that `tensorboard(1)` refers to our own binary.
-
- Raises:
- subprocess.CalledProcessError: If there is no `tensorboard` on the
- path.
- AssertionError: If the `tensorboard` on the path is not under the
- provided directory.
- """
- # In Python 3.3+, we could use `shutil.which` to inspect the path.
- # For Python 2 compatibility, we shell out to the host platform's
- # standard utility.
- command = "where" if os.name == "nt" else "which"
- binary = subprocess.check_output([command, "tensorboard"])
- self.assertTrue(
- binary.startswith(expected_binary_dir.encode("utf-8")),
- "expected %r to start with %r" % (binary, expected_binary_dir),
- )
-
- def _stub_tensorboard(self, name, program):
- """Install a stub version of TensorBoard.
-
- Args:
- name: A short description of the stub's behavior. This will appear
- in the file path, which may appear in error messages.
- program: The contents of the stub: this should probably be a
- string that starts with "#!/bin/sh" and then contains a POSIX
- shell script.
- """
- tempdir = tempfile.mkdtemp(prefix="tensorboard-stub-%s-" % name)
- # (this directory is under our test directory; no need to clean it up)
- filepath = os.path.join(tempdir, "tensorboard")
- with open(filepath, "w") as outfile:
- outfile.write(program)
- os.chmod(filepath, 0o777)
- self._patch_environ({
- "PATH": os.pathsep.join((tempdir, os.environ["PATH"])),
- })
- self._ensure_tensorboard_on_path(expected_binary_dir=tempdir)
-
- def _assert_live(self, info, expected_logdir):
- url = "http://localhost:%d%s/data/logdir" % (info.port, info.path_prefix)
- with contextlib.closing(urllib.request.urlopen(url)) as infile:
- data = infile.read()
- self.assertEqual(json.loads(data), {"logdir": expected_logdir})
-
- def test_simple_start(self):
- start_result = manager.start(["--logdir=./logs", "--port=0"])
- self.assertIsInstance(start_result, manager.StartLaunched)
- self._assert_live(start_result.info, expected_logdir="./logs")
-
- def test_reuse(self):
- r1 = manager.start(["--logdir=./logs", "--port=0"])
- self.assertIsInstance(r1, manager.StartLaunched)
- r2 = manager.start(["--logdir=./logs", "--port=0"])
- self.assertIsInstance(r2, manager.StartReused)
- self.assertEqual(r1.info, r2.info)
- infos = manager.get_all()
- self.assertEqual(infos, [r1.info])
- self._assert_live(r1.info, expected_logdir="./logs")
-
- def test_launch_new_because_incompatible(self):
- r1 = manager.start(["--logdir=./logs", "--port=0"])
- self.assertIsInstance(r1, manager.StartLaunched)
- r2 = manager.start(["--logdir=./adders", "--port=0"])
- self.assertIsInstance(r2, manager.StartLaunched)
- self.assertNotEqual(r1.info.port, r2.info.port)
- self.assertNotEqual(r1.info.pid, r2.info.pid)
- infos = manager.get_all()
- self.assertItemsEqual(infos, [r1.info, r2.info])
- self._assert_live(r1.info, expected_logdir="./logs")
- self._assert_live(r2.info, expected_logdir="./adders")
-
- def test_launch_new_because_info_file_deleted(self):
- r1 = manager.start(["--logdir=./logs", "--port=0"])
- self.assertIsInstance(r1, manager.StartLaunched)
-
- # Now suppose that someone comes and wipes /tmp/...
- self.assertEqual(len(manager.get_all()), 1, manager.get_all())
- shutil.rmtree(self.tmproot)
- os.mkdir(self.tmproot)
- self.assertEqual(len(manager.get_all()), 0, manager.get_all())
-
- # ...so that starting even the same command forces a relaunch:
- r2 = manager.start(["--logdir=./logs", "--port=0"])
- self.assertIsInstance(r2, manager.StartLaunched) # (picked a new port)
- self.assertEqual(r1.info.cache_key, r2.info.cache_key)
- infos = manager.get_all()
- self.assertItemsEqual(infos, [r2.info])
- self._assert_live(r1.info, expected_logdir="./logs")
- self._assert_live(r2.info, expected_logdir="./logs")
-
- def test_reuse_after_kill(self):
- if os.name == "nt":
- self.skipTest("Can't send SIGTERM or SIGINT on Windows.")
- r1 = manager.start(["--logdir=./logs", "--port=0"])
- self.assertIsInstance(r1, manager.StartLaunched)
- os.kill(r1.info.pid, signal.SIGTERM)
- os.waitpid(r1.info.pid, 0)
- r2 = manager.start(["--logdir=./logs", "--port=0"])
- self.assertIsInstance(r2, manager.StartLaunched)
- self.assertEqual(r1.info.cache_key, r2.info.cache_key)
- # It's not technically guaranteed by POSIX that the following holds,
- # but it will unless the OS preemptively recycles PIDs or we somehow
- # cycled exactly through the whole PID space. Neither Linux nor
- # macOS recycles PIDs, so we should be fine.
- self.assertNotEqual(r1.info.pid, r2.info.pid)
- self._assert_live(r2.info, expected_logdir="./logs")
-
- def test_exit_failure(self):
- if os.name == "nt":
- # TODO(@wchargin): This could in principle work on Windows.
- self.skipTest("Requires a POSIX shell for the stub script.")
- self._stub_tensorboard(
- name="fail-with-77",
- program=textwrap.dedent(
- r"""
+ def setUp(self):
+ super(ManagerEndToEndTest, self).setUp()
+
+ # Spy on subprocesses spawned so that we can kill them.
+ self.popens = []
+
+ class PopenSpy(subprocess.Popen):
+ def __init__(p, *args, **kwargs):
+ super(PopenSpy, p).__init__(*args, **kwargs)
+ self.popens.append(p)
+
+ popen_patcher = mock.patch.object(subprocess, "Popen", PopenSpy)
+ popen_patcher.start()
+
+ # Make sure that temporary files (including .tensorboard-info) are
+ # created under our purview.
+ self.tmproot = os.path.join(self.get_temp_dir(), "tmproot")
+ os.mkdir(self.tmproot)
+ self._patch_environ({"TMPDIR": self.tmproot})
+ tempfile.tempdir = None # force `gettempdir` to reinitialize from env
+ self.assertEqual(tempfile.gettempdir(), self.tmproot)
+ self.info_dir = manager._get_info_dir() # ensure that directory exists
+
+ # Add our Bazel-provided `tensorboard` to the system path. (The
+ # //tensorboard:tensorboard target is made available in the same
+ # directory as //tensorboard:manager_e2e_test.)
+ tensorboard_binary_dir = os.path.dirname(os.environ["TEST_BINARY"])
+ self._patch_environ(
+ {
+ "PATH": os.pathsep.join(
+ (tensorboard_binary_dir, os.environ["PATH"])
+ ),
+ }
+ )
+ self._ensure_tensorboard_on_path(tensorboard_binary_dir)
+
+ def tearDown(self):
+ failed_kills = []
+ for p in self.popens:
+ try:
+ p.kill()
+ except Exception as e:
+ if isinstance(e, OSError) and e.errno == errno.ESRCH:
+ # ESRCH 3 No such process: e.g., it already exited.
+ pass
+ else:
+ # We really want to make sure to try to kill all these
+ # processes. Continue killing; fail the test later.
+ failed_kills.append(e)
+ for p in self.popens:
+ p.wait()
+ self.assertEqual(failed_kills, [])
+
+ def _patch_environ(self, partial_environ):
+ patcher = mock.patch.dict(os.environ, partial_environ)
+ patcher.start()
+ self.addCleanup(patcher.stop)
+
+ def _ensure_tensorboard_on_path(self, expected_binary_dir):
+ """Ensure that `tensorboard(1)` refers to our own binary.
+
+ Raises:
+ subprocess.CalledProcessError: If there is no `tensorboard` on the
+ path.
+ AssertionError: If the `tensorboard` on the path is not under the
+ provided directory.
+ """
+ # In Python 3.3+, we could use `shutil.which` to inspect the path.
+ # For Python 2 compatibility, we shell out to the host platform's
+ # standard utility.
+ command = "where" if os.name == "nt" else "which"
+ binary = subprocess.check_output([command, "tensorboard"])
+ self.assertTrue(
+ binary.startswith(expected_binary_dir.encode("utf-8")),
+ "expected %r to start with %r" % (binary, expected_binary_dir),
+ )
+
+ def _stub_tensorboard(self, name, program):
+ """Install a stub version of TensorBoard.
+
+ Args:
+ name: A short description of the stub's behavior. This will appear
+ in the file path, which may appear in error messages.
+ program: The contents of the stub: this should probably be a
+ string that starts with "#!/bin/sh" and then contains a POSIX
+ shell script.
+ """
+ tempdir = tempfile.mkdtemp(prefix="tensorboard-stub-%s-" % name)
+ # (this directory is under our test directory; no need to clean it up)
+ filepath = os.path.join(tempdir, "tensorboard")
+ with open(filepath, "w") as outfile:
+ outfile.write(program)
+ os.chmod(filepath, 0o777)
+ self._patch_environ(
+ {"PATH": os.pathsep.join((tempdir, os.environ["PATH"])),}
+ )
+ self._ensure_tensorboard_on_path(expected_binary_dir=tempdir)
+
+ def _assert_live(self, info, expected_logdir):
+ url = "http://localhost:%d%s/data/logdir" % (
+ info.port,
+ info.path_prefix,
+ )
+ with contextlib.closing(urllib.request.urlopen(url)) as infile:
+ data = infile.read()
+ self.assertEqual(json.loads(data), {"logdir": expected_logdir})
+
+ def test_simple_start(self):
+ start_result = manager.start(["--logdir=./logs", "--port=0"])
+ self.assertIsInstance(start_result, manager.StartLaunched)
+ self._assert_live(start_result.info, expected_logdir="./logs")
+
+ def test_reuse(self):
+ r1 = manager.start(["--logdir=./logs", "--port=0"])
+ self.assertIsInstance(r1, manager.StartLaunched)
+ r2 = manager.start(["--logdir=./logs", "--port=0"])
+ self.assertIsInstance(r2, manager.StartReused)
+ self.assertEqual(r1.info, r2.info)
+ infos = manager.get_all()
+ self.assertEqual(infos, [r1.info])
+ self._assert_live(r1.info, expected_logdir="./logs")
+
+ def test_launch_new_because_incompatible(self):
+ r1 = manager.start(["--logdir=./logs", "--port=0"])
+ self.assertIsInstance(r1, manager.StartLaunched)
+ r2 = manager.start(["--logdir=./adders", "--port=0"])
+ self.assertIsInstance(r2, manager.StartLaunched)
+ self.assertNotEqual(r1.info.port, r2.info.port)
+ self.assertNotEqual(r1.info.pid, r2.info.pid)
+ infos = manager.get_all()
+ self.assertItemsEqual(infos, [r1.info, r2.info])
+ self._assert_live(r1.info, expected_logdir="./logs")
+ self._assert_live(r2.info, expected_logdir="./adders")
+
+ def test_launch_new_because_info_file_deleted(self):
+ r1 = manager.start(["--logdir=./logs", "--port=0"])
+ self.assertIsInstance(r1, manager.StartLaunched)
+
+ # Now suppose that someone comes and wipes /tmp/...
+ self.assertEqual(len(manager.get_all()), 1, manager.get_all())
+ shutil.rmtree(self.tmproot)
+ os.mkdir(self.tmproot)
+ self.assertEqual(len(manager.get_all()), 0, manager.get_all())
+
+ # ...so that starting even the same command forces a relaunch:
+ r2 = manager.start(["--logdir=./logs", "--port=0"])
+ self.assertIsInstance(r2, manager.StartLaunched) # (picked a new port)
+ self.assertEqual(r1.info.cache_key, r2.info.cache_key)
+ infos = manager.get_all()
+ self.assertItemsEqual(infos, [r2.info])
+ self._assert_live(r1.info, expected_logdir="./logs")
+ self._assert_live(r2.info, expected_logdir="./logs")
+
+ def test_reuse_after_kill(self):
+ if os.name == "nt":
+ self.skipTest("Can't send SIGTERM or SIGINT on Windows.")
+ r1 = manager.start(["--logdir=./logs", "--port=0"])
+ self.assertIsInstance(r1, manager.StartLaunched)
+ os.kill(r1.info.pid, signal.SIGTERM)
+ os.waitpid(r1.info.pid, 0)
+ r2 = manager.start(["--logdir=./logs", "--port=0"])
+ self.assertIsInstance(r2, manager.StartLaunched)
+ self.assertEqual(r1.info.cache_key, r2.info.cache_key)
+ # It's not technically guaranteed by POSIX that the following holds,
+ # but it will unless the OS preemptively recycles PIDs or we somehow
+ # cycled exactly through the whole PID space. Neither Linux nor
+ # macOS recycles PIDs, so we should be fine.
+ self.assertNotEqual(r1.info.pid, r2.info.pid)
+ self._assert_live(r2.info, expected_logdir="./logs")
+
+ def test_exit_failure(self):
+ if os.name == "nt":
+ # TODO(@wchargin): This could in principle work on Windows.
+ self.skipTest("Requires a POSIX shell for the stub script.")
+ self._stub_tensorboard(
+ name="fail-with-77",
+ program=textwrap.dedent(
+ r"""
#!/bin/sh
printf >&2 'fatal: something bad happened\n'
printf 'also some stdout\n'
exit 77
""".lstrip(),
- ),
- )
- start_result = manager.start(["--logdir=./logs", "--port=0"])
- self.assertIsInstance(start_result, manager.StartFailed)
- self.assertEqual(
- start_result,
- manager.StartFailed(
- exit_code=77,
- stderr="fatal: something bad happened\n",
- stdout="also some stdout\n",
- ),
- )
- self.assertEqual(manager.get_all(), [])
-
- def test_exit_success(self):
- # TensorBoard exiting with success but not writing the info file is
- # still a failure to launch.
- if os.name == "nt":
- # TODO(@wchargin): This could in principle work on Windows.
- self.skipTest("Requires a POSIX shell for the stub script.")
- self._stub_tensorboard(
- name="fail-with-0",
- program=textwrap.dedent(
- r"""
+ ),
+ )
+ start_result = manager.start(["--logdir=./logs", "--port=0"])
+ self.assertIsInstance(start_result, manager.StartFailed)
+ self.assertEqual(
+ start_result,
+ manager.StartFailed(
+ exit_code=77,
+ stderr="fatal: something bad happened\n",
+ stdout="also some stdout\n",
+ ),
+ )
+ self.assertEqual(manager.get_all(), [])
+
+ def test_exit_success(self):
+ # TensorBoard exiting with success but not writing the info file is
+ # still a failure to launch.
+ if os.name == "nt":
+ # TODO(@wchargin): This could in principle work on Windows.
+ self.skipTest("Requires a POSIX shell for the stub script.")
+ self._stub_tensorboard(
+ name="fail-with-0",
+ program=textwrap.dedent(
+ r"""
#!/bin/sh
printf >&2 'info: something good happened\n'
printf 'also some standard output\n'
exit 0
""".lstrip(),
- ),
- )
- start_result = manager.start(["--logdir=./logs", "--port=0"])
- self.assertIsInstance(start_result, manager.StartFailed)
- self.assertEqual(
- start_result,
- manager.StartFailed(
- exit_code=0,
- stderr="info: something good happened\n",
- stdout="also some standard output\n",
- ),
- )
- self.assertEqual(manager.get_all(), [])
-
- def test_failure_unreadable_stdio(self):
- if os.name == "nt":
- # TODO(@wchargin): This could in principle work on Windows.
- self.skipTest("Requires a POSIX shell for the stub script.")
- self._stub_tensorboard(
- name="fail-and-nuke-tmp",
- program=textwrap.dedent(
- r"""
+ ),
+ )
+ start_result = manager.start(["--logdir=./logs", "--port=0"])
+ self.assertIsInstance(start_result, manager.StartFailed)
+ self.assertEqual(
+ start_result,
+ manager.StartFailed(
+ exit_code=0,
+ stderr="info: something good happened\n",
+ stdout="also some standard output\n",
+ ),
+ )
+ self.assertEqual(manager.get_all(), [])
+
+ def test_failure_unreadable_stdio(self):
+ if os.name == "nt":
+ # TODO(@wchargin): This could in principle work on Windows.
+ self.skipTest("Requires a POSIX shell for the stub script.")
+ self._stub_tensorboard(
+ name="fail-and-nuke-tmp",
+ program=textwrap.dedent(
+ r"""
#!/bin/sh
rm -r %s
exit 22
- """.lstrip() % pipes.quote(self.tmproot),
- ),
- )
- start_result = manager.start(["--logdir=./logs", "--port=0"])
- self.assertIsInstance(start_result, manager.StartFailed)
- self.assertEqual(
- start_result,
- manager.StartFailed(
- exit_code=22,
- stderr=None,
- stdout=None,
- ),
- )
- self.assertEqual(manager.get_all(), [])
-
- def test_timeout(self):
- if os.name == "nt":
- # TODO(@wchargin): This could in principle work on Windows.
- self.skipTest("Requires a POSIX shell for the stub script.")
- tempdir = tempfile.mkdtemp()
- pid_file = os.path.join(tempdir, "pidfile")
- self._stub_tensorboard(
- name="wait-a-minute",
- program=textwrap.dedent(
- r"""
+ """.lstrip()
+ % pipes.quote(self.tmproot),
+ ),
+ )
+ start_result = manager.start(["--logdir=./logs", "--port=0"])
+ self.assertIsInstance(start_result, manager.StartFailed)
+ self.assertEqual(
+ start_result,
+ manager.StartFailed(exit_code=22, stderr=None, stdout=None,),
+ )
+ self.assertEqual(manager.get_all(), [])
+
+ def test_timeout(self):
+ if os.name == "nt":
+ # TODO(@wchargin): This could in principle work on Windows.
+ self.skipTest("Requires a POSIX shell for the stub script.")
+ tempdir = tempfile.mkdtemp()
+ pid_file = os.path.join(tempdir, "pidfile")
+ self._stub_tensorboard(
+ name="wait-a-minute",
+ program=textwrap.dedent(
+ r"""
#!/bin/sh
printf >%s '%%s' "$$"
printf >&2 'warn: I am tired\n'
sleep 60
- """.lstrip() % pipes.quote(os.path.realpath(pid_file)),
- ),
- )
- start_result = manager.start(
- ["--logdir=./logs", "--port=0"],
- timeout=datetime.timedelta(seconds=1),
- )
- self.assertIsInstance(start_result, manager.StartTimedOut)
- with open(pid_file) as infile:
- expected_pid = int(infile.read())
- self.assertEqual(start_result, manager.StartTimedOut(pid=expected_pid))
- self.assertEqual(manager.get_all(), [])
-
- def test_tensorboard_binary_environment_variable(self):
- if os.name == "nt":
- # TODO(@wchargin): This could in principle work on Windows.
- self.skipTest("Requires a POSIX shell for the stub script.")
- tempdir = tempfile.mkdtemp()
- filepath = os.path.join(tempdir, "tensorbad")
- program = textwrap.dedent(
- r"""
+ """.lstrip()
+ % pipes.quote(os.path.realpath(pid_file)),
+ ),
+ )
+ start_result = manager.start(
+ ["--logdir=./logs", "--port=0"],
+ timeout=datetime.timedelta(seconds=1),
+ )
+ self.assertIsInstance(start_result, manager.StartTimedOut)
+ with open(pid_file) as infile:
+ expected_pid = int(infile.read())
+ self.assertEqual(start_result, manager.StartTimedOut(pid=expected_pid))
+ self.assertEqual(manager.get_all(), [])
+
+ def test_tensorboard_binary_environment_variable(self):
+ if os.name == "nt":
+ # TODO(@wchargin): This could in principle work on Windows.
+ self.skipTest("Requires a POSIX shell for the stub script.")
+ tempdir = tempfile.mkdtemp()
+ filepath = os.path.join(tempdir, "tensorbad")
+ program = textwrap.dedent(
+ r"""
#!/bin/sh
printf >&2 'tensorbad: fatal: something bad happened\n'
printf 'tensorbad: also some stdout\n'
exit 77
""".lstrip()
- )
- with open(filepath, "w") as outfile:
- outfile.write(program)
- os.chmod(filepath, 0o777)
- self._patch_environ({"TENSORBOARD_BINARY": filepath})
-
- start_result = manager.start(["--logdir=./logs", "--port=0"])
- self.assertIsInstance(start_result, manager.StartFailed)
- self.assertEqual(
- start_result,
- manager.StartFailed(
- exit_code=77,
- stderr="tensorbad: fatal: something bad happened\n",
- stdout="tensorbad: also some stdout\n",
- ),
- )
- self.assertEqual(manager.get_all(), [])
-
- def test_exec_failure_with_explicit_binary(self):
- path = os.path.join(".", "non", "existent")
- self._patch_environ({"TENSORBOARD_BINARY": path})
-
- start_result = manager.start(["--logdir=./logs", "--port=0"])
- self.assertIsInstance(start_result, manager.StartExecFailed)
- self.assertEqual(start_result.os_error.errno, errno.ENOENT)
- self.assertEqual(start_result.explicit_binary, path)
-
- def test_exec_failure_with_no_explicit_binary(self):
- if os.name == "nt":
- # Can't use ENOENT without an absolute path (it's not treated as
- # an exec failure).
- self.skipTest("Not clear how to trigger this case on Windows.")
- self._patch_environ({"PATH": "nope"})
-
- start_result = manager.start(["--logdir=./logs", "--port=0"])
- self.assertIsInstance(start_result, manager.StartExecFailed)
- self.assertEqual(start_result.os_error.errno, errno.ENOENT)
- self.assertIs(start_result.explicit_binary, None)
+ )
+ with open(filepath, "w") as outfile:
+ outfile.write(program)
+ os.chmod(filepath, 0o777)
+ self._patch_environ({"TENSORBOARD_BINARY": filepath})
+
+ start_result = manager.start(["--logdir=./logs", "--port=0"])
+ self.assertIsInstance(start_result, manager.StartFailed)
+ self.assertEqual(
+ start_result,
+ manager.StartFailed(
+ exit_code=77,
+ stderr="tensorbad: fatal: something bad happened\n",
+ stdout="tensorbad: also some stdout\n",
+ ),
+ )
+ self.assertEqual(manager.get_all(), [])
+
+ def test_exec_failure_with_explicit_binary(self):
+ path = os.path.join(".", "non", "existent")
+ self._patch_environ({"TENSORBOARD_BINARY": path})
+
+ start_result = manager.start(["--logdir=./logs", "--port=0"])
+ self.assertIsInstance(start_result, manager.StartExecFailed)
+ self.assertEqual(start_result.os_error.errno, errno.ENOENT)
+ self.assertEqual(start_result.explicit_binary, path)
+
+ def test_exec_failure_with_no_explicit_binary(self):
+ if os.name == "nt":
+ # Can't use ENOENT without an absolute path (it's not treated as
+ # an exec failure).
+ self.skipTest("Not clear how to trigger this case on Windows.")
+ self._patch_environ({"PATH": "nope"})
+
+ start_result = manager.start(["--logdir=./logs", "--port=0"])
+ self.assertIsInstance(start_result, manager.StartExecFailed)
+ self.assertEqual(start_result.os_error.errno, errno.ENOENT)
+ self.assertIs(start_result.explicit_binary, None)
if __name__ == "__main__":
- tf.test.main()
+ tf.test.main()
diff --git a/tensorboard/manager_test.py b/tensorboard/manager_test.py
index 805d867e76..d76cc40751 100644
--- a/tensorboard/manager_test.py
+++ b/tensorboard/manager_test.py
@@ -28,10 +28,10 @@
import six
try:
- # python version >= 3.3
- from unittest import mock
+ # python version >= 3.3
+ from unittest import mock
except ImportError:
- import mock # pylint: disable=unused-import
+ import mock # pylint: disable=unused-import
from tensorboard import manager
from tensorboard import test as tb_test
@@ -40,365 +40,365 @@
def _make_info(i=0):
- """Make a sample TensorBoardInfo object.
-
- Args:
- i: Seed; vary this value to produce slightly different outputs.
-
- Returns:
- A type-correct `TensorBoardInfo` object.
- """
- return manager.TensorBoardInfo(
- version=version.VERSION,
- start_time=1548973541 + i,
- port=6060 + i,
- pid=76540 + i,
- path_prefix="/foo",
- logdir="~/my_data/",
- db="",
- cache_key="asdf",
- )
+ """Make a sample TensorBoardInfo object.
+
+ Args:
+ i: Seed; vary this value to produce slightly different outputs.
+
+ Returns:
+ A type-correct `TensorBoardInfo` object.
+ """
+ return manager.TensorBoardInfo(
+ version=version.VERSION,
+ start_time=1548973541 + i,
+ port=6060 + i,
+ pid=76540 + i,
+ path_prefix="/foo",
+ logdir="~/my_data/",
+ db="",
+ cache_key="asdf",
+ )
class TensorBoardInfoTest(tb_test.TestCase):
- """Unit tests for TensorBoardInfo typechecking and serialization."""
-
- def test_roundtrip_serialization(self):
- # This is also tested indirectly as part of `manager` integration
- # tests, in `test_get_all`.
- info = _make_info()
- also_info = manager._info_from_string(manager._info_to_string(info))
- self.assertEqual(also_info, info)
-
- def test_serialization_rejects_bad_types(self):
- bad_time = datetime.datetime.fromtimestamp(1549061116) # not an int
- info = _make_info()._replace(start_time=bad_time)
- with six.assertRaisesRegex(
- self,
- ValueError,
- r"expected 'start_time' of type.*int.*, but found: datetime\."):
- manager._info_to_string(info)
-
- def test_serialization_rejects_wrong_version(self):
- info = _make_info()._replace(version="reversion")
- with six.assertRaisesRegex(
- self,
- ValueError,
- "expected 'version' to be '.*', but found: 'reversion'"):
- manager._info_to_string(info)
-
- def test_deserialization_rejects_bad_json(self):
- bad_input = "parse me if you dare"
- with six.assertRaisesRegex(
- self,
- ValueError,
- "invalid JSON:"):
- manager._info_from_string(bad_input)
-
- def test_deserialization_rejects_non_object_json(self):
- bad_input = "[1, 2]"
- with six.assertRaisesRegex(
- self,
- ValueError,
- re.escape("not a JSON object: [1, 2]")):
- manager._info_from_string(bad_input)
-
- def test_deserialization_rejects_missing_version(self):
- info = _make_info()
- json_value = json.loads(manager._info_to_string(info))
- del json_value["version"]
- bad_input = json.dumps(json_value)
- with six.assertRaisesRegex(
- self,
- ValueError,
- re.escape("missing keys: ['version']")):
- manager._info_from_string(bad_input)
-
- def test_deserialization_accepts_future_version(self):
- info = _make_info()
- json_value = json.loads(manager._info_to_string(info))
- json_value["version"] = "99.99.99a20991232"
- input_ = json.dumps(json_value)
- result = manager._info_from_string(input_)
- self.assertEqual(result.version, "99.99.99a20991232")
-
- def test_deserialization_ignores_extra_keys(self):
- info = _make_info()
- json_value = json.loads(manager._info_to_string(info))
- json_value["unlikely"] = "story"
- bad_input = json.dumps(json_value)
- result = manager._info_from_string(bad_input)
- self.assertIsInstance(result, manager.TensorBoardInfo)
-
- def test_deserialization_rejects_missing_keys(self):
- info = _make_info()
- json_value = json.loads(manager._info_to_string(info))
- del json_value["start_time"]
- bad_input = json.dumps(json_value)
- with six.assertRaisesRegex(
- self,
- ValueError,
- re.escape("missing keys: ['start_time']")):
- manager._info_from_string(bad_input)
-
- def test_deserialization_rejects_bad_types(self):
- info = _make_info()
- json_value = json.loads(manager._info_to_string(info))
- json_value["start_time"] = "2001-02-03T04:05:06"
- bad_input = json.dumps(json_value)
- with six.assertRaisesRegex(
- self,
- ValueError,
- "expected 'start_time' of type.*int.*, but found:.*"
- "'2001-02-03T04:05:06'"):
- manager._info_from_string(bad_input)
-
- def test_logdir_data_source_format(self):
- info = _make_info()._replace(logdir="~/foo", db="")
- self.assertEqual(manager.data_source_from_info(info), "logdir ~/foo")
-
- def test_db_data_source_format(self):
- info = _make_info()._replace(logdir="", db="sqlite:~/bar")
- self.assertEqual(manager.data_source_from_info(info), "db sqlite:~/bar")
+ """Unit tests for TensorBoardInfo typechecking and serialization."""
+
+ def test_roundtrip_serialization(self):
+ # This is also tested indirectly as part of `manager` integration
+ # tests, in `test_get_all`.
+ info = _make_info()
+ also_info = manager._info_from_string(manager._info_to_string(info))
+ self.assertEqual(also_info, info)
+
+ def test_serialization_rejects_bad_types(self):
+ bad_time = datetime.datetime.fromtimestamp(1549061116) # not an int
+ info = _make_info()._replace(start_time=bad_time)
+ with six.assertRaisesRegex(
+ self,
+ ValueError,
+ r"expected 'start_time' of type.*int.*, but found: datetime\.",
+ ):
+ manager._info_to_string(info)
+
+ def test_serialization_rejects_wrong_version(self):
+ info = _make_info()._replace(version="reversion")
+ with six.assertRaisesRegex(
+ self,
+ ValueError,
+ "expected 'version' to be '.*', but found: 'reversion'",
+ ):
+ manager._info_to_string(info)
+
+ def test_deserialization_rejects_bad_json(self):
+ bad_input = "parse me if you dare"
+ with six.assertRaisesRegex(self, ValueError, "invalid JSON:"):
+ manager._info_from_string(bad_input)
+
+ def test_deserialization_rejects_non_object_json(self):
+ bad_input = "[1, 2]"
+ with six.assertRaisesRegex(
+ self, ValueError, re.escape("not a JSON object: [1, 2]")
+ ):
+ manager._info_from_string(bad_input)
+
+ def test_deserialization_rejects_missing_version(self):
+ info = _make_info()
+ json_value = json.loads(manager._info_to_string(info))
+ del json_value["version"]
+ bad_input = json.dumps(json_value)
+ with six.assertRaisesRegex(
+ self, ValueError, re.escape("missing keys: ['version']")
+ ):
+ manager._info_from_string(bad_input)
+
+ def test_deserialization_accepts_future_version(self):
+ info = _make_info()
+ json_value = json.loads(manager._info_to_string(info))
+ json_value["version"] = "99.99.99a20991232"
+ input_ = json.dumps(json_value)
+ result = manager._info_from_string(input_)
+ self.assertEqual(result.version, "99.99.99a20991232")
+
+ def test_deserialization_ignores_extra_keys(self):
+ info = _make_info()
+ json_value = json.loads(manager._info_to_string(info))
+ json_value["unlikely"] = "story"
+ bad_input = json.dumps(json_value)
+ result = manager._info_from_string(bad_input)
+ self.assertIsInstance(result, manager.TensorBoardInfo)
+
+ def test_deserialization_rejects_missing_keys(self):
+ info = _make_info()
+ json_value = json.loads(manager._info_to_string(info))
+ del json_value["start_time"]
+ bad_input = json.dumps(json_value)
+ with six.assertRaisesRegex(
+ self, ValueError, re.escape("missing keys: ['start_time']")
+ ):
+ manager._info_from_string(bad_input)
+
+ def test_deserialization_rejects_bad_types(self):
+ info = _make_info()
+ json_value = json.loads(manager._info_to_string(info))
+ json_value["start_time"] = "2001-02-03T04:05:06"
+ bad_input = json.dumps(json_value)
+ with six.assertRaisesRegex(
+ self,
+ ValueError,
+ "expected 'start_time' of type.*int.*, but found:.*"
+ "'2001-02-03T04:05:06'",
+ ):
+ manager._info_from_string(bad_input)
+
+ def test_logdir_data_source_format(self):
+ info = _make_info()._replace(logdir="~/foo", db="")
+ self.assertEqual(manager.data_source_from_info(info), "logdir ~/foo")
+
+ def test_db_data_source_format(self):
+ info = _make_info()._replace(logdir="", db="sqlite:~/bar")
+ self.assertEqual(manager.data_source_from_info(info), "db sqlite:~/bar")
class CacheKeyTest(tb_test.TestCase):
- """Unit tests for `manager.cache_key`."""
+ """Unit tests for `manager.cache_key`."""
- def test_result_is_str(self):
- result = manager.cache_key(
- working_directory="/home/me",
- arguments=["--logdir", "something"],
- configure_kwargs={},
- )
- self.assertIsInstance(result, str)
-
- def test_depends_on_working_directory(self):
- results = [
- manager.cache_key(
- working_directory=d,
+ def test_result_is_str(self):
+ result = manager.cache_key(
+ working_directory="/home/me",
arguments=["--logdir", "something"],
configure_kwargs={},
)
- for d in ("/home/me", "/home/you")
- ]
- self.assertEqual(len(results), len(set(results)))
-
- def test_depends_on_arguments(self):
- results = [
- manager.cache_key(
+ self.assertIsInstance(result, str)
+
+ def test_depends_on_working_directory(self):
+ results = [
+ manager.cache_key(
+ working_directory=d,
+ arguments=["--logdir", "something"],
+ configure_kwargs={},
+ )
+ for d in ("/home/me", "/home/you")
+ ]
+ self.assertEqual(len(results), len(set(results)))
+
+ def test_depends_on_arguments(self):
+ results = [
+ manager.cache_key(
+ working_directory="/home/me",
+ arguments=arguments,
+ configure_kwargs={},
+ )
+ for arguments in (
+ ["--logdir=something"],
+ ["--logdir", "something"],
+ ["--logdir", "", "something"],
+ ["--logdir", "", "something", ""],
+ )
+ ]
+ self.assertEqual(len(results), len(set(results)))
+
+ def test_depends_on_configure_kwargs(self):
+ results = [
+ manager.cache_key(
+ working_directory="/home/me",
+ arguments=[],
+ configure_kwargs=configure_kwargs,
+ )
+ for configure_kwargs in (
+ {"logdir": "something"},
+ {"logdir": "something_else"},
+ {"logdir": "something", "port": "6006"},
+ )
+ ]
+ self.assertEqual(len(results), len(set(results)))
+
+ def test_arguments_and_configure_kwargs_independent(self):
+ # This test documents current behavior; its existence shouldn't be
+ # interpreted as mandating the behavior. In fact, it would be nice
+ # for `arguments` and `configure_kwargs` to be semantically merged
+ # in the cache key computation, but we don't currently do that.
+ results = [
+ manager.cache_key(
+ working_directory="/home/me",
+ arguments=["--logdir", "something"],
+ configure_kwargs={},
+ ),
+ manager.cache_key(
+ working_directory="/home/me",
+ arguments=[],
+ configure_kwargs={"logdir": "something"},
+ ),
+ ]
+ self.assertEqual(len(results), len(set(results)))
+
+ def test_arguments_list_vs_tuple_irrelevant(self):
+ with_list = manager.cache_key(
working_directory="/home/me",
- arguments=arguments,
+ arguments=["--logdir", "something"],
configure_kwargs={},
)
- for arguments in (
- ["--logdir=something"],
- ["--logdir", "something"],
- ["--logdir", "", "something"],
- ["--logdir", "", "something", ""],
- )
- ]
- self.assertEqual(len(results), len(set(results)))
-
- def test_depends_on_configure_kwargs(self):
- results = [
- manager.cache_key(
- working_directory="/home/me",
- arguments=[],
- configure_kwargs=configure_kwargs,
- )
- for configure_kwargs in (
- {"logdir": "something"},
- {"logdir": "something_else"},
- {"logdir": "something", "port": "6006"},
- )
- ]
- self.assertEqual(len(results), len(set(results)))
-
- def test_arguments_and_configure_kwargs_independent(self):
- # This test documents current behavior; its existence shouldn't be
- # interpreted as mandating the behavior. In fact, it would be nice
- # for `arguments` and `configure_kwargs` to be semantically merged
- # in the cache key computation, but we don't currently do that.
- results = [
- manager.cache_key(
+ with_tuple = manager.cache_key(
working_directory="/home/me",
- arguments=["--logdir", "something"],
+ arguments=("--logdir", "something"),
configure_kwargs={},
- ),
- manager.cache_key(
- working_directory="/home/me",
- arguments=[],
- configure_kwargs={"logdir": "something"},
- ),
- ]
- self.assertEqual(len(results), len(set(results)))
-
- def test_arguments_list_vs_tuple_irrelevant(self):
- with_list = manager.cache_key(
- working_directory="/home/me",
- arguments=["--logdir", "something"],
- configure_kwargs={},
- )
- with_tuple = manager.cache_key(
- working_directory="/home/me",
- arguments=("--logdir", "something"),
- configure_kwargs={},
- )
- self.assertEqual(with_list, with_tuple)
+ )
+ self.assertEqual(with_list, with_tuple)
class TensorBoardInfoIoTest(tb_test.TestCase):
- """Tests for `write_info_file`, `remove_info_file`, and `get_all`."""
-
- def setUp(self):
- super(TensorBoardInfoIoTest, self).setUp()
- patcher = mock.patch.dict(os.environ, {"TMPDIR": self.get_temp_dir()})
- patcher.start()
- self.addCleanup(patcher.stop)
- tempfile.tempdir = None # force `gettempdir` to reinitialize from env
- self.info_dir = manager._get_info_dir() # ensure that directory exists
-
- def _list_info_dir(self):
- return os.listdir(self.info_dir)
-
- def assertMode(self, path, expected):
- """Assert that the permission bits of a file are as expected.
-
- Args:
- path: File to stat.
- expected: `int`; a subset of 0o777.
-
- Raises:
- AssertionError: If the permissions bits of `path` do not match
- `expected`.
- """
- stat_result = os.stat(path)
- format_mode = lambda m: "0o%03o" % m
- self.assertEqual(
- format_mode(stat_result.st_mode & 0o777),
- format_mode(expected),
- )
-
- def test_fails_if_info_dir_name_is_taken_by_a_regular_file(self):
- os.rmdir(self.info_dir)
- with open(self.info_dir, "w") as outfile:
- pass
- with self.assertRaises(OSError) as cm:
- manager._get_info_dir()
- self.assertEqual(cm.exception.errno, errno.EEXIST, cm.exception)
-
- @mock.patch("os.getpid", lambda: 76540)
- def test_directory_world_accessible(self):
- """Test that the TensorBoardInfo directory is world-accessible.
+ """Tests for `write_info_file`, `remove_info_file`, and `get_all`."""
+
+ def setUp(self):
+ super(TensorBoardInfoIoTest, self).setUp()
+ patcher = mock.patch.dict(os.environ, {"TMPDIR": self.get_temp_dir()})
+ patcher.start()
+ self.addCleanup(patcher.stop)
+ tempfile.tempdir = None # force `gettempdir` to reinitialize from env
+ self.info_dir = manager._get_info_dir() # ensure that directory exists
+
+ def _list_info_dir(self):
+ return os.listdir(self.info_dir)
+
+ def assertMode(self, path, expected):
+ """Assert that the permission bits of a file are as expected.
+
+ Args:
+ path: File to stat.
+ expected: `int`; a subset of 0o777.
+
+ Raises:
+ AssertionError: If the permissions bits of `path` do not match
+ `expected`.
+ """
+ stat_result = os.stat(path)
+ format_mode = lambda m: "0o%03o" % m
+ self.assertEqual(
+ format_mode(stat_result.st_mode & 0o777), format_mode(expected),
+ )
- Regression test for issue #2010:
-
- """
- if os.name == "nt":
- self.skipTest("Windows does not use POSIX-style permissions.")
- os.rmdir(self.info_dir)
- # The default umask is typically 0o022, in which case this test is
- # nontrivial. In the unlikely case that the umask is 0o000, we'll
- # still be covered by the "restrictive umask" test case below.
- manager.write_info_file(_make_info())
- self.assertMode(self.info_dir, 0o777)
- self.assertEqual(self._list_info_dir(), ["pid-76540.info"])
-
- @mock.patch("os.getpid", lambda: 76540)
- def test_writing_file_with_restrictive_umask(self):
- if os.name == "nt":
- self.skipTest("Windows does not use POSIX-style permissions.")
- os.rmdir(self.info_dir)
- # Even if umask prevents owner-access, our I/O should still work.
- old_umask = os.umask(0o777)
- try:
- # Sanity-check that, without special accommodation, this would
- # create inaccessible directories...
- sanity_dir = os.path.join(self.get_temp_dir(), "canary")
- os.mkdir(sanity_dir)
- self.assertMode(sanity_dir, 0o000)
-
- manager.write_info_file(_make_info())
- self.assertMode(self.info_dir, 0o777)
- self.assertEqual(self._list_info_dir(), ["pid-76540.info"])
- finally:
- self.assertEqual(oct(os.umask(old_umask)), oct(0o777))
-
- @mock.patch("os.getpid", lambda: 76540)
- def test_write_remove_info_file(self):
- info = _make_info()
- self.assertEqual(self._list_info_dir(), [])
- manager.write_info_file(info)
- filename = "pid-76540.info"
- expected_filepath = os.path.join(self.info_dir, filename)
- self.assertEqual(self._list_info_dir(), [filename])
- with open(expected_filepath) as infile:
- self.assertEqual(manager._info_from_string(infile.read()), info)
- manager.remove_info_file()
- self.assertEqual(self._list_info_dir(), [])
-
- def test_write_info_file_rejects_bad_types(self):
- # The particulars of validation are tested more thoroughly in
- # `TensorBoardInfoTest` above.
- bad_time = datetime.datetime.fromtimestamp(1549061116)
- info = _make_info()._replace(start_time=bad_time)
- with six.assertRaisesRegex(
- self,
- ValueError,
- r"expected 'start_time' of type.*int.*, but found: datetime\."):
- manager.write_info_file(info)
- self.assertEqual(self._list_info_dir(), [])
-
- def test_write_info_file_rejects_wrong_version(self):
- # The particulars of validation are tested more thoroughly in
- # `TensorBoardInfoTest` above.
- info = _make_info()._replace(version="reversion")
- with six.assertRaisesRegex(
- self,
- ValueError,
- "expected 'version' to be '.*', but found: 'reversion'"):
- manager.write_info_file(info)
- self.assertEqual(self._list_info_dir(), [])
-
- def test_remove_nonexistent(self):
- # Should be a no-op, except to create the info directory if
- # necessary. In particular, should not raise any exception.
- manager.remove_info_file()
-
- def test_get_all(self):
- def add_info(i):
- with mock.patch("os.getpid", lambda: 76540 + i):
- manager.write_info_file(_make_info(i))
- def remove_info(i):
- with mock.patch("os.getpid", lambda: 76540 + i):
+ def test_fails_if_info_dir_name_is_taken_by_a_regular_file(self):
+ os.rmdir(self.info_dir)
+ with open(self.info_dir, "w") as outfile:
+ pass
+ with self.assertRaises(OSError) as cm:
+ manager._get_info_dir()
+ self.assertEqual(cm.exception.errno, errno.EEXIST, cm.exception)
+
+ @mock.patch("os.getpid", lambda: 76540)
+ def test_directory_world_accessible(self):
+ """Test that the TensorBoardInfo directory is world-accessible.
+
+ Regression test for issue #2010:
+
+ """
+ if os.name == "nt":
+ self.skipTest("Windows does not use POSIX-style permissions.")
+ os.rmdir(self.info_dir)
+ # The default umask is typically 0o022, in which case this test is
+ # nontrivial. In the unlikely case that the umask is 0o000, we'll
+ # still be covered by the "restrictive umask" test case below.
+ manager.write_info_file(_make_info())
+ self.assertMode(self.info_dir, 0o777)
+ self.assertEqual(self._list_info_dir(), ["pid-76540.info"])
+
+ @mock.patch("os.getpid", lambda: 76540)
+ def test_writing_file_with_restrictive_umask(self):
+ if os.name == "nt":
+ self.skipTest("Windows does not use POSIX-style permissions.")
+ os.rmdir(self.info_dir)
+ # Even if umask prevents owner-access, our I/O should still work.
+ old_umask = os.umask(0o777)
+ try:
+ # Sanity-check that, without special accommodation, this would
+ # create inaccessible directories...
+ sanity_dir = os.path.join(self.get_temp_dir(), "canary")
+ os.mkdir(sanity_dir)
+ self.assertMode(sanity_dir, 0o000)
+
+ manager.write_info_file(_make_info())
+ self.assertMode(self.info_dir, 0o777)
+ self.assertEqual(self._list_info_dir(), ["pid-76540.info"])
+ finally:
+ self.assertEqual(oct(os.umask(old_umask)), oct(0o777))
+
+ @mock.patch("os.getpid", lambda: 76540)
+ def test_write_remove_info_file(self):
+ info = _make_info()
+ self.assertEqual(self._list_info_dir(), [])
+ manager.write_info_file(info)
+ filename = "pid-76540.info"
+ expected_filepath = os.path.join(self.info_dir, filename)
+ self.assertEqual(self._list_info_dir(), [filename])
+ with open(expected_filepath) as infile:
+ self.assertEqual(manager._info_from_string(infile.read()), info)
manager.remove_info_file()
- self.assertItemsEqual(manager.get_all(), [])
- add_info(1)
- self.assertItemsEqual(manager.get_all(), [_make_info(1)])
- add_info(2)
- self.assertItemsEqual(manager.get_all(), [_make_info(1), _make_info(2)])
- remove_info(1)
- self.assertItemsEqual(manager.get_all(), [_make_info(2)])
- add_info(3)
- self.assertItemsEqual(manager.get_all(), [_make_info(2), _make_info(3)])
- remove_info(3)
- self.assertItemsEqual(manager.get_all(), [_make_info(2)])
- remove_info(2)
- self.assertItemsEqual(manager.get_all(), [])
-
- def test_get_all_ignores_bad_files(self):
- with open(os.path.join(self.info_dir, "pid-1234.info"), "w") as outfile:
- outfile.write("good luck parsing this\n")
- with open(os.path.join(self.info_dir, "pid-5678.info"), "w") as outfile:
- outfile.write('{"valid_json":"yes","valid_tbinfo":"no"}\n')
- with open(os.path.join(self.info_dir, "pid-9012.info"), "w") as outfile:
- outfile.write('if a tbinfo has st_mode==0, does it make a sound?\n')
- os.chmod(os.path.join(self.info_dir, "pid-9012.info"), 0o000)
- with mock.patch.object(tb_logging.get_logger(), "debug") as fn:
- self.assertEqual(manager.get_all(), [])
- self.assertEqual(fn.call_count, 2) # 2 invalid, 1 unreadable (silent)
+ self.assertEqual(self._list_info_dir(), [])
+
+ def test_write_info_file_rejects_bad_types(self):
+ # The particulars of validation are tested more thoroughly in
+ # `TensorBoardInfoTest` above.
+ bad_time = datetime.datetime.fromtimestamp(1549061116)
+ info = _make_info()._replace(start_time=bad_time)
+ with six.assertRaisesRegex(
+ self,
+ ValueError,
+ r"expected 'start_time' of type.*int.*, but found: datetime\.",
+ ):
+ manager.write_info_file(info)
+ self.assertEqual(self._list_info_dir(), [])
+
+ def test_write_info_file_rejects_wrong_version(self):
+ # The particulars of validation are tested more thoroughly in
+ # `TensorBoardInfoTest` above.
+ info = _make_info()._replace(version="reversion")
+ with six.assertRaisesRegex(
+ self,
+ ValueError,
+ "expected 'version' to be '.*', but found: 'reversion'",
+ ):
+ manager.write_info_file(info)
+ self.assertEqual(self._list_info_dir(), [])
+
+ def test_remove_nonexistent(self):
+ # Should be a no-op, except to create the info directory if
+ # necessary. In particular, should not raise any exception.
+ manager.remove_info_file()
+
+ def test_get_all(self):
+ def add_info(i):
+ with mock.patch("os.getpid", lambda: 76540 + i):
+ manager.write_info_file(_make_info(i))
+
+ def remove_info(i):
+ with mock.patch("os.getpid", lambda: 76540 + i):
+ manager.remove_info_file()
+
+ self.assertItemsEqual(manager.get_all(), [])
+ add_info(1)
+ self.assertItemsEqual(manager.get_all(), [_make_info(1)])
+ add_info(2)
+ self.assertItemsEqual(manager.get_all(), [_make_info(1), _make_info(2)])
+ remove_info(1)
+ self.assertItemsEqual(manager.get_all(), [_make_info(2)])
+ add_info(3)
+ self.assertItemsEqual(manager.get_all(), [_make_info(2), _make_info(3)])
+ remove_info(3)
+ self.assertItemsEqual(manager.get_all(), [_make_info(2)])
+ remove_info(2)
+ self.assertItemsEqual(manager.get_all(), [])
+
+ def test_get_all_ignores_bad_files(self):
+ with open(os.path.join(self.info_dir, "pid-1234.info"), "w") as outfile:
+ outfile.write("good luck parsing this\n")
+ with open(os.path.join(self.info_dir, "pid-5678.info"), "w") as outfile:
+ outfile.write('{"valid_json":"yes","valid_tbinfo":"no"}\n')
+ with open(os.path.join(self.info_dir, "pid-9012.info"), "w") as outfile:
+ outfile.write("if a tbinfo has st_mode==0, does it make a sound?\n")
+ os.chmod(os.path.join(self.info_dir, "pid-9012.info"), 0o000)
+ with mock.patch.object(tb_logging.get_logger(), "debug") as fn:
+ self.assertEqual(manager.get_all(), [])
+ self.assertEqual(fn.call_count, 2) # 2 invalid, 1 unreadable (silent)
if __name__ == "__main__":
- tb_test.main()
+ tb_test.main()
diff --git a/tensorboard/notebook.py b/tensorboard/notebook.py
index 05c0c304e2..406aad3a76 100644
--- a/tensorboard/notebook.py
+++ b/tensorboard/notebook.py
@@ -30,13 +30,15 @@
import time
try:
- import html
- html_escape = html.escape
- del html
+ import html
+
+ html_escape = html.escape
+ del html
except ImportError:
- import cgi
- html_escape = cgi.escape
- del cgi
+ import cgi
+
+ html_escape = cgi.escape
+ del cgi
from tensorboard import manager
@@ -49,304 +51,299 @@
def _get_context():
- """Determine the most specific context that we're in.
-
- Returns:
- _CONTEXT_COLAB: If in Colab with an IPython notebook context.
- _CONTEXT_IPYTHON: If not in Colab, but we are in an IPython notebook
- context (e.g., from running `jupyter notebook` at the command
- line).
- _CONTEXT_NONE: Otherwise (e.g., by running a Python script at the
- command-line or using the `ipython` interactive shell).
- """
- # In Colab, the `google.colab` module is available, but the shell
- # returned by `IPython.get_ipython` does not have a `get_trait`
- # method.
- try:
- import google.colab
- import IPython
- except ImportError:
- pass
- else:
- if IPython.get_ipython() is not None:
- # We'll assume that we're in a Colab notebook context.
- return _CONTEXT_COLAB
-
- # In an IPython command line shell or Jupyter notebook, we can
- # directly query whether we're in a notebook context.
- try:
- import IPython
- except ImportError:
- pass
- else:
- ipython = IPython.get_ipython()
- if ipython is not None and ipython.has_trait("kernel"):
- return _CONTEXT_IPYTHON
-
- # Otherwise, we're not in a known notebook context.
- return _CONTEXT_NONE
+ """Determine the most specific context that we're in.
+
+ Returns:
+ _CONTEXT_COLAB: If in Colab with an IPython notebook context.
+ _CONTEXT_IPYTHON: If not in Colab, but we are in an IPython notebook
+ context (e.g., from running `jupyter notebook` at the command
+ line).
+ _CONTEXT_NONE: Otherwise (e.g., by running a Python script at the
+ command-line or using the `ipython` interactive shell).
+ """
+ # In Colab, the `google.colab` module is available, but the shell
+ # returned by `IPython.get_ipython` does not have a `get_trait`
+ # method.
+ try:
+ import google.colab
+ import IPython
+ except ImportError:
+ pass
+ else:
+ if IPython.get_ipython() is not None:
+ # We'll assume that we're in a Colab notebook context.
+ return _CONTEXT_COLAB
+
+ # In an IPython command line shell or Jupyter notebook, we can
+ # directly query whether we're in a notebook context.
+ try:
+ import IPython
+ except ImportError:
+ pass
+ else:
+ ipython = IPython.get_ipython()
+ if ipython is not None and ipython.has_trait("kernel"):
+ return _CONTEXT_IPYTHON
+
+ # Otherwise, we're not in a known notebook context.
+ return _CONTEXT_NONE
def load_ipython_extension(ipython):
- """Deprecated: use `%load_ext tensorboard` instead.
+ """Deprecated: use `%load_ext tensorboard` instead.
Raises:
RuntimeError: Always.
"""
- raise RuntimeError(
- "Use '%load_ext tensorboard' instead of '%load_ext tensorboard.notebook'."
- )
+ raise RuntimeError(
+ "Use '%load_ext tensorboard' instead of '%load_ext tensorboard.notebook'."
+ )
def _load_ipython_extension(ipython):
- """Load the TensorBoard notebook extension.
+ """Load the TensorBoard notebook extension.
- Intended to be called from `%load_ext tensorboard`. Do not invoke this
- directly.
+ Intended to be called from `%load_ext tensorboard`. Do not invoke this
+ directly.
- Args:
- ipython: An `IPython.InteractiveShell` instance.
- """
- _register_magics(ipython)
+ Args:
+ ipython: An `IPython.InteractiveShell` instance.
+ """
+ _register_magics(ipython)
def _register_magics(ipython):
- """Register IPython line/cell magics.
+ """Register IPython line/cell magics.
- Args:
- ipython: An `InteractiveShell` instance.
- """
- ipython.register_magic_function(
- _start_magic,
- magic_kind="line",
- magic_name="tensorboard",
- )
+ Args:
+ ipython: An `InteractiveShell` instance.
+ """
+ ipython.register_magic_function(
+ _start_magic, magic_kind="line", magic_name="tensorboard",
+ )
def _start_magic(line):
- """Implementation of the `%tensorboard` line magic."""
- return start(line)
+ """Implementation of the `%tensorboard` line magic."""
+ return start(line)
def start(args_string):
- """Launch and display a TensorBoard instance as if at the command line.
-
- Args:
- args_string: Command-line arguments to TensorBoard, to be
- interpreted by `shlex.split`: e.g., "--logdir ./logs --port 0".
- Shell metacharacters are not supported: e.g., "--logdir 2>&1" will
- point the logdir at the literal directory named "2>&1".
- """
- context = _get_context()
- try:
- import IPython
- import IPython.display
- except ImportError:
- IPython = None
-
- if context == _CONTEXT_NONE:
- handle = None
- print("Launching TensorBoard...")
- else:
- handle = IPython.display.display(
- IPython.display.Pretty("Launching TensorBoard..."),
- display_id=True,
- )
-
- def print_or_update(message):
- if handle is None:
- print(message)
+ """Launch and display a TensorBoard instance as if at the command line.
+
+ Args:
+ args_string: Command-line arguments to TensorBoard, to be
+ interpreted by `shlex.split`: e.g., "--logdir ./logs --port 0".
+ Shell metacharacters are not supported: e.g., "--logdir 2>&1" will
+ point the logdir at the literal directory named "2>&1".
+ """
+ context = _get_context()
+ try:
+ import IPython
+ import IPython.display
+ except ImportError:
+ IPython = None
+
+ if context == _CONTEXT_NONE:
+ handle = None
+ print("Launching TensorBoard...")
else:
- handle.update(IPython.display.Pretty(message))
+ handle = IPython.display.display(
+ IPython.display.Pretty("Launching TensorBoard..."), display_id=True,
+ )
- parsed_args = shlex.split(args_string, comments=True, posix=True)
- start_result = manager.start(parsed_args)
+ def print_or_update(message):
+ if handle is None:
+ print(message)
+ else:
+ handle.update(IPython.display.Pretty(message))
- if isinstance(start_result, manager.StartLaunched):
- _display(
- port=start_result.info.port,
- print_message=False,
- display_handle=handle,
- )
+ parsed_args = shlex.split(args_string, comments=True, posix=True)
+ start_result = manager.start(parsed_args)
- elif isinstance(start_result, manager.StartReused):
- template = (
- "Reusing TensorBoard on port {port} (pid {pid}), started {delta} ago. "
- "(Use '!kill {pid}' to kill it.)"
- )
- message = template.format(
- port=start_result.info.port,
- pid=start_result.info.pid,
- delta=_time_delta_from_info(start_result.info),
- )
- print_or_update(message)
- _display(
- port=start_result.info.port,
- print_message=False,
- display_handle=None,
- )
+ if isinstance(start_result, manager.StartLaunched):
+ _display(
+ port=start_result.info.port,
+ print_message=False,
+ display_handle=handle,
+ )
- elif isinstance(start_result, manager.StartFailed):
- def format_stream(name, value):
- if value == "":
- return ""
- elif value is None:
- return "\n" % name
- else:
- return "\nContents of %s:\n%s" % (name, value.strip())
- message = (
- "ERROR: Failed to launch TensorBoard (exited with %d).%s%s" %
- (
- start_result.exit_code,
- format_stream("stderr", start_result.stderr),
- format_stream("stdout", start_result.stdout),
+ elif isinstance(start_result, manager.StartReused):
+ template = (
+ "Reusing TensorBoard on port {port} (pid {pid}), started {delta} ago. "
+ "(Use '!kill {pid}' to kill it.)"
)
- )
- print_or_update(message)
+ message = template.format(
+ port=start_result.info.port,
+ pid=start_result.info.pid,
+ delta=_time_delta_from_info(start_result.info),
+ )
+ print_or_update(message)
+ _display(
+ port=start_result.info.port,
+ print_message=False,
+ display_handle=None,
+ )
+
+ elif isinstance(start_result, manager.StartFailed):
+
+ def format_stream(name, value):
+ if value == "":
+ return ""
+ elif value is None:
+ return "\n" % name
+ else:
+ return "\nContents of %s:\n%s" % (name, value.strip())
+
+ message = (
+ "ERROR: Failed to launch TensorBoard (exited with %d).%s%s"
+ % (
+ start_result.exit_code,
+ format_stream("stderr", start_result.stderr),
+ format_stream("stdout", start_result.stdout),
+ )
+ )
+ print_or_update(message)
- elif isinstance(start_result, manager.StartExecFailed):
- the_tensorboard_binary = (
- "%r (set by the `TENSORBOARD_BINARY` environment variable)"
+ elif isinstance(start_result, manager.StartExecFailed):
+ the_tensorboard_binary = (
+ "%r (set by the `TENSORBOARD_BINARY` environment variable)"
% (start_result.explicit_binary,)
- if start_result.explicit_binary is not None
- else "`tensorboard`"
- )
- if start_result.os_error.errno == errno.ENOENT:
- message = (
- "ERROR: Could not find %s. Please ensure that your PATH contains "
- "an executable `tensorboard` program, or explicitly specify the path "
- "to a TensorBoard binary by setting the `TENSORBOARD_BINARY` "
- "environment variable."
- % (the_tensorboard_binary,)
- )
- else:
- message = (
- "ERROR: Failed to start %s: %s"
- % (the_tensorboard_binary, start_result.os_error)
- )
- print_or_update(textwrap.fill(message))
-
- elif isinstance(start_result, manager.StartTimedOut):
- message = (
- "ERROR: Timed out waiting for TensorBoard to start. "
- "It may still be running as pid %d."
- % start_result.pid
- )
- print_or_update(message)
+ if start_result.explicit_binary is not None
+ else "`tensorboard`"
+ )
+ if start_result.os_error.errno == errno.ENOENT:
+ message = (
+ "ERROR: Could not find %s. Please ensure that your PATH contains "
+ "an executable `tensorboard` program, or explicitly specify the path "
+ "to a TensorBoard binary by setting the `TENSORBOARD_BINARY` "
+ "environment variable." % (the_tensorboard_binary,)
+ )
+ else:
+ message = "ERROR: Failed to start %s: %s" % (
+ the_tensorboard_binary,
+ start_result.os_error,
+ )
+ print_or_update(textwrap.fill(message))
+
+ elif isinstance(start_result, manager.StartTimedOut):
+ message = (
+ "ERROR: Timed out waiting for TensorBoard to start. "
+ "It may still be running as pid %d." % start_result.pid
+ )
+ print_or_update(message)
- else:
- raise TypeError(
- "Unexpected result from `manager.start`: %r.\n"
- "This is a TensorBoard bug; please report it."
- % start_result
- )
+ else:
+ raise TypeError(
+ "Unexpected result from `manager.start`: %r.\n"
+ "This is a TensorBoard bug; please report it." % start_result
+ )
def _time_delta_from_info(info):
- """Format the elapsed time for the given TensorBoardInfo.
+ """Format the elapsed time for the given TensorBoardInfo.
- Args:
- info: A TensorBoardInfo value.
+ Args:
+ info: A TensorBoardInfo value.
- Returns:
- A human-readable string describing the time since the server
- described by `info` started: e.g., "2 days, 0:48:58".
- """
- delta_seconds = int(time.time()) - info.start_time
- return str(datetime.timedelta(seconds=delta_seconds))
+ Returns:
+ A human-readable string describing the time since the server
+ described by `info` started: e.g., "2 days, 0:48:58".
+ """
+ delta_seconds = int(time.time()) - info.start_time
+ return str(datetime.timedelta(seconds=delta_seconds))
def display(port=None, height=None):
- """Display a TensorBoard instance already running on this machine.
-
- Args:
- port: The port on which the TensorBoard server is listening, as an
- `int`, or `None` to automatically select the most recently
- launched TensorBoard.
- height: The height of the frame into which to render the TensorBoard
- UI, as an `int` number of pixels, or `None` to use a default value
- (currently 800).
- """
- _display(port=port, height=height, print_message=True, display_handle=None)
+ """Display a TensorBoard instance already running on this machine.
+ Args:
+ port: The port on which the TensorBoard server is listening, as an
+ `int`, or `None` to automatically select the most recently
+ launched TensorBoard.
+ height: The height of the frame into which to render the TensorBoard
+ UI, as an `int` number of pixels, or `None` to use a default value
+ (currently 800).
+ """
+ _display(port=port, height=height, print_message=True, display_handle=None)
-def _display(port=None, height=None, print_message=False, display_handle=None):
- """Internal version of `display`.
-
- Args:
- port: As with `display`.
- height: As with `display`.
- print_message: True to print which TensorBoard instance was selected
- for display (if applicable), or False otherwise.
- display_handle: If not None, an IPython display handle into which to
- render TensorBoard.
- """
- if height is None:
- height = 800
-
- if port is None:
- infos = manager.get_all()
- if not infos:
- raise ValueError("Can't display TensorBoard: no known instances running.")
- else:
- info = max(manager.get_all(), key=lambda x: x.start_time)
- port = info.port
- else:
- infos = [i for i in manager.get_all() if i.port == port]
- info = (
- max(infos, key=lambda x: x.start_time)
- if infos
- else None
- )
- if print_message:
- if info is not None:
- message = (
- "Selecting TensorBoard with {data_source} "
- "(started {delta} ago; port {port}, pid {pid})."
- ).format(
- data_source=manager.data_source_from_info(info),
- delta=_time_delta_from_info(info),
- port=info.port,
- pid=info.pid,
- )
- print(message)
+def _display(port=None, height=None, print_message=False, display_handle=None):
+ """Internal version of `display`.
+
+ Args:
+ port: As with `display`.
+ height: As with `display`.
+ print_message: True to print which TensorBoard instance was selected
+ for display (if applicable), or False otherwise.
+ display_handle: If not None, an IPython display handle into which to
+ render TensorBoard.
+ """
+ if height is None:
+ height = 800
+
+ if port is None:
+ infos = manager.get_all()
+ if not infos:
+ raise ValueError(
+ "Can't display TensorBoard: no known instances running."
+ )
+ else:
+ info = max(manager.get_all(), key=lambda x: x.start_time)
+ port = info.port
else:
- # The user explicitly provided a port, and we don't have any
- # additional information. There's nothing useful to say.
- pass
-
- fn = {
- _CONTEXT_COLAB: _display_colab,
- _CONTEXT_IPYTHON: _display_ipython,
- _CONTEXT_NONE: _display_cli,
- }[_get_context()]
- return fn(port=port, height=height, display_handle=display_handle)
+ infos = [i for i in manager.get_all() if i.port == port]
+ info = max(infos, key=lambda x: x.start_time) if infos else None
+
+ if print_message:
+ if info is not None:
+ message = (
+ "Selecting TensorBoard with {data_source} "
+ "(started {delta} ago; port {port}, pid {pid})."
+ ).format(
+ data_source=manager.data_source_from_info(info),
+ delta=_time_delta_from_info(info),
+ port=info.port,
+ pid=info.pid,
+ )
+ print(message)
+ else:
+ # The user explicitly provided a port, and we don't have any
+ # additional information. There's nothing useful to say.
+ pass
+
+ fn = {
+ _CONTEXT_COLAB: _display_colab,
+ _CONTEXT_IPYTHON: _display_ipython,
+ _CONTEXT_NONE: _display_cli,
+ }[_get_context()]
+ return fn(port=port, height=height, display_handle=display_handle)
def _display_colab(port, height, display_handle):
- """Display a TensorBoard instance in a Colab output frame.
-
- The Colab VM is not directly exposed to the network, so the Colab
- runtime provides a service worker tunnel to proxy requests from the
- end user's browser through to servers running on the Colab VM: the
- output frame may issue requests to https://localhost: (HTTPS
- only), which will be forwarded to the specified port on the VM.
-
- It does not suffice to create an `iframe` and let the service worker
- redirect its traffic (`