Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tensorboard/plugins/hparams/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ py_library(
"download_data.py",
"get_experiment.py",
"hparams_plugin.py",
"json_format_compat.py",
"list_metric_evals.py",
"list_session_groups.py",
"metrics.py",
Expand Down Expand Up @@ -115,6 +116,15 @@ py_test(
],
)

py_test(
name = "json_format_compat_test",
size = "small",
srcs = [
"json_format_compat_test.py",
],
deps = [":hparams_plugin"],
)

py_binary(
name = "hparams_demo",
srcs = ["hparams_demo.py"],
Expand Down
17 changes: 12 additions & 5 deletions tensorboard/plugins/hparams/backend_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from tensorboard.data import provider
from tensorboard.plugins.hparams import api_pb2
from tensorboard.plugins.hparams import json_format_compat
from tensorboard.plugins.hparams import metadata
from google.protobuf import json_format
from tensorboard.plugins.scalar import metadata as scalar_metadata
Expand Down Expand Up @@ -282,11 +283,6 @@ def _compute_hparam_info_from_values(self, name, values):
# If all values have the same type, then that is the type used.
# Otherwise, the returned type is DATA_TYPE_STRING.
result = api_pb2.HParamInfo(name=name, type=api_pb2.DATA_TYPE_UNSET)
distinct_values = set(
_protobuf_value_to_string(v)
for v in values
if _protobuf_value_type(v)
)
for v in values:
v_type = _protobuf_value_type(v)
if not v_type:
Expand All @@ -304,6 +300,11 @@ def _compute_hparam_info_from_values(self, name, values):
return None

if result.type == api_pb2.DATA_TYPE_STRING:
distinct_values = set(
_protobuf_value_to_string(v)
for v in values
if _can_be_converted_to_string(v)
)
result.domain_discrete.extend(distinct_values)

if result.type == api_pb2.DATA_TYPE_BOOL:
Expand Down Expand Up @@ -452,6 +453,12 @@ def _find_longest_parent_path(path_set, path):
return path


def _can_be_converted_to_string(value):
if not _protobuf_value_type(value):
return False
return json_format_compat.is_serializable_value(value)


def _protobuf_value_type(value):
"""Returns the type of the google.protobuf.Value message as an
api.DataType.
Expand Down
27 changes: 27 additions & 0 deletions tensorboard/plugins/hparams/backend_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,33 @@ def test_experiment_with_bool_types(self):
_canonicalize_experiment(actual_exp)
self.assertProtoEquals(expected_exp, actual_exp)

def test_experiment_with_string_domain_and_invalid_number_values(self):
self.session_1_start_info_ = """
hparams:[
{key: 'maybe_invalid' value: {string_value: 'force_to_string_type'}}
]
"""
self.session_2_start_info_ = """
hparams:[
{key: 'maybe_invalid' value: {number_value: NaN}}
]
"""
self.session_3_start_info_ = """
hparams:[
{key: 'maybe_invalid' value: {number_value: Infinity}}
]
"""
expected_hparam_info = """
name: 'maybe_invalid'
type: DATA_TYPE_STRING
domain_discrete: {
values: [{string_value: 'force_to_string_type'}]
}
"""
actual_exp = self._experiment_from_metadata()
self.assertLen(actual_exp.hparam_infos, 1)
self.assertProtoEquals(expected_hparam_info, actual_exp.hparam_infos[0])

def test_experiment_without_any_hparams(self):
request_ctx = context.RequestContext()
actual_exp = self._experiment_from_metadata()
Expand Down
38 changes: 38 additions & 0 deletions tensorboard/plugins/hparams/json_format_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import math


def is_serializable_value(value):
"""Returns whether a protobuf Value will be serializable by MessageToJson.

The json_format documentation states that "attempting to serialize NaN or
Infinity results in error."

https://protobuf.dev/reference/protobuf/google.protobuf/#value

Args:
value: A value of type protobuf.Value.

Returns:
True if the Value should be serializable without error by MessageToJson.
False, otherwise.
"""
if not value.HasField("number_value"):
return True

number_value = value.number_value
return not math.isnan(number_value) and not math.isinf(number_value)
64 changes: 64 additions & 0 deletions tensorboard/plugins/hparams/json_format_compat_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from google.protobuf import struct_pb2
from tensorboard.plugins.hparams import json_format_compat


class TestCase(absltest.TestCase):
def test_real_value_is_serializable(self):
self.assertTrue(
json_format_compat.is_serializable_value(
struct_pb2.Value(number_value=1.0)
)
)
self.assertTrue(
json_format_compat.is_serializable_value(
struct_pb2.Value(string_value="nan")
)
)
self.assertTrue(
json_format_compat.is_serializable_value(
struct_pb2.Value(bool_value=False)
)
)

def test_empty_value_is_serializable(self):
self.assertTrue(
json_format_compat.is_serializable_value(struct_pb2.Value())
)

def test_nan_value_is_not_serializable(self):
self.assertFalse(
json_format_compat.is_serializable_value(
struct_pb2.Value(number_value=float("nan"))
)
)

def test_infinity_value_is_not_serializable(self):
self.assertFalse(
json_format_compat.is_serializable_value(
struct_pb2.Value(number_value=float("inf"))
)
)
self.assertFalse(
json_format_compat.is_serializable_value(
struct_pb2.Value(number_value=float("-inf"))
)
)


if __name__ == "__main__":
absltest.main()
8 changes: 8 additions & 0 deletions tensorboard/plugins/hparams/list_session_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tensorboard.data import provider
from tensorboard.plugins.hparams import api_pb2
from tensorboard.plugins.hparams import error
from tensorboard.plugins.hparams import json_format_compat
from tensorboard.plugins.hparams import metadata
from tensorboard.plugins.hparams import metrics

Expand Down Expand Up @@ -246,6 +247,13 @@ def _add_session(self, session, start_info, groups_by_name):
# There doesn't seem to be a way to initialize a protobuffer map in the
# constructor.
for (key, value) in start_info.hparams.items():
if not json_format_compat.is_serializable_value(value):
# NaN number_value cannot be serialized by higher level layers
# that are using json_format.MessageToJson(). To workaround
# the issue we do not copy them to the session group and
# effectively treat them as "unset".
continue

group.hparams[key].CopyFrom(value)
groups_by_name[group_name] = group

Expand Down
131 changes: 131 additions & 0 deletions tensorboard/plugins/hparams/list_session_groups_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,6 +1172,137 @@ def test_sort_one_column_with_missing_values(self):
expected_total_size=3,
)

def _mock_list_tensors_invalid_number_values(
self, ctx, *, experiment_id, plugin_name, run_tag_filter
):
hparams_content = {
"session_1": {
metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data(
DATA_TYPE_SESSION_START_INFO,
"""
hparams:{ key: 'maybe_bad' value: { number_value: 1 } }
group_name: 'group_1'
""",
)
},
"session_2": {
metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data(
DATA_TYPE_SESSION_START_INFO,
"""
hparams:{ key: 'maybe_bad' value: { number_value: nan } }
group_name: 'group_2'
""",
),
},
"session_3": {
metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data(
DATA_TYPE_SESSION_START_INFO,
"""
hparams:{ key: 'maybe_bad' value: { number_value: -infinity } }
group_name: 'group_3'
""",
),
},
"session_4": {
metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data(
DATA_TYPE_SESSION_START_INFO,
"""
hparams:{ key: 'maybe_bad' value: { number_value: 4.0 } }
group_name: 'group_4'
""",
),
},
}
result = {}
for (run, tag_to_content) in hparams_content.items():
result.setdefault(run, {})
for (tag, content) in tag_to_content.items():
t = provider.TensorTimeSeries(
max_step=0,
max_wall_time=0,
plugin_content=content,
description="",
display_name="",
)
result[run][tag] = t
return result

def test_hparams_with_invalid_number_values(self):
self._mock_tb_context.data_provider.list_tensors.side_effect = (
self._mock_list_tensors_invalid_number_values
)
request = """
start_index: 0
slice_size: 10
allowed_statuses: [STATUS_UNKNOWN]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just an arbitrary value that is irrelevant for the test? Can it be omitted, or is it a required field?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya unfortunately it is necessary for the query to work. (allowed_status: [] will return no results).

"""
groups = self._run_handler(request).session_groups
self.assertLen(groups, 4)
self.assertEqual(1, groups[0].hparams.get("maybe_bad").number_value)
self.assertEqual(None, groups[1].hparams.get("maybe_bad"))
self.assertEqual(None, groups[2].hparams.get("maybe_bad"))
self.assertEqual(4, groups[3].hparams.get("maybe_bad").number_value)

def test_sort_hparams_with_invalid_number_values(self):
self._mock_tb_context.data_provider.list_tensors.side_effect = (
self._mock_list_tensors_invalid_number_values
)
self._verify_handler(
request="""
start_index: 0
slice_size: 10
allowed_statuses: [STATUS_UNKNOWN]
col_params: {
hparam: 'maybe_bad'
order: ORDER_DESC
}
""",
expected_session_group_names=[
"group_4",
"group_1",
"group_2",
"group_3",
],
expected_total_size=4,
)

def test_filter_hparams_include_invalid_number_values(self):
self._mock_tb_context.data_provider.list_tensors.side_effect = (
self._mock_list_tensors_invalid_number_values
)
self._verify_handler(
request="""
start_index: 0
slice_size: 10
allowed_statuses: [STATUS_UNKNOWN]
col_params: {
hparam: 'maybe_bad'
order: ORDER_DESC
filter_interval: { min_value: 2.0 max_value: 10.0 }
}
""",
expected_session_group_names=["group_4", "group_2", "group_3"],
expected_total_size=3,
)

def test_filer_hparams_exclude_invalid_number_values(self):
self._mock_tb_context.data_provider.list_tensors.side_effect = (
self._mock_list_tensors_invalid_number_values
)
self._verify_handler(
request="""
start_index: 0
slice_size: 10
allowed_statuses: [STATUS_UNKNOWN]
col_params: {
hparam: 'maybe_bad'
exclude_missing_values: true
}
""",
expected_session_group_names=["group_1", "group_4"],
expected_total_size=2,
)

def test_experiment_without_any_hparams(self):
self._mock_tb_context.data_provider.list_tensors.side_effect = None
self._hyperparameters = []
Expand Down