-
Notifications
You must be signed in to change notification settings - Fork 254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a mixed precision test and fix mixed precision errors for layers #1242
Add a mixed precision test and fix mixed precision errors for layers #1242
Conversation
/gcbrun |
Ah I think this needs a keras-core release to pass. Will sync with Francois. |
44b1005
to
fc4f0f5
Compare
) | ||
intermediate_shape = list(decoder_sequence_shape) | ||
intermediate_shape[-1] = self.intermediate_dim | ||
self._feedforward_output_dense.build(tuple(intermediate_shape)) | ||
self._feedforward_layer_norm = keras.layers.LayerNormalization( | ||
epsilon=self.layer_norm_epsilon, | ||
name="output_layer_norm", | ||
dtype=self.dtype_policy, | ||
name="feedforward_layer_norm", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a random thought while I saw these
With the new distribution API, we're going to need to be careful about changing layer names moving forward!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep! Trying to get us in a nice consistent state before we do.
keras_nlp/samplers/sampler.py
Outdated
This will always be done in full precision, regardless of dtype, and | ||
scale by `temperature`. | ||
""" | ||
dtype = logits.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe logits_dtype
? (For consistency of style with inputs_dtype
a few files up)
keras_nlp/tests/test_case.py
Outdated
output_data = layer(input_data) | ||
for tensor in tree.flatten(output_data): | ||
dtype = standardize_dtype(tensor.dtype) | ||
if "float" in dtype: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like assertDType
should be good here no?
03c1c85
to
79c0219
Compare
Is mixed precision working in Keras Core now? |
aefe4ed
to
926a737
Compare
Yes, or at least it should be largely with the latest release. Landed the loss scaling optimizer, which was the main piece we were missing, as well as a few other fixes. |
/gcbrun |
5 similar comments
/gcbrun |
/gcbrun |
/gcbrun |
/gcbrun |
/gcbrun |
Merging, last breakage unrelated #1251 |
No description provided.