Skip to content

Commit

Permalink
server : add n_indent parameter for line indentation requirement (gge…
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov authored and arthw committed Nov 18, 2024
1 parent e2d2676 commit aa9a7ef
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 7 deletions.
2 changes: 2 additions & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,8 @@ node index.js

`n_predict`: Set the maximum number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. When 0, no tokens will be generated but the prompt is evaluated into the cache. Default: `-1`, where `-1` is infinity.

`n_indent`: Specify the minimum line indentation for the generated text in number of whitespace characters. Useful for code completion tasks. Default: `0`

`n_keep`: Specify the number of tokens from the prompt to retain when the context size is exceeded and tokens need to be discarded. The number excludes the BOS token.
By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to retain all tokens from the prompt.

Expand Down
54 changes: 47 additions & 7 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ struct slot_params {
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict
int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters

int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
Expand Down Expand Up @@ -173,6 +174,8 @@ struct server_slot {
std::vector<llama_token> prompt_tokens;
std::vector<llama_token> extra_tokens;

size_t last_nl_pos = 0;

std::string generated_text;
std::vector<llama_token> cache_tokens;
std::vector<completion_token_output> generated_token_probs;
Expand Down Expand Up @@ -215,6 +218,7 @@ struct server_slot {
SLT_DBG(*this, "%s", "\n");

n_prompt_tokens = 0;
last_nl_pos = 0;
generated_text = "";
has_new_line = false;
truncated = false;
Expand Down Expand Up @@ -860,6 +864,7 @@ struct server_context {
slot.params.stream = json_value(data, "stream", false);
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent);
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
Expand All @@ -878,7 +883,7 @@ struct server_context {
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
slot.params.n_keep = json_value(data, "n_keep", default_params.n_keep);
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
Expand Down Expand Up @@ -1129,13 +1134,48 @@ struct server_context {
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
}

// if we have already seen a new line, we stop after a certain time limit
if (slot.has_new_line && slot.params.t_max_predict_ms > 0 &&
(ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
slot.stopped_limit = true;
slot.has_next_token = false;
if (slot.has_new_line) {
// if we have already seen a new line, we stop after a certain time limit
if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
slot.stopped_limit = true;
slot.has_next_token = false;

SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
}

// require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
if (slot.params.n_indent > 0) {
// check the current indentation
// TODO: improve by not doing it more than once for each new line
if (slot.last_nl_pos > 0) {
size_t pos = slot.last_nl_pos;

int n_indent = 0;
while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
n_indent++;
pos++;
}

if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) {
slot.stopped_limit = true;
slot.has_next_token = false;

// cut the last line
slot.generated_text.erase(pos, std::string::npos);

SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent);
}
}

// find the next new line
{
const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos);

if (pos != std::string::npos) {
slot.last_nl_pos = pos + 1;
}
}
}
}

// check if there is a new line in the generated text
Expand Down

0 comments on commit aa9a7ef

Please sign in to comment.