Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can flashattention run on Jetson AGX Orin with compute capability of 8.7? #449

Closed
PeterBaelish opened this issue Aug 15, 2023 · 6 comments

Comments

@PeterBaelish
Copy link

PeterBaelish commented Aug 15, 2023

It seems Jetson AGX Orin's compute capability is not supported by flashattention. Its compute capability is 8.7 and GPU architecture is ampere.

Can I just modify something to make it work? Thanks a lot in advance!

When I just test "flash_attn_func" in flashattention like this

import math
import random
import time
from einops import rearrange
import torch
import torch.nn.functional as F
from flash_attn import flash_attn_func

def custom_attention(q, k, v, causal=False):
    score = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
    if causal:
        mask = torch.triu(torch.ones(score.shape[-2], score.shape[-1]), diagonal=1)
        mask = mask.masked_fill(mask==1, torch.finfo(q.dtype).min)
        mask = mask.to(q.device, q.dtype)
        score = score + mask
    attn = F.softmax(score, dim=-1)
    o = torch.matmul(attn, v)
    return o

def pytorch_func(q, k, v, causal=False):
    with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
        return F.scaled_dot_product_attention(q, k, v, is_causal=causal)

def flash_attention(q, k, v, causal=False):
    o = flash_attn_func(q, k, v, causal=causal)
    return o

def test(func_name, q, k, v, *args, **kwargs):
    if func_name in ["custom_attention", "pytorch_func"]:
        q = rearrange(q, "a b c d -> a c b d")
        k = rearrange(k, "a b c d -> a c b d")
        v = rearrange(v, "a b c d -> a c b d")

    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()
    for _ in range(5):
        o = globals()[func_name](q, k, v, *args, **kwargs)
    torch.cuda.synchronize()
    st = time.time()
    o = globals()[func_name](q, k, v, *args, **kwargs)
    torch.cuda.synchronize()
    tt = time.time() - st
    max_memory = torch.cuda.max_memory_allocated() // 2**20
    torch.cuda.empty_cache()
    print(o.size())
    if func_name in ["custom_attention", "pytorch_func"]:
        o = rearrange(o, "a c b d -> a b c d")

    return o, tt, max_memory

if __name__ == "__main__":
    test_num = 10
    #torch.backends.cuda.enable_flash_sdp(False)
    #torch.backends.cuda.enable_mem_efficient_sdp(False)
    for idx in range(test_num):
        print(f"test {idx} >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
        bsz = random.randint(1, 64)
        sql = random.randint(1, 4096)
        nh = random.choice([8, 12, 16])
        hd = random.choice([64, 128])
        dtype = random.choice([torch.float16, torch.bfloat16])
        causal = random.choice([False, True])
        print(f"shape: ({bsz}, {sql}, {nh}, {hd}), dtype: {dtype}, causal: {causal}")
        q = torch.randn((bsz, sql, nh, hd)).to("cuda:0", dtype)
        k = torch.rand_like(q)
        v = torch.rand_like(q)

        o, t, m = test("custom_attention", q, k, v, causal=causal)
        print(f"custom pytorch time: {t:.6f}, peak memory: {m} MB")

        pf_o, pf_t, pf_m = test("pytorch_func", q, k, v, causal=causal)
        print(f"pytorch func time: {pf_t:.6f}, speedup: {t/pf_t:.2f}; peak memory: {pf_m} MB, save: {int((m-pf_m)/m*100)}%")
        assert torch.allclose(o, pf_o, rtol=1e-2, atol=1e-2)
        
        fa_o, fa_t, fa_m = test("flash_attention", q, k, v, causal=causal)
        print(f"flash attention time: {fa_t:.6f}, speedup: {t/fa_t:.2f}; peak memory: {fa_m} MB, save: {int((m-fa_m)/m*100)}%")
        assert torch.allclose(o, fa_o, rtol=1e-2, atol=1e-2)

The compilation result is:

test 0 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (53, 1678, 12, 64), dtype: torch.bfloat16, causal: False
torch.Size([53, 12, 1678, 64])
custom pytorch time: 0.421059, peak memory: 7621 MB
torch.Size([53, 12, 1678, 64])
pytorch func time: 0.369576, speedup: 1.14; peak memory: 7621 MB, save: 0%
Traceback (most recent call last):
  File "flash_attention_test.py", line 75, in <module>
    fa_o, fa_t, fa_m = test("flash_attention", q, k, v, causal=causal)
  File "flash_attention_test.py", line 37, in test
    o = globals()[func_name](q, k, v, *args, **kwargs)
  File "flash_attention_test.py", line 25, in flash_attention
    o = flash_attn_func(q, k, v, causal=causal)
  File "/home/jane/.local/lib/python3.8/site-packages/flash_attn-2.0.6-py3.8-linux-aarch64.egg/flash_attn/flash_attn_interface.py", line 373, in flash_attn_func
    return FlashAttnFunc.apply(q, k, v, dropout_p, softmax_scale, causal, return_attn_probs)
  File "/home/jane/.local/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/jane/.local/lib/python3.8/site-packages/flash_attn-2.0.6-py3.8-linux-aarch64.egg/flash_attn/flash_attn_interface.py", line 222, in forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
  File "/home/jane/.local/lib/python3.8/site-packages/flash_attn-2.0.6-py3.8-linux-aarch64.egg/flash_attn/flash_attn_interface.py", line 42, in _flash_attn_forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
RuntimeError: CUDA error: invalid device function
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

The environment is:
Cuda: 11.4
Pytorch: 2.0.0
flashattention: 2.0.6

@PeterBaelish PeterBaelish changed the title Can flashattention run on Jetson AGX Orin? Can flashattention run on Jetson AGX Orin with compute capability of 8.7? Aug 15, 2023
@PeterBaelish
Copy link
Author

NVIDIA docs said that the only difference between compute capability of 8.7 and 8.6 is the size of shared memory. A unified data cache and shared memory with a total size of 192 KB for devices of compute capability 8.7 and 128 KB for devices of compute capabilities 8.6.

@tridao
Copy link
Contributor

tridao commented Aug 15, 2023

From the docs it seems like the code should just run, since 8.7 has more shared memory than 8.6. Idk the issue is, and I don't have the hardware to test or debug.
Can you try uncommenting this line to see how much shared memory the kernel is using?

You can also try running with the nvcr pytorch 23.07 container so we're sure it's not the environment that's the issue.

@PeterBaelish
Copy link
Author

Thanks for your reply!

nvcr pytorch 23.07 container seems not fit NVIDIA SoCs. I found NVIDIA has provided pytorch container specifically for SoCs, but it also doesn't work for flash-attention, the error output is the same as issue #451. Maybe I will check it in the future. So I close this issue now.

Anyway, thanks for your reply again!

@tridao
Copy link
Contributor

tridao commented Aug 21, 2023

You can try compiling in that container with FLASH_ATTENTION_FORCE_BUILD=TRUE pip install flash-attn to compile locally instead of downloading the prebuilt wheel.

@shubhendu-ranadive
Copy link

shubhendu-ranadive commented Oct 2, 2024

Running into similar problem using this container by NVIDIA on Jetson Orin AGX 64GB (compute capability 8.7) to install FlashAttention v2.1.1

Tried FLASH_ATTENTION_FORCE_BUILD=TRUE pip install flash-attn==2.1.1 as well as building from source, but ran into the same error on both occasions. I suppose FlashAttention doesn't support Jetson devices yet?

CUDA = 11.4
PyTorch = v2.0.0

@shubhendu-ranadive
Copy link

Got it working after following this issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants