diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 3cb24412d577e..30c0dcb0dcb04 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -313,12 +313,12 @@ void main() { sums[i] = coopmat(0.0f); } #else - ACC_TYPE sums[WMITER * TM * WNITER * TN]; + ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2]; FLOAT_TYPE_VEC2 cache_a[WMITER * TM]; - FLOAT_TYPE_VEC2 cache_b[TN]; + FLOAT_TYPE_VEC2 cache_b; - [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { - sums[i] = ACC_TYPE(0.0f); + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) { + sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f); } #endif @@ -360,20 +360,22 @@ void main() { cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i]; } } - [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (uint j = 0; j < TN; j++) { - cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; - } - [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (uint cc = 0; cc < TN; cc++) { - [[unroll]] for (uint cr = 0; cr < TM; cr++) { - const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; - sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx])); + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i]; + + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { + // [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr] + const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr; + sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x)); + sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y)); } } } } + } #endif @@ -388,8 +390,9 @@ void main() { } } #else - [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { - sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) { + sums[i].x = clamp(sums[i].x, -ACC_TYPE_MAX, ACC_TYPE_MAX); + sums[i].y = clamp(sums[i].y, -ACC_TYPE_MAX, ACC_TYPE_MAX); } #endif #endif @@ -463,14 +466,21 @@ void main() { const u16vec2 row_idx = row_ids[row_i - ic * BN]; #endif // MUL_MAT_ID - [[unroll]] for (uint cr = 0; cr < TM; cr++) { + [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { + const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr; #ifdef MUL_MAT_ID - if (dr_warp + cr < p.M) { - data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + if (dr_warp + 2 * cr < p.M) { + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x); + } + if (dr_warp + 2 * cr + 1 < p.M) { + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y); } #else - if (dr_warp + cr < p.M && dc_warp + cc < p.N) { - data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) { + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x); + } + if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) { + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y); } #endif // MUL_MAT_ID }