From 61a98bc30aca7d7a839539577be4931250c881e3 Mon Sep 17 00:00:00 2001 From: Igor Pissolati <igo08an@hotmail.com> Date: Sun, 18 Jun 2023 20:11:01 -0300 Subject: [PATCH 01/13] Improve support for special tokens --- convert.py | 89 +++++++++++++++++++--------- llama-util.h | 164 +++++++++++++++++++++++++++++++++++++++++++++++++++ llama.cpp | 76 +++++++++++++++++++++--- llama.h | 4 +- 4 files changed, 297 insertions(+), 36 deletions(-) diff --git a/convert.py b/convert.py index f3bf1798089cc..8bc06120dc84e 100644 --- a/convert.py +++ b/convert.py @@ -142,6 +142,7 @@ def find_n_mult(n_ff: int, n_embd: int) -> int: @dataclass class Params: n_vocab: int + n_vocab_sp:int n_embd: int n_mult: int n_head: int @@ -169,6 +170,7 @@ def guessed(model: 'LazyModel') -> 'Params': return Params( n_vocab = n_vocab, + n_vocab_sp= n_vocab, n_embd = n_embd, n_mult = 256, n_head = n_head, @@ -191,6 +193,7 @@ def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params': return Params( n_vocab = n_vocab, + n_vocab_sp= n_vocab, n_embd = n_embd, n_mult = n_mult, n_head = n_head, @@ -215,6 +218,7 @@ def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params': return Params( n_vocab = n_vocab, + n_vocab_sp= n_vocab n_embd = n_embd, n_mult = n_mult, n_head = n_head, @@ -239,7 +243,7 @@ def load(model_plus: 'ModelPlus') -> 'Params': class SentencePieceVocab: - def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], vocabtype: Optional[str]) -> None: + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], fname_special_tokens: Optional[Path], vocabtype: Optional[str]) -> None: self.vocabtype = vocabtype if self.vocabtype == "bpe": self.sentencepiece_tokenizer = json.loads(open(str(fname_tokenizer)).read()) @@ -264,35 +268,46 @@ def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], vo self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list) self.fname_tokenizer = fname_tokenizer self.fname_added_tokens = fname_added_tokens + special_tokens: Dict[str, Dict[str, Any]] + if fname_special_tokens is not None: + special_tokens = json.load(open(fname_special_tokens)) + else: + special_tokens = {} + token_name_to_id = {"unk_token": self.sentencepiece_tokenizer.unk_id(), "bos_token": self.sentencepiece_tokenizer.bos_id(), "eos_token": self.sentencepiece_tokenizer.eos_id(), "pad_token": self.sentencepiece_tokenizer.pad_id()} + self.special_tokens_map = {token_name_to_id[token_name]: info["content"] if isinstance(info, dict) else info for token_name, info in special_tokens.items() if token_name in token_name_to_id and token_name_to_id[token_name] != -1} + self.vocab_special_size: int = len(self.added_tokens_list) + len(self.special_tokens_map) def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float]]: tokenizer = self.sentencepiece_tokenizer if self.vocabtype == "bpe": - from transformers.models.gpt2 import tokenization_gpt2 - byte_encoder = tokenization_gpt2.bytes_to_unicode() - byte_decoder = {v: k for k, v in byte_encoder.items()} - for i, item in enumerate(tokenizer): - text: bytes - text = b''.join([x.to_bytes(1, byteorder='big') for x in [byte_decoder[y] for y in item]]) - score: float = -i - yield text, score + from transformers.models.gpt2 import tokenization_gpt2 + byte_encoder = tokenization_gpt2.bytes_to_unicode() + byte_decoder = {v: k for k, v in byte_encoder.items()} + for i, item in enumerate(tokenizer): + text: bytes + text = b''.join([x.to_bytes(1, byteorder='big') for x in [byte_decoder[y] for y in item]]) + score: float = -i + yield text, score else: - for i in range(tokenizer.vocab_size()): - text: bytes - if tokenizer.is_unknown(i): - text = " \u2047 ".encode("utf-8") - elif tokenizer.is_control(i): - text = b"" - elif tokenizer.is_byte(i): - piece = tokenizer.id_to_piece(i) - if len(piece) != 6: - raise Exception(f"Invalid token: {piece}") - byte_value = int(piece[3:-1], 16) - text = struct.pack("B", byte_value) - else: - text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") - score: float = tokenizer.get_score(i) - yield text, score + special_tokens = [tokenizer.bos_id(), tokenizer.eos_id(), tokenizer.pad_id()] + for i in range(tokenizer.vocab_size()): + text: bytes + if tokenizer.is_unknown(i): + text = self.special_tokens_map.get(i, " \u2047 ").encode("utf-8") + elif i in special_tokens: + text = self.special_tokens_map.get(i, "").encode("utf-8") + elif tokenizer.is_control(i): + text = b"" + elif tokenizer.is_byte(i): + piece = tokenizer.id_to_piece(i) + if len(piece) != 6: + raise Exception(f"Invalid token: {piece}") + byte_value = int(piece[3:-1], 16) + text = struct.pack("B", byte_value) + else: + text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") + score: float = tokenizer.get_score(i) + yield text, score def added_tokens(self) -> Iterable[Tuple[bytes, float]]: for text in self.added_tokens_list: @@ -303,6 +318,12 @@ def all_tokens(self) -> Iterable[Tuple[bytes, float]]: yield from self.sentencepiece_tokens() yield from self.added_tokens() + def all_special_tokens(self) -> Iterable[int]: + for token_id in self.special_tokens_map.keys(): + yield token_id + for i in range(len(self.added_tokens_list)): + yield self.vocab_size_base + i + def __repr__(self) -> str: return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>" @@ -310,11 +331,16 @@ def __repr__(self) -> str: class GGMLVocab: def __init__(self, tokens: List[Tuple[bytes, float]]): self.tokens = tokens + self.special_tokens = [] self.vocab_size = len(tokens) + self.vocab_special_size = 0 def all_tokens(self) -> Iterable[Tuple[bytes, float]]: return self.tokens + def all_special_tokens(self) -> Iterable[int]: + return self.special_tokens + def __repr__(self) -> str: return f"<GGMLVocab with {self.vocab_size} tokens>" @@ -1066,8 +1092,9 @@ def __init__(self, fname_out: Path) -> None: def write_file_header(self, params: Params, file_type: GGMLFileType) -> None: self.fout.write(b"ggjt"[::-1]) # magic values = [ - 1, # file version + 4, # file version params.n_vocab, + params.n_vocab_sp, params.n_embd, params.n_mult, params.n_head, @@ -1089,11 +1116,14 @@ def write_vocab(self, vocab: Vocab) -> None: self.fout.write(struct.pack("i", len(text))) self.fout.write(text) self.fout.write(struct.pack("f", score)) + for token_id in vocab.all_special_tokens(): + self.fout.write(struct.pack("i", token_id)) @staticmethod def write_vocab_only(fname_out: Path, vocab: Vocab) -> None: of = OutputFile(fname_out) - params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0) + params = Params(n_vocab=vocab.vocab_size, n_vocab_sp=vocab.vocab_special_size, n_embd=0, n_mult=0, + n_head=1, n_layer=0) of = OutputFile(fname_out) of.write_file_header(params, file_type=GGMLFileType.AllF32) of.write_vocab(vocab) @@ -1249,8 +1279,10 @@ def load_vocab(path: Path, vocabtype: Optional[str]) -> SentencePieceVocab: f"Could not find tokenizer.model in {path} or its parent; " "if it's in another directory, pass the directory as --vocab-dir") added_tokens_path = path.parent / "added_tokens.json" + special_tokens_path = path.parent / "special_tokens_map.json" + tokenizer_config_path = path.parent / "tokenizer_config.json" print(f"Loading vocab file {path}") - return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None, + return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None, special_tokens_path if special_tokens_path.exists() else tokenizer_config_path if tokenizer_config_path.exists() else None, vocabtype) @@ -1313,6 +1345,7 @@ def main(args_in: Optional[List[str]] = None) -> None: vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent vocab = load_vocab(vocab_dir, args.vocabtype) params = Params.load(model_plus) + params.n_vocab_sp = vocab.vocab_special_size model = model_plus.model model = do_necessary_conversions(model, params) output_type = pick_output_type(model, args.outtype) diff --git a/llama-util.h b/llama-util.h index 6e9e39ddb6f58..a3ec8a501f413 100644 --- a/llama-util.h +++ b/llama-util.h @@ -14,6 +14,8 @@ #include <string> #include <vector> +#include <map> +#include <unordered_map> #include <stdexcept> #ifdef __has_include @@ -541,4 +543,166 @@ struct llama_ctx_buffer { typedef llama_buffer llama_ctx_buffer; #endif +struct llama_trie_node { + llama_trie_node(): is_terminator(false) {} + + std::unordered_map<char, llama_trie_node*> children; + bool is_terminator; +}; + +// Trie in C++. Creates a Trie out of a list of words. The trie is used to split on multiple delimiters in one pass +// Ported from: https://github.com/huggingface/transformers/blob/ee88ae59940fd4b2c8fc119373143d7a1175c651/src/transformers/tokenization_utils.py#L52 +struct llama_trie { +public: + llama_trie(): root_(new llama_trie_node()) {} + + void add(const std::string & word) { + if (word.empty()) { + return; + } + + llama_trie_node *ref = root_; + for (char c : word) { + if (ref->children.find(c) == ref->children.end()) { + ref->children[c] = new llama_trie_node(); + } + ref = ref->children[c]; + } + ref->is_terminator = true; + } + + // Will look for the words added to the trie within `text`. Output is the boundaries of the words found. + // Note that this trie will match the longest possible word first! + std::vector<int> split(const std::string & text) const { + std::map<int, llama_trie_node*> states; + std::vector<int> offsets{0}; + + int skip = 0; + for (int current = 0; current < text.size(); current++) { + char current_char = text[current]; + if (skip > 0 && current < skip) { + // Prevents the lookahead for matching twice + // like extra_id_100 and id_100 + continue; + } + + // Whenever we found a match, we need to drop everything + // this is a greedy algorithm, it will match on the first found token + bool reset = false; + + // In this case, we already have partial matches (But unfinished) + for (auto state = states.begin(); state != states.end(); ) { + int start = state->first; + llama_trie_node *trie_pointer = state->second; + if (trie_pointer->is_terminator) { + // This is a final match, we need to reset and + // store the results in `offsets`. + + // Lookahead to match longest first + // Important in case of extra_id_1 vs extra_id_100 + // Here we are also actively looking for other earlier partial + // matches + // "[CLS]", "L", we need to match CLS even if L is special + int end = 0; + for (const auto & look : states) { + int lookstart = look.first; + llama_trie_node *looktrie_pointer = look.second; + int lookahead_index = 0; + if (lookstart > start) { + // This partial match is later, we can stop looking + break; + } + if (lookstart < start) { + // This partial match is earlier, the trie pointer + // was already updated, so index is + 1 + lookahead_index = current + 1; + end = current + 1; + } else { + // Here lookstart == start and + // looktrie_pointer == trie_pointer + // It wasn't updated yet so indices are current ones + lookahead_index = current; + end = current; + } + char next_char = lookahead_index < text.size() ? text[lookahead_index] : '\0'; + if (looktrie_pointer->is_terminator) { + start = lookstart; + end = lookahead_index; + skip = lookahead_index; + } + + auto looktrie_pointer_it = looktrie_pointer->children.find(next_char); + while (looktrie_pointer_it != looktrie_pointer->children.end()) { + looktrie_pointer = looktrie_pointer_it->second; + lookahead_index++; + if (looktrie_pointer->is_terminator) { + start = lookstart; + end = lookahead_index; + skip = lookahead_index; + } + + if (lookahead_index == text.size()) { + // End of string + break; + } + next_char = text[lookahead_index]; + looktrie_pointer_it = looktrie_pointer->children.find(next_char); + } + } + + offsets.push_back(start); + offsets.push_back(end); + reset = true; + break; + } + + auto trie_pointer_it = trie_pointer->children.find(current_char); + if (trie_pointer_it != trie_pointer->children.end()) { + // The current character being looked at has a match within the trie + // update the pointer (it will be stored back into states later). + trie_pointer = trie_pointer_it->second; + states[start] = trie_pointer; + ++state; + } else { + // The new character has not match in the trie, we need + // to stop keeping track of this partial match. + state = states.erase(state); + } + } + + if (reset) { + // Clear the full start (we found a real match) + states.clear(); + } + + // If this character is a starting character within the trie + // start keeping track of this partial match. + auto children_it = root_->children.find(current_char); + if (current >= skip && children_it != root_->children.end()) { + states[current] = children_it->second; + } + } + + // We have a cut at the end with states. + for (const auto & state : states) { + int start = state.first; + llama_trie_node *trie_pointer = state.second; + if (trie_pointer->is_terminator) { + // This is a final match, we need to reset and + // store the results in `offsets`. + int end = text.size(); + offsets.push_back(start); + offsets.push_back(end); + break; + } + } + + offsets.push_back(text.size()); + return offsets; + } + +private: + llama_trie_node *root_; +}; + #endif diff --git a/llama.cpp b/llama.cpp index 39aefd499dd0c..9908065ee8ce6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -181,6 +181,7 @@ static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_PER_CONTEXT() // default hparams (LLaMA 7B) struct llama_hparams { uint32_t n_vocab = 32000; + uint32_t n_vocab_sp = 0; uint32_t n_ctx = 512; // this is provided as user input? uint32_t n_embd = 4096; uint32_t n_mult = 256; @@ -277,6 +278,11 @@ struct llama_vocab { std::unordered_map<token, id> token_to_id; std::vector<token_score> id_to_token; + + llama_trie special_token_trie; + std::unordered_map<token, id> special_token_to_id; + std::vector<id> special_tokens; + size_t max_special_token_length; }; struct llama_model { @@ -494,6 +500,7 @@ enum llama_file_version { LLAMA_FILE_VERSION_GGJT_V1, // added padding LLAMA_FILE_VERSION_GGJT_V2, // changed quantization format LLAMA_FILE_VERSION_GGJT_V3, // changed Q4 and Q8 quantization format + LLAMA_FILE_VERSION_GGJT_V4, // improved support for added/special tokens }; struct llama_file_loader { @@ -531,6 +538,7 @@ struct llama_file_loader { case 1: file_version = LLAMA_FILE_VERSION_GGJT_V1; return; case 2: file_version = LLAMA_FILE_VERSION_GGJT_V2; return; case 3: file_version = LLAMA_FILE_VERSION_GGJT_V3; return; + case 4: file_version = LLAMA_FILE_VERSION_GGJT_V4; return; } } @@ -539,6 +547,7 @@ struct llama_file_loader { } void read_hparams() { hparams.n_vocab = file.read_u32(); + hparams.n_vocab_sp = file_version >= LLAMA_FILE_VERSION_GGJT_V4 ? file.read_u32() : 0; hparams.n_embd = file.read_u32(); hparams.n_mult = file.read_u32(); hparams.n_head = file.read_u32(); @@ -566,6 +575,21 @@ struct llama_file_loader { tok_score.tok = std::move(word); tok_score.score = score; } + + vocab.special_token_to_id.reserve(hparams.n_vocab_sp); + + for (uint32_t i = 0; i < hparams.n_vocab_sp; i++) { + uint32_t token_id = file.read_u32(); + const auto & token = vocab.id_to_token[token_id].tok; + + vocab.special_token_trie.add(token); + vocab.special_tokens.push_back(token_id); + vocab.special_token_to_id[token] = token_id; + + if (vocab.max_special_token_length < token.size()) { + vocab.max_special_token_length = token.size(); + } + } } void read_tensor_metadata(llama_load_tensors_map & tensors_map) { while (file.tell() < file.size) { @@ -631,6 +655,7 @@ struct llama_file_saver { void write_hparams(enum llama_ftype new_ftype) { const llama_hparams & hparams = any_file_loader->hparams; file.write_u32(hparams.n_vocab); + file.write_u32(hparams.n_vocab_sp); file.write_u32(hparams.n_embd); file.write_u32(hparams.n_mult); file.write_u32(hparams.n_head); @@ -649,6 +674,10 @@ struct llama_file_saver { file.write_raw(token_score.tok.data(), token_score.tok.size()); file.write_raw(&token_score.score, sizeof(token_score.score)); } + uint32_t n_vocab_sp = any_file_loader->hparams.n_vocab_sp; + for (uint32_t i = 0; i < n_vocab; i++) { + file.write_u32(any_file_loader->vocab.special_tokens[i]); + } } void write_tensor(llama_load_tensor & tensor, enum ggml_type new_type, const void * new_data, size_t new_size) { switch (new_type) { @@ -975,7 +1004,8 @@ static const char *llama_file_version_name(llama_file_version version) { case LLAMA_FILE_VERSION_GGMF_V1: return "ggmf v1 (old version with no mmap support)"; case LLAMA_FILE_VERSION_GGJT_V1: return "ggjt v1 (pre #1405)"; case LLAMA_FILE_VERSION_GGJT_V2: return "ggjt v2 (pre #1508)"; - case LLAMA_FILE_VERSION_GGJT_V3: return "ggjt v3 (latest)"; + case LLAMA_FILE_VERSION_GGJT_V3: return "ggjt v3 (pre #1931)"; + case LLAMA_FILE_VERSION_GGJT_V4: return "ggjt v4 (latest)"; } return "unknown"; @@ -1960,18 +1990,20 @@ struct llama_sp_bigram { struct llama_tokenizer { llama_tokenizer(const llama_vocab & vocab): vocab_(vocab) {} - void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) { + void tokenize(const char * text, size_t len, std::vector<llama_vocab::id> & output) { + symbols_.clear(); + // split string into utf8 chars int index = 0; size_t offs = 0; - while (offs < text.size()) { + while (offs < len) { llama_sp_symbol sym; - size_t char_len = std::min(text.size() - offs, utf8_len(text[offs])); - sym.text = text.c_str() + offs; + size_t char_len = std::min(len - offs, utf8_len(text[offs])); + sym.text = text + offs; sym.n = char_len; offs += char_len; sym.prev = index - 1; - sym.next = offs == text.size() ? -1 : index + 1; + sym.next = offs == len ? -1 : index + 1; index++; symbols_.emplace_back(sym); } @@ -2074,7 +2106,33 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co output.push_back(llama_token_bos()); } - tokenizer.tokenize(text, output); + if (vocab.special_token_to_id.empty()) { + tokenizer.tokenize(text.c_str(), text.size(), output); + return output; + } + + auto offsets = vocab.special_token_trie.split(text); + int start = 0; + for (int end : offsets) { + if (start >= end) { + continue; + } + + size_t part_length = end - start; + //printf("\"%.*s\"\n", (int) part_length, text.c_str() + start); + + if (vocab.max_special_token_length < part_length) { + tokenizer.tokenize(text.c_str() + start, part_length, output); + } else { + auto token_it = vocab.special_token_to_id.find(std::string(text.c_str() + start, part_length)); + if (token_it != vocab.special_token_to_id.end()) { + output.push_back(token_it->second); + } else { + tokenizer.tokenize(text.c_str() + start, part_length, output); + } + } + start = end; + } return output; } @@ -4212,6 +4270,10 @@ llama_token llama_token_nl() { return 13; } +bool llama_is_special_token(const struct llama_context *ctx, llama_token token) { + return std::find(ctx->vocab.special_tokens.begin(), ctx->vocab.special_tokens.end(), token) != ctx->vocab.special_tokens.end(); +} + struct llama_timings llama_get_timings(struct llama_context * ctx) { struct llama_timings result = { /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, diff --git a/llama.h b/llama.h index fa1977f2d9492..9ece944d9c566 100644 --- a/llama.h +++ b/llama.h @@ -40,7 +40,7 @@ #define LLAMA_FILE_MAGIC_GGML 0x67676d6cu // 'ggml' #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' -#define LLAMA_FILE_VERSION 3 +#define LLAMA_FILE_VERSION 4 #define LLAMA_FILE_MAGIC LLAMA_FILE_MAGIC_GGJT #define LLAMA_FILE_MAGIC_UNVERSIONED LLAMA_FILE_MAGIC_GGML #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN @@ -373,6 +373,8 @@ extern "C" { LLAMA_API llama_token llama_token_eos(); // end-of-sentence LLAMA_API llama_token llama_token_nl(); // next-line + LLAMA_API bool llama_is_special_token(const struct llama_context * ctx, llama_token token); + // Grammar // LLAMA_API struct llama_grammar * llama_grammar_init( From 0c14627438073e85dbc5d7e41cb8203b61b37b35 Mon Sep 17 00:00:00 2001 From: Igor Pissolati <igo08an@hotmail.com> Date: Mon, 19 Jun 2023 14:52:57 -0300 Subject: [PATCH 02/13] Code cleanup --- llama.cpp | 36 ++++++++++++++---------------------- llama.h | 2 -- 2 files changed, 14 insertions(+), 24 deletions(-) diff --git a/llama.cpp b/llama.cpp index 9908065ee8ce6..d7e0b3174f445 100644 --- a/llama.cpp +++ b/llama.cpp @@ -281,7 +281,6 @@ struct llama_vocab { llama_trie special_token_trie; std::unordered_map<token, id> special_token_to_id; - std::vector<id> special_tokens; size_t max_special_token_length; }; @@ -580,14 +579,13 @@ struct llama_file_loader { for (uint32_t i = 0; i < hparams.n_vocab_sp; i++) { uint32_t token_id = file.read_u32(); - const auto & token = vocab.id_to_token[token_id].tok; + const auto & word = vocab.id_to_token[token_id].tok; - vocab.special_token_trie.add(token); - vocab.special_tokens.push_back(token_id); - vocab.special_token_to_id[token] = token_id; + vocab.special_token_trie.add(word); + vocab.special_token_to_id[word] = token_id; - if (vocab.max_special_token_length < token.size()) { - vocab.max_special_token_length = token.size(); + if (vocab.max_special_token_length < word.size()) { + vocab.max_special_token_length = word.size(); } } } @@ -674,9 +672,8 @@ struct llama_file_saver { file.write_raw(token_score.tok.data(), token_score.tok.size()); file.write_raw(&token_score.score, sizeof(token_score.score)); } - uint32_t n_vocab_sp = any_file_loader->hparams.n_vocab_sp; - for (uint32_t i = 0; i < n_vocab; i++) { - file.write_u32(any_file_loader->vocab.special_tokens[i]); + for (const auto & pair : any_file_loader->vocab.special_token_to_id) { + file.write_u32(pair.second); } } void write_tensor(llama_load_tensor & tensor, enum ggml_type new_type, const void * new_data, size_t new_size) { @@ -2111,24 +2108,23 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co return output; } - auto offsets = vocab.special_token_trie.split(text); + std::vector<int> offsets = vocab.special_token_trie.split(text); int start = 0; for (int end : offsets) { if (start >= end) { continue; } - size_t part_length = end - start; - //printf("\"%.*s\"\n", (int) part_length, text.c_str() + start); - - if (vocab.max_special_token_length < part_length) { - tokenizer.tokenize(text.c_str() + start, part_length, output); + const char *part = text.c_str() + start; + size_t part_len = end - start; + if (vocab.max_special_token_length < part_len) { + tokenizer.tokenize(part, part_len, output); } else { - auto token_it = vocab.special_token_to_id.find(std::string(text.c_str() + start, part_length)); + auto token_it = vocab.special_token_to_id.find(std::string(part, part_len)); if (token_it != vocab.special_token_to_id.end()) { output.push_back(token_it->second); } else { - tokenizer.tokenize(text.c_str() + start, part_length, output); + tokenizer.tokenize(part, part_len, output); } } start = end; @@ -4270,10 +4266,6 @@ llama_token llama_token_nl() { return 13; } -bool llama_is_special_token(const struct llama_context *ctx, llama_token token) { - return std::find(ctx->vocab.special_tokens.begin(), ctx->vocab.special_tokens.end(), token) != ctx->vocab.special_tokens.end(); -} - struct llama_timings llama_get_timings(struct llama_context * ctx) { struct llama_timings result = { /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, diff --git a/llama.h b/llama.h index 9ece944d9c566..40d0737a2a6d8 100644 --- a/llama.h +++ b/llama.h @@ -373,8 +373,6 @@ extern "C" { LLAMA_API llama_token llama_token_eos(); // end-of-sentence LLAMA_API llama_token llama_token_nl(); // next-line - LLAMA_API bool llama_is_special_token(const struct llama_context * ctx, llama_token token); - // Grammar // LLAMA_API struct llama_grammar * llama_grammar_init( From 7f9d720105270eb362b3120b39d3ffb2bf41ce11 Mon Sep 17 00:00:00 2001 From: Igor Pissolati <igo08an@hotmail.com> Date: Mon, 19 Jun 2023 16:00:13 -0300 Subject: [PATCH 03/13] Better loading of special tokens from jsons --- convert.py | 37 ++++++++++++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/convert.py b/convert.py index 8bc06120dc84e..7a35a99a2ae65 100644 --- a/convert.py +++ b/convert.py @@ -243,7 +243,7 @@ def load(model_plus: 'ModelPlus') -> 'Params': class SentencePieceVocab: - def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], fname_special_tokens: Optional[Path], vocabtype: Optional[str]) -> None: + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], fname_special_tokens: Optional[Path], fname_tokenizer_config: Optional[Path], vocabtype: Optional[str]) -> None: self.vocabtype = vocabtype if self.vocabtype == "bpe": self.sentencepiece_tokenizer = json.loads(open(str(fname_tokenizer)).read()) @@ -268,13 +268,40 @@ def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], fn self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list) self.fname_tokenizer = fname_tokenizer self.fname_added_tokens = fname_added_tokens - special_tokens: Dict[str, Dict[str, Any]] + self.special_tokens_map: Dict[int, str] = {} + + TOKEN_NAME_TO_ID: Dict[str, int] = { + "unk_token": self.sentencepiece_tokenizer.unk_id(), + "bos_token": self.sentencepiece_tokenizer.bos_id(), + "eos_token": self.sentencepiece_tokenizer.eos_id(), + "pad_token": self.sentencepiece_tokenizer.pad_id() + } + + tokenizer_config: Dict[str, Any] + if fname_tokenizer_config is not None: + tokenizer_config = json.load(open(fname_tokenizer_config)) + else: + tokenizer_config = {} + for key, value in tokenizer_config.items(): + assert isinstance(value, dict) or isinstance(value, str) + if key not in TOKEN_NAME_TO_ID or TOKEN_NAME_TO_ID[key] == -1: + continue + self.special_tokens_map[TOKEN_NAME_TO_ID[key]] = value["content"] if isinstance(value, dict) else value + + special_tokens: Dict[str, Any] if fname_special_tokens is not None: special_tokens = json.load(open(fname_special_tokens)) else: special_tokens = {} - token_name_to_id = {"unk_token": self.sentencepiece_tokenizer.unk_id(), "bos_token": self.sentencepiece_tokenizer.bos_id(), "eos_token": self.sentencepiece_tokenizer.eos_id(), "pad_token": self.sentencepiece_tokenizer.pad_id()} - self.special_tokens_map = {token_name_to_id[token_name]: info["content"] if isinstance(info, dict) else info for token_name, info in special_tokens.items() if token_name in token_name_to_id and token_name_to_id[token_name] != -1} + for key, value in special_tokens.items(): + assert isinstance(value, dict) or isinstance(value, str) + if key not in TOKEN_NAME_TO_ID: + continue + token_id = TOKEN_NAME_TO_ID[key] + if token_id == -1 or token_id in self.special_tokens_map: + continue + self.special_tokens_map[token_id] = value["content"] if isinstance(value, dict) else value + self.vocab_special_size: int = len(self.added_tokens_list) + len(self.special_tokens_map) def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float]]: @@ -1282,7 +1309,7 @@ def load_vocab(path: Path, vocabtype: Optional[str]) -> SentencePieceVocab: special_tokens_path = path.parent / "special_tokens_map.json" tokenizer_config_path = path.parent / "tokenizer_config.json" print(f"Loading vocab file {path}") - return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None, special_tokens_path if special_tokens_path.exists() else tokenizer_config_path if tokenizer_config_path.exists() else None, + return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None, special_tokens_path if special_tokens_path.exists() else None, tokenizer_config_path if tokenizer_config_path.exists() else None, vocabtype) From e468e7551563120fa448fa40c558e31b5170b091 Mon Sep 17 00:00:00 2001 From: Igor Pissolati <igo08an@hotmail.com> Date: Mon, 19 Jun 2023 23:03:58 -0300 Subject: [PATCH 04/13] Remove trailing whitespaces --- llama-util.h | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/llama-util.h b/llama-util.h index a3ec8a501f413..9c38bddd0aacb 100644 --- a/llama-util.h +++ b/llama-util.h @@ -545,7 +545,7 @@ typedef llama_buffer llama_ctx_buffer; struct llama_trie_node { llama_trie_node(): is_terminator(false) {} - + std::unordered_map<char, llama_trie_node*> children; bool is_terminator; }; @@ -560,7 +560,7 @@ struct llama_trie { if (word.empty()) { return; } - + llama_trie_node *ref = root_; for (char c : word) { if (ref->children.find(c) == ref->children.end()) { @@ -630,7 +630,7 @@ struct llama_trie { end = lookahead_index; skip = lookahead_index; } - + auto looktrie_pointer_it = looktrie_pointer->children.find(next_char); while (looktrie_pointer_it != looktrie_pointer->children.end()) { looktrie_pointer = looktrie_pointer_it->second; @@ -640,7 +640,7 @@ struct llama_trie { end = lookahead_index; skip = lookahead_index; } - + if (lookahead_index == text.size()) { // End of string break; @@ -649,13 +649,13 @@ struct llama_trie { looktrie_pointer_it = looktrie_pointer->children.find(next_char); } } - + offsets.push_back(start); offsets.push_back(end); reset = true; break; - } - + } + auto trie_pointer_it = trie_pointer->children.find(current_char); if (trie_pointer_it != trie_pointer->children.end()) { // The current character being looked at has a match within the trie @@ -669,12 +669,12 @@ struct llama_trie { state = states.erase(state); } } - + if (reset) { // Clear the full start (we found a real match) states.clear(); } - + // If this character is a starting character within the trie // start keeping track of this partial match. auto children_it = root_->children.find(current_char); @@ -682,7 +682,7 @@ struct llama_trie { states[current] = children_it->second; } } - + // We have a cut at the end with states. for (const auto & state : states) { int start = state.first; @@ -696,7 +696,7 @@ struct llama_trie { break; } } - + offsets.push_back(text.size()); return offsets; } From ca1fc20508c761e680fd22b2742077055a1de79e Mon Sep 17 00:00:00 2001 From: Igor Pissolati <igo08an@hotmail.com> Date: Tue, 20 Jun 2023 01:27:36 -0300 Subject: [PATCH 05/13] Fix issues revealed by CI --- llama-util.h | 39 ++++++++++++++++++++------------------- llama.cpp | 10 +++++----- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/llama-util.h b/llama-util.h index 9c38bddd0aacb..30a6c0eb53eac 100644 --- a/llama-util.h +++ b/llama-util.h @@ -16,6 +16,7 @@ #include <vector> #include <map> #include <unordered_map> +#include <memory> #include <stdexcept> #ifdef __has_include @@ -546,7 +547,7 @@ typedef llama_buffer llama_ctx_buffer; struct llama_trie_node { llama_trie_node(): is_terminator(false) {} - std::unordered_map<char, llama_trie_node*> children; + std::unordered_map<char, std::unique_ptr<llama_trie_node>> children; bool is_terminator; }; @@ -561,24 +562,24 @@ struct llama_trie { return; } - llama_trie_node *ref = root_; + llama_trie_node *ref = root_.get(); for (char c : word) { if (ref->children.find(c) == ref->children.end()) { - ref->children[c] = new llama_trie_node(); + ref->children[c].reset(new llama_trie_node()); } - ref = ref->children[c]; + ref = ref->children[c].get(); } ref->is_terminator = true; } // Will look for the words added to the trie within `text`. Output is the boundaries of the words found. // Note that this trie will match the longest possible word first! - std::vector<int> split(const std::string & text) const { - std::map<int, llama_trie_node*> states; - std::vector<int> offsets{0}; + std::vector<size_t> split(const std::string & text) const { + std::map<size_t, llama_trie_node*> states; + std::vector<size_t> offsets{0}; - int skip = 0; - for (int current = 0; current < text.size(); current++) { + size_t skip = 0; + for (size_t current = 0; current < text.size(); current++) { char current_char = text[current]; if (skip > 0 && current < skip) { // Prevents the lookahead for matching twice @@ -592,7 +593,7 @@ struct llama_trie { // In this case, we already have partial matches (But unfinished) for (auto state = states.begin(); state != states.end(); ) { - int start = state->first; + size_t start = state->first; llama_trie_node *trie_pointer = state->second; if (trie_pointer->is_terminator) { // This is a final match, we need to reset and @@ -603,11 +604,11 @@ struct llama_trie { // Here we are also actively looking for other earlier partial // matches // "[CLS]", "L", we need to match CLS even if L is special - int end = 0; + size_t end = 0; for (const auto & look : states) { - int lookstart = look.first; + size_t lookstart = look.first; llama_trie_node *looktrie_pointer = look.second; - int lookahead_index = 0; + size_t lookahead_index = 0; if (lookstart > start) { // This partial match is later, we can stop looking break; @@ -633,7 +634,7 @@ struct llama_trie { auto looktrie_pointer_it = looktrie_pointer->children.find(next_char); while (looktrie_pointer_it != looktrie_pointer->children.end()) { - looktrie_pointer = looktrie_pointer_it->second; + looktrie_pointer = looktrie_pointer_it->second.get(); lookahead_index++; if (looktrie_pointer->is_terminator) { start = lookstart; @@ -660,7 +661,7 @@ struct llama_trie { if (trie_pointer_it != trie_pointer->children.end()) { // The current character being looked at has a match within the trie // update the pointer (it will be stored back into states later). - trie_pointer = trie_pointer_it->second; + trie_pointer = trie_pointer_it->second.get(); states[start] = trie_pointer; ++state; } else { @@ -679,18 +680,18 @@ struct llama_trie { // start keeping track of this partial match. auto children_it = root_->children.find(current_char); if (current >= skip && children_it != root_->children.end()) { - states[current] = children_it->second; + states[current] = children_it->second.get(); } } // We have a cut at the end with states. for (const auto & state : states) { - int start = state.first; + size_t start = state.first; llama_trie_node *trie_pointer = state.second; if (trie_pointer->is_terminator) { // This is a final match, we need to reset and // store the results in `offsets`. - int end = text.size(); + size_t end = text.size(); offsets.push_back(start); offsets.push_back(end); break; @@ -702,7 +703,7 @@ struct llama_trie { } private: - llama_trie_node *root_; + std::unique_ptr<llama_trie_node> root_; }; #endif diff --git a/llama.cpp b/llama.cpp index d7e0b3174f445..af12931e097da 100644 --- a/llama.cpp +++ b/llama.cpp @@ -281,7 +281,7 @@ struct llama_vocab { llama_trie special_token_trie; std::unordered_map<token, id> special_token_to_id; - size_t max_special_token_length; + size_t max_special_token_length = 0; }; struct llama_model { @@ -578,7 +578,7 @@ struct llama_file_loader { vocab.special_token_to_id.reserve(hparams.n_vocab_sp); for (uint32_t i = 0; i < hparams.n_vocab_sp; i++) { - uint32_t token_id = file.read_u32(); + llama_vocab::id token_id = file.read_u32(); const auto & word = vocab.id_to_token[token_id].tok; vocab.special_token_trie.add(word); @@ -2108,9 +2108,9 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co return output; } - std::vector<int> offsets = vocab.special_token_trie.split(text); - int start = 0; - for (int end : offsets) { + std::vector<size_t> offsets = vocab.special_token_trie.split(text); + size_t start = 0; + for (size_t end : offsets) { if (start >= end) { continue; } From 41a2ed03e7c335f3c573cac118fc07e8e306b6f2 Mon Sep 17 00:00:00 2001 From: Igor Pissolati <igo08an@hotmail.com> Date: Tue, 20 Jun 2023 19:20:53 -0300 Subject: [PATCH 06/13] Ignore unusable json values --- convert.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/convert.py b/convert.py index 7a35a99a2ae65..250659248df9a 100644 --- a/convert.py +++ b/convert.py @@ -283,10 +283,12 @@ def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], fn else: tokenizer_config = {} for key, value in tokenizer_config.items(): - assert isinstance(value, dict) or isinstance(value, str) - if key not in TOKEN_NAME_TO_ID or TOKEN_NAME_TO_ID[key] == -1: + if not isinstance(value, dict) or not isinstance(value, str): continue - self.special_tokens_map[TOKEN_NAME_TO_ID[key]] = value["content"] if isinstance(value, dict) else value + token_id = TOKEN_NAME_TO_ID.get(key, -1) + if token_id == -1: + continue + self.special_tokens_map[token_id] = value["content"] if isinstance(value, dict) else value special_tokens: Dict[str, Any] if fname_special_tokens is not None: @@ -294,10 +296,9 @@ def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], fn else: special_tokens = {} for key, value in special_tokens.items(): - assert isinstance(value, dict) or isinstance(value, str) - if key not in TOKEN_NAME_TO_ID: + if not isinstance(value, dict) or not isinstance(value, str): continue - token_id = TOKEN_NAME_TO_ID[key] + token_id = TOKEN_NAME_TO_ID.get(key, -1) if token_id == -1 or token_id in self.special_tokens_map: continue self.special_tokens_map[token_id] = value["content"] if isinstance(value, dict) else value From f6d5fe3afc1ff5992f1b0b27ecec6b2bb91515e5 Mon Sep 17 00:00:00 2001 From: Igor Pissolati <igo08an@hotmail.com> Date: Thu, 22 Jun 2023 11:29:51 -0300 Subject: [PATCH 07/13] Use some tricks to eliminate the necessity for a new format --- convert.py | 32 +++++++++++++++---------------- llama.cpp | 55 ++++++++++++++++++++++++++---------------------------- llama.h | 2 +- 3 files changed, 43 insertions(+), 46 deletions(-) diff --git a/convert.py b/convert.py index 250659248df9a..4748e262b8d04 100644 --- a/convert.py +++ b/convert.py @@ -142,7 +142,7 @@ def find_n_mult(n_ff: int, n_embd: int) -> int: @dataclass class Params: n_vocab: int - n_vocab_sp:int + n_vocab_base: int n_embd: int n_mult: int n_head: int @@ -170,7 +170,7 @@ def guessed(model: 'LazyModel') -> 'Params': return Params( n_vocab = n_vocab, - n_vocab_sp= n_vocab, + n_vocab_base=n_vocab, n_embd = n_embd, n_mult = 256, n_head = n_head, @@ -193,7 +193,7 @@ def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params': return Params( n_vocab = n_vocab, - n_vocab_sp= n_vocab, + n_vocab_base=n_vocab, n_embd = n_embd, n_mult = n_mult, n_head = n_head, @@ -218,7 +218,7 @@ def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params': return Params( n_vocab = n_vocab, - n_vocab_sp= n_vocab + n_vocab_base=n_vocab, n_embd = n_embd, n_mult = n_mult, n_head = n_head, @@ -283,7 +283,7 @@ def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], fn else: tokenizer_config = {} for key, value in tokenizer_config.items(): - if not isinstance(value, dict) or not isinstance(value, str): + if not isinstance(value, dict) and not isinstance(value, str): continue token_id = TOKEN_NAME_TO_ID.get(key, -1) if token_id == -1: @@ -296,15 +296,13 @@ def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], fn else: special_tokens = {} for key, value in special_tokens.items(): - if not isinstance(value, dict) or not isinstance(value, str): + if not isinstance(value, dict) and not isinstance(value, str): continue token_id = TOKEN_NAME_TO_ID.get(key, -1) if token_id == -1 or token_id in self.special_tokens_map: continue self.special_tokens_map[token_id] = value["content"] if isinstance(value, dict) else value - self.vocab_special_size: int = len(self.added_tokens_list) + len(self.special_tokens_map) - def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float]]: tokenizer = self.sentencepiece_tokenizer if self.vocabtype == "bpe": @@ -361,7 +359,7 @@ def __init__(self, tokens: List[Tuple[bytes, float]]): self.tokens = tokens self.special_tokens = [] self.vocab_size = len(tokens) - self.vocab_special_size = 0 + self.vocab_size_base = 0 def all_tokens(self) -> Iterable[Tuple[bytes, float]]: return self.tokens @@ -1120,17 +1118,21 @@ def __init__(self, fname_out: Path) -> None: def write_file_header(self, params: Params, file_type: GGMLFileType) -> None: self.fout.write(b"ggjt"[::-1]) # magic values = [ - 4, # file version + 1, # file version params.n_vocab, - params.n_vocab_sp, params.n_embd, params.n_mult, params.n_head, params.n_layer, +<<<<<<< HEAD params.n_embd // params.n_head, # rot (obsolete) file_type.value, +======= + params.n_vocab_base | 0xF0000000, # reuse obsolete rot value to store vocab_base + params.file_type.value, +>>>>>>> bfccc62 (Use some tricks to eliminate the necessity for a new format) ] - self.fout.write(struct.pack("i" * len(values), *values)) + self.fout.write(struct.pack("I" * len(values), *values)) def write_tensor_header(self, name: str, shape: Sequence[int], data_type: DataType) -> None: sname = name.encode('utf-8') @@ -1144,13 +1146,11 @@ def write_vocab(self, vocab: Vocab) -> None: self.fout.write(struct.pack("i", len(text))) self.fout.write(text) self.fout.write(struct.pack("f", score)) - for token_id in vocab.all_special_tokens(): - self.fout.write(struct.pack("i", token_id)) @staticmethod def write_vocab_only(fname_out: Path, vocab: Vocab) -> None: of = OutputFile(fname_out) - params = Params(n_vocab=vocab.vocab_size, n_vocab_sp=vocab.vocab_special_size, n_embd=0, n_mult=0, + params = Params(n_vocab=vocab.vocab_size, n_vocab_base=vocab.vocab_size_base, n_embd=0, n_mult=0, n_head=1, n_layer=0) of = OutputFile(fname_out) of.write_file_header(params, file_type=GGMLFileType.AllF32) @@ -1373,7 +1373,7 @@ def main(args_in: Optional[List[str]] = None) -> None: vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent vocab = load_vocab(vocab_dir, args.vocabtype) params = Params.load(model_plus) - params.n_vocab_sp = vocab.vocab_special_size + params.n_vocab_base = vocab.vocab_size_base model = model_plus.model model = do_necessary_conversions(model, params) output_type = pick_output_type(model, args.outtype) diff --git a/llama.cpp b/llama.cpp index af12931e097da..8bbe51009b039 100644 --- a/llama.cpp +++ b/llama.cpp @@ -181,14 +181,13 @@ static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_PER_CONTEXT() // default hparams (LLaMA 7B) struct llama_hparams { uint32_t n_vocab = 32000; - uint32_t n_vocab_sp = 0; + uint32_t n_vocab_base = 32000; uint32_t n_ctx = 512; // this is provided as user input? uint32_t n_embd = 4096; uint32_t n_mult = 256; uint32_t n_head = 32; uint32_t n_head_kv = 32; uint32_t n_layer = 32; - uint32_t n_rot = 64; // LLaMAv2 // TODO: load from model data hparams @@ -499,7 +498,6 @@ enum llama_file_version { LLAMA_FILE_VERSION_GGJT_V1, // added padding LLAMA_FILE_VERSION_GGJT_V2, // changed quantization format LLAMA_FILE_VERSION_GGJT_V3, // changed Q4 and Q8 quantization format - LLAMA_FILE_VERSION_GGJT_V4, // improved support for added/special tokens }; struct llama_file_loader { @@ -515,6 +513,7 @@ struct llama_file_loader { read_hparams(); read_vocab(); read_tensor_metadata(tensors_map); + set_vocab_sp(); } void read_magic() { uint32_t magic = file.read_u32(); @@ -537,7 +536,6 @@ struct llama_file_loader { case 1: file_version = LLAMA_FILE_VERSION_GGJT_V1; return; case 2: file_version = LLAMA_FILE_VERSION_GGJT_V2; return; case 3: file_version = LLAMA_FILE_VERSION_GGJT_V3; return; - case 4: file_version = LLAMA_FILE_VERSION_GGJT_V4; return; } } @@ -546,18 +544,18 @@ struct llama_file_loader { } void read_hparams() { hparams.n_vocab = file.read_u32(); - hparams.n_vocab_sp = file_version >= LLAMA_FILE_VERSION_GGJT_V4 ? file.read_u32() : 0; hparams.n_embd = file.read_u32(); hparams.n_mult = file.read_u32(); hparams.n_head = file.read_u32(); hparams.n_layer = file.read_u32(); - hparams.n_rot = file.read_u32(); + hparams.n_vocab_base = file.read_u32(); + hparams.n_vocab_base = (hparams.n_vocab_base & 0xF0000000) == 0 ? hparams.n_vocab : (hparams.n_vocab_base & ~0xF0000000); // this bitwise operation is necessary for compatibility with older models hparams.ftype = (enum llama_ftype) file.read_u32(); // LLaMAv2 // TODO: read from header hparams.n_head_kv = hparams.n_head; - } +======= void read_vocab() { vocab.id_to_token.resize(hparams.n_vocab); @@ -574,20 +572,6 @@ struct llama_file_loader { tok_score.tok = std::move(word); tok_score.score = score; } - - vocab.special_token_to_id.reserve(hparams.n_vocab_sp); - - for (uint32_t i = 0; i < hparams.n_vocab_sp; i++) { - llama_vocab::id token_id = file.read_u32(); - const auto & word = vocab.id_to_token[token_id].tok; - - vocab.special_token_trie.add(word); - vocab.special_token_to_id[word] = token_id; - - if (vocab.max_special_token_length < word.size()) { - vocab.max_special_token_length = word.size(); - } - } } void read_tensor_metadata(llama_load_tensors_map & tensors_map) { while (file.tell() < file.size) { @@ -634,6 +618,24 @@ struct llama_file_loader { tensors_map.name_to_idx[name] = tensors_map.tensors.size() - 1; } } + void set_vocab_sp() { + uint32_t vocab_sp = 3 + hparams.n_vocab - hparams.n_vocab_base; + vocab.special_token_to_id.reserve(vocab_sp); + for (uint32_t i = 0; i < vocab_sp; i++) { + llama_vocab::id token_id = i > 2 ? hparams.n_vocab_base + i : i; + const auto & word = vocab.id_to_token[token_id].tok; + if (word.empty()) { + continue; + } + + vocab.special_token_trie.add(word); + vocab.special_token_to_id[word] = token_id; + + if (vocab.max_special_token_length < word.size()) { + vocab.max_special_token_length = word.size(); + } + } + } }; struct llama_file_saver { @@ -653,12 +655,11 @@ struct llama_file_saver { void write_hparams(enum llama_ftype new_ftype) { const llama_hparams & hparams = any_file_loader->hparams; file.write_u32(hparams.n_vocab); - file.write_u32(hparams.n_vocab_sp); file.write_u32(hparams.n_embd); file.write_u32(hparams.n_mult); file.write_u32(hparams.n_head); file.write_u32(hparams.n_layer); - file.write_u32(hparams.n_rot); + file.write_u32(hparams.n_vocab_base | 0xF0000000); // this bitwise operation is necessary for compatibility with older models file.write_u32(new_ftype); } void write_vocab() { @@ -672,9 +673,6 @@ struct llama_file_saver { file.write_raw(token_score.tok.data(), token_score.tok.size()); file.write_raw(&token_score.score, sizeof(token_score.score)); } - for (const auto & pair : any_file_loader->vocab.special_token_to_id) { - file.write_u32(pair.second); - } } void write_tensor(llama_load_tensor & tensor, enum ggml_type new_type, const void * new_data, size_t new_size) { switch (new_type) { @@ -1001,8 +999,7 @@ static const char *llama_file_version_name(llama_file_version version) { case LLAMA_FILE_VERSION_GGMF_V1: return "ggmf v1 (old version with no mmap support)"; case LLAMA_FILE_VERSION_GGJT_V1: return "ggjt v1 (pre #1405)"; case LLAMA_FILE_VERSION_GGJT_V2: return "ggjt v2 (pre #1508)"; - case LLAMA_FILE_VERSION_GGJT_V3: return "ggjt v3 (pre #1931)"; - case LLAMA_FILE_VERSION_GGJT_V4: return "ggjt v4 (latest)"; + case LLAMA_FILE_VERSION_GGJT_V3: return "ggjt v3 (latest)"; } return "unknown"; @@ -1127,7 +1124,7 @@ static void llama_model_load_internal( fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head); fprintf(stderr, "%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer); - fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim + fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_embd/hparams.n_head); // a.k.a. n_embd_head, n_head_dim fprintf(stderr, "%s: n_gqa = %u\n", __func__, hparams.n_gqa()); fprintf(stderr, "%s: rnorm_eps = %.1e\n", __func__, hparams.f_rms_norm_eps); fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); diff --git a/llama.h b/llama.h index 40d0737a2a6d8..fa1977f2d9492 100644 --- a/llama.h +++ b/llama.h @@ -40,7 +40,7 @@ #define LLAMA_FILE_MAGIC_GGML 0x67676d6cu // 'ggml' #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' -#define LLAMA_FILE_VERSION 4 +#define LLAMA_FILE_VERSION 3 #define LLAMA_FILE_MAGIC LLAMA_FILE_MAGIC_GGJT #define LLAMA_FILE_MAGIC_UNVERSIONED LLAMA_FILE_MAGIC_GGML #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN From 099119f5329a32fb97a9980cbbae826cc9f9348c Mon Sep 17 00:00:00 2001 From: Igor Pissolati <igo08an@hotmail.com> Date: Mon, 7 Aug 2023 12:59:11 -0300 Subject: [PATCH 08/13] Fixes to rebase --- convert.py | 7 +------ llama.cpp | 4 ++-- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/convert.py b/convert.py index 4748e262b8d04..54dba5979cb38 100644 --- a/convert.py +++ b/convert.py @@ -1124,13 +1124,8 @@ def write_file_header(self, params: Params, file_type: GGMLFileType) -> None: params.n_mult, params.n_head, params.n_layer, -<<<<<<< HEAD - params.n_embd // params.n_head, # rot (obsolete) - file_type.value, -======= params.n_vocab_base | 0xF0000000, # reuse obsolete rot value to store vocab_base - params.file_type.value, ->>>>>>> bfccc62 (Use some tricks to eliminate the necessity for a new format) + file_type.value, ] self.fout.write(struct.pack("I" * len(values), *values)) diff --git a/llama.cpp b/llama.cpp index 8bbe51009b039..c620d98971df9 100644 --- a/llama.cpp +++ b/llama.cpp @@ -555,7 +555,7 @@ struct llama_file_loader { // LLaMAv2 // TODO: read from header hparams.n_head_kv = hparams.n_head; -======= + } void read_vocab() { vocab.id_to_token.resize(hparams.n_vocab); @@ -1442,7 +1442,7 @@ static struct ggml_cgraph * llama_build_graph( const int64_t n_embd_head = hparams.n_embd_head(); const int64_t n_embd_gqa = hparams.n_embd_gqa(); - LLAMA_ASSERT(n_embd_head == hparams.n_rot); + LLAMA_ASSERT(n_embd_head == hparams.n_embd/hparams.n_head); const float freq_base = hparams.rope_freq_base; const float freq_scale = hparams.rope_freq_scale; From d9791bb48b5b5ba819f6c97f0a8c0a2d646961dd Mon Sep 17 00:00:00 2001 From: Igor Pissolati <igo08an@hotmail.com> Date: Mon, 7 Aug 2023 17:30:12 -0300 Subject: [PATCH 09/13] Add C API for adding special tokens --- llama.cpp | 24 +++++++++++++++--------- llama.h | 5 +++++ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/llama.cpp b/llama.cpp index c620d98971df9..3a50090f82881 100644 --- a/llama.cpp +++ b/llama.cpp @@ -281,6 +281,15 @@ struct llama_vocab { llama_trie special_token_trie; std::unordered_map<token, id> special_token_to_id; size_t max_special_token_length = 0; + + void add_special_token(const token & word, id token_id) { + special_token_trie.add(word); + special_token_to_id[word] = token_id; + + if (max_special_token_length < word.size()) { + max_special_token_length = word.size(); + } + } }; struct llama_model { @@ -624,15 +633,8 @@ struct llama_file_loader { for (uint32_t i = 0; i < vocab_sp; i++) { llama_vocab::id token_id = i > 2 ? hparams.n_vocab_base + i : i; const auto & word = vocab.id_to_token[token_id].tok; - if (word.empty()) { - continue; - } - - vocab.special_token_trie.add(word); - vocab.special_token_to_id[word] = token_id; - - if (vocab.max_special_token_length < word.size()) { - vocab.max_special_token_length = word.size(); + if (!word.empty()) { + vocab.add_special_token(word, token_id); } } } @@ -4263,6 +4265,10 @@ llama_token llama_token_nl() { return 13; } +void llama_add_special_token(struct llama_model * model, const char * token, llama_token token_id) { + model->vocab.add_special_token(token, token_id); +} + struct llama_timings llama_get_timings(struct llama_context * ctx) { struct llama_timings result = { /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, diff --git a/llama.h b/llama.h index fa1977f2d9492..519ee716d0e63 100644 --- a/llama.h +++ b/llama.h @@ -373,6 +373,11 @@ extern "C" { LLAMA_API llama_token llama_token_eos(); // end-of-sentence LLAMA_API llama_token llama_token_nl(); // next-line + LLAMA_API void llama_add_special_token( + struct llama_model * model, + const char * token, + llama_token token_id); + // Grammar // LLAMA_API struct llama_grammar * llama_grammar_init( From 6f7dabab441566078446ef868e573cd309fe62be Mon Sep 17 00:00:00 2001 From: Igor Pissolati <igo08an@hotmail.com> Date: Mon, 7 Aug 2023 17:31:13 -0300 Subject: [PATCH 10/13] Add simple test for special tokens --- tests/test-tokenizer-0.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp index 87fde16453d25..c7aeb31a5689e 100644 --- a/tests/test-tokenizer-0.cpp +++ b/tests/test-tokenizer-0.cpp @@ -14,6 +14,8 @@ static const std::map<std::string, std::vector<llama_token>> & k_tests() { " this is 🦙.cpp", { 1, 445, 338, 29871, 243, 162, 169, 156, 29889, 8223, }, }, { "w048 7tuijk dsdfhu", { 1, 29893, 29900, 29946, 29947, 29871, 29955, 9161, 13535, 18031, 2176, 6905, }, }, { "нещо на Български", { 1, 821, 4851, 665, 1386, 29713, 1305, }, }, + { "<🦙>test extra_id_1 test", { 1, 32003, 1688, 29871, 32001, 259, 1243, }, }, + { "<🦙>test extra_id_100 test", { 1, 32003, 1688, 29871, 32002, 1243, }, }, }; return _k_tests; }; @@ -46,6 +48,10 @@ int main(int argc, char **argv) { return 1; } + llama_add_special_token(model, "extra_id_1", 32001); + llama_add_special_token(model, "extra_id_100", 32002); + llama_add_special_token(model, "<🦙>", 32003); + ctx = llama_new_context_with_model(model, lparams); if (ctx == NULL) { From 4fc3776ceb283f86a08dd5ea9ecd0353df8a7db3 Mon Sep 17 00:00:00 2001 From: Igor Pissolati <igo08an@hotmail.com> Date: Mon, 7 Aug 2023 18:30:24 -0300 Subject: [PATCH 11/13] Add another test case --- tests/test-tokenizer-0.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp index c7aeb31a5689e..3472180343c24 100644 --- a/tests/test-tokenizer-0.cpp +++ b/tests/test-tokenizer-0.cpp @@ -14,8 +14,9 @@ static const std::map<std::string, std::vector<llama_token>> & k_tests() { " this is 🦙.cpp", { 1, 445, 338, 29871, 243, 162, 169, 156, 29889, 8223, }, }, { "w048 7tuijk dsdfhu", { 1, 29893, 29900, 29946, 29947, 29871, 29955, 9161, 13535, 18031, 2176, 6905, }, }, { "нещо на Български", { 1, 821, 4851, 665, 1386, 29713, 1305, }, }, - { "<🦙>test extra_id_1 test", { 1, 32003, 1688, 29871, 32001, 259, 1243, }, }, - { "<🦙>test extra_id_100 test", { 1, 32003, 1688, 29871, 32002, 1243, }, }, + { "<🦙>test extra_id_1 test", { 1, 32004, 1688, 29871, 32001, 259, 1243, }, }, + { "<🦙>test extra_id_100 test", { 1, 32004, 1688, 29871, 32002, 1243, }, }, + { "<🦙>test extra_id_200 test", { 1, 32004, 1688, 321, 32003, 1243, }, }, }; return _k_tests; }; @@ -50,7 +51,8 @@ int main(int argc, char **argv) { llama_add_special_token(model, "extra_id_1", 32001); llama_add_special_token(model, "extra_id_100", 32002); - llama_add_special_token(model, "<🦙>", 32003); + llama_add_special_token(model, "xtra_id_200", 32003); + llama_add_special_token(model, "<🦙>", 32004); ctx = llama_new_context_with_model(model, lparams); From ada6cce40fb82b90d6d860653623718082760d6f Mon Sep 17 00:00:00 2001 From: Igor Pissolati <igo08an@hotmail.com> Date: Tue, 8 Aug 2023 11:43:29 -0300 Subject: [PATCH 12/13] Replace trie with linear search --- llama-util.h | 165 --------------------------------------------------- llama.cpp | 36 ++++++++++- 2 files changed, 33 insertions(+), 168 deletions(-) diff --git a/llama-util.h b/llama-util.h index 30a6c0eb53eac..6e9e39ddb6f58 100644 --- a/llama-util.h +++ b/llama-util.h @@ -14,9 +14,6 @@ #include <string> #include <vector> -#include <map> -#include <unordered_map> -#include <memory> #include <stdexcept> #ifdef __has_include @@ -544,166 +541,4 @@ struct llama_ctx_buffer { typedef llama_buffer llama_ctx_buffer; #endif -struct llama_trie_node { - llama_trie_node(): is_terminator(false) {} - - std::unordered_map<char, std::unique_ptr<llama_trie_node>> children; - bool is_terminator; -}; - -// Trie in C++. Creates a Trie out of a list of words. The trie is used to split on multiple delimiters in one pass -// Ported from: https://github.com/huggingface/transformers/blob/ee88ae59940fd4b2c8fc119373143d7a1175c651/src/transformers/tokenization_utils.py#L52 -struct llama_trie { -public: - llama_trie(): root_(new llama_trie_node()) {} - - void add(const std::string & word) { - if (word.empty()) { - return; - } - - llama_trie_node *ref = root_.get(); - for (char c : word) { - if (ref->children.find(c) == ref->children.end()) { - ref->children[c].reset(new llama_trie_node()); - } - ref = ref->children[c].get(); - } - ref->is_terminator = true; - } - - // Will look for the words added to the trie within `text`. Output is the boundaries of the words found. - // Note that this trie will match the longest possible word first! - std::vector<size_t> split(const std::string & text) const { - std::map<size_t, llama_trie_node*> states; - std::vector<size_t> offsets{0}; - - size_t skip = 0; - for (size_t current = 0; current < text.size(); current++) { - char current_char = text[current]; - if (skip > 0 && current < skip) { - // Prevents the lookahead for matching twice - // like extra_id_100 and id_100 - continue; - } - - // Whenever we found a match, we need to drop everything - // this is a greedy algorithm, it will match on the first found token - bool reset = false; - - // In this case, we already have partial matches (But unfinished) - for (auto state = states.begin(); state != states.end(); ) { - size_t start = state->first; - llama_trie_node *trie_pointer = state->second; - if (trie_pointer->is_terminator) { - // This is a final match, we need to reset and - // store the results in `offsets`. - - // Lookahead to match longest first - // Important in case of extra_id_1 vs extra_id_100 - // Here we are also actively looking for other earlier partial - // matches - // "[CLS]", "L", we need to match CLS even if L is special - size_t end = 0; - for (const auto & look : states) { - size_t lookstart = look.first; - llama_trie_node *looktrie_pointer = look.second; - size_t lookahead_index = 0; - if (lookstart > start) { - // This partial match is later, we can stop looking - break; - } - if (lookstart < start) { - // This partial match is earlier, the trie pointer - // was already updated, so index is + 1 - lookahead_index = current + 1; - end = current + 1; - } else { - // Here lookstart == start and - // looktrie_pointer == trie_pointer - // It wasn't updated yet so indices are current ones - lookahead_index = current; - end = current; - } - char next_char = lookahead_index < text.size() ? text[lookahead_index] : '\0'; - if (looktrie_pointer->is_terminator) { - start = lookstart; - end = lookahead_index; - skip = lookahead_index; - } - - auto looktrie_pointer_it = looktrie_pointer->children.find(next_char); - while (looktrie_pointer_it != looktrie_pointer->children.end()) { - looktrie_pointer = looktrie_pointer_it->second.get(); - lookahead_index++; - if (looktrie_pointer->is_terminator) { - start = lookstart; - end = lookahead_index; - skip = lookahead_index; - } - - if (lookahead_index == text.size()) { - // End of string - break; - } - next_char = text[lookahead_index]; - looktrie_pointer_it = looktrie_pointer->children.find(next_char); - } - } - - offsets.push_back(start); - offsets.push_back(end); - reset = true; - break; - } - - auto trie_pointer_it = trie_pointer->children.find(current_char); - if (trie_pointer_it != trie_pointer->children.end()) { - // The current character being looked at has a match within the trie - // update the pointer (it will be stored back into states later). - trie_pointer = trie_pointer_it->second.get(); - states[start] = trie_pointer; - ++state; - } else { - // The new character has not match in the trie, we need - // to stop keeping track of this partial match. - state = states.erase(state); - } - } - - if (reset) { - // Clear the full start (we found a real match) - states.clear(); - } - - // If this character is a starting character within the trie - // start keeping track of this partial match. - auto children_it = root_->children.find(current_char); - if (current >= skip && children_it != root_->children.end()) { - states[current] = children_it->second.get(); - } - } - - // We have a cut at the end with states. - for (const auto & state : states) { - size_t start = state.first; - llama_trie_node *trie_pointer = state.second; - if (trie_pointer->is_terminator) { - // This is a final match, we need to reset and - // store the results in `offsets`. - size_t end = text.size(); - offsets.push_back(start); - offsets.push_back(end); - break; - } - } - - offsets.push_back(text.size()); - return offsets; - } - -private: - std::unique_ptr<llama_trie_node> root_; -}; - #endif diff --git a/llama.cpp b/llama.cpp index 3a50090f82881..3b6d23eac572c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -278,12 +278,10 @@ struct llama_vocab { std::unordered_map<token, id> token_to_id; std::vector<token_score> id_to_token; - llama_trie special_token_trie; std::unordered_map<token, id> special_token_to_id; size_t max_special_token_length = 0; void add_special_token(const token & word, id token_id) { - special_token_trie.add(word); special_token_to_id[word] = token_id; if (max_special_token_length < word.size()) { @@ -2090,6 +2088,38 @@ struct llama_tokenizer { llama_sp_bigram::queue work_queue_; }; +static std::vector<size_t> llama_split_special_tokens(const llama_vocab & vocab, const std::string & text) { + std::vector<size_t> offsets{0}; + size_t start = 0; + + while (start < text.size()) { + size_t max_end = start; + const std::string * max_delimiter = nullptr; + + for (const auto & mit : vocab.special_token_to_id) { + const std::string & delimiter = mit.first; + size_t end = start + delimiter.size(); + if (end <= text.size() && text.compare(start, delimiter.size(), delimiter) == 0) { + if (max_delimiter == nullptr || delimiter.size() > max_delimiter->size()) { + max_end = end; + max_delimiter = &delimiter; + } + } + } + + if (max_delimiter != nullptr) { + offsets.push_back(start); + offsets.push_back(max_end); + start = max_end; + } else { + start++; + } + } + + offsets.push_back(text.size()); + return offsets; +} + static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) { llama_tokenizer tokenizer(vocab); std::vector<llama_vocab::id> output; @@ -2107,7 +2137,7 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co return output; } - std::vector<size_t> offsets = vocab.special_token_trie.split(text); + std::vector<size_t> offsets = llama_split_special_tokens(vocab, text); size_t start = 0; for (size_t end : offsets) { if (start >= end) { From 465cadd44c59bad2d3037d896ae36f5396272c72 Mon Sep 17 00:00:00 2001 From: Igor Pissolati <igo08an@hotmail.com> Date: Tue, 8 Aug 2023 12:46:18 -0300 Subject: [PATCH 13/13] Refactor special tokens tokenization --- llama.cpp | 82 ++++++++++++++++++++----------------------------------- 1 file changed, 29 insertions(+), 53 deletions(-) diff --git a/llama.cpp b/llama.cpp index 3b6d23eac572c..44104be66d710 100644 --- a/llama.cpp +++ b/llama.cpp @@ -279,14 +279,9 @@ struct llama_vocab { std::vector<token_score> id_to_token; std::unordered_map<token, id> special_token_to_id; - size_t max_special_token_length = 0; void add_special_token(const token & word, id token_id) { special_token_to_id[word] = token_id; - - if (max_special_token_length < word.size()) { - max_special_token_length = word.size(); - } } }; @@ -2088,38 +2083,6 @@ struct llama_tokenizer { llama_sp_bigram::queue work_queue_; }; -static std::vector<size_t> llama_split_special_tokens(const llama_vocab & vocab, const std::string & text) { - std::vector<size_t> offsets{0}; - size_t start = 0; - - while (start < text.size()) { - size_t max_end = start; - const std::string * max_delimiter = nullptr; - - for (const auto & mit : vocab.special_token_to_id) { - const std::string & delimiter = mit.first; - size_t end = start + delimiter.size(); - if (end <= text.size() && text.compare(start, delimiter.size(), delimiter) == 0) { - if (max_delimiter == nullptr || delimiter.size() > max_delimiter->size()) { - max_end = end; - max_delimiter = &delimiter; - } - } - } - - if (max_delimiter != nullptr) { - offsets.push_back(start); - offsets.push_back(max_end); - start = max_end; - } else { - start++; - } - } - - offsets.push_back(text.size()); - return offsets; -} - static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) { llama_tokenizer tokenizer(vocab); std::vector<llama_vocab::id> output; @@ -2137,27 +2100,40 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co return output; } - std::vector<size_t> offsets = llama_split_special_tokens(vocab, text); - size_t start = 0; - for (size_t end : offsets) { - if (start >= end) { - continue; + size_t delim_start = 0; + size_t last_delim_end = 0; + + while (delim_start < text.size()) { + size_t delim_end = 0; + llama_vocab::id token_id = -1; + + for (const auto & mit : vocab.special_token_to_id) { + const std::string & delimiter = mit.first; + size_t end = delim_start + delimiter.size(); + if (end <= text.size() && text.compare(delim_start, delimiter.size(), delimiter) == 0) { + if (token_id == -1 || end > delim_end) { + token_id = mit.second; + delim_end = end; + } + } } - const char *part = text.c_str() + start; - size_t part_len = end - start; - if (vocab.max_special_token_length < part_len) { - tokenizer.tokenize(part, part_len, output); - } else { - auto token_it = vocab.special_token_to_id.find(std::string(part, part_len)); - if (token_it != vocab.special_token_to_id.end()) { - output.push_back(token_it->second); - } else { - tokenizer.tokenize(part, part_len, output); + if (token_id != -1) { + if (last_delim_end < delim_start) { + tokenizer.tokenize(text.c_str() + last_delim_end, delim_start - last_delim_end, output); } + output.push_back(token_id); + delim_start = delim_end; + last_delim_end = delim_end; + } else { + delim_start++; } - start = end; } + + if (last_delim_end < text.size()) { + tokenizer.tokenize(text.c_str() + last_delim_end, text.size() - last_delim_end, output); + } + return output; }