-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVIDIA] Use custom grad accumulation for FP8 params #3623
[NVIDIA] Use custom grad accumulation for FP8 params #3623
Conversation
Also, cc. @mingxu1067 |
Note, the failed tests on |
@kaixih it looks good but we will have to wait for JAX to push a new release to pypi for tests to pass (according to your comment). |
@kaixih thanks! Seems that |
@kaixih do you mind fixing the CI errors? |
It seems the CI is still on jax 0.4.23 (failed test)
I re-tested on my machine but using 0.4.24 and then the tests can pass.
@zhangqiaorjc @cgarciae Can you help check if the jax has already been updated to 0.4.24 or later? |
Based on the output of the test run it seems not:
|
As a quick fix maybe add venv/bin/python3 -m pip install -U jax jaxlib after this line: flax/.github/workflows/build.yml Line 120 in daf06ea
|
@cgarciae Do you mean I should add this line in this PR? |
Created a PR so you can rebase when merged. |
3e31661
to
dd004c2
Compare
@zhangqiaorjc It seems all tests pass now. Can you take another look or reassign? Thx. |
@cgarciae is this PR merge blocked on any internal error? |
This pull request introduces a custom data type rule for the FP8 parameters to implement custom gradient accumulation. Specifically, when reusing the FP8 parameters, the autograd will accumulate their gradients. In this case, we aim for the accumulation to be a maximum operation instead of the default addition operation.