Skip to content

Commit 4a7f43f

Browse files
committed
speculative : refactor sampling
1 parent 32a67cb commit 4a7f43f

File tree

3 files changed

+213
-208
lines changed

3 files changed

+213
-208
lines changed

common/sampling.cpp

+113-59
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,69 @@
11
#include "sampling.h"
22

3-
llama_sampling_context llama_sampling_context_init(
4-
const struct gpt_params & params,
5-
llama_grammar * grammar) {
6-
llama_sampling_context result;
3+
struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params) {
4+
struct llama_sampling_context * result =
5+
(struct llama_sampling_context *) malloc(sizeof(struct llama_sampling_context));
76

8-
result.params = params.sampling_params;
9-
result.grammar = grammar;
7+
result->params = params.sampling_params;
8+
result->grammar = nullptr;
9+
10+
// if there is a grammar, parse it
11+
if (!params.grammar.empty()) {
12+
result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
13+
14+
// will be empty (default) if there are parse errors
15+
if (result->parsed_grammar.rules.empty()) {
16+
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
17+
return nullptr;
18+
}
19+
20+
std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
21+
22+
result->grammar = llama_grammar_init(
23+
grammar_rules.data(),
24+
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
25+
}
26+
27+
result->prev.resize(params.n_ctx);
1028

1129
return result;
1230
}
1331

32+
void llama_sampling_free(struct llama_sampling_context * ctx) {
33+
if (ctx->grammar != NULL) {
34+
llama_grammar_free(ctx->grammar);
35+
}
36+
37+
free(ctx);
38+
}
39+
40+
void llama_sampling_reset(llama_sampling_context * ctx) {
41+
if (ctx->grammar != NULL) {
42+
llama_grammar_free(ctx->grammar);
43+
}
44+
45+
if (!ctx->parsed_grammar.rules.empty()) {
46+
std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
47+
48+
ctx->grammar = llama_grammar_init(
49+
grammar_rules.data(),
50+
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
51+
}
52+
53+
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
54+
ctx->cur.clear();
55+
}
56+
1457
llama_token llama_sampling_sample(
15-
struct llama_context * ctx,
58+
struct llama_sampling_context * ctx_sampling,
59+
struct llama_context * ctx_main,
1660
struct llama_context * ctx_guidance,
17-
struct llama_sampling_context & ctx_sampling,
18-
const std::vector<llama_token> & last_tokens,
19-
std::vector<llama_token_data> & candidates,
20-
const int idx) {
21-
const int n_ctx = llama_n_ctx(ctx);
22-
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
23-
24-
const llama_sampling_params & params = ctx_sampling.params;
61+
const int idx) {
62+
const int n_ctx = llama_n_ctx(ctx_main);
63+
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
64+
65+
const llama_sampling_params & params = ctx_sampling->params;
66+
2567
const float temp = params.temp;
2668
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
2769
const float top_p = params.top_p;
@@ -36,92 +78,104 @@ llama_token llama_sampling_sample(
3678
const float mirostat_eta = params.mirostat_eta;
3779
const bool penalize_nl = params.penalize_nl;
3880

81+
auto & prev = ctx_sampling->prev;
82+
auto & cur = ctx_sampling->cur;
83+
3984
llama_token id = 0;
4085

41-
float * logits = llama_get_logits_ith(ctx, idx);
86+
float * logits = llama_get_logits_ith(ctx_main, idx);
4287

4388
// Apply params.logit_bias map
4489
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
4590
logits[it->first] += it->second;
4691
}
4792

48-
candidates.clear();
93+
cur.clear();
94+
4995
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
50-
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
96+
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
5197
}
5298

53-
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
99+
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
54100

55101
if (ctx_guidance) {
56-
llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
102+
llama_sample_classifier_free_guidance(ctx_main, &cur_p, ctx_guidance, params.cfg_scale);
57103
}
58104

59105
// apply penalties
60-
if (!last_tokens.empty()) {
61-
const float nl_logit = logits[llama_token_nl(ctx)];
62-
const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx);
106+
if (!prev.empty()) {
107+
const float nl_logit = logits[llama_token_nl(ctx_main)];
108+
const int last_n_repeat = std::min(std::min((int)prev.size(), repeat_last_n), n_ctx);
63109

64-
llama_sample_repetition_penalty(ctx, &cur_p,
65-
last_tokens.data() + last_tokens.size() - last_n_repeat,
110+
llama_sample_repetition_penalty(ctx_main, &cur_p,
111+
prev.data() + prev.size() - last_n_repeat,
66112
last_n_repeat, repeat_penalty);
67-
llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
68-
last_tokens.data() + last_tokens.size() - last_n_repeat,
113+
llama_sample_frequency_and_presence_penalties(ctx_main, &cur_p,
114+
prev.data() + prev.size() - last_n_repeat,
69115
last_n_repeat, alpha_frequency, alpha_presence);
70116

71117
if (!penalize_nl) {
72118
for (size_t idx = 0; idx < cur_p.size; idx++) {
73-
if (cur_p.data[idx].id == llama_token_nl(ctx)) {
119+
if (cur_p.data[idx].id == llama_token_nl(ctx_main)) {
74120
cur_p.data[idx].logit = nl_logit;
75121
break;
76122
}
77123
}
78124
}
79125
}
80126

81-
if (ctx_sampling.grammar != NULL) {
82-
llama_sample_grammar(ctx, &cur_p, ctx_sampling.grammar);
127+
if (ctx_sampling->grammar != NULL) {
128+
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
83129
}
84130

85131
if (temp <= 0) {
86132
// Greedy sampling
87-
id = llama_sample_token_greedy(ctx, &cur_p);
133+
id = llama_sample_token_greedy(ctx_main, &cur_p);
88134
} else {
89135
if (mirostat == 1) {
90136
const int mirostat_m = 100;
91-
llama_sample_temp(ctx, &cur_p, temp);
92-
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling.mirostat_mu);
137+
llama_sample_temp(ctx_main, &cur_p, temp);
138+
id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu);
93139
} else if (mirostat == 2) {
94-
llama_sample_temp(ctx, &cur_p, temp);
95-
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling.mirostat_mu);
140+
llama_sample_temp(ctx_main, &cur_p, temp);
141+
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
96142
} else {
97143
// Temperature sampling
98144
size_t min_keep = std::max(1, params.n_probs);
99-
llama_sample_top_k (ctx, &cur_p, top_k, min_keep);
100-
llama_sample_tail_free (ctx, &cur_p, tfs_z, min_keep);
101-
llama_sample_typical (ctx, &cur_p, typical_p, min_keep);
102-
llama_sample_top_p (ctx, &cur_p, top_p, min_keep);
103-
llama_sample_temp(ctx, &cur_p, temp);
104-
105-
{
106-
const int n_top = 10;
107-
LOG("top %d candidates:\n", n_top);
108-
109-
for (int i = 0; i < n_top; i++) {
110-
const llama_token id = cur_p.data[i].id;
111-
(void)id; // To avoid a warning that id is unused when logging is disabled.
112-
LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
113-
}
114-
}
115-
116-
id = llama_sample_token(ctx, &cur_p);
117-
118-
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
145+
llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
146+
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
147+
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
148+
llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
149+
llama_sample_temp (ctx_main, &cur_p, temp);
150+
151+
id = llama_sample_token(ctx_main, &cur_p);
152+
153+
//{
154+
// const int n_top = 10;
155+
// LOG("top %d candidates:\n", n_top);
156+
157+
// for (int i = 0; i < n_top; i++) {
158+
// const llama_token id = cur_p.data[i].id;
159+
// (void)id; // To avoid a warning that id is unused when logging is disabled.
160+
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p);
161+
// }
162+
//}
163+
164+
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str());
119165
}
120166
}
121167

