11#include " sampling.h"
22
3- llama_sampling_context llama_sampling_context_init (
4- const struct gpt_params & params,
5- llama_grammar * grammar) {
6- llama_sampling_context result;
3+ struct llama_sampling_context * llama_sampling_init (const struct gpt_params & params) {
4+ struct llama_sampling_context * result =
5+ (struct llama_sampling_context *) malloc (sizeof (struct llama_sampling_context ));
76
8- result.params = params.sampling_params ;
9- result.grammar = grammar;
7+ result->params = params.sampling_params ;
8+ result->grammar = nullptr ;
9+
10+ // if there is a grammar, parse it
11+ if (!params.grammar .empty ()) {
12+ result->parsed_grammar = grammar_parser::parse (params.grammar .c_str ());
13+
14+ // will be empty (default) if there are parse errors
15+ if (result->parsed_grammar .rules .empty ()) {
16+ fprintf (stderr, " %s: failed to parse grammar\n " , __func__);
17+ return nullptr ;
18+ }
19+
20+ std::vector<const llama_grammar_element *> grammar_rules (result->parsed_grammar .c_rules ());
21+
22+ result->grammar = llama_grammar_init (
23+ grammar_rules.data (),
24+ grammar_rules.size (), result->parsed_grammar .symbol_ids .at (" root" ));
25+ }
26+
27+ result->prev .resize (params.n_ctx );
1028
1129 return result;
1230}
1331
32+ void llama_sampling_free (struct llama_sampling_context * ctx) {
33+ if (ctx->grammar != NULL ) {
34+ llama_grammar_free (ctx->grammar );
35+ }
36+
37+ free (ctx);
38+ }
39+
40+ void llama_sampling_reset (llama_sampling_context * ctx) {
41+ if (ctx->grammar != NULL ) {
42+ llama_grammar_free (ctx->grammar );
43+ }
44+
45+ if (!ctx->parsed_grammar .rules .empty ()) {
46+ std::vector<const llama_grammar_element *> grammar_rules (ctx->parsed_grammar .c_rules ());
47+
48+ ctx->grammar = llama_grammar_init (
49+ grammar_rules.data (),
50+ grammar_rules.size (), ctx->parsed_grammar .symbol_ids .at (" root" ));
51+ }
52+
53+ std::fill (ctx->prev .begin (), ctx->prev .end (), 0 );
54+ ctx->cur .clear ();
55+ }
56+
1457llama_token llama_sampling_sample (
15- struct llama_context * ctx,
58+ struct llama_sampling_context * ctx_sampling,
59+ struct llama_context * ctx_main,
1660 struct llama_context * ctx_guidance,
17- struct llama_sampling_context & ctx_sampling,
18- const std::vector<llama_token> & last_tokens,
19- std::vector<llama_token_data> & candidates,
20- const int idx) {
21- const int n_ctx = llama_n_ctx (ctx);
22- const int n_vocab = llama_n_vocab (llama_get_model (ctx));
23-
24- const llama_sampling_params & params = ctx_sampling.params ;
61+ const int idx) {
62+ const int n_ctx = llama_n_ctx (ctx_main);
63+ const int n_vocab = llama_n_vocab (llama_get_model (ctx_main));
64+
65+ const llama_sampling_params & params = ctx_sampling->params ;
66+
2567 const float temp = params.temp ;
2668 const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k ;
2769 const float top_p = params.top_p ;
@@ -36,92 +78,104 @@ llama_token llama_sampling_sample(
3678 const float mirostat_eta = params.mirostat_eta ;
3779 const bool penalize_nl = params.penalize_nl ;
3880
81+ auto & prev = ctx_sampling->prev ;
82+ auto & cur = ctx_sampling->cur ;
83+
3984 llama_token id = 0 ;
4085
41- float * logits = llama_get_logits_ith (ctx , idx);
86+ float * logits = llama_get_logits_ith (ctx_main , idx);
4287
4388 // Apply params.logit_bias map
4489 for (auto it = params.logit_bias .begin (); it != params.logit_bias .end (); it++) {
4590 logits[it->first ] += it->second ;
4691 }
4792
48- candidates.clear ();
93+ cur.clear ();
94+
4995 for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
50- candidates .emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
96+ cur .emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
5197 }
5298
53- llama_token_data_array cur_p = { candidates .data (), candidates .size (), false };
99+ llama_token_data_array cur_p = { cur .data (), cur .size (), false };
54100
55101 if (ctx_guidance) {
56- llama_sample_classifier_free_guidance (ctx , &cur_p, ctx_guidance, params.cfg_scale );
102+ llama_sample_classifier_free_guidance (ctx_main , &cur_p, ctx_guidance, params.cfg_scale );
57103 }
58104
59105 // apply penalties
60- if (!last_tokens .empty ()) {
61- const float nl_logit = logits[llama_token_nl (ctx )];
62- const int last_n_repeat = std::min (std::min ((int )last_tokens .size (), repeat_last_n), n_ctx);
106+ if (!prev .empty ()) {
107+ const float nl_logit = logits[llama_token_nl (ctx_main )];
108+ const int last_n_repeat = std::min (std::min ((int )prev .size (), repeat_last_n), n_ctx);
63109
64- llama_sample_repetition_penalty (ctx , &cur_p,
65- last_tokens .data () + last_tokens .size () - last_n_repeat,
110+ llama_sample_repetition_penalty (ctx_main , &cur_p,
111+ prev .data () + prev .size () - last_n_repeat,
66112 last_n_repeat, repeat_penalty);
67- llama_sample_frequency_and_presence_penalties (ctx , &cur_p,
68- last_tokens .data () + last_tokens .size () - last_n_repeat,
113+ llama_sample_frequency_and_presence_penalties (ctx_main , &cur_p,
114+ prev .data () + prev .size () - last_n_repeat,
69115 last_n_repeat, alpha_frequency, alpha_presence);
70116
71117 if (!penalize_nl) {
72118 for (size_t idx = 0 ; idx < cur_p.size ; idx++) {
73- if (cur_p.data [idx].id == llama_token_nl (ctx )) {
119+ if (cur_p.data [idx].id == llama_token_nl (ctx_main )) {
74120 cur_p.data [idx].logit = nl_logit;
75121 break ;
76122 }
77123 }
78124 }
79125 }
80126
81- if (ctx_sampling. grammar != NULL ) {
82- llama_sample_grammar (ctx , &cur_p, ctx_sampling. grammar );
127+ if (ctx_sampling-> grammar != NULL ) {
128+ llama_sample_grammar (ctx_main , &cur_p, ctx_sampling-> grammar );
83129 }
84130
85131 if (temp <= 0 ) {
86132 // Greedy sampling
87- id = llama_sample_token_greedy (ctx , &cur_p);
133+ id = llama_sample_token_greedy (ctx_main , &cur_p);
88134 } else {
89135 if (mirostat == 1 ) {
90136 const int mirostat_m = 100 ;
91- llama_sample_temp (ctx , &cur_p, temp);
92- id = llama_sample_token_mirostat (ctx , &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling. mirostat_mu );
137+ llama_sample_temp (ctx_main , &cur_p, temp);
138+ id = llama_sample_token_mirostat (ctx_main , &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling-> mirostat_mu );
93139 } else if (mirostat == 2 ) {
94- llama_sample_temp (ctx , &cur_p, temp);
95- id = llama_sample_token_mirostat_v2 (ctx , &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling. mirostat_mu );
140+ llama_sample_temp (ctx_main , &cur_p, temp);
141+ id = llama_sample_token_mirostat_v2 (ctx_main , &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling-> mirostat_mu );
96142 } else {
97143 // Temperature sampling
98144 size_t min_keep = std::max (1 , params.n_probs );
99- llama_sample_top_k (ctx , &cur_p, top_k, min_keep);
100- llama_sample_tail_free (ctx , &cur_p, tfs_z, min_keep);
101- llama_sample_typical (ctx , &cur_p, typical_p, min_keep);
102- llama_sample_top_p (ctx , &cur_p, top_p, min_keep);
103- llama_sample_temp (ctx , &cur_p, temp);
104-
105- {
106- const int n_top = 10 ;
107- LOG ( " top %d candidates: \n " , n_top);
108-
109- for ( int i = 0 ; i < n_top; i++) {
110- const llama_token id = cur_p. data [i]. id ;
111- ( void )id; // To avoid a warning that id is unused when logging is disabled.
112- LOG ( " - %5d: '%12s' (%.3f) \n " , id, llama_token_to_piece (ctx, id). c_str (), cur_p.data [i].p ) ;
113- }
114- }
115-
116- id = llama_sample_token (ctx, &cur_p);
117-
118- LOG (" sampled token: %5d: '%s'\n " , id, llama_token_to_piece (ctx , id).c_str ());
145+ llama_sample_top_k (ctx_main , &cur_p, top_k, min_keep);
146+ llama_sample_tail_free (ctx_main , &cur_p, tfs_z, min_keep);
147+ llama_sample_typical (ctx_main , &cur_p, typical_p, min_keep);
148+ llama_sample_top_p (ctx_main , &cur_p, top_p, min_keep);
149+ llama_sample_temp (ctx_main , &cur_p, temp);
150+
151+ id = llama_sample_token (ctx_main, &cur_p);
152+
153+ // {
154+ // const int n_top = 10;
155+ // LOG("top %d candidates:\n", n_top);
156+
157+ // for (int i = 0; i < n_top; i++) {
158+ // const llama_token id = cur_p.data[i].id ;
159+ // (void)id; // To avoid a warning that id is unused when logging is disabled.
160+ // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p);
161+ // }
162+ // }
163+
164+ LOG (" sampled token: %5d: '%s'\n " , id, llama_token_to_piece (ctx_main , id).c_str ());
119165 }
120166 }
121167
122- if (ctx_sampling.grammar != NULL ) {
123- llama_grammar_accept_token (ctx, ctx_sampling.grammar , id);
124- }
125-
126168 return id;
127169}
170+
171+ void llama_sampling_accept (
172+ struct llama_sampling_context * ctx_sampling,
173+ struct llama_context * ctx_main,
174+ llama_token id) {
175+ ctx_sampling->prev .erase (ctx_sampling->prev .begin ());
176+ ctx_sampling->prev .push_back (id);
177+
178+ if (ctx_sampling->grammar != NULL ) {
179+ llama_grammar_accept_token (ctx_main, ctx_sampling->grammar , id);
180+ }
181+ }
0 commit comments