From cc2516790d87e74fa0cc27154b66752389951769 Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 11 Nov 2024 12:09:27 -0500 Subject: [PATCH 01/15] Model: Add support for chat_template.json HuggingFace separated the chat template in the newest transformers versions. Signed-off-by: kingbri --- backends/exllamav2/model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 6f17570..c7d2069 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -389,6 +389,10 @@ async def find_prompt_template(self, prompt_template_name, model_directory): logger.info("Attempting to load a prompt template if present.") find_template_functions = [ + lambda: PromptTemplate.from_model_json( + pathlib.Path(self.config.model_dir) / "chat_template.json", + key="chat_template", + ), lambda: PromptTemplate.from_model_json( pathlib.Path(self.config.model_dir) / "tokenizer_config.json", key="chat_template", From 69ac0eb8aad783eec9581e5f76224b2e1df58b69 Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 11 Nov 2024 12:04:40 -0500 Subject: [PATCH 02/15] Model: Add vision loading support Adds the ability to load vision parts of text + image models. Requires an explicit flag in config because there isn't a way to automatically determine whether the vision tower should be used. Signed-off-by: kingbri --- backends/exllamav2/model.py | 17 +++++++++++++++++ common/config_models.py | 6 ++++++ common/model.py | 20 +++++++++++++++----- config_sample.yml | 3 +++ endpoints/core/types/model.py | 1 + 5 files changed, 42 insertions(+), 5 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index c7d2069..df8cacf 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -20,6 +20,7 @@ ExLlamaV2Cache_TP, ExLlamaV2Tokenizer, ExLlamaV2Lora, + ExLlamaV2VisionTower, ) from exllamav2.generator import ( ExLlamaV2Sampler, @@ -28,6 +29,7 @@ ) from itertools import zip_longest from loguru import logger +from PIL import Image from typing import List, Optional, Union from ruamel.yaml import YAML @@ -91,6 +93,10 @@ class ExllamaV2Container: autosplit_reserve: List[float] = [96 * 1024**2] use_tp: bool = False + # Vision vars + use_vision: bool = False + vision_model: Optional[ExLlamaV2VisionTower] = None + # Load state model_is_loading: bool = False model_loaded: bool = False @@ -144,6 +150,9 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): # Apply a model's config overrides while respecting user settings kwargs = await self.set_model_overrides(**kwargs) + # Set vision state + self.use_vision = unwrap(kwargs.get("vision"), True) + # Prepare the draft model config if necessary draft_args = unwrap(kwargs.get("draft_model"), {}) draft_model_name = draft_args.get("draft_model_name") @@ -608,6 +617,14 @@ def progress(loaded_modules: int, total_modules: int) input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long) self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True) + # Load vision tower if it exists + if self.use_vision: + self.vision_model = ExLlamaV2VisionTower(self.config) + + for value in self.vision_model.load_gen(callback_gen=progress_callback): + if value: + yield value + self.model = ExLlamaV2(self.config) if not self.quiet: logger.info("Loading model: " + self.config.model_dir) diff --git a/common/config_models.py b/common/config_models.py index 40b4109..8102333 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -270,6 +270,12 @@ class ModelConfig(BaseConfigModel): "NOTE: Only works with chat completion message lists!" ), ) + vision: Optional[bool] = Field( + False, + description=( + "Enables vision support if the model supports it. (default: False)" + ), + ) num_experts_per_token: Optional[int] = Field( None, description=( diff --git a/common/model.py b/common/model.py index 87b06ad..d30d11b 100644 --- a/common/model.py +++ b/common/model.py @@ -33,6 +33,7 @@ class ModelType(Enum): MODEL = "model" DRAFT = "draft" EMBEDDING = "embedding" + VISION = "vision" def load_progress(module, modules): @@ -70,17 +71,26 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): # Create a new container container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs) - model_type = "draft" if container.draft_config else "model" + # Add possible types of models that can be loaded + model_type = [ModelType.MODEL] + + if container.use_vision: + model_type.insert(0, ModelType.VISION) + + if container.draft_config: + model_type.insert(0, ModelType.DRAFT) + load_status = container.load_gen(load_progress, **kwargs) progress = get_loading_progress_bar() progress.start() try: + index = 0 async for module, modules in load_status: if module == 0: loading_task = progress.add_task( - f"[cyan]Loading {model_type} modules", total=modules + f"[cyan]Loading {model_type[index].value} modules", total=modules ) else: progress.advance(loading_task) @@ -89,10 +99,10 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): if module == modules: # Switch to model progress if the draft model is loaded - if model_type == "draft": - model_type = "model" - else: + if index == len(model_type): progress.stop() + else: + index += 1 finally: progress.stop() diff --git a/config_sample.yml b/config_sample.yml index 83f2fc7..388dcf4 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -124,6 +124,9 @@ model: # NOTE: Only works with chat completion message lists! prompt_template: + # Enables vision support if the model supports it. (default: False) + vision: false + # Number of experts to use per token. # Fetched from the model's config.json if empty. # NOTE: For MoE models only. diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index f2817f0..17fa0a7 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -107,6 +107,7 @@ class ModelLoadRequest(BaseModel): cache_mode: Optional[str] = None chunk_size: Optional[int] = None prompt_template: Optional[str] = None + vision: Optional[bool] = None num_experts_per_token: Optional[int] = None # Non-config arguments From 5fa298e601de4eaa2e45968444cb8d01ce177bd5 Mon Sep 17 00:00:00 2001 From: DocShotgun <126566557+DocShotgun@users.noreply.github.com> Date: Sat, 16 Nov 2024 23:25:22 -0800 Subject: [PATCH 03/15] Vision: Define basic utils for ExLlamaV2 vision --- backends/exllamav2/vision.py | 63 ++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 backends/exllamav2/vision.py diff --git a/backends/exllamav2/vision.py b/backends/exllamav2/vision.py new file mode 100644 index 0000000..c49bf33 --- /dev/null +++ b/backends/exllamav2/vision.py @@ -0,0 +1,63 @@ +"""Vision utilities for ExLlamaV2.""" + +import io +import base64 +import re +from PIL import Image +import aiohttp +from common.networking import ( + handle_request_error, +) +from fastapi import HTTPException +from exllamav2 import ( + ExLlamaV2, + ExLlamaV2Tokenizer, + ExLlamaV2VisionTower, + ExLlamaV2MMEmbedding, +) +from functools import lru_cache + + +async def get_image(url: str) -> Image: + if url.startswith("data:image"): + # Handle base64 image + match = re.match(r"^data:image\/[a-zA-Z0-9]+;base64,(.*)$", url) + if match: + base64_image = match.group(1) + bytes_image = base64.b64decode(base64_image) + else: + error_message = handle_request_error( + "Failed to read base64 image input.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + else: + # Handle image URL + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status == 200: + bytes_image = await response.read() + else: + error_message = handle_request_error( + f"Failed to fetch image from {url}.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + return Image.open(io.BytesIO(bytes_image)) + + +@lru_cache(20) +async def get_image_embedding( + model: ExLlamaV2, + tokenizer: ExLlamaV2Tokenizer, + vision_model: ExLlamaV2VisionTower, + url: str, +) -> ExLlamaV2MMEmbedding: + image = await get_image(url) + return vision_model.get_image_embeddings( + model=model, tokenizer=tokenizer, image=image + ) From dd41eec8a4d6e8b284326bc98986295e5167c92a Mon Sep 17 00:00:00 2001 From: DocShotgun <126566557+DocShotgun@users.noreply.github.com> Date: Sun, 17 Nov 2024 21:23:09 -0800 Subject: [PATCH 04/15] OAI: Initial vision support in OAI chat completions * Support image_url inputs containing URLs or base64 strings following OAI vision spec * Use async lru cache for image embeddings * Add generic wrapper class for multimodal embeddings --- backends/exllamav2/model.py | 21 ++++++++++++-- backends/exllamav2/vision.py | 30 +++++++++---------- common/multimodal.py | 36 +++++++++++++++++++++++ endpoints/OAI/router.py | 9 ++++-- endpoints/OAI/utils/chat_completion.py | 40 +++++++++++++++++++++++--- endpoints/OAI/utils/completion.py | 4 ++- pyproject.toml | 1 + 7 files changed, 115 insertions(+), 26 deletions(-) create mode 100644 common/multimodal.py diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index df8cacf..3c6634f 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -6,6 +6,8 @@ import math import pathlib import traceback +from backends.exllamav2.vision import clear_image_embedding_cache +from common.multimodal import MultimodalEmbeddingWrapper import torch import uuid from copy import deepcopy @@ -816,6 +818,9 @@ async def unload(self, loras_only: bool = False, **kwargs): # Delete references held in the grammar module clear_grammar_func_cache() + # Clear the image embedding cache + clear_image_embedding_cache() + # Unload LoRAs if self.generator and self.generator.generator.current_loras: for lora in self.generator.generator.current_loras: @@ -908,12 +913,17 @@ def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor): return dict(zip_longest(top_tokens, cleaned_values)) async def generate( - self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs + self, + prompt: str, + embeddings: MultimodalEmbeddingWrapper, + request_id: str, + abort_event: asyncio.Event = None, + **kwargs, ): """Generate a response to a prompt.""" generations = [] async for generation in self.generate_gen( - prompt, request_id, abort_event, **kwargs + prompt, embeddings, request_id, abort_event, **kwargs ): generations.append(generation) @@ -979,6 +989,7 @@ def check_unsupported_settings(self, **kwargs): async def generate_gen( self, prompt: str, + embeddings: MultimodalEmbeddingWrapper, request_id: str, abort_event: Optional[asyncio.Event] = None, **kwargs, @@ -1246,7 +1257,10 @@ async def generate_gen( # Encode both positive and negative prompts input_ids = [ self.tokenizer.encode( - prompt, add_bos=add_bos_token, encode_special_tokens=True + prompt, + add_bos=add_bos_token, + encode_special_tokens=True, + embeddings=embeddings.content, ) for prompt in prompts ] @@ -1297,6 +1311,7 @@ async def generate_gen( banned_strings=banned_strings, token_healing=token_healing, identifier=job_id, + embeddings=embeddings.content, ) # Save generated tokens and full response diff --git a/backends/exllamav2/vision.py b/backends/exllamav2/vision.py index c49bf33..d207d3e 100644 --- a/backends/exllamav2/vision.py +++ b/backends/exllamav2/vision.py @@ -4,18 +4,14 @@ import base64 import re from PIL import Image +from common import model import aiohttp from common.networking import ( handle_request_error, ) from fastapi import HTTPException -from exllamav2 import ( - ExLlamaV2, - ExLlamaV2Tokenizer, - ExLlamaV2VisionTower, - ExLlamaV2MMEmbedding, -) -from functools import lru_cache +from exllamav2.generator import ExLlamaV2MMEmbedding +from async_lru import alru_cache async def get_image(url: str) -> Image: @@ -50,14 +46,16 @@ async def get_image(url: str) -> Image: return Image.open(io.BytesIO(bytes_image)) -@lru_cache(20) -async def get_image_embedding( - model: ExLlamaV2, - tokenizer: ExLlamaV2Tokenizer, - vision_model: ExLlamaV2VisionTower, - url: str, -) -> ExLlamaV2MMEmbedding: +@alru_cache(20) +async def get_image_embedding(url: str) -> ExLlamaV2MMEmbedding: image = await get_image(url) - return vision_model.get_image_embeddings( - model=model, tokenizer=tokenizer, image=image + return model.container.vision_model.get_image_embeddings( + model=model.container.model, + tokenizer=model.container.tokenizer, + image=image, + text_alias=None, ) + + +def clear_image_embedding_cache(): + get_image_embedding.cache_clear() diff --git a/common/multimodal.py b/common/multimodal.py new file mode 100644 index 0000000..74d4964 --- /dev/null +++ b/common/multimodal.py @@ -0,0 +1,36 @@ +from typing import List +from backends.exllamav2.vision import get_image_embedding +from common import model +from pydantic import BaseModel +from loguru import logger + +from common.optional_dependencies import dependencies + +if dependencies.exllamav2: + from exllamav2 import ExLlamaV2VisionTower + + +class MultimodalEmbeddingWrapper(BaseModel): + """Common multimodal embedding wrapper""" + + type: str = None + content: List = [] + text_alias: List[str] = [] + + +async def add_image_embedding( + embeddings: MultimodalEmbeddingWrapper, url: str +) -> MultimodalEmbeddingWrapper: + # Determine the type of vision embedding to use + if not embeddings.type: + if isinstance(model.container.vision_model, ExLlamaV2VisionTower): + embeddings.type = "ExLlamaV2MMEmbedding" + + if embeddings.type == "ExLlamaV2MMEmbedding": + embedding = await get_image_embedding(url) + embeddings.content.append(embedding) + embeddings.text_alias.append(embedding.text_alias) + else: + logger.error("No valid vision model to create embedding") + + return embeddings diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index b6a44c9..c018f5c 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -17,6 +17,7 @@ from endpoints.OAI.utils.chat_completion import ( format_prompt_with_template, generate_chat_completion, + preprocess_vision_request, stream_generate_chat_completion, ) from endpoints.OAI.utils.completion import ( @@ -126,6 +127,8 @@ async def chat_completion_request( if isinstance(data.messages, str): prompt = data.messages else: + if model.container.use_vision: + data.messages, embeddings = await preprocess_vision_request(data.messages) prompt = await format_prompt_with_template(data) # Set an empty JSON schema if the request wants a JSON response @@ -136,12 +139,14 @@ async def chat_completion_request( if data.stream and not disable_request_streaming: return EventSourceResponse( - stream_generate_chat_completion(prompt, data, request, model_path), + stream_generate_chat_completion( + prompt, embeddings, data, request, model_path + ), ping=maxsize, ) else: generate_task = asyncio.create_task( - generate_chat_completion(prompt, data, request, model_path) + generate_chat_completion(prompt, embeddings, data, request, model_path) ) response = await run_with_request_disconnect( diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 3b5c07f..a59f425 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -3,9 +3,10 @@ import asyncio import pathlib from asyncio import CancelledError -from typing import List, Optional +from typing import Dict, List, Optional import json +from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding from fastapi import HTTPException, Request from jinja2 import TemplateError from loguru import logger @@ -279,7 +280,11 @@ async def format_prompt_with_template( async def stream_generate_chat_completion( - prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path + prompt: str, + embeddings: MultimodalEmbeddingWrapper, + data: ChatCompletionRequest, + request: Request, + model_path: pathlib.Path, ): """Generator for the generation process.""" abort_event = asyncio.Event() @@ -298,6 +303,7 @@ async def stream_generate_chat_completion( n, gen_queue, prompt, + embeddings, request.state.id, abort_event, **task_gen_params.model_dump(exclude={"prompt"}), @@ -372,7 +378,11 @@ async def stream_generate_chat_completion( async def generate_chat_completion( - prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path + prompt: str, + embeddings: MultimodalEmbeddingWrapper, + data: ChatCompletionRequest, + request: Request, + model_path: pathlib.Path, ): gen_tasks: List[asyncio.Task] = [] @@ -381,7 +391,10 @@ async def generate_chat_completion( gen_tasks.append( asyncio.create_task( model.container.generate( - prompt, request.state.id, **data.model_dump(exclude={"prompt"}) + prompt, + embeddings, + request.state.id, + **data.model_dump(exclude={"prompt"}), ) ) ) @@ -459,3 +472,22 @@ def postprocess_tool_call(call_str: str) -> List[ToolCall]: tool_call["function"]["arguments"] ) return [ToolCall(**tool_call) for tool_call in tool_calls] + + +async def preprocess_vision_request(messages: List[Dict]): + embeddings = MultimodalEmbeddingWrapper() + for message in messages: + if isinstance(message["content"], list): + concatenated_content = "" + for content in message["content"]: + if content["type"] == "text": + concatenated_content += content["text"] + elif content["type"] == "image_url": + embeddings = await add_image_embedding( + embeddings, content["image_url"]["url"] + ) + concatenated_content += embeddings.text_alias[-1] + + message["content"] = concatenated_content + + return messages, embeddings diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index e939525..65ff0d3 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -7,6 +7,7 @@ import asyncio import pathlib from asyncio import CancelledError +from common.multimodal import MultimodalEmbeddingWrapper from fastapi import HTTPException, Request from typing import List, Union @@ -87,6 +88,7 @@ async def _stream_collector( task_idx: int, gen_queue: asyncio.Queue, prompt: str, + embeddings: MultimodalEmbeddingWrapper, request_id: str, abort_event: asyncio.Event, **kwargs, @@ -95,7 +97,7 @@ async def _stream_collector( try: new_generation = model.container.generate_gen( - prompt, request_id, abort_event, **kwargs + prompt, embeddings, request_id, abort_event, **kwargs ) async for generation in new_generation: generation["index"] = task_idx diff --git a/pyproject.toml b/pyproject.toml index 81f8bf2..ca4b511 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "lm-format-enforcer >= 0.9.6", "aiofiles", "aiohttp", + "async_lru", "huggingface_hub", "psutil", "httptools>=0.5.0", From c42655336be23e20baa52aa3ed7d8e84572eb04c Mon Sep 17 00:00:00 2001 From: DocShotgun <126566557+DocShotgun@users.noreply.github.com> Date: Sun, 17 Nov 2024 23:05:17 -0800 Subject: [PATCH 05/15] Config: Add option to disable fetching content from URLs --- backends/exllamav2/vision.py | 9 +++++++++ common/config_models.py | 7 +++++++ config_sample.yml | 3 +++ 3 files changed, 19 insertions(+) diff --git a/backends/exllamav2/vision.py b/backends/exllamav2/vision.py index d207d3e..168c80c 100644 --- a/backends/exllamav2/vision.py +++ b/backends/exllamav2/vision.py @@ -9,6 +9,7 @@ from common.networking import ( handle_request_error, ) +from common.tabby_config import config from fastapi import HTTPException from exllamav2.generator import ExLlamaV2MMEmbedding from async_lru import alru_cache @@ -31,6 +32,14 @@ async def get_image(url: str) -> Image: else: # Handle image URL + if config.network.disable_fetch_requests: + error_message = handle_request_error( + f"Failed to fetch image from {url} as fetch requests are disabled.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + async with aiohttp.ClientSession() as session: async with session.get(url) as response: if response.status == 200: diff --git a/common/config_models.py b/common/config_models.py index 8102333..b8e7606 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -78,6 +78,13 @@ class NetworkConfig(BaseConfigModel): "Turn on this option if you are ONLY connecting from localhost." ), ) + disable_fetch_requests: Optional[bool] = Field( + False, + description=( + "Disable fetching external content in response to requests," + "such as images from URLs." + ), + ) send_tracebacks: Optional[bool] = Field( False, description=( diff --git a/config_sample.yml b/config_sample.yml index 388dcf4..48e58d9 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -20,6 +20,9 @@ network: # Turn on this option if you are ONLY connecting from localhost. disable_auth: false + # Disable fetching external content in response to requests, such as images from URLs. + disable_fetch_requests: false + # Send tracebacks over the API (default: False). # NOTE: Only enable this for debug purposes. send_tracebacks: false From 5611365c0753b98d2bf64f04abc9ca2a3a336132 Mon Sep 17 00:00:00 2001 From: DocShotgun <126566557+DocShotgun@users.noreply.github.com> Date: Tue, 19 Nov 2024 11:14:37 -0800 Subject: [PATCH 06/15] OAI: Allow /v1/encode endpoint to handle vision requests * More robust checks for OAI chat completion message lists on /v1/encode endpoint * Added TODO to support other aspects of chat completions * Fix oversight where embeddings was not defined in advance on /v1/chat/completions endpoint --- backends/exllamav2/model.py | 5 ++++- endpoints/OAI/router.py | 3 +++ endpoints/core/router.py | 31 ++++++++++++++++++++++++++++--- endpoints/core/types/token.py | 2 +- 4 files changed, 36 insertions(+), 5 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 3c6634f..bc9142a 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -862,7 +862,9 @@ async def unload(self, loras_only: bool = False, **kwargs): async with self.load_condition: self.load_condition.notify_all() - def encode_tokens(self, text: str, **kwargs): + def encode_tokens( + self, text: str, embeddings: MultimodalEmbeddingWrapper, **kwargs + ): """Wrapper to encode tokens from a text string.""" return ( @@ -870,6 +872,7 @@ def encode_tokens(self, text: str, **kwargs): text, add_bos=unwrap(kwargs.get("add_bos_token"), True), encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True), + embeddings=embeddings.content, ) .flatten() .tolist() diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index c018f5c..acb35f9 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -1,4 +1,5 @@ import asyncio +from common.multimodal import MultimodalEmbeddingWrapper from fastapi import APIRouter, Depends, HTTPException, Request from sse_starlette import EventSourceResponse from sys import maxsize @@ -124,6 +125,8 @@ async def chat_completion_request( model_path = model.container.model_dir + embeddings = MultimodalEmbeddingWrapper() + if isinstance(data.messages, str): prompt = data.messages else: diff --git a/endpoints/core/router.py b/endpoints/core/router.py index f2b4247..0a48a2e 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -1,6 +1,7 @@ import asyncio import pathlib from sys import maxsize +from common.multimodal import MultimodalEmbeddingWrapper from fastapi import APIRouter, Depends, HTTPException, Request, Response from sse_starlette import EventSourceResponse @@ -357,10 +358,27 @@ async def unload_embedding_model(): ) async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse: """Encodes a string or chat completion messages into tokens.""" + embeddings = MultimodalEmbeddingWrapper() if isinstance(data.text, str): text = data.text - else: + elif isinstance(data.text, list) and "oai" in config.network.api_servers: + # TODO: Support additional chat completion args for encode + # i.e. add_generation_prompt, template selection, tool args, template kwargs + if model.container.prompt_template is None: + error_message = handle_request_error( + "Tokenization of chat completion requests is disabled " + "because a prompt template is not set.", + exc_info=False, + ).error.message + + raise HTTPException(422, error_message) + + from endpoints.OAI.utils.chat_completion import preprocess_vision_request + + if model.container.use_vision: + data.text, embeddings = await preprocess_vision_request(data.text) + special_tokens_dict = model.container.get_special_tokens( unwrap(data.add_bos_token, True) ) @@ -371,9 +389,16 @@ async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse: **special_tokens_dict, } - text, _ = model.container.prompt_template.render(template_vars) + text = await model.container.prompt_template.render(template_vars) + else: + error_message = handle_request_error( + "OAI API server must be enabled to handle chat completion message inputs.", + exc_info=False, + ).error.message + + raise HTTPException(422, error_message) - raw_tokens = model.container.encode_tokens(text, **data.get_params()) + raw_tokens = model.container.encode_tokens(text, embeddings, **data.get_params()) tokens = unwrap(raw_tokens, []) response = TokenEncodeResponse(tokens=tokens, length=len(tokens)) diff --git a/endpoints/core/types/token.py b/endpoints/core/types/token.py index 945adbf..2c205ab 100644 --- a/endpoints/core/types/token.py +++ b/endpoints/core/types/token.py @@ -23,7 +23,7 @@ def get_params(self): class TokenEncodeRequest(CommonTokenRequest): """Represents a tokenization request.""" - text: Union[str, List[Dict[str, str]]] + text: Union[str, List[Dict]] class TokenEncodeResponse(BaseModel): From 27d9af50a8c739f4243b1458c35f6e6f1fb142c0 Mon Sep 17 00:00:00 2001 From: DocShotgun <126566557+DocShotgun@users.noreply.github.com> Date: Tue, 19 Nov 2024 12:29:25 -0800 Subject: [PATCH 07/15] API: Report whether vision is enabled --- backends/exllamav2/model.py | 1 + endpoints/core/types/model.py | 1 + 2 files changed, 2 insertions(+) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index bc9142a..a0a7f0e 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -476,6 +476,7 @@ def get_model_parameters(self): "prompt_template": self.prompt_template.name if self.prompt_template else None, + "use_vision": self.use_vision, } if self.draft_config: diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index 17fa0a7..ddf1cc2 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -21,6 +21,7 @@ class ModelCardParameters(BaseModel): chunk_size: Optional[int] = 2048 prompt_template: Optional[str] = None num_experts_per_token: Optional[int] = None + use_vision: Optional[bool] = False # Draft is another model, so include it in the card params draft: Optional["ModelCard"] = None From 731a345cfc0ba503d05a95630dc86e6a01d35b70 Mon Sep 17 00:00:00 2001 From: DocShotgun <126566557+DocShotgun@users.noreply.github.com> Date: Tue, 19 Nov 2024 12:40:00 -0800 Subject: [PATCH 08/15] OAI: Keep behavior consistent between chat completion and encode * When vision is not enabled, only the first text block is kept in message.content if it is a list --- endpoints/core/router.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 0a48a2e..d7837f8 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -379,6 +379,20 @@ async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse: if model.container.use_vision: data.text, embeddings = await preprocess_vision_request(data.text) + # Keeping behavior consistent with format_prompt_with_template + # Deal with list in messages.content + # Just replace the content list with the very first text message + for message in data.text: + if isinstance(message["content"], list): + message["content"] = next( + ( + content["text"] + for content in message["content"] + if content["type"] == "text" + ), + "", + ) + special_tokens_dict = model.container.get_special_tokens( unwrap(data.add_bos_token, True) ) From 8ffc636dce3b5ca418661914cf8d2135e259e77c Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 19 Nov 2024 23:15:47 -0500 Subject: [PATCH 09/15] OAI: Strictly type chat completions Previously, the messages were a list of dicts. These are untyped and don't provide strict hinting. Add types for chat completion messages and reformat existing code. Signed-off-by: kingbri --- endpoints/OAI/router.py | 1 + endpoints/OAI/types/chat_completion.py | 19 +++++++++--- endpoints/OAI/utils/chat_completion.py | 42 +++++++++++++------------- 3 files changed, 37 insertions(+), 25 deletions(-) diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index acb35f9..8403a87 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -132,6 +132,7 @@ async def chat_completion_request( else: if model.container.use_vision: data.messages, embeddings = await preprocess_vision_request(data.messages) + prompt = await format_prompt_with_template(data) # Set an empty JSON schema if the request wants a JSON response diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 30ec769..86a2247 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -1,7 +1,7 @@ from pydantic import BaseModel, Field from pydantic.json_schema import SkipJsonSchema from time import time -from typing import Union, List, Optional, Dict +from typing import Literal, Union, List, Optional, Dict from uuid import uuid4 from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest @@ -18,10 +18,21 @@ class ChatCompletionLogprobs(BaseModel): content: List[ChatCompletionLogprob] = Field(default_factory=list) +class ChatCompletionImageUrl(BaseModel): + url: str + + +class ChatCompletionMessagePart(BaseModel): + type: Literal["text", "image_url"] = "text" + text: Optional[str] = None + image_url: Optional[ChatCompletionImageUrl] = None + + class ChatCompletionMessage(BaseModel): - role: Optional[str] = None - content: Optional[str] = None + role: str = "user" + content: Optional[Union[str, List[ChatCompletionMessagePart]]] = None tool_calls: Optional[List[ToolCall]] = None + tool_calls_json: SkipJsonSchema[Optional[str]] = None class ChatCompletionRespChoice(BaseModel): @@ -51,7 +62,7 @@ class ChatCompletionRequest(CommonCompletionRequest): # WIP this can probably be tightened, or maybe match the OAI lib type # in openai\types\chat\chat_completion_message_param.py - messages: Union[str, List[Dict]] + messages: List[ChatCompletionMessage] = Field(default_factory=list) prompt_template: Optional[str] = None add_generation_prompt: Optional[bool] = True template_vars: Optional[dict] = {} diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index a59f425..c14a8dc 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -1,17 +1,16 @@ """Chat completion utilities for OAI server.""" import asyncio +import json import pathlib from asyncio import CancelledError -from typing import Dict, List, Optional -import json - -from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding +from typing import List, Optional from fastapi import HTTPException, Request from jinja2 import TemplateError from loguru import logger from common import model +from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding from common.networking import ( get_generator_error, handle_request_disconnect, @@ -214,21 +213,21 @@ async def format_prompt_with_template( unwrap(data.ban_eos_token, False), ) - # Deal with list in messages.content - # Just replace the content list with the very first text message + # Convert list to text-based content + # Use the first instance of text inside the part list for message in data.messages: - if isinstance(message["content"], list): - message["content"] = next( + if isinstance(message.content, list): + message.content = next( ( - content["text"] - for content in message["content"] - if content["type"] == "text" + content.text + for content in message.content + if content.type == "text" ), "", ) - if "tool_calls" in message: - message["tool_calls_json"] = json.dumps(message["tool_calls"], indent=2) + if message.tool_calls: + message.tool_calls_json = json.dumps(message.tool_calls, indent=2) # Overwrite any protected vars with their values data.template_vars.update( @@ -474,20 +473,21 @@ def postprocess_tool_call(call_str: str) -> List[ToolCall]: return [ToolCall(**tool_call) for tool_call in tool_calls] -async def preprocess_vision_request(messages: List[Dict]): +# TODO: Combine this with the existing preprocessor in format_prompt_with_template +async def preprocess_vision_request(messages: List[ChatCompletionMessage]): embeddings = MultimodalEmbeddingWrapper() for message in messages: - if isinstance(message["content"], list): + if isinstance(message.content, list): concatenated_content = "" - for content in message["content"]: - if content["type"] == "text": - concatenated_content += content["text"] - elif content["type"] == "image_url": + for content in message.content: + if content.type == "text": + concatenated_content += content.text + elif content.type == "image_url": embeddings = await add_image_embedding( - embeddings, content["image_url"]["url"] + embeddings, content.image_url.url ) concatenated_content += embeddings.text_alias[-1] - message["content"] = concatenated_content + message.content = concatenated_content return messages, embeddings From c652a6e0301c45e30f7e709eb907dad79b525b99 Mon Sep 17 00:00:00 2001 From: kingbri Date: Wed, 20 Nov 2024 00:05:15 -0500 Subject: [PATCH 10/15] API: Transform multimodal into an actual class Migrate the add method into the class itself. Also, a BaseModel isn't needed here since this isn't a serialized class. Signed-off-by: kingbri --- common/multimodal.py | 32 +++++++++++--------------- endpoints/OAI/utils/chat_completion.py | 6 ++--- 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/common/multimodal.py b/common/multimodal.py index 74d4964..5b93f23 100644 --- a/common/multimodal.py +++ b/common/multimodal.py @@ -1,7 +1,6 @@ from typing import List from backends.exllamav2.vision import get_image_embedding from common import model -from pydantic import BaseModel from loguru import logger from common.optional_dependencies import dependencies @@ -10,27 +9,22 @@ from exllamav2 import ExLlamaV2VisionTower -class MultimodalEmbeddingWrapper(BaseModel): +class MultimodalEmbeddingWrapper: """Common multimodal embedding wrapper""" type: str = None content: List = [] text_alias: List[str] = [] - -async def add_image_embedding( - embeddings: MultimodalEmbeddingWrapper, url: str -) -> MultimodalEmbeddingWrapper: - # Determine the type of vision embedding to use - if not embeddings.type: - if isinstance(model.container.vision_model, ExLlamaV2VisionTower): - embeddings.type = "ExLlamaV2MMEmbedding" - - if embeddings.type == "ExLlamaV2MMEmbedding": - embedding = await get_image_embedding(url) - embeddings.content.append(embedding) - embeddings.text_alias.append(embedding.text_alias) - else: - logger.error("No valid vision model to create embedding") - - return embeddings + async def add(self, url: str): + # Determine the type of vision embedding to use + if not self.type: + if isinstance(model.container.vision_model, ExLlamaV2VisionTower): + self.type = "ExLlamaV2MMEmbedding" + + if self.type == "ExLlamaV2MMEmbedding": + embedding = await get_image_embedding(url) + self.content.append(embedding) + self.text_alias.append(embedding.text_alias) + else: + logger.error("No valid vision model to create embedding") diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index c14a8dc..7a31f39 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -10,7 +10,7 @@ from loguru import logger from common import model -from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding +from common.multimodal import MultimodalEmbeddingWrapper from common.networking import ( get_generator_error, handle_request_disconnect, @@ -483,9 +483,7 @@ async def preprocess_vision_request(messages: List[ChatCompletionMessage]): if content.type == "text": concatenated_content += content.text elif content.type == "image_url": - embeddings = await add_image_embedding( - embeddings, content.image_url.url - ) + await embeddings.add(content.image_url.url) concatenated_content += embeddings.text_alias[-1] message.content = concatenated_content From 902045edbbe6ba4f91c5efdd7a1ae6b308be9e1e Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 21 Nov 2024 17:51:14 -0500 Subject: [PATCH 11/15] API: Fix chat completion formatting flow Previously, the flow for parsing chat completion messages and rendering from the prompt template was disconnected between endpoints. Now, create a common function to render and handle everything appropriately afterwards. Signed-off-by: kingbri --- backends/exllamav2/model.py | 24 +++--- endpoints/OAI/router.py | 14 +--- endpoints/OAI/utils/chat_completion.py | 104 ++++++++++++------------- endpoints/OAI/utils/completion.py | 4 +- endpoints/core/router.py | 57 ++++++-------- endpoints/core/types/token.py | 6 +- 6 files changed, 93 insertions(+), 116 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index eaa431c..e82c94e 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -31,7 +31,6 @@ ) from itertools import zip_longest from loguru import logger -from PIL import Image from typing import List, Optional, Union from ruamel.yaml import YAML @@ -374,6 +373,8 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): self.draft_config.max_input_len = chunk_size self.draft_config.max_attention_size = chunk_size**2 + self.prompt_template = None + # Return the created instance return self @@ -875,17 +876,18 @@ async def unload(self, loras_only: bool = False, **kwargs): async with self.load_condition: self.load_condition.notify_all() - def encode_tokens( - self, text: str, embeddings: MultimodalEmbeddingWrapper, **kwargs - ): + def encode_tokens(self, text: str, **kwargs): """Wrapper to encode tokens from a text string.""" + mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings") + mm_embeddings_content = mm_embeddings.content if mm_embeddings else [] + return ( self.tokenizer.encode( text, add_bos=unwrap(kwargs.get("add_bos_token"), True), encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True), - embeddings=embeddings.content, + embeddings=mm_embeddings_content, ) .flatten() .tolist() @@ -931,7 +933,6 @@ def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor): async def generate( self, prompt: str, - embeddings: MultimodalEmbeddingWrapper, request_id: str, abort_event: asyncio.Event = None, **kwargs, @@ -939,7 +940,7 @@ async def generate( """Generate a response to a prompt.""" generations = [] async for generation in self.generate_gen( - prompt, embeddings, request_id, abort_event, **kwargs + prompt, request_id, abort_event, **kwargs ): generations.append(generation) @@ -1005,7 +1006,6 @@ def check_unsupported_settings(self, **kwargs): async def generate_gen( self, prompt: str, - embeddings: MultimodalEmbeddingWrapper, request_id: str, abort_event: Optional[asyncio.Event] = None, **kwargs, @@ -1270,13 +1270,17 @@ async def generate_gen( else: stop_conditions += eos_tokens + # Get multimodal embeddings if present + mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings") + mm_embeddings_content = mm_embeddings.content if mm_embeddings else [] + # Encode both positive and negative prompts input_ids = [ self.tokenizer.encode( prompt, add_bos=add_bos_token, encode_special_tokens=True, - embeddings=embeddings.content, + embeddings=mm_embeddings_content, ) for prompt in prompts ] @@ -1327,7 +1331,7 @@ async def generate_gen( banned_strings=banned_strings, token_healing=token_healing, identifier=job_id, - embeddings=embeddings.content, + embeddings=mm_embeddings_content, ) # Save generated tokens and full response diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 8403a87..8f4e7a4 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -1,5 +1,4 @@ import asyncio -from common.multimodal import MultimodalEmbeddingWrapper from fastapi import APIRouter, Depends, HTTPException, Request from sse_starlette import EventSourceResponse from sys import maxsize @@ -16,9 +15,8 @@ ) from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse from endpoints.OAI.utils.chat_completion import ( - format_prompt_with_template, + apply_chat_template, generate_chat_completion, - preprocess_vision_request, stream_generate_chat_completion, ) from endpoints.OAI.utils.completion import ( @@ -125,15 +123,7 @@ async def chat_completion_request( model_path = model.container.model_dir - embeddings = MultimodalEmbeddingWrapper() - - if isinstance(data.messages, str): - prompt = data.messages - else: - if model.container.use_vision: - data.messages, embeddings = await preprocess_vision_request(data.messages) - - prompt = await format_prompt_with_template(data) + prompt, embeddings = await apply_chat_template(data) # Set an empty JSON schema if the request wants a JSON response if data.response_format.type == "json": diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 7a31f39..84905db 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -177,11 +177,11 @@ def _create_stream_chunk( return chunk -async def _append_template_metadata(data: ChatCompletionRequest): +async def _append_template_metadata(data: ChatCompletionRequest, template_vars: dict): """Adding metadata is a one-time process.""" template_metadata = await model.container.prompt_template.extract_metadata( - data.template_vars + template_vars ) # Stop strings @@ -199,7 +199,43 @@ async def _append_template_metadata(data: ChatCompletionRequest): data.stop.extend(template_metadata.tool_starts) -async def format_prompt_with_template( +async def format_messages_with_template( + messages: List[ChatCompletionMessage], + existing_template_vars: Optional[dict] = None, + add_bos_token: bool = True, + ban_eos_token: bool = False, +): + """Barebones function to format chat completion messages into a prompt.""" + + template_vars = unwrap(existing_template_vars, {}) + mm_embeddings = MultimodalEmbeddingWrapper() if model.container.use_vision else None + + for message in messages: + if isinstance(message.content, list): + concatenated_content = "" + for content in message.content: + if content.type == "text": + concatenated_content += content.text + elif content.type == "image_url" and mm_embeddings: + await mm_embeddings.add(content.image_url.url) + concatenated_content += mm_embeddings.text_alias[-1] + + if message.tool_calls: + message.tool_calls_json = json.dumps(message.tool_calls, indent=2) + + message.content = concatenated_content + + special_tokens_dict = model.container.get_special_tokens( + add_bos_token, ban_eos_token + ) + + template_vars.update({"messages": messages, **special_tokens_dict}) + + prompt = await model.container.prompt_template.render(template_vars) + return prompt, mm_embeddings, template_vars + + +async def apply_chat_template( data: ChatCompletionRequest, tool_precursor: Optional[str] = None ): """ @@ -208,40 +244,18 @@ async def format_prompt_with_template( """ try: - special_tokens_dict = model.container.get_special_tokens( - unwrap(data.add_bos_token, True), - unwrap(data.ban_eos_token, False), - ) - - # Convert list to text-based content - # Use the first instance of text inside the part list - for message in data.messages: - if isinstance(message.content, list): - message.content = next( - ( - content.text - for content in message.content - if content.type == "text" - ), - "", - ) - - if message.tool_calls: - message.tool_calls_json = json.dumps(message.tool_calls, indent=2) - - # Overwrite any protected vars with their values data.template_vars.update( { - "messages": data.messages, "add_generation_prompt": data.add_generation_prompt, "tools_json": json.dumps(data.model_dump()["tools"], indent=2), "functions_json": json.dumps(data.functions, indent=2), "tool_precursor": tool_precursor, - **special_tokens_dict, } ) - prompt = await model.container.prompt_template.render(data.template_vars) + prompt, mm_embeddings, template_vars = await format_messages_with_template( + data.messages, data.template_vars, data.add_bos_token, data.ban_eos_token + ) # Append response prefix if present if data.response_prefix: @@ -255,14 +269,14 @@ async def format_prompt_with_template( # Removes the starting BOS token if present # This is to prevent add_bos_token from adding multiple bos tokens - bos_token = special_tokens_dict.get("bos_token") + bos_token = template_vars.get("bos_token") if bos_token and prompt.startswith(bos_token): prompt = prompt.removeprefix(bos_token) # Add template metadata - await _append_template_metadata(data) + await _append_template_metadata(data, template_vars) - return prompt + return prompt, mm_embeddings except KeyError as exc: error_message = handle_request_error( @@ -302,9 +316,9 @@ async def stream_generate_chat_completion( n, gen_queue, prompt, - embeddings, request.state.id, abort_event, + embeddings=embeddings, **task_gen_params.model_dump(exclude={"prompt"}), ) ) @@ -391,8 +405,8 @@ async def generate_chat_completion( asyncio.create_task( model.container.generate( prompt, - embeddings, request.state.id, + embeddings=embeddings, **data.model_dump(exclude={"prompt"}), ) ) @@ -439,13 +453,11 @@ async def generate_tool_calls( if gen["stop_str"] in tool_data.tool_call_start: if "text" in gen: # non streaming, all generations will have the text they generated - pre_tool_prompt = await format_prompt_with_template(data, gen["text"]) + pre_tool_prompt = await apply_chat_template(data, gen["text"]) elif current_generations is not None: # streaming, we wont have text in the generation, # we'll have to use the current_generations - pre_tool_prompt = await format_prompt_with_template( - data, current_generations - ) + pre_tool_prompt = await apply_chat_template(data, current_generations) gen_tasks.append( asyncio.create_task( @@ -471,21 +483,3 @@ def postprocess_tool_call(call_str: str) -> List[ToolCall]: tool_call["function"]["arguments"] ) return [ToolCall(**tool_call) for tool_call in tool_calls] - - -# TODO: Combine this with the existing preprocessor in format_prompt_with_template -async def preprocess_vision_request(messages: List[ChatCompletionMessage]): - embeddings = MultimodalEmbeddingWrapper() - for message in messages: - if isinstance(message.content, list): - concatenated_content = "" - for content in message.content: - if content.type == "text": - concatenated_content += content.text - elif content.type == "image_url": - await embeddings.add(content.image_url.url) - concatenated_content += embeddings.text_alias[-1] - - message.content = concatenated_content - - return messages, embeddings diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index e798176..9fd8b90 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -7,7 +7,6 @@ import asyncio import pathlib from asyncio import CancelledError -from common.multimodal import MultimodalEmbeddingWrapper from fastapi import HTTPException, Request from typing import List, Union @@ -88,7 +87,6 @@ async def _stream_collector( task_idx: int, gen_queue: asyncio.Queue, prompt: str, - embeddings: MultimodalEmbeddingWrapper, request_id: str, abort_event: asyncio.Event, **kwargs, @@ -97,7 +95,7 @@ async def _stream_collector( try: new_generation = model.container.generate_gen( - prompt, embeddings, request_id, abort_event, **kwargs + prompt, request_id, abort_event, **kwargs ) async for generation in new_generation: generation["index"] = task_idx diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 64450f4..ccb26d9 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -1,6 +1,7 @@ import asyncio import pathlib from sys import maxsize +from typing import Optional from common.multimodal import MultimodalEmbeddingWrapper from fastapi import APIRouter, Depends, HTTPException, Request, Response from sse_starlette import EventSourceResponse @@ -14,6 +15,7 @@ from common.templating import PromptTemplate, get_all_templates from common.utils import unwrap from common.health import HealthManager +from endpoints.OAI.utils.chat_completion import format_messages_with_template from endpoints.core.types.auth import AuthPermissionResponse from endpoints.core.types.download import DownloadRequest, DownloadResponse from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse @@ -359,61 +361,48 @@ async def unload_embedding_model(): ) async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse: """Encodes a string or chat completion messages into tokens.""" - embeddings = MultimodalEmbeddingWrapper() + + mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None if isinstance(data.text, str): text = data.text - elif isinstance(data.text, list) and "oai" in config.network.api_servers: - # TODO: Support additional chat completion args for encode - # i.e. add_generation_prompt, template selection, tool args, template kwargs - if model.container.prompt_template is None: + elif isinstance(data.text, list): + if "oai" not in config.network.api_servers: error_message = handle_request_error( - "Tokenization of chat completion requests is disabled " - "because a prompt template is not set.", + "Enable the OAI server to handle chat completion messages.", exc_info=False, ).error.message raise HTTPException(422, error_message) - from endpoints.OAI.utils.chat_completion import preprocess_vision_request - - if model.container.use_vision: - data.text, embeddings = await preprocess_vision_request(data.text) - - # Keeping behavior consistent with format_prompt_with_template - # Deal with list in messages.content - # Just replace the content list with the very first text message - for message in data.text: - if isinstance(message["content"], list): - message["content"] = next( - ( - content["text"] - for content in message["content"] - if content["type"] == "text" - ), - "", - ) - - special_tokens_dict = model.container.get_special_tokens( - unwrap(data.add_bos_token, True) - ) + if not model.container.prompt_template: + error_message = handle_request_error( + "Cannot tokenize chat completion message because " + + "a prompt template is not set.", + exc_info=False, + ).error.message + + raise HTTPException(422, error_message) template_vars = { - "messages": data.text, "add_generation_prompt": False, - **special_tokens_dict, } - text = await model.container.prompt_template.render(template_vars) + # Don't need template vars again + text, mm_embeddings, _ = await format_messages_with_template( + data.text, template_vars, data.add_bos_token + ) else: error_message = handle_request_error( - "OAI API server must be enabled to handle chat completion message inputs.", + "Unable to tokenize the provided text. Check your formatting?", exc_info=False, ).error.message raise HTTPException(422, error_message) - raw_tokens = model.container.encode_tokens(text, embeddings, **data.get_params()) + raw_tokens = model.container.encode_tokens( + text, embeddings=mm_embeddings, **data.get_params() + ) tokens = unwrap(raw_tokens, []) response = TokenEncodeResponse(tokens=tokens, length=len(tokens)) diff --git a/endpoints/core/types/token.py b/endpoints/core/types/token.py index 2c205ab..d43e65e 100644 --- a/endpoints/core/types/token.py +++ b/endpoints/core/types/token.py @@ -1,7 +1,9 @@ """Tokenization types""" from pydantic import BaseModel -from typing import Dict, List, Union +from typing import List, Union + +from endpoints.OAI.types.chat_completion import ChatCompletionMessage class CommonTokenRequest(BaseModel): @@ -23,7 +25,7 @@ def get_params(self): class TokenEncodeRequest(CommonTokenRequest): """Represents a tokenization request.""" - text: Union[str, List[Dict]] + text: Union[str, List[ChatCompletionMessage]] class TokenEncodeResponse(BaseModel): From 0ab393f09c7b022be02b590059d7ae368b97984d Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 21 Nov 2024 17:54:42 -0500 Subject: [PATCH 12/15] Model: Set vision load to False by default Mistake in unwrapping. Vision should be false to allow normal model loading when the flag isn't provided. Signed-off-by: kingbri --- backends/exllamav2/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index e82c94e..beb4512 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -152,7 +152,7 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): kwargs = await self.set_model_overrides(**kwargs) # Set vision state - self.use_vision = unwrap(kwargs.get("vision"), True) + self.use_vision = unwrap(kwargs.get("vision"), False) # Prepare the draft model config if necessary draft_args = unwrap(kwargs.get("draft_model"), {}) From c49047eea1192bb728173f2d8bacc3246e057b39 Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 21 Nov 2024 18:06:47 -0500 Subject: [PATCH 13/15] Model: Fix load packets The model_type internal reference was changed to an enum for a more extendable loading process. Return the current model type when loading a new model. Signed-off-by: kingbri --- common/model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/common/model.py b/common/model.py index d30d11b..3eac446 100644 --- a/common/model.py +++ b/common/model.py @@ -88,14 +88,15 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): try: index = 0 async for module, modules in load_status: + current_model_type = model_type[index].value if module == 0: loading_task = progress.add_task( - f"[cyan]Loading {model_type[index].value} modules", total=modules + f"[cyan]Loading {current_model_type} modules", total=modules ) else: progress.advance(loading_task) - yield module, modules, model_type + yield module, modules, current_model_type if module == modules: # Switch to model progress if the draft model is loaded From eadc71a4c3e889c28d8466035e48ac41f00e87c5 Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 22 Nov 2024 14:25:03 -0500 Subject: [PATCH 14/15] Model: Add unload and error messages for vision If vision is enabled and the model doesn't support it, send an error asking the user to reload. Also, add a method to unload the vision tower. Signed-off-by: kingbri --- backends/exllamav2/model.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index beb4512..ff11531 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -151,8 +151,14 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): # Apply a model's config overrides while respecting user settings kwargs = await self.set_model_overrides(**kwargs) - # Set vision state + # Set vision state and error if vision isn't supported on the current model self.use_vision = unwrap(kwargs.get("vision"), False) + if self.use_vision and not self.config.vision_model_type: + raise ValueError( + "The provided model does not have vision capabilities that are " + "supported by ExllamaV2. " + "Please reload with vision disabled." + ) # Prepare the draft model config if necessary draft_args = unwrap(kwargs.get("draft_model"), {}) @@ -373,8 +379,6 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): self.draft_config.max_input_len = chunk_size self.draft_config.max_attention_size = chunk_size**2 - self.prompt_template = None - # Return the created instance return self @@ -848,6 +852,16 @@ async def unload(self, loras_only: bool = False, **kwargs): self.model.unload() self.model = None + if self.vision_model: + # TODO: Remove this with newer exl2 versions + # Required otherwise unload function won't finish + try: + self.vision_model.unload() + except AttributeError: + pass + + self.vision_model = None + if self.draft_model: self.draft_model.unload() self.draft_model = None From 388d36e6bda2b3a2c30a3e3520b744ccf011bf17 Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 22 Nov 2024 17:30:29 -0500 Subject: [PATCH 15/15] OAI: Fix chat completion list parsing The strings weren't being concatenated properly. Only add the combined text if the chat completion type is a List. Signed-off-by: kingbri --- endpoints/OAI/utils/chat_completion.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 84905db..14a2243 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -220,11 +220,12 @@ async def format_messages_with_template( await mm_embeddings.add(content.image_url.url) concatenated_content += mm_embeddings.text_alias[-1] + # Convert the message content into a concatenated string + message.content = concatenated_content + if message.tool_calls: message.tool_calls_json = json.dumps(message.tool_calls, indent=2) - message.content = concatenated_content - special_tokens_dict = model.container.get_special_tokens( add_bos_token, ban_eos_token )