122-
if (ctx_sampling.grammar != NULL) {
123-
llama_grammar_accept_token(ctx, ctx_sampling.grammar, id);
124-
}
125-
126168
return id;
127169
}
170+
171+
void llama_sampling_accept(
172+
struct llama_sampling_context * ctx_sampling,
173+
struct llama_context * ctx_main,
174+
llama_token id) {
175+
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
176+
ctx_sampling->prev.push_back(id);
177+
178+
if (ctx_sampling->grammar != NULL) {
179+
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
180+
}
181+
}

common/sampling.h

+31-23
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#include "llama.h"
44

5+
#include "grammar-parser.h"
6+
57
#include <string>
68
#include <vector>
79
#include <unordered_map>
@@ -35,53 +37,59 @@ typedef struct llama_sampling_params {
3537
} llama_sampling_params;
3638

3739
// general sampler context
38-
typedef struct llama_sampling_context {
40+
// TODO: move to llama.h
41+
struct llama_sampling_context {
3942
// parameters that will be used for sampling
4043
llama_sampling_params params;
4144

4245
// mirostat sampler state
4346
float mirostat_mu;
4447

4548
llama_grammar * grammar;
46-
} llama_sampling_context;
49+
50+
// internal
51+
grammar_parser::parse_state parsed_grammar;
52+
53+
std::vector<llama_token> prev;
54+
std::vector<llama_token_data> cur;
55+
};
4756

