Skip to content

Commit 403dee8

Browse files
committed
llama : vocab cleanup
ggml-ci
1 parent 0f71186 commit 403dee8

File tree

4 files changed

+26
-21
lines changed

4 files changed

+26
-21
lines changed

src/llama-sampling.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1663,7 +1663,7 @@ struct llama_sampler_dry {
16631663

16641664
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
16651665
static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
1666-
for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) {
1666+
for (llama_token token_id = 0; token_id < (llama_token) vocab.n_vocab(); token_id++) {
16671667
std::string word = vocab.detokenize({token_id}, true);
16681668
if (word.find(str) != std::string::npos) {
16691669
token_sequences.emplace(token_id, std::vector<llama_token>());

src/llama-vocab.cpp

+19-12
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ struct llm_tokenizer_spm_session {
208208
return;
209209
}
210210

211-
if (static_cast<uint32_t>(token) >= vocab.n_vocab) {
211+
if (static_cast<uint32_t>(token) >= vocab.n_vocab()) {
212212
return;
213213
}
214214

@@ -734,7 +734,7 @@ struct llm_tokenizer_ugm : llm_tokenizer {
734734
prefix_replacements_size = precompiled_charsmap.size() - charsmap_offset;
735735
}
736736

737-
for (uint32_t id = 0; id < vocab.n_vocab; ++id) {
737+
for (uint32_t id = 0; id < vocab.n_vocab(); ++id) {
738738
const auto & token_data = vocab.get_token_data(id);
739739

740740
if (vocab.is_normal(id)) {
@@ -1119,7 +1119,7 @@ struct llm_tokenizer_rwkv : llm_tokenizer {
11191119
// For now, we decode the vocab here into the lookup we'll use for tokenization.
11201120

11211121
// build trie
1122-
for (uint32_t id = 0; id < vocab.n_vocab; ++id) {
1122+
for (uint32_t id = 0; id < vocab.n_vocab(); ++id) {
11231123
const auto & data = vocab.get_token_data(id);
11241124
const auto text = llama_unescape_rwkv_token(data.text);
11251125
token_matcher.insert((const char *) text.data(), text.size(), id);
@@ -1204,6 +1204,8 @@ struct fragment_buffer_variant {
12041204
};
12051205

12061206
struct llama_vocab::impl {
1207+
uint32_t n_vocab = 0;
1208+
12071209
std::unordered_map<std::string, llama_token> token_to_id;
12081210
std::vector<token_data> id_to_token;
12091211

@@ -1283,6 +1285,13 @@ llama_vocab::~llama_vocab() {
12831285
void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
12841286
struct gguf_context * ctx = ml.meta.get();
12851287

1288+
auto & n_vocab = pimpl->n_vocab;
1289+
auto & id_to_token = pimpl->id_to_token;
1290+
auto & token_to_id = pimpl->token_to_id;
1291+
auto & special_eog_ids = pimpl->special_eog_ids;
1292+
auto & cache_special_tokens = pimpl->cache_special_tokens;
1293+
auto & cache_token_to_piece = pimpl->cache_token_to_piece;
1294+
12861295
// determine vocab type
12871296
{
12881297
std::string tokenizer_model;
@@ -1589,12 +1598,6 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
15891598
toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
15901599
}
15911600

1592-
auto & id_to_token = pimpl->id_to_token;
1593-
auto & token_to_id = pimpl->token_to_id;
1594-
auto & special_eog_ids = pimpl->special_eog_ids;
1595-
auto & cache_special_tokens = pimpl->cache_special_tokens;
1596-
auto & cache_token_to_piece = pimpl->cache_token_to_piece;
1597-
15981601
n_vocab = gguf_get_arr_n(ctx, token_idx);
15991602
id_to_token.resize(n_vocab);
16001603

@@ -1908,7 +1911,7 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
19081911

19091912
// build special tokens cache
19101913
{
1911-
for (llama_token id = 0; id < (llama_token)n_vocab; ++id) {
1914+
for (llama_token id = 0; id < (llama_token) n_vocab; ++id) {
19121915
if (id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
19131916
cache_special_tokens.push_back(id);
19141917
}
@@ -2002,6 +2005,10 @@ enum llama_vocab_pre_type llama_vocab::get_pre_type() const {
20022005
return pre_type;
20032006
}
20042007

2008+
uint32_t llama_vocab::n_vocab() const {
2009+
return (uint32_t) pimpl->id_to_token.size();
2010+
}
2011+
20052012
std::string llama_vocab::type_name() const{
20062013
switch (type) {
20072014
case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
@@ -2366,8 +2373,8 @@ int llama_vocab::max_token_text_len() const {
23662373

23672374
void llama_vocab::print_info() const {
23682375
LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, type_name().c_str());
2369-
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, n_vocab);
2370-
LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) pimpl->bpe_ranks.size());
2376+
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, pimpl->n_vocab);
2377+
LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (uint32_t) pimpl->bpe_ranks.size());
23712378

23722379
auto & id_to_token = pimpl->id_to_token;
23732380
auto & special_eog_ids = pimpl->special_eog_ids;

src/llama-vocab.h

+3-5
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44

55
#include <string>
66
#include <vector>
7-
#include <unordered_map>
8-
#include <map>
9-
#include <set>
107
#include <memory>
118

129
struct LLM_KV;
@@ -19,8 +16,6 @@ struct llama_vocab {
1916
llama_token_attr attr;
2017
};
2118

22-
uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab
23-
2419
llama_vocab();
2520
~llama_vocab();
2621

@@ -29,6 +24,9 @@ struct llama_vocab {
2924
enum llama_vocab_type get_type() const;
3025
enum llama_vocab_pre_type get_pre_type() const;
3126

27+
// TODO: how to deduplicate with llama_hparams.n_vocab ?
28+
uint32_t n_vocab() const;
29+
3230
std::string type_name() const;
3331

3432
bool is_normal (llama_token id) const;

src/llama.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
6666
model.print_info();
6767

6868
if (model.vocab.get_type() != LLAMA_VOCAB_TYPE_NONE &&
69-
model.hparams.n_vocab != model.vocab.n_vocab) {
69+
model.hparams.n_vocab != model.vocab.n_vocab()) {
7070
throw std::runtime_error("vocab size mismatch");
7171
}
7272

@@ -8317,7 +8317,7 @@ static int llama_decode_impl(
83178317

83188318
if (batch.token) {
83198319
for (uint32_t i = 0; i < n_tokens_all; ++i) {
8320-
if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) {
8320+
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_vocab()) {
83218321
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
83228322
return -1;
83238323
}
@@ -8652,7 +8652,7 @@ static int llama_encode_impl(
86528652

86538653
if (batch.token) {
86548654
for (uint32_t i = 0; i < n_tokens; ++i) {
8655-
if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) {
8655+
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_vocab()) {
86568656
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
86578657
return -1;
86588658
}

0 commit comments

Comments
 (0)