File tree 1 file changed +3
-2
lines changed 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -700,6 +700,7 @@ kernel void kernel_rope(
700
700
constant float & freq_base,
701
701
constant float & freq_scale,
702
702
uint tiitg[[thread_index_in_threadgroup]],
703
+ uint3 tptg[[threads_per_threadgroup]],
703
704
uint3 tgpig[[threadgroup_position_in_grid]]) {
704
705
const int64_t i3 = tgpig[2 ];
705
706
const int64_t i2 = tgpig[1 ];
@@ -713,7 +714,7 @@ kernel void kernel_rope(
713
714
const float inv_ndims = -1 .f /n_dims;
714
715
715
716
if (!is_neox) {
716
- for (int64_t i0 = 2 *tiitg; i0 < ne0; i0 += 64 ) {
717
+ for (int64_t i0 = 2 *tiitg; i0 < ne0; i0 += 2 *tptg. x ) {
717
718
718
719
const float theta = theta_0 * pow (freq_base, inv_ndims*i0);
719
720
const float cos_theta = cos (theta);
@@ -730,7 +731,7 @@ kernel void kernel_rope(
730
731
}
731
732
} else {
732
733
for (int64_t ib = 0 ; ib < ne0/n_dims; ++ib) {
733
- for (int64_t ic = 2 *tiitg; ic < n_dims; ic += 64 ) {
734
+ for (int64_t ic = 2 *tiitg; ic < n_dims; ic += 2 *tptg. x ) {
734
735
735
736
const float theta = theta_0 * pow (freq_base, inv_ndims*ic - ib);
736
737
const float cos_theta = cos (theta);
You can’t perform that action at this time.
0 commit comments