Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions tensorboard/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,31 @@ py_library(
],
)

py_library(
name = "manager",
srcs = ["manager.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorboard:internal"],
deps = [
":version",
"@org_pythonhosted_six",
],
)

py_test(
name = "manager_test",
size = "small",
srcs = ["manager_test.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorboard:internal"],
deps = [
":manager",
":version",
"//tensorboard:expect_tensorflow_installed",
"@org_pythonhosted_six",
],
)

py_library(
name = "program",
srcs = ["program.py"],
Expand Down
152 changes: 152 additions & 0 deletions tensorboard/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Private utilities for managing multiple TensorBoard processes."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import datetime
import json

import six

from tensorboard import version


# Type descriptors for `TensorboardInfo` fields.
_FieldType = collections.namedtuple(
"_FieldType",
(
"serialized_type",
"runtime_type",
"serialize",
"deserialize",
),
)
_type_timestamp = _FieldType(
serialized_type=int, # seconds since epoch
runtime_type=datetime.datetime, # microseconds component ignored
serialize=lambda dt: int(dt.strftime("%s")),
deserialize=lambda n: datetime.datetime.fromtimestamp(n),
)
_type_int = _FieldType(
serialized_type=int,
runtime_type=int,
serialize=lambda n: n,
deserialize=lambda n: n,
)
_type_str = _FieldType(
serialized_type=six.text_type, # `json.loads` always gives Unicode
runtime_type=str,
serialize=six.text_type,
deserialize=str,
)

# Information about a running TensorBoard instance.
_TENSORBOARD_INFO_FIELDS = collections.OrderedDict((
("version", _type_str),
("start_time", _type_timestamp),
("pid", _type_int),
("port", _type_int),
("path_prefix", _type_str), # may be empty
("logdir", _type_str), # may be empty
("db", _type_str), # may be empty
("cache_key", _type_str), # opaque
))
TensorboardInfo = collections.namedtuple(
"TensorboardInfo",
_TENSORBOARD_INFO_FIELDS,
)

def _info_to_string(info):
"""Convert a `TensorboardInfo` to string form to be stored on disk.

The format returned by this function is opaque and should only be
interpreted by `_info_from_string`.

Args:
info: A valid `TensorboardInfo` object.

Raises:
ValueError: If any field on `info` is not of the correct type.

Returns:
A string representation of the provided `TensorboardInfo`.
"""
for key in _TENSORBOARD_INFO_FIELDS:
field_type = _TENSORBOARD_INFO_FIELDS[key]
if not isinstance(getattr(info, key), field_type.runtime_type):
raise ValueError(
"expected %r of type %s, but found: %r" %
(key, field_type.runtime_type, getattr(info, key))
)
if info.version != version.VERSION:
raise ValueError(
"expected 'version' to be %r, but found: %r" %
(version.VERSION, info.version)
)
json_value = {
k: _TENSORBOARD_INFO_FIELDS[k].serialize(getattr(info, k))
for k in _TENSORBOARD_INFO_FIELDS
}
return json.dumps(json_value, sort_keys=True, indent=4)


def _info_from_string(info_string):
"""Parse a `TensorboardInfo` object from its string representation.

Args:
info_string: A string representation of a `TensorboardInfo`, as
produced by a previous call to `_info_to_string`.

Returns:
A `TensorboardInfo` value.

Raises:
ValueError: If the provided string is not valid JSON, or if it does
not represent a JSON object with a "version" field whose value is
`tensorboard.version.VERSION`, or if it has the wrong set of
fields, or if at least one field is of invalid type.
"""

try:
json_value = json.loads(info_string)
except ValueError:
raise ValueError("invalid JSON: %r" % (info_string,))
if not isinstance(json_value, dict):
raise ValueError("not a JSON object: %r" % (json_value,))
if json_value.get("version") != version.VERSION:
raise ValueError("incompatible version: %r" % (json_value,))
expected_keys = frozenset(_TENSORBOARD_INFO_FIELDS)
actual_keys = frozenset(json_value)
if expected_keys != actual_keys:
raise ValueError(
"bad keys on TensorboardInfo (missing: %s; extraneous: %s)"
% (expected_keys - actual_keys, actual_keys - expected_keys)
)

# Validate and deserialize fields.
for key in _TENSORBOARD_INFO_FIELDS:
field_type = _TENSORBOARD_INFO_FIELDS[key]
if not isinstance(json_value[key], field_type.serialized_type):
raise ValueError(
"expected %r of type %s, but found: %r" %
(key, field_type.serialized_type, json_value[key])
)
json_value[key] = field_type.deserialize(json_value[key])

return TensorboardInfo(**json_value)
154 changes: 154 additions & 0 deletions tensorboard/manager_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit tests for `tensorboard.manager`."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import datetime
import json
import re

import six
import tensorflow as tf

from tensorboard import manager
from tensorboard import version


def _make_info(i=0):
"""Make a sample TensorboardInfo object.

Args:
i: Seed; vary this value to produce slightly different outputs.

Returns:
A type-correct `TensorboardInfo` object.
"""
return manager.TensorboardInfo(
version=version.VERSION,
start_time=datetime.datetime.fromtimestamp(1548973541 + i),
port=6060 + i,
pid=76540 + i,
path_prefix="/foo",
logdir="~/my_data/",
db="",
cache_key="asdf",
)


class TensorboardInfoTest(tf.test.TestCase):
"""Unit tests for TensorboardInfo typechecking and serialization."""

def test_roundtrip_serialization(self):
# This will also be tested indirectly as part of `manager`
# integration tests.
info = _make_info()
also_info = manager._info_from_string(manager._info_to_string(info))
self.assertEqual(also_info, info)

def test_serialization_rejects_bad_types(self):
info = _make_info()._replace(start_time=1549061116) # not a datetime
with six.assertRaisesRegex(
self,
ValueError,
"expected 'start_time' of type.*datetime.*, but found: 1549061116"):
manager._info_to_string(info)

def test_serialization_rejects_wrong_version(self):
info = _make_info()._replace(version="reversion")
with six.assertRaisesRegex(
self,
ValueError,
"expected 'version' to be '.*', but found: 'reversion'"):
manager._info_to_string(info)

def test_deserialization_rejects_bad_json(self):
bad_input = "parse me if you dare"
with six.assertRaisesRegex(
self,
ValueError,
"invalid JSON:"):
manager._info_from_string(bad_input)

def test_deserialization_rejects_non_object_json(self):
bad_input = "[1, 2]"
with six.assertRaisesRegex(
self,
ValueError,
re.escape("not a JSON object: [1, 2]")):
manager._info_from_string(bad_input)

def test_deserialization_rejects_missing_version(self):
info = _make_info()
json_value = json.loads(manager._info_to_string(info))
del json_value["version"]
bad_input = json.dumps(json_value)
with six.assertRaisesRegex(
self,
ValueError,
"incompatible version:"):
manager._info_from_string(bad_input)

def test_deserialization_rejects_bad_version(self):
info = _make_info()
json_value = json.loads(manager._info_to_string(info))
json_value["version"] = "not likely"
bad_input = json.dumps(json_value)
with six.assertRaisesRegex(
self,
ValueError,
"incompatible version:.*not likely"):
manager._info_from_string(bad_input)

def test_deserialization_rejects_extra_keys(self):
info = _make_info()
json_value = json.loads(manager._info_to_string(info))
json_value["unlikely"] = "story"
bad_input = json.dumps(json_value)
with six.assertRaisesRegex(
self,
ValueError,
"bad keys on TensorboardInfo"):
manager._info_from_string(bad_input)

def test_deserialization_rejects_missing_keys(self):
info = _make_info()
json_value = json.loads(manager._info_to_string(info))
del json_value["start_time"]
bad_input = json.dumps(json_value)
with six.assertRaisesRegex(
self,
ValueError,
"bad keys on TensorboardInfo"):
manager._info_from_string(bad_input)

def test_deserialization_rejects_bad_types(self):
info = _make_info()
json_value = json.loads(manager._info_to_string(info))
json_value["start_time"] = "2001-02-03T04:05:06"
bad_input = json.dumps(json_value)
with six.assertRaisesRegex(
self,
ValueError,
"expected 'start_time' of type.*int.*, but found:.*"
"'2001-02-03T04:05:06'"):
manager._info_from_string(bad_input)



if __name__ == "__main__":
tf.test.main()