Skip to content

Commit 625f212

Browse files
committed
ggml : multi-thread ggml_rope() (~3-4 times faster on M1)
1 parent 1868f6c commit 625f212

File tree

1 file changed

+37
-4
lines changed

1 file changed

+37
-4
lines changed

Diff for: ggml.c

+37-4
Original file line numberDiff line numberDiff line change
@@ -7238,7 +7238,6 @@ static void ggml_compute_forward_rope_f32(
72387238
const struct ggml_tensor * src0,
72397239
const struct ggml_tensor * src1,
72407240
struct ggml_tensor * dst) {
7241-
assert(params->ith == 0);
72427241
assert(src1->type == GGML_TYPE_I32);
72437242
assert(ggml_nelements(src1) == 3);
72447243

@@ -7265,11 +7264,28 @@ static void ggml_compute_forward_rope_f32(
72657264

72667265
assert(nb0 == sizeof(float));
72677266

7268-
// TODO: optimize
7267+
const int ith = params->ith;
7268+
const int nth = params->nth;
7269+
7270+
const int nr = ggml_nrows(src0);
7271+
7272+
// rows per thread
7273+
const int dr = (nr + nth - 1)/nth;
7274+
7275+
// row range for this thread
7276+
const int ir0 = dr*ith;
7277+
const int ir1 = MIN(ir0 + dr, nr);
7278+
7279+
// row index used to determine which thread to use
7280+
int ir = 0;
7281+
72697282
for (int64_t i3 = 0; i3 < ne3; i3++) {
72707283
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
72717284
const int p = (mode == 0 ? n_past + i2 : i2);
72727285
for (int64_t i1 = 0; i1 < ne1; i1++) {
7286+
if (ir++ < ir0) continue;
7287+
if (ir > ir1) break;
7288+
72737289
for (int i0 = 0; i0 < n_dims; i0 += 2) {
72747290
const float theta = powf(10000.0, ((float)-i0)/n_dims);
72757291

@@ -7295,7 +7311,6 @@ static void ggml_compute_forward_rope_f16(
72957311
const struct ggml_tensor * src0,
72967312
const struct ggml_tensor * src1,
72977313
struct ggml_tensor * dst) {
7298-
assert(params->ith == 0);
72997314
assert(src1->type == GGML_TYPE_I32);
73007315
assert(ggml_nelements(src1) == 3);
73017316

@@ -7322,10 +7337,28 @@ static void ggml_compute_forward_rope_f16(
73227337

73237338
assert(nb0 == sizeof(ggml_fp16_t));
73247339

7340+
const int ith = params->ith;
7341+
const int nth = params->nth;
7342+
7343+
const int nr = ggml_nrows(src0);
7344+
7345+
// rows per thread
7346+
const int dr = (nr + nth - 1)/nth;
7347+
7348+
// row range for this thread
7349+
const int ir0 = dr*ith;
7350+
const int ir1 = MIN(ir0 + dr, nr);
7351+
7352+
// row index used to determine which thread to use
7353+
int ir = 0;
7354+
73257355
for (int64_t i3 = 0; i3 < ne3; i3++) {
73267356
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
73277357
const int p = (mode == 0 ? n_past + i2 : i2);
73287358
for (int64_t i1 = 0; i1 < ne1; i1++) {
7359+
if (ir++ < ir0) continue;
7360+
if (ir > ir1) break;
7361+
73297362
for (int i0 = 0; i0 < n_dims; i0 += 2) {
73307363
const float theta = powf(10000.0, ((float)-i0)/n_dims);
73317364

@@ -9424,7 +9457,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
94249457
} break;
94259458
case GGML_OP_ROPE:
94269459
{
9427-
node->n_tasks = 1;
9460+
node->n_tasks = n_threads;
94289461
} break;
94299462
case GGML_OP_CONV_1D_1S:
94309463
case GGML_OP_CONV_1D_2S:

0 commit comments

Comments
 (0)