Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dtype handling for Adam and LAMB optimizers in 64bit mode. #965

Merged
merged 2 commits into from
Feb 2, 2021

Conversation

levskaya
Copy link
Collaborator

@levskaya levskaya commented Feb 1, 2021

This prevents accidental upcasting of parameters due to wonky JAX integer + floating-point py-scalar + array dtype rules:
in jax x64 mode: jnp.array(1) + 1. becomes a jax array with dtype float64, that dtype then "invades" the params in the optimizer, causing dtype instability.

Should fix #924

This prevents accidental upcasting of parameters due to wonky
JAX integer + floating-point py-scalar + array dtype rules.
@levskaya levskaya requested a review from andsteing February 1, 2021 07:07
@marcvanzee marcvanzee added the Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required) label Feb 1, 2021
@levskaya levskaya requested a review from avital February 1, 2021 23:09
@copybara-service copybara-service bot merged commit 61580b9 into google:master Feb 2, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required) pull ready
Projects
None yet
Development

Successfully merging this pull request may close these issues.

When jax_enable_x64 is set Adam promotes everything to float64
3 participants