Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] inThreadTranspose: Transpose between global load and local store for non-TN layouts: part 1 of 4 #5148

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jtang10
Copy link
Contributor

@jtang10 jtang10 commented Nov 14, 2024

inThreadTranpose: part 1 of 4

Introduction

This PR introduces the AMD-specific inThreadTranspose feature to improve shared memory access efficiency for non-TN GEMM and 2nd dotOp in Flash Attention.
The entire feature has been broken into 4 pieces for more reliable integration and this PR is the 1st of 4.

  • Add TTGIR pass for inThreadTranspose w/o activating it
  • Add the sharedEncodingAttr update and associate toLLVM change. This should fully enable inThreadTranspose.
  • Add the new flag in sharedEncodingAttr w/o using it
  • Change the shared linearLayout conversion and shared->dot in amd path to enable it. This should improve the perf of inThreadTranspose.

Feature description

Currently on AMD hardware, if the dot Operand is K-major, we'd use the same vectorization for ds_write as global_load, but won't coalesce on ds_read, resulting in poor shared memory/LDS read efficiency prior to MFMA operation.
This feature, inThreadTranspose, groups multiple global_load together and packs vector across grain to write to LDS with vectorization, so that when the matrix is written into LDS, it's already consecutive on K dimension, and therefore vectorized ds_read is also enabled. This is achieved by v_perm_b32 assembly instruction in AMDGCN, allowing independent register to be contiguous in VGPR space, so that we can write them together into LDS.

PR description

The 1st of 4, i.e. this PR, introduces the TTGIR pass to pack multiple global_load together, so that in C = A @ B, where B is K-major, its blocked layout will be changed from

#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}>

to

#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}>

To accompany this update, the following changes are also added:

  1. A new way to interpret blockedEncodingAttr. Because we have changed sizePerThread from 1D to 2D, there will be 2 ways to convert it to linear layout. We want to keep the original way, i.e along the dimension of 8 in the example above, when reading from global memory, but convert along the dimension of 4, when transfering from blocked to shared.
  2. Associated lit tests and linear layout conversion tests.

Note that this feature is tested but not activated in compiler.py until the 2nd PR lands in the future.

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because FILL THIS IN.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

@jtang10 jtang10 changed the title [AMD] inThreadTranspose: efficient blocked->shared for K-major dotOperand: part 1 of 4 [AMD] inThreadTranspose: Transpose between global load and local store for non-TN layouts: part 1 of 4 Nov 14, 2024
@jtang10 jtang10 marked this pull request as ready for review November 14, 2024 19:45
@antiagainst antiagainst marked this pull request as draft November 14, 2024 19:57
@jtang10 jtang10 force-pushed the jingtang/inThreadTranspose_blocked branch from 8b9bb50 to 7f24463 Compare November 14, 2024 20:12
Comment on lines 106 to 125
// traverse 2nd dimension (K-dim in GEMM case)
int dim = order[1];
for (int basis = 1; basis < shape[dim]; basis <<= 1) {
bases.push_back({0, basis});
}
// traverse 1st dimension (N-dim in GEMM K-major B-tensor)
// this is the consecutive dimension loaded from global memory
dim = order[0];
for (int basis = 1; basis < shape[dim]; basis <<= 1) {
bases.push_back({basis, 0});
}
auto dimMinor = "dim" + std::to_string(order[0]);
auto dimMajor = "dim" + std::to_string(order[1]);
StringAttr kDimMinor = S(dimMinor);
StringAttr kDimMajor = S(dimMajor);
ret = LinearLayout(
{{kRegister, bases}},
{{kDimMinor, shape[order[0]]}, {kDimMajor, shape[order[1]]}}, false);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to "unroll" the loop. Try to do it in a similar way as in identityND

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the problem here is that if I fuse the loops, the basis I will push_back into the bases vector will have problem. Notice that it is {0, basis} in the first loop and {basis, 0} in the second loop, so technically I can still fuse the loop together, but maybe keeping it as-is, instead of adding a conditional arg inside the single for-loop, will be more readable

auto numMaxIters = elemsTotal / elemsPerIter;
auto bitwidth = tensorTy.getElementType().getIntOrFloatBitWidth();
// Current the widest is set to ds_write_b64
auto newKOuterDim = std::min(numMaxIters, 64 / bitwidth);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we do ds_write_b128?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh I just fixed this 64bits initially and never changed it afterwards. I can experiment on this later when I have all the PRs in to see if relaxing it to 128bits is better, and change accordingly

for (auto loadOp : loadOps)
convertLayout(newBlockedEnc, (Operation *)loadOp);
} else {
LDBG("opB is K-inner and nothing to be done");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code for opA and opB are duplicated. Can you try to merge them?

Copy link
Collaborator

@zhanglx13 zhanglx13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see my inline comments.
When you address review comments, make sure to submit new commits rather than rebase. Otherwise the comments are gone.

@jtang10 jtang10 force-pushed the jingtang/inThreadTranspose_blocked branch 3 times, most recently from 1da0b5f to 2875461 Compare November 20, 2024 17:37
// in-thread, we'd want to iterate of the 2nd order, i.e. dim of 4, so that we
// can pack the element of 4 into a single vector, and AMD backend LLVM compiler
// will pack elements into consecutive VGPR and therefore achieve high
// vectorization LDS write and also keep data in K-minor inside LDS.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not say "high vectorization LDS write" because the old way, i.e. traversing the dim of 8, has higher vectorsize of ds_write.
So the ds_write can have short or long vector size, but we are guaranteed to have vectorized ds_read.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about change this
"and AMD backend LLVM compiler will pack elements into consecutive VGPR and therefore achieve high vectorization LDS write and also keep data in K-minor inside LDS."
to
"and AMD backend LLVM compiler will pack elements into consecutive VGPR to write data contiguous in K dimension into LDS. In this way we guarantee vectorized ds_read, and ds_write can be vectorized to 64bit or 32bit depending on the block size and number of warps"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good

@jtang10 jtang10 force-pushed the jingtang/inThreadTranspose_blocked branch from b623a9d to e275dcc Compare November 21, 2024 19:44
@jtang10 jtang10 force-pushed the jingtang/inThreadTranspose_blocked branch from e275dcc to 3ce4134 Compare November 21, 2024 20:37
Copy link
Collaborator

@zhanglx13 zhanglx13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.
@antiagainst please take a look

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants