Skip to content

Commit 41c6741

Browse files
authored
make rms_norm_eps a parameter (#2374)
* make rms_norm_eps a parameter * add rms_norm_eps to command line * fix baby llama, test-grad0 * use scientific notation for eps param in the help ggml-ci
1 parent b3f138d commit 41c6741

File tree

11 files changed

+89
-56
lines changed

11 files changed

+89
-56
lines changed

examples/baby-llama/baby-llama.cpp

+11-9
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#pragma warning(disable: 4244 4267) // possible loss of data
99
#endif
1010

11+
static const float rms_norm_eps = 1e-6f;
12+
1113
float frand() {
1214
return (float)rand()/(float)RAND_MAX;
1315
}
@@ -562,7 +564,7 @@ struct ggml_tensor * forward(
562564
// norm
563565
{
564566
// cur shape [n_embd,N,1,1]
565-
cur = ggml_rms_norm(ctx0, inpL);
567+
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
566568

567569
// cur = attention_norm*cur
568570
cur = ggml_mul(ctx0,
@@ -685,7 +687,7 @@ struct ggml_tensor * forward(
685687
// norm
686688
{
687689
// cur shape [n_embd,N,1,1]
688-
cur = ggml_rms_norm(ctx0, inpFF);
690+
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
689691

690692
// cur = ffn_norm*cur
691693
// cur shape [n_embd,N,1,1]
@@ -729,7 +731,7 @@ struct ggml_tensor * forward(
729731
{
730732

731733
// inpL shape [n_embd,N,1,1]
732-
inpL = ggml_rms_norm(ctx0, inpL);
734+
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
733735

734736
// inpL = norm*inpL
735737
// inpL shape [n_embd,N,1,1]
@@ -817,7 +819,7 @@ struct ggml_tensor * forward_batch(
817819
// norm
818820
{
819821
// cur shape [n_embd,N*n_batch,1,1]
820-
cur = ggml_rms_norm(ctx0, inpL);
822+
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
821823
assert_shape_2d(cur, n_embd, N*n_batch);
822824

823825
// cur = attention_norm*cur
@@ -981,7 +983,7 @@ struct ggml_tensor * forward_batch(
981983
// norm
982984
{
983985
// cur shape [n_embd,N*n_batch,1,1]
984-
cur = ggml_rms_norm(ctx0, inpFF);
986+
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
985987
assert_shape_2d(cur, n_embd, N*n_batch);
986988

987989
// cur = ffn_norm*cur
@@ -1034,7 +1036,7 @@ struct ggml_tensor * forward_batch(
10341036
{
10351037

10361038
// inpL shape [n_embd,N*n_batch,1,1]
1037-
inpL = ggml_rms_norm(ctx0, inpL);
1039+
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
10381040
assert_shape_2d(inpL, n_embd, N*n_batch);
10391041

10401042
// inpL = norm*inpL
@@ -1104,7 +1106,7 @@ struct ggml_tensor * forward_lora(
11041106
// norm
11051107
{
11061108
// cur shape [n_embd,N,1,1]
1107-
cur = ggml_rms_norm(ctx0, inpL);
1109+
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
11081110

11091111
// cur = attention_norm*cur
11101112
cur = ggml_mul(ctx0,
@@ -1251,7 +1253,7 @@ struct ggml_tensor * forward_lora(
12511253
// norm
12521254
{
12531255
// cur shape [n_embd,N,1,1]
1254-
cur = ggml_rms_norm(ctx0, inpFF);
1256+
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
12551257

12561258
// cur = ffn_norm*cur
12571259
// cur shape [n_embd,N,1,1]
@@ -1295,7 +1297,7 @@ struct ggml_tensor * forward_lora(
12951297
{
12961298

12971299
// inpL shape [n_embd,N,1,1]
1298-
inpL = ggml_rms_norm(ctx0, inpL);
1300+
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
12991301

13001302
// inpL = norm*inpL
13011303
// inpL shape [n_embd,N,1,1]

examples/common.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
177177
break;
178178
}
179179
params.n_gqa = std::stoi(argv[i]);
180+
} else if (arg == "-eps" || arg == "--rms-norm-eps") {
181+
if (++i >= argc) {
182+
invalid_param = true;
183+
break;
184+
}
185+
params.rms_norm_eps = std::stof(argv[i]);
180186
} else if (arg == "--rope-freq-base") {
181187
if (++i >= argc) {
182188
invalid_param = true;
@@ -519,6 +525,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
519525
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
520526
fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
521527
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
528+
fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
522529
fprintf(stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
523530
fprintf(stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
524531
fprintf(stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
@@ -615,6 +622,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
615622
lparams.n_ctx = params.n_ctx;
616623
lparams.n_batch = params.n_batch;
617624
lparams.n_gqa = params.n_gqa;
625+
lparams.rms_norm_eps = params.rms_norm_eps;
618626
lparams.n_gpu_layers = params.n_gpu_layers;
619627
lparams.main_gpu = params.main_gpu;
620628
lparams.tensor_split = params.tensor_split;

examples/common.h

+12-11
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,19 @@
2222
int32_t get_num_physical_cores();
2323

2424
struct gpt_params {
25-
uint32_t seed = -1; // RNG seed
25+
uint32_t seed = -1; // RNG seed
2626
int32_t n_threads = get_num_physical_cores();
27-
int32_t n_predict = -1; // new tokens to predict
28-
int32_t n_ctx = 512; // context size
29-
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
30-
int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams)
31-
int32_t n_keep = 0; // number of tokens to keep from initial prompt
32-
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
33-
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
34-
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
35-
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
36-
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
27+
int32_t n_predict = -1; // new tokens to predict
28+
int32_t n_ctx = 512; // context size
29+
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
30+
int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams)
31+
int32_t n_keep = 0; // number of tokens to keep from initial prompt
32+
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
33+
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
34+
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
35+
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
36+
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
37+
float rms_norm_eps = 1e-6; // rms norm epsilon
3738
float rope_freq_base = 10000.0f; // RoPE base frequency
3839
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor
3940

examples/train-text-from-scratch/train-text-from-scratch.cpp

+17-15
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#pragma warning(disable: 4244 4267) // possible loss of data
1717
#endif
1818

19+
static const float rms_norm_eps = 1e-6f;
20+
1921
struct random_normal_distribution {
2022
std::mt19937 gen;
2123
std::normal_distribution<float> rd;
@@ -439,7 +441,7 @@ struct ggml_tensor * forward(
439441
// norm
440442
{
441443
// cur shape [n_embd,N,1,1]
442-
cur = ggml_rms_norm(ctx0, inpL);
444+
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
443445

444446
// cur = attention_norm*cur
445447
cur = ggml_mul(ctx0,
@@ -562,7 +564,7 @@ struct ggml_tensor * forward(
562564
// norm
563565
{
564566
// cur shape [n_embd,N,1,1]
565-
cur = ggml_rms_norm(ctx0, inpFF);
567+
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
566568

567569
// cur = ffn_norm*cur
568570
// cur shape [n_embd,N,1,1]
@@ -606,7 +608,7 @@ struct ggml_tensor * forward(
606608
{
607609

608610
// inpL shape [n_embd,N,1,1]
609-
inpL = ggml_rms_norm(ctx0, inpL);
611+
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
610612

611613
// inpL = norm*inpL
612614
// inpL shape [n_embd,N,1,1]
@@ -694,7 +696,7 @@ struct ggml_tensor * forward_batch(
694696
// norm
695697
{
696698
// cur shape [n_embd,N*n_batch,1,1]
697-
cur = ggml_rms_norm(ctx0, inpL);
699+
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
698700
assert_shape_2d(cur, n_embd, N*n_batch);
699701

700702
// cur = attention_norm*cur
@@ -857,7 +859,7 @@ struct ggml_tensor * forward_batch(
857859
// norm
858860
{
859861
// cur shape [n_embd,N*n_batch,1,1]
860-
cur = ggml_rms_norm(ctx0, inpFF);
862+
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
861863
assert_shape_2d(cur, n_embd, N*n_batch);
862864

863865
// cur = ffn_norm*cur
@@ -910,7 +912,7 @@ struct ggml_tensor * forward_batch(
910912
{
911913

912914
// inpL shape [n_embd,N*n_batch,1,1]
913-
inpL = ggml_rms_norm(ctx0, inpL);
915+
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
914916
assert_shape_2d(inpL, n_embd, N*n_batch);
915917

916918
// inpL = norm*inpL
@@ -979,7 +981,7 @@ struct ggml_tensor * forward_batch_wo_cache(
979981
// norm
980982
{
981983
// cur shape [n_embd,N*n_batch,1,1]
982-
cur = ggml_rms_norm(ctx0, inpL);
984+
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
983985
assert_shape_2d(cur, n_embd, N*n_batch);
984986

985987
// cur = attention_norm*cur
@@ -1085,7 +1087,7 @@ struct ggml_tensor * forward_batch_wo_cache(
10851087
// norm
10861088
{
10871089
// cur shape [n_embd,N*n_batch,1,1]
1088-
cur = ggml_rms_norm(ctx0, inpFF);
1090+
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
10891091
assert_shape_2d(cur, n_embd, N*n_batch);
10901092

10911093
// cur = ffn_norm*cur
@@ -1138,7 +1140,7 @@ struct ggml_tensor * forward_batch_wo_cache(
11381140
{
11391141

11401142
// inpL shape [n_embd,N*n_batch,1,1]
1141-
inpL = ggml_rms_norm(ctx0, inpL);
1143+
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
11421144
assert_shape_2d(inpL, n_embd, N*n_batch);
11431145

11441146
// inpL = norm*inpL
@@ -1203,7 +1205,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
12031205

12041206
// norm
12051207
{
1206-
cur = ggml_rms_norm(ctx0, inpL);
1208+
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
12071209
assert_shape_2d(cur, n_embd, N*n_batch);
12081210

12091211
// cur = attention_norm*cur
@@ -1267,7 +1269,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
12671269
{
12681270
// norm
12691271
{
1270-
cur = ggml_rms_norm(ctx0, inpFF);
1272+
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
12711273
assert_shape_2d(cur, n_embd, N*n_batch);
12721274

12731275
// cur = ffn_norm*cur
@@ -1311,7 +1313,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
13111313
// norm
13121314
{
13131315

1314-
inpL = ggml_rms_norm(ctx0, inpL);
1316+
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
13151317
assert_shape_2d(inpL, n_embd, N*n_batch);
13161318

13171319
// inpL = norm*inpL
@@ -1603,7 +1605,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
16031605
struct my_llama_layer & layer = model->layers[il];
16041606
// tensors with values necessary for backward pass are in persistent buf(-1)
16051607
// other tensors with buf(0) and buf(1) are only temporary needed, and their memory reused after layer is completed.
1606-
use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t02, n_embd, N*n_batch);
1608+
use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t02, n_embd, N*n_batch);
16071609
use_buf( 0); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch);
16081610
use_buf(-1); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch);
16091611
use_buf(-1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch);
@@ -1623,7 +1625,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
16231625
use_buf(-1); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch);
16241626
use_buf( 0); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch);
16251627
use_buf(-1); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch);
1626-
use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21)); assert_shape_2d(t22, n_embd, N*n_batch);
1628+
use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21, rms_norm_eps)); assert_shape_2d(t22, n_embd, N*n_batch);
16271629
use_buf( 0); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch);
16281630
use_buf(-1); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch);
16291631
use_buf(-1); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch);
@@ -1666,7 +1668,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
16661668
}
16671669
clr_buf(0);
16681670
use_buf(0);
1669-
struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t31, n_embd, N*n_batch);
1671+
struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t31, n_embd, N*n_batch);
16701672
struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch);
16711673
struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch);
16721674
use_buf(-1);

ggml-cuda.cu

+7-6
Original file line numberDiff line numberDiff line change
@@ -332,12 +332,10 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
332332
}
333333
}
334334

335-
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
335+
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
336336
const int row = blockIdx.x*blockDim.y + threadIdx.y;
337337
const int tid = threadIdx.x;
338338

339-
const float eps = 1e-6f;
340-
341339
float tmp = 0.0f; // partial sum for thread in warp
342340

343341
for (int col = tid; col < ncols; col += WARP_SIZE) {
@@ -2122,10 +2120,10 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
21222120
norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
21232121
}
21242122

2125-
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
2123+
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
21262124
GGML_ASSERT(ncols % WARP_SIZE == 0);
21272125
const dim3 block_dims(WARP_SIZE, 1, 1);
2128-
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
2126+
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
21292127
}
21302128

21312129
static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
@@ -2876,8 +2874,11 @@ inline void ggml_cuda_op_rms_norm(
28762874
const int64_t ne00 = src0->ne[0];
28772875
const int64_t i01_diff = i01_high - i01_low;
28782876

2877+
float eps;
2878+
memcpy(&eps, dst->op_params, sizeof(float));
2879+
28792880
// compute
2880-
rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
2881+
rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, eps, cudaStream_main);
28812882

28822883
(void) src1;
28832884
(void) dst;

ggml-metal.m

+2-1
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,8 @@ void ggml_metal_graph_compute(
812812
encoder = [command_buffer computeCommandEncoder];
813813
}
814814

815-
const float eps = 1e-6f;
815+
float eps;
816+
memcpy(&eps, dst->op_params, sizeof(float));
816817

817818
const int nth = 512;
818819

ggml.c

+10-6
Original file line numberDiff line numberDiff line change
@@ -5781,6 +5781,7 @@ struct ggml_tensor * ggml_norm_inplace(
57815781
static struct ggml_tensor * ggml_rms_norm_impl(
57825782
struct ggml_context * ctx,
57835783
struct ggml_tensor * a,
5784+
float eps,
57845785
bool inplace) {
57855786
bool is_node = false;
57865787

@@ -5790,7 +5791,7 @@ static struct ggml_tensor * ggml_rms_norm_impl(
57905791

57915792
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
57925793

5793-
// TODO: maybe store epsilon here?
5794+
ggml_set_op_params(result, &eps, sizeof(eps));
57945795

57955796
result->op = GGML_OP_RMS_NORM;
57965797
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5801,14 +5802,16 @@ static struct ggml_tensor * ggml_rms_norm_impl(
58015802

58025803
struct ggml_tensor * ggml_rms_norm(
58035804
struct ggml_context * ctx,
5804-
struct ggml_tensor * a) {
5805-
return ggml_rms_norm_impl(ctx, a, false);
5805+
struct ggml_tensor * a,
5806+
float eps) {
5807+
return ggml_rms_norm_impl(ctx, a, eps, false);
58065808
}
58075809

58085810
struct ggml_tensor * ggml_rms_norm_inplace(
58095811
struct ggml_context * ctx,
5810-
struct ggml_tensor * a) {
5811-
return ggml_rms_norm_impl(ctx, a, true);
5812+
struct ggml_tensor * a,
5813+
float eps) {
5814+
return ggml_rms_norm_impl(ctx, a, eps, true);
58125815
}
58135816

58145817
struct ggml_tensor * ggml_rms_norm_back(
@@ -10131,7 +10134,8 @@ static void ggml_compute_forward_rms_norm_f32(
1013110134

1013210135
GGML_TENSOR_UNARY_OP_LOCALS;
1013310136

10134-
const float eps = 1e-6f; // TODO: make this a parameter
10137+
float eps;
10138+
memcpy(&eps, dst->op_params, sizeof(float));
1013510139

1013610140
// TODO: optimize
1013710141
for (int64_t i03 = 0; i03 < ne03; i03++) {

0 commit comments

Comments
 (0)