Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1005,15 +1005,21 @@ struct common_init_result common_init_from_params(common_params & params) {
params.sampling.ignore_eos = false;
}

if (params.sampling.ignore_eos) {
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
params.sampling.logit_bias.push_back({i, -INFINITY});
}
// initialize once
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
}
}

if (params.sampling.ignore_eos) {
// add EOG biases to the active set of logit biases
params.sampling.logit_bias.insert(
params.sampling.logit_bias.end(),
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
}

if (params.sampling.penalty_last_n == -1) {
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
params.sampling.penalty_last_n = llama_n_ctx(lctx);
Expand Down
3 changes: 2 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ struct common_params_sampling {
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
std::set<llama_token> preserved_tokens;

std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens

// print the parameters into a string
std::string print() const;
Expand Down
11 changes: 3 additions & 8 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,12 +473,9 @@ struct server_task {

params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos);
if (params.sampling.ignore_eos) {
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
//SRV_DBG("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(ctx, i).c_str(), -INFINITY);
params.sampling.logit_bias.push_back({i, -INFINITY});
}
}
params.sampling.logit_bias.insert(
params.sampling.logit_bias.end(),
defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end());
}
}

Expand Down Expand Up @@ -1906,7 +1903,6 @@ struct server_context {

bool clean_kv_cache = true;
bool add_bos_token = true;
bool has_eos_token = false;

int32_t n_ctx; // total context for all clients / slots

Expand Down Expand Up @@ -1965,7 +1961,6 @@ struct server_context {
n_ctx = llama_n_ctx(ctx);

add_bos_token = llama_vocab_get_add_bos(vocab);
has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;

if (!params_base.speculative.model.path.empty() || !params_base.speculative.model.hf_repo.empty()) {
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());
Expand Down
Loading