Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA GPU] Support for mix type gemm bias addition fusion #2859

Closed
wants to merge 14 commits into from

Conversation

Cjkkkk
Copy link
Contributor

@Cjkkkk Cjkkkk commented May 8, 2023

Considering a gemm with D = alpha * A * B + beta * C form, We observed when A and C has different type, in XLA HLO, it will appear as 3 HLO instructions: Add(Convert(Gemm(A, B)), C).

However cuBLAS and cuBLASLt both supports gemm with different types for A and C. To be more specific, both supports 14 different type combinations (ComputeType, alpha/beta Type, A/BType, C/DType) .Therefore we can just create a single custom call Gemm(A, B, C) to fuse these 3 instructions and call corresponding gemm routines. This fusion opportunities is observed in weight gradient accumulation scenario.

This PR is intended to add support for these mix type gemm bias addition fusion, even though cuBLAS and cuBLASLt supports 14 combinations right now. XLA does not have support for choosing different compute type and scale type for different A/C type. We are not planning to add support for choosing different compute type and scale type here because we are not sure how that might affect the precision. Therefore, we plan to add support for 2 cases of form (ComputeType, alpha/beta Type, A/BType, C/DType):

  1. (fp32, fp32, fp16, fp32)
  2. (fp32, fp32, bf16, fp32)
  3. (fp32, fp32, s8, fp32)

This works for both legacy cuBLAS and cuBLASLt, and it should provide support for both regular gemm and batched gemm.

@google-cla
Copy link

google-cla bot commented May 8, 2023

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label May 8, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label May 8, 2023
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label May 8, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label May 8, 2023
// TYPED_GEMM(F16, F16, F16, F16)
TYPED_GEMM(F32, BF16, BF16, BF16)
TYPED_GEMM(F32, F16, F16, F16)
// TYPED_GEMM(F32, S8, S8, F32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are some things commented out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed. Cases are not supported.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you removed it here, but not in the TypesAreSupportedByLegacyCublas function. In some internal tests, we see this error being triggered:
"Unexpected GEMM dtype: s8 s8 f32".
Can you please double-check all supported cases again, whether they have a corresponding TYPED_GEMM or TYPED_GEMM_COMPLEX macro call?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is added back.

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label May 12, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label May 12, 2023
@SandSnip3r
Copy link
Contributor

@cheshire Isn't this pr a bit moot if Triton already handles all of our mixed type gemms?

@Cjkkkk Cjkkkk changed the title [XLA GPU] mix type gemm bias addition fusion support [XLA GPU] Support for mix type gemm bias addition fusion May 12, 2023
@cheshire
Copy link
Contributor

I think this is for different mixed precision case, where A and C have different types. We're mostly interested in cases where A and B have different datatypes, are those supported? The bias fusion isn't super-useful unless it can also fuse the broadcast.

@Cjkkkk
Copy link
Contributor Author

Cjkkkk commented May 15, 2023

@cheshire Isn't this pr a bit moot if Triton already handles all of our mixed type gemms?

could you point me to reference of what mix type cases Triton supports now?

@Cjkkkk
Copy link
Contributor Author

Cjkkkk commented May 15, 2023

I think this is for different mixed precision case, where A and C have different types. We're mostly interested in cases where A and B have different datatypes, are those supported? The bias fusion isn't super-useful unless it can also fuse the broadcast.

This MR is for cases where A and C have different types. Does Triton support this? This happens a lot in our internal model weight grad accumulation. The weight grad is fp16 and the accumulation is fp32.

As for the bias that is broadcast, for cublasLT, it can fuse such case. considering a case like this where A [8, 16] fp16, B [16, 32] fp16 and C [32] fp32. We can fuse A@ B + C into one gemm call.

@jprabhas
Copy link
Contributor

jprabhas commented May 16, 2023

I think this is for different mixed precision case, where A and C have different types. We're mostly interested in cases where A and B have different datatypes, are those supported? The bias fusion isn't super-useful unless it can also fuse the broadcast.

No, this only supports cases where A(B) has a different type from C. However, C may not only correspond to Bias. The add operation in matmul + C may come from arbitrary add operations. For instance, we found case in weight-gradient accumulation in backward graphs (for T5X and GPT models) where the gradient matmul is in bf16 , while the accumulation of the gradient buffer is done in fp32. This fix fuses those cases thus avoiding separate kernels for add, giving us 2% perf improvement with T5X and estimated 6% perf improvement with GPT model.

@@ -1516,7 +1617,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
const Shape &output_shape = instr.shape();

TF_ASSIGN_OR_RETURN(bool types_are_supported_by_cublas_lt,
TypesAreSupportedByCublasLt(instr));
AreTypesSupportedByCublasLt(instr));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: any reason to change the function name from TypesAre to AreTypes? I think the original reads better, especially now with the inconsistent naming of the bool.

Copy link
Contributor Author

@Cjkkkk Cjkkkk May 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. Just renamed to TypesAre*.

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label May 22, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label May 22, 2023
@Cjkkkk
Copy link
Contributor Author

Cjkkkk commented May 22, 2023

