diff --git a/.gitignore b/.gitignore index 50ae0973ae3b3..049efd703bf80 100644 --- a/.gitignore +++ b/.gitignore @@ -105,6 +105,7 @@ examples/jeopardy/results.txt examples/server/*.html.hpp examples/server/*.js.hpp examples/server/*.mjs.hpp +examples/server/*.css.hpp poetry.lock poetry.toml diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index ad071b97404f7..0a52d577f2aa8 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2,19 +2,28 @@ from __future__ import annotations -import logging import argparse import contextlib import json +import logging +import math import os import re import sys -from enum import IntEnum -from pathlib import Path from hashlib import sha256 -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ContextManager, + Iterable, + Iterator, + Sequence, + TypeVar, + cast, +) -import math import numpy as np import torch @@ -22,23 +31,13 @@ from torch import Tensor if 'NO_LOCAL_GGUF' not in os.environ: - sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) + sys.path.insert(1, str(Path('gguf-py'))) import gguf logger = logging.getLogger("hf-to-gguf") ###### MODEL DEFINITIONS ###### - -class SentencePieceTokenTypes(IntEnum): - NORMAL = 1 - UNKNOWN = 2 - CONTROL = 3 - USER_DEFINED = 4 - UNUSED = 5 - BYTE = 6 - - AnyModel = TypeVar("AnyModel", bound="type[Model]") @@ -406,94 +405,43 @@ def get_vocab_base_pre(self, tokenizer) -> str: # is specific for the BPE pre-tokenizer used by the model # we will use this unique identifier to write a "tokenizer.ggml.pre" entry in the GGUF file which we can # use in llama.cpp to implement the same pre-tokenizer - - chktxt = '\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶\u200d🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български \'\'\'\'\'\'```````""""......!!!!!!?????? I\'ve been \'told he\'s there, \'RE you sure? \'M not sure I\'ll make it, \'D you like some tea? We\'Ve a\'lL' - - chktok = tokenizer.encode(chktxt) - chkhsh = sha256(str(chktok).encode()).hexdigest() - - logger.debug(f"chktok: {chktok}") - logger.debug(f"chkhsh: {chkhsh}") - - res = None - - # NOTE: if you get an error here, you need to update the convert-hf-to-gguf-update.py script - # or pull the latest version of the model from Huggingface - # don't edit the hashes manually! - if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5": - # ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B - res = "llama-bpe" - if chkhsh == "049ecf7629871e3041641907f3de7c733e4dbfdc736f57d882ba0b0845599754": - # ref: https://huggingface.co/deepseek-ai/deepseek-llm-7b-base - res = "deepseek-llm" - if chkhsh == "347715f544604f9118bb75ed199f68779f423cabb20db6de6f31b908d04d7821": - # ref: https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base - res = "deepseek-coder" - if chkhsh == "8aeee3860c56296a157a1fe2fad249ec40aa59b1bb5709f4ade11c4e6fe652ed": - # ref: https://huggingface.co/tiiuae/falcon-7b - res = "falcon" - if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f": - # ref: https://huggingface.co/BAAI/bge-small-en-v1.5 - res = "bert-bge" - if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166": - # ref: https://huggingface.co/mosaicml/mpt-7b - res = "mpt" - if chkhsh == "35d91631860c815f952d711435f48d356ebac988362536bed955d43bfa436e34": - # ref: https://huggingface.co/bigcode/starcoder2-3b - res = "starcoder" - if chkhsh == "3ce83efda5659b07b1ad37ca97ca5797ea4285d9b9ab0dc679e4a720c9da7454": - # ref: https://huggingface.co/openai-community/gpt2 - res = "gpt-2" - if chkhsh == "32d85c31273f8019248f2559fed492d929ea28b17e51d81d3bb36fff23ca72b3": - # ref: https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b - res = "stablelm2" - if chkhsh == "6221ad2852e85ce96f791f476e0b390cf9b474c9e3d1362f53a24a06dc8220ff": - # ref: https://huggingface.co/smallcloudai/Refact-1_6-base - res = "refact" - if chkhsh == "9c2227e4dd922002fb81bde4fc02b0483ca4f12911410dee2255e4987644e3f8": - # ref: https://huggingface.co/CohereForAI/c4ai-command-r-v01 - res = "command-r" - if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea": - # ref: https://huggingface.co/Qwen/Qwen1.5-7B - res = "qwen2" - if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166": - # ref: https://huggingface.co/allenai/OLMo-1.7-7B-hf - res = "olmo" - if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e": - # ref: https://huggingface.co/databricks/dbrx-base - res = "dbrx" - if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f": - # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-en - res = "jina-v2-en" - if chkhsh == "171aeeedd6fb548d418a7461d053f11b6f1f1fc9b387bd66640d28a4b9f5c643": - # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-es - res = "jina-v2-es" - if chkhsh == "27949a2493fc4a9f53f5b9b029c82689cfbe5d3a1929bb25e043089e28466de6": - # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-de - res = "jina-v2-de" - if chkhsh == "c136ed14d01c2745d4f60a9596ae66800e2b61fa45643e72436041855ad4089d": - # ref: https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct - res = "smaug-bpe" - - if res is None: - logger.warning("\n") - logger.warning("**************************************************************************************") - logger.warning("** WARNING: The BPE pre-tokenizer was not recognized!") - logger.warning("** There are 2 possible reasons for this:") - logger.warning("** - the model has not been added to convert-hf-to-gguf-update.py yet") - logger.warning("** - the pre-tokenization config has changed upstream") - logger.warning("** Check your model files and convert-hf-to-gguf-update.py and update them accordingly.") - logger.warning("** ref: https://github.com/ggerganov/llama.cpp/pull/6920") - logger.warning("**") - logger.warning(f"** chkhsh: {chkhsh}") - logger.warning("**************************************************************************************") - logger.warning("\n") - raise NotImplementedError("BPE pre-tokenizer was not recognized - update get_vocab_base_pre()") - - logger.debug(f"tokenizer.ggml.pre: {repr(res)}") - logger.debug(f"chkhsh: {chkhsh}") - - return res + checksum = sha256(str(tokenizer.vocab).encode()).hexdigest() + logger.debug(f"checksum: {checksum}") + + # NOTE: IF you get an error here: + # Update the huggingface_hub.py module and add the vocab, model, and repo. + # Run the `gguf-py/scripts/gguf-gen-pre.py` script to generate the checksums. + # This script should ideally pull in the latest version of the model from HuggingFace. + # DO NOT MANUALLY EDIT THIS METHOD! + models = json.load(f"{tokenizer.name_or_path}/checksums.json") + for model in models: + if checksum == model["checksum"]: + pre = None + if model["tokt"] == gguf.TokenizerType.BPE: + pre = "bpe" + elif model["tokt"] == gguf.TokenizerType.SPM: + pre = "spm" + elif model["tokt"] == gguf.TokenizerType.WPM: + pre = "wpm" + else: + raise KeyError() + logger.debug(f"tokenizer checksum: {checksum}") + logger.debug(f"tokenizer.ggml.pre: {pre}") + return pre # NOTE: Use the enum to id the vocab + + logger.warning("\n") + logger.warning("**************************************************************************************") + logger.warning("** WARNING: The BPE pre-tokenizer was not recognized!") + logger.warning("** There are 2 possible reasons for this:") + logger.warning("** - the model has not been added to convert-hf-to-gguf-update.py yet") + logger.warning("** - the pre-tokenization config has changed upstream") + logger.warning("** Check your model files and convert-hf-to-gguf-update.py and update them accordingly.") + logger.warning("** ref: https://github.com/ggerganov/llama.cpp/pull/6920") + logger.warning("**") + logger.warning(f"** tokenizer checksum: {checksum}") + logger.warning("**************************************************************************************") + logger.warning("\n") + raise NotImplementedError("BPE pre-tokenizer was not recognized - update get_vocab_base_pre()") # Marker: End get_vocab_base_pre def _set_vocab_gpt2(self) -> None: @@ -579,22 +527,22 @@ def _set_vocab_sentencepiece(self): tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] scores: list[float] = [-10000.0] * vocab_size - toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size + toktypes: list[int] = [gguf.TokenType.UNKNOWN] * vocab_size for token_id in range(tokenizer.vocab_size()): piece = tokenizer.IdToPiece(token_id) text = piece.encode("utf-8") score = tokenizer.GetScore(token_id) - toktype = SentencePieceTokenTypes.NORMAL + toktype = gguf.TokenType.NORMAL if tokenizer.IsUnknown(token_id): - toktype = SentencePieceTokenTypes.UNKNOWN + toktype = gguf.TokenType.UNKNOWN elif tokenizer.IsControl(token_id): - toktype = SentencePieceTokenTypes.CONTROL + toktype = gguf.TokenType.CONTROL elif tokenizer.IsUnused(token_id): - toktype = SentencePieceTokenTypes.UNUSED + toktype = gguf.TokenType.UNUSED elif tokenizer.IsByte(token_id): - toktype = SentencePieceTokenTypes.BYTE + toktype = gguf.TokenType.BYTE tokens[token_id] = text scores[token_id] = score @@ -612,7 +560,7 @@ def _set_vocab_sentencepiece(self): tokens[token_id] = key.encode("utf-8") scores[token_id] = -1000.0 - toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + toktypes[token_id] = gguf.TokenType.USER_DEFINED if vocab_size > len(tokens): pad_count = vocab_size - len(tokens) @@ -620,7 +568,7 @@ def _set_vocab_sentencepiece(self): for i in range(1, pad_count + 1): tokens.append(bytes(f"[PAD{i}]", encoding="utf-8")) scores.append(-1000.0) - toktypes.append(SentencePieceTokenTypes.UNUSED) + toktypes.append(gguf.TokenType.UNUSED) self.gguf_writer.add_tokenizer_model("llama") self.gguf_writer.add_tokenizer_pre("default") @@ -1753,7 +1701,7 @@ def set_vocab(self): tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] scores: list[float] = [-10000.0] * vocab_size - toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size + toktypes: list[int] = [gguf.TokenType.UNKNOWN] * vocab_size for token_id in range(tokenizer.vocab_size()): @@ -1761,15 +1709,15 @@ def set_vocab(self): text = piece.encode("utf-8") score = tokenizer.GetScore(token_id) - toktype = SentencePieceTokenTypes.NORMAL + toktype = gguf.TokenType.NORMAL if tokenizer.IsUnknown(token_id): - toktype = SentencePieceTokenTypes.UNKNOWN + toktype = gguf.TokenType.UNKNOWN elif tokenizer.IsControl(token_id): - toktype = SentencePieceTokenTypes.CONTROL + toktype = gguf.TokenType.CONTROL elif tokenizer.IsUnused(token_id): - toktype = SentencePieceTokenTypes.UNUSED + toktype = gguf.TokenType.UNUSED elif tokenizer.IsByte(token_id): - toktype = SentencePieceTokenTypes.BYTE + toktype = gguf.TokenType.BYTE tokens[token_id] = text scores[token_id] = score @@ -1788,7 +1736,7 @@ def set_vocab(self): tokens[token_id] = key.encode("utf-8") scores[token_id] = -1000.0 - toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + toktypes[token_id] = gguf.TokenType.USER_DEFINED tokenizer_config_file = self.dir_model / 'tokenizer_config.json' if tokenizer_config_file.is_file(): @@ -1798,13 +1746,13 @@ def set_vocab(self): for token_id, foken_data in added_tokens_decoder.items(): token_id = int(token_id) token = foken_data["content"].encode("utf-8") - if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN: + if toktypes[token_id] != gguf.TokenType.UNKNOWN: assert tokens[token_id] == token tokens[token_id] = token scores[token_id] = -1000.0 - toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + toktypes[token_id] = gguf.TokenType.USER_DEFINED if foken_data.get("special"): - toktypes[token_id] = SentencePieceTokenTypes.CONTROL + toktypes[token_id] = gguf.TokenType.CONTROL tokenizer_file = self.dir_model / 'tokenizer.json' if tokenizer_file.is_file(): @@ -1814,13 +1762,13 @@ def set_vocab(self): for foken_data in added_tokens: token_id = int(foken_data["id"]) token = foken_data["content"].encode("utf-8") - if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN: + if toktypes[token_id] != gguf.TokenType.UNKNOWN: assert tokens[token_id] == token tokens[token_id] = token scores[token_id] = -1000.0 - toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + toktypes[token_id] = gguf.TokenType.USER_DEFINED if foken_data.get("special"): - toktypes[token_id] = SentencePieceTokenTypes.CONTROL + toktypes[token_id] = gguf.TokenType.CONTROL self.gguf_writer.add_tokenizer_model("llama") self.gguf_writer.add_tokenizer_pre("default") @@ -2015,15 +1963,15 @@ def set_vocab(self): logger.warning(f"InternLM2 convert token '{text}' to '🐉'!") text = "🐉".encode("utf-8") - toktype = SentencePieceTokenTypes.NORMAL + toktype = gguf.TokenType.NORMAL if tokenizer.IsUnknown(token_id): - toktype = SentencePieceTokenTypes.UNKNOWN + toktype = gguf.TokenType.UNKNOWN elif tokenizer.IsControl(token_id): - toktype = SentencePieceTokenTypes.CONTROL + toktype = gguf.TokenType.CONTROL elif tokenizer.IsUnused(token_id): - toktype = SentencePieceTokenTypes.UNUSED + toktype = gguf.TokenType.UNUSED elif tokenizer.IsByte(token_id): - toktype = SentencePieceTokenTypes.BYTE + toktype = gguf.TokenType.BYTE tokens.append(text) scores.append(score) @@ -2037,7 +1985,7 @@ def set_vocab(self): for key in added_tokens_json: tokens.append(key.encode("utf-8")) scores.append(-1000.0) - toktypes.append(SentencePieceTokenTypes.USER_DEFINED) + toktypes.append(gguf.TokenType.USER_DEFINED) self.gguf_writer.add_tokenizer_model("llama") self.gguf_writer.add_tokenizer_pre("default") @@ -2502,7 +2450,7 @@ def set_vocab(self): tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] scores: list[float] = [-10000.0] * vocab_size - toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size + toktypes: list[int] = [gguf.TokenType.UNKNOWN] * vocab_size for token_id in range(tokenizer.vocab_size()): @@ -2510,15 +2458,15 @@ def set_vocab(self): text = piece.encode("utf-8") score = tokenizer.GetScore(token_id) - toktype = SentencePieceTokenTypes.NORMAL + toktype = gguf.TokenType.NORMAL if tokenizer.IsUnknown(token_id): - toktype = SentencePieceTokenTypes.UNKNOWN + toktype = gguf.TokenType.UNKNOWN elif tokenizer.IsControl(token_id): - toktype = SentencePieceTokenTypes.CONTROL + toktype = gguf.TokenType.CONTROL elif tokenizer.IsUnused(token_id): - toktype = SentencePieceTokenTypes.UNUSED + toktype = gguf.TokenType.UNUSED elif tokenizer.IsByte(token_id): - toktype = SentencePieceTokenTypes.BYTE + toktype = gguf.TokenType.BYTE tokens[token_id] = text scores[token_id] = score @@ -2540,16 +2488,16 @@ def set_vocab(self): continue token_content = token_json["content"] - token_type = SentencePieceTokenTypes.USER_DEFINED + token_type = gguf.TokenType.USER_DEFINED token_score = -10000.0 # Map unk_token to UNKNOWN, other special tokens to CONTROL # Set the score to 0.0 as in the original tokenizer.model if ("special" in token_json) and token_json["special"]: if token_content == tokenizer_config_json["unk_token"]: - token_type = SentencePieceTokenTypes.UNKNOWN + token_type = gguf.TokenType.UNKNOWN else: - token_type = SentencePieceTokenTypes.CONTROL + token_type = gguf.TokenType.CONTROL token_score = 0.0 logger.info(f"Setting added token {token_id} to '{token_content}' (type: {token_type}, score: {token_score:.2f})") diff --git a/generate-vocab.sh b/generate-vocab.sh new file mode 100755 index 0000000000000..22186c186e841 --- /dev/null +++ b/generate-vocab.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +python3 convert-hf-to-gguf.py models/meta-llama/Llama-2-7b-hf --outfile models/meta-llama/Llama-2-7b-hf/ggml-vocab-llama-2-7b-hf.gguf --vocab-only +python3 convert-hf-to-gguf.py models/meta-llama/Meta-Llama-3-8B --outfile models/meta-llama/Meta-Llama-3-8B/ggml-vocab-meta-llama-3-8b.gguf --vocab-only +python3 convert-hf-to-gguf.py models/microsoft/Phi-3-mini-4k-instruct --outfile models/microsoft/Phi-3-mini-4k-instruct/ggml-vocab-phi-3-mini-4k-instruct.gguf --vocab-only +python3 convert-hf-to-gguf.py models/deepseek-ai/deepseek-llm-7b-base --outfile models/deepseek-ai/deepseek-llm-7b-base/ggml-vocab-deepseek-llm-7b-base.gguf --vocab-only +python3 convert-hf-to-gguf.py models/deepseek-ai/deepseek-coder-6.7b-base --outfile models/deepseek-ai/deepseek-coder-6.7b-base/ggml-vocab-deepseek-coder-6.gguf --vocab-only +python3 convert-hf-to-gguf.py models/tiiuae/falcon-7b --outfile models/tiiuae/falcon-7b/ggml-vocab-falcon-7b.gguf --vocab-only +python3 convert-hf-to-gguf.py models/BAAI/bge-small-en-v1.5 --outfile models/BAAI/bge-small-en-v1.5/ggml-vocab-bge-small-en-v1.gguf --vocab-only +python3 convert-hf-to-gguf.py models/mosaicml/mpt-7b --outfile models/mosaicml/mpt-7b/ggml-vocab-mpt-7b.gguf --vocab-only +python3 convert-hf-to-gguf.py models/bigcode/starcoder2-3b --outfile models/bigcode/starcoder2-3b/ggml-vocab-starcoder2-3b.gguf --vocab-only +python3 convert-hf-to-gguf.py models/openai-community/gpt2 --outfile models/openai-community/gpt2/ggml-vocab-gpt2.gguf --vocab-only +python3 convert-hf-to-gguf.py models/smallcloudai/Refact-1_6-base --outfile models/smallcloudai/Refact-1_6-base/ggml-vocab-refact-1_6-base.gguf --vocab-only +python3 convert-hf-to-gguf.py models/CohereForAI/c4ai-command-r-v01 --outfile models/CohereForAI/c4ai-command-r-v01/ggml-vocab-c4ai-command-r-v01.gguf --vocab-only +python3 convert-hf-to-gguf.py models/Qwen/Qwen1.5-7B --outfile models/Qwen/Qwen1.5-7B/ggml-vocab-qwen1.gguf --vocab-only +python3 convert-hf-to-gguf.py models/allenai/OLMo-1.7-7B-hf --outfile models/allenai/OLMo-1.7-7B-hf/ggml-vocab-olmo-1.gguf --vocab-only +# python3 convert-hf-to-gguf.py models/databricks/dbrx-base --outfile models/databricks/dbrx-base/ggml-vocab-dbrx-base.gguf --vocab-only +python3 convert-hf-to-gguf.py models/jinaai/jina-embeddings-v2-base-en --outfile models/jinaai/jina-embeddings-v2-base-en/ggml-vocab-jina-embeddings-v2-base-en.gguf --vocab-only +python3 convert-hf-to-gguf.py models/jinaai/jina-embeddings-v2-base-es --outfile models/jinaai/jina-embeddings-v2-base-es/ggml-vocab-jina-embeddings-v2-base-es.gguf --vocab-only +python3 convert-hf-to-gguf.py models/jinaai/jina-embeddings-v2-base-de --outfile models/jinaai/jina-embeddings-v2-base-de/ggml-vocab-jina-embeddings-v2-base-de.gguf --vocab-only +python3 convert-hf-to-gguf.py models/microsoft/phi-1 --outfile models/microsoft/phi-1/ggml-vocab-phi-1.gguf --vocab-only +python3 convert-hf-to-gguf.py models/stabilityai/stablelm-2-zephyr-1_6b --outfile models/stabilityai/stablelm-2-zephyr-1_6b/ggml-vocab-stablelm-2-zephyr-1_6b.gguf --vocab-only +python3 convert-hf-to-gguf.py models/mistralai/Mistral-7B-Instruct-v0.2 --outfile models/mistralai/Mistral-7B-Instruct-v0.2/ggml-vocab-mistral-7b-instruct-v0.gguf --vocab-only +python3 convert-hf-to-gguf.py models/mistralai/Mixtral-8x7B-Instruct-v0.1 --outfile models/mistralai/Mixtral-8x7B-Instruct-v0.1/ggml-vocab-mixtral-8x7b-instruct-v0.gguf --vocab-only diff --git a/gguf-py/gguf/__init__.py b/gguf-py/gguf/__init__.py index ea5146b161bc8..4b7534f3b0ceb 100644 --- a/gguf-py/gguf/__init__.py +++ b/gguf-py/gguf/__init__.py @@ -1,7 +1,8 @@ from .constants import * -from .lazy import * from .gguf_reader import * from .gguf_writer import * +from .huggingface_hub import * +from .lazy import * from .quants import * from .tensor_mapping import * from .vocab import * diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 55ec2cb5c848a..29358eb121021 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -6,18 +6,16 @@ # # constants # - GGUF_MAGIC = 0x46554747 # "GGUF" GGUF_VERSION = 3 GGUF_DEFAULT_ALIGNMENT = 32 GGML_QUANT_VERSION = 2 # GGML_QNT_VERSION from ggml.h + # -# metadata keys +# model metadata keys # - - -class Keys: +class GGUFMetadataKeys: class General: ARCHITECTURE = "general.architecture" QUANTIZATION_VERSION = "general.quantization_version" @@ -29,8 +27,9 @@ class General: DESCRIPTION = "general.description" LICENSE = "general.license" SOURCE_URL = "general.source.url" - SOURCE_HF_REPO = "general.source.huggingface.repository" + SOURCE_REPO = "general.source.repository" FILE_TYPE = "general.file_type" + ENDIANESS = "general.endianess" class LLM: VOCAB_SIZE = "{arch}.vocab_size" @@ -79,33 +78,35 @@ class SSM: TIME_STEP_RANK = "{arch}.ssm.time_step_rank" class Tokenizer: - MODEL = "tokenizer.ggml.model" - PRE = "tokenizer.ggml.pre" - LIST = "tokenizer.ggml.tokens" - TOKEN_TYPE = "tokenizer.ggml.token_type" - TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types - SCORES = "tokenizer.ggml.scores" - MERGES = "tokenizer.ggml.merges" - BOS_ID = "tokenizer.ggml.bos_token_id" - EOS_ID = "tokenizer.ggml.eos_token_id" - UNK_ID = "tokenizer.ggml.unknown_token_id" - SEP_ID = "tokenizer.ggml.seperator_token_id" - PAD_ID = "tokenizer.ggml.padding_token_id" - CLS_ID = "tokenizer.ggml.cls_token_id" - MASK_ID = "tokenizer.ggml.mask_token_id" - ADD_BOS = "tokenizer.ggml.add_bos_token" - ADD_EOS = "tokenizer.ggml.add_eos_token" - ADD_PREFIX = "tokenizer.ggml.add_space_prefix" - HF_JSON = "tokenizer.huggingface.json" + MODEL = "tokenizer.model" # STRING: e.g. llama, gpt2, etc... + TYPE = "tokenizer.type" # STRING: BPE, SPM, WPM, etc. + NORM = "tokenizer.norm" # OBJECT {"type": "ByteLevel", ...} + PRE = "tokenizer.pre" # OBJECT {"type": "ByteLevel", ...} + ADDED = "tokenizer.added" # ARRAY of OBJECTs: [{"id": 1, ...}, ...] + VOCAB = "tokenizer.vocab" # ARRAY of STRINGs: ["[BOS]", ...] + MERGES = "tokenizer.merges" # ARRAY of STRINGs: ["▁ t", ...] + TOKEN_TYPE = "tokenizer.token_type" # ARRAY of INT [2, ...] + TOKEN_TYPE_COUNT = "tokenizer.token_type_count" # BERT token types + SCORES = "tokenizer.scores" # WPM only + BOS_ID = "tokenizer.bos_token_id" + EOS_ID = "tokenizer.eos_token_id" + UNK_ID = "tokenizer.unknown_token_id" + SEP_ID = "tokenizer.separator_token_id" # Fixed typo + PAD_ID = "tokenizer.padding_token_id" + CLS_ID = "tokenizer.cls_token_id" + MASK_ID = "tokenizer.mask_token_id" + ADD_BOS = "tokenizer.add_bos_token" + ADD_EOS = "tokenizer.add_eos_token" + ADD_PREFIX = "tokenizer.add_space_prefix" RWKV = "tokenizer.rwkv.world" CHAT_TEMPLATE = "tokenizer.chat_template" CHAT_TEMPLATE_N = "tokenizer.chat_template.{name}" CHAT_TEMPLATES = "tokenizer.chat_templates" # FIM/Infill special tokens constants - PREFIX_ID = "tokenizer.ggml.prefix_token_id" - SUFFIX_ID = "tokenizer.ggml.suffix_token_id" - MIDDLE_ID = "tokenizer.ggml.middle_token_id" - EOT_ID = "tokenizer.ggml.eot_token_id" + PREFIX_ID = "tokenizer.prefix_token_id" + SUFFIX_ID = "tokenizer.suffix_token_id" + MIDDLE_ID = "tokenizer.middle_token_id" + EOT_ID = "tokenizer.eot_token_id" # @@ -844,27 +845,17 @@ class MODEL_TENSOR(IntEnum): ], } + # # types # - - -class TokenType(IntEnum): - NORMAL = 1 - UNKNOWN = 2 - CONTROL = 3 - USER_DEFINED = 4 - UNUSED = 5 - BYTE = 6 - - -class RopeScalingType(Enum): +class GGMLRopeScalingType(Enum): NONE = 'none' LINEAR = 'linear' YARN = 'yarn' -class PoolingType(IntEnum): +class GGMLPoolingType(IntEnum): NONE = 0 MEAN = 1 CLS = 2 @@ -907,7 +898,7 @@ class GGMLQuantizationType(IntEnum): # from llama_ftype in llama.h # ALL VALUES SHOULD BE THE SAME HERE AS THEY ARE OVER THERE. -class LlamaFileType(IntEnum): +class GGUFFileType(IntEnum): ALL_F32 = 0 MOSTLY_F16 = 1 # except 1d tensors MOSTLY_Q4_0 = 2 # except 1d tensors @@ -945,39 +936,70 @@ class LlamaFileType(IntEnum): GUESSED = 1024 # not specified in the model file +GGUF_FILE_TYPE_MAP: dict[str, GGUFFileType] = { + "F32" : GGUFFileType.ALL_F32, + "F16" : GGUFFileType.MOSTLY_F16, + "BF16" : GGUFFileType.MOSTLY_BF16, + "Q8_0" : GGUFFileType.MOSTLY_Q8_0, +} + + +GGUF_FILE_TYPE_NAMES: dict[GGUFFileType, str] = { + GGUFFileType.ALL_F32 : "F32", + GGUFFileType.MOSTLY_F16 : "F16", + GGUFFileType.MOSTLY_BF16 : "BF16", + GGUFFileType.MOSTLY_Q8_0 : "Q8_0", +} + + class GGUFEndian(IntEnum): LITTLE = 0 - BIG = 1 + BIG = 1 class GGUFValueType(IntEnum): - UINT8 = 0 - INT8 = 1 - UINT16 = 2 - INT16 = 3 - UINT32 = 4 - INT32 = 5 - FLOAT32 = 6 - BOOL = 7 - STRING = 8 - ARRAY = 9 - UINT64 = 10 - INT64 = 11 - FLOAT64 = 12 + UINT8 = auto() + INT8 = auto() + UINT16 = auto() + INT16 = auto() + UINT32 = auto() + INT32 = auto() + UINT64 = auto() + INT64 = auto() + FLOAT32 = auto() + FLOAT64 = auto() + BOOL = auto() + STRING = auto() + ARRAY = auto() + OBJECT = auto() @staticmethod def get_type(val: Any) -> GGUFValueType: if isinstance(val, (str, bytes, bytearray)): return GGUFValueType.STRING - elif isinstance(val, list): - return GGUFValueType.ARRAY - elif isinstance(val, float): - return GGUFValueType.FLOAT32 + elif isinstance(val, bool): return GGUFValueType.BOOL - elif isinstance(val, int): + + # TODO: Need help with 64-bit types in Python. + # NOTE: Maybe use numpy, e.g. np.dtypes to determine data type? + # Using base types is unreliable in python as all numbers in python are 64-bits. + + # If it's an integer (either signed or unsigned) + if isinstance(val, int): return GGUFValueType.INT32 - # TODO: need help with 64-bit types in Python + + elif isinstance(val, float): + # NOTE: This is unreliable in python as all numbers in python are 64-bits + return GGUFValueType.FLOAT32 + + elif isinstance(val, list): + return GGUFValueType.ARRAY + + elif isinstance(val, dict): + # NOTE: JSON Object, Dict, or Mapping are valid types + return GGUFValueType.OBJECT + else: raise ValueError(f"Unknown type: {type(val)}") @@ -1017,69 +1039,155 @@ def get_type(val: Any) -> GGUFValueType: } +# +# Model File Types +# +class ModelFileExtension(Enum): + PT = ".pt" # torch + PTH = ".pth" # torch + BIN = ".bin" # torch + SAFETENSORS = ".safetensors" # safetensors + JSON = ".json" # transformers/tokenizers + MODEL = ".model" # sentencepiece + GGUF = ".gguf" # ggml/llama.cpp + + +# +# Tokenizer Types +# +class GGUFTokenType(IntEnum): + NORMAL = 1 + UNKNOWN = 2 + CONTROL = 3 + USER_DEFINED = 4 + UNUSED = 5 + BYTE = 6 + + +class HFTokenizerType(Enum): + SPM = "SPM" # SentencePiece LLaMa tokenizer + BPE = "BPE" # BytePair GPT-2 tokenizer + WPM = "WPM" # WordPiece BERT tokenizer + + +# +# Normalizer Types +# +class HFNormalizerType(Enum): + SEQUENCE = "Sequence" + NFC = "NFC" + NFD = "NFD" + NFKC = "NFKC" + NFKD = "NFKD" + + +# +# Pre-tokenizer Types +# +class HFPreTokenizerType(Enum): + WHITESPACE = "Whitespace" + METASPACE = "Metaspace" + BYTE_LEVEL = "ByteLevel" + BERT_PRE_TOKENIZER = "BertPreTokenizer" + SEQUENCE = "Sequence" + + +# +# HF Vocab Files +# +HF_TOKENIZER_BPE_FILES = ( + "config.json", + "tokenizer_config.json", + "tokenizer.json", +) + +HF_TOKENIZER_SPM_FILES: tuple[str, ...] = HF_TOKENIZER_BPE_FILES + ("tokenizer.model",) + +# +# Pre-tokenization Regular Expressions +# + +# NOTE: `tokenizers` defaults to OpenAI GPT-2 `ByteLevel` RegEx. +# The pattern uses perl regex and formatting is arbitrary. +# https://github.com/openai/gpt-2/blob/master/src/encoder.py#L53 +# https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L40-L42 + +# These are fallback values if the pre-tokenizer cannot be dynamically discovered at runtime. +GPT_PRE_TOKENIZER_DEFAULT = ("'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+",) + # Aliases for backward compatibility. # general -KEY_GENERAL_ARCHITECTURE = Keys.General.ARCHITECTURE -KEY_GENERAL_QUANTIZATION_VERSION = Keys.General.QUANTIZATION_VERSION -KEY_GENERAL_ALIGNMENT = Keys.General.ALIGNMENT -KEY_GENERAL_NAME = Keys.General.NAME -KEY_GENERAL_AUTHOR = Keys.General.AUTHOR -KEY_GENERAL_URL = Keys.General.URL -KEY_GENERAL_DESCRIPTION = Keys.General.DESCRIPTION -KEY_GENERAL_LICENSE = Keys.General.LICENSE -KEY_GENERAL_SOURCE_URL = Keys.General.SOURCE_URL -KEY_GENERAL_SOURCE_HF_REPO = Keys.General.SOURCE_HF_REPO -KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE +KEY_GENERAL_ARCHITECTURE = GGUFMetadataKeys.General.ARCHITECTURE +KEY_GENERAL_QUANTIZATION_VERSION = GGUFMetadataKeys.General.QUANTIZATION_VERSION +KEY_GENERAL_ALIGNMENT = GGUFMetadataKeys.General.ALIGNMENT +KEY_GENERAL_NAME = GGUFMetadataKeys.General.NAME +KEY_GENERAL_AUTHOR = GGUFMetadataKeys.General.AUTHOR +KEY_GENERAL_URL = GGUFMetadataKeys.General.URL +KEY_GENERAL_DESCRIPTION = GGUFMetadataKeys.General.DESCRIPTION +KEY_GENERAL_LICENSE = GGUFMetadataKeys.General.LICENSE +KEY_GENERAL_SOURCE_URL = GGUFMetadataKeys.General.SOURCE_URL +KEY_GENERAL_SOURCE_REPO = GGUFMetadataKeys.General.SOURCE_REPO +KEY_GENERAL_FILE_TYPE = GGUFMetadataKeys.General.FILE_TYPE +KEY_GENERAL_ENDIANESS = GGUFMetadataKeys.General.ENDIANESS # LLM -KEY_VOCAB_SIZE = Keys.LLM.VOCAB_SIZE -KEY_CONTEXT_LENGTH = Keys.LLM.CONTEXT_LENGTH -KEY_EMBEDDING_LENGTH = Keys.LLM.EMBEDDING_LENGTH -KEY_BLOCK_COUNT = Keys.LLM.BLOCK_COUNT -KEY_FEED_FORWARD_LENGTH = Keys.LLM.FEED_FORWARD_LENGTH -KEY_USE_PARALLEL_RESIDUAL = Keys.LLM.USE_PARALLEL_RESIDUAL -KEY_TENSOR_DATA_LAYOUT = Keys.LLM.TENSOR_DATA_LAYOUT +KEY_VOCAB_SIZE = GGUFMetadataKeys.LLM.VOCAB_SIZE +KEY_CONTEXT_LENGTH = GGUFMetadataKeys.LLM.CONTEXT_LENGTH +KEY_EMBEDDING_LENGTH = GGUFMetadataKeys.LLM.EMBEDDING_LENGTH +KEY_BLOCK_COUNT = GGUFMetadataKeys.LLM.BLOCK_COUNT +KEY_FEED_FORWARD_LENGTH = GGUFMetadataKeys.LLM.FEED_FORWARD_LENGTH +KEY_USE_PARALLEL_RESIDUAL = GGUFMetadataKeys.LLM.USE_PARALLEL_RESIDUAL +KEY_TENSOR_DATA_LAYOUT = GGUFMetadataKeys.LLM.TENSOR_DATA_LAYOUT # attention -KEY_ATTENTION_HEAD_COUNT = Keys.Attention.HEAD_COUNT -KEY_ATTENTION_HEAD_COUNT_KV = Keys.Attention.HEAD_COUNT_KV -KEY_ATTENTION_MAX_ALIBI_BIAS = Keys.Attention.MAX_ALIBI_BIAS -KEY_ATTENTION_CLAMP_KQV = Keys.Attention.CLAMP_KQV -KEY_ATTENTION_LAYERNORM_EPS = Keys.Attention.LAYERNORM_EPS -KEY_ATTENTION_LAYERNORM_RMS_EPS = Keys.Attention.LAYERNORM_RMS_EPS +KEY_ATTENTION_HEAD_COUNT = GGUFMetadataKeys.Attention.HEAD_COUNT +KEY_ATTENTION_HEAD_COUNT_KV = GGUFMetadataKeys.Attention.HEAD_COUNT_KV +KEY_ATTENTION_MAX_ALIBI_BIAS = GGUFMetadataKeys.Attention.MAX_ALIBI_BIAS +KEY_ATTENTION_CLAMP_KQV = GGUFMetadataKeys.Attention.CLAMP_KQV +KEY_ATTENTION_LAYERNORM_EPS = GGUFMetadataKeys.Attention.LAYERNORM_EPS +KEY_ATTENTION_LAYERNORM_RMS_EPS = GGUFMetadataKeys.Attention.LAYERNORM_RMS_EPS # RoPE -KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT -KEY_ROPE_FREQ_BASE = Keys.Rope.FREQ_BASE -KEY_ROPE_SCALING_TYPE = Keys.Rope.SCALING_TYPE -KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR -KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN -KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED +KEY_ROPE_DIMENSION_COUNT = GGUFMetadataKeys.Rope.DIMENSION_COUNT +KEY_ROPE_FREQ_BASE = GGUFMetadataKeys.Rope.FREQ_BASE +KEY_ROPE_SCALING_TYPE = GGUFMetadataKeys.Rope.SCALING_TYPE +KEY_ROPE_SCALING_FACTOR = GGUFMetadataKeys.Rope.SCALING_FACTOR +KEY_ROPE_SCALING_ORIG_CTX_LEN = GGUFMetadataKeys.Rope.SCALING_ORIG_CTX_LEN +KEY_ROPE_SCALING_FINETUNED = GGUFMetadataKeys.Rope.SCALING_FINETUNED # SSM -KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL -KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE -KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE -KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK +KEY_SSM_CONV_KERNEL = GGUFMetadataKeys.SSM.CONV_KERNEL +KEY_SSM_INNER_SIZE = GGUFMetadataKeys.SSM.INNER_SIZE +KEY_SSM_STATE_SIZE = GGUFMetadataKeys.SSM.STATE_SIZE +KEY_SSM_TIME_STEP_RANK = GGUFMetadataKeys.SSM.TIME_STEP_RANK # tokenization -KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL -KEY_TOKENIZER_PRE = Keys.Tokenizer.PRE -KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST -KEY_TOKENIZER_TOKEN_TYPE = Keys.Tokenizer.TOKEN_TYPE -KEY_TOKENIZER_SCORES = Keys.Tokenizer.SCORES -KEY_TOKENIZER_MERGES = Keys.Tokenizer.MERGES -KEY_TOKENIZER_BOS_ID = Keys.Tokenizer.BOS_ID -KEY_TOKENIZER_EOS_ID = Keys.Tokenizer.EOS_ID -KEY_TOKENIZER_UNK_ID = Keys.Tokenizer.UNK_ID -KEY_TOKENIZER_SEP_ID = Keys.Tokenizer.SEP_ID -KEY_TOKENIZER_PAD_ID = Keys.Tokenizer.PAD_ID -KEY_TOKENIZER_CLS_ID = Keys.Tokenizer.CLS_ID -KEY_TOKENIZER_MASK_ID = Keys.Tokenizer.MASK_ID -KEY_TOKENIZER_HF_JSON = Keys.Tokenizer.HF_JSON -KEY_TOKENIZER_RWKV = Keys.Tokenizer.RWKV -KEY_TOKENIZER_PRIFIX_ID = Keys.Tokenizer.PREFIX_ID -KEY_TOKENIZER_SUFFIX_ID = Keys.Tokenizer.SUFFIX_ID -KEY_TOKENIZER_MIDDLE_ID = Keys.Tokenizer.MIDDLE_ID -KEY_TOKENIZER_EOT_ID = Keys.Tokenizer.EOT_ID +KEY_TOKENIZER_MODEL = GGUFMetadataKeys.Tokenizer.MODEL +KEY_TOKENIZER_TYPE = GGUFMetadataKeys.Tokenizer.TYPE +KEY_TOKENIZER_NORM = GGUFMetadataKeys.Tokenizer.NORM +KEY_TOKENIZER_PRE = GGUFMetadataKeys.Tokenizer.PRE +KEY_TOKENIZER_ADDED = GGUFMetadataKeys.Tokenizer.ADDED +KEY_TOKENIZER_VOCAB = GGUFMetadataKeys.Tokenizer.VOCAB +KEY_TOKENIZER_MERGES = GGUFMetadataKeys.Tokenizer.MERGES +KEY_TOKENIZER_TOKEN_TYPE = GGUFMetadataKeys.Tokenizer.TOKEN_TYPE +KEY_TOKENIZER_TOKEN_TYPE_COUNT = GGUFMetadataKeys.Tokenizer.TOKEN_TYPE_COUNT +KEY_TOKENIZER_SCORES = GGUFMetadataKeys.Tokenizer.SCORES +KEY_TOKENIZER_BOS_ID = GGUFMetadataKeys.Tokenizer.BOS_ID +KEY_TOKENIZER_EOS_ID = GGUFMetadataKeys.Tokenizer.EOS_ID +KEY_TOKENIZER_UNK_ID = GGUFMetadataKeys.Tokenizer.UNK_ID +KEY_TOKENIZER_SEP_ID = GGUFMetadataKeys.Tokenizer.SEP_ID +KEY_TOKENIZER_PAD_ID = GGUFMetadataKeys.Tokenizer.PAD_ID +KEY_TOKENIZER_CLS_ID = GGUFMetadataKeys.Tokenizer.CLS_ID +KEY_TOKENIZER_MASK_ID = GGUFMetadataKeys.Tokenizer.MASK_ID +KEY_TOKENIZER_ADD_BOS = GGUFMetadataKeys.Tokenizer.ADD_BOS +KEY_TOKENIZER_ADD_EOS = GGUFMetadataKeys.Tokenizer.ADD_EOS +KEY_TOKENIZER_ADD_PREFIX = GGUFMetadataKeys.Tokenizer.ADD_PREFIX +KEY_TOKENIZER_RWKV = GGUFMetadataKeys.Tokenizer.RWKV +KEY_TOKENIZER_CHAT_TEMPLATE = GGUFMetadataKeys.Tokenizer.CHAT_TEMPLATE +KEY_TOKENIZER_CHAT_TEMPLATE_N = GGUFMetadataKeys.Tokenizer.CHAT_TEMPLATE_N +KEY_TOKENIZER_CHAT_TEMPLATES = GGUFMetadataKeys.Tokenizer.CHAT_TEMPLATES +KEY_TOKENIZER_PRIFIX_ID = GGUFMetadataKeys.Tokenizer.PREFIX_ID +KEY_TOKENIZER_SUFFIX_ID = GGUFMetadataKeys.Tokenizer.SUFFIX_ID +KEY_TOKENIZER_MIDDLE_ID = GGUFMetadataKeys.Tokenizer.MIDDLE_ID +KEY_TOKENIZER_EOT_ID = GGUFMetadataKeys.Tokenizer.EOT_ID diff --git a/gguf-py/gguf/huggingface_hub.py b/gguf-py/gguf/huggingface_hub.py new file mode 100644 index 0000000000000..5f72e183ad9e1 --- /dev/null +++ b/gguf-py/gguf/huggingface_hub.py @@ -0,0 +1,321 @@ +import json +import logging +import os +import pathlib +from hashlib import sha256 +from typing import Protocol + +import requests +from sentencepiece import SentencePieceProcessor +from tqdm import tqdm + +from .constants import HF_TOKENIZER_SPM_FILES + + +class HFHubBase(Protocol): + def __init__( + self, + model_path: None | str | pathlib.Path, + logger: None | logging.Logger, + ): + # Set the model path + if model_path is None: + model_path = "models" + self._model_path = model_path + + # Set the logger + if logger is None: + logger = logging.getLogger(__name__) + self.logger = logger + + @property + def model_path(self) -> pathlib.Path: + return pathlib.Path(self._model_path) + + @model_path.setter + def model_path(self, value: pathlib.Path): + self._model_path = value + + def write_file(self, content: bytes, file_path: pathlib.Path) -> None: + with open(file_path, "wb") as file: + file.write(content) + self.logger.debug(f"Wrote {len(content)} bytes to {file_path} successfully") + + +class HFHubRequest(HFHubBase): + def __init__( + self, + auth_token: None | str, + model_path: None | str | pathlib.Path, + logger: None | logging.Logger, + ): + super().__init__(model_path, logger) + + # Set headers if authentication is available + if auth_token is None: + self._headers = None + else: + # headers = { + # "Authorization": f"Bearer {auth_token}", + # "securityStatus": True, + # "blobs": True, + # } + self._headers = {"Authorization": f"Bearer {auth_token}"} + + # Persist across requests + self._session = requests.Session() + + # This is read-only + self._base_url = "https://huggingface.co" + + # NOTE: Cache repeat calls + self._model_repo = None + self._model_files = None + + @property + def headers(self) -> None | dict[str, str]: + return self._headers + + @property + def session(self) -> requests.Session: + return self._session + + @property + def base_url(self) -> str: + return self._base_url + + def resolve_url(self, repo: str, filename: str) -> str: + return f"{self._base_url}/{repo}/resolve/main/{filename}" + + def get_response(self, url: str) -> requests.Response: + # TODO: Stream requests and use tqdm to output the progress live + response = self._session.get(url, headers=self.headers) + self.logger.debug(f"Response status was {response.status_code}") + response.raise_for_status() + return response + + def model_info(self, model_repo: str) -> dict[str, object]: + url = f"{self._base_url}/api/models/{model_repo}" + return self.get_response(url).json() + + def list_remote_files(self, model_repo: str) -> list[str]: + # NOTE: Reset the cache if the repo changed + if self._model_repo != model_repo: + self._model_repo = model_repo + self._model_files = [] + for f in self.model_info(self._model_repo)["siblings"]: + self._model_files.append(f["rfilename"]) + dump = json.dumps(self._model_files, indent=4) + self.logger.debug(f"Cached remote files: {dump}") + # Return the cached file listing + return self._model_files + + def list_filtered_remote_files( + self, model_repo: str, file_suffix: str + ) -> list[str]: + model_files = [] + self.logger.debug(f"Model Repo:{model_repo}") + self.logger.debug(f"File Suffix:{file_suffix}") + # NOTE: Valuable files are typically in the root path + for filename in self.list_remote_files(model_repo): + path = pathlib.Path(filename) + if len(path.parents) > 1: + continue # skip nested paths + self.logger.debug(f"Path Suffix: {path.suffix}") + if path.suffix == file_suffix: + self.logger.debug(f"File Name: {filename}") + model_files.append(filename) + return model_files + + def list_remote_safetensors(self, model_repo: str) -> list[str]: + # NOTE: HuggingFace recommends using safetensors to mitigate pickled injections + return [ + part + for part in self.list_filtered_remote_files(model_repo, ".safetensors") + if part.startswith("model") + ] + + def list_remote_bin(self, model_repo: str) -> list[str]: + # NOTE: HuggingFace is streamlining PyTorch models with the ".bin" extension + return [ + part + for part in self.list_filtered_remote_files(model_repo, ".bin") + if part.startswith("pytorch_model") + ] + + def list_remote_weights(self, model_repo: str) -> list[str]: + model_parts = self.list_remote_safetensors(model_repo) + if not model_parts: + model_parts = self.list_remote_bin(model_repo) + self.logger.debug(f"Remote model parts: {model_parts}") + return model_parts + + def list_remote_tokenizers(self, model_repo: str) -> list[str]: + return [ + tok + for tok in self.list_remote_files(model_repo) + if tok in HF_TOKENIZER_SPM_FILES + ] + + +class HFHubTokenizer(HFHubBase): + def __init__( + self, model_path: None | str | pathlib.Path, logger: None | logging.Logger + ): + super().__init__(model_path, logger) + + @staticmethod + def list_vocab_files() -> tuple[str, ...]: + return HF_TOKENIZER_SPM_FILES + + def model(self, model_repo: str) -> SentencePieceProcessor: + path = self.model_path / model_repo / "tokenizer.model" + processor = SentencePieceProcessor() + processor.LoadFromFile(path.read_bytes()) + return processor + + def config(self, model_repo: str) -> dict[str, object]: + path = self.model_path / model_repo / "tokenizer_config.json" + return json.loads(path.read_text(encoding="utf-8")) + + def json(self, model_repo: str) -> dict[str, object]: + path = self.model_path / model_repo / "tokenizer.json" + return json.loads(path.read_text(encoding="utf-8")) + + def get_normalizer(self, model_repo: str) -> None | dict[str, object]: + normalizer = self.json(model_repo).get("normalizer", dict()) + if normalizer: + self.logger.info(f"JSON:Normalizer: {json.dumps(normalizer, indent=2)}") + else: + self.logger.warn(f"WARN:Normalizer: {normalizer}") + return normalizer + + def get_pre_tokenizer(self, model_repo: str) -> None | dict[str, object]: + pre_tokenizer = self.json(model_repo).get("pre_tokenizer") + if pre_tokenizer: + self.logger.info( + f"JSON:PreTokenizer: {json.dumps(pre_tokenizer, indent=2)}" + ) + return pre_tokenizer + else: + self.logger.warn(f"WARN:PreTokenizer: {pre_tokenizer}") + return pre_tokenizer + + def get_added_tokens(self, model_repo: str) -> None | list[dict[str, object]]: + added_tokens = self.json(model_repo).get("added_tokens", list()) + if added_tokens: + self.logger.info(f"JSON:AddedTokens: {json.dumps(added_tokens, indent=2)}") + else: + self.logger.warn(f"WARN:PreTokenizer: {added_tokens}") + return added_tokens + + def get_pre_tokenizer_json_hash(self, model_repo: str) -> None | str: + tokenizer = self.json(model_repo) + tokenizer_path = self.model_path / model_repo / "tokenizer.json" + if tokenizer.get("pre_tokenizer"): + sha256sum = sha256(str(tokenizer.get("pre_tokenizer")).encode()).hexdigest() + else: + return + self.logger.info(f"Hashed '{tokenizer_path}' as {sha256sum}") + return sha256sum + + def get_tokenizer_json_hash(self, model_repo: str) -> str: + tokenizer = self.json(model_repo) + tokenizer_path = self.model_path / model_repo / "tokenizer.json" + sha256sum = sha256(str(tokenizer).encode()).hexdigest() + self.logger.info(f"Hashed '{tokenizer_path}' as {sha256sum}") + return sha256sum + + def log_tokenizer_json_info(self, model_repo: str) -> None: + self.logger.info(f"{model_repo}") + tokenizer = self.json(model_repo) + for k, v in tokenizer.items(): + if k not in ["added_tokens", "model"]: + self.logger.info(f"{k}:{json.dumps(v, indent=2)}") + if k == "model": + for x, y in v.items(): + if x not in ["vocab", "merges"]: + self.logger.info(f"{k}:{x}:{json.dumps(y, indent=2)}") + + +class HFHubModel(HFHubBase): + def __init__( + self, + auth_token: None | str, + model_path: None | str | pathlib.Path, + logger: None | logging.Logger, + ): + super().__init__(model_path, logger) + + self._request = HFHubRequest(auth_token, model_path, logger) + self._tokenizer = HFHubTokenizer(model_path, logger) + + @property + def request(self) -> HFHubRequest: + return self._request + + @property + def tokenizer(self) -> HFHubTokenizer: + return self._tokenizer + + def _request_single_file( + self, model_repo: str, file_name: str, file_path: pathlib.Path + ) -> None: + # NOTE: Do not use bare exceptions! They mask issues! + # Allow the exception to occur or explicitly handle it. + try: + resolved_url = self.request.resolve_url(model_repo, file_name) + response = self.request.get_response(resolved_url) + self.write_file(response.content, file_path) + except requests.exceptions.HTTPError as e: + self.logger.debug(f"Error while downloading '{file_name}': {str(e)}") + + def _request_listed_files( + self, model_repo: str, remote_files: list[str, ...] + ) -> None: + for file_name in tqdm(remote_files, total=len(remote_files)): + dir_path = self.model_path / model_repo + os.makedirs(dir_path, exist_ok=True) + + # NOTE: Consider optional `force` parameter if files need to be updated. + # e.g. The model creator updated the vocabulary to resolve an issue or add a feature. + file_path = dir_path / file_name + if file_path.exists(): + self.logger.debug(f"skipped - downloaded {file_path} exists already.") + continue # skip existing files + + self.logger.debug(f"Downloading '{file_name}' from {model_repo}") + self._request_single_file(model_repo, file_name, file_path) + self.logger.debug(f"Model file successfully saved to {file_path}") + + def config(self, model_repo: str) -> dict[str, object]: + path = self.model_path / model_repo / "config.json" + return json.loads(path.read_text(encoding="utf-8")) + + def architecture(self, model_repo: str) -> str: + # NOTE: Allow IndexError to be raised because something unexpected happened. + # The general assumption is there is only a single architecture, but + # merged models may have multiple architecture types. This means this method + # call is not guaranteed. + try: + return self.config(model_repo).get("architectures", [])[0] + except IndexError: + self.logger.debug(f"Failed to get {model_repo} architecture") + return str() + + def download_model_weights(self, model_repo: str) -> None: + remote_files = self.request.list_remote_weights(model_repo) + self._request_listed_files(model_repo, remote_files) + + def download_model_tokenizers(self, model_repo: str) -> None: + remote_files = self.request.list_remote_tokenizers(model_repo) + self._request_listed_files(model_repo, remote_files) + + def download_model_weights_and_tokenizers(self, model_repo: str) -> None: + # attempt by priority + self.download_model_weights(model_repo) + self.download_model_tokenizers(model_repo) + + def download_all_repository_files(self, model_repo: str) -> None: + all_files = self.request.list_remote_files(model_repo) + self._request_listed_files(model_repo, all_files) diff --git a/gguf-py/scripts/gguf-gen-pre.py b/gguf-py/scripts/gguf-gen-pre.py new file mode 100644 index 0000000000000..72995a50bc3ab --- /dev/null +++ b/gguf-py/scripts/gguf-gen-pre.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import logging +import os +import sys +from pathlib import Path + +# Necessary to load the local gguf package +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): + sys.path.insert(0, str(Path(__file__).parent.parent)) + +from gguf.huggingface_hub import HFVocabRequest + +logger = logging.getLogger("gguf-gen-pre") + + +# NOTE: It's impossible to catch all edge cases. +# Most naive way to handle this is to a have a pre-compiled unicode list of all 1.1 million characters +# as it's finite and iso standardized. +# This means we can predict the upper bound and can apply known time complexity solutions to +# discover the best way resolve it. +def test_pre_tok_params() -> list[str]: + return [ + "ü, ǖ, ǘ, ǚ, ǜ", # diaeresis + "綠, 女, 怒, 玉, 句", # pinyin + "ied 4 ½ months", # ordinal + "¡Hola Mundo!", # spanish + "Olá Mundo!", # portuguese + "Selam Dünya!", # turkish + "Salam, dünýä!", # turkman + "Γειά σου Κόσμε!", # greek + "हैलो वर्ल्ड!", # hindi + "สวัสดีชาวโลก!", # thai + "こんにちは世界!", # japanese + "你好世界!", # chinese + "Hàlo a Shaoghail!", # gaelic + "Chào thế giới!", # vietnamese + "Привет, мир!", # russian + "Здравей свят!", # bulgarian + "សួស្តី​ពិភពលោក!", # kymer + "The quick brown fox jumped over the lazy dog.", # uses every letter in en alpha + "Le rapide renard brun sauta par dessus le chien paresseux.", # french + "\tWil je een kopje thee?\n", # dutch + " Te gustaría algo de té ? ", # spanish + # NOTE: I expect right-to-left languages to fail + "העלא וועלט!", # yiddish (r-to-l) + "سلام دنیا!", # persian (r-to-l) + "", # Why?; This is a falsy value in python, no symbols. + " ", + " ", + " ", + "\t", + "\n", + "\n\n", + "\n\n\n", + "\t\n", + "Hello world", + " Hello world", + "Hello World", + " Hello World", + " Hello World!", + "Hello, world!", + " Hello, world!", + " this is 🦙.cpp", + "w048 7tuijk dsdfhu", + "🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)", + "Hello", + " Hello", + " Hello", + " Hello", + " Hello", + " Hello\n Hello", + " (", + "\n =", + "' era", + "Hello, y'all! How are you 😁 局外人?苹果apple工作work3.14159天God~", + "3", + "33", + "333", + "3333", + "33333", + "333333", + "3333333", + ] + + +def test_pre_tok(hf_voc_req: HFVocabRequest) -> None: + # NOTE: aggregate all models to their respective paths + from transformers import AutoTokenizer + + params = test_pre_tok_params() + for model in hf_voc_req.models: + # set the model path, e.g. 'models/meta-llama/Llama-2-7b-hf' + path = Path(f"{hf_voc_req.model_path}/{model['repo']}") + # set the model name, e.g. llama-2-7b-hf + name = path.stem.lower() + # model input encodings, e.g. 'models/meta-llama/Llama-2-7b-hf/llama-2-7b-hf.vocab.gguf.inp' + inp = path / f"ggml-vocab-{name}.inp" + # model output encodings, e.g. 'models/meta-llama/Llama-2-7b-hf/llama-2-7b-hf.vocab.gguf.out' + out = path / f"ggml-vocab-{name}.out" + # extracted tokenizer model + final = path / f"ggml-vocab-{name}.gguf" + + # skip tokenizer folder if unavailable + if not path.exists(): + logger.warning(f"skipped - {model['repo']} not found.") + continue + + try: # create the tokenizer + tokenizer = AutoTokenizer.from_pretrained(path) + except OSError as e: + logger.error(f"{model['repo']} not found: {e}") + continue # skip this tokenizer model + + with open(inp, "w", encoding="utf-8") as f: + for test in params: + f.write(f"{test}") + f.write("\n__ggml_vocab_test__\n") + + with open(out, "w", encoding="utf-8") as f: + for test in params: + encodings = tokenizer.encode(test, add_special_tokens=False) + for encoding in encodings: + f.write(f" {encoding}") + f.write("\n") + + logger.info(f"Tests for {model['repo']} written in {final}.*") + + +def generate_vocab_script(hf_voc_req: HFVocabRequest) -> None: + # generate commands for creating vocab files + shscript = "#!/usr/bin/env bash\n\n" + + for model in hf_voc_req.models: + # get the repo path + path = Path(f"{hf_voc_req.model_path}/{model['repo']}") + # set the vocab path + vocab = path / f"ggml-vocab-{path.stem.lower()}.gguf" + # set the command line + tmpline = f"python3 convert-hf-to-gguf.py {path} --outfile {vocab} --vocab-only\n" + shscript += tmpline + logger.info(tmpline.strip()) + + with open("generate-vocab.sh", "w", encoding="utf-8") as f: + f.writelines(shscript) + logger.info(f"Wrote {len(shscript)} bytes to generate-vocab.sh") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("hf_auth_token", help="A huggingface read auth token") + parser.add_argument( + "-v", "--verbose", action="store_true", help="Increase output verbosity." + ) + parser.add_argument( + "-r", "--model-repo", default="meta-llama/Llama-2-7b-hf", + help="The models repository. Default is 'meta-llama/Llama-2-7b-hf'." + ) + parser.add_argument( + "-m", "--model-path", default="models/", + help="The models storage path. Default is 'models/'." + ) + parser.add_argument( + "-a", "--model-arch", default="llama", + help="The supported model architecture. Default is 'llama'." + ) + parser.add_argument( + "-p", "--model-parts", default=2, + help="The number of model shards encompassing the model. Default is 2." + ) + parser.add_argument( + "-t", "--model-type", default="safetensors", + help="The models file type. Default is 'safetensors'" + ) + parser.add_argument( + "-b", "--vocab-type", + default="SPM", const="SPM", nargs="?", choices=["SPM", "BPE", "WPM"], + help="The models tokenizer type. Default is 'SPM'." + ) + parser.add_argument( + "-t", "--gen-vocab-tests", action="store_true", + help="Generate the tokenizer tests. Default is False." + ) + parser.add_argument( + "-s", "--gen-vocab-script", action="store_true", + help="Generate the gguf vocab files. Default is False." + ) + args = parser.parse_args() + + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + else: + logging.basicConfig(level=logging.INFO) + + hf_vocab_req = HFVocabRequest( + args.model_path, args.hf_auth_token, logger + ) + + hf_vocab_req.download_models() + hf_vocab_req.generate_checksums() + hf_vocab_req.log_pre_tokenizer_info() + + if args.gen_vocab_tests: + test_pre_tok(hf_vocab_req) + + if args.gen_vocab_script: + generate_vocab_script(hf_vocab_req) + + +if __name__ == '__main__': + main() diff --git a/gguf-py/scripts/gguf-registry.py b/gguf-py/scripts/gguf-registry.py new file mode 100644 index 0000000000000..44bc3b79e2e78 --- /dev/null +++ b/gguf-py/scripts/gguf-registry.py @@ -0,0 +1,372 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import json +import logging +import os +import sys +from pathlib import Path + +# Necessary to load the local gguf package +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): + sys.path.insert(0, str(Path(__file__).parent.parent)) + +from gguf.constants import ( + HF_TOKENIZER_BPE_FILES, + HF_TOKENIZER_SPM_FILES, + MODEL_ARCH, + MODEL_ARCH_NAMES, + ModelFileExtension, + PreTokenizerType, + VocabType, +) +from gguf.huggingface_hub import HFHubModel, HFHubTokenizer + +logger = logging.getLogger(__file__) + +# +# HuggingFace Model Map +# +# NOTE: All prerequisite model metadata must be defined here. +# +# Defines metadata for each Hugging Face model required during conversion to GGUF +# +# Field Descriptions +# - `model_repo` (str): The HuggingFace endpoint or local path to the models repository +# - `model_arch` (MODEL_ARCH): Model architecture type +# - `model_parts` (int): Number of parts required to join the model during conversion +# - `model_type` (FileFormatType): File format for the Hugging Face model files +# - `vocab_type` (VocabType): Vocabulary type used by the tokenizer +# - `vocab_pre` (Optional[Tuple[str]]): Tuple of pre-tokenizer pattern strings for this model +# - `vocab_files` (Tuple[str]): Tuple of file names required to extract vocabulary and other metadata +# +# NOTES +# - Possible algorithms are WordLevel, BPE, WordPiece, or Unigram +# - Possible LLaMa tokenizer model types are: None, SPM, BPE, or WPM +HF_MODEL_MAP = ( + # SPM (Sentence Piece Models): Default to Byte Level Pre-tokenization. + { + "model_repo": "meta-llama/Llama-2-7b-hf", + "model_arch": MODEL_ARCH.LLAMA, + "model_parts": 2, + "model_type": ModelFileExtension.SAFETENSORS.value, + "vocab_type": VocabType.SPM.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_SPM_FILES, + }, + { + "model_repo": "mistralai/Mistral-7B-Instruct-v0.1", + "model_arch": MODEL_ARCH.LLAMA, + "model_parts": 2, + "model_type": ModelFileExtension.SAFETENSORS.value, + "vocab_type": VocabType.SPM.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_SPM_FILES, + }, + { + "model_repo": "mistralai/Mistral-7B-Instruct-v0.2", + "model_arch": MODEL_ARCH.LLAMA, + "model_parts": 3, + "model_type": ModelFileExtension.SAFETENSORS.value, + "vocab_type": VocabType.SPM.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_SPM_FILES, + }, + { # NOTE: Mistral v0.3 has a 'tokenizer.model.v3' file + "model_repo": "mistralai/Mistral-7B-Instruct-v0.3", + "model_arch": MODEL_ARCH.LLAMA, + "model_parts": 3, + "model_type": ModelFileExtension.SAFETENSORS.value, + "vocab_type": VocabType.SPM.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_SPM_FILES, + }, + { + "model_repo": "mistralai/Mixtral-8x7B-Instruct-v0.1", + "model_arch": MODEL_ARCH.LLAMA, + "model_parts": 8, + "model_type": ModelFileExtension.SAFETENSORS.value, + "vocab_type": VocabType.SPM.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_SPM_FILES, + }, + { + "model_repo": "microsoft/Phi-3-mini-4k-instruct", + "model_arch": MODEL_ARCH.PHI3, + "model_parts": 2, + "model_type": ModelFileExtension.SAFETENSORS.value, + "vocab_type": VocabType.SPM.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_SPM_FILES, + }, + # WPM (Word Piece Models): Default to Byte Level Pre-tokenization. + # NOTE: BERT Normalization and Pre-tokenization rules differ from Byte Level Pre-tokenization. + { + "model_repo": "BAAI/bge-small-en-v1.5", + "model_arch": MODEL_ARCH.BERT, + "model_parts": 1, + "model_type": ModelFileExtension.BIN.value, + "vocab_type": VocabType.WPM.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_BPE_FILES, + }, + { + "model_repo": "jinaai/jina-embeddings-v2-base-en", + "model_arch": MODEL_ARCH.JINA_BERT_V2, + "model_parts": 1, + "model_type": ModelFileExtension.SAFETENSORS.value, + "vocab_type": VocabType.WPM.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_BPE_FILES, + }, + # BPE (Byte Pair Encoding Models): Default is Byte Level Pre-tokenization + { + "model_repo": "meta-llama/Meta-Llama-3-8B", + "model_arch": MODEL_ARCH.LLAMA, + "model_parts": 4, + "model_type": ModelFileExtension.SAFETENSORS.value, + "vocab_type": VocabType.BPE.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_BPE_FILES, + }, + { + "model_repo": "tiiuae/falcon-7b", + "model_arch": MODEL_ARCH.FALCON, + "model_parts": 2, + "model_type": ModelFileExtension.BIN.value, + "vocab_type": VocabType.BPE.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_BPE_FILES, + }, + { + "model_repo": "deepseek-ai/deepseek-llm-7b-base", + "model_arch": MODEL_ARCH.LLAMA, + "model_parts": 2, + "model_type": ModelFileExtension.BIN.value, + "vocab_type": VocabType.BPE.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_BPE_FILES, + }, + { + "model_repo": "deepseek-ai/deepseek-coder-6.7b-base", + "model_arch": MODEL_ARCH.LLAMA, + "model_parts": 2, + "model_type": ModelFileExtension.SAFETENSORS.value, + "vocab_type": VocabType.BPE.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_BPE_FILES, + }, + { + "model_repo": "mosaicml/mpt-7b", + "model_arch": MODEL_ARCH.MPT, + "model_parts": 2, + "model_type": ModelFileExtension.BIN.value, + "vocab_type": VocabType.BPE.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_BPE_FILES, + }, + # + # BPE: STARCODER + # + { + "model_repo": "bigcode/starcoder2-3b", + "model_arch": MODEL_ARCH.STARCODER2, + "model_parts": 1, + "model_type": ModelFileExtension.SAFETENSORS.value, + "vocab_type": VocabType.BPE.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_BPE_FILES, + }, + { + "model_repo": "smallcloudai/Refact-1_6-base", + "model_arch": MODEL_ARCH.REFACT, + "model_parts": 1, + "model_type": ModelFileExtension.BIN.value, + "vocab_type": VocabType.BPE.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_BPE_FILES, + }, + { + "model_repo": "CohereForAI/c4ai-command-r-v01", + "model_arch": MODEL_ARCH.COMMAND_R, + "model_parts": 15, + "model_type": ModelFileExtension.SAFETENSORS.value, + "vocab_type": VocabType.BPE.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_BPE_FILES, + }, + # + # BPE: QWEN + # + { + "model_repo": "Qwen/Qwen1.5-7B", + "model_arch": MODEL_ARCH.QWEN2, + "model_parts": 4, + "model_type": ModelFileExtension.SAFETENSORS.value, + "vocab_type": VocabType.BPE.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_BPE_FILES, + }, + { + "model_repo": "stabilityai/stablelm-2-zephyr-1_6b", + "model_arch": MODEL_ARCH.STABLELM, + "model_parts": 1, + "model_type": ModelFileExtension.SAFETENSORS.value, + "vocab_type": VocabType.BPE.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_BPE_FILES, + }, + # + # BPE: GPT-2 + # + { + "model_repo": "openai-community/gpt2", + "model_arch": MODEL_ARCH.GPT2, + "model_parts": 1, + "model_type": ModelFileExtension.SAFETENSORS.value, + "vocab_type": VocabType.BPE.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_BPE_FILES, + }, + { + "model_repo": "allenai/OLMo-1.7-7B-hf", + "model_arch": MODEL_ARCH.OLMO, + "model_parts": 6, + "model_type": ModelFileExtension.SAFETENSORS.value, + "vocab_type": VocabType.BPE.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_BPE_FILES, + }, + # { # NOTE: I don't have access to this model + # "model_repo": "databricks/dbrx-base", + # "model_arch": MODEL_ARCH.DBRX, + # "model_parts": 0, + # "model_type": ModelFileExtension.SAFETENSORS.value, + # "vocab_type": VocabType.BPE.value, + # "vocab_pre": None, + # "vocab_files": HF_TOKENIZER_BPE_FILES, + # }, + { # NOTE: RoBERTa post processor + "model_repo": "jinaai/jina-embeddings-v2-base-es", + "model_arch": MODEL_ARCH.JINA_BERT_V2, + "model_parts": 1, + "model_type": ModelFileExtension.SAFETENSORS.value, + "vocab_type": VocabType.BPE.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_BPE_FILES, + }, + { # NOTE: RoBERTa post processor + "model_repo": "jinaai/jina-embeddings-v2-base-de", + "model_arch": MODEL_ARCH.JINA_BERT_V2, + "model_parts": 1, + "model_type": ModelFileExtension.SAFETENSORS.value, + "vocab_type": VocabType.BPE.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_BPE_FILES, + }, + { # NOTE: Phi-1 is compatible with GPT-2 arch and vocab + "model_repo": "microsoft/phi-1", + "model_arch": MODEL_ARCH.PHI2, + "model_parts": 1, + "model_type": ModelFileExtension.SAFETENSORS.value, + "vocab_type": VocabType.BPE.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_BPE_FILES, + }, + { + "model_repo": "microsoft/phi-1_5", + "model_arch": MODEL_ARCH.PHI2, + "model_parts": 1, + "model_type": ModelFileExtension.SAFETENSORS.value, + "vocab_type": VocabType.BPE.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_BPE_FILES, + }, + { + "model_repo": "microsoft/phi-2", + "model_arch": MODEL_ARCH.PHI2, + "model_parts": 2, + "model_type": ModelFileExtension.SAFETENSORS.value, + "vocab_type": VocabType.BPE.value, + "vocab_pre": None, + "vocab_files": HF_TOKENIZER_BPE_FILES, + }, +) + + +def get_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("auth_token", help="A huggingface read auth token") + parser.add_argument( + "-v", "--verbose", action="store_true", help="Increase output verbosity." + ) + parser.add_argument( + "--model-path", + default="models", + help="The models storage path. Default is 'models'.", + ) + return parser.parse_args() + + +args = get_arguments() + +if args.verbose: + logging.basicConfig(level=logging.DEBUG) +else: + logging.basicConfig(level=logging.INFO) + +hub_model = HFHubModel( + auth_token=args.auth_token, + model_path=args.model_path, + logger=logger, +) + +hub_tokenizer = HFHubTokenizer( + model_path=args.model_path, + logger=logger, +) + + +metadata = [] +for model in HF_MODEL_MAP: + model_repo = model["model_repo"] + model_arch = model["model_arch"] + vocab_type = model["vocab_type"] + + print("HUB_REPO:", model_repo, "LLAMA_ARCH:", MODEL_ARCH_NAMES[model_arch]) + + hub_model.download_all_vocab_files( + model_repo=model_repo, + vocab_type=vocab_type, + ) + # log the downloaded results + hub_tokenizer.log_tokenizer_json_info(model_repo) + + model["model_arch"] = MODEL_ARCH_NAMES[model_arch] + + normalizer = hub_tokenizer.get_normalizer(model_repo) + # NOTE: Normalizer may be one of null, Sequence, NFC, NFD, NFKC, NFKD... + # Seems to be null, Sequence, or NFC in most cases + # Default to NFD + # TODO: Extract the normalizer metadata + model["normalizer"] = normalizer + + # Seems safe to assume most basic types are of type "Sequence" + # I expect this to cause issues in the future. Needs more research. + pre_tokenizer = hub_tokenizer.get_pre_tokenizer(model_repo) + # extract the added tokens metadata + model["pre_tokenizer"] = pre_tokenizer + + added_tokens = hub_tokenizer.get_added_tokens(model_repo) + # extract the added tokens metadata + model["added_tokens"] = added_tokens + + sha256sum = hub_tokenizer.get_tokenizer_json_hash(model_repo) + # use the hash to validate the models vocabulary + model["vocab_hash"] = sha256sum + + metadata.append(model) + +with open(f"{args.model_path}/registry.json", mode="w") as file: + json.dump(metadata, file, indent=2) diff --git a/gguf-py/scripts/hub-model.py b/gguf-py/scripts/hub-model.py new file mode 100644 index 0000000000000..fc0e88546ae09 --- /dev/null +++ b/gguf-py/scripts/hub-model.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import argparse +import logging +import os +import sys +from pathlib import Path + +# Necessary to load the local gguf package +if ( + "NO_LOCAL_GGUF" not in os.environ + and (Path(__file__).parent.parent.parent / "gguf-py").exists() +): + sys.path.insert(0, str(Path(__file__).parent.parent)) + +from gguf.huggingface_hub import HFHubModel, HFHubTokenizer + +logger = logging.getLogger(Path(__file__).stem) + + +def get_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("auth_token", help="A huggingface read auth token") + parser.add_argument( + "model_repo", help="A huggingface model repository, e.g. org/model" + ) + parser.add_argument( + "-v", "--verbose", action="store_true", help="Increase output verbosity." + ) + parser.add_argument( + "-m", "--model-path", default="models", + help="The models storage path. Default is 'models'.", + ) + parser.add_argument( + "-a", "--model-arch", default="llama", + help="The supported llama.cpp model architecture. Default is 'llama'." + ) + parser.add_argument( + "-p", "--model-parts", default=2, + help="The number of model shards encompassing the model. Default is 2." + ) + parser.add_argument( + "-f", "--model-name", + default=".safetensors", const=".safetensors", nargs="?", + choices=[".pt", ".pth", ".bin", ".safetensors", ".gguf"], + help="The models file name extension. Default is '.safetensors'" + ) + parser.add_argument( + "-t", "--vocab-type", + nargs="?", choices=["SPM", "BPE", "WPM"], + help="The models tokenizer type. Default is 'SPM'." + ) + return parser.parse_args() + + +def main(): + args = get_arguments() + + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + else: + logging.basicConfig(level=logging.INFO) + + hub_model = HFHubModel( + auth_token=args.auth_token, + model_path=args.model_path, + logger=logger, + ) + + hub_tokenizer = HFHubTokenizer( + model_path=args.model_path, + logger=logger, + ) + + vocab_type = HFHubTokenizer.get_vocab_type(args.vocab_name) + hub_model.download_all_vocab_files( + model_repo=args.model_repo, + vocab_type=vocab_type, + ) + hub_model.download_all_vocab_files(args.model_repo, vocab_type) + hub_tokenizer.log_tokenizer_json_info(args.model_repo) + + model_type = HFHubModel.get_model_type(args.model_name) + hub_model.download_model_files(args.model_repo, model_type) + + +if __name__ == '__main__': + main() diff --git a/gguf-py/scripts/hub-vocab.py b/gguf-py/scripts/hub-vocab.py new file mode 100644 index 0000000000000..2e5225aff65a4 --- /dev/null +++ b/gguf-py/scripts/hub-vocab.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +""" +Tokenizers Vocabulary Notes: + +Normalizers: +Normalizers are a set of operations applied to raw string input data to make it less random or “cleaner”. Common normalization operations include stripping whitespace, removing accented characters or lowercasing all text. The Hugging Face `tokenizers` library provides various Normalizer classes that can be combined using a normalizers.Sequence to apply multiple normalization operations in sequence on the input data before tokenization takes place. + +Pre-Tokenization: +Pre-Tokenization encompasses identifying characters and their types, including letters, numbers, whitespace, etc., prior to applying actual tokenization or feeding the data into machine learning models. The Hugging Face `tokenizers` library provides several Pre-tokenizer classes that can be used for different purposes such as Byte Level pre-tokenization (using openai/gpt-2 RegEx by default) and BERT pre-tokenization, which inherits from Byte Level tokenization but has some differences in its behavior. + +Pre-Tokenization Types: + +1. Byte Level Pre-tokenization: + - Default regular expression used for pattern matching is taken from openai/gpt-2 `encoder.py`. + +2. BERT pre-tokenization (inherits from Byte Level): + - Differences in behavior compared to the default Byte Level tokenizer, but defaults for each RegEx are identical in either case. + +Pre-Tokenization Character Types: + +1. Sequence: Matches a sequence of characters that should be treated as a single unit during preprocessing or tokenization. +2. Letters and Numbers (Alphabetic/Alphanumeric): Characters belonging to the alphabet or mixed combinations of letters and numbers, respectively. +3. Whitespace: Spaces, tabs, newlines, etc., that separate words or other units in the text data. +""" +from __future__ import annotations + +import argparse +import logging +import os +import sys +from pathlib import Path + +# Necessary to load the local gguf package +if ( + "NO_LOCAL_GGUF" not in os.environ + and (Path(__file__).parent.parent.parent / "gguf-py").exists() +): + sys.path.insert(0, str(Path(__file__).parent.parent)) + +from gguf.huggingface_hub import HFHubModel, HFHubTokenizer + +logger = logging.getLogger(Path(__file__).stem) + + +def get_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("auth_token", help="A huggingface read auth token") + parser.add_argument( + "model_repo", help="A huggingface model repository, e.g. org/model" + ) + parser.add_argument( + "-v", "--verbose", action="store_true", help="Increase output verbosity." + ) + parser.add_argument( + "--model-path", + default="models", + help="The models storage path. Default is 'models/'.", + ) + parser.add_argument( + "--vocab-name", + const="BPE", + nargs="?", + choices=["SPM", "BPE", "WPM"], + help="The name of the vocab type. Default is 'BPE'.", + ) + return parser.parse_args() + + +def main(): + args = get_arguments() + + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + else: + logging.basicConfig(level=logging.INFO) + + hub_model = HFHubModel( + auth_token=args.auth_token, + model_path=args.model_path, + logger=logger, + ) + + hub_tokenizer = HFHubTokenizer( + model_path=args.model_path, + logger=logger, + ) + + vocab_type = HFHubTokenizer.get_vocab_type(args.vocab_name) + hub_model.download_all_vocab_files( + model_repo=args.model_repo, + vocab_type=vocab_type, + ) + + hub_model.download_all_vocab_files(args.model_repo, vocab_type) + hub_tokenizer.log_tokenizer_json_info(args.model_repo) + + +if __name__ == "__main__": + main() diff --git a/llama.cpp b/llama.cpp index 841be1de7291e..2b6dcd7b6b26d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -325,7 +325,9 @@ enum llm_kv { LLM_KV_SSM_TIME_STEP_RANK, LLM_KV_TOKENIZER_MODEL, + LLM_KV_TOKENIZER_TYPE, LLM_KV_TOKENIZER_PRE, + LLM_KV_TOKENIZER_HASH, LLM_KV_TOKENIZER_LIST, LLM_KV_TOKENIZER_TOKEN_TYPE, LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, @@ -410,7 +412,9 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, + { LLM_KV_TOKENIZER_TYPE, "tokenizer.ggml.type" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, + { LLM_KV_TOKENIZER_HASH, "tokenizer.ggml.hash" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, diff --git a/models/ggml-vocab-aquila.gguf b/models/ggml-vocab-aquila.gguf deleted file mode 100644 index 7a9abb122ddd1..0000000000000 Binary files a/models/ggml-vocab-aquila.gguf and /dev/null differ diff --git a/models/ggml-vocab-baichuan.gguf b/models/ggml-vocab-baichuan.gguf deleted file mode 100644 index 7caaf8239b052..0000000000000 Binary files a/models/ggml-vocab-baichuan.gguf and /dev/null differ diff --git a/models/ggml-vocab-bert-bge.gguf b/models/ggml-vocab-bert-bge.gguf deleted file mode 100644 index b2cbd5df6882d..0000000000000 Binary files a/models/ggml-vocab-bert-bge.gguf and /dev/null differ diff --git a/models/ggml-vocab-bert-bge.gguf.inp b/models/ggml-vocab-bert-bge.gguf.inp deleted file mode 100644 index 0a89107c60d7f..0000000000000 --- a/models/ggml-vocab-bert-bge.gguf.inp +++ /dev/null @@ -1,106 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ diff --git a/models/ggml-vocab-bert-bge.gguf.out b/models/ggml-vocab-bert-bge.gguf.out deleted file mode 100644 index e4a76cdb07d3f..0000000000000 --- a/models/ggml-vocab-bert-bge.gguf.out +++ /dev/null @@ -1,43 +0,0 @@ - 29464 2094 1018 1092 2706 - 11865 17875 - - - - - - - - - - 7592 2088 - 7592 2088 - 7592 2088 - 7592 2088 - 7592 2088 999 - 7592 1010 2088 999 - 7592 1010 2088 999 - 2023 2003 100 1012 18133 2361 - 1059 2692 18139 1021 8525 28418 2243 16233 20952 6979 - 1192 15290 29754 14150 1192 10260 1181 29755 29436 29741 10260 16856 29747 23925 10325 - 100 - 100 1006 3671 1007 100 1006 3674 7861 29147 2483 9530 16280 23854 1007 100 1006 2069 7861 29147 2072 2008 2038 2049 2219 19204 1007 - 7592 - 7592 - 7592 - 7592 - 7592 - 7592 7592 - 1006 - 1027 - 1005 3690 - 7592 1010 1061 1005 2035 999 2129 2024 2017 100 1029 1855 100 100 6207 100 100 14677 23632 22203 1811 1995 - 1017 - 3943 - 21211 - 21211 2509 - 21211 22394 - 21211 22394 2509 - 21211 22394 22394 - 21211 22394 22394 2509 - 21211 22394 22394 22394 - 100 1006 3671 1007 100 1006 3674 7861 29147 2483 9530 16280 23854 1007 100 100 1017 3943 21211 21211 2509 21211 22394 21211 22394 2509 21211 22394 22394 21211 22394 22394 2509 1017 1012 1017 1017 1012 1012 1017 1017 1012 1012 1012 1017 100 1029 1855 100 100 6207 100 100 14677 23632 22203 1811 1995 1011 1011 1011 1011 1011 1011 1027 1027 1027 1027 1027 1027 1027 1192 15290 29754 14150 1192 10260 1181 29755 29436 29741 10260 16856 29747 23925 10325 1005 1005 1005 1005 1005 1005 1036 1036 1036 1036 1036 1036 1036 1000 1000 1000 1000 1012 1012 1012 1012 1012 1012 999 999 999 999 999 999 1029 1029 1029 1029 1029 1029 1045 1005 2310 2042 1005 2409 2002 1005 1055 2045 1010 1005 2128 2017 2469 1029 1005 1049 2025 2469 1045 1005 2222 2191 2009 1010 1005 1040 2017 2066 2070 5572 1029 2057 1005 2310 1037 1005 2222 diff --git a/models/ggml-vocab-command-r.gguf b/models/ggml-vocab-command-r.gguf deleted file mode 100644 index b553eab330591..0000000000000 Binary files a/models/ggml-vocab-command-r.gguf and /dev/null differ diff --git a/models/ggml-vocab-command-r.gguf.inp b/models/ggml-vocab-command-r.gguf.inp deleted file mode 100644 index 0a89107c60d7f..0000000000000 --- a/models/ggml-vocab-command-r.gguf.inp +++ /dev/null @@ -1,106 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ diff --git a/models/ggml-vocab-command-r.gguf.out b/models/ggml-vocab-command-r.gguf.out deleted file mode 100644 index cc4277daa1d25..0000000000000 --- a/models/ggml-vocab-command-r.gguf.out +++ /dev/null @@ -1,43 +0,0 @@ - 2536 228 27 228 22957 6983 - 45 193433 - - 228 - 1667 - 1742 - 205 - 206 - 2126 - 11516 - 34777 - 28339 3845 - 46609 3845 - 28339 3930 - 46609 3930 - 46609 3930 8 - 28339 19 3845 8 - 46609 19 3845 8 - 2075 1801 11254 107 255 21 19317 - 94 23 27 31 228 30 21213 20752 39267 6405 9980 - 4929 40071 2196 3236 8750 1764 37097 41168 - 38111 230 174833 38111 249 86325 241 38111 245 86325 232 38111 252 38111 123 38111 261 165 24629 38111 261 38111 103 174833 38111 235 38111 231 38111 257 38111 235 165 24629 38111 239 - 2226 256 230 1737 18258 16 80503 122 35927 2226 242 112 57462 1737 54457 223165 106230 2096 16 48389 1737 10203 109160 1875 2222 2517 3342 12523 16 - 28339 - 46609 - 228 46609 - 1667 46609 - 1742 46609 - 1742 46609 1856 46609 - 1737 - 206 1857 - 14 4515 - 28339 19 1770 14 1954 8 4070 1955 1933 80503 231 5691 12081 13336 2648 29325 14315 24 26 24 27 24 28 24 5123 18372 - 26 - 26 26 - 26 26 26 - 26 26 26 26 - 26 26 26 26 26 - 26 26 26 26 26 26 - 26 26 26 26 26 26 26 - 26 26 26 26 26 26 26 26 - 26 26 26 26 26 26 26 26 26 - 127731 51628 205 57788 18494 97469 126134 206 2226 256 230 1737 18258 16 80503 122 35927 2226 242 112 57462 1737 54457 223165 106230 2096 16 48389 11254 107 255 2226 107 255 228 26 228 26 26 228 26 26 26 228 26 26 26 26 228 26 26 26 26 26 228 26 26 26 26 26 26 228 26 26 26 26 26 26 26 228 26 26 26 26 26 26 26 26 228 26 21 26 228 26 2271 26 228 26 3834 26 182018 230 174833 38111 249 86325 241 38111 245 86325 232 38111 252 38111 123 38111 261 165 24629 38111 261 38111 103 174833 38111 235 188568 231 5691 12081 13336 2648 29325 14315 24 26 24 27 24 28 24 5123 18372 8391 158343 3512 40071 2196 3236 8750 1764 37097 41168 29721 32797 25646 3802 4975 4975 116167 57178 10251 154048 27292 1767 5125 2632 2155 91 2378 1919 1914 2782 19 2155 3354 1933 5470 38 2155 52 2068 5470 1767 4961 3059 1894 19 2155 43 1933 3026 2725 23186 38 2930 14 20676 1671 14 83 51 diff --git a/models/ggml-vocab-deepseek-coder.gguf b/models/ggml-vocab-deepseek-coder.gguf deleted file mode 100644 index 6728cd747249e..0000000000000 Binary files a/models/ggml-vocab-deepseek-coder.gguf and /dev/null differ diff --git a/models/ggml-vocab-deepseek-coder.gguf.inp b/models/ggml-vocab-deepseek-coder.gguf.inp deleted file mode 100644 index 0a89107c60d7f..0000000000000 --- a/models/ggml-vocab-deepseek-coder.gguf.inp +++ /dev/null @@ -1,106 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ diff --git a/models/ggml-vocab-deepseek-coder.gguf.out b/models/ggml-vocab-deepseek-coder.gguf.out deleted file mode 100644 index 9ccc560d694ca..0000000000000 --- a/models/ggml-vocab-deepseek-coder.gguf.out +++ /dev/null @@ -1,43 +0,0 @@ - 1050 207 19 207 19192 4217 - 37 32009 71 6247 - - 207 - 243 - 315 - 184 - 185 - 185 185 - 185 185 185 - 184 185 - 17535 1835 - 414 9489 1835 - 17535 5414 - 414 9489 5414 - 414 9489 5414 0 - 17535 11 1835 0 - 414 9489 11 1835 0 - 437 317 12394 99 234 13 14789 - 86 15 19 23 207 22 83 3963 27659 26078 3934 14072 - 1593 6478 616 2251 14994 - 155 239 209 155 239 114 155 239 228 155 240 220 155 239 224 155 240 211 155 239 231 155 239 115 155 239 240 155 240 210 155 239 240 155 239 95 155 239 114 155 239 214 155 239 210 155 239 236 155 239 214 155 240 210 155 239 218 - 10047 235 209 334 8760 8 12394 233 114 350 222 10047 221 104 169 116 224 334 4684 3909 992 24330 262 29651 612 8 207 156 237 214 334 5950 992 78 12896 344 638 891 1372 10736 8 - 17535 - 414 9489 - 207 414 9489 - 243 414 9489 - 315 414 9489 - 315 414 9489 185 315 414 9489 - 334 - 185 405 - 6 2895 - 17535 11 320 6 435 0 1717 417 340 12394 233 210 3015 19100 608 9413 2668 16 18 16 19 16 20 16 1393 169 121 239 - 18 - 18 18 - 18 18 18 - 18 18 18 18 - 18 18 18 18 18 - 18 18 18 18 18 18 - 18 18 18 18 18 18 18 - 18 18 18 18 18 18 18 18 - 18 18 18 18 18 18 18 18 18 - 185 207 185 185 207 185 185 185 207 12405 459 22758 185 243 185 315 185 251 185 730 185 10047 235 209 334 8760 8 12394 233 114 350 222 10047 221 104 169 116 224 334 4684 3909 992 24330 262 29651 612 8 207 156 237 214 12394 99 234 10047 99 234 207 18 207 18 18 207 18 18 18 207 18 18 18 18 207 18 18 18 18 18 207 18 18 18 18 18 18 207 18 18 18 18 18 18 18 207 18 18 18 18 18 18 18 18 207 18 13 18 207 18 524 18 207 18 1202 18 207 155 239 209 155 239 114 155 239 228 155 240 220 155 239 224 155 240 211 155 239 231 155 239 115 155 239 240 155 240 210 155 239 240 155 239 95 155 239 114 155 239 214 10047 233 210 3015 19100 608 9413 2668 16 18 16 19 16 20 16 1393 169 121 239 18155 374 17194 28 2861 6478 616 2251 14994 31269 4191 6 4686 4686 10252 3358 3358 3409 524 15330 3023 15031 5668 303 6 312 798 651 83 839 362 6 82 741 11 651 1369 340 2037 30 651 44 441 2037 303 6 642 1098 359 11 651 35 340 833 738 10860 30 998 6 10709 245 6 75 43 diff --git a/models/ggml-vocab-deepseek-llm.gguf b/models/ggml-vocab-deepseek-llm.gguf deleted file mode 100644 index 5d66091c44b6f..0000000000000 Binary files a/models/ggml-vocab-deepseek-llm.gguf and /dev/null differ diff --git a/models/ggml-vocab-deepseek-llm.gguf.inp b/models/ggml-vocab-deepseek-llm.gguf.inp deleted file mode 100644 index 0a89107c60d7f..0000000000000 --- a/models/ggml-vocab-deepseek-llm.gguf.inp +++ /dev/null @@ -1,106 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ diff --git a/models/ggml-vocab-deepseek-llm.gguf.out b/models/ggml-vocab-deepseek-llm.gguf.out deleted file mode 100644 index fd94b896d24e7..0000000000000 --- a/models/ggml-vocab-deepseek-llm.gguf.out +++ /dev/null @@ -1,43 +0,0 @@ - 1052 207 19 207 19109 4223 - 37 100014 71 6245 - - 207 - 243 - 300 - 184 - 185 - 185 185 - 185 185 185 - 184 185 - 17464 1843 - 37727 1843 - 17464 5427 - 37727 5427 - 37727 5427 0 - 17464 11 1843 0 - 37727 11 1843 0 - 437 317 12356 99 234 13 14743 - 86 15 19 23 207 22 83 3970 27519 26016 3944 14025 - 1603 6476 620 91754 - 71374 209 71374 114 71374 228 155 240 220 71374 224 155 240 211 71374 231 71374 115 71374 240 155 240 210 71374 240 71374 95 71374 114 71374 214 71374 210 71374 236 71374 214 155 240 210 71374 218 - 10044 95300 334 8754 8 33701 114 350 222 10044 221 104 46713 334 34732 996 24250 262 80923 8 207 37103 214 334 5956 89213 344 643 895 1377 10728 8 - 17464 - 37727 - 207 37727 - 243 37727 - 300 37727 - 300 37727 185 300 37727 - 334 - 185 403 - 6 2906 - 17464 11 320 6 436 0 1724 418 340 33701 210 3025 19017 612 9407 2681 16 18 16 19 16 20 16 1398 68940 239 - 18 - 18 18 - 18 18 18 - 18 18 18 18 - 18 18 18 18 18 - 18 18 18 18 18 18 - 18 18 18 18 18 18 18 - 18 18 18 18 18 18 18 18 - 18 18 18 18 18 18 18 18 18 - 185 207 185 185 207 185 185 185 207 11969 486 22504 185 243 185 300 185 251 185 663 185 10044 95300 334 8754 8 33701 114 350 222 10044 221 104 46713 334 34732 996 24250 262 80923 8 207 37103 214 12356 99 234 10044 99 234 207 18 207 18 18 207 18 18 18 207 18 18 18 18 207 18 18 18 18 18 207 18 18 18 18 18 18 207 18 18 18 18 18 18 18 207 18 18 18 18 18 18 18 18 207 18 13 18 207 18 526 18 207 18 1204 18 207 71374 209 71374 114 71374 228 155 240 220 71374 224 155 240 211 71374 231 71374 115 71374 240 155 240 210 71374 240 71374 95 71374 114 71374 214 71899 210 3025 19017 612 9407 2681 16 18 16 19 16 20 16 1398 68940 239 78827 55170 76659 620 91754 31116 36804 4885 4885 10897 4390 4390 41047 15278 3033 14986 5675 304 6 313 803 655 33326 362 6 82 745 11 655 1374 340 2049 30 655 44 441 2049 304 6 647 1099 359 11 655 35 340 837 742 10842 30 1003 6 10699 245 6 75 43 diff --git a/models/ggml-vocab-falcon.gguf b/models/ggml-vocab-falcon.gguf deleted file mode 100644 index 334d50da51ba5..0000000000000 Binary files a/models/ggml-vocab-falcon.gguf and /dev/null differ diff --git a/models/ggml-vocab-falcon.gguf.inp b/models/ggml-vocab-falcon.gguf.inp deleted file mode 100644 index 0a89107c60d7f..0000000000000 --- a/models/ggml-vocab-falcon.gguf.inp +++ /dev/null @@ -1,106 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ diff --git a/models/ggml-vocab-falcon.gguf.out b/models/ggml-vocab-falcon.gguf.out deleted file mode 100644 index 209b04cdaf330..0000000000000 --- a/models/ggml-vocab-falcon.gguf.out +++ /dev/null @@ -1,43 +0,0 @@ - 878 204 31 3068 133 2137 - 28611 132 30042 - - 204 - 258 - 466 - 192 - 193 - 1001 - 11331 - 19125 - 9856 1079 - 23090 1079 - 9856 2889 - 23090 2889 - 23090 2889 12 - 9856 23 1079 12 - 23090 23 1079 12 - 414 304 3346 111 231 25 29247 - 98 55866 204 34 16682 7149 36190 6869 11481 - 150 133 6207 151 215 150 134 5052 133 6279 5052 223 151 216 49679 123 53110 47043 7795 - 38154 206 38154 126 38154 225 167 237 217 38154 221 167 237 208 38154 228 38154 127 38154 237 167 237 207 38154 237 38154 107 38154 126 38154 211 38154 207 38154 233 38154 211 167 237 207 38154 215 - 2571 232 206 204 19 11003 20 8196 126 283 219 48778 116 13392 204 19 51831 732 63209 1741 7955 522 20 22438 211 204 19 7927 53360 325 504 701 946 10930 20 - 9856 - 23090 - 204 23090 - 258 23090 - 466 23090 - 466 23090 742 23090 - 204 19 - 1212 40 - 18 4932 - 9856 23 291 18 436 12 1265 362 299 8196 207 204 42 50087 123 2727 20300 32022 133 234 17419 30137 28 7858 181 133 236 - 30 - 3138 - 22287 - 22287 30 - 22287 3138 - 22287 22287 - 22287 22287 30 - 22287 22287 3138 - 22287 22287 22287 - 1212 4824 1001 1212 192 204 663 49453 2069 742 561 1501 193 2571 232 206 204 19 11003 20 8196 126 283 219 48778 116 13392 204 19 51831 732 63209 1741 7955 522 20 22438 211 3346 111 231 2571 111 231 204 30 204 3138 204 22287 204 22287 30 204 22287 3138 204 22287 22287 204 22287 22287 30 204 22287 22287 3138 204 30 25 30 204 30 513 30 204 30 951 30 27171 236 206 38154 126 38154 225 167 237 217 38154 221 167 237 208 38154 228 38154 127 38154 237 167 237 207 38154 237 38154 107 38154 126 38154 211 20589 207 204 42 50087 123 2727 20300 32022 133 234 17419 30137 28 7858 181 133 236 204 37057 2228 10666 5052 133 6207 151 215 150 134 5052 133 6279 5052 223 151 216 49679 123 53110 47043 7795 204 7544 7544 7544 8543 8543 17593 3513 3513 12844 51520 17664 4247 295 18 298 650 204 18 95 693 332 18 94 629 23 204 18 1553 299 1310 42 204 18 56 416 1310 295 18 567 717 334 23 204 18 47 299 606 596 6696 42 703 18 16139 241 18 87 55 diff --git a/models/ggml-vocab-gpt-2.gguf b/models/ggml-vocab-gpt-2.gguf deleted file mode 100644 index 5ea85cf52e7de..0000000000000 Binary files a/models/ggml-vocab-gpt-2.gguf and /dev/null differ diff --git a/models/ggml-vocab-gpt-2.gguf.inp b/models/ggml-vocab-gpt-2.gguf.inp deleted file mode 100644 index 0a89107c60d7f..0000000000000 --- a/models/ggml-vocab-gpt-2.gguf.inp +++ /dev/null @@ -1,106 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ diff --git a/models/ggml-vocab-gpt-2.gguf.out b/models/ggml-vocab-gpt-2.gguf.out deleted file mode 100644 index 78430f0d31fdc..0000000000000 --- a/models/ggml-vocab-gpt-2.gguf.out +++ /dev/null @@ -1,43 +0,0 @@ - 798 604 25208 1933 - 37 9116 71 11751 - - 220 - 220 220 - 220 220 220 - 197 - 198 - 628 - 628 198 - 197 198 - 15496 995 - 18435 995 - 15496 2159 - 18435 2159 - 18435 2159 0 - 15496 11 995 0 - 18435 11 995 0 - 428 318 12520 99 247 13 20322 - 86 47202 767 28047 45961 288 82 7568 13415 - 22177 16843 141 231 15166 12466 121 16142 12466 239 141 232 30143 140 111 16142 21169 21727 31583 18849 - 157 252 222 157 252 114 157 252 241 157 253 233 157 252 237 157 253 224 157 252 244 157 252 115 157 252 253 157 253 223 157 252 253 157 252 95 157 252 114 157 252 227 157 252 223 157 252 249 157 252 227 157 253 223 157 252 231 - 8582 248 222 357 11265 8 30325 114 447 235 8582 234 104 37929 357 48101 795 13210 271 1673 36686 515 8 14519 227 357 8807 44805 326 468 663 898 11241 8 - 15496 - 18435 - 220 18435 - 220 220 18435 - 220 220 220 18435 - 220 220 220 18435 198 220 220 220 18435 - 357 - 198 796 - 6 6980 - 15496 11 331 6 439 0 1374 389 345 30325 223 5633 22755 239 46349 111 28839 101 18040 32432 98 43291 1485 1415 24309 25465 171 121 252 - 18 - 2091 - 20370 - 24840 - 2091 20370 - 24840 2091 - 24840 20370 - 24840 24840 - 24840 2091 20370 - 198 220 628 220 628 198 220 197 220 197 197 220 197 198 220 220 198 220 220 220 198 220 220 220 220 198 220 220 220 220 220 198 8582 248 222 357 11265 8 30325 114 447 235 8582 234 104 37929 357 48101 795 13210 271 1673 36686 515 8 14519 227 12520 99 247 8582 99 247 513 4747 23460 513 20370 23460 2091 23460 20370 23460 24840 23460 2091 20370 513 13 18 513 492 18 513 986 18 28053 252 222 157 252 114 157 252 241 157 253 233 157 252 237 157 253 224 157 252 244 157 252 115 157 252 253 157 253 223 157 252 253 157 252 95 157 252 114 157 252 227 47249 223 5633 22755 239 46349 111 28839 101 18040 32432 98 43291 1485 1415 24309 25465 171 121 252 40103 1421 18604 12466 121 16843 141 231 15166 12466 121 16142 12466 239 141 232 30143 140 111 16142 21169 21727 31583 18849 705 39115 6 33153 15506 63 15931 15931 16317 13896 3228 9805 3548 314 1053 587 705 44040 339 338 612 11 705 2200 345 1654 30 705 44 407 1654 314 1183 787 340 11 705 35 345 588 617 8887 30 775 6 26979 257 6 75 43 diff --git a/models/ggml-vocab-gpt-neox.gguf b/models/ggml-vocab-gpt-neox.gguf deleted file mode 100644 index b9af16845ccb4..0000000000000 Binary files a/models/ggml-vocab-gpt-neox.gguf and /dev/null differ diff --git a/models/ggml-vocab-gpt2.gguf b/models/ggml-vocab-gpt2.gguf deleted file mode 100644 index 1fbc72c1e4d9e..0000000000000 Binary files a/models/ggml-vocab-gpt2.gguf and /dev/null differ diff --git a/models/ggml-vocab-llama-bpe.gguf b/models/ggml-vocab-llama-bpe.gguf deleted file mode 100644 index e51a99118bc43..0000000000000 Binary files a/models/ggml-vocab-llama-bpe.gguf and /dev/null differ diff --git a/models/ggml-vocab-llama-bpe.gguf.inp b/models/ggml-vocab-llama-bpe.gguf.inp deleted file mode 100644 index 9380bf355202a..0000000000000 --- a/models/ggml-vocab-llama-bpe.gguf.inp +++ /dev/null @@ -1,108 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ - Việt -__ggml_vocab_test__ diff --git a/models/ggml-vocab-llama-bpe.gguf.out b/models/ggml-vocab-llama-bpe.gguf.out deleted file mode 100644 index 1f3607fb6a378..0000000000000 --- a/models/ggml-vocab-llama-bpe.gguf.out +++ /dev/null @@ -1,44 +0,0 @@ - 1142 220 19 220 27154 4038 - 37 51853 261 - - 220 - 256 - 262 - 197 - 198 - 271 - 1432 - 1602 - 9906 1917 - 22691 1917 - 9906 4435 - 22691 4435 - 22691 4435 0 - 9906 11 1917 0 - 22691 11 1917 0 - 420 374 11410 99 247 13 11055 - 86 23904 220 22 83 2005 42908 11729 3013 17156 - 79862 102118 13373 64571 34694 3114 112203 80112 - 21549 222 98629 241 45358 233 21549 237 45358 224 21549 244 21549 115 21549 253 45358 223 21549 253 21549 95 98629 227 21549 223 21549 249 21549 227 45358 223 21549 231 - 9468 248 222 320 8416 8 27623 114 102470 9468 234 104 31643 320 36773 100166 98634 8 26602 227 320 3323 43465 430 706 1202 1866 4037 8 - 9906 - 22691 - 220 22691 - 256 22691 - 262 22691 - 262 22691 198 262 22691 - 320 - 198 284 - 6 11639 - 9906 11 379 65948 0 2650 527 499 27623 223 949 37046 101067 19000 23182 102301 9263 18136 16 36827 21909 - 18 - 1644 - 8765 - 8765 18 - 8765 1644 - 8765 8765 - 8765 8765 18 - 8765 8765 1644 - 8765 8765 8765 - 198 4815 15073 66597 8004 1602 2355 79772 11187 9468 248 222 320 8416 8 27623 114 102470 9468 234 104 31643 320 36773 100166 98634 8 26602 227 11410 99 247 9468 99 247 220 18 220 1644 220 8765 220 8765 18 220 8765 1644 220 8765 8765 220 8765 8765 18 220 8765 8765 1644 220 18 13 18 220 18 497 18 220 18 1131 18 220 21549 222 98629 241 45358 233 21549 237 45358 224 21549 244 21549 115 21549 253 45358 223 21549 253 21549 95 98629 227 76460 223 949 37046 101067 19000 23182 102301 9263 18136 16 36827 21909 56560 54337 19175 102118 13373 64571 34694 3114 112203 80112 3436 106451 14196 14196 74694 3089 3089 29249 17523 3001 27708 7801 358 3077 1027 364 83 820 568 596 1070 11 364 793 499 2771 30 364 44 539 2771 358 3358 1304 433 11 364 35 499 1093 1063 15600 30 1226 6 43712 264 64966 43 - 101798 diff --git a/models/ggml-vocab-llama-spm.gguf b/models/ggml-vocab-llama-spm.gguf deleted file mode 100644 index 658295a5df741..0000000000000 Binary files a/models/ggml-vocab-llama-spm.gguf and /dev/null differ diff --git a/models/ggml-vocab-llama-spm.gguf.inp b/models/ggml-vocab-llama-spm.gguf.inp deleted file mode 100644 index 0a89107c60d7f..0000000000000 --- a/models/ggml-vocab-llama-spm.gguf.inp +++ /dev/null @@ -1,106 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ diff --git a/models/ggml-vocab-llama-spm.gguf.out b/models/ggml-vocab-llama-spm.gguf.out deleted file mode 100644 index 9c3327cb54380..0000000000000 --- a/models/ggml-vocab-llama-spm.gguf.out +++ /dev/null @@ -1,43 +0,0 @@ - 474 287 29871 29946 29871 30226 7378 - 383 4000 261 - - 259 - 1678 - 268 - 29871 12 - 29871 13 - 29871 13 13 - 29871 13 13 13 - 29871 12 13 - 15043 3186 - 29871 15043 3186 - 15043 2787 - 29871 15043 2787 - 29871 15043 2787 29991 - 15043 29892 3186 29991 - 29871 15043 29892 3186 29991 - 29871 445 338 29871 243 162 169 156 29889 8223 - 281 29900 29946 29947 29871 29955 9161 13535 18031 2176 6905 - 1538 4851 665 1386 29713 1305 - 29871 31849 31324 31934 228 162 142 228 161 146 228 162 133 228 161 153 228 161 186 31708 228 162 132 31708 228 161 165 31324 228 161 136 228 161 132 228 161 158 228 161 136 228 162 132 228 161 140 - 29871 243 162 157 131 313 8945 29897 29871 243 162 155 185 30722 243 162 143 174 30598 313 20787 953 3848 275 16125 630 29897 29871 31681 313 6194 953 29877 2397 393 756 967 1914 5993 29897 - 15043 - 29871 15043 - 259 15043 - 1678 15043 - 268 15043 - 268 15043 13 1678 15043 - 29871 313 - 29871 13 353 - 525 3152 - 15043 29892 343 29915 497 29991 1128 526 366 29871 243 162 155 132 1577 30672 31522 30505 11548 31041 30732 29896 29941 29896 29946 29896 29945 29896 30408 30739 - 29871 29941 - 29871 29941 29941 - 29871 29941 29941 29941 - 29871 29941 29941 29941 29941 - 29871 29941 29941 29941 29941 29941 - 29871 29941 29941 29941 29941 29941 29941 - 29871 29941 29941 29941 29941 29941 29941 29941 - 29871 29941 29941 29941 29941 29941 29941 29941 29941 - 29871 29941 29941 29941 29941 29941 29941 29941 29941 29941 - 29871 13 29871 13 13 29871 13 13 13 29871 12 29871 12 12 29871 12 13 259 13 1678 13 268 13 418 13 243 162 157 131 313 8945 29897 29871 243 162 155 185 30722 243 162 143 174 30598 313 20787 953 3848 275 16125 630 29897 29871 31681 29871 243 162 169 156 243 162 169 156 29871 29941 29871 29941 29941 29871 29941 29941 29941 29871 29941 29941 29941 29941 29871 29941 29941 29941 29941 29941 29871 29941 29941 29941 29941 29941 29941 29871 29941 29941 29941 29941 29941 29941 29941 29871 29941 29941 29941 29941 29941 29941 29941 29941 29871 29941 29889 29941 29871 29941 636 29941 29871 29941 856 29941 29871 31849 31324 31934 228 162 142 228 161 146 228 162 133 228 161 153 228 161 186 31708 228 162 132 31708 228 161 165 31324 228 161 136 243 162 155 132 1577 30672 31522 30505 11548 31041 30732 29896 29941 29896 29946 29896 29945 29896 30408 30739 448 23648 2751 25512 1538 4851 665 1386 29713 1305 14550 4907 11120 16159 16159 16159 15945 15945 3045 636 6824 6824 6824 8773 8773 8773 306 29915 345 1063 525 29873 1025 540 29915 29879 727 29892 525 1525 366 1854 29973 525 29924 451 1854 306 29915 645 1207 372 29892 525 29928 366 763 777 23429 29973 1334 29915 29963 29872 263 29915 29880 29931 diff --git a/models/ggml-vocab-mpt.gguf b/models/ggml-vocab-mpt.gguf deleted file mode 100644 index f42f56dec9294..0000000000000 Binary files a/models/ggml-vocab-mpt.gguf and /dev/null differ diff --git a/models/ggml-vocab-mpt.gguf.inp b/models/ggml-vocab-mpt.gguf.inp deleted file mode 100644 index 0a89107c60d7f..0000000000000 --- a/models/ggml-vocab-mpt.gguf.inp +++ /dev/null @@ -1,106 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ diff --git a/models/ggml-vocab-mpt.gguf.out b/models/ggml-vocab-mpt.gguf.out deleted file mode 100644 index d8d0fe90900bb..0000000000000 --- a/models/ggml-vocab-mpt.gguf.out +++ /dev/null @@ -1,43 +0,0 @@ - 728 577 24142 2607 - 39 26288 6554 - - 209 - 50276 - 50275 - 186 - 187 - 535 - 2756 - 186 187 - 12092 1533 - 24387 1533 - 12092 3645 - 24387 3645 - 24387 3645 2 - 12092 13 1533 2 - 24387 13 1533 2 - 436 310 22692 101 236 15 14161 - 88 27244 818 16853 16392 20505 4989 11917 - 32520 11514 1068 8713 38177 13396 3415 9925 12559 10453 1389 - 18081 211 18081 116 18081 230 39936 222 18081 226 39936 213 18081 233 18081 117 18081 242 39936 212 18081 242 18081 97 18081 116 18081 216 18081 212 18081 238 18081 216 39936 212 18081 220 - 14931 237 211 313 6320 10 49042 116 325 224 14931 223 106 171 118 226 313 34263 802 13511 261 32147 456 10 3384 239 216 313 7483 802 80 8020 326 556 697 1211 10669 10 - 12092 - 24387 - 50276 12092 - 50275 12092 - 50274 12092 - 50274 12092 187 50274 12092 - 313 - 187 426 - 8 8685 - 12092 13 340 8 455 2 1359 403 368 49042 212 3736 15367 41197 13610 19934 41869 21275 1012 1047 18795 40120 20422 241 - 20 - 1610 - 20084 - 26409 - 1610 20084 - 26409 1610 - 26409 20084 - 26409 26409 - 26409 1610 20084 - 586 1744 33525 186 209 623 28910 187 50276 187 50275 187 50274 187 50273 187 14931 237 211 313 6320 10 49042 116 325 224 14931 223 106 171 118 226 313 34263 802 13511 261 32147 456 10 3384 239 216 22692 101 236 14931 101 236 495 5922 30057 495 20084 495 26409 30057 20084 495 26409 1610 495 26409 20084 495 15 20 495 537 20 495 1051 20 209 18081 211 18081 116 18081 230 39936 222 18081 226 39936 213 18081 233 18081 117 18081 242 39936 212 18081 242 18081 97 18081 116 18081 216 14931 235 212 3736 15367 41197 13610 19934 41869 21275 1012 1047 18795 40120 20422 241 16081 6877 12880 11514 1068 8713 38177 13396 3415 9925 12559 10453 1389 42011 35033 34842 11202 9739 9739 33021 18963 4672 25561 8220 309 1849 644 686 42618 344 434 627 13 686 1848 368 2119 32 686 46 417 2119 309 1833 1056 352 13 686 37 368 751 690 10331 32 844 8 31516 247 8 77 45 diff --git a/models/ggml-vocab-phi-3.gguf b/models/ggml-vocab-phi-3.gguf deleted file mode 100644 index f8022a385e4aa..0000000000000 Binary files a/models/ggml-vocab-phi-3.gguf and /dev/null differ diff --git a/models/ggml-vocab-phi-3.gguf.inp b/models/ggml-vocab-phi-3.gguf.inp deleted file mode 100644 index 0a89107c60d7f..0000000000000 --- a/models/ggml-vocab-phi-3.gguf.inp +++ /dev/null @@ -1,106 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ diff --git a/models/ggml-vocab-phi-3.gguf.out b/models/ggml-vocab-phi-3.gguf.out deleted file mode 100644 index 9c3327cb54380..0000000000000 --- a/models/ggml-vocab-phi-3.gguf.out +++ /dev/null @@ -1,43 +0,0 @@ - 474 287 29871 29946 29871 30226 7378 - 383 4000 261 - - 259 - 1678 - 268 - 29871 12 - 29871 13 - 29871 13 13 - 29871 13 13 13 - 29871 12 13 - 15043 3186 - 29871 15043 3186 - 15043 2787 - 29871 15043 2787 - 29871 15043 2787 29991 - 15043 29892 3186 29991 - 29871 15043 29892 3186 29991 - 29871 445 338 29871 243 162 169 156 29889 8223 - 281 29900 29946 29947 29871 29955 9161 13535 18031 2176 6905 - 1538 4851 665 1386 29713 1305 - 29871 31849 31324 31934 228 162 142 228 161 146 228 162 133 228 161 153 228 161 186 31708 228 162 132 31708 228 161 165 31324 228 161 136 228 161 132 228 161 158 228 161 136 228 162 132 228 161 140 - 29871 243 162 157 131 313 8945 29897 29871 243 162 155 185 30722 243 162 143 174 30598 313 20787 953 3848 275 16125 630 29897 29871 31681 313 6194 953 29877 2397 393 756 967 1914 5993 29897 - 15043 - 29871 15043 - 259 15043 - 1678 15043 - 268 15043 - 268 15043 13 1678 15043 - 29871 313 - 29871 13 353 - 525 3152 - 15043 29892 343 29915 497 29991 1128 526 366 29871 243 162 155 132 1577 30672 31522 30505 11548 31041 30732 29896 29941 29896 29946 29896 29945 29896 30408 30739 - 29871 29941 - 29871 29941 29941 - 29871 29941 29941 29941 - 29871 29941 29941 29941 29941 - 29871 29941 29941 29941 29941 29941 - 29871 29941 29941 29941 29941 29941 29941 - 29871 29941 29941 29941 29941 29941 29941 29941 - 29871 29941 29941 29941 29941 29941 29941 29941 29941 - 29871 29941 29941 29941 29941 29941 29941 29941 29941 29941 - 29871 13 29871 13 13 29871 13 13 13 29871 12 29871 12 12 29871 12 13 259 13 1678 13 268 13 418 13 243 162 157 131 313 8945 29897 29871 243 162 155 185 30722 243 162 143 174 30598 313 20787 953 3848 275 16125 630 29897 29871 31681 29871 243 162 169 156 243 162 169 156 29871 29941 29871 29941 29941 29871 29941 29941 29941 29871 29941 29941 29941 29941 29871 29941 29941 29941 29941 29941 29871 29941 29941 29941 29941 29941 29941 29871 29941 29941 29941 29941 29941 29941 29941 29871 29941 29941 29941 29941 29941 29941 29941 29941 29871 29941 29889 29941 29871 29941 636 29941 29871 29941 856 29941 29871 31849 31324 31934 228 162 142 228 161 146 228 162 133 228 161 153 228 161 186 31708 228 162 132 31708 228 161 165 31324 228 161 136 243 162 155 132 1577 30672 31522 30505 11548 31041 30732 29896 29941 29896 29946 29896 29945 29896 30408 30739 448 23648 2751 25512 1538 4851 665 1386 29713 1305 14550 4907 11120 16159 16159 16159 15945 15945 3045 636 6824 6824 6824 8773 8773 8773 306 29915 345 1063 525 29873 1025 540 29915 29879 727 29892 525 1525 366 1854 29973 525 29924 451 1854 306 29915 645 1207 372 29892 525 29928 366 763 777 23429 29973 1334 29915 29963 29872 263 29915 29880 29931 diff --git a/models/ggml-vocab-qwen2.gguf b/models/ggml-vocab-qwen2.gguf deleted file mode 100644 index 541e475bc9453..0000000000000 Binary files a/models/ggml-vocab-qwen2.gguf and /dev/null differ diff --git a/models/ggml-vocab-qwen2.gguf.inp b/models/ggml-vocab-qwen2.gguf.inp deleted file mode 100644 index 0a89107c60d7f..0000000000000 --- a/models/ggml-vocab-qwen2.gguf.inp +++ /dev/null @@ -1,106 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ diff --git a/models/ggml-vocab-qwen2.gguf.out b/models/ggml-vocab-qwen2.gguf.out deleted file mode 100644 index 401a510e86f3a..0000000000000 --- a/models/ggml-vocab-qwen2.gguf.out +++ /dev/null @@ -1,43 +0,0 @@ - 1122 220 19 220 26062 3951 - 37 50753 261 - - 220 - 256 - 262 - 197 - 198 - 271 - 1406 - 1572 - 9707 1879 - 21927 1879 - 9707 4337 - 21927 4337 - 21927 4337 0 - 9707 11 1879 0 - 21927 11 1879 0 - 419 374 11162 99 247 13 10821 - 86 15 19 23 220 22 83 1963 41808 11472 2940 16739 - 78762 14144 1456 13073 63471 33594 3038 133178 79012 - 146394 97529 241 44258 233 146568 44258 224 147603 20879 115 146280 44258 223 146280 147272 97529 227 147805 148301 147270 44258 223 146848 - 145836 320 8252 8 26525 114 378 235 149921 30543 320 35673 99066 97534 8 25521 227 320 3243 42365 429 702 1181 1828 3950 8 - 9707 - 21927 - 220 21927 - 256 21927 - 262 21927 - 262 21927 198 262 21927 - 320 - 198 284 - 6 11385 - 9707 11 379 64848 0 2585 525 498 26525 223 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216 - 18 - 18 18 - 18 18 18 - 18 18 18 18 - 18 18 18 18 18 - 18 18 18 18 18 18 - 18 18 18 18 18 18 18 - 18 18 18 18 18 18 18 18 - 18 18 18 18 18 18 18 18 18 - 198 4710 14731 65497 7847 1572 2303 78672 10947 145836 320 8252 8 26525 114 378 235 149921 30543 320 35673 99066 97534 8 25521 227 11162 99 247 149955 220 18 220 18 18 220 18 18 18 220 18 18 18 18 220 18 18 18 18 18 220 18 18 18 18 18 18 220 18 18 18 18 18 18 18 220 18 18 18 18 18 18 18 18 220 18 13 18 220 18 496 18 220 18 1112 18 220 146394 97529 241 44258 233 146568 44258 224 147603 20879 115 146280 44258 223 146280 147272 97529 227 144534 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216 55460 53237 18658 14144 1456 13073 63471 33594 3038 133178 79012 3355 4605 4605 13874 13874 73594 3014 3014 28149 17085 2928 26610 7646 358 3003 1012 364 83 813 566 594 1052 11 364 787 498 2704 30 364 44 537 2704 358 3278 1281 432 11 364 35 498 1075 1045 15243 30 1205 6 42612 264 63866 43 diff --git a/models/ggml-vocab-refact.gguf b/models/ggml-vocab-refact.gguf deleted file mode 100644 index 52afcf01aeb73..0000000000000 Binary files a/models/ggml-vocab-refact.gguf and /dev/null differ diff --git a/models/ggml-vocab-refact.gguf.inp b/models/ggml-vocab-refact.gguf.inp deleted file mode 100644 index 0a89107c60d7f..0000000000000 --- a/models/ggml-vocab-refact.gguf.inp +++ /dev/null @@ -1,106 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ diff --git a/models/ggml-vocab-refact.gguf.out b/models/ggml-vocab-refact.gguf.out deleted file mode 100644 index 06b15c090c0f8..0000000000000 --- a/models/ggml-vocab-refact.gguf.out +++ /dev/null @@ -1,43 +0,0 @@ - 4833 225 38 225 143 140 17723 - 56 2006 3935 265 - - 225 - 261 - 264 - 202 - 203 - 478 - 2831 - 15773 - 8279 5788 - 12000 5788 - 8279 10896 - 12000 10896 - 12000 10896 19 - 8279 30 5788 19 - 12000 30 5788 19 - 458 438 5945 118 252 32 3766 - 105 34 38 42 225 41 102 1707 12530 10180 1479 8278 - 39862 8372 1039 9446 40242 13852 2053 8949 12531 1520 10700 - 14574 227 14574 133 14574 246 30457 238 14574 242 30457 229 14574 249 14574 134 14574 258 30457 228 14574 258 14574 114 14574 133 14574 232 14574 228 14574 254 14574 232 30457 228 14574 236 - 3807 253 227 308 4382 27 18458 133 46113 44967 123 13868 308 12565 19775 33071 40824 733 27 41889 308 2585 22680 688 1401 2819 4369 2404 27 - 8279 - 12000 - 225 12000 - 261 12000 - 264 12000 - 264 12000 284 12000 - 308 - 203 280 - 25 34666 - 8279 30 533 25 464 19 4971 884 844 18458 228 1018 4982 13368 2909 9513 17827 35 37 35 38 35 39 35 11873 47838 - 37 - 37 37 - 37 37 37 - 37 37 37 37 - 37 37 37 37 37 - 37 37 37 37 37 37 - 37 37 37 37 37 37 37 - 37 37 37 37 37 37 37 37 - 37 37 37 37 37 37 37 37 37 - 334 719 8878 202 10885 4222 16104 28570 203 3807 253 227 308 4382 27 18458 133 46113 44967 123 13868 308 12565 19775 33071 40824 733 27 41889 5945 118 252 3807 118 252 225 37 225 37 37 225 37 37 37 225 37 37 37 37 225 37 37 37 37 37 225 37 37 37 37 37 37 225 37 37 37 37 37 37 37 225 37 37 37 37 37 37 37 37 225 37 32 37 225 37 497 37 225 37 1179 37 225 14574 227 14574 133 14574 246 30457 238 14574 242 30457 229 14574 249 14574 134 14574 258 30457 228 14574 258 14574 114 14574 133 14574 232 36628 228 1018 4982 13368 2909 9513 17827 35 37 35 38 35 39 35 11873 47838 20921 16623 13028 8372 1039 9446 40242 13852 2053 8949 12531 1520 10700 5881 9592 13299 914 31753 31359 9163 3202 35472 10397 439 4763 2583 330 102 1455 938 1182 2017 30 330 613 844 3654 49 330 63 646 3654 439 4621 1930 561 30 330 54 844 2124 1629 35993 49 2688 25 7709 312 25 94 62 diff --git a/models/ggml-vocab-stablelm.gguf b/models/ggml-vocab-stablelm.gguf deleted file mode 100644 index ebb0cdb7d6a4a..0000000000000 Binary files a/models/ggml-vocab-stablelm.gguf and /dev/null differ diff --git a/models/ggml-vocab-starcoder.gguf b/models/ggml-vocab-starcoder.gguf deleted file mode 100644 index 7a7e7742ab1fc..0000000000000 Binary files a/models/ggml-vocab-starcoder.gguf and /dev/null differ diff --git a/models/ggml-vocab-starcoder.gguf.inp b/models/ggml-vocab-starcoder.gguf.inp deleted file mode 100644 index 0a89107c60d7f..0000000000000 --- a/models/ggml-vocab-starcoder.gguf.inp +++ /dev/null @@ -1,106 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ diff --git a/models/ggml-vocab-starcoder.gguf.out b/models/ggml-vocab-starcoder.gguf.out deleted file mode 100644 index ccb55c7feeef8..0000000000000 --- a/models/ggml-vocab-starcoder.gguf.out +++ /dev/null @@ -1,43 +0,0 @@ - 4850 244 57 244 162 159 17722 - 75 2022 3943 284 - - 244 - 280 - 283 - 221 - 222 - 499 - 3067 - 15767 - 8302 5810 - 12009 5810 - 8302 10914 - 12009 10914 - 12009 10914 38 - 8302 49 5810 38 - 12009 49 5810 38 - 477 458 5954 137 271 51 3779 - 124 53 57 61 244 60 121 1726 12568 10240 1519 8290 - 39916 8389 1059 9504 40216 13858 2073 8983 12571 1539 10721 - 14566 246 14566 152 14566 265 30428 257 14566 261 30428 248 14566 268 14566 153 14566 277 30428 247 14566 277 14566 133 14566 152 14566 251 14566 247 14566 273 14566 251 30428 247 14566 255 - 3822 272 246 327 4434 46 18445 152 46030 45022 142 13878 327 12585 19884 33773 40920 751 46 41839 327 2605 22716 708 1421 2840 4387 2421 46 - 8302 - 12009 - 244 12009 - 280 12009 - 283 12009 - 283 12009 303 12009 - 327 - 222 299 - 44 34719 - 8302 49 553 44 483 38 4998 904 863 18445 247 1037 4995 13379 2924 9515 17823 54 56 54 57 54 58 54 11904 47892 - 56 - 56 56 - 56 56 56 - 56 56 56 56 - 56 56 56 56 56 - 56 56 56 56 56 56 - 56 56 56 56 56 56 56 - 56 56 56 56 56 56 56 56 - 56 56 56 56 56 56 56 56 56 - 353 736 8886 221 10883 4238 16101 28540 222 3822 272 246 327 4434 46 18445 152 46030 45022 142 13878 327 12585 19884 33773 40920 751 46 41839 5954 137 271 3822 137 271 244 56 244 56 56 244 56 56 56 244 56 56 56 56 244 56 56 56 56 56 244 56 56 56 56 56 56 244 56 56 56 56 56 56 56 244 56 56 56 56 56 56 56 56 244 56 51 56 244 56 516 56 244 56 1198 56 244 14566 246 14566 152 14566 265 30428 257 14566 261 30428 248 14566 268 14566 153 14566 277 30428 247 14566 277 14566 133 14566 152 14566 251 36570 247 1037 4995 13379 2924 9515 17823 54 56 54 57 54 58 54 11904 47892 20895 16625 13047 8389 1059 9504 40216 13858 2073 8983 12571 1539 10721 5918 9643 13298 932 31723 31330 9221 3226 35426 10400 457 4783 2602 349 121 1477 957 1200 2038 49 349 632 863 3673 68 349 82 666 3673 457 4650 1949 580 49 349 73 863 2144 1649 35941 68 2726 44 7728 331 44 113 81