Skip to content

Commit

Permalink
CUDA: fix Volta FlashAttention logic (#11615)
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler authored Feb 3, 2025
1 parent d92cb67 commit 21c84b5
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/fattn-wmma-f16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
break;
// case 256:
// ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
// ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
// break;
default:
GGML_ABORT("fatal error");
Expand Down
3 changes: 2 additions & 1 deletion ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
return;
}

if (!new_mma_available(cc)) {
if (!fp16_mma_available(cc)) {
if (prec == GGML_PREC_DEFAULT) {
if (Q->ne[1] <= 8) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
Expand Down Expand Up @@ -265,6 +265,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
if (cc == GGML_CUDA_CC_VOLTA) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
return;
}

ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
Expand Down

0 comments on commit 21c84b5

Please sign in to comment.