diff --git a/llama.cpp b/llama.cpp index 7419b03b61dc3..c9488dc6fac4e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1771,21 +1771,35 @@ struct llama_tokenizer { for (int i = 0; i != -1; i = symbols_[i].next) { auto & symbol = symbols_[i]; - auto token = vocab_.token_to_id.find(std::string(symbol.text, symbol.n)); + resegment(symbol, output); + } + } - if (token == vocab_.token_to_id.end()) { - // output any symbols that did not form tokens as bytes. - for (int j = 0; j < (int) symbol.n; ++j) { - llama_vocab::id token_id = static_cast(symbol.text[j]) + 3; - output.push_back(token_id); - } - } else { - output.push_back((*token).second); +private: + void resegment(llama_sp_symbol &symbol, std::vector &output) { + auto text = std::string(symbol.text, symbol.n); + auto token = vocab_.token_to_id.find(text); + + if (token != vocab_.token_to_id.end()) { + output.push_back((*token).second); + return; + } + + const auto p = rev_merge.find(text); + + if (p == rev_merge.end()) { + // output any symbols that did not form tokens as bytes. + for (int j = 0; j < (int) symbol.n; ++j) { + llama_vocab::id token_id = static_cast(symbol.text[j]) + 3; + output.push_back(token_id); } + return; } + + resegment(symbols_[p->second.first], output); + resegment(symbols_[p->second.second], output); } -private: void try_add_bigram(int left, int right) { if (left == -1 || right == -1) { return; @@ -1810,11 +1824,14 @@ struct llama_tokenizer { bigram.score = tok_score.score; bigram.size = text.size(); work_queue_.push(bigram); + + rev_merge[text] = std::make_pair(left, right); } const llama_vocab & vocab_; std::vector symbols_; llama_sp_bigram::queue work_queue_; + std::map > rev_merge; }; static std::vector llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) { diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp index 20abe710018ee..a5f0770bf5896 100644 --- a/tests/test-tokenizer-0.cpp +++ b/tests/test-tokenizer-0.cpp @@ -14,6 +14,8 @@ static const std::map> & k_tests() { " this is 🦙.cpp", { 1, 445, 338, 29871, 243, 162, 169, 156, 29889, 8223, }, }, { "w048 7tuijk dsdfhu", { 1, 29893, 29900, 29946, 29947, 29871, 29955, 9161, 13535, 18031, 2176, 6905, }, }, { "нещо на Български", { 1, 821, 4851, 665, 1386, 29713, 1305, }, }, + { "\xe6\x88\x91\xe4\xbb\xac\xe5\xa4\xa7\xe5\xae\xb6\xe4\xb8\x80\xe8\xb5\xb7", { 1, 30672, 31381, 30257, 30613, 30287, 31558, }, }, + { " >>>>ANSWER<<", {1, 5099, 6778, 2190, 23066, 1001, 9314}, }, }; return _k_tests; }; @@ -82,11 +84,19 @@ int main(int argc, char **argv) { fprintf(stderr, "%6d, ", t); } fprintf(stderr, "\n"); + for (const auto & t : test_kv.second) { + fprintf(stderr, "%7s ", llama_token_to_str(ctx, t)); + } + fprintf(stderr, "\n"); fprintf(stderr, "%s : got tokens: ", __func__); for (const auto & t : res) { fprintf(stderr, "%6d, ", t); } fprintf(stderr, "\n"); + for (const auto & t : res) { + fprintf(stderr, "%7s ", llama_token_to_str(ctx, t)); + } + fprintf(stderr, "\n"); llama_free_model(model); llama_free(ctx); @@ -94,6 +104,38 @@ int main(int argc, char **argv) { } } +#if 0 + // how many tokens would not tokenize to themselves + for (llama_token i = 1; i < llama_n_vocab(ctx); i++) + { + const char* str = llama_token_to_str(ctx, i); + std::vector res(100); + + const int n = llama_tokenize(ctx, str, res.data(), int(res.size()), false); + res.resize(n); + + for (const auto & t : res) + { + //if (t == 1) continue; + + if (t != i) { + fprintf(stderr, "%s : failed test: '%s'\n", __func__, str); + fprintf(stderr, "%s : expected tokens: %d\n", __func__, i); + fprintf(stderr, "%s : got tokens: ", __func__); + for (const auto & t : res) { + fprintf(stderr, "%6d, ", t); + } + for (const auto & t : res) { + fprintf(stderr, "%s|", llama_token_to_str(ctx, t)); + } + + fprintf(stderr, "\n"); + } + } + + } +#endif + llama_free_model(model); llama_free(ctx);