Skip to content

Commit

Permalink
Fix the apis of dtype_polices (#19580)
Browse files Browse the repository at this point in the history
* Fix api of `dtype_polices`

* Update docstring

* Increase test coverage

* Fix format
  • Loading branch information
james77777778 authored Apr 22, 2024
1 parent 5021ab7 commit 3afc089
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 7 deletions.
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/dtype_policies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
since your modifications would be overwritten.
"""

from keras.src.dtype_policies import deserialize
from keras.src.dtype_policies import get
from keras.src.dtype_policies import serialize
from keras.src.dtype_policies.dtype_policy import DTypePolicy
from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy
from keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy
Expand Down
3 changes: 3 additions & 0 deletions keras/api/dtype_policies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
since your modifications would be overwritten.
"""

from keras.src.dtype_policies import deserialize
from keras.src.dtype_policies import get
from keras.src.dtype_policies import serialize
from keras.src.dtype_policies.dtype_policy import DTypePolicy
from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy
from keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy
Expand Down
77 changes: 75 additions & 2 deletions keras/src/dtype_policies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,96 @@
from keras.src import backend
from keras.src.api_export import keras_export
from keras.src.dtype_policies import dtype_policy
from keras.src.dtype_policies.dtype_policy import QUANTIZATION_MODES
from keras.src.dtype_policies.dtype_policy import DTypePolicy
from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy
from keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy
from keras.src.dtype_policies.dtype_policy import QuantizedFloat8DTypePolicy

ALL_OBJECTS = {
DTypePolicy,
FloatDTypePolicy,
QuantizedDTypePolicy,
QuantizedFloat8DTypePolicy,
}
ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}


@keras_export("keras.dtype_policies.serialize")
def serialize(dtype_policy):
"""Serializes `DTypePolicy` instance.
Args:
dtype_policy: A Keras `DTypePolicy` instance.
Returns:
`DTypePolicy` configuration dictionary.
"""
from keras.src.saving import serialization_lib

return serialization_lib.serialize_keras_object(dtype_policy)


@keras_export("keras.dtype_policies.deserialize")
def deserialize(config, custom_objects=None):
"""Deserializes a serialized `DTypePolicy` instance.
Args:
config: `DTypePolicy` configuration.
custom_objects: Optional dictionary mapping names (strings) to custom
objects (classes and functions) to be considered during
deserialization.
Returns:
A Keras `DTypePolicy` instance.
"""
from keras.src.saving import serialization_lib

return serialization_lib.deserialize_keras_object(
config,
module_objects=ALL_OBJECTS_DICT,
custom_objects=custom_objects,
)


@keras_export("keras.dtype_policies.get")
def get(identifier):
"""Retrieves a Keras `DTypePolicy` instance.
The `identifier` may be the string name of a `DTypePolicy` class.
>>> policy = dtype_policies.get("mixed_bfloat16")
>>> type(loss)
<class '...FloatDTypePolicy'>
You can also specify `config` of the dtype policy to this function by
passing dict containing `class_name` and `config` as an identifier. Also
note that the `class_name` must map to a `DTypePolicy` class
>>> identifier = {"class_name": "FloatDTypePolicy",
... "config": {"name": "float32"}}
>>> policy = dtype_policies.get(identifier)
>>> type(loss)
<class '...FloatDTypePolicy'>
Args:
identifier: A dtype policy identifier. One of `None` or string name of a
`DTypePolicy` or `DTypePolicy` configuration dictionary or a
`DTypePolicy` instance.
Returns:
A Keras `DTypePolicy` instance.
"""
from keras.src.dtype_policies.dtype_policy import (
_get_quantized_dtype_policy_by_str,
)
from keras.src.saving import serialization_lib

if identifier is None:
return dtype_policy.dtype_policy()
if isinstance(identifier, (FloatDTypePolicy, QuantizedDTypePolicy)):
return identifier
if isinstance(identifier, dict):
return serialization_lib.deserialize_keras_object(identifier)
return deserialize(identifier)
if isinstance(identifier, str):
if identifier.startswith(QUANTIZATION_MODES):
return _get_quantized_dtype_policy_by_str(identifier)
Expand Down
5 changes: 5 additions & 0 deletions keras/src/dtype_policies/dtype_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,11 @@ def _get_all_valid_policies(self):
]
return valid_policies

def get_config(self):
config = super().get_config()
config.update({"amax_history_length": self.amax_history_length})
return config


@keras_export(
[
Expand Down
55 changes: 52 additions & 3 deletions keras/src/dtype_policies/dtype_policy_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from absl.testing import parameterized

from keras.src.dtype_policies import deserialize
from keras.src.dtype_policies import get
from keras.src.dtype_policies import serialize
from keras.src.dtype_policies.dtype_policy import DTypePolicy
from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy
from keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy
Expand Down Expand Up @@ -64,7 +67,7 @@ def test_get_config_from_config(self):
new_policy = DTypePolicy.from_config(config)
self.assertEqual(new_policy.name, "mixed_float16")

def test_serialization(self):
def test_python_serialization(self):
"""Test builtin serialization methods."""
import copy
import pickle
Expand All @@ -91,6 +94,16 @@ def test_serialization(self):
repr(copied_policy), '<FloatDTypePolicy "mixed_float16">'
)

def test_serialization(self):
policy = DTypePolicy("mixed_float16")
config = serialize(policy)
reloaded_policy = deserialize(config)
self.assertEqual(policy.name, reloaded_policy.name)

# Test `dtype_policies.get`
reloaded_policy = get(config)
self.assertEqual(policy.name, reloaded_policy.name)


class FloatDTypePolicyTest(test_case.TestCase):
def test_initialization_valid_name(self):
Expand Down Expand Up @@ -154,6 +167,16 @@ def test_get_config_from_config(self):
new_policy = FloatDTypePolicy.from_config(config)
self.assertEqual(new_policy.name, "mixed_float16")

def test_serialization(self):
policy = FloatDTypePolicy("mixed_float16")
config = serialize(policy)
reloaded_policy = deserialize(config)
self.assertEqual(policy.name, reloaded_policy.name)

# Test `dtype_policies.get`
reloaded_policy = get(config)
self.assertEqual(policy.name, reloaded_policy.name)


class QuantizedDTypePolicyTest(test_case.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
Expand Down Expand Up @@ -224,7 +247,7 @@ def test_get_config_from_config(self):
'<QuantizedFloat8DTypePolicy "float8_from_mixed_bfloat16">',
),
)
def test_serialization(self, name, repr_str):
def test_python_serialization(self, name, repr_str):
import copy
import pickle

Expand All @@ -244,6 +267,16 @@ def test_serialization(self, name, repr_str):
copied_policy = pickle.load(f)
self.assertEqual(repr(copied_policy), repr_str)

def test_serialization(self):
policy = QuantizedDTypePolicy("int8_from_float32")
config = serialize(policy)
reloaded_policy = deserialize(config)
self.assertEqual(policy.name, reloaded_policy.name)

# Test `dtype_policies.get`
reloaded_policy = get(config)
self.assertEqual(policy.name, reloaded_policy.name)

def test_properties_for_float8(self):
policy = QuantizedFloat8DTypePolicy("float8_from_mixed_bfloat16")
self.assertEqual(policy.amax_history_length, 1024)
Expand All @@ -256,7 +289,7 @@ def test_invalid_properties_for_float8(self):
with self.assertRaisesRegex(TypeError, "must be an integer."):
QuantizedFloat8DTypePolicy("float8_from_float32", 512.0)

def test_serialization_for_float8(self):
def test_python_serialization_for_float8(self):
import copy
import pickle

Expand Down Expand Up @@ -288,6 +321,22 @@ def test_serialization_for_float8(self):
)
self.assertEqual(copied_policy.amax_history_length, 123)

def test_serialization_for_float8(self):
policy = QuantizedFloat8DTypePolicy("float8_from_mixed_float16")
config = serialize(policy)
reloaded_policy = deserialize(config)
self.assertEqual(policy.name, reloaded_policy.name)
self.assertEqual(
policy.amax_history_length, reloaded_policy.amax_history_length
)

# Test `dtype_policies.get`
reloaded_policy = get(config)
self.assertEqual(policy.name, reloaded_policy.name)
self.assertEqual(
policy.amax_history_length, reloaded_policy.amax_history_length
)

@parameterized.named_parameters(
("int8_from_mixed_bfloat16", "int8_from_mixed_bfloat16"),
("float8_from_mixed_bfloat16", "float8_from_mixed_bfloat16"),
Expand Down
4 changes: 2 additions & 2 deletions keras/src/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ def deserialize(name, custom_objects=None):
Args:
name: Loss configuration.
custom_objects: Optional dictionary mapping names (strings) to custom
objects (classes and functions) to be considered during
deserialization.
objects (classes and functions) to be considered during
deserialization.
Returns:
A Keras `Loss` instance or a loss function.
Expand Down

0 comments on commit 3afc089

Please sign in to comment.