Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Update the algorithm to compute fp8 scales #3441

Merged
merged 2 commits into from
Oct 31, 2023

Conversation

kaixih
Copy link
Contributor

@kaixih kaixih commented Oct 26, 2023

This pull request implements updates to the algorithm for computing new FP8 scales, in line with the design outlined in the following documentation: https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/jax.html#transformer_engine.jax.update_fp8_metas.

cc. @wenscarl @reedwm @lukaszlew @levskaya

@levskaya
Copy link
Collaborator

The change looks OK, but mypy is now yielding a type error - would you mind double-checking if that's finding something real, or is it spurious?

@codecov-commenter
Copy link

codecov-commenter commented Oct 26, 2023

Codecov Report

Merging #3441 (3a0e87a) into main (5557649) will decrease coverage by 0.03%.
Report is 2 commits behind head on main.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main    #3441      +/-   ##
==========================================
- Coverage   83.68%   83.65%   -0.03%     
==========================================
  Files          56       56              
  Lines        6808     6797      -11     
==========================================
- Hits         5697     5686      -11     
  Misses       1111     1111              
Files Coverage Δ
flax/linen/__init__.py 100.00% <100.00%> (ø)
flax/linen/fp8_ops.py 100.00% <100.00%> (ø)
flax/training/train_state.py 100.00% <100.00%> (ø)

@kaixih
Copy link
Contributor Author

kaixih commented Oct 26, 2023

@levskaya thank you for bringing this to my attention. I've removed the function return value annotation that was causing the previous mypy check to trigger. I don't believe it's necessary because this custom operation for FP8 is not configured to the layers, such as DenseGeneral, by default. There won't be any other type checking tracing back to this operation.

@kaixih
Copy link
Contributor Author

kaixih commented Oct 30, 2023

@levskaya Any updates? I see this PR has stayed in "pull ready" status for a couple of days.

@copybara-service copybara-service bot merged commit 8d09772 into google:main Oct 31, 2023
19 checks passed
@levskaya
Copy link
Collaborator

@kaixih - super sorry for the delay, not sure why it didn't go through earlier! It should have merged last night.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants