Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for fast accumulation selection for FP8 GEMM #3416

Merged
merged 5 commits into from
Nov 8, 2023

Conversation

wenscarl
Copy link
Collaborator

@wenscarl wenscarl commented Oct 16, 2023

In this PR, issue#6168 is addressed by introducing a custom gradients(forward mode autodiff) for fp8 dot_general. This PR is closely related to FLAX PR-6599@reedwm @kaixih @burmako

@wenscarl wenscarl changed the title Allow for fast accumulation selection for FP8 GEMM [draft]Allow for fast accumulation selection for FP8 GEMM Oct 16, 2023
@wenscarl wenscarl changed the title [draft]Allow for fast accumulation selection for FP8 GEMM Allow for fast accumulation selection for FP8 GEMM Oct 17, 2023
@burmako
Copy link

burmako commented Oct 17, 2023

Thank you for the heads up! Also cc @GleasonK

"""
FP8 helper to manage the FP8 meta
"""
FWD_DTYPE: DType = jnp.float8_e4m3fn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't observe the usage of FWD/BWD_DTYPE in this commit. Perhaps we should limit this commit to handling accumulation datatypes exclusively. For simplicity, I suggest hardcoding the accumulation datatypes directly within the dot general custom gradient functions as we have already done for the input/output dtypes.

@codecov-commenter
Copy link

codecov-commenter commented Oct 27, 2023

Codecov Report

Merging #3416 (09f2ec9) into main (0b126b8) will increase coverage by 0.12%.
Report is 10 commits behind head on main.
The diff coverage is 93.33%.

@@            Coverage Diff             @@
##             main    #3416      +/-   ##
==========================================
+ Coverage   83.50%   83.63%   +0.12%     
==========================================
  Files          56       56              
  Lines        6725     6802      +77     
==========================================
+ Hits         5616     5689      +73     
- Misses       1109     1113       +4     
Files Coverage Δ
flax/linen/fp8_ops.py 98.92% <93.33%> (-1.08%) ⬇️

... and 4 files with indirect coverage changes

copybara-service bot pushed a commit to openxla/xla that referenced this pull request Nov 2, 2023
Imported from GitHub PR #6599

FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168

