From 8419edf82c6cb5ba40c7b6577cfa0d70df5e44e0 Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi Date: Mon, 26 Jun 2023 23:40:01 +0000 Subject: [PATCH] Cherrypick Sequential serialization bug fix for r2.13 --- keras/engine/sequential.py | 8 +++--- keras/saving/legacy/hdf5_format.py | 3 +++ keras/saving/legacy/save_test.py | 41 +++++++++++++++++++++++++++++- 3 files changed, 48 insertions(+), 4 deletions(-) diff --git a/keras/engine/sequential.py b/keras/engine/sequential.py index a04bca2f223..da81d5a13ec 100644 --- a/keras/engine/sequential.py +++ b/keras/engine/sequential.py @@ -26,6 +26,7 @@ from keras.engine import training from keras.engine import training_utils from keras.saving import serialization_lib +from keras.saving.legacy import serialization as legacy_serialization from keras.saving.legacy.saved_model import model_serialization from keras.utils import generic_utils from keras.utils import layer_utils @@ -441,14 +442,15 @@ def compute_mask(self, inputs, mask): def get_config(self): layer_configs = [] + serialize_obj_fn = serialization_lib.serialize_keras_object + if getattr(self, "use_legacy_config", None): + serialize_obj_fn = legacy_serialization.serialize_keras_object for layer in super().layers: # `super().layers` include the InputLayer if available (it is # filtered out of `self.layers`). Note that # `self._self_tracked_trackables` is managed by the tracking # infrastructure and should not be used. - layer_configs.append( - serialization_lib.serialize_keras_object(layer) - ) + layer_configs.append(serialize_obj_fn(layer)) config = training.Model.get_config(self) config["name"] = self.name config["layers"] = copy.deepcopy(layer_configs) diff --git a/keras/saving/legacy/hdf5_format.py b/keras/saving/legacy/hdf5_format.py index f739a0ec728..b4597655df4 100644 --- a/keras/saving/legacy/hdf5_format.py +++ b/keras/saving/legacy/hdf5_format.py @@ -81,6 +81,9 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True): "import h5py." ) + # Ensures that all models saved in HDF5 format follow the old serialization + model.use_legacy_config = True + # TODO(psv) Add warning when we save models that contain non-serializable # entities like metrics added using `add_metric` and losses added using # `add_loss.` diff --git a/keras/saving/legacy/save_test.py b/keras/saving/legacy/save_test.py index 7d7185baefb..b9ec7d5d749 100644 --- a/keras/saving/legacy/save_test.py +++ b/keras/saving/legacy/save_test.py @@ -1134,6 +1134,46 @@ def c(self): ) self.assertIsInstance(reloaded_model, new_cls) + @test_combinations.generate(test_combinations.combine(mode=["eager"])) + def test_custom_sequential_registered_no_scope(self): + @object_registration.register_keras_serializable(package="my_package") + class MyDense(keras.layers.Dense): + def __init__(self, units, **kwargs): + super().__init__(units, **kwargs) + + input_shape = [1] + inputs = keras.Input(shape=input_shape) + custom_layer = MyDense(1) + saved_model_dir = self._save_model_dir() + save_format = test_utils.get_save_format() + + model = keras.Sequential(layers=[inputs, custom_layer]) + model.save(saved_model_dir, save_format=save_format) + loaded_model = keras.models.load_model(saved_model_dir) + + x = tf.constant([5]) + self.assertAllEqual(model(x), loaded_model(x)) + + @test_combinations.generate(test_combinations.combine(mode=["eager"])) + def test_custom_functional_registered_no_scope(self): + @object_registration.register_keras_serializable(package="my_package") + class MyDense(keras.layers.Dense): + def __init__(self, units, **kwargs): + super().__init__(units, **kwargs) + + saved_model_dir = self._save_model_dir() + save_format = test_utils.get_save_format() + input_shape = [1] + inputs = keras.Input(shape=input_shape) + outputs = MyDense(1)(inputs) + model = keras.Model(inputs, outputs) + + model.save(saved_model_dir, save_format=save_format) + loaded_model = keras.models.load_model(saved_model_dir) + + x = tf.constant([5]) + self.assertAllEqual(model(x), loaded_model(x)) + @test_combinations.generate(test_combinations.combine(mode=["eager"])) def test_shared_objects(self): class OuterLayer(keras.layers.Layer): @@ -1222,7 +1262,6 @@ def _get_all_keys_recursive(dict_or_iterable): with object_registration.CustomObjectScope( {"OuterLayer": OuterLayer, "InnerLayer": InnerLayer} ): - # Test saving and loading to disk save_format = test_utils.get_save_format() saved_model_dir = self._save_model_dir()