-
Notifications
You must be signed in to change notification settings - Fork 440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[XLA GPU] Support for mix type gemm bias addition fusion #2859
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
xla/service/gpu/matmul_utils.cc
Outdated
// TYPED_GEMM(F16, F16, F16, F16) | ||
TYPED_GEMM(F32, BF16, BF16, BF16) | ||
TYPED_GEMM(F32, F16, F16, F16) | ||
// TYPED_GEMM(F32, S8, S8, F32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are some things commented out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed. Cases are not supported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems you removed it here, but not in the TypesAreSupportedByLegacyCublas function. In some internal tests, we see this error being triggered:
"Unexpected GEMM dtype: s8 s8 f32".
Can you please double-check all supported cases again, whether they have a corresponding TYPED_GEMM or TYPED_GEMM_COMPLEX macro call?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is added back.
@cheshire Isn't this pr a bit moot if Triton already handles all of our mixed type gemms? |
I think this is for different mixed precision case, where A and C have different types. We're mostly interested in cases where A and B have different datatypes, are those supported? The bias fusion isn't super-useful unless it can also fuse the broadcast. |
could you point me to reference of what mix type cases Triton supports now? |
This MR is for cases where A and C have different types. Does Triton support this? This happens a lot in our internal model weight grad accumulation. The weight grad is fp16 and the accumulation is fp32. As for the bias that is broadcast, for cublasLT, it can fuse such case. considering a case like this where A [8, 16] fp16, B [16, 32] fp16 and C [32] fp32. We can fuse A@ B + C into one gemm call. |
No, this only supports cases where A(B) has a different type from C. However, C may not only correspond to Bias. The add operation in matmul + C may come from arbitrary add operations. For instance, we found case in weight-gradient accumulation in backward graphs (for T5X and GPT models) where the gradient matmul is in bf16 , while the accumulation of the gradient buffer is done in fp32. This fix fuses those cases thus avoiding separate kernels for add, giving us 2% perf improvement with T5X and estimated 6% perf improvement with GPT model. |
xla/service/gpu/gemm_rewriter.cc
Outdated
@@ -1516,7 +1617,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { | |||
const Shape &output_shape = instr.shape(); | |||
|
|||
TF_ASSIGN_OR_RETURN(bool types_are_supported_by_cublas_lt, | |||
TypesAreSupportedByCublasLt(instr)); | |||
AreTypesSupportedByCublasLt(instr)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: any reason to change the function name from TypesAre
to AreTypes
? I think the original reads better, especially now with the inconsistent naming of the bool.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really. Just renamed to TypesAre*
.
the internal XLA Linux CPU and GPU failed: |
@@ -1184,6 +1275,7 @@ class LegacyCublasGemmRewriteTest : public GemmRewriteTest { | |||
DebugOptions GetDebugOptionsForTest() override { | |||
DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest(); | |||
debug_options.set_xla_gpu_enable_cublaslt(false); | |||
debug_options.set_xla_gpu_simplify_all_fp_conversions(true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why enable this for all tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cuz mix type gemm test is guarded by this flag, this is not enabled in google internal test somehow. And it seems it is enabled by default when i ran it on my local machine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed. Turns out it is related to filecheck.
Imported from GitHub PR openxla/xla#2859 Considering a gemm with `D = alpha * A * B + beta * C` form, We observed when `A` and `C` has different type, in XLA HLO, it will appear as 3 HLO instructions: `Add(Convert(Gemm(A, B)), C)`. However [cuBLAS](https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmex) and [cuBLASLt](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmul) both supports gemm with different types for A and C. To be more specific, both supports 14 different type combinations `(ComputeType, alpha/beta Type, A/BType, C/DType)` .Therefore we can just create a single custom call `Gemm(A, B, C)` to fuse these 3 instructions and call corresponding gemm routines. This fusion opportunities is observed in weight gradient accumulation scenario. This PR is intended to add support for these mix type gemm bias addition fusion, even though cuBLAS and cuBLASLt supports 14 combinations right now. XLA does not have support for choosing different compute type and scale type for different A/C type. We are not planning to add support for choosing different compute type and scale type here because we are not sure how that might affect the precision. Therefore, we plan to add support for 2 cases of form `(ComputeType, alpha/beta Type, A/BType, C/DType)`: 1. (fp32, fp32, fp16, fp32) 2. (fp32, fp32, bf16, fp32) 3. (fp32, fp32, s8, fp32) This works for both legacy cuBLAS and cuBLASLt, and it should provide support for both regular gemm and batched gemm. Copybara import of the project: -- 5caa8b0d57c1f0263dde45229ecd708b27290cfb by Shanbin Ke <ske@nvidia.com>: init upstream -- 390044468fdd05fafe853640d4a526c5a6fc9da6 by Shanbin Ke <ske@nvidia.com>: remove commented code -- adb1e0e4553b899f82dc74ca07fdfd20dba19601 by Shanbin Ke <ske@nvidia.com>: rename AreTypes* to TypesAre* -- 7f04569f06bd2dab4e49e003a951457e274fb860 by Shanbin Ke <ske@nvidia.com>: remove some empty lines -- d836d4c2283f26ab1eafb2d45dc297be70dcdd22 by Shanbin Ke <ske@nvidia.com>: fix hlo_verifier failing issue -- 7e9f1d63dde6007c9aa59e3d104dc077206e0c20 by Shanbin Ke <ske@nvidia.com>: add print in AsBlasDataType -- 108caef42b5f94daf92c9600285e2f3746a3be3a by Shanbin Ke <ske@nvidia.com>: fix unsupported type issue -- c7d4862daa9653cd1b9de629af11f2fe6037b8a4 by Shanbin Ke <ske@nvidia.com>: add tests to OSS -- 6480ecb89e3beaac1f83b34fbdb57cdb325ffaf2 by Shanbin Ke <ske@nvidia.com>: add s8 f32 support -- 1dee25f44e438542928a0b88f81ea775759e1161 by Shanbin Ke <ske@nvidia.com>: enhance compute type choice -- d5f6fb1d300e82fd15a5d34c1fe216bb64abe1a7 by Shanbin Ke <ske@nvidia.com>: guard mix type gemm by xla_gpu_simplify_all_fp_conversions -- 5c853fa52c0c7d619ad46e9ee887fecf5d934126 by Shanbin Ke <ske@nvidia.com>: explicitly set xla_gpu_simplify_all_fp_conversions=true -- d91c05b1b1a05b127ade3cd3fd8f5de1133ecfd5 by Shanbin Ke <ske@nvidia.com>: fix file check Merging this change closes #2859 PiperOrigin-RevId: 539837473
Considering a gemm with
D = alpha * A * B + beta * C
form, We observed whenA
andC
has different type, in XLA HLO, it will appear as 3 HLO instructions:Add(Convert(Gemm(A, B)), C)
.However cuBLAS and cuBLASLt both supports gemm with different types for A and C. To be more specific, both supports 14 different type combinations
(ComputeType, alpha/beta Type, A/BType, C/DType)
.Therefore we can just create a single custom callGemm(A, B, C)
to fuse these 3 instructions and call corresponding gemm routines. This fusion opportunities is observed in weight gradient accumulation scenario.This PR is intended to add support for these mix type gemm bias addition fusion, even though cuBLAS and cuBLASLt supports 14 combinations right now. XLA does not have support for choosing different compute type and scale type for different A/C type. We are not planning to add support for choosing different compute type and scale type here because we are not sure how that might affect the precision. Therefore, we plan to add support for 2 cases of form
(ComputeType, alpha/beta Type, A/BType, C/DType)
:This works for both legacy cuBLAS and cuBLASLt, and it should provide support for both regular gemm and batched gemm.