Skip to content

Commit 405c8e9

Browse files
committed
PR suggestion
1 parent a68e1a5 commit 405c8e9

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

ggml-metal.metal

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,7 @@ kernel void kernel_rope(
700700
constant float & freq_base,
701701
constant float & freq_scale,
702702
uint tiitg[[thread_index_in_threadgroup]],
703+
uint3 tptg[[threads_per_threadgroup]],
703704
uint3 tgpig[[threadgroup_position_in_grid]]) {
704705
const int64_t i3 = tgpig[2];
705706
const int64_t i2 = tgpig[1];
@@ -713,7 +714,7 @@ kernel void kernel_rope(
713714
const float inv_ndims = -1.f/n_dims;
714715

715716
if (!is_neox) {
716-
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 64) {
717+
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
717718

718719
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
719720
const float cos_theta = cos(theta);
@@ -730,7 +731,7 @@ kernel void kernel_rope(
730731
}
731732
} else {
732733
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
733-
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 64) {
734+
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
734735

735736
const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
736737
const float cos_theta = cos(theta);

0 commit comments

Comments
 (0)