diff --git a/tensorboard/BUILD b/tensorboard/BUILD index 4ce94840ae..912ff3f22d 100644 --- a/tensorboard/BUILD +++ b/tensorboard/BUILD @@ -69,17 +69,44 @@ py_library( ], ) +py_library( + name = "manager", + srcs = ["manager.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorboard:internal"], + deps = [ + ":version", + "//tensorboard/util:tb_logging", + "@org_pythonhosted_six", + ], +) + +py_test( + name = "manager_test", + size = "small", + srcs = ["manager_test.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorboard:internal"], + deps = [ + ":manager", + "//tensorboard/util:tb_logging", + "//tensorboard:expect_tensorflow_installed", + ], +) + py_library( name = "program", srcs = ["program.py"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ + ":manager", ":version", "//tensorboard/backend:application", "//tensorboard/backend/event_processing:event_file_inspector", "//tensorboard/util", "@org_pocoo_werkzeug", + "@org_pythonhosted_six", ], ) diff --git a/tensorboard/manager.py b/tensorboard/manager.py new file mode 100644 index 0000000000..c15d48fff7 --- /dev/null +++ b/tensorboard/manager.py @@ -0,0 +1,251 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Private utilities for managing multiple TensorBoard processes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import base64 +import collections +import datetime +import errno +import json +import os +import subprocess +import tempfile + +import six + +from tensorboard import version +from tensorboard.util import tb_logging + + +# Information about a running TensorBoard instance. +TensorboardInfo = collections.namedtuple( + "TensorboardInfo", + ( + "version", # tensorboard.version.VERSION + "start_time", # datetime.datetime (microseconds ignored) + "pid", + "port", + "path_prefix", # str (maybe empty) + "logdir", # str (maybe empty) + "db", # str (maybe empty) + "cache_key", # opaque string + ), +) + +def _info_to_string(info): + """Convert a `TensorboardInfo` to string form to be stored on disk. + + The format returned by this function is opaque and should only be + interpreted by `_info_from_string`. + + Args: + info: A valid `TensorboardInfo` object. + + Raises: + ValueError: If any field on `info` is not of the correct type. + + Returns: + A string representation of the provided `TensorboardInfo`. + """ + field_types = { + "version": str, + "start_time": datetime.datetime, + "pid": int, + "port": int, + "path_prefix": str, + "logdir": str, + "db": str, + "cache_key": str, + } + assert frozenset(field_types) == frozenset(TensorboardInfo._fields) + for key in field_types: + if not isinstance(getattr(info, key), field_types[key]): + raise ValueError( + "expected %r of type %s, but found: %r" % + (key, field_types[key], getattr(info, key)) + ) + if info.version != version.VERSION: + raise ValueError( + "expected 'version' to be %r, but found: %r" % + (version.VERSION, info.version) + ) + json_value = info._asdict() + json_value["start_time"] = int(info.start_time.strftime("%s")) + return json.dumps(json_value, sort_keys=True, indent=4) + + +def _info_from_string(info_string): + """Parse a `TensorboardInfo` object from its string representation. + + Args: + info_string: A string representation of a `TensorboardInfo`, as + produced by a previous call to `_info_to_string`. + + Returns: + A `TensorboardInfo` value. + + Raises: + ValueError: If the provided string is not valid JSON, or if it does + not represent a JSON object with a "version" field whose value is + `tensorboard.version.VERSION`, or if it has the wrong set of + fields, or if at least one field is of invalid type. + """ + + json_value = json.loads(info_string) # may raise ValueError + if not isinstance(json_value, dict): + raise ValueError("not a JSON object: %r" % (json_value,)) + if json_value.get("version") != version.VERSION: + raise ValueError("incompatible version: %r" % (json_value,)) + + field_types = { + "version": six.text_type, + "start_time": int, + "pid": int, + "port": int, + "path_prefix": six.text_type, + "logdir": six.text_type, + "db": six.text_type, + "cache_key": six.text_type, + } + assert frozenset(field_types) == frozenset(TensorboardInfo._fields) + for key in field_types: + if not isinstance(json_value[key], field_types[key]): + raise ValueError( + "expected %r of type %s, but found: %r" % + (key, field_types[key], json_value[key]) + ) + if field_types[key] is six.text_type: + # Python 2 compatibility kludge. + json_value[key] = str(json_value[key]) + + expected_keys = frozenset(field_types) + actual_keys = frozenset(json_value) + if expected_keys != actual_keys: + raise ValueError( + "bad keys on TensorboardInfo (missing: %s; extraneous: %s)" + % (expected_keys - actual_keys, actual_keys - expected_keys) + ) + json_value["start_time"] = ( + datetime.datetime.fromtimestamp(json_value["start_time"]) + ) + return TensorboardInfo(**json_value) + + +def _get_info_dir(): + """Get path to directory in which to store info files. + + The directory will be created if it does not exist. + """ + path = os.path.join(tempfile.gettempdir(), ".tensorboard-info") + if not os.path.exists(path): + os.mkdir(path) + return path + + +def _info_file_path(): + """Get path to info file for the current process.""" + return os.path.join(_get_info_dir(), "pid-%d.info" % os.getpid()) + + +def write_info_file(tensorboard_info): + """Write TensorboardInfo to the current process's info file. + + This should be called by `main` once the server is ready. When the + server shuts down, `remove_info_file` should be called. + """ + payload = "%s\n" % _info_to_string(tensorboard_info) + with open(_info_file_path(), "w") as outfile: + outfile.write(payload) + + +def remove_info_file(): + """Remove the current process's TensorboardInfo file, if it exists. + + If the file does not exist, no action is taken and no error is raised. + """ + try: + os.unlink(_info_file_path()) + except OSError as e: + if e.errno == errno.ENOENT: + # The user may have wiped their temporary directory or something. + # Not a problem: we're already in the state that we want to be in. + pass + else: + raise + + +def cache_key(working_directory, arguments): + """Compute a `TensorboardInfo.cache_key` field. + + Args: + working_directory: The directory from which TensorBoard was launched + and relative to which paths like `--logdir` and `--db` are + resolved. + arguments: The command-line args to TensorBoard: `sys.argv[1:]`. + Should be a list (or tuple), not an unparsed string. If you have a + raw shell command, use `shlex.split` before passing it to this + function. + + Returns: + A string such that if two (prospective or actual) TensorBoard + invocations have the same cache key then it is safe to use one in + place of the other. + """ + if not isinstance(arguments, (list, tuple)): + raise TypeError( + "'arguments' should be a list of arguments, but found: %r " + "(use `shlex.split` if given a string)" + % (arguments,) + ) + datum = {"working_directory": working_directory, "arguments": arguments} + return base64.b64encode( + json.dumps(datum, sort_keys=True, separators=(",", ":")).encode("utf-8") + ) + + +def get_all(): + """Return TensorboardInfo values for running TensorBoard processes. + + This function may not provide a perfect snapshot of the set of running + processes. Its result set may be incomplete if the user has cleaned + their /tmp/ directory while TensorBoard processes are running. It may + contain extraneous entries if TensorBoard processes exited uncleanly + (e.g., with SIGKILL). + + Returns: + A list of `TensorboardInfo` objects. + """ + + info_dir = _get_info_dir() + results = [] + for filename in os.listdir(info_dir): + filepath = os.path.join(info_dir, filename) + with open(os.path.join(info_dir, filepath)) as infile: + contents = infile.read() + try: + info = _info_from_string(contents) + except ValueError: + tb_logging.get_logger().warning( + "invalid info file: %r", + filepath, + exc_info=True, + ) + else: + results.append(info) + return results diff --git a/tensorboard/manager_test.py b/tensorboard/manager_test.py new file mode 100644 index 0000000000..293882faf0 --- /dev/null +++ b/tensorboard/manager_test.py @@ -0,0 +1,175 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for `tensorboard.manager`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import datetime +import json +import os +import tempfile + +import six +import tensorflow as tf + +try: + # python version >= 3.3 + from unittest import mock # pylint: disable=g-import-not-at-top +except ImportError: + import mock # pylint: disable=g-import-not-at-top,unused-import + +from tensorboard import manager +from tensorboard import version +from tensorboard.util import tb_logging + + +class ManagerTest(tf.test.TestCase): + def setUp(self): + super(ManagerTest, self).setUp() + self.info_dir = os.path.join(self.get_temp_dir(), ".test-tensorboard-info") + os.mkdir(self.info_dir) + patcher = mock.patch( + "tensorboard.manager._get_info_dir", + lambda: self.info_dir, + ) + patcher.start() + self.addCleanup(patcher.stop) + + def _make_info(self, i=0): + return manager.TensorboardInfo( + version=version.VERSION, + start_time=datetime.datetime.fromtimestamp(1548973541 + i), + port=6060 + i, + pid=76540 + i, + path_prefix="/foo", + logdir="~/my_data/", + db="", + cache_key="asdf", + ) + + @mock.patch("os.getpid", lambda: 76540) + def test_write_remove_info_file(self): + info = self._make_info() + manager.write_info_file(info) + filename = "pid-76540.info" + expected_filepath = os.path.join(self.info_dir, filename) + self.assertEqual(os.listdir(self.info_dir), [filename]) + with open(expected_filepath) as infile: + self.assertEqual(manager._info_from_string(infile.read()), info) + manager.remove_info_file() + self.assertEqual(os.listdir(self.info_dir), []) + + def test_write_info_file_rejects_bad_types(self): + info = self._make_info()._replace(start_time=1549061116) + with six.assertRaisesRegex( + self, + ValueError, + "expected 'start_time' of type.*datetime.*, but found: 1549061116", + ): + manager.write_info_file(info) + self.assertEqual(os.listdir(self.info_dir), []) + + def test_write_info_file_rejects_wrong_version(self): + info = self._make_info()._replace(version="reversion") + with six.assertRaisesRegex( + self, + ValueError, + "expected 'version' to be '.*', but found: 'reversion'", + ): + manager.write_info_file(info) + self.assertEqual(os.listdir(self.info_dir), []) + + def test_tensorboardinfo_serde_roundtrip(self): + # This is also tested as part of integration tests below. + info = self._make_info() + also_info = manager._info_from_string(manager._info_to_string(info)) + self.assertEqual(also_info, info) + + def test_remove_nonexistent(self): + # Should be a no-op, except to create the info directory. + manager.remove_info_file() + + def test_cache_key_differs_by_logdir(self): + results = [ + manager.cache_key( + working_directory=d, + arguments=["--logdir", "something"], + ) + for d in ("/home/me", "/home/you") + ] + self.assertEqual(len(results), len(set(results))) + + def test_cache_key_differs_by_arguments(self): + results = [ + manager.cache_key( + working_directory="/home/me", + arguments=arguments, + ) + for arguments in ( + ["--logdir=something"], + ["--logdir", "something"], + ["--logdir", "", "something"], + ["--logdir", "", "something", ""], + ) + ] + self.assertEqual(len(results), len(set(results))) + + def test_cache_key_rejects_string_arguments(self): + with six.assertRaisesRegex(self, TypeError, "should be a list"): + manager.cache_key( + working_directory="/home/me", + arguments="--logdir=something", + ) + + def test_get_all(self): + + def add_info(i): + with mock.patch("os.getpid", lambda: 76540 + i): + manager.write_info_file(self._make_info(i)) + + def remove_info(i): + with mock.patch("os.getpid", lambda: 76540 + i): + manager.remove_info_file() + + make_info = self._make_info + + self.assertItemsEqual(manager.get_all(), []) + add_info(1) + self.assertItemsEqual(manager.get_all(), [make_info(1)]) + add_info(2) + self.assertItemsEqual(manager.get_all(), [make_info(1), make_info(2)]) + remove_info(1) + self.assertItemsEqual(manager.get_all(), [make_info(2)]) + add_info(3) + self.assertItemsEqual(manager.get_all(), [make_info(2), make_info(3)]) + remove_info(3) + self.assertItemsEqual(manager.get_all(), [make_info(2)]) + remove_info(2) + self.assertItemsEqual(manager.get_all(), []) + + def test_get_all_ignores_bad_files(self): + with open(os.path.join(self.info_dir, "pid-1234.info"), "w") as outfile: + outfile.write("good luck parsing this\n") + with open(os.path.join(self.info_dir, "pid-5678.info"), "w") as outfile: + outfile.write('{"valid_json":"yes","valid_tbinfo":"no"}\n') + with mock.patch.object(tb_logging.get_logger(), "warning") as fn: + self.assertEqual(manager.get_all(), []) + self.assertEqual(fn.call_count, 2) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorboard/program.py b/tensorboard/program.py index 96f84ac4f9..efbeaedd62 100644 --- a/tensorboard/program.py +++ b/tensorboard/program.py @@ -32,16 +32,23 @@ from abc import ABCMeta from abc import abstractmethod import argparse +import atexit +import datetime from collections import defaultdict import errno import os +import signal import socket import sys +import time import threading import inspect +import six +from six.moves import urllib from werkzeug import serving +from tensorboard import manager from tensorboard import version from tensorboard.backend import application from tensorboard.backend.event_processing import event_file_inspector as efi @@ -200,6 +207,7 @@ def main(self, ignored_argv=('',)): :rtype: int """ + self._install_sigterm_handler() if self.flags.inspect: logger.info('Not bringing up TensorBoard, but inspecting event files.') event_file = os.path.expanduser(self.flags.event_file) @@ -209,6 +217,7 @@ def main(self, ignored_argv=('',)): server = self._make_server() sys.stderr.write('TensorBoard %s at %s (Press CTRL+C to quit)\n' % (version.VERSION, server.get_url())) + self._register_info(server) sys.stderr.flush() server.serve_forever() return 0 @@ -218,6 +227,38 @@ def main(self, ignored_argv=('',)): sys.stderr.flush() return -1 + def _install_sigterm_handler(self): + old_sigterm_handler = None # set below + def handle_sigterm(signal_number, stack_frame): + del signal_number # unused + del stack_frame # unused + # In case we get SIGTERMed again while running atexit handlers, + # take the hint and actually die. + signal.signal(signal.SIGTERM, old_sigterm_handler) + sys.stderr.write("TensorBoard caught SIGTERM; exiting.\n") + # The main thread is the only non-daemon thread, so it suffices to + # exit hence. + sys.exit(0) + old_sigterm_handler = signal.signal(signal.SIGTERM, handle_sigterm) + + def _register_info(self, server): + server_url = urllib.parse.urlparse(server.get_url()) + info = manager.TensorboardInfo( + version=version.VERSION, + start_time=datetime.datetime.now(), + port=server_url.port, + pid=os.getpid(), + path_prefix=self.flags.path_prefix, + logdir=self.flags.logdir, + db=self.flags.db, + cache_key=manager.cache_key( + working_directory=os.getcwd(), + arguments=sys.argv[1:], + ), + ) + atexit.register(manager.remove_info_file) + manager.write_info_file(info) + def launch(self): """Python API for launching TensorBoard.