From 861ad749eeff1f222f941c1f96f0f56c58a97742 Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi Date: Fri, 28 Apr 2023 14:46:49 -0700 Subject: [PATCH] Adds error for serializing metric using layer serialization. PiperOrigin-RevId: 527991285 --- keras/layers/serialization.py | 8 ++++++++ keras/layers/serialization_test.py | 6 ++++++ 2 files changed, 14 insertions(+) diff --git a/keras/layers/serialization.py b/keras/layers/serialization.py index fd0e6b0a6e5..e35761b5b27 100644 --- a/keras/layers/serialization.py +++ b/keras/layers/serialization.py @@ -50,6 +50,7 @@ from keras.layers.rnn import cell_wrappers from keras.layers.rnn import gru from keras.layers.rnn import lstm +from keras.metrics import base_metric from keras.saving import serialization_lib from keras.saving.legacy import serialization as legacy_serialization from keras.saving.legacy.saved_model import json_utils @@ -208,6 +209,13 @@ def serialize(layer, use_legacy_format=False): pprint(tf.keras.layers.serialize(model)) # prints the configuration of the model, as a dict. """ + if isinstance(layer, base_metric.Metric): + raise ValueError( + f"Cannot serialize {layer} since it is a metric. " + "Please use the `keras.metrics.serialize()` and " + "`keras.metrics.deserialize()` APIs to serialize " + "and deserialize metrics." + ) if use_legacy_format: return legacy_serialization.serialize_keras_object(layer) diff --git a/keras/layers/serialization_test.py b/keras/layers/serialization_test.py index c457ccd621e..688466be0b7 100644 --- a/keras/layers/serialization_test.py +++ b/keras/layers/serialization_test.py @@ -24,6 +24,7 @@ from keras.layers.rnn import gru_v1 from keras.layers.rnn import lstm from keras.layers.rnn import lstm_v1 +from keras.metrics import Mean from keras.testing_infra import test_combinations @@ -191,6 +192,11 @@ def test_serialize_deserialize_gru(self, layer): self.assertIsInstance(new_layer, gru_v1.GRU) self.assertNotIsInstance(new_layer, gru.GRU) + def test_serialize_metric_throws_error(self): + metric = Mean() + with self.assertRaisesRegex(ValueError, "since it is a metric."): + _ = keras.layers.serialize(metric) + if __name__ == "__main__": tf.test.main()