Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : cache llama_token_to_piece #7587

Merged
merged 4 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 110 additions & 82 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1702,12 +1702,13 @@ struct llama_mlock {
};
using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;

static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
// NOTE: avoid ever using this except for building the token_to_piece caches
static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
std::vector<char> result(8, 0);
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special);
const int n_tokens = llama_token_to_piece(model, token, result.data(), result.size(), special);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special);
int check = llama_token_to_piece(model, token, result.data(), result.size(), special);
GGML_ASSERT(check == -n_tokens);
}
else {
Expand Down Expand Up @@ -2162,7 +2163,9 @@ struct llama_vocab {
std::unordered_map<token, id> token_to_id;
std::vector<token_data> id_to_token;

std::vector<id> special_tokens_cache;
std::vector<id> cache_special_tokens;
std::vector<token> cache_token_to_piece; // llama_token_to_piece(special = false);
std::vector<token> cache_token_to_piece_special; // llama_token_to_piece(special = true);

std::map<std::pair<std::string, std::string>, int> bpe_ranks;

Expand Down Expand Up @@ -4592,20 +4595,14 @@ static void llm_load_vocab(
vocab.special_cls_id = 101;
vocab.special_mask_id = 103;
vocab.add_space_prefix = false;
} else {
if (tokenizer_model == "gpt2") {
vocab.type = LLAMA_VOCAB_TYPE_BPE;
} else if (tokenizer_model == "gpt2") {
vocab.type = LLAMA_VOCAB_TYPE_BPE;

const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
if (add_space_prefix_keyidx != -1) {
vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
}
} else {
LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_model.c_str());
LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__);
vocab.type = LLAMA_VOCAB_TYPE_SPM;
return;
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
if (add_space_prefix_keyidx != -1) {
vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
}

// read bpe merges and populate bpe ranks
const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
if (merges_keyidx == -1) {
Expand Down Expand Up @@ -4639,6 +4636,8 @@ static void llm_load_vocab(
vocab.special_pad_id = -1;
vocab.special_cls_id = -1;
vocab.special_mask_id = -1;
} else {
throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
HanClinto marked this conversation as resolved.
Show resolved Hide resolved
}

// for now, only BPE models have pre-tokenizers
Expand Down Expand Up @@ -4833,17 +4832,31 @@ static void llm_load_vocab(
{
for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) {
vocab.special_tokens_cache.push_back(id);
vocab.cache_special_tokens.push_back(id);
}
}

std::sort( vocab.special_tokens_cache.begin(), vocab.special_tokens_cache.end(),
std::sort( vocab.cache_special_tokens.begin(), vocab.cache_special_tokens.end(),
[&] (const llama_vocab::id a, const llama_vocab::id b) {
return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
}
);

LLAMA_LOG_INFO("%s: special tokens cache size = %u.\n", __func__, (uint32_t)vocab.special_tokens_cache.size());
LLAMA_LOG_INFO("%s: special tokens cache size = %u.\n", __func__, (uint32_t)vocab.cache_special_tokens.size());
}

// build token to piece caches
{
std::vector<llama_vocab::token> cache_token_to_piece (n_vocab);
std::vector<llama_vocab::token> cache_token_to_piece_special(n_vocab);

for (uint32_t id = 0; id < n_vocab; ++id) {
cache_token_to_piece[id] = llama_token_to_piece(&model, id, false);
cache_token_to_piece_special[id] = llama_token_to_piece(&model, id, true);
}

std::swap(vocab.cache_token_to_piece, cache_token_to_piece);
std::swap(vocab.cache_token_to_piece_special, cache_token_to_piece_special);
HanClinto marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down Expand Up @@ -13233,7 +13246,7 @@ struct fragment_buffer_variant {

static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
// for each special token
for (const llama_vocab::id special_id : vocab.special_tokens_cache) {
for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
const auto & special_token = vocab.id_to_token[special_id].text;

// for each text fragment
Expand Down Expand Up @@ -14392,7 +14405,7 @@ void llama_sample_repetition_penalties(

void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
GGML_ASSERT(ctx);
const int64_t t_start_sample_us = ggml_time_us();
int64_t t_start_sample_us = ggml_time_us();

bool allow_eog = false;
for (const auto & stack : grammar->stacks) {
Expand All @@ -14404,12 +14417,13 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c

std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
candidates_decoded.reserve(candidates->size);
std::vector<llama_grammar_candidate> candidates_grammar;

std::vector<llama_grammar_candidate> candidates_grammar;
candidates_grammar.reserve(candidates->size);

for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id;
const std::string piece = llama_token_to_piece(ctx, id, false);
const llama_token id = candidates->data[i].id;
const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(id);
HanClinto marked this conversation as resolved.
Show resolved Hide resolved

if (llama_token_is_eog(&ctx->model, id)) {
if (!allow_eog) {
Expand Down Expand Up @@ -14609,7 +14623,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
GGML_ASSERT(false);
}

const std::string piece = llama_token_to_piece(ctx, token, false);
const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(token);

// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
Expand Down Expand Up @@ -18292,69 +18306,83 @@ static std::string llama_decode_text(const std::string & text) {

// does not write null-terminator to buf
int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) {
// if we have a cache - use it
{
const auto & cache = special ? model->vocab.cache_token_to_piece_special : model->vocab.cache_token_to_piece;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Maybe we could get away w/ a single cache (built w/ special=true) and early-exit in special case at the top of the function?

int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) {
    if (!special && llama_is_control_token(model->vocab, token)) {
        return 0;
    }
    // if we have a cache - use it
    if (!model->vocab.cache_token_to_piece.empty()) {
         ....
    }
    ...


if (!cache.empty()) {
HanClinto marked this conversation as resolved.
Show resolved Hide resolved
const auto & res = cache.at(token);
if (length < (int) res.size()) {
return -(int) res.size();
}
memcpy(buf, res.c_str(), res.size());
return res.size();
}
}

if (0 <= token && token < llama_n_vocab(model)) {
switch (llama_vocab_get_type(model->vocab)) {
case LLAMA_VOCAB_TYPE_WPM:
case LLAMA_VOCAB_TYPE_SPM: {
// NOTE: we accept all unsupported token types,
// suppressing them like CONTROL tokens.
if (llama_is_normal_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].text;
llama_unescape_whitespace(result);
if (length < (int) result.length()) {
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (
(llama_is_user_defined_token(model->vocab, token)) ||
(llama_is_control_token (model->vocab, token) && special)) {
std::string result = model->vocab.id_to_token[token].text;
if (length < (int) result.length()) {
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
if (length < 3) {
return -3;
}
memcpy(buf, "\xe2\x96\x85", 3);
return 3;
} else if (llama_is_byte_token(model->vocab, token)) {
if (length < 1) {
return -1;
case LLAMA_VOCAB_TYPE_WPM:
case LLAMA_VOCAB_TYPE_SPM: {
// NOTE: we accept all unsupported token types,
// suppressing them like CONTROL tokens.
if (llama_is_normal_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].text;
llama_unescape_whitespace(result);
if (length < (int) result.length()) {
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (
(llama_is_user_defined_token(model->vocab, token)) ||
(llama_is_control_token (model->vocab, token) && special)) {
std::string result = model->vocab.id_to_token[token].text;
if (length < (int) result.length()) {
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
if (length < 3) {
return -3;
}
memcpy(buf, "\xe2\x96\x85", 3);
return 3;
} else if (llama_is_byte_token(model->vocab, token)) {
if (length < 1) {
return -1;
}
buf[0] = llama_token_to_byte(model->vocab, token);
return 1;
}
buf[0] = llama_token_to_byte(model->vocab, token);
return 1;
break;
}
break;
}
case LLAMA_VOCAB_TYPE_BPE: {
// NOTE: we accept all unsupported token types,
// suppressing them like CONTROL tokens.
if (llama_is_normal_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].text;
result = llama_decode_text(result);
if (length < (int) result.length()) {
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (
(llama_is_user_defined_token(model->vocab, token)) ||
(llama_is_control_token (model->vocab, token) && special)) {
std::string result = model->vocab.id_to_token[token].text;
if (length < (int) result.length()) {
return -(int) result.length();
case LLAMA_VOCAB_TYPE_BPE: {
// NOTE: we accept all unsupported token types,
// suppressing them like CONTROL tokens.
if (llama_is_normal_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].text;
result = llama_decode_text(result);
if (length < (int) result.length()) {
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (
(llama_is_user_defined_token(model->vocab, token)) ||
(llama_is_control_token (model->vocab, token) && special)) {
std::string result = model->vocab.id_to_token[token].text;
if (length < (int) result.length()) {
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
break;
}
break;
}
default:
GGML_ASSERT(false);
default:
GGML_ASSERT(false);
}
}
return 0;
Expand Down
4 changes: 2 additions & 2 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,8 @@ extern "C" {

LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);

LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);

LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
Expand Down
Loading