4857
#include "common.h"
4958

5059
// Create a new sampling context instance.
51-
llama_sampling_context llama_sampling_context_init(
52-
const struct gpt_params & params,
53-
llama_grammar * grammar = NULL);
60+
struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params);
5461

55-
// Reset the sampler context for the supplied sequence id (defaults to 0).
56-
// This is necessary to reuse a sequence id or free memory used by sequences
57-
// that are no longer required.
58-
bool llama_sampling_context_reset(
59-
llama_sampling_context & ctx_sampling,
60-
const llama_seq_id seq = 0);
62+
void llama_sampling_free(struct llama_sampling_context * ctx);
63+
64+
// Reset the sampler context
65+
// - clear prev tokens
66+
// - reset grammar
67+
void llama_sampling_reset(llama_sampling_context * ctx);
6168

6269
// this is a common sampling function used across the examples for convenience
6370
// it can serve as a starting point for implementing your own sampling function
6471
// Note: When using multiple sequences, it is the caller's responsibility to call
65-
// llama_sampling_context_reset when a sequence ends
72+
// llama_sampling_reset when a sequence ends
6673
//
6774
// required:
68-
// - ctx: context to use for sampling
75+
// - ctx_main: context to use for sampling
6976
// - ctx_sampling: sampling-specific context
7077
//
7178
// optional:
72-
// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL
73-
// - last_tokens: needed for repetition penalty, ignore if empty
74-
// - idx: sample from llama_get_logits_ith(ctx, idx)
75-
// - seq: sequence id to associate sampler state with
79+
// - ctx_guidance: context to use for guidance
80+
// - idx: sample from llama_get_logits_ith(ctx, idx)
7681
//
7782
// returns:
7883
// - token: sampled token
7984
// - candidates: vector of candidate tokens
8085
//
8186
llama_token llama_sampling_sample(
82-
struct llama_context * ctx,
83-
struct llama_context * ctx_guidance,
84-
struct llama_sampling_context & ctx_sampling,
85-
const std::vector<llama_token> & last_tokens,
86-
std::vector<llama_token_data> & candidates,
87-
const int idx = 0);
87+
struct llama_sampling_context * ctx_sampling,
88+
struct llama_context * ctx_main,
89+
struct llama_context * ctx_guidance,
90+
int idx = 0);
91+
92+
void llama_sampling_accept(
93+
struct llama_sampling_context * ctx_sampling,
94+
struct llama_context * ctx_main,
95+
llama_token id);

0 commit comments

Comments
 (0)