diff --git a/tensorboard/BUILD b/tensorboard/BUILD index baccb184cc..6dd10fd00f 100644 --- a/tensorboard/BUILD +++ b/tensorboard/BUILD @@ -76,6 +76,7 @@ py_library( visibility = ["//tensorboard:internal"], deps = [ ":version", + "//tensorboard/util:tb_logging", "@org_pythonhosted_six", ], ) @@ -89,6 +90,7 @@ py_test( deps = [ ":manager", ":version", + "//tensorboard/util:tb_logging", "//tensorboard:expect_tensorflow_installed", "@org_pythonhosted_six", ], diff --git a/tensorboard/manager.py b/tensorboard/manager.py index 280aa86cdb..64508efcd3 100644 --- a/tensorboard/manager.py +++ b/tensorboard/manager.py @@ -21,11 +21,15 @@ import base64 import collections import datetime +import errno import json +import os +import tempfile import six from tensorboard import version +from tensorboard.util import tb_logging # Type descriptors for `TensorboardInfo` fields. @@ -199,3 +203,104 @@ def cache_key(working_directory, arguments, configure_kwargs): # `raw` is of type `bytes`, even though it only contains ASCII # characters; we want it to be `str` in both Python 2 and 3. return str(raw.decode("ascii")) + + +def _get_info_dir(): + """Get path to directory in which to store info files. + + The directory returned by this function is "owned" by this module. If + the contents of the directory are modified other than via the public + functions of this module, subsequent behavior is undefined. + + The directory will be created if it does not exist. + """ + path = os.path.join(tempfile.gettempdir(), ".tensorboard-info") + try: + os.makedirs(path) + except OSError as e: + if e.errno == errno.EEXIST and os.path.isdir(path): + pass + else: + raise + return path + + +def _get_info_file_path(): + """Get path to info file for the current process. + + As with `_get_info_dir`, the info directory will be created if it does + not exist. + """ + return os.path.join(_get_info_dir(), "pid-%d.info" % os.getpid()) + + +def write_info_file(tensorboard_info): + """Write TensorboardInfo to the current process's info file. + + This should be called by `main` once the server is ready. When the + server shuts down, `remove_info_file` should be called. + + Args: + tensorboard_info: A valid `TensorboardInfo` object. + + Raises: + ValueError: If any field on `info` is not of the correct type. + """ + payload = "%s\n" % _info_to_string(tensorboard_info) + with open(_get_info_file_path(), "w") as outfile: + outfile.write(payload) + + +def remove_info_file(): + """Remove the current process's TensorboardInfo file, if it exists. + + If the file does not exist, no action is taken and no error is raised. + """ + try: + os.unlink(_get_info_file_path()) + except OSError as e: + if e.errno == errno.ENOENT: + # The user may have wiped their temporary directory or something. + # Not a problem: we're already in the state that we want to be in. + pass + else: + raise + + +def get_all(): + """Return TensorboardInfo values for running TensorBoard processes. + + This function may not provide a perfect snapshot of the set of running + processes. Its result set may be incomplete if the user has cleaned + their /tmp/ directory while TensorBoard processes are running. It may + contain extraneous entries if TensorBoard processes exited uncleanly + (e.g., with SIGKILL or SIGQUIT). + + Returns: + A fresh list of `TensorboardInfo` objects. + """ + info_dir = _get_info_dir() + results = [] + for filename in os.listdir(info_dir): + filepath = os.path.join(info_dir, filename) + try: + with open(filepath) as infile: + contents = infile.read() + except IOError as e: + if e.errno == errno.EACCES: + # May have been written by this module in a process whose + # `umask` includes some bits of 0o444. + continue + else: + raise + try: + info = _info_from_string(contents) + except ValueError: + tb_logging.get_logger().warning( + "invalid info file: %r", + filepath, + exc_info=True, + ) + else: + results.append(info) + return results diff --git a/tensorboard/manager_test.py b/tensorboard/manager_test.py index eb0b6f2062..6f9cf15f81 100644 --- a/tensorboard/manager_test.py +++ b/tensorboard/manager_test.py @@ -19,14 +19,24 @@ from __future__ import print_function import datetime +import errno import json +import os import re +import tempfile import six import tensorflow as tf +try: + # python version >= 3.3 + from unittest import mock # pylint: disable=g-import-not-at-top +except ImportError: + import mock # pylint: disable=g-import-not-at-top,unused-import + from tensorboard import manager from tensorboard import version +from tensorboard.util import tb_logging def _make_info(i=0): @@ -54,8 +64,8 @@ class TensorboardInfoTest(tf.test.TestCase): """Unit tests for TensorboardInfo typechecking and serialization.""" def test_roundtrip_serialization(self): - # This will also be tested indirectly as part of `manager` - # integration tests. + # This is also tested indirectly as part of `manager` integration + # tests, in `test_get_all`. info = _make_info() also_info = manager._info_from_string(manager._info_to_string(info)) self.assertEqual(also_info, info) @@ -235,5 +245,101 @@ def test_arguments_list_vs_tuple_irrelevant(self): self.assertEqual(with_list, with_tuple) +class TensorboardInfoIoTest(tf.test.TestCase): + """Tests for `write_info_file`, `remove_info_file`, and `get_all`.""" + + def setUp(self): + super(TensorboardInfoIoTest, self).setUp() + patcher = mock.patch.dict(os.environ, {"TMPDIR": self.get_temp_dir()}) + patcher.start() + self.addCleanup(patcher.stop) + tempfile.tempdir = None # force `gettempdir` to reinitialize from env + self.info_dir = manager._get_info_dir() # ensure that directory exists + + def _list_info_dir(self): + return os.listdir(self.info_dir) + + def test_fails_if_info_dir_name_is_taken_by_a_regular_file(self): + os.rmdir(self.info_dir) + with open(self.info_dir, "w") as outfile: + pass + with self.assertRaises(OSError) as cm: + manager._get_info_dir() + self.assertEqual(cm.exception.errno, errno.EEXIST, cm.exception) + + @mock.patch("os.getpid", lambda: 76540) + def test_write_remove_info_file(self): + info = _make_info() + self.assertEqual(self._list_info_dir(), []) + manager.write_info_file(info) + filename = "pid-76540.info" + expected_filepath = os.path.join(self.info_dir, filename) + self.assertEqual(self._list_info_dir(), [filename]) + with open(expected_filepath) as infile: + self.assertEqual(manager._info_from_string(infile.read()), info) + manager.remove_info_file() + self.assertEqual(self._list_info_dir(), []) + + def test_write_info_file_rejects_bad_types(self): + # The particulars of validation are tested more thoroughly in + # `TensorboardInfoTest` above. + info = _make_info()._replace(start_time=1549061116) + with six.assertRaisesRegex( + self, + ValueError, + "expected 'start_time' of type.*datetime.*, but found: 1549061116"): + manager.write_info_file(info) + self.assertEqual(self._list_info_dir(), []) + + def test_write_info_file_rejects_wrong_version(self): + # The particulars of validation are tested more thoroughly in + # `TensorboardInfoTest` above. + info = _make_info()._replace(version="reversion") + with six.assertRaisesRegex( + self, + ValueError, + "expected 'version' to be '.*', but found: 'reversion'"): + manager.write_info_file(info) + self.assertEqual(self._list_info_dir(), []) + + def test_remove_nonexistent(self): + # Should be a no-op, except to create the info directory if + # necessary. In particular, should not raise any exception. + manager.remove_info_file() + + def test_get_all(self): + def add_info(i): + with mock.patch("os.getpid", lambda: 76540 + i): + manager.write_info_file(_make_info(i)) + def remove_info(i): + with mock.patch("os.getpid", lambda: 76540 + i): + manager.remove_info_file() + self.assertItemsEqual(manager.get_all(), []) + add_info(1) + self.assertItemsEqual(manager.get_all(), [_make_info(1)]) + add_info(2) + self.assertItemsEqual(manager.get_all(), [_make_info(1), _make_info(2)]) + remove_info(1) + self.assertItemsEqual(manager.get_all(), [_make_info(2)]) + add_info(3) + self.assertItemsEqual(manager.get_all(), [_make_info(2), _make_info(3)]) + remove_info(3) + self.assertItemsEqual(manager.get_all(), [_make_info(2)]) + remove_info(2) + self.assertItemsEqual(manager.get_all(), []) + + def test_get_all_ignores_bad_files(self): + with open(os.path.join(self.info_dir, "pid-1234.info"), "w") as outfile: + outfile.write("good luck parsing this\n") + with open(os.path.join(self.info_dir, "pid-5678.info"), "w") as outfile: + outfile.write('{"valid_json":"yes","valid_tbinfo":"no"}\n') + with open(os.path.join(self.info_dir, "pid-9012.info"), "w") as outfile: + outfile.write('if a tbinfo has st_mode==0, does it make a sound?\n') + os.chmod(os.path.join(self.info_dir, "pid-9012.info"), 0o000) + with mock.patch.object(tb_logging.get_logger(), "warning") as fn: + self.assertEqual(manager.get_all(), []) + self.assertEqual(fn.call_count, 2) # 2 invalid, 1 unreadable (silent) + + if __name__ == "__main__": tf.test.main()