Skip to content

Commit

Permalink
llama_sampling_sample with default args is more naively usable
Browse files Browse the repository at this point in the history
* Batches populated by either llama_batch_get_one or llama_batch_add work with default args
  * Previously get_one could use the default argument
  * Previously add should usually have used the last index where logits[idx] == true
* This hopefully encourages the use of llama_batch_add
  * By giving expected results when using default arguments.
* Believed to work with any currently well behaved program
  * Default arg now works for both cases (previously would give strange results for add case)
  * Any non-negative number is unaffected and behaves as previously
  * Negative arguments were previously invalid.
  • Loading branch information
TheFlipbook committed Apr 7, 2024
1 parent c6a1f52 commit 4d54281
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
8 changes: 7 additions & 1 deletion common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ static llama_token llama_sampling_sample_impl(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx,
int idx,
bool is_resampling) { // Add a parameter to indicate if we are resampling
const llama_sampling_params & params = ctx_sampling->params;

Expand All @@ -173,6 +173,12 @@ static llama_token llama_sampling_sample_impl(
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;

if (idx == -1) {
const int32_t last_idx = llama_get_logits_last_idx(ctx_main);
GGML_ASSERT(last_idx >= 0);
idx = last_idx;
}

std::vector<float> original_logits;
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits);
if (!is_resampling) {
Expand Down
4 changes: 2 additions & 2 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ std::string llama_sampling_order_print(const llama_sampling_params & params);
//
// optional:
// - ctx_cfg: context to use for classifier-free guidance
// - idx: sample from llama_get_logits_ith(ctx, idx)
// - idx: sample from llama_get_logits_ith(ctx, idx), -1 for last logit
//
// returns:
// - token: sampled token
Expand All @@ -129,7 +129,7 @@ llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
int idx = 0);
int idx = -1);

// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
llama_token_data_array llama_sampling_prepare(
Expand Down
2 changes: 1 addition & 1 deletion llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15543,7 +15543,7 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
int32_t llama_get_logits_last_idx(struct llama_context * ctx) {
llama_synchronize(ctx);

if (ctx->cparams.n_batch == 1 && ctx->n_outputs == 1) {
if ((ctx->cparams.n_batch == 1) && (ctx->n_outputs == 1)) {
// trivial case, one input/output, return it.
return 0;
}
Expand Down

0 comments on commit 4d54281

Please sign in to comment.