diff --git a/keras/api/_tf_keras/keras/dtype_policies/__init__.py b/keras/api/_tf_keras/keras/dtype_policies/__init__.py index da8364263a2..2abb181f5df 100644 --- a/keras/api/_tf_keras/keras/dtype_policies/__init__.py +++ b/keras/api/_tf_keras/keras/dtype_policies/__init__.py @@ -4,6 +4,9 @@ since your modifications would be overwritten. """ +from keras.src.dtype_policies import deserialize +from keras.src.dtype_policies import get +from keras.src.dtype_policies import serialize from keras.src.dtype_policies.dtype_policy import DTypePolicy from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy from keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy diff --git a/keras/api/dtype_policies/__init__.py b/keras/api/dtype_policies/__init__.py index da8364263a2..2abb181f5df 100644 --- a/keras/api/dtype_policies/__init__.py +++ b/keras/api/dtype_policies/__init__.py @@ -4,6 +4,9 @@ since your modifications would be overwritten. """ +from keras.src.dtype_policies import deserialize +from keras.src.dtype_policies import get +from keras.src.dtype_policies import serialize from keras.src.dtype_policies.dtype_policy import DTypePolicy from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy from keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy diff --git a/keras/src/dtype_policies/__init__.py b/keras/src/dtype_policies/__init__.py index ec84c266041..03cff8015b9 100644 --- a/keras/src/dtype_policies/__init__.py +++ b/keras/src/dtype_policies/__init__.py @@ -1,23 +1,96 @@ from keras.src import backend +from keras.src.api_export import keras_export from keras.src.dtype_policies import dtype_policy from keras.src.dtype_policies.dtype_policy import QUANTIZATION_MODES +from keras.src.dtype_policies.dtype_policy import DTypePolicy from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy from keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy from keras.src.dtype_policies.dtype_policy import QuantizedFloat8DTypePolicy +ALL_OBJECTS = { + DTypePolicy, + FloatDTypePolicy, + QuantizedDTypePolicy, + QuantizedFloat8DTypePolicy, +} +ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS} + +@keras_export("keras.dtype_policies.serialize") +def serialize(dtype_policy): + """Serializes `DTypePolicy` instance. + + Args: + dtype_policy: A Keras `DTypePolicy` instance. + + Returns: + `DTypePolicy` configuration dictionary. + """ + from keras.src.saving import serialization_lib + + return serialization_lib.serialize_keras_object(dtype_policy) + + +@keras_export("keras.dtype_policies.deserialize") +def deserialize(config, custom_objects=None): + """Deserializes a serialized `DTypePolicy` instance. + + Args: + config: `DTypePolicy` configuration. + custom_objects: Optional dictionary mapping names (strings) to custom + objects (classes and functions) to be considered during + deserialization. + + Returns: + A Keras `DTypePolicy` instance. + """ + from keras.src.saving import serialization_lib + + return serialization_lib.deserialize_keras_object( + config, + module_objects=ALL_OBJECTS_DICT, + custom_objects=custom_objects, + ) + + +@keras_export("keras.dtype_policies.get") def get(identifier): + """Retrieves a Keras `DTypePolicy` instance. + + The `identifier` may be the string name of a `DTypePolicy` class. + + >>> policy = dtype_policies.get("mixed_bfloat16") + >>> type(loss) + + + You can also specify `config` of the dtype policy to this function by + passing dict containing `class_name` and `config` as an identifier. Also + note that the `class_name` must map to a `DTypePolicy` class + + >>> identifier = {"class_name": "FloatDTypePolicy", + ... "config": {"name": "float32"}} + >>> policy = dtype_policies.get(identifier) + >>> type(loss) + + + Args: + identifier: A dtype policy identifier. One of `None` or string name of a + `DTypePolicy` or `DTypePolicy` configuration dictionary or a + `DTypePolicy` instance. + + Returns: + A Keras `DTypePolicy` instance. + """ from keras.src.dtype_policies.dtype_policy import ( _get_quantized_dtype_policy_by_str, ) - from keras.src.saving import serialization_lib if identifier is None: return dtype_policy.dtype_policy() if isinstance(identifier, (FloatDTypePolicy, QuantizedDTypePolicy)): return identifier if isinstance(identifier, dict): - return serialization_lib.deserialize_keras_object(identifier) + return deserialize(identifier) if isinstance(identifier, str): if identifier.startswith(QUANTIZATION_MODES): return _get_quantized_dtype_policy_by_str(identifier) diff --git a/keras/src/dtype_policies/dtype_policy.py b/keras/src/dtype_policies/dtype_policy.py index 2618e118e2b..a55eaa4c065 100644 --- a/keras/src/dtype_policies/dtype_policy.py +++ b/keras/src/dtype_policies/dtype_policy.py @@ -293,6 +293,11 @@ def _get_all_valid_policies(self): ] return valid_policies + def get_config(self): + config = super().get_config() + config.update({"amax_history_length": self.amax_history_length}) + return config + @keras_export( [ diff --git a/keras/src/dtype_policies/dtype_policy_test.py b/keras/src/dtype_policies/dtype_policy_test.py index b040663781a..b66df0779f3 100644 --- a/keras/src/dtype_policies/dtype_policy_test.py +++ b/keras/src/dtype_policies/dtype_policy_test.py @@ -1,5 +1,8 @@ from absl.testing import parameterized +from keras.src.dtype_policies import deserialize +from keras.src.dtype_policies import get +from keras.src.dtype_policies import serialize from keras.src.dtype_policies.dtype_policy import DTypePolicy from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy from keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy @@ -64,7 +67,7 @@ def test_get_config_from_config(self): new_policy = DTypePolicy.from_config(config) self.assertEqual(new_policy.name, "mixed_float16") - def test_serialization(self): + def test_python_serialization(self): """Test builtin serialization methods.""" import copy import pickle @@ -91,6 +94,16 @@ def test_serialization(self): repr(copied_policy), '' ) + def test_serialization(self): + policy = DTypePolicy("mixed_float16") + config = serialize(policy) + reloaded_policy = deserialize(config) + self.assertEqual(policy.name, reloaded_policy.name) + + # Test `dtype_policies.get` + reloaded_policy = get(config) + self.assertEqual(policy.name, reloaded_policy.name) + class FloatDTypePolicyTest(test_case.TestCase): def test_initialization_valid_name(self): @@ -154,6 +167,16 @@ def test_get_config_from_config(self): new_policy = FloatDTypePolicy.from_config(config) self.assertEqual(new_policy.name, "mixed_float16") + def test_serialization(self): + policy = FloatDTypePolicy("mixed_float16") + config = serialize(policy) + reloaded_policy = deserialize(config) + self.assertEqual(policy.name, reloaded_policy.name) + + # Test `dtype_policies.get` + reloaded_policy = get(config) + self.assertEqual(policy.name, reloaded_policy.name) + class QuantizedDTypePolicyTest(test_case.TestCase, parameterized.TestCase): @parameterized.named_parameters( @@ -224,7 +247,7 @@ def test_get_config_from_config(self): '', ), ) - def test_serialization(self, name, repr_str): + def test_python_serialization(self, name, repr_str): import copy import pickle @@ -244,6 +267,16 @@ def test_serialization(self, name, repr_str): copied_policy = pickle.load(f) self.assertEqual(repr(copied_policy), repr_str) + def test_serialization(self): + policy = QuantizedDTypePolicy("int8_from_float32") + config = serialize(policy) + reloaded_policy = deserialize(config) + self.assertEqual(policy.name, reloaded_policy.name) + + # Test `dtype_policies.get` + reloaded_policy = get(config) + self.assertEqual(policy.name, reloaded_policy.name) + def test_properties_for_float8(self): policy = QuantizedFloat8DTypePolicy("float8_from_mixed_bfloat16") self.assertEqual(policy.amax_history_length, 1024) @@ -256,7 +289,7 @@ def test_invalid_properties_for_float8(self): with self.assertRaisesRegex(TypeError, "must be an integer."): QuantizedFloat8DTypePolicy("float8_from_float32", 512.0) - def test_serialization_for_float8(self): + def test_python_serialization_for_float8(self): import copy import pickle @@ -288,6 +321,22 @@ def test_serialization_for_float8(self): ) self.assertEqual(copied_policy.amax_history_length, 123) + def test_serialization_for_float8(self): + policy = QuantizedFloat8DTypePolicy("float8_from_mixed_float16") + config = serialize(policy) + reloaded_policy = deserialize(config) + self.assertEqual(policy.name, reloaded_policy.name) + self.assertEqual( + policy.amax_history_length, reloaded_policy.amax_history_length + ) + + # Test `dtype_policies.get` + reloaded_policy = get(config) + self.assertEqual(policy.name, reloaded_policy.name) + self.assertEqual( + policy.amax_history_length, reloaded_policy.amax_history_length + ) + @parameterized.named_parameters( ("int8_from_mixed_bfloat16", "int8_from_mixed_bfloat16"), ("float8_from_mixed_bfloat16", "float8_from_mixed_bfloat16"), diff --git a/keras/src/losses/__init__.py b/keras/src/losses/__init__.py index 9652ceb057b..3f4ef8d0f69 100644 --- a/keras/src/losses/__init__.py +++ b/keras/src/losses/__init__.py @@ -135,8 +135,8 @@ def deserialize(name, custom_objects=None): Args: name: Loss configuration. custom_objects: Optional dictionary mapping names (strings) to custom - objects (classes and functions) to be considered during - deserialization. + objects (classes and functions) to be considered during + deserialization. Returns: A Keras `Loss` instance or a loss function.