Skip to content

Commit 49006c6

Browse files
authored
llama : move random seed generation to the samplers (#9398)
* llama_sampler_penalties : clamp penalty_last_n to zero
1 parent 00ba2ff commit 49006c6

File tree

10 files changed

+92
-34
lines changed

10 files changed

+92
-34
lines changed

Diff for: common/arg.cpp

+1-6
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,6 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx
173173
std::string arg;
174174
const std::string arg_prefix = "--";
175175
gpt_params & params = ctx_arg.params;
176-
gpt_sampler_params & sparams = params.sparams;
177176

178177
std::unordered_map<std::string, llama_arg *> arg_to_options;
179178
for (auto & opt : ctx_arg.options) {
@@ -283,10 +282,6 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx
283282
params.kv_overrides.back().key[0] = 0;
284283
}
285284

286-
if (sparams.seed == LLAMA_DEFAULT_SEED) {
287-
sparams.seed = time(NULL);
288-
}
289-
290285
return true;
291286
}
292287

@@ -909,7 +904,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
909904
).set_sparam());
910905
add_opt(llama_arg(
911906
{"-s", "--seed"}, "SEED",
912-
format("RNG seed (default: %d, use random seed for < 0)", params.sparams.seed),
907+
format("RNG seed (default: %u, use random seed for %u)", params.sparams.seed, LLAMA_DEFAULT_SEED),
913908
[](gpt_params & params, const std::string & value) {
914909
params.sparams.seed = std::stoul(value);
915910
}

Diff for: common/sampling.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,10 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
310310
return cur_p.data[cur_p.selected].id;
311311
}
312312

313+
uint32_t gpt_sampler_get_seed(const struct gpt_sampler * gsmpl) {
314+
return llama_sampler_get_seed(gsmpl->chain);
315+
}
316+
313317
// helpers
314318

315319
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) {

Diff for: common/sampling.h

+2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
6060
//
6161
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
6262

63+
uint32_t gpt_sampler_get_seed(const struct gpt_sampler * gsmpl);
64+
6365
// helpers
6466

6567
// access the internal list of current candidate tokens

Diff for: examples/embedding/embedding.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,6 @@ int main(int argc, char ** argv) {
9090

9191
print_build_info();
9292

93-
LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
94-
9593
llama_backend_init();
9694
llama_numa_init(params.numa);
9795

Diff for: examples/infill/infill.cpp

+3-4
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,6 @@ int main(int argc, char ** argv) {
159159

160160
print_build_info();
161161

162-
LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
163-
164162
LOG("%s: llama backend init\n", __func__);
165163
llama_backend_init();
166164
llama_numa_init(params.numa);
@@ -301,6 +299,9 @@ int main(int argc, char ** argv) {
301299
LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
302300
}
303301
}
302+
smpl = gpt_sampler_init(model, sparams);
303+
304+
LOG_TEE("sampling seed: %u\n", gpt_sampler_get_seed(smpl));
304305
LOG_TEE("sampling: \n%s\n", sparams.print().c_str());
305306
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
306307
LOG_TEE("\n\n");
@@ -340,8 +341,6 @@ int main(int argc, char ** argv) {
340341

341342
std::vector<llama_token> embd;
342343

343-
smpl = gpt_sampler_init(model, sparams);
344-
345344
while (n_remain != 0 || params.interactive) {
346345
// predict
347346
if (!embd.empty()) {

Diff for: examples/main/main.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,6 @@ int main(int argc, char ** argv) {
191191

192192
print_build_info();
193193

194-
LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
195-
196194
LOG("%s: llama backend init\n", __func__);
197195
llama_backend_init();
198196
llama_numa_init(params.numa);
@@ -470,8 +468,10 @@ int main(int argc, char ** argv) {
470468
exit(1);
471469
}
472470

471+
LOG_TEE("sampling seed: %u\n", gpt_sampler_get_seed(smpl));
473472
LOG_TEE("sampling params: \n%s\n", sparams.print().c_str());
474-
LOG_TEE(" sampler constr: \n%s\n", gpt_sampler_print(smpl).c_str());
473+
LOG_TEE("sampler constr: \n%s\n", gpt_sampler_print(smpl).c_str());
474+
475475
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
476476

477477
// group-attention state

Diff for: examples/perplexity/perplexity.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -2007,8 +2007,6 @@ int main(int argc, char ** argv) {
20072007

20082008
print_build_info();
20092009

2010-
LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
2011-
20122010
llama_backend_init();
20132011
llama_numa_init(params.numa);
20142012

Diff for: examples/server/server.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1266,6 +1266,7 @@ struct server_context {
12661266
{"n_predict", slot.n_predict}, // Server configured n_predict
12671267
{"model", params.model_alias},
12681268
{"seed", slot.sparams.seed},
1269+
{"seed_cur", slot.smpl ? gpt_sampler_get_seed(slot.smpl) : 0},
12691270
{"temperature", slot.sparams.temp},
12701271
{"dynatemp_range", slot.sparams.dynatemp_range},
12711272
{"dynatemp_exponent", slot.sparams.dynatemp_exponent},

Diff for: include/llama.h

+4
Original file line numberDiff line numberDiff line change
@@ -1127,6 +1127,10 @@ extern "C" {
11271127
int32_t n_logit_bias,
11281128
const llama_logit_bias * logit_bias);
11291129

1130+
1131+
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
1132+
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
1133+
11301134
/// @details Sample and accept a token from the idx-th output of the last evaluation
11311135
//
11321136
// Shorthand for:

Diff for: src/llama-sampling.cpp

+74-17
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <cstring>
99
#include <ctime>
1010
#include <cfloat>
11+
#include <chrono>
1112
#include <cmath>
1213
#include <numeric>
1314
#include <random>
@@ -162,6 +163,19 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
162163
cur_p->size = k;
163164
}
164165

166+
static uint32_t get_rng_seed(uint32_t seed) {
167+
if (seed == LLAMA_DEFAULT_SEED) {
168+
// use system clock if std::random_device is not a true RNG
169+
static bool is_rd_prng = std::random_device().entropy() == 0;
170+
if (is_rd_prng) {
171+
return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
172+
}
173+
std::random_device rd;
174+
return rd();
175+
}
176+
return seed;
177+
}
178+
165179
// llama_sampler API
166180

167181
const char * llama_sampler_name(const struct llama_sampler * smpl) {
@@ -387,6 +401,7 @@ struct llama_sampler * llama_sampler_init_greedy() {
387401

388402
struct llama_sampler_dist {
389403
const uint32_t seed;
404+
uint32_t seed_cur;
390405

391406
std::mt19937 rng;
392407
};
@@ -416,7 +431,8 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample
416431

417432
static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
418433
auto * ctx = (llama_sampler_dist *) smpl->ctx;
419-
ctx->rng = std::mt19937(ctx->seed);
434+
ctx->seed_cur = get_rng_seed(ctx->seed);
435+
ctx->rng.seed(ctx->seed_cur);
420436
}
421437

422438
static void llama_sampler_dist_free(struct llama_sampler * smpl) {
@@ -433,11 +449,13 @@ static struct llama_sampler_i llama_sampler_dist_i = {
433449
};
434450

435451
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
452+
auto seed_cur = get_rng_seed(seed);
436453
return new llama_sampler {
437454
/* .iface = */ &llama_sampler_dist_i,
438455
/* .ctx = */ new llama_sampler_dist {
439-
/* .seed = */ seed,
440-
/* .rng = */ std::mt19937(seed),
456+
/* .seed = */ seed,
457+
/* .seed_cur = */ seed_cur,
458+
/* .rng = */ std::mt19937(seed_cur),
441459
},
442460
};
443461
}
@@ -1032,6 +1050,7 @@ struct llama_sampler_mirostat {
10321050
const int32_t n_vocab;
10331051

10341052
const uint32_t seed;
1053+
uint32_t seed_cur;
10351054

10361055
const float tau;
10371056
const float eta;
@@ -1100,7 +1119,8 @@ static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sa
11001119
static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
11011120
auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
11021121
ctx->mu = 2.0f*ctx->tau;
1103-
ctx->rng = std::mt19937(ctx->seed);
1122+
ctx->seed_cur = get_rng_seed(ctx->seed);
1123+
ctx->rng.seed(ctx->seed_cur);
11041124
}
11051125

11061126
static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
@@ -1117,16 +1137,18 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
11171137
};
11181138

11191139
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
1140+
auto seed_cur = get_rng_seed(seed);
11201141
return new llama_sampler {
11211142
/* .iface = */ &llama_sampler_mirostat_i,
11221143
/* .ctx = */ new llama_sampler_mirostat {
1123-
/* .n_vocab = */ n_vocab,
1124-
/* .seed = */ seed,
1125-
/* .tau = */ tau,
1126-
/* .eta = */ eta,
1127-
/* .m = */ m,
1128-
/* .mu = */ 2.0f*tau,
1129-
/* .rng = */ std::mt19937(seed),
1144+
/* .n_vocab = */ n_vocab,
1145+
/* .seed = */ seed,
1146+
/* .seed_cur = */ seed_cur,
1147+
/* .tau = */ tau,
1148+
/* .eta = */ eta,
1149+
/* .m = */ m,
1150+
/* .mu = */ 2.0f*tau,
1151+
/* .rng = */ std::mt19937(seed_cur),
11301152
},
11311153
};
11321154
}
@@ -1135,6 +1157,7 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see
11351157

11361158
struct llama_sampler_mirostat_v2 {
11371159
const uint32_t seed;
1160+
uint32_t seed_cur;
11381161

11391162
const float tau;
11401163
const float eta;
@@ -1179,7 +1202,8 @@ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_t
11791202
static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
11801203
auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
11811204
ctx->mu = 2.0f*ctx->tau;
1182-
ctx->rng = std::mt19937(ctx->seed);
1205+
ctx->seed_cur = get_rng_seed(ctx->seed);
1206+
ctx->rng.seed(ctx->seed_cur);
11831207
}
11841208

11851209
static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
@@ -1212,14 +1236,16 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
12121236
};
12131237

12141238
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
1239+
auto seed_cur = get_rng_seed(seed);
12151240
return new llama_sampler {
12161241
/* .iface = */ &llama_sampler_mirostat_v2_i,
12171242
/* .ctx = */ new llama_sampler_mirostat_v2 {
1218-
/* .seed = */ seed,
1219-
/* .tau = */ tau,
1220-
/* .eta = */ eta,
1221-
/* .mu = */ 2.0f*tau,
1222-
/* .rng = */ std::mt19937(seed),
1243+
/* .seed = */ seed,
1244+
/* .seed_cur = */ seed_cur,
1245+
/* .tau = */ tau,
1246+
/* .eta = */ eta,
1247+
/* .mu = */ 2.0f*tau,
1248+
/* .rng = */ std::mt19937(seed_cur),
12231249
},
12241250
};
12251251
}
@@ -1505,6 +1531,8 @@ struct llama_sampler * llama_sampler_init_penalties(
15051531
ignore_eos = false;
15061532
}
15071533

1534+
penalty_last_n = std::max(penalty_last_n, 0);
1535+
15081536
return new llama_sampler {
15091537
/* .iface = */ &llama_sampler_penalties_i,
15101538
/* .ctx = */ new llama_sampler_penalties {
@@ -1568,6 +1596,7 @@ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_to
15681596
}
15691597
}
15701598
}
1599+
15711600
static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
15721601
const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
15731602
return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
@@ -1599,3 +1628,31 @@ struct llama_sampler * llama_sampler_init_logit_bias(
15991628
},
16001629
};
16011630
}
1631+
1632+
// utils
1633+
1634+
uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
1635+
if (smpl->iface == &llama_sampler_dist_i) {
1636+
return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
1637+
}
1638+
1639+
if (smpl->iface == &llama_sampler_mirostat_i) {
1640+
return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
1641+
}
1642+
1643+
if (smpl->iface == &llama_sampler_mirostat_v2_i) {
1644+
return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur;
1645+
}
1646+
1647+
if (smpl->iface == &llama_sampler_chain_i) {
1648+
const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
1649+
for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
1650+
const uint32_t seed = llama_sampler_get_seed(*it);
1651+
if (seed != LLAMA_DEFAULT_SEED) {
1652+
return seed;
1653+
}
1654+
}
1655+
}
1656+
1657+
return LLAMA_DEFAULT_SEED;
1658+
}

0 commit comments

Comments
 (0)