Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

memory_efficient_attention is slower than scaled_dot_product_attention of PyTorch? #1107

Open
QinlongHuang opened this issue Sep 19, 2024 · 2 comments

Comments

@QinlongHuang
Copy link

❓ Questions and Help

I am new to xformers, and I want to speed my Transformer models w/ it. But I found that xformers is no speed up compared w/ scaled_dot_product_attention from PyTorch. Here is my code snippet for training a vanilla GPT-2. Is there anywhing wrong when I use xformers?

from xformers.ops import memory_efficient_attention, LowerTriangularMask

self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
if memory_efficient_attention is not None:
            y = memory_efficient_attention(
                q, k, v, 
                p=self.dropout if self.training else 0,
                attn_bias=LowerTriangularMask(),
            )
elif self.flash:
    y = F.scaled_dot_product_attention(
        q, k, v,
        dropout_p=self.dropout if self.training else 0,
        is_causal=True,
    )

Environment: Ubuntu 20.04 CUDA11.8 NVIDIA RTX 4090, PyTorch 2.4.1, xformers 0.0.28.post1

python -m xformers.info

xFormers 0.0.28.post1
memory_efficient_attention.ckF:                    unavailable
memory_efficient_attention.ckB:                    unavailable
memory_efficient_attention.ck_decoderF:            unavailable
memory_efficient_attention.ck_splitKF:             unavailable
memory_efficient_attention.cutlassF-pt:            available
memory_efficient_attention.cutlassB-pt:            available
memory_efficient_attention.fa2F@v2.5.6-pt:         available
memory_efficient_attention.fa2B@v2.5.6-pt:         available
memory_efficient_attention.fa3F@0.0.0:             unavailable
memory_efficient_attention.fa3B@0.0.0:             unavailable
memory_efficient_attention.triton_splitKF:         available
indexing.scaled_index_addF:                        available
indexing.scaled_index_addB:                        available
indexing.index_select:                             available
sequence_parallel_fused.write_values:              available
sequence_parallel_fused.wait_values:               available
sequence_parallel_fused.cuda_memset_32b_async:     available
sp24.sparse24_sparsify_both_ways:                  available
sp24.sparse24_apply:                               available
sp24.sparse24_apply_dense_output:                  available
sp24._sparse24_gemm:                               available
sp24._cslt_sparse_mm@0.4.0:                        available
swiglu.dual_gemm_silu:                             available
swiglu.gemm_fused_operand_sum:                     available
swiglu.fused.p.cpp:                                available
is_triton_available:                               True
pytorch.version:                                   2.4.1+cu118
pytorch.cuda:                                      available
gpu.compute_capability:                            8.9
gpu.name:                                          NVIDIA GeForce RTX 4090
dcgm_profiler:                                     unavailable
build.info:                                        available
build.cuda_version:                                1108
build.hip_version:                                 None
build.python_version:                              3.9.20
build.torch_version:                               2.4.1+cu118
build.env.TORCH_CUDA_ARCH_LIST:                    6.0+PTX 7.0 7.5 8.0+PTX
build.env.PYTORCH_ROCM_ARCH:                       None
build.env.XFORMERS_BUILD_TYPE:                     Release
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS:        None
build.env.NVCC_FLAGS:                              -allow-unsupported-compiler
build.env.XFORMERS_PACKAGE_FROM:                   wheel-v0.0.28.post1
build.nvcc_version:                                11.8.89
source.privacy:                                    open source

When I trained w/ a standard GPT-2 (~89M parameters) using scaled_dot_product_attention, I got ~9it/s, but only ~7it/s on memory_efficient_attention.

And I cannot train a GPT-2-medium (~300M parameters) when using memory_efficient_attention, but I can train that w/ scaled_dot_product_attention.

All exps are trained using fp16 and w/ torch.compile.

@danthe3rd
Copy link
Contributor

danthe3rd commented Sep 19, 2024

Hi,
memory_efficient_attention used to be faster than PyTorch's SDPA because xFormers was using Flash-Attention. Now SDPA is also using Flash-Attention, so it's normal to have the same speed.
Also the dimensions need to be transposed between SDPA (format BHMK) and memory_efficient_attention (format BMHK).

@QinlongHuang
Copy link
Author

Thank u so much for the quick reply!

So is there any gain to use memory_efficient_attention of xformers instead of PyTorch's SDPA now?

And it seems that memory_efficient_attention is not compatiable w/ torch.compile which can speed up training and use less memory.

Besides, you've mentioned that "the dimensions need to be transposed between SDPA (format BHMK) and memory_efficient_attention (format BMHK)". So I have to transpose the QKV to get the CORRECT results?

I did a toy test w/ the following snippet.

import torch
import torch.nn.functional as F
from xformers.ops import memory_efficient_attention

batch_size, num_heads, seq_length, head_dim = 32, 128, 512, 256
q = torch.randn(batch_size, num_heads, seq_length, head_dim, device='cuda')  # BHMK
k = torch.randn(batch_size, num_heads, seq_length, head_dim, device='cuda')
v = torch.randn(batch_size, num_heads, seq_length, head_dim, device='cuda')

# PyTorch scaled_dot_product_attention
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()

for i in range(10):
    output_torch = F.scaled_dot_product_attention(q, k, v)

end.record()
torch.cuda.synchronize()
print(f"PyTorch scaled_dot_product_attention: {start.elapsed_time(end)} ms")

# xformers ScaledDotProduct
q = q.transpose(1, 2)  # BMHK
k = k.transpose(1, 2)
v = v.transpose(1, 2)
torch.cuda.synchronize()
start.record()

for i in range(10):
    output_xformers = memory_efficient_attention(q, k, v)

end.record()
torch.cuda.synchronize()
print(f"xformers memory_efficient_attention: {start.elapsed_time(end)} ms")

And now I get the similiar speed w/ these two implementations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants