-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Refactor keras.dtype_policies
#19711
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor keras.dtype_policies
#19711
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #19711 +/- ##
=======================================
Coverage 78.52% 78.53%
=======================================
Files 498 498
Lines 45769 45756 -13
Branches 8456 8454 -2
=======================================
- Hits 35942 35936 -6
+ Misses 8091 8087 -4
+ Partials 1736 1733 -3
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
fchollet
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
| return f'<FloatDTypePolicy "{self._name}">' | ||
|
|
||
|
|
||
| GLOBAL_DEFAULT_PLACEHOLDER = "global_default" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use a more explicit name, e.g. "DEFAULT_DTYPE_POLICY". Why use this string as the initial value, instead of e.g. None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why use this string as the initial value, instead of e.g. None?
Currently, DTypePolicy and its subclasses rely on string value for parsing.
It is not clear for me how we can pass None in combination with the quantization mode.
Should we refactor QuantizedDTypePolicy to support a signature for both the quantization mode and the source dtype policy?
Ex:
policy = QuantizedDTypePolicy(mode="int8", source_dtype_policy="mixed_bfloat16")There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, DTypePolicy and its subclasses rely on string value for parsing.
It is not clear for me how we can pass None in combination with the quantization mode.
We could just modify DTypePolicy to support None, meaning "default".
Should we refactor QuantizedDTypePolicy to support a signature for both the quantization mode and the source dtype policy?
Yes, that's a great idea!
QuantizedDTypePolicykeras.dtype_policies
|
I've significantly refactored the Some notes:
Imcompatible warning:
To add flexibility to quantized dtype policy: Detailsimport keras
from keras import dtype_policies
from keras import layers
from keras import models
@keras.saving.register_keras_serializable("MyPackage")
class MySubclass(layers.Layer):
def __init__(self, **kwargs):
dtypes = kwargs.pop("dtypes", {})
super().__init__(**kwargs)
self.layer = layers.Dense(8, dtype=dtypes.pop("layer", None))
def call(self, inputs, training=None):
return self.layer(inputs)
def get_config(self):
config = super().get_config()
config.pop("dtype")
if self.layer.dtype_policy.is_quantized:
_config = dtype_policies.serialize(self.layer.dtype_policy)
_config["config"]["source_name"] = None
config.update({"dtypes": {"layer": _config}})
return config
inputs = layers.Input(shape=[None, 4])
outputs = MySubclass()(inputs)
model = models.Model(inputs, outputs)
"""global dtype policy (float32)"""
model.quantize("int8")
for layer in model._flatten_layers(include_self=False, recursive=True):
print(layer.name, layer.dtype_policy)
model.save("model.keras")
"""global dtype policy (bfloat16)"""
keras.config.set_dtype_policy("bfloat16")
new_model = models.load_model("model.keras")
for layer in new_model._flatten_layers(include_self=False, recursive=True):
print(layer.name, layer.dtype_policy)The outputs: # global dtype policy: float32
input_layer <FloatDTypePolicy "float32">
my_subclass <FloatDTypePolicy "float32">
dense <QuantizedDTypePolicy "int8_from_float32">
# global dtype policy: bfloat16
input_layer <FloatDTypePolicy "bfloat16">
my_subclass <FloatDTypePolicy "bfloat16">
dense_1 <QuantizedDTypePolicy "int8_from_bfloat16"> |
fchollet
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work -- it's definitely cleaner this way! LGTM
Keras' output format was slightly changed in keras-team/keras#19711; in some cases dtypes will now be exported as a config map instead of just a string. This fixes test breakages when using ToT keras.
Keras' output format was slightly changed in keras-team/keras#19711; for non-input layers dtypes will now be exported as a config map instead of just a string. This fixes test breakages when using ToT keras.
Keras' output format was slightly changed in keras-team/keras#19711; for non-input layers dtypes will now be exported as a config map instead of just a string. This fixes test breakages when using ToT keras.
Keras' output format was slightly changed in keras-team/keras#19711; for non-input layers dtypes will now be exported as a config map instead of just a string. This fixes test breakages when using ToT keras. Alternative to #6855
Original PR #19711 by james77777778 Original: keras-team/keras#19711
Merged from original PR #19711 Original: keras-team/keras#19711
Original PR #19711 by james77777778 Original: keras-team/keras#19711
Merged from original PR #19711 Original: keras-team/keras#19711
EDITED:
Please refer to #19711 (comment) for the new updates.
I think it would be beneficial to provide some flexibility to
QuantizedDTypePolicyregarding the global dtype policykeras.config.dtype_policy()Additionally, there is a new property in
DTypePolicy:is_quantizedthat should be useful for these quantization-related methods.With this PR, we can do the following:
Outputs:
@mattdangerw has pointed out that currently the dtype policies of the quantized saves are immutable regarding the global dtype policy. keras-team/keras-hub#1612 (comment)
With this PR, we can make a slight modification in
get_configto support that feature.