Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RMS norm and use it #187

Merged
merged 2 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 126 additions & 2 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -2069,6 +2069,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"GELU",
"SILU",
"NORM",
"RMS_NORM",

"MUL_MAT",

Expand All @@ -2089,7 +2090,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"FLASH_FF",
};

static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34");
static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");

static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
Expand All @@ -2112,6 +2113,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"gelu(x)",
"silu(x)",
"norm(x)",
"rms_norm(x)",

"X*Y",

Expand All @@ -2132,7 +2134,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"flash_ff(x)",
};

static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34");
static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");

//
// ggml object
Expand Down Expand Up @@ -3618,6 +3620,39 @@ struct ggml_tensor * ggml_norm_inplace(
return ggml_norm_impl(ctx, a, true);
}

struct ggml_tensor * ggml_rms_norm_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
bool inplace) {
bool is_node = false;

if (!inplace && (a->grad)) {
GGML_ASSERT(false); // TODO: implement backward
is_node = true;
}

struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);

result->op = GGML_OP_RMS_NORM;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src0 = a;
result->src1 = NULL; // TODO: maybe store epsilon here?

return result;
}

struct ggml_tensor * ggml_rms_norm(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_rms_norm_impl(ctx, a, false);
}

struct ggml_tensor * ggml_rms_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_rms_norm_impl(ctx, a, true);
}

// ggml_mul_mat

struct ggml_tensor * ggml_mul_mat(
Expand Down Expand Up @@ -5406,6 +5441,87 @@ static void ggml_compute_forward_norm(
}
}

static void ggml_compute_forward_rms_norm_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
GGML_ASSERT(ggml_are_same_shape(src0, dst));

if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}

GGML_ASSERT(src0->nb[0] == sizeof(float));

const int ith = params->ith;
const int nth = params->nth;

const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];

const size_t nb01 = src0->nb[1];
const size_t nb02 = src0->nb[2];
const size_t nb03 = src0->nb[3];

const size_t nb1 = dst->nb[1];
const size_t nb2 = dst->nb[2];
const size_t nb3 = dst->nb[3];

const ggml_float eps = 1e-5f; // TODO: make this a parameter

// TODO: optimize
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
for (int i01 = ith; i01 < ne01; i01 += nth) {
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);

ggml_float mean = 0.0;
for (int i00 = 0; i00 < ne00; i00++) {
mean += x[i00] * x[i00];
}

mean /= ne00;

float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);

memcpy(y, x, ne00 * sizeof(float));
// for (int i00 = 0; i00 < ne00; i00++) {
// y[i00] = x[i00];
// }

const float scale = 1.0/sqrt(mean + eps);

ggml_vec_scale_f32(ne00, y, scale);
}
}
}
}

static void ggml_compute_forward_rms_norm(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_rms_norm_f32(params, src0, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
case GGML_TYPE_F16:
case GGML_TYPE_COUNT:
{
GGML_ASSERT(false);
} break;
}
}


// ggml_compute_forward_mul_mat

#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
Expand Down Expand Up @@ -8522,6 +8638,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_norm(params, tensor->src0, tensor);
} break;
case GGML_OP_RMS_NORM:
{
ggml_compute_forward_rms_norm(params, tensor->src0, tensor);
} break;
case GGML_OP_MUL_MAT:
{
ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
Expand Down Expand Up @@ -8764,6 +8884,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_RMS_NORM:
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_MUL_MAT:
{
if (src0->grad) {
Expand Down
5 changes: 5 additions & 0 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ enum ggml_op {
GGML_OP_GELU,
GGML_OP_SILU,
GGML_OP_NORM, // normalize
GGML_OP_RMS_NORM,

GGML_OP_MUL_MAT,

Expand Down Expand Up @@ -482,6 +483,10 @@ struct ggml_tensor * ggml_norm(
struct ggml_context * ctx,
struct ggml_tensor * a);

struct ggml_tensor * ggml_rms_norm(
struct ggml_context * ctx,
struct ggml_tensor * a);

// A: m rows, n columns
// B: p rows, n columns (i.e. we transpose it internally)
// result is m columns, p rows
Expand Down
6 changes: 3 additions & 3 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ bool llama_eval(

// norm
{
cur = ggml_norm(ctx0, inpL);
cur = ggml_rms_norm(ctx0, inpL);

// cur = attention_norm*cur
cur = ggml_mul(ctx0,
Expand Down Expand Up @@ -678,7 +678,7 @@ bool llama_eval(
{
// norm
{
cur = ggml_norm(ctx0, inpFF);
cur = ggml_rms_norm(ctx0, inpFF);

// cur = ffn_norm*cur
cur = ggml_mul(ctx0,
Expand Down Expand Up @@ -713,7 +713,7 @@ bool llama_eval(

// norm
{
inpL = ggml_norm(ctx0, inpL);
inpL = ggml_rms_norm(ctx0, inpL);

// inpL = norm*inpL
inpL = ggml_mul(ctx0,
Expand Down