Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Layout conversion error on H100 #4418

Closed
calebthomas259 opened this issue Jul 29, 2024 · 2 comments · Fixed by #4492
Closed

Layout conversion error on H100 #4418

calebthomas259 opened this issue Jul 29, 2024 · 2 comments · Fixed by #4492
Assignees
Labels

Comments

@calebthomas259
Copy link

Hello,

My modified flash attention kernel gives me the following error when I run it on a H100 GPU, even though the kernel works fine on A100 and RTX 3060:

python: ../../../lib/Analysis/Allocation.cpp:43: std::pair<llvm::SmallVector<unsigned int>, llvm::SmallVector<unsigned int> > mlir::triton::getCvtOrder(mlir::Attribute, mlir::Attribute): Assertion `!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) && "mma -> mma layout conversion is only supported on Ampere"' failed.
Aborted (core dumped)

I've reduced the kernel down to the following minimal example, which crashes on the H100 with the same error, but runs successfully on my RTX 3060:

import torch
import triton
import triton.language as tl

@triton.jit
def fwd_kernel(
    out_LHD,
    stride_o_l,
    stride_o_h,
    stride_o_d,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_D: tl.constexpr,
):

    # current thread block index
    cur_head = tl.program_id(0)

    # data
    x = tl.zeros([BLOCK_N, BLOCK_D], dtype=tl.float16)
    y = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)

    # calculate
    p = tl.dot(y, x)
    alpha = tl.zeros([BLOCK_M], dtype=tl.float16)
    out = alpha[:, None] * p + p

    # store output
    offs_o = (
        tl.arange(0, BLOCK_M)[:, None] * stride_o_l
        + cur_head * stride_o_h
        + tl.arange(0, BLOCK_D)[None, :] * stride_o_d
    )
    tl.store(out_LHD + offs_o, out)


def main():

    # parameters
    X_len = 84
    H = 8
    D = 32
    BLOCK_M = 64
    BLOCK_N = 64
    device = torch.device("cuda:0")
    torch.manual_seed(999)

    # setup and launch kernel
    out = torch.empty((X_len, H, D), device=device, dtype=torch.float16)
    fwd_kernel[(H,)](
        out,
        out.stride(0),
        out.stride(1),
        out.stride(2),
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        BLOCK_D=D,
        num_warps=8,
    )

if __name__ == "__main__":
    main()

The H100 system is LambdaLabs Ubuntu, with these software versions installed using conda:

  • torchtriton 2.3.1
  • python 3.11.6
  • pytorch 2.3.1
  • pytorch-cuda 12.1

Thanks for the help!

@calebthomas259
Copy link
Author

Apparently this is a known issue with Hopper architecture #2627

@calebthomas259
Copy link
Author

In case anyone else has a similar problem, I was able to successfully work around the issue by removing all num_warps = 8 autotune configurations from my flash attention kernel

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants