diff --git a/tensorboard/BUILD b/tensorboard/BUILD index eb256de9fd..baccb184cc 100644 --- a/tensorboard/BUILD +++ b/tensorboard/BUILD @@ -69,6 +69,31 @@ py_library( ], ) +py_library( + name = "manager", + srcs = ["manager.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorboard:internal"], + deps = [ + ":version", + "@org_pythonhosted_six", + ], +) + +py_test( + name = "manager_test", + size = "small", + srcs = ["manager_test.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorboard:internal"], + deps = [ + ":manager", + ":version", + "//tensorboard:expect_tensorflow_installed", + "@org_pythonhosted_six", + ], +) + py_library( name = "program", srcs = ["program.py"], diff --git a/tensorboard/manager.py b/tensorboard/manager.py new file mode 100644 index 0000000000..de412619b9 --- /dev/null +++ b/tensorboard/manager.py @@ -0,0 +1,152 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Private utilities for managing multiple TensorBoard processes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import datetime +import json + +import six + +from tensorboard import version + + +# Type descriptors for `TensorboardInfo` fields. +_FieldType = collections.namedtuple( + "_FieldType", + ( + "serialized_type", + "runtime_type", + "serialize", + "deserialize", + ), +) +_type_timestamp = _FieldType( + serialized_type=int, # seconds since epoch + runtime_type=datetime.datetime, # microseconds component ignored + serialize=lambda dt: int(dt.strftime("%s")), + deserialize=lambda n: datetime.datetime.fromtimestamp(n), +) +_type_int = _FieldType( + serialized_type=int, + runtime_type=int, + serialize=lambda n: n, + deserialize=lambda n: n, +) +_type_str = _FieldType( + serialized_type=six.text_type, # `json.loads` always gives Unicode + runtime_type=str, + serialize=six.text_type, + deserialize=str, +) + +# Information about a running TensorBoard instance. +_TENSORBOARD_INFO_FIELDS = collections.OrderedDict(( + ("version", _type_str), + ("start_time", _type_timestamp), + ("pid", _type_int), + ("port", _type_int), + ("path_prefix", _type_str), # may be empty + ("logdir", _type_str), # may be empty + ("db", _type_str), # may be empty + ("cache_key", _type_str), # opaque +)) +TensorboardInfo = collections.namedtuple( + "TensorboardInfo", + _TENSORBOARD_INFO_FIELDS, +) + +def _info_to_string(info): + """Convert a `TensorboardInfo` to string form to be stored on disk. + + The format returned by this function is opaque and should only be + interpreted by `_info_from_string`. + + Args: + info: A valid `TensorboardInfo` object. + + Raises: + ValueError: If any field on `info` is not of the correct type. + + Returns: + A string representation of the provided `TensorboardInfo`. + """ + for key in _TENSORBOARD_INFO_FIELDS: + field_type = _TENSORBOARD_INFO_FIELDS[key] + if not isinstance(getattr(info, key), field_type.runtime_type): + raise ValueError( + "expected %r of type %s, but found: %r" % + (key, field_type.runtime_type, getattr(info, key)) + ) + if info.version != version.VERSION: + raise ValueError( + "expected 'version' to be %r, but found: %r" % + (version.VERSION, info.version) + ) + json_value = { + k: _TENSORBOARD_INFO_FIELDS[k].serialize(getattr(info, k)) + for k in _TENSORBOARD_INFO_FIELDS + } + return json.dumps(json_value, sort_keys=True, indent=4) + + +def _info_from_string(info_string): + """Parse a `TensorboardInfo` object from its string representation. + + Args: + info_string: A string representation of a `TensorboardInfo`, as + produced by a previous call to `_info_to_string`. + + Returns: + A `TensorboardInfo` value. + + Raises: + ValueError: If the provided string is not valid JSON, or if it does + not represent a JSON object with a "version" field whose value is + `tensorboard.version.VERSION`, or if it has the wrong set of + fields, or if at least one field is of invalid type. + """ + + try: + json_value = json.loads(info_string) + except ValueError: + raise ValueError("invalid JSON: %r" % (info_string,)) + if not isinstance(json_value, dict): + raise ValueError("not a JSON object: %r" % (json_value,)) + if json_value.get("version") != version.VERSION: + raise ValueError("incompatible version: %r" % (json_value,)) + expected_keys = frozenset(_TENSORBOARD_INFO_FIELDS) + actual_keys = frozenset(json_value) + if expected_keys != actual_keys: + raise ValueError( + "bad keys on TensorboardInfo (missing: %s; extraneous: %s)" + % (expected_keys - actual_keys, actual_keys - expected_keys) + ) + + # Validate and deserialize fields. + for key in _TENSORBOARD_INFO_FIELDS: + field_type = _TENSORBOARD_INFO_FIELDS[key] + if not isinstance(json_value[key], field_type.serialized_type): + raise ValueError( + "expected %r of type %s, but found: %r" % + (key, field_type.serialized_type, json_value[key]) + ) + json_value[key] = field_type.deserialize(json_value[key]) + + return TensorboardInfo(**json_value) diff --git a/tensorboard/manager_test.py b/tensorboard/manager_test.py new file mode 100644 index 0000000000..093bba4c16 --- /dev/null +++ b/tensorboard/manager_test.py @@ -0,0 +1,154 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for `tensorboard.manager`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import datetime +import json +import re + +import six +import tensorflow as tf + +from tensorboard import manager +from tensorboard import version + + +def _make_info(i=0): + """Make a sample TensorboardInfo object. + + Args: + i: Seed; vary this value to produce slightly different outputs. + + Returns: + A type-correct `TensorboardInfo` object. + """ + return manager.TensorboardInfo( + version=version.VERSION, + start_time=datetime.datetime.fromtimestamp(1548973541 + i), + port=6060 + i, + pid=76540 + i, + path_prefix="/foo", + logdir="~/my_data/", + db="", + cache_key="asdf", + ) + + +class TensorboardInfoTest(tf.test.TestCase): + """Unit tests for TensorboardInfo typechecking and serialization.""" + + def test_roundtrip_serialization(self): + # This will also be tested indirectly as part of `manager` + # integration tests. + info = _make_info() + also_info = manager._info_from_string(manager._info_to_string(info)) + self.assertEqual(also_info, info) + + def test_serialization_rejects_bad_types(self): + info = _make_info()._replace(start_time=1549061116) # not a datetime + with six.assertRaisesRegex( + self, + ValueError, + "expected 'start_time' of type.*datetime.*, but found: 1549061116"): + manager._info_to_string(info) + + def test_serialization_rejects_wrong_version(self): + info = _make_info()._replace(version="reversion") + with six.assertRaisesRegex( + self, + ValueError, + "expected 'version' to be '.*', but found: 'reversion'"): + manager._info_to_string(info) + + def test_deserialization_rejects_bad_json(self): + bad_input = "parse me if you dare" + with six.assertRaisesRegex( + self, + ValueError, + "invalid JSON:"): + manager._info_from_string(bad_input) + + def test_deserialization_rejects_non_object_json(self): + bad_input = "[1, 2]" + with six.assertRaisesRegex( + self, + ValueError, + re.escape("not a JSON object: [1, 2]")): + manager._info_from_string(bad_input) + + def test_deserialization_rejects_missing_version(self): + info = _make_info() + json_value = json.loads(manager._info_to_string(info)) + del json_value["version"] + bad_input = json.dumps(json_value) + with six.assertRaisesRegex( + self, + ValueError, + "incompatible version:"): + manager._info_from_string(bad_input) + + def test_deserialization_rejects_bad_version(self): + info = _make_info() + json_value = json.loads(manager._info_to_string(info)) + json_value["version"] = "not likely" + bad_input = json.dumps(json_value) + with six.assertRaisesRegex( + self, + ValueError, + "incompatible version:.*not likely"): + manager._info_from_string(bad_input) + + def test_deserialization_rejects_extra_keys(self): + info = _make_info() + json_value = json.loads(manager._info_to_string(info)) + json_value["unlikely"] = "story" + bad_input = json.dumps(json_value) + with six.assertRaisesRegex( + self, + ValueError, + "bad keys on TensorboardInfo"): + manager._info_from_string(bad_input) + + def test_deserialization_rejects_missing_keys(self): + info = _make_info() + json_value = json.loads(manager._info_to_string(info)) + del json_value["start_time"] + bad_input = json.dumps(json_value) + with six.assertRaisesRegex( + self, + ValueError, + "bad keys on TensorboardInfo"): + manager._info_from_string(bad_input) + + def test_deserialization_rejects_bad_types(self): + info = _make_info() + json_value = json.loads(manager._info_to_string(info)) + json_value["start_time"] = "2001-02-03T04:05:06" + bad_input = json.dumps(json_value) + with six.assertRaisesRegex( + self, + ValueError, + "expected 'start_time' of type.*int.*, but found:.*" + "'2001-02-03T04:05:06'"): + manager._info_from_string(bad_input) + + + +if __name__ == "__main__": + tf.test.main()