This PR is closely related to Flax PR-![3416](google/flax#3416).
Copybara import of the project:

--
a4140da by shuw <shuw@nvidia.com>:

Add FP8 fast accumulation support for cublasLt.

--
9684568 by shuw <shuw@nvidia.com>:

Improve based on review #1

--
e906d76 by shuw <shuw@nvidia.com>:

Improve based on review #2

Merging this change closes #6599

FUTURE_COPYBARA_INTEGRATE_REVIEW=#6599 from wenscarl:fp8_fast_accumulation e906d76
PiperOrigin-RevId: 578904075
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Nov 2, 2023
Imported from GitHub PR #6599

FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168

This PR is closely related to Flax PR-![3416](google/flax#3416).
Copybara import of the project:

--
a4140da by shuw <shuw@nvidia.com>:

Add FP8 fast accumulation support for cublasLt.

--
9684568 by shuw <shuw@nvidia.com>:

Improve based on review #1

--
e906d76 by shuw <shuw@nvidia.com>:

Improve based on review #2

Merging this change closes #6599

FUTURE_COPYBARA_INTEGRATE_REVIEW=#6599 from wenscarl:fp8_fast_accumulation e906d76
PiperOrigin-RevId: 578904075
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Nov 2, 2023
Imported from GitHub PR #6599

FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168

This PR is closely related to Flax PR-![3416](google/flax#3416).
Copybara import of the project:

--
a4140da by shuw <shuw@nvidia.com>:

Add FP8 fast accumulation support for cublasLt.

--
9684568 by shuw <shuw@nvidia.com>:

Improve based on review #1

--
e906d76 by shuw <shuw@nvidia.com>:

Improve based on review #2

Merging this change closes #6599

FUTURE_COPYBARA_INTEGRATE_REVIEW=#6599 from wenscarl:fp8_fast_accumulation e906d76
PiperOrigin-RevId: 578904075
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Nov 2, 2023
Imported from GitHub PR #6599

FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168

This PR is closely related to Flax PR-![3416](google/flax#3416).
Copybara import of the project:

--
a4140da by shuw <shuw@nvidia.com>:

Add FP8 fast accumulation support for cublasLt.

--
9684568 by shuw <shuw@nvidia.com>:

Improve based on review #1

--
e906d76 by shuw <shuw@nvidia.com>:

Improve based on review #2

Merging this change closes #6599

COPYBARA_INTEGRATE_REVIEW=#6599 from wenscarl:fp8_fast_accumulation e906d76
PiperOrigin-RevId: 578948593
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 2, 2023
Imported from GitHub PR openxla/xla#6599

FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue#openxla/xla#6168

This PR is closely related to Flax PR-![3416](google/flax#3416).
Copybara import of the project:

--
a4140da8ca08cd2d4796a7b8f032827867a361bc by shuw <shuw@nvidia.com>:

Add FP8 fast accumulation support for cublasLt.

--
96845683cc4b1e7b947bc919fbf97d8865abeac9 by shuw <shuw@nvidia.com>:

Improve based on review #1

--
e906d7620780d2cf1fe8433c933648dcb98dc61d by shuw <shuw@nvidia.com>:

Improve based on review #2

Merging this change closes #6599

PiperOrigin-RevId: 578948593
@levskaya
Copy link
Collaborator

levskaya commented Nov 2, 2023

@wenscarl - sorry I was about to merge but there's a fresh merge conflict - would you mind resolving it and I can merge?

@@ -123,6 +123,24 @@ def out_qdq_bwd(compute_dtype, res, g):
out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd)


@partial(custom_jvp, nondiff_argnums=(2,))
def dot_general_with_precision(lhs, rhs, dimension_numbers):
Copy link
Contributor

@kaixih kaixih Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make this have the same signature of the lax.dot_general() so that I can inject it to the jnp.einsum?

@partial(custom_jvp, nondiff_argnums=(2, 3, 4))
def dot_general_with_precision(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None):
   return lax.dot_general(lhs, rhs, dimension_numbers, precision=lax.Precision.DEFAULT)
  

precision=lax.Precision.DEFAULT)

@dot_general_with_precision.defjvp
def dot_general_with_precision_jvp(dimension_numbers, primals, tangents):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this also needs to be changed accordingly, like:

@dot_general_with_precision.defjvp
def dot_general_with_precision_jvp(dimension_numbers, precision, preferred_element_type, primals, tangents):

flax/linen/fp8_ops.py Outdated Show resolved Hide resolved
@kaixih
Copy link
Contributor

kaixih commented Nov 3, 2023

As for #3416 (comment), can you also rebase your change against the latest? I think the upstream has done some format cleanup. @wenscarl

@wenscarl wenscarl force-pushed the expm_fp8_precision branch 2 times, most recently from d307996 to 2176aac Compare November 3, 2023 17:23
@kaixih
Copy link
Contributor

kaixih commented Nov 6, 2023

@levskaya Can you take a look if the merge can be triggered?

@levskaya
Copy link
Collaborator

levskaya commented Nov 7, 2023

Sorry could you re-push, there was a stupid warning being triggered that I've now silenced. But beyond that I'm seeing this error in internal CI for Fp8Test.test_fp8_dot_general_injection:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/build/work/74f3597f42da87446d9a71de4bf7594fa9c1/google3/runfiles/google3/third_party/py/flax/tests/linen/linen_test.py", line 940, in test_fp8_dot_general_injection
    output2a, output2b = run(True, expected_shapes_new)
  File "/build/work/74f3597f42da87446d9a71de4bf7594fa9c1/google3/runfiles/google3/third_party/py/flax/tests/linen/linen_test.py", line 909, in run
    y, initial_vars = p.init_with_output(init_key, x)
  File "/build/work/74f3597f42da87446d9a71de4bf7594fa9c1/google3/runfiles/google3/third_party/py/flax/linen/linear.py", line 167, in __call__
    out = dot_general(
  File "/build/work/74f3597f42da87446d9a71de4bf7594fa9c1/google3/runfiles/google3/third_party/py/flax/linen/fp8_ops.py", line 196, in __call__
    y_qdq = dot_general_with_precision(x_qdq, k_qdq, dimension_numbers)  # type: ignore
AttributeError: No JVP defined for custom_jvp function dot_general_with_precision using defjvp.

@kaixih
Copy link
Contributor

kaixih commented Nov 7, 2023

@levskaya Could you please review this once more? We have resolved the issue with the failed test, which occurred due to the accidental deletion of a decorator in the previous commit. All tests are now passing successfully.

@copybara-service copybara-service bot merged commit ed98abb into google:main Nov 8, 2023
19 checks passed
@levskaya
Copy link
Collaborator

levskaya commented Nov 8, 2023

@kaixih - should be merged now! sorry for the delay!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants