Skip to content

Commit

Permalink
Fixes Sequential serialization with custom object registration for HD…
Browse files Browse the repository at this point in the history
…F5 format.

PiperOrigin-RevId: 543569141
  • Loading branch information
nkovela1 authored and tensorflower-gardener committed Jun 26, 2023
1 parent 0454a40 commit a78f714
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 4 deletions.
8 changes: 5 additions & 3 deletions keras/engine/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from keras.engine import training
from keras.engine import training_utils
from keras.saving import serialization_lib
from keras.saving.legacy import serialization as legacy_serialization
from keras.saving.legacy.saved_model import model_serialization
from keras.utils import generic_utils
from keras.utils import layer_utils
Expand Down Expand Up @@ -434,14 +435,15 @@ def compute_mask(self, inputs, mask):

def get_config(self):
layer_configs = []
serialize_obj_fn = serialization_lib.serialize_keras_object
if getattr(self, "use_legacy_config", None):
serialize_obj_fn = legacy_serialization.serialize_keras_object
for layer in super().layers:
# `super().layers` include the InputLayer if available (it is
# filtered out of `self.layers`). Note that
# `self._self_tracked_trackables` is managed by the tracking
# infrastructure and should not be used.
layer_configs.append(
serialization_lib.serialize_keras_object(layer)
)
layer_configs.append(serialize_obj_fn(layer))
config = training.Model.get_config(self)
config["name"] = self.name
config["layers"] = copy.deepcopy(layer_configs)
Expand Down
3 changes: 3 additions & 0 deletions keras/saving/legacy/hdf5_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
"import h5py."
)

# Ensures that all models saved in HDF5 format follow the old serialization
model.use_legacy_config = True

# TODO(psv) Add warning when we save models that contain non-serializable
# entities like metrics added using `add_metric` and losses added using
# `add_loss.`
Expand Down
41 changes: 40 additions & 1 deletion keras/saving/legacy/save_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,46 @@ def c(self):
)
self.assertIsInstance(reloaded_model, new_cls)

@test_combinations.generate(test_combinations.combine(mode=["eager"]))
def test_custom_sequential_registered_no_scope(self):
@object_registration.register_keras_serializable(package="my_package")
class MyDense(keras.layers.Dense):
def __init__(self, units, **kwargs):
super().__init__(units, **kwargs)

input_shape = [1]
inputs = keras.Input(shape=input_shape)
custom_layer = MyDense(1)
saved_model_dir = self._save_model_dir()
save_format = test_utils.get_save_format()

model = keras.Sequential(layers=[inputs, custom_layer])
model.save(saved_model_dir, save_format=save_format)
loaded_model = keras.models.load_model(saved_model_dir)

x = tf.constant([5])
self.assertAllEqual(model(x), loaded_model(x))

@test_combinations.generate(test_combinations.combine(mode=["eager"]))
def test_custom_functional_registered_no_scope(self):
@object_registration.register_keras_serializable(package="my_package")
class MyDense(keras.layers.Dense):
def __init__(self, units, **kwargs):
super().__init__(units, **kwargs)

saved_model_dir = self._save_model_dir()
save_format = test_utils.get_save_format()
input_shape = [1]
inputs = keras.Input(shape=input_shape)
outputs = MyDense(1)(inputs)
model = keras.Model(inputs, outputs)

model.save(saved_model_dir, save_format=save_format)
loaded_model = keras.models.load_model(saved_model_dir)

x = tf.constant([5])
self.assertAllEqual(model(x), loaded_model(x))

@test_combinations.generate(test_combinations.combine(mode=["eager"]))
def test_shared_objects(self):
class OuterLayer(keras.layers.Layer):
Expand Down Expand Up @@ -1222,7 +1262,6 @@ def _get_all_keys_recursive(dict_or_iterable):
with object_registration.CustomObjectScope(
{"OuterLayer": OuterLayer, "InnerLayer": InnerLayer}
):

# Test saving and loading to disk
save_format = test_utils.get_save_format()
saved_model_dir = self._save_model_dir()
Expand Down

0 comments on commit a78f714

Please sign in to comment.