1
1
#include " sampling.h"
2
2
3
- llama_sampling_context llama_sampling_context_init (
4
- const struct gpt_params & params,
5
- llama_grammar * grammar) {
6
- llama_sampling_context result;
3
+ struct llama_sampling_context * llama_sampling_init (const struct gpt_params & params) {
4
+ struct llama_sampling_context * result =
5
+ (struct llama_sampling_context *) malloc (sizeof (struct llama_sampling_context ));
7
6
8
- result.params = params.sampling_params ;
9
- result.grammar = grammar;
7
+ result->params = params.sampling_params ;
8
+ result->grammar = nullptr ;
9
+
10
+ // if there is a grammar, parse it
11
+ if (!params.grammar .empty ()) {
12
+ result->parsed_grammar = grammar_parser::parse (params.grammar .c_str ());
13
+
14
+ // will be empty (default) if there are parse errors
15
+ if (result->parsed_grammar .rules .empty ()) {
16
+ fprintf (stderr, " %s: failed to parse grammar\n " , __func__);
17
+ return nullptr ;
18
+ }
19
+
20
+ std::vector<const llama_grammar_element *> grammar_rules (result->parsed_grammar .c_rules ());
21
+
22
+ result->grammar = llama_grammar_init (
23
+ grammar_rules.data (),
24
+ grammar_rules.size (), result->parsed_grammar .symbol_ids .at (" root" ));
25
+ }
26
+
27
+ result->prev .resize (params.n_ctx );
10
28
11
29
return result;
12
30
}
13
31
32
+ void llama_sampling_free (struct llama_sampling_context * ctx) {
33
+ if (ctx->grammar != NULL ) {
34
+ llama_grammar_free (ctx->grammar );
35
+ }
36
+
37
+ free (ctx);
38
+ }
39
+
40
+ void llama_sampling_reset (llama_sampling_context * ctx) {
41
+ if (ctx->grammar != NULL ) {
42
+ llama_grammar_free (ctx->grammar );
43
+ }
44
+
45
+ if (!ctx->parsed_grammar .rules .empty ()) {
46
+ std::vector<const llama_grammar_element *> grammar_rules (ctx->parsed_grammar .c_rules ());
47
+
48
+ ctx->grammar = llama_grammar_init (
49
+ grammar_rules.data (),
50
+ grammar_rules.size (), ctx->parsed_grammar .symbol_ids .at (" root" ));
51
+ }
52
+
53
+ std::fill (ctx->prev .begin (), ctx->prev .end (), 0 );
54
+ ctx->cur .clear ();
55
+ }
56
+
14
57
llama_token llama_sampling_sample (
15
- struct llama_context * ctx,
58
+ struct llama_sampling_context * ctx_sampling,
59
+ struct llama_context * ctx_main,
16
60
struct llama_context * ctx_guidance,
17
- struct llama_sampling_context & ctx_sampling,
18
- const std::vector<llama_token> & last_tokens,
19
- std::vector<llama_token_data> & candidates,
20
- const int idx) {
21
- const int n_ctx = llama_n_ctx (ctx);
22
- const int n_vocab = llama_n_vocab (llama_get_model (ctx));
23
-
24
- const llama_sampling_params & params = ctx_sampling.params ;
61
+ const int idx) {
62
+ const int n_ctx = llama_n_ctx (ctx_main);
63
+ const int n_vocab = llama_n_vocab (llama_get_model (ctx_main));
64
+
65
+ const llama_sampling_params & params = ctx_sampling->params ;
66
+
25
67
const float temp = params.temp ;
26
68
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k ;
27
69
const float top_p = params.top_p ;
@@ -36,92 +78,104 @@ llama_token llama_sampling_sample(
36
78
const float mirostat_eta = params.mirostat_eta ;
37
79
const bool penalize_nl = params.penalize_nl ;
38
80
81
+ auto & prev = ctx_sampling->prev ;
82
+ auto & cur = ctx_sampling->cur ;
83
+
39
84
llama_token id = 0 ;
40
85
41
- float * logits = llama_get_logits_ith (ctx , idx);
86
+ float * logits = llama_get_logits_ith (ctx_main , idx);
42
87
43
88
// Apply params.logit_bias map
44
89
for (auto it = params.logit_bias .begin (); it != params.logit_bias .end (); it++) {
45
90
logits[it->first ] += it->second ;
46
91
}
47
92
48
- candidates.clear ();
93
+ cur.clear ();
94
+
49
95
for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
50
- candidates .emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
96
+ cur .emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
51
97
}
52
98
53
- llama_token_data_array cur_p = { candidates .data (), candidates .size (), false };
99
+ llama_token_data_array cur_p = { cur .data (), cur .size (), false };
54
100
55
101
if (ctx_guidance) {
56
- llama_sample_classifier_free_guidance (ctx , &cur_p, ctx_guidance, params.cfg_scale );
102
+ llama_sample_classifier_free_guidance (ctx_main , &cur_p, ctx_guidance, params.cfg_scale );
57
103
}
58
104
59
105
// apply penalties
60
- if (!last_tokens .empty ()) {
61
- const float nl_logit = logits[llama_token_nl (ctx )];
62
- const int last_n_repeat = std::min (std::min ((int )last_tokens .size (), repeat_last_n), n_ctx);
106
+ if (!prev .empty ()) {
107
+ const float nl_logit = logits[llama_token_nl (ctx_main )];
108
+ const int last_n_repeat = std::min (std::min ((int )prev .size (), repeat_last_n), n_ctx);
63
109
64
- llama_sample_repetition_penalty (ctx , &cur_p,
65
- last_tokens .data () + last_tokens .size () - last_n_repeat,
110
+ llama_sample_repetition_penalty (ctx_main , &cur_p,
111
+ prev .data () + prev .size () - last_n_repeat,
66
112
last_n_repeat, repeat_penalty);
67
- llama_sample_frequency_and_presence_penalties (ctx , &cur_p,
68
- last_tokens .data () + last_tokens .size () - last_n_repeat,
113
+ llama_sample_frequency_and_presence_penalties (ctx_main , &cur_p,
114
+ prev .data () + prev .size () - last_n_repeat,
69
115
last_n_repeat, alpha_frequency, alpha_presence);
70
116
71
117
if (!penalize_nl) {
72
118
for (size_t idx = 0 ; idx < cur_p.size ; idx++) {
73
- if (cur_p.data [idx].id == llama_token_nl (ctx )) {
119
+ if (cur_p.data [idx].id == llama_token_nl (ctx_main )) {
74
120
cur_p.data [idx].logit = nl_logit;
75
121
break ;
76
122
}
77
123
}
78
124
}
79
125
}
80
126
81
- if (ctx_sampling. grammar != NULL ) {
82
- llama_sample_grammar (ctx , &cur_p, ctx_sampling. grammar );
127
+ if (ctx_sampling-> grammar != NULL ) {
128
+ llama_sample_grammar (ctx_main , &cur_p, ctx_sampling-> grammar );
83
129
}
84
130
85
131
if (temp <= 0 ) {
86
132
// Greedy sampling
87
- id = llama_sample_token_greedy (ctx , &cur_p);
133
+ id = llama_sample_token_greedy (ctx_main , &cur_p);
88
134
} else {
89
135
if (mirostat == 1 ) {
90
136
const int mirostat_m = 100 ;
91
- llama_sample_temp (ctx , &cur_p, temp);
92
- id = llama_sample_token_mirostat (ctx , &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling. mirostat_mu );
137
+ llama_sample_temp (ctx_main , &cur_p, temp);
138
+ id = llama_sample_token_mirostat (ctx_main , &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling-> mirostat_mu );
93
139
} else if (mirostat == 2 ) {
94
- llama_sample_temp (ctx , &cur_p, temp);
95
- id = llama_sample_token_mirostat_v2 (ctx , &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling. mirostat_mu );
140
+ llama_sample_temp (ctx_main , &cur_p, temp);
141
+ id = llama_sample_token_mirostat_v2 (ctx_main , &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling-> mirostat_mu );
96
142
} else {
97
143
// Temperature sampling
98
144
size_t min_keep = std::max (1 , params.n_probs );
99
- llama_sample_top_k (ctx , &cur_p, top_k, min_keep);
100
- llama_sample_tail_free (ctx , &cur_p, tfs_z, min_keep);
101
- llama_sample_typical (ctx , &cur_p, typical_p, min_keep);
102
- llama_sample_top_p (ctx , &cur_p, top_p, min_keep);
103
- llama_sample_temp (ctx , &cur_p, temp);
104
-
105
- {
106
- const int n_top = 10 ;
107
- LOG ( " top %d candidates: \n " , n_top);
108
-
109
- for ( int i = 0 ; i < n_top; i++) {
110
- const llama_token id = cur_p. data [i]. id ;
111
- ( void )id; // To avoid a warning that id is unused when logging is disabled.
112
- LOG ( " - %5d: '%12s' (%.3f) \n " , id, llama_token_to_piece (ctx, id). c_str (), cur_p.data [i].p ) ;
113
- }
114
- }
115
-
116
- id = llama_sample_token (ctx, &cur_p);
117
-
118
- LOG (" sampled token: %5d: '%s'\n " , id, llama_token_to_piece (ctx , id).c_str ());
145
+ llama_sample_top_k (ctx_main , &cur_p, top_k, min_keep);
146
+ llama_sample_tail_free (ctx_main , &cur_p, tfs_z, min_keep);
147
+ llama_sample_typical (ctx_main , &cur_p, typical_p, min_keep);
148
+ llama_sample_top_p (ctx_main , &cur_p, top_p, min_keep);
149
+ llama_sample_temp (ctx_main , &cur_p, temp);
150
+
151
+ id = llama_sample_token (ctx_main, &cur_p);
152
+
153
+ // {
154
+ // const int n_top = 10;
155
+ // LOG("top %d candidates:\n", n_top);
156
+
157
+ // for (int i = 0; i < n_top; i++) {
158
+ // const llama_token id = cur_p.data[i].id ;
159
+ // (void)id; // To avoid a warning that id is unused when logging is disabled.
160
+ // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p);
161
+ // }
162
+ // }
163
+
164
+ LOG (" sampled token: %5d: '%s'\n " , id, llama_token_to_piece (ctx_main , id).c_str ());
119
165
}
120
166
}
121
167
122
- if (ctx_sampling.grammar != NULL ) {
123
- llama_grammar_accept_token (ctx, ctx_sampling.grammar , id);
124
- }
125
-
126
168
return id;
127
169
}
170
+
171
+ void llama_sampling_accept (
172
+ struct llama_sampling_context * ctx_sampling,
173
+ struct llama_context * ctx_main,
174
+ llama_token id) {
175
+ ctx_sampling->prev .erase (ctx_sampling->prev .begin ());
176
+ ctx_sampling->prev .push_back (id);
177
+
178
+ if (ctx_sampling->grammar != NULL ) {
179
+ llama_grammar_accept_token (ctx_main, ctx_sampling->grammar , id);
180
+ }
181
+ }
0 commit comments