the internal XLA Linux CPU and GPU failed:
ERROR: /root/.cache/bazel/_bazel_root/217377b0e928b171b843eb11ea7bc36e/external/llvm-project/llvm/BUILD.bazel:200:11: no such package '@llvm_zlib//': The repository '@llvm_zlib' could not be resolved: Repository '@llvm_zlib' is not defined and referenced by '@llvm-project//llvm:Support'
How do we resolve this? Did not encounter this before.

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jun 5, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jun 5, 2023
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jun 6, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jun 6, 2023
@SandSnip3r SandSnip3r added ready to pull PR ready for merge process and removed ready to pull PR ready for merge process labels Jun 6, 2023
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jun 7, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jun 7, 2023
@SandSnip3r SandSnip3r added ready to pull PR ready for merge process and removed ready to pull PR ready for merge process labels Jun 7, 2023
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jun 7, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jun 7, 2023
@@ -1184,6 +1275,7 @@ class LegacyCublasGemmRewriteTest : public GemmRewriteTest {
DebugOptions GetDebugOptionsForTest() override {
DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest();
debug_options.set_xla_gpu_enable_cublaslt(false);
debug_options.set_xla_gpu_simplify_all_fp_conversions(true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why enable this for all tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuz mix type gemm test is guarded by this flag, this is not enabled in google internal test somehow. And it seems it is enabled by default when i ran it on my local machine

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed. Turns out it is related to filecheck.

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jun 7, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jun 7, 2023
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jun 7, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jun 7, 2023
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jun 13, 2023
Imported from GitHub PR openxla/xla#2859

Considering a gemm  with `D = alpha * A * B + beta * C` form, We observed when `A` and `C` has different type, in XLA HLO, it will appear as 3 HLO instructions: `Add(Convert(Gemm(A, B)), C)`.

However [cuBLAS](https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmex) and [cuBLASLt](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmul) both supports gemm with different types for A and C. To be more specific, both supports 14 different type combinations `(ComputeType, alpha/beta Type, A/BType, C/DType)` .Therefore we can just create a single custom call `Gemm(A, B, C)` to fuse these 3 instructions and call corresponding gemm routines. This fusion  opportunities is observed in weight gradient accumulation scenario.

This PR is intended to add support for these mix type gemm bias addition fusion, even though cuBLAS and cuBLASLt supports 14 combinations right now. XLA does not have support for choosing different compute type and scale type for different A/C type. We are not planning to add support for choosing different compute type and scale type here because we are not sure how that might affect the precision. Therefore, we plan to add support for 2 cases of form `(ComputeType, alpha/beta Type, A/BType, C/DType)`:
1. (fp32, fp32, fp16, fp32)
2. (fp32, fp32, bf16, fp32)
3. (fp32, fp32, s8, fp32)

This works for both legacy cuBLAS and cuBLASLt, and it should provide support for both regular gemm and batched gemm.
Copybara import of the project:

--
5caa8b0d57c1f0263dde45229ecd708b27290cfb by Shanbin Ke <ske@nvidia.com>:

init upstream

--
390044468fdd05fafe853640d4a526c5a6fc9da6 by Shanbin Ke <ske@nvidia.com>:

remove commented code

--
adb1e0e4553b899f82dc74ca07fdfd20dba19601 by Shanbin Ke <ske@nvidia.com>:

rename AreTypes* to TypesAre*

--
7f04569f06bd2dab4e49e003a951457e274fb860 by Shanbin Ke <ske@nvidia.com>:

remove some empty lines

--
d836d4c2283f26ab1eafb2d45dc297be70dcdd22 by Shanbin Ke <ske@nvidia.com>:

fix hlo_verifier failing issue

--
7e9f1d63dde6007c9aa59e3d104dc077206e0c20 by Shanbin Ke <ske@nvidia.com>:

add print in AsBlasDataType

--
108caef42b5f94daf92c9600285e2f3746a3be3a by Shanbin Ke <ske@nvidia.com>:

fix unsupported type issue

--
c7d4862daa9653cd1b9de629af11f2fe6037b8a4 by Shanbin Ke <ske@nvidia.com>:

add tests to OSS

--
6480ecb89e3beaac1f83b34fbdb57cdb325ffaf2 by Shanbin Ke <ske@nvidia.com>:

add s8 f32 support

--
1dee25f44e438542928a0b88f81ea775759e1161 by Shanbin Ke <ske@nvidia.com>:

enhance compute type choice

--
d5f6fb1d300e82fd15a5d34c1fe216bb64abe1a7 by Shanbin Ke <ske@nvidia.com>:

guard mix type gemm by xla_gpu_simplify_all_fp_conversions

--
5c853fa52c0c7d619ad46e9ee887fecf5d934126 by Shanbin Ke <ske@nvidia.com>:

explicitly set xla_gpu_simplify_all_fp_conversions=true

--
d91c05b1b1a05b127ade3cd3fd8f5de1133ecfd5 by Shanbin Ke <ske@nvidia.com>:

fix file check

Merging this change closes #2859

PiperOrigin-RevId: 539837473
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready to pull PR ready for merge process
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants