-
Notifications
You must be signed in to change notification settings - Fork 493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fp8 Fast Accumulation support for cublasLt #6599
Conversation
5eacd0f
to
a4140da
Compare
// accumulation is enabled. When Precision is set to HIGHEST, indicative of | ||
// scenarios in backward propagation, a higher precision accumulation method | ||
// is utilized. | ||
bool fast_accum = (xla::primitive_util::IsF8Type(lhs_layout.dtype) || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we call it use_fast_accum
or enable_fast_accum
to imply it is a bool? Also, shouldn't it be is_fp8(lhs) && is_fp8(rhs) && cfg.compute_precision == 0
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on current rewrite rule, Fp8 matmuls has to take both inputs as Fp8 types. So any one of them being Fp8 type should be good to indicate the Fp8 matmul.
// encountered during forward propagation with E4M3 operands, fast | ||
// accumulation is enabled. When Precision is set to HIGHEST, indicative of | ||
// scenarios in backward propagation, a higher precision accumulation method | ||
// is utilized. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no need to specify the particular cases for using different precisions. Let's simply state that for fp8 matmul, there are two options available: fast accumulation (PrecisionConfig.Precision.DEFAULT
) and higher precision accumulation (PrecisionConfig.Precision.HIGHEST
).
@@ -210,6 +210,8 @@ cudaDataType_t BlasLt::MatrixLayout::type() const { | |||
AsCublasOperation(trans_b))); | |||
TF_ASSIGN_OR_RETURN(cublasLtEpilogue_t epi, AsCublasLtEpilogue(epilogue)); | |||
TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, epi)); | |||
TF_RETURN_IF_ERROR( | |||
SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, int8_t(fast_accum))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use static_cast<int8_t>?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to add a test by checking an FP8 matmul output is accurate enough if the PrecisionConfig is HIGHEST? I'm OK having no test if it's not easy to test this.
// For FP8 matmuls, there are two options available: fast | ||
// accumulation(PrecisionConfig.Precision.DEFAULT) and | ||
// higher precision accumulation (PrecisionConfig.Precision.HIGHEST). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't mention the HIGHER case. I would phrase this as:
FP8 matmuls have a fast accumulation mode that is less precise than the default accumulation mode. Use the fast accumulation mode if the compute precision is DEFAULT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a test in gemm_rewrite_test.cc
replacements["<<precision>>"] = "default"; | ||
const auto hlo_text_default = absl::StrReplaceAll(hlo_template, replacements); | ||
EXPECT_TRUE(RunAndCompare(hlo_text_default, ErrorSpec{1e-3, 1e-3})); | ||
EXPECT_FALSE(RunAndCompare(hlo_text_default, ErrorSpec{1e-4, 1e-4})); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This expectation seems to fail, it seems if we are lucky, it already has enough precision to pass with a tolerance of 1e-4.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. Do you suggest to remove this test or replace by some file check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be guarded by a check for Ada/Hopper, because it only affects those GPUs and our tests are running on pre-Ada GPUs. But it's possible new GPUs will treat the fast-accumulation flag differently, so we should not do this check anyway.
I'll remove this line when merging.
Imported from GitHub PR #6599 FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168 This PR is closely related to Flax PR-![3416](google/flax#3416). Copybara import of the project: -- a4140da by shuw <shuw@nvidia.com>: Add FP8 fast accumulation support for cublasLt. -- 9684568 by shuw <shuw@nvidia.com>: Improve based on review #1 -- e906d76 by shuw <shuw@nvidia.com>: Improve based on review #2 Merging this change closes #6599 FUTURE_COPYBARA_INTEGRATE_REVIEW=#6599 from wenscarl:fp8_fast_accumulation e906d76 PiperOrigin-RevId: 578904075
Imported from GitHub PR #6599 FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168 This PR is closely related to Flax PR-![3416](google/flax#3416). Copybara import of the project: -- a4140da by shuw <shuw@nvidia.com>: Add FP8 fast accumulation support for cublasLt. -- 9684568 by shuw <shuw@nvidia.com>: Improve based on review #1 -- e906d76 by shuw <shuw@nvidia.com>: Improve based on review #2 Merging this change closes #6599 FUTURE_COPYBARA_INTEGRATE_REVIEW=#6599 from wenscarl:fp8_fast_accumulation e906d76 PiperOrigin-RevId: 578904075
Imported from GitHub PR #6599 FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168 This PR is closely related to Flax PR-![3416](google/flax#3416). Copybara import of the project: -- a4140da by shuw <shuw@nvidia.com>: Add FP8 fast accumulation support for cublasLt. -- 9684568 by shuw <shuw@nvidia.com>: Improve based on review #1 -- e906d76 by shuw <shuw@nvidia.com>: Improve based on review #2 Merging this change closes #6599 FUTURE_COPYBARA_INTEGRATE_REVIEW=#6599 from wenscarl:fp8_fast_accumulation e906d76 PiperOrigin-RevId: 578904075
Imported from GitHub PR openxla/xla#6599 FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue#openxla/xla#6168 This PR is closely related to Flax PR-![3416](google/flax#3416). Copybara import of the project: -- a4140da8ca08cd2d4796a7b8f032827867a361bc by shuw <shuw@nvidia.com>: Add FP8 fast accumulation support for cublasLt. -- 96845683cc4b1e7b947bc919fbf97d8865abeac9 by shuw <shuw@nvidia.com>: Improve based on review #1 -- e906d7620780d2cf1fe8433c933648dcb98dc61d by shuw <shuw@nvidia.com>: Improve based on review #2 Merging this change closes #6599 PiperOrigin-RevId: 578948593
cuBLAS LT has a flag, CUBLASLT_MATMUL_DESC_FAST_ACCUM, that can be set for FP8 gemms. This flag causes the matmul to run faster but with lower accumulation precision. NVIDIA recommends using this flag on the forward pass on FP8 models but not the backwards pass, since the backwards pass needs more accumulation precision. The PR openxla/xla#6599 enabled fast accumulation on FP8 dots whose PrecisionConfig is DEFAULT (but not HIGH or HIGHEST). This allows layers in frameworks to use fast accumulation on the forward pass but not the backwards pass by setting the PrecisionConfig on the backwards pass to be HIGH or HIGHEST. The issue is, Flax and Praxis do not yet set the PrecisionConfig to HIGH or HIGHEST on the backwards pass, so the PR will cause poor FP8 training quality. The PR should not have been merged until Flax and Praxis set the PrecisionConfig, but I didn't realize this and merged it anyway. Reverting the PR is a pain, so instead this CL just removes the line that sets CUBLASLT_MATMUL_DESC_FAST_ACCUM, while keeping most of the plumbing around it. This CL will be rolled back once Flax and Praxis set the PrecisionConfig. PiperOrigin-RevId: 579018421
cuBLAS LT has a flag, CUBLASLT_MATMUL_DESC_FAST_ACCUM, that can be set for FP8 gemms. This flag causes the matmul to run faster but with lower accumulation precision. NVIDIA recommends using this flag on the forward pass on FP8 models but not the backwards pass, since the backwards pass needs more accumulation precision. The PR #6599 enabled fast accumulation on FP8 dots whose PrecisionConfig is DEFAULT (but not HIGH or HIGHEST). This allows layers in frameworks to use fast accumulation on the forward pass but not the backwards pass by setting the PrecisionConfig on the backwards pass to be HIGH or HIGHEST. The issue is, Flax and Praxis do not yet set the PrecisionConfig to HIGH or HIGHEST on the backwards pass, so the PR will cause poor FP8 training quality. The PR should not have been merged until Flax and Praxis set the PrecisionConfig, but I didn't realize this and merged it anyway. Reverting the PR is a pain, so instead this CL just removes the line that sets CUBLASLT_MATMUL_DESC_FAST_ACCUM, while keeping most of the plumbing around it. This CL will be rolled back once Flax and Praxis set the PrecisionConfig. PiperOrigin-RevId: 579018421
FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168
This PR is closely related to Flax PR-.