diff --git a/tensorboard/plugins/hparams/BUILD b/tensorboard/plugins/hparams/BUILD index 266663efd1..4d112f3d10 100644 --- a/tensorboard/plugins/hparams/BUILD +++ b/tensorboard/plugins/hparams/BUILD @@ -29,6 +29,7 @@ py_library( "download_data.py", "get_experiment.py", "hparams_plugin.py", + "json_format_compat.py", "list_metric_evals.py", "list_session_groups.py", "metrics.py", @@ -115,6 +116,15 @@ py_test( ], ) +py_test( + name = "json_format_compat_test", + size = "small", + srcs = [ + "json_format_compat_test.py", + ], + deps = [":hparams_plugin"], +) + py_binary( name = "hparams_demo", srcs = ["hparams_demo.py"], diff --git a/tensorboard/plugins/hparams/backend_context.py b/tensorboard/plugins/hparams/backend_context.py index efc5affd70..152454211e 100644 --- a/tensorboard/plugins/hparams/backend_context.py +++ b/tensorboard/plugins/hparams/backend_context.py @@ -22,6 +22,7 @@ from tensorboard.data import provider from tensorboard.plugins.hparams import api_pb2 +from tensorboard.plugins.hparams import json_format_compat from tensorboard.plugins.hparams import metadata from google.protobuf import json_format from tensorboard.plugins.scalar import metadata as scalar_metadata @@ -282,11 +283,6 @@ def _compute_hparam_info_from_values(self, name, values): # If all values have the same type, then that is the type used. # Otherwise, the returned type is DATA_TYPE_STRING. result = api_pb2.HParamInfo(name=name, type=api_pb2.DATA_TYPE_UNSET) - distinct_values = set( - _protobuf_value_to_string(v) - for v in values - if _protobuf_value_type(v) - ) for v in values: v_type = _protobuf_value_type(v) if not v_type: @@ -304,6 +300,11 @@ def _compute_hparam_info_from_values(self, name, values): return None if result.type == api_pb2.DATA_TYPE_STRING: + distinct_values = set( + _protobuf_value_to_string(v) + for v in values + if _can_be_converted_to_string(v) + ) result.domain_discrete.extend(distinct_values) if result.type == api_pb2.DATA_TYPE_BOOL: @@ -452,6 +453,12 @@ def _find_longest_parent_path(path_set, path): return path +def _can_be_converted_to_string(value): + if not _protobuf_value_type(value): + return False + return json_format_compat.is_serializable_value(value) + + def _protobuf_value_type(value): """Returns the type of the google.protobuf.Value message as an api.DataType. diff --git a/tensorboard/plugins/hparams/backend_context_test.py b/tensorboard/plugins/hparams/backend_context_test.py index a4fcb70828..f71052680c 100644 --- a/tensorboard/plugins/hparams/backend_context_test.py +++ b/tensorboard/plugins/hparams/backend_context_test.py @@ -344,6 +344,33 @@ def test_experiment_with_bool_types(self): _canonicalize_experiment(actual_exp) self.assertProtoEquals(expected_exp, actual_exp) + def test_experiment_with_string_domain_and_invalid_number_values(self): + self.session_1_start_info_ = """ + hparams:[ + {key: 'maybe_invalid' value: {string_value: 'force_to_string_type'}} + ] + """ + self.session_2_start_info_ = """ + hparams:[ + {key: 'maybe_invalid' value: {number_value: NaN}} + ] + """ + self.session_3_start_info_ = """ + hparams:[ + {key: 'maybe_invalid' value: {number_value: Infinity}} + ] + """ + expected_hparam_info = """ + name: 'maybe_invalid' + type: DATA_TYPE_STRING + domain_discrete: { + values: [{string_value: 'force_to_string_type'}] + } + """ + actual_exp = self._experiment_from_metadata() + self.assertLen(actual_exp.hparam_infos, 1) + self.assertProtoEquals(expected_hparam_info, actual_exp.hparam_infos[0]) + def test_experiment_without_any_hparams(self): request_ctx = context.RequestContext() actual_exp = self._experiment_from_metadata() diff --git a/tensorboard/plugins/hparams/json_format_compat.py b/tensorboard/plugins/hparams/json_format_compat.py new file mode 100644 index 0000000000..43fa422cc7 --- /dev/null +++ b/tensorboard/plugins/hparams/json_format_compat.py @@ -0,0 +1,38 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import math + + +def is_serializable_value(value): + """Returns whether a protobuf Value will be serializable by MessageToJson. + + The json_format documentation states that "attempting to serialize NaN or + Infinity results in error." + + https://protobuf.dev/reference/protobuf/google.protobuf/#value + + Args: + value: A value of type protobuf.Value. + + Returns: + True if the Value should be serializable without error by MessageToJson. + False, otherwise. + """ + if not value.HasField("number_value"): + return True + + number_value = value.number_value + return not math.isnan(number_value) and not math.isinf(number_value) diff --git a/tensorboard/plugins/hparams/json_format_compat_test.py b/tensorboard/plugins/hparams/json_format_compat_test.py new file mode 100644 index 0000000000..d87ed7b780 --- /dev/null +++ b/tensorboard/plugins/hparams/json_format_compat_test.py @@ -0,0 +1,64 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from google.protobuf import struct_pb2 +from tensorboard.plugins.hparams import json_format_compat + + +class TestCase(absltest.TestCase): + def test_real_value_is_serializable(self): + self.assertTrue( + json_format_compat.is_serializable_value( + struct_pb2.Value(number_value=1.0) + ) + ) + self.assertTrue( + json_format_compat.is_serializable_value( + struct_pb2.Value(string_value="nan") + ) + ) + self.assertTrue( + json_format_compat.is_serializable_value( + struct_pb2.Value(bool_value=False) + ) + ) + + def test_empty_value_is_serializable(self): + self.assertTrue( + json_format_compat.is_serializable_value(struct_pb2.Value()) + ) + + def test_nan_value_is_not_serializable(self): + self.assertFalse( + json_format_compat.is_serializable_value( + struct_pb2.Value(number_value=float("nan")) + ) + ) + + def test_infinity_value_is_not_serializable(self): + self.assertFalse( + json_format_compat.is_serializable_value( + struct_pb2.Value(number_value=float("inf")) + ) + ) + self.assertFalse( + json_format_compat.is_serializable_value( + struct_pb2.Value(number_value=float("-inf")) + ) + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tensorboard/plugins/hparams/list_session_groups.py b/tensorboard/plugins/hparams/list_session_groups.py index 5d46f64797..daace18297 100644 --- a/tensorboard/plugins/hparams/list_session_groups.py +++ b/tensorboard/plugins/hparams/list_session_groups.py @@ -26,6 +26,7 @@ from tensorboard.data import provider from tensorboard.plugins.hparams import api_pb2 from tensorboard.plugins.hparams import error +from tensorboard.plugins.hparams import json_format_compat from tensorboard.plugins.hparams import metadata from tensorboard.plugins.hparams import metrics @@ -246,6 +247,13 @@ def _add_session(self, session, start_info, groups_by_name): # There doesn't seem to be a way to initialize a protobuffer map in the # constructor. for (key, value) in start_info.hparams.items(): + if not json_format_compat.is_serializable_value(value): + # NaN number_value cannot be serialized by higher level layers + # that are using json_format.MessageToJson(). To workaround + # the issue we do not copy them to the session group and + # effectively treat them as "unset". + continue + group.hparams[key].CopyFrom(value) groups_by_name[group_name] = group diff --git a/tensorboard/plugins/hparams/list_session_groups_test.py b/tensorboard/plugins/hparams/list_session_groups_test.py index 5a9a4aa684..cdc6f482f2 100644 --- a/tensorboard/plugins/hparams/list_session_groups_test.py +++ b/tensorboard/plugins/hparams/list_session_groups_test.py @@ -1172,6 +1172,137 @@ def test_sort_one_column_with_missing_values(self): expected_total_size=3, ) + def _mock_list_tensors_invalid_number_values( + self, ctx, *, experiment_id, plugin_name, run_tag_filter + ): + hparams_content = { + "session_1": { + metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( + DATA_TYPE_SESSION_START_INFO, + """ + hparams:{ key: 'maybe_bad' value: { number_value: 1 } } + group_name: 'group_1' + """, + ) + }, + "session_2": { + metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( + DATA_TYPE_SESSION_START_INFO, + """ + hparams:{ key: 'maybe_bad' value: { number_value: nan } } + group_name: 'group_2' + """, + ), + }, + "session_3": { + metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( + DATA_TYPE_SESSION_START_INFO, + """ + hparams:{ key: 'maybe_bad' value: { number_value: -infinity } } + group_name: 'group_3' + """, + ), + }, + "session_4": { + metadata.SESSION_START_INFO_TAG: self._serialized_plugin_data( + DATA_TYPE_SESSION_START_INFO, + """ + hparams:{ key: 'maybe_bad' value: { number_value: 4.0 } } + group_name: 'group_4' + """, + ), + }, + } + result = {} + for (run, tag_to_content) in hparams_content.items(): + result.setdefault(run, {}) + for (tag, content) in tag_to_content.items(): + t = provider.TensorTimeSeries( + max_step=0, + max_wall_time=0, + plugin_content=content, + description="", + display_name="", + ) + result[run][tag] = t + return result + + def test_hparams_with_invalid_number_values(self): + self._mock_tb_context.data_provider.list_tensors.side_effect = ( + self._mock_list_tensors_invalid_number_values + ) + request = """ + start_index: 0 + slice_size: 10 + allowed_statuses: [STATUS_UNKNOWN] + """ + groups = self._run_handler(request).session_groups + self.assertLen(groups, 4) + self.assertEqual(1, groups[0].hparams.get("maybe_bad").number_value) + self.assertEqual(None, groups[1].hparams.get("maybe_bad")) + self.assertEqual(None, groups[2].hparams.get("maybe_bad")) + self.assertEqual(4, groups[3].hparams.get("maybe_bad").number_value) + + def test_sort_hparams_with_invalid_number_values(self): + self._mock_tb_context.data_provider.list_tensors.side_effect = ( + self._mock_list_tensors_invalid_number_values + ) + self._verify_handler( + request=""" + start_index: 0 + slice_size: 10 + allowed_statuses: [STATUS_UNKNOWN] + col_params: { + hparam: 'maybe_bad' + order: ORDER_DESC + } + """, + expected_session_group_names=[ + "group_4", + "group_1", + "group_2", + "group_3", + ], + expected_total_size=4, + ) + + def test_filter_hparams_include_invalid_number_values(self): + self._mock_tb_context.data_provider.list_tensors.side_effect = ( + self._mock_list_tensors_invalid_number_values + ) + self._verify_handler( + request=""" + start_index: 0 + slice_size: 10 + allowed_statuses: [STATUS_UNKNOWN] + col_params: { + hparam: 'maybe_bad' + order: ORDER_DESC + filter_interval: { min_value: 2.0 max_value: 10.0 } + } + """, + expected_session_group_names=["group_4", "group_2", "group_3"], + expected_total_size=3, + ) + + def test_filer_hparams_exclude_invalid_number_values(self): + self._mock_tb_context.data_provider.list_tensors.side_effect = ( + self._mock_list_tensors_invalid_number_values + ) + self._verify_handler( + request=""" + start_index: 0 + slice_size: 10 + allowed_statuses: [STATUS_UNKNOWN] + col_params: { + hparam: 'maybe_bad' + exclude_missing_values: true + } + """, + expected_session_group_names=["group_1", "group_4"], + expected_total_size=2, + ) + def test_experiment_without_any_hparams(self): self._mock_tb_context.data_provider.list_tensors.side_effect = None self._hyperparameters = []