Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 14a7dce

Browse files
committedMay 17, 2023
ggml : rms_norm in chunks
1 parent add49f6 commit 14a7dce

File tree

1 file changed

+38
-23
lines changed

1 file changed

+38
-23
lines changed
 

‎ggml.c

+38-23
Original file line numberDiff line numberDiff line change
@@ -9018,18 +9018,20 @@ static void ggml_compute_forward_rms_norm_f32(
90189018
GGML_ASSERT(ggml_are_same_shape(src0, dst));
90199019

90209020
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9021+
atomic_store(params->aic, 0);
9022+
90219023
return;
90229024
}
90239025

90249026
GGML_ASSERT(src0->nb[0] == sizeof(float));
90259027

9026-
const int ith = params->ith;
9028+
const int ith = params->ith; UNUSED(ith);
90279029
const int nth = params->nth;
90289030

90299031
const int64_t ne00 = src0->ne[0];
90309032
const int64_t ne01 = src0->ne[1];
90319033
const int64_t ne02 = src0->ne[2];
9032-
const int64_t ne03 = src0->ne[3];
9034+
const int64_t ne03 = src0->ne[3]; UNUSED(ne03);
90339035

90349036
const size_t nb01 = src0->nb[1];
90359037
const size_t nb02 = src0->nb[2];
@@ -9041,30 +9043,45 @@ static void ggml_compute_forward_rms_norm_f32(
90419043

90429044
const float eps = 1e-6f; // TODO: make this a parameter
90439045

9044-
// TODO: optimize
9045-
for (int64_t i03 = 0; i03 < ne03; i03++) {
9046-
for (int64_t i02 = 0; i02 < ne02; i02++) {
9047-
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
9048-
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
9049-
9050-
ggml_float sum = 0.0;
9051-
for (int64_t i00 = 0; i00 < ne00; i00++) {
9052-
sum += (ggml_float)(x[i00] * x[i00]);
9053-
}
9046+
const int nr = ggml_nrows(src0);
9047+
const int dr = (nr + 8*nth - 1)/(8*nth);
90549048

9055-
float mean = sum/ne00;
9049+
while (true) {
9050+
const int ir0 = atomic_fetch_add(params->aic, dr);
90569051

9057-
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
9052+
for (int ir = ir0; ir < ir0 + dr; ++ir) {
9053+
if (ir >= nr) {
9054+
break;
9055+
}
90589056

9059-
memcpy(y, x, ne00 * sizeof(float));
9060-
// for (int i00 = 0; i00 < ne00; i00++) {
9061-
// y[i00] = x[i00];
9062-
// }
9057+
// src0 indices
9058+
const int i03 = ir/(ne02*ne01);
9059+
const int i02 = (ir - i03*ne02*ne01)/ne01;
9060+
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
90639061

9064-
const float scale = 1.0f/sqrtf(mean + eps);
9062+
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
90659063

9066-
ggml_vec_scale_f32(ne00, y, scale);
9064+
ggml_float sum = 0.0;
9065+
for (int64_t i00 = 0; i00 < ne00; i00++) {
9066+
sum += (ggml_float)(x[i00] * x[i00]);
90679067
}
9068+
9069+
float mean = sum/ne00;
9070+
9071+
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
9072+
9073+
memcpy(y, x, ne00 * sizeof(float));
9074+
// for (int i00 = 0; i00 < ne00; i00++) {
9075+
// y[i00] = x[i00];
9076+
// }
9077+
9078+
const float scale = 1.0f/sqrtf(mean + eps);
9079+
9080+
ggml_vec_scale_f32(ne00, y, scale);
9081+
}
9082+
9083+
if (ir0 + dr >= nr) {
9084+
break;
90689085
}
90699086
}
90709087
}
@@ -9739,11 +9756,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
97399756
const int nb2 = dst->nb[2];
97409757
const int nb3 = dst->nb[3];
97419758

9742-
const int ith = params->ith;
9759+
const int ith = params->ith; UNUSED(ith);
97439760
const int nth = params->nth;
97449761

9745-
UNUSED(ith);
9746-
97479762
GGML_ASSERT(ne02 == ne12);
97489763
GGML_ASSERT(ne03 == ne13);
97499764
GGML_ASSERT(ne2 == ne12);

0 commit comments

Comments
 (0)
Please sign in to comment.