From 94493c2a78884e8e7d068b87d8fca1904877d863 Mon Sep 17 00:00:00 2001 From: Jonathan Jordan Date: Thu, 25 Apr 2024 09:51:28 +0200 Subject: [PATCH 1/2] Add Llamacpp backend (#81) * Add context limit check function to backends/util.py * Add model entries to registry * Add handling of optional model loading flags for CPU/GPU usage and GPU layer offload * Add openchat_3.5-GGUF-q5 to model registry * Add llama.cpp backend howto --- .gitignore | 1 + backends/llamacpp_api.py | 198 ++++++++++++++++++++++++++ backends/model_registry.json | 73 ++++++++++ backends/utils.py | 32 ++++- docs/howto_use_llama-cpp_backend.md | 48 +++++++ docs/model_backend_registry_readme.md | 26 ++++ setup_llamacpp_cuda122.sh | 6 + 7 files changed, 382 insertions(+), 2 deletions(-) create mode 100644 backends/llamacpp_api.py create mode 100644 docs/howto_use_llama-cpp_backend.md create mode 100644 setup_llamacpp_cuda122.sh diff --git a/.gitignore b/.gitignore index 591c973539..301cc6eb39 100644 --- a/.gitignore +++ b/.gitignore @@ -192,3 +192,4 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ /venv_hf/ +/venv_llamacpp/ diff --git a/backends/llamacpp_api.py b/backends/llamacpp_api.py new file mode 100644 index 0000000000..2cbafc9b3a --- /dev/null +++ b/backends/llamacpp_api.py @@ -0,0 +1,198 @@ +""" + Backend using llama.cpp for GGUF/GGML models. +""" + +from typing import List, Dict, Tuple, Any + +import backends +from backends.utils import check_context_limit_generic + +import llama_cpp +from llama_cpp import Llama + +logger = backends.get_logger(__name__) + + +def load_model(model_spec: backends.ModelSpec) -> Any: + """ + Load GGUF/GGML model weights from HuggingFace, into VRAM if available. Weights are distributed over all available + GPUs for maximum speed - make sure to limit the available GPUs using environment variables if only a subset is to be + used. + :param model_spec: The ModelSpec for the model. + :return: The llama_cpp model class instance of the loaded model. + """ + logger.info(f'Start loading llama.cpp model weights from HuggingFace: {model_spec.model_name}') + + hf_repo_id = model_spec['huggingface_id'] + hf_model_file = model_spec['filename'] + + # default to GPU offload: + gpu_layers_offloaded = -1 # -1 = offload all model layers to GPU + # check for optional execute_on flag: + if hasattr(model_spec, 'execute_on'): + if model_spec.execute_on == "gpu": + gpu_layers_offloaded = -1 + elif model_spec.execute_on == "cpu": + gpu_layers_offloaded = 0 + # check for optional gpu_layers_offloaded value: + elif hasattr(model_spec, 'gpu_layers_offloaded'): + gpu_layers_offloaded = model_spec.gpu_layers_offloaded + + if 'requires_api_key' in model_spec and model_spec['requires_api_key']: + # load HF API key: + creds = backends.load_credentials("huggingface") + api_key = creds["huggingface"]["api_key"] + model = Llama.from_pretrained(hf_repo_id, hf_model_file, token=api_key, verbose=False, + n_gpu_layers=gpu_layers_offloaded, n_ctx=0) + else: + model = Llama.from_pretrained(hf_repo_id, hf_model_file, verbose=False, n_gpu_layers=gpu_layers_offloaded, + n_ctx=0) + + logger.info(f"Finished loading llama.cpp model: {model_spec.model_name}") + + return model + + +def get_chat_formatter(model: Llama, model_spec: backends.ModelSpec) -> llama_cpp.llama_chat_format.Jinja2ChatFormatter: + # placeholders for BOS/EOS: + bos_string = None + eos_string = None + + # check chat template: + if model_spec.premade_chat_template: + # jinja chat template available in metadata + chat_template = model.metadata['tokenizer.chat_template'] + else: + chat_template = model_spec.custom_chat_template + + if hasattr(model, 'chat_format'): + if not model.chat_format: + # no guessed chat format + pass + else: + if model.chat_format == "chatml": + # get BOS/EOS strings for chatml from llama.cpp: + bos_string = llama_cpp.llama_chat_format.CHATML_BOS_TOKEN + eos_string = llama_cpp.llama_chat_format.CHATML_EOS_TOKEN + elif model.chat_format == "mistral-instruct": + # get BOS/EOS strings for mistral-instruct from llama.cpp: + bos_string = llama_cpp.llama_chat_format.MISTRAL_INSTRUCT_BOS_TOKEN + eos_string = llama_cpp.llama_chat_format.MISTRAL_INSTRUCT_EOS_TOKEN + + # get BOS/EOS token string from model file: + # NOTE: These may not be the expected tokens, checking these when model is added is likely necessary! + if "tokenizer.ggml.bos_token_id" in model.metadata: + bos_string = model._model.token_get_text(int(model.metadata.get("tokenizer.ggml.bos_token_id"))) + if "tokenizer.ggml.eos_token_id" in model.metadata: + eos_string = model._model.token_get_text(int(model.metadata.get("tokenizer.ggml.eos_token_id"))) + + # get BOS/EOS strings for template from registry if not available from model file: + if not bos_string: + bos_string = model_spec.bos_string + if not eos_string: + eos_string = model_spec.eos_string + + # init llama-cpp-python jinja chat formatter: + chat_formatter = llama_cpp.llama_chat_format.Jinja2ChatFormatter( + template=chat_template, + bos_token=bos_string, + eos_token=eos_string + ) + + return chat_formatter + + +class LlamaCPPLocal(backends.Backend): + """ + Model/backend handler class for locally-run GGUF/GGML models. + """ + def __init__(self): + super().__init__() + + def get_model_for(self, model_spec: backends.ModelSpec) -> backends.Model: + """ + Get a LlamaCPPLocalModel instance with the passed model and settings. Will load all required data for using + the model upon initialization. + :param model_spec: The ModelSpec for the model. + :return: The Model class instance of the model. + """ + return LlamaCPPLocalModel(model_spec) + + +class LlamaCPPLocalModel(backends.Model): + """ + Class for loaded llama.cpp models ready for generation. + """ + def __init__(self, model_spec: backends.ModelSpec): + super().__init__(model_spec) + self.model = load_model(model_spec) + + self.chat_formatter = get_chat_formatter(self.model, model_spec) + + if hasattr(self.model, 'chat_handler'): + if not self.model.chat_handler: + # no custom chat handler + pass + else: + # specific chat handlers may be needed for multimodal models + # see https://llama-cpp-python.readthedocs.io/en/latest/#multi-modal-models + pass + + # get context size from model instance: + self.context_size = self.model._n_ctx + + def generate_response(self, messages: List[Dict], return_full_text: bool = False) -> Tuple[Any, Any, str]: + """ + :param messages: for example + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who won the world series in 2020?"}, + {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."}, + {"role": "user", "content": "Where was it played?"} + ] + :param return_full_text: If True, whole input context is returned. + :return: the continuation + """ + # use llama.cpp jinja to apply chat template for prompt: + prompt_text = self.chat_formatter(messages=messages).prompt + + prompt = {"inputs": prompt_text, "max_new_tokens": self.get_max_tokens(), + "temperature": self.get_temperature(), "return_full_text": return_full_text} + + prompt_tokens = self.model.tokenize(prompt_text.encode(), add_bos=False) # BOS expected in template + + # check context limit: + check_context_limit_generic(self.context_size, prompt_tokens, self.model_spec.model_name, + max_new_tokens=self.get_max_tokens()) + + # NOTE: HF transformers models come with their own generation configs, but llama.cpp doesn't seem to have a + # feature like that. There are default sampling parameters, and clembench only handles two of them so far, which + # are set accordingly. Other parameters use the llama-cpp-python default values for now. + + # NOTE: llama.cpp has a set sampling order, which differs from that of HF transformers. The latter allows + # individual sampling orders defined in the generation config that comes with HF models. + + model_output = self.model( + prompt_text, + temperature=self.get_temperature(), + max_tokens=self.get_max_tokens() + ) + + response = {'response': model_output} + + # cull input context: + if not return_full_text: + response_text = model_output['choices'][0]['text'].strip() + + if 'output_split_prefix' in self.model_spec: + response_text = response_text.rsplit(self.model_spec['output_split_prefix'], maxsplit=1)[1] + + eos_len = len(self.model_spec['eos_to_cull']) + + if response_text.endswith(self.model_spec['eos_to_cull']): + response_text = response_text[:-eos_len] + + else: + response_text = prompt_text + model_output['choices'][0]['text'].strip() + + return prompt, response, response_text diff --git a/backends/model_registry.json b/backends/model_registry.json index 11f7aca12c..a742ac337b 100644 --- a/backends/model_registry.json +++ b/backends/model_registry.json @@ -462,5 +462,78 @@ "huggingface_id": "google/gemma-7b-it", "premade_chat_template": true, "eos_to_cull": "" + }, + { + "model_name": "Qwen1.5-0.5B-Chat-GGUF-q8", + "backend": "llamacpp", + "huggingface_id": "Qwen/Qwen1.5-0.5B-Chat-GGUF", + "filename": "*q8_0.gguf", + "premade_chat_template": true, + "bos_string": "", + "eos_string": "<|im_end|>", + "eos_to_cull": "<|im_end|>" + }, + { + "model_name": "CapybaraHermes-2.5-Mistral-7B-GGUF-q4", + "backend": "llamacpp", + "huggingface_id": "TheBloke/CapybaraHermes-2.5-Mistral-7B-GGUF", + "filename": "*q4_0.gguf", + "premade_chat_template": true, + "bos_string": "", + "eos_string": "<|im_end|>", + "eos_to_cull": "<|im_end|>" + }, + { + "model_name": "CapybaraHermes-2.5-Mistral-7B-GGUF-q5", + "backend": "llamacpp", + "huggingface_id": "TheBloke/CapybaraHermes-2.5-Mistral-7B-GGUF", + "filename": "*q5_0.gguf", + "premade_chat_template": true, + "bos_string": "", + "eos_string": "<|im_end|>", + "eos_to_cull": "<|im_end|>" + }, + { + "model_name": "CapybaraHermes-2.5-Mistral-7B-GGUF-q5-k-s", + "backend": "llamacpp", + "huggingface_id": "TheBloke/CapybaraHermes-2.5-Mistral-7B-GGUF", + "filename": "*q5_k_s.gguf", + "premade_chat_template": true, + "bos_string": "", + "eos_string": "<|im_end|>", + "eos_to_cull": "<|im_end|>" + }, + { + "model_name": "EstopianMaid-13B-GGUF-q2-k", + "backend": "llamacpp", + "huggingface_id": "TheBloke/EstopianMaid-13B-GGUF", + "filename": "*q2_k.gguf", + "premade_chat_template": false, + "custom_chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'].strip() + '\\n\\n' %}{% else %}{% set loop_messages = messages %}{% set system_message = '' %}{% endif %}{% if system_message %}{{ bos_token + system_message }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{bos_token + '### Instruction:\\n' + message['content'].strip() + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response:\\n' + message['content'].strip() + eos_token + '\\n\\n' }}{% endif %}{% if loop.last and message['role'] == 'user' and add_generation_prompt %}{{ '### Response:\\n' }}{% endif %}{% endfor %}", + "bos_string": "", + "eos_string": "", + "eos_to_cull": "" + }, + { + "model_name": "EstopianMaid-13B-GGUF-q3-k-s", + "backend": "llamacpp", + "huggingface_id": "TheBloke/EstopianMaid-13B-GGUF", + "filename": "*q3_k_s.gguf", + "premade_chat_template": false, + "custom_chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'].strip() + '\\n\\n' %}{% else %}{% set loop_messages = messages %}{% set system_message = '' %}{% endif %}{% if system_message %}{{ bos_token + system_message }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{bos_token + '### Instruction:\\n' + message['content'].strip() + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response:\\n' + message['content'].strip() + eos_token + '\\n\\n' }}{% endif %}{% if loop.last and message['role'] == 'user' and add_generation_prompt %}{{ '### Response:\\n' }}{% endif %}{% endfor %}", + "bos_string": "", + "eos_string": "", + "eos_to_cull": "" + }, + { + "model_name": "openchat_3.5-GGUF-q5", + "backend": "llamacpp", + "huggingface_id": "TheBloke/openchat_3.5-GGUF", + "filename": "*q5_0.gguf", + "premade_chat_template": false, + "custom_chat_template": "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}", + "bos_string": "", + "eos_string": "<|end_of_turn|>", + "eos_to_cull": "<|end_of_turn|>" } ] \ No newline at end of file diff --git a/backends/utils.py b/backends/utils.py index cf38a8f046..7dbf66f51f 100644 --- a/backends/utils.py +++ b/backends/utils.py @@ -1,8 +1,8 @@ import copy from functools import wraps -from typing import List, Dict +from typing import List, Dict, Tuple -from backends import get_logger +from backends import get_logger, ContextExceededError logger = get_logger(__name__) @@ -60,3 +60,31 @@ def wrapped_fn(self, messages): return generate_response_fn(self, _messages) return wrapped_fn + + +def check_context_limit_generic(context_size: int, prompt_tokens: List, model_name: str, max_new_tokens: int = 100) \ + -> Tuple[bool, int, int, int]: + """ + Internal context limit check to run in generate_response. + :param context_size: The + :param prompt_tokens: List of prompt token IDs. + :param model_name: Name of the model checked for. + :param max_new_tokens: How many tokens to generate ('at most', but no stop sequence is defined). + :return: Tuple with + Bool: True if context limit is not exceeded, False if too many tokens + Number of tokens for the given messages and maximum new tokens + Number of tokens of 'context space left' + Total context token limit + """ + prompt_size = len(prompt_tokens) + tokens_used = prompt_size + max_new_tokens # context includes tokens to be generated + tokens_left = context_size - tokens_used + fits = tokens_used <= context_size + + if not fits: + logger.info(f"Context token limit for {model_name} exceeded: {tokens_used}/{tokens_left}") + # fail gracefully: + raise ContextExceededError(f"Context token limit for {model_name} exceeded", + tokens_used=tokens_used, tokens_left=tokens_left, context_size=context_size) + + return fits, tokens_used, tokens_left, context_size diff --git a/docs/howto_use_llama-cpp_backend.md b/docs/howto_use_llama-cpp_backend.md new file mode 100644 index 0000000000..d6475a67b1 --- /dev/null +++ b/docs/howto_use_llama-cpp_backend.md @@ -0,0 +1,48 @@ +# Setup and usage of llama.cpp clembench backend +This guide covers the installation and usage of the llama.cpp-based backend for clembench. This backend allows the use +of models in the GGUF format, supporting pre-quantized model versions and merged models. The setup varies by available +hardware backend and operating system, and models may need to be loaded with specific arguments depending on the setup. +## Content +[Setup](#setup) +[Model loading](#model-loading) +## Setup +The clembench llama.cpp backend relies on the llama-cpp-python library, which wraps C++ llama.cpp. To allow the usage of +specific hardware, specially GPUs, the installation must include a fitting version of llama.cpp. This may entail +compiling llama.cpp, but pre-compiled versions for specific hardware are available. +Since this is specific to the available hardware, please refer to the [llama-cpp-python installation instructions](https://llama-cpp-python.readthedocs.io/en/latest/#installation) +to install the library. It is recommended to use one of the pre-built wheels for the available hardware, as this does not require a C++ compiler +and compiling llama.cpp during the installation. +### Sample setup script +The following example shell script installs the clembench llama.cpp backend with support for CUDA 12.2 GPUs: +```shell +# create separate venv for running the llama.cpp backend: +python3 -m venv venv_llamacpp +source venv_llamacpp/bin/activate +# install basic clembench requirements: +pip3 install -r requirements.txt +# install llama-cpp-python using pre-built wheel with CUDA 12.2 support: +pip3 install llama-cpp-python --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cu122 +``` +## Model loading +The clembench llama.cpp backend downloads model files from HuggingFace model repositories. See the [model registry readme](model_backend_registry_readme.md). +By default, the clembench llama.cpp backend loads all model layers onto the available GPU(s). This requires that during +setup, proper llama.cpp GPU support fitting the system hardware was installed. +Optionally, models can be loaded to run on CPU (using RAM instead of GPU VRAM). This is required if llama-cpp-python was +installed without GPU support. This can be done by passing a JSON object to the clembench CLI scripts, or a Python `dict` +to the model loading function of the clembench `backends`. +The JSON object/`dict` has to contain the model name as defined in the [model registry](model_backend_registry_readme.md) +and the key `execute_on` with string value `gpu` or `cpu`: +```python +model_on_gpu = {'model_name': "openchat_3.5-GGUF-q5", 'execute_on': "gpu"} +model_on_cpu = {'model_name': "openchat_3.5-GGUF-q5", 'execute_on': "cpu"} +``` +For clembench CLI scripts, the JSON object is given as a "-delimited string: +```shell +# run the taboo clemgame with openchat_3.5-GGUF-q5 on CPU: +python3 scripts/cli.py run -g taboo -m "{'model_name': 'openchat_3.5-GGUF-q5', 'execute_on': 'cpu'}" +``` +Alternatively, the number of model layers to offload to GPU can be set by using the `gpu_layers_offloaded` key with an +integer value: +```python +model_15_layers_on_gpu = {'model_name': "openchat_3.5-GGUF-q5", 'gpu_layers_offloaded': 15} +``` \ No newline at end of file diff --git a/docs/model_backend_registry_readme.md b/docs/model_backend_registry_readme.md index fbb2931ed7..8b0fe2795a 100644 --- a/docs/model_backend_registry_readme.md +++ b/docs/model_backend_registry_readme.md @@ -17,6 +17,32 @@ The following key/values are **optional**, but should be defined for models that `custom_chat_template`(string): A jinja2 template string of the chat template to be applied for this model. This should be set if `premade_chat_template` is `false` for the model, as the generic fallback chat template that will be used if this is not defined is likely to lead to bad model performance. `slow_tokenizer`(bool): If `true`, the backend will load the model's tokenizer with `use_fast=False`. Some models require the use of a 'slow' tokenizer class to assure proper tokenization. `output_split_prefix`(string): The model's raw output will be rsplit using this string, and the remaining output following this string will be considered the model output. This is necessary for some models that decode tokens differently than they encode them, to assure that the prompt is properly removed from model responses. Example: `assistant\n` +### llama.cpp Backend +This backend requires these **mandatory** key/values: +`huggingface_id`(string): The full huggingface model ID; huggingface user name / model name. Example: `TheBloke/openchat_3.5-GGUF` +`filename`(string): This is a string used as a regular expression to download the specific model file for a specific +quantization/version of the model on the HuggingFace repository. It is case-sensitive. Please check the repository +defined in `huggingface_id` for the proper file name. Example: `*Q5_0.gguf` for the q5 version of `openchat_3.5-GGUF` on +the `TheBloke/openchat_3.5-GGUF` repository. +`premade_chat_template`(bool): If `true`, the chat template that is applied for generation is loaded from the model +repository on huggingface. If `false`, the value of `custom_chat_template` will be used if defined, otherwise a generic +chat template is applied (highly discouraged). +`eos_to_cull`(string): This is the string representation of the model's EOS token. It needs to be removed by the backend to assure proper processing by clembench. Example: `<|im_end|>` (This is mandatory as there are models that do not define this in their tokenizer configuration.) + +The following key/values are **optional**, but should be defined for models that require them for proper functioning: +`requires_api_key`(bool): If `true`, the backend will load a huggingface api access key/token from `key.json`, which is required to access 'gated' models like Meta's Llama2. +`custom_chat_template`(string): A jinja2 template string of the chat template to be applied for this model. This should be set if `premade_chat_template` is `false` for the model, as the generic fallback chat template that will be used if this is not defined is likely to lead to bad model performance. +`bos_string` (string): In case the model file does not contain a predefined BOS token, this string will be used to +create the logged input prompt. +`eos_string` (string): In case the model file does not contain a predefined EOS token, this string will be used to +create the logged input prompt. +`output_split_prefix`(string): The model's raw output will be rsplit using this string, and the remaining output following this string will be considered the model output. This is necessary for some models that decode tokens differently than they encode them, to assure that the prompt is properly removed from model responses. Example: `assistant\n` +#### Advanced +These key/values are recommended to only be used with a custom registry file: +`execute_on` (string): Either `gpu`, to run the model with all layers loaded to GPU using VRAM, or `cpu` to run the model on CPU +only, using main RAM. `gpu` requires a llama.cpp installation with GPU support, `cpu` one with CPU support. +`gpu_layers_offloaded` (integer): The number of model layers to offload to GPU/VRAM. This requires a llama.cpp +installation with GPU support. This key is only used if there is no `execute_on` key in the model entry. # Backend Classes Model registry entries are mainly used for two classes: `backends.ModelSpec` and `backends.Model`. ## ModelSpec diff --git a/setup_llamacpp_cuda122.sh b/setup_llamacpp_cuda122.sh new file mode 100644 index 0000000000..94d27c9dc3 --- /dev/null +++ b/setup_llamacpp_cuda122.sh @@ -0,0 +1,6 @@ +#!/bin/bash +python3 -m venv venv_llamacpp +source venv_llamacpp/bin/activate +pip3 install -r requirements.txt +# install using pre-built wheel with CUDA 12.2 support: +pip3 install llama-cpp-python --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cu122 \ No newline at end of file From afccf94fece3e57735b21d7b56635cbec2b15b1e Mon Sep 17 00:00:00 2001 From: Philipp Sadler Date: Fri, 26 Apr 2024 16:26:24 +0200 Subject: [PATCH 2/2] remove double file_handler for backend logger's --- logging.yaml | 2 +- tests/test_logging.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/logging.yaml b/logging.yaml index 339906fcd4..ab0fdc4ee0 100644 --- a/logging.yaml +++ b/logging.yaml @@ -17,7 +17,7 @@ loggers: handlers: [ console ] backends: level: DEBUG - handlers: [ console, file_handler ] + handlers: [ console ] benchmark: level: DEBUG games: diff --git a/tests/test_logging.py b/tests/test_logging.py index 47369ccfa4..9514ec9d3b 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -7,4 +7,5 @@ class TabooTestCase(unittest.TestCase): def test_get_model_for_huggingface_local_logs_infos(self): load_model_registry() - get_model_for("llama-2-7b-chat-hf") + model = get_model_for("vicuna-7b-v1.5") + assert model is not None