@@ -7238,7 +7238,6 @@ static void ggml_compute_forward_rope_f32(
7238
7238
const struct ggml_tensor * src0 ,
7239
7239
const struct ggml_tensor * src1 ,
7240
7240
struct ggml_tensor * dst ) {
7241
- assert (params -> ith == 0 );
7242
7241
assert (src1 -> type == GGML_TYPE_I32 );
7243
7242
assert (ggml_nelements (src1 ) == 3 );
7244
7243
@@ -7265,11 +7264,28 @@ static void ggml_compute_forward_rope_f32(
7265
7264
7266
7265
assert (nb0 == sizeof (float ));
7267
7266
7268
- // TODO: optimize
7267
+ const int ith = params -> ith ;
7268
+ const int nth = params -> nth ;
7269
+
7270
+ const int nr = ggml_nrows (src0 );
7271
+
7272
+ // rows per thread
7273
+ const int dr = (nr + nth - 1 )/nth ;
7274
+
7275
+ // row range for this thread
7276
+ const int ir0 = dr * ith ;
7277
+ const int ir1 = MIN (ir0 + dr , nr );
7278
+
7279
+ // row index used to determine which thread to use
7280
+ int ir = 0 ;
7281
+
7269
7282
for (int64_t i3 = 0 ; i3 < ne3 ; i3 ++ ) {
7270
7283
for (int64_t i2 = (mode == 0 ? 0 : n_past ); i2 < ne2 ; i2 ++ ) {
7271
7284
const int p = (mode == 0 ? n_past + i2 : i2 );
7272
7285
for (int64_t i1 = 0 ; i1 < ne1 ; i1 ++ ) {
7286
+ if (ir ++ < ir0 ) continue ;
7287
+ if (ir > ir1 ) break ;
7288
+
7273
7289
for (int i0 = 0 ; i0 < n_dims ; i0 += 2 ) {
7274
7290
const float theta = powf (10000.0 , ((float )- i0 )/n_dims );
7275
7291
@@ -7295,7 +7311,6 @@ static void ggml_compute_forward_rope_f16(
7295
7311
const struct ggml_tensor * src0 ,
7296
7312
const struct ggml_tensor * src1 ,
7297
7313
struct ggml_tensor * dst ) {
7298
- assert (params -> ith == 0 );
7299
7314
assert (src1 -> type == GGML_TYPE_I32 );
7300
7315
assert (ggml_nelements (src1 ) == 3 );
7301
7316
@@ -7322,10 +7337,28 @@ static void ggml_compute_forward_rope_f16(
7322
7337
7323
7338
assert (nb0 == sizeof (ggml_fp16_t ));
7324
7339
7340
+ const int ith = params -> ith ;
7341
+ const int nth = params -> nth ;
7342
+
7343
+ const int nr = ggml_nrows (src0 );
7344
+
7345
+ // rows per thread
7346
+ const int dr = (nr + nth - 1 )/nth ;
7347
+
7348
+ // row range for this thread
7349
+ const int ir0 = dr * ith ;
7350
+ const int ir1 = MIN (ir0 + dr , nr );
7351
+
7352
+ // row index used to determine which thread to use
7353
+ int ir = 0 ;
7354
+
7325
7355
for (int64_t i3 = 0 ; i3 < ne3 ; i3 ++ ) {
7326
7356
for (int64_t i2 = (mode == 0 ? 0 : n_past ); i2 < ne2 ; i2 ++ ) {
7327
7357
const int p = (mode == 0 ? n_past + i2 : i2 );
7328
7358
for (int64_t i1 = 0 ; i1 < ne1 ; i1 ++ ) {
7359
+ if (ir ++ < ir0 ) continue ;
7360
+ if (ir > ir1 ) break ;
7361
+
7329
7362
for (int i0 = 0 ; i0 < n_dims ; i0 += 2 ) {
7330
7363
const float theta = powf (10000.0 , ((float )- i0 )/n_dims );
7331
7364
@@ -9424,7 +9457,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9424
9457
} break ;
9425
9458
case GGML_OP_ROPE :
9426
9459
{
9427
- node -> n_tasks = 1 ;
9460
+ node -> n_tasks = n_threads ;
9428
9461
} break ;
9429
9462
case GGML_OP_CONV_1D_1S :
9430
9463
case GGML_OP_CONV_1D_2S :
0 commit comments