Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 772f88f

Browse files
committedOct 8, 2023
Fix mirostat state when using multiple sequences
1 parent b0ec521 commit 772f88f

File tree

3 files changed

+31
-10
lines changed

3 files changed

+31
-10
lines changed
 

‎common/common.cpp

+15-6
Original file line numberDiff line numberDiff line change
@@ -940,10 +940,11 @@ llama_token llama_sample_token(
940940
struct llama_context * ctx,
941941
struct llama_context * ctx_guidance,
942942
struct llama_grammar * grammar,
943-
const struct gpt_params & params,
943+
struct gpt_params & params,
944944
const std::vector<llama_token> & last_tokens,
945945
std::vector<llama_token_data> & candidates,
946-
int idx) {
946+
const int idx,
947+
llama_seq_id seq) {
947948
const int n_ctx = llama_n_ctx(ctx);
948949
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
949950

@@ -1011,15 +1012,23 @@ llama_token llama_sample_token(
10111012
// Greedy sampling
10121013
id = llama_sample_token_greedy(ctx, &cur_p);
10131014
} else {
1015+
float * mirostat_mu = NULL;
1016+
if (mirostat > 0) {
1017+
seq = std::max(0, seq); // Deal with people passing -1 or something.
1018+
auto mu_it = params.sampler_state.find(seq);
1019+
if (mu_it == params.sampler_state.end()) {
1020+
const llama_sampler_state new_state = { 2.0f * mirostat_tau };
1021+
mu_it = params.sampler_state.insert({seq, new_state}).first;
1022+
}
1023+
mirostat_mu = &mu_it->second.mirostat_mu;
1024+
}
10141025
if (mirostat == 1) {
1015-
static float mirostat_mu = 2.0f * mirostat_tau;
10161026
const int mirostat_m = 100;
10171027
llama_sample_temp(ctx, &cur_p, temp);
1018-
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
1028+
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, mirostat_mu);
10191029
} else if (mirostat == 2) {
1020-
static float mirostat_mu = 2.0f * mirostat_tau;
10211030
llama_sample_temp(ctx, &cur_p, temp);
1022-
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
1031+
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_mu);
10231032
} else {
10241033
// Temperature sampling
10251034
size_t min_keep = std::max(1, params.n_probs);

‎common/common.h

+14-2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@
3333
//
3434
int32_t get_num_physical_cores();
3535

36+
typedef struct llama_sampler_state {
37+
float mirostat_mu; // mirostat sampler state
38+
} llama_sampler_state;
39+
3640
struct gpt_params {
3741
uint32_t seed = -1; // RNG seed
3842
int32_t n_threads = get_num_physical_cores();
@@ -54,6 +58,9 @@ struct gpt_params {
5458
float rope_freq_base = 0.0f; // RoPE base frequency
5559
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
5660

61+
// per sequence sampler state
62+
std::unordered_map<llama_seq_id, llama_sampler_state> sampler_state;
63+
5764
// sampling parameters
5865
int32_t top_k = 40; // <= 0 to use vocab size
5966
float top_p = 0.95f; // 1.0 = disabled
@@ -186,6 +193,9 @@ std::string llama_detokenize_bpe(
186193

187194
// this is a common sampling function used across the examples for convenience
188195
// it can serve as a starting point for implementing your own sampling function
196+
// Note: When using multiple sequences, it is the caller's responsibility to delete
197+
// the item in params.sampler_state when a sequence ends and samplers that rely on
198+
// state are being used.
189199
//
190200
// required:
191201
// - ctx: context to use for sampling
@@ -196,6 +206,7 @@ std::string llama_detokenize_bpe(
196206
// - grammar: grammar to use for sampling, ignore if NULL
197207
// - last_tokens: needed for repetition penalty, ignore if empty
198208
// - idx: sample from llama_get_logits_ith(ctx, idx)
209+
// - seq: sequence id to associate sampler state with (currently only used by mirostat)
199210
//
200211
// returns:
201212
// - token: sampled token
@@ -205,10 +216,11 @@ llama_token llama_sample_token(
205216
struct llama_context * ctx,
206217
struct llama_context * ctx_guidance,
207218
struct llama_grammar * grammar,
208-
const struct gpt_params & params,
219+
struct gpt_params & params,
209220
const std::vector<llama_token> & last_tokens,
210221
std::vector<llama_token_data> & candidates,
211-
int idx = 0);
222+
const int idx = 0,
223+
llama_seq_id seq = 0);
212224

213225
//
214226
// YAML utils

‎examples/parallel/parallel.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ int main(int argc, char ** argv) {
339339
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
340340
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
341341

342-
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.tokens_prev, candidates, client.i_batch - i);
342+
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.tokens_prev, candidates, client.i_batch - i, client.seq_id);
343343

344344
if (client.n_decoded == 1) {
345345
// start measuring generation time after the first token to make sure all concurrent clients
@@ -384,7 +384,7 @@ int main(int argc, char ** argv) {
384384

385385
n_total_prompt += client.n_prompt;
386386
n_total_gen += client.n_decoded;
387-
387+
params.sampler_state.erase(client.seq_id);
388388
client.seq_id = -1;
389389
}
390390

0 commit comments

Comments
 (0)
Please sign in to comment.