Skip to content

Commit

Permalink
metal : fix shared memory calc + reduce smem + comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Nov 5, 2024
1 parent 1e12961 commit 624485f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 14 deletions.
47 changes: 35 additions & 12 deletions ggml/src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -3145,12 +3145,16 @@ static void ggml_metal_encode_node(
GGML_ASSERT(nqptg % 8 == 0);
GGML_ASSERT(ncpsg % 32 == 0);

// 16*32*nsgmax
// the shared memory needed for the simdgroups to load the KV cache
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
//
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg)) + 16*32*nsgmax)*(sizeof(float)/2), 16))

int64_t nsgmax = 2;

while (true) {
// 16*32*nsgmax - the shared memory needed for the simdgroups to load the KV cache
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
const size_t smem = (nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg)) + 16*32*nsgmax)*(sizeof(float)/2);
const size_t smem = FATTN_SMEM(nsgmax);
if (smem > device.maxThreadgroupMemoryLength) {
break;
}
Expand All @@ -3161,13 +3165,12 @@ static void ggml_metal_encode_node(
// simdgroups per threadgroup (a.k.a. warps)
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;

const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + 16*32*nsg)*(sizeof(float)/2);
const size_t smem = FATTN_SMEM(nsg);

//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);

[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];

[encoder setThreadgroupMemoryLength:smem atIndex:0];
#undef FATTN_SMEM
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
} else {
// half4x4 kernel
Expand All @@ -3178,21 +3181,41 @@ static void ggml_metal_encode_node(
GGML_ASSERT(nqptg % 1 == 0);
GGML_ASSERT(ncpsg % 32 == 0);

// ne00 + 2*ncpsg*(nsg)
// for each query, we load it as f16 in shared memory (ne00)
// and store the attention scores (nqptg x ncpsg) as f32
//
// 2*ne00*(nsg)
// each simdgroup has a full f32 head vector in shared mem to accumulate results
//
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16))

int64_t nsgmax = 2;

while (true) {
const size_t smem = FATTN_SMEM(nsgmax);
if (smem > device.maxThreadgroupMemoryLength) {
break;
}
nsgmax *= 2;
}
nsgmax /= 2;

// simdgroups per threadgroup (a.k.a. warps)
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));

int64_t nsg = 1;
while (nsg <= nsgt) {
nsg *= 2;
}
nsg /= 2;

const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + 2*nsg*ne00)*(sizeof(float)/2);
const size_t smem = FATTN_SMEM(nsg);

//printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];

[encoder setThreadgroupMemoryLength:smem atIndex:0];
#undef FATTN_SMEM
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
}
} break;
Expand Down
7 changes: 5 additions & 2 deletions ggml/src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2876,6 +2876,7 @@ kernel void kernel_flash_attn_ext(
for (short cc = 0; cc < C/8; ++cc) {
simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);

// this is compile-time check, so it does not have runtime overhead
if (is_same<block_q, half4x4>::value) {
// we can read directly from global memory
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
Expand All @@ -2891,6 +2892,7 @@ kernel void kernel_flash_attn_ext(
device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));

if (D16%4 == 0) {
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
half4x4 tmp;
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
skv4[4*ty + tx] = tmp;
Expand Down Expand Up @@ -3009,6 +3011,7 @@ kernel void kernel_flash_attn_ext(
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));

if (D16%4 == 0) {
// no need for bound checks
half4x4 tmp;
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
skv4[4*ty + tx] = tmp;
Expand Down Expand Up @@ -3233,14 +3236,14 @@ kernel void kernel_flash_attn_ext_vec(
const short D16 = D/16;
const short NW = N_SIMDWIDTH;
const short NW4 = NW/4;
const short SH = (C + Q); // shared memory per simdgroup in (half)
const short SH = C; // shared memory per simdgroup in (half)

const short T = D + 2*nsg*SH; // shared memory size per query in (half)

//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
threadgroup half4x4 * sq44 = (threadgroup half4x4 *) (shared + 0*D); // same as above but in half4x4
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
threadgroup float4x4 * sr44 = (threadgroup float4x4 *) (shared + 2*sgitg*D + Q*T); // scratch buffer for the results

Expand Down

0 comments on commit 624485f

Please sign in to comment.