Skip to content

Commit 4d54281

Browse files
committed
llama_sampling_sample with default args is more naively usable
* Batches populated by either llama_batch_get_one or llama_batch_add work with default args * Previously get_one could use the default argument * Previously add should usually have used the last index where logits[idx] == true * This hopefully encourages the use of llama_batch_add * By giving expected results when using default arguments. * Believed to work with any currently well behaved program * Default arg now works for both cases (previously would give strange results for add case) * Any non-negative number is unaffected and behaves as previously * Negative arguments were previously invalid.
1 parent c6a1f52 commit 4d54281

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

common/sampling.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ static llama_token llama_sampling_sample_impl(
164164
struct llama_sampling_context * ctx_sampling,
165165
struct llama_context * ctx_main,
166166
struct llama_context * ctx_cfg,
167-
const int idx,
167+
int idx,
168168
bool is_resampling) { // Add a parameter to indicate if we are resampling
169169
const llama_sampling_params & params = ctx_sampling->params;
170170

@@ -173,6 +173,12 @@ static llama_token llama_sampling_sample_impl(
173173
const float mirostat_tau = params.mirostat_tau;
174174
const float mirostat_eta = params.mirostat_eta;
175175

176+
if (idx == -1) {
177+
const int32_t last_idx = llama_get_logits_last_idx(ctx_main);
178+
GGML_ASSERT(last_idx >= 0);
179+
idx = last_idx;
180+
}
181+
176182
std::vector<float> original_logits;
177183
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits);
178184
if (!is_resampling) {

common/sampling.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ std::string llama_sampling_order_print(const llama_sampling_params & params);
119119
//
120120
// optional:
121121
// - ctx_cfg: context to use for classifier-free guidance
122-
// - idx: sample from llama_get_logits_ith(ctx, idx)
122+
// - idx: sample from llama_get_logits_ith(ctx, idx), -1 for last logit
123123
//
124124
// returns:
125125
// - token: sampled token
@@ -129,7 +129,7 @@ llama_token llama_sampling_sample(
129129
struct llama_sampling_context * ctx_sampling,
130130
struct llama_context * ctx_main,
131131
struct llama_context * ctx_cfg,
132-
int idx = 0);
132+
int idx = -1);
133133

134134
// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
135135
llama_token_data_array llama_sampling_prepare(

llama.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15543,7 +15543,7 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
1554315543
int32_t llama_get_logits_last_idx(struct llama_context * ctx) {
1554415544
llama_synchronize(ctx);
1554515545

15546-
if (ctx->cparams.n_batch == 1 && ctx->n_outputs == 1) {
15546+
if ((ctx->cparams.n_batch == 1) && (ctx->n_outputs == 1)) {
1554715547
// trivial case, one input/output, return it.
1554815548
return 0;
1554915549
}

0 commit comments

Comments
 (0)