Skip to content

Commit 347c113

Browse files
committed
metal : remove mask padding requirement
1 parent 9b21358 commit 347c113

File tree

3 files changed

+18
-5
lines changed

3 files changed

+18
-5
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -982,19 +982,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
982982
const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
983983
const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
984984

985+
// do bounds checks for the mask?
986+
const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
987+
985988
snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
986989
"flash_attn_ext",
987990
ggml_type_name(op->src[1]->type),
988991
dk,
989992
dv);
990993

991-
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d",
994+
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
992995
base,
993996
has_mask,
994997
has_sinks,
995998
has_bias,
996999
has_scap,
9971000
has_kvpad,
1001+
bc_mask,
9981002
ns10,
9991003
ns20,
10001004
nsg);
@@ -1012,6 +1016,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
10121016
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
10131017
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
10141018

1019+
ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
1020+
10151021
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
10161022
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
10171023
ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22);

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1981,8 +1981,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
19811981
GGML_ASSERT(ne12 == ne22);
19821982

19831983
GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16);
1984-
GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= GGML_PAD(op->src[0]->ne[1], 8) &&
1985-
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
1984+
GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
1985+
"the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
19861986

19871987
float scale;
19881988
float max_bias;

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4525,6 +4525,8 @@ constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT
45254525
constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]];
45264526
constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]];
45274527

4528+
constant bool FC_flash_attn_ext_bc_mask [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
4529+
45284530
//constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
45294531
//constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]];
45304532
//constant float FC_flash_attn_ext_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT + 12)]];
@@ -4711,7 +4713,7 @@ void kernel_flash_attn_ext_impl(
47114713
v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
47124714

47134715
if (!FC_flash_attn_ext_has_mask) {
4714-
threadgroup half * sm = (threadgroup half *) (sm2);
4716+
threadgroup half * sm = (threadgroup half *) (sm2);
47154717

47164718
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
47174719
const short j = jj*NSG + sgitg;
@@ -4741,7 +4743,12 @@ void kernel_flash_attn_ext_impl(
47414743
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
47424744
const short j = jj*NSG + sgitg;
47434745

4744-
sm2[j*SH + tiisg] = pm2[jj][tiisg];
4746+
if (FC_flash_attn_ext_bc_mask) {
4747+
sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
4748+
} else {
4749+
sm2[j*SH + tiisg] = pm2[jj][tiisg];
4750+
}
4751+
47454752
pm2[jj] += NW;
47464753
}
47474754

0 commit comments

Comments
 (0)