Skip to content

Commit 90cbb6e

Browse files
author
s50048389
committed
Fix indexing
1 parent c8a293c commit 90cbb6e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ void main() {
368368
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
369369
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
370370
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
371-
const uint sums_idx = (cr * WNITER + wsic) * (WMITER * TN) + cc * TN + wsir;
371+
const uint sums_idx = (cr * WNITER + wsic) * (WMITER * TN) + cc * WMITER + wsir;
372372
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx].x));
373373
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx].y));
374374
}
@@ -466,7 +466,7 @@ void main() {
466466
const u16vec2 row_idx = row_ids[row_i - ic * BN];
467467
#endif // MUL_MAT_ID
468468
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
469-
const uint sums_idx = (cr * WNITER + wsic) * (WMITER * TN) + cc * TN + wsir;
469+
const uint sums_idx = (cr * WNITER + wsic) * (WMITER * TN) + cc * WMITER + wsir;
470470
#ifdef MUL_MAT_ID
471471
if (dr_warp + 2 * cr < p.M) {
472472
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);

0 commit comments

Comments
 (0)