Skip to content

Commit

Permalink
rwkv : speed-up tokenization using trie
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Aug 30, 2024
1 parent 7f2ef56 commit 7004323
Showing 1 changed file with 33 additions and 31 deletions.
64 changes: 33 additions & 31 deletions src/llama-vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,17 @@ struct naive_trie {
auto res = children.find(c);
if (res != children.end()) {
return res->second.get_longest_prefix(key, len, offset + 1);
} else {
return std::make_pair(key, offset);
}

return std::make_pair(key, offset);
}
struct naive_trie * traverse(const char c) {
const struct naive_trie * traverse(const char c) const {
auto res = children.find(c);
if (res != children.end()) {
return &res->second;
} else {
return NULL;
}

return NULL;
}
std::map<char, struct naive_trie> children;
bool has_value;
Expand Down Expand Up @@ -843,7 +843,7 @@ struct llm_tokenizer_ugm {
// traverse the token matcher trie to find a matching token
bool single_codepoint_token_found = false;
const struct best_tokenization & current_best = tokenization_results[input_offset];
struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);

while (prefix_offset <= input_len && node != NULL) {
// check if we found valid token in prefix
Expand Down Expand Up @@ -1103,6 +1103,7 @@ struct llm_tokenizer_ugm {

static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escaped) {
std::vector<uint8_t> output;
output.reserve(escaped.size());

// Parser state
bool escaping = false;
Expand Down Expand Up @@ -1158,46 +1159,47 @@ struct llm_tokenizer_rwkv {
llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) {
// RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
// For now, we decode the vocab here into the lookup we'll use for tokenization.
for (const auto & token : vocab.id_to_token) {
auto data = llama_unescape_rwkv_token(token.text);
tokens.push_back(data);

// build trie
for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
const auto & token = vocab.id_to_token[id];
const auto data = llama_unescape_rwkv_token(token.text);
token_matcher.insert((const char *) data.data(), data.size(), id);
}
}

void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
uint32_t position = 0;

while (position < text.size()) {
// Iterate through possible tokens backwards, starting with the largest
for (int32_t i = (int32_t)tokens.size() - 1; i >= 0; i--) {
// Skip tokens that aren't normal type, we can't match on those
if (!(vocab.id_to_token[i].attr & LLAMA_TOKEN_ATTR_NORMAL)) {
continue;
}

uint32_t token_size = tokens[i].size();

// If there's not enough left for this token
if (text.size() - position < token_size) {
continue;
}
const struct naive_trie * node = token_matcher.traverse(text[position]);
if (node == NULL) {
// no matching token found, add unknown token
output.push_back(vocab.special_unk_id);
position += 1;
continue;
}

// If the token doesn't match the data
if (std::memcmp(text.data() + position, tokens[i].data(), token_size) != 0) {
continue;
// traverse the trie to find the longest matching token
uint32_t token_id = 0;
uint32_t token_length = 0;
while (node != NULL) {
if (node->has_value) {
token_id = node->value;
token_length = position + 1;
}

// Add the token and advance
output.push_back(i);
position += token_size;
break;
node = node->traverse(text[++position]);
}

// add the longest matching token
output.push_back(token_id);
position = token_length;
}
}

const llama_vocab & vocab;

std::vector<std::vector<uint8_t>> tokens;
struct naive_trie token_matcher;
};

//
Expand Down

0 comments on commit 7004323

Please sign in to comment.