diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index aeaa233..ff11531 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -6,6 +6,8 @@ import math import pathlib import traceback +from backends.exllamav2.vision import clear_image_embedding_cache +from common.multimodal import MultimodalEmbeddingWrapper import torch import uuid from copy import deepcopy @@ -20,6 +22,7 @@ ExLlamaV2Cache_TP, ExLlamaV2Tokenizer, ExLlamaV2Lora, + ExLlamaV2VisionTower, ) from exllamav2.generator import ( ExLlamaV2Sampler, @@ -91,6 +94,10 @@ class ExllamaV2Container: autosplit_reserve: List[float] = [96 * 1024**2] use_tp: bool = False + # Vision vars + use_vision: bool = False + vision_model: Optional[ExLlamaV2VisionTower] = None + # Load state model_is_loading: bool = False model_loaded: bool = False @@ -144,6 +151,15 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): # Apply a model's config overrides while respecting user settings kwargs = await self.set_model_overrides(**kwargs) + # Set vision state and error if vision isn't supported on the current model + self.use_vision = unwrap(kwargs.get("vision"), False) + if self.use_vision and not self.config.vision_model_type: + raise ValueError( + "The provided model does not have vision capabilities that are " + "supported by ExllamaV2. " + "Please reload with vision disabled." + ) + # Prepare the draft model config if necessary draft_args = unwrap(kwargs.get("draft_model"), {}) draft_model_name = draft_args.get("draft_model_name") @@ -477,6 +493,7 @@ def get_model_parameters(self): "prompt_template": self.prompt_template.name if self.prompt_template else None, + "use_vision": self.use_vision, } if self.draft_config: @@ -620,6 +637,14 @@ def progress(loaded_modules: int, total_modules: int) input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long) self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True) + # Load vision tower if it exists + if self.use_vision: + self.vision_model = ExLlamaV2VisionTower(self.config) + + for value in self.vision_model.load_gen(callback_gen=progress_callback): + if value: + yield value + self.model = ExLlamaV2(self.config) if not self.quiet: logger.info("Loading model: " + self.config.model_dir) @@ -811,6 +836,9 @@ async def unload(self, loras_only: bool = False, **kwargs): # Delete references held in the grammar module clear_grammar_func_cache() + # Clear the image embedding cache + clear_image_embedding_cache() + # Unload LoRAs if self.generator and self.generator.generator.current_loras: for lora in self.generator.generator.current_loras: @@ -824,6 +852,16 @@ async def unload(self, loras_only: bool = False, **kwargs): self.model.unload() self.model = None + if self.vision_model: + # TODO: Remove this with newer exl2 versions + # Required otherwise unload function won't finish + try: + self.vision_model.unload() + except AttributeError: + pass + + self.vision_model = None + if self.draft_model: self.draft_model.unload() self.draft_model = None @@ -855,11 +893,15 @@ async def unload(self, loras_only: bool = False, **kwargs): def encode_tokens(self, text: str, **kwargs): """Wrapper to encode tokens from a text string.""" + mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings") + mm_embeddings_content = mm_embeddings.content if mm_embeddings else [] + return ( self.tokenizer.encode( text, add_bos=unwrap(kwargs.get("add_bos_token"), True), encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True), + embeddings=mm_embeddings_content, ) .flatten() .tolist() @@ -903,7 +945,11 @@ def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor): return dict(zip_longest(top_tokens, cleaned_values)) async def generate( - self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs + self, + prompt: str, + request_id: str, + abort_event: asyncio.Event = None, + **kwargs, ): """Generate a response to a prompt.""" generations = [] @@ -1238,10 +1284,17 @@ async def generate_gen( else: stop_conditions += eos_tokens + # Get multimodal embeddings if present + mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings") + mm_embeddings_content = mm_embeddings.content if mm_embeddings else [] + # Encode both positive and negative prompts input_ids = [ self.tokenizer.encode( - prompt, add_bos=add_bos_token, encode_special_tokens=True + prompt, + add_bos=add_bos_token, + encode_special_tokens=True, + embeddings=mm_embeddings_content, ) for prompt in prompts ] @@ -1292,6 +1345,7 @@ async def generate_gen( banned_strings=banned_strings, token_healing=token_healing, identifier=job_id, + embeddings=mm_embeddings_content, ) # Save generated tokens and full response diff --git a/backends/exllamav2/vision.py b/backends/exllamav2/vision.py new file mode 100644 index 0000000..168c80c --- /dev/null +++ b/backends/exllamav2/vision.py @@ -0,0 +1,70 @@ +"""Vision utilities for ExLlamaV2.""" + +import io +import base64 +import re +from PIL import Image +from common import model +import aiohttp +from common.networking import ( + handle_request_error, +) +from common.tabby_config import config +from fastapi import HTTPException +from exllamav2.generator import ExLlamaV2MMEmbedding +from async_lru import alru_cache + + +async def get_image(url: str) -> Image: + if url.startswith("data:image"): + # Handle base64 image + match = re.match(r"^data:image\/[a-zA-Z0-9]+;base64,(.*)$", url) + if match: + base64_image = match.group(1) + bytes_image = base64.b64decode(base64_image) + else: + error_message = handle_request_error( + "Failed to read base64 image input.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + else: + # Handle image URL + if config.network.disable_fetch_requests: + error_message = handle_request_error( + f"Failed to fetch image from {url} as fetch requests are disabled.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status == 200: + bytes_image = await response.read() + else: + error_message = handle_request_error( + f"Failed to fetch image from {url}.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + return Image.open(io.BytesIO(bytes_image)) + + +@alru_cache(20) +async def get_image_embedding(url: str) -> ExLlamaV2MMEmbedding: + image = await get_image(url) + return model.container.vision_model.get_image_embeddings( + model=model.container.model, + tokenizer=model.container.tokenizer, + image=image, + text_alias=None, + ) + + +def clear_image_embedding_cache(): + get_image_embedding.cache_clear() diff --git a/common/config_models.py b/common/config_models.py index b113194..f7f0add 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -78,6 +78,13 @@ class NetworkConfig(BaseConfigModel): "Turn on this option if you are ONLY connecting from localhost." ), ) + disable_fetch_requests: Optional[bool] = Field( + False, + description=( + "Disable fetching external content in response to requests," + "such as images from URLs." + ), + ) send_tracebacks: Optional[bool] = Field( False, description=( @@ -281,6 +288,12 @@ class ModelConfig(BaseConfigModel): "NOTE: Only works with chat completion message lists!" ), ) + vision: Optional[bool] = Field( + False, + description=( + "Enables vision support if the model supports it. (default: False)" + ), + ) num_experts_per_token: Optional[int] = Field( None, description=( diff --git a/common/model.py b/common/model.py index 87b06ad..3eac446 100644 --- a/common/model.py +++ b/common/model.py @@ -33,6 +33,7 @@ class ModelType(Enum): MODEL = "model" DRAFT = "draft" EMBEDDING = "embedding" + VISION = "vision" def load_progress(module, modules): @@ -70,29 +71,39 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): # Create a new container container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs) - model_type = "draft" if container.draft_config else "model" + # Add possible types of models that can be loaded + model_type = [ModelType.MODEL] + + if container.use_vision: + model_type.insert(0, ModelType.VISION) + + if container.draft_config: + model_type.insert(0, ModelType.DRAFT) + load_status = container.load_gen(load_progress, **kwargs) progress = get_loading_progress_bar() progress.start() try: + index = 0 async for module, modules in load_status: + current_model_type = model_type[index].value if module == 0: loading_task = progress.add_task( - f"[cyan]Loading {model_type} modules", total=modules + f"[cyan]Loading {current_model_type} modules", total=modules ) else: progress.advance(loading_task) - yield module, modules, model_type + yield module, modules, current_model_type if module == modules: # Switch to model progress if the draft model is loaded - if model_type == "draft": - model_type = "model" - else: + if index == len(model_type): progress.stop() + else: + index += 1 finally: progress.stop() diff --git a/common/multimodal.py b/common/multimodal.py new file mode 100644 index 0000000..5b93f23 --- /dev/null +++ b/common/multimodal.py @@ -0,0 +1,30 @@ +from typing import List +from backends.exllamav2.vision import get_image_embedding +from common import model +from loguru import logger + +from common.optional_dependencies import dependencies + +if dependencies.exllamav2: + from exllamav2 import ExLlamaV2VisionTower + + +class MultimodalEmbeddingWrapper: + """Common multimodal embedding wrapper""" + + type: str = None + content: List = [] + text_alias: List[str] = [] + + async def add(self, url: str): + # Determine the type of vision embedding to use + if not self.type: + if isinstance(model.container.vision_model, ExLlamaV2VisionTower): + self.type = "ExLlamaV2MMEmbedding" + + if self.type == "ExLlamaV2MMEmbedding": + embedding = await get_image_embedding(url) + self.content.append(embedding) + self.text_alias.append(embedding.text_alias) + else: + logger.error("No valid vision model to create embedding") diff --git a/config_sample.yml b/config_sample.yml index 39593db..ebea5a1 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -20,6 +20,9 @@ network: # Turn on this option if you are ONLY connecting from localhost. disable_auth: false + # Disable fetching external content in response to requests, such as images from URLs. + disable_fetch_requests: false + # Send tracebacks over the API (default: False). # NOTE: Only enable this for debug purposes. send_tracebacks: false @@ -130,6 +133,9 @@ model: # NOTE: Only works with chat completion message lists! prompt_template: + # Enables vision support if the model supports it. (default: False) + vision: false + # Number of experts to use per token. # Fetched from the model's config.json if empty. # NOTE: For MoE models only. diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index b6a44c9..8f4e7a4 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -15,7 +15,7 @@ ) from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse from endpoints.OAI.utils.chat_completion import ( - format_prompt_with_template, + apply_chat_template, generate_chat_completion, stream_generate_chat_completion, ) @@ -123,10 +123,7 @@ async def chat_completion_request( model_path = model.container.model_dir - if isinstance(data.messages, str): - prompt = data.messages - else: - prompt = await format_prompt_with_template(data) + prompt, embeddings = await apply_chat_template(data) # Set an empty JSON schema if the request wants a JSON response if data.response_format.type == "json": @@ -136,12 +133,14 @@ async def chat_completion_request( if data.stream and not disable_request_streaming: return EventSourceResponse( - stream_generate_chat_completion(prompt, data, request, model_path), + stream_generate_chat_completion( + prompt, embeddings, data, request, model_path + ), ping=maxsize, ) else: generate_task = asyncio.create_task( - generate_chat_completion(prompt, data, request, model_path) + generate_chat_completion(prompt, embeddings, data, request, model_path) ) response = await run_with_request_disconnect( diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 30ec769..86a2247 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -1,7 +1,7 @@ from pydantic import BaseModel, Field from pydantic.json_schema import SkipJsonSchema from time import time -from typing import Union, List, Optional, Dict +from typing import Literal, Union, List, Optional, Dict from uuid import uuid4 from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest @@ -18,10 +18,21 @@ class ChatCompletionLogprobs(BaseModel): content: List[ChatCompletionLogprob] = Field(default_factory=list) +class ChatCompletionImageUrl(BaseModel): + url: str + + +class ChatCompletionMessagePart(BaseModel): + type: Literal["text", "image_url"] = "text" + text: Optional[str] = None + image_url: Optional[ChatCompletionImageUrl] = None + + class ChatCompletionMessage(BaseModel): - role: Optional[str] = None - content: Optional[str] = None + role: str = "user" + content: Optional[Union[str, List[ChatCompletionMessagePart]]] = None tool_calls: Optional[List[ToolCall]] = None + tool_calls_json: SkipJsonSchema[Optional[str]] = None class ChatCompletionRespChoice(BaseModel): @@ -51,7 +62,7 @@ class ChatCompletionRequest(CommonCompletionRequest): # WIP this can probably be tightened, or maybe match the OAI lib type # in openai\types\chat\chat_completion_message_param.py - messages: Union[str, List[Dict]] + messages: List[ChatCompletionMessage] = Field(default_factory=list) prompt_template: Optional[str] = None add_generation_prompt: Optional[bool] = True template_vars: Optional[dict] = {} diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 3b5c07f..14a2243 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -1,16 +1,16 @@ """Chat completion utilities for OAI server.""" import asyncio +import json import pathlib from asyncio import CancelledError from typing import List, Optional -import json - from fastapi import HTTPException, Request from jinja2 import TemplateError from loguru import logger from common import model +from common.multimodal import MultimodalEmbeddingWrapper from common.networking import ( get_generator_error, handle_request_disconnect, @@ -177,11 +177,11 @@ def _create_stream_chunk( return chunk -async def _append_template_metadata(data: ChatCompletionRequest): +async def _append_template_metadata(data: ChatCompletionRequest, template_vars: dict): """Adding metadata is a one-time process.""" template_metadata = await model.container.prompt_template.extract_metadata( - data.template_vars + template_vars ) # Stop strings @@ -199,7 +199,44 @@ async def _append_template_metadata(data: ChatCompletionRequest): data.stop.extend(template_metadata.tool_starts) -async def format_prompt_with_template( +async def format_messages_with_template( + messages: List[ChatCompletionMessage], + existing_template_vars: Optional[dict] = None, + add_bos_token: bool = True, + ban_eos_token: bool = False, +): + """Barebones function to format chat completion messages into a prompt.""" + + template_vars = unwrap(existing_template_vars, {}) + mm_embeddings = MultimodalEmbeddingWrapper() if model.container.use_vision else None + + for message in messages: + if isinstance(message.content, list): + concatenated_content = "" + for content in message.content: + if content.type == "text": + concatenated_content += content.text + elif content.type == "image_url" and mm_embeddings: + await mm_embeddings.add(content.image_url.url) + concatenated_content += mm_embeddings.text_alias[-1] + + # Convert the message content into a concatenated string + message.content = concatenated_content + + if message.tool_calls: + message.tool_calls_json = json.dumps(message.tool_calls, indent=2) + + special_tokens_dict = model.container.get_special_tokens( + add_bos_token, ban_eos_token + ) + + template_vars.update({"messages": messages, **special_tokens_dict}) + + prompt = await model.container.prompt_template.render(template_vars) + return prompt, mm_embeddings, template_vars + + +async def apply_chat_template( data: ChatCompletionRequest, tool_precursor: Optional[str] = None ): """ @@ -208,40 +245,18 @@ async def format_prompt_with_template( """ try: - special_tokens_dict = model.container.get_special_tokens( - unwrap(data.add_bos_token, True), - unwrap(data.ban_eos_token, False), - ) - - # Deal with list in messages.content - # Just replace the content list with the very first text message - for message in data.messages: - if isinstance(message["content"], list): - message["content"] = next( - ( - content["text"] - for content in message["content"] - if content["type"] == "text" - ), - "", - ) - - if "tool_calls" in message: - message["tool_calls_json"] = json.dumps(message["tool_calls"], indent=2) - - # Overwrite any protected vars with their values data.template_vars.update( { - "messages": data.messages, "add_generation_prompt": data.add_generation_prompt, "tools_json": json.dumps(data.model_dump()["tools"], indent=2), "functions_json": json.dumps(data.functions, indent=2), "tool_precursor": tool_precursor, - **special_tokens_dict, } ) - prompt = await model.container.prompt_template.render(data.template_vars) + prompt, mm_embeddings, template_vars = await format_messages_with_template( + data.messages, data.template_vars, data.add_bos_token, data.ban_eos_token + ) # Append response prefix if present if data.response_prefix: @@ -255,14 +270,14 @@ async def format_prompt_with_template( # Removes the starting BOS token if present # This is to prevent add_bos_token from adding multiple bos tokens - bos_token = special_tokens_dict.get("bos_token") + bos_token = template_vars.get("bos_token") if bos_token and prompt.startswith(bos_token): prompt = prompt.removeprefix(bos_token) # Add template metadata - await _append_template_metadata(data) + await _append_template_metadata(data, template_vars) - return prompt + return prompt, mm_embeddings except KeyError as exc: error_message = handle_request_error( @@ -279,7 +294,11 @@ async def format_prompt_with_template( async def stream_generate_chat_completion( - prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path + prompt: str, + embeddings: MultimodalEmbeddingWrapper, + data: ChatCompletionRequest, + request: Request, + model_path: pathlib.Path, ): """Generator for the generation process.""" abort_event = asyncio.Event() @@ -300,6 +319,7 @@ async def stream_generate_chat_completion( prompt, request.state.id, abort_event, + embeddings=embeddings, **task_gen_params.model_dump(exclude={"prompt"}), ) ) @@ -372,7 +392,11 @@ async def stream_generate_chat_completion( async def generate_chat_completion( - prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path + prompt: str, + embeddings: MultimodalEmbeddingWrapper, + data: ChatCompletionRequest, + request: Request, + model_path: pathlib.Path, ): gen_tasks: List[asyncio.Task] = [] @@ -381,7 +405,10 @@ async def generate_chat_completion( gen_tasks.append( asyncio.create_task( model.container.generate( - prompt, request.state.id, **data.model_dump(exclude={"prompt"}) + prompt, + request.state.id, + embeddings=embeddings, + **data.model_dump(exclude={"prompt"}), ) ) ) @@ -427,13 +454,11 @@ async def generate_tool_calls( if gen["stop_str"] in tool_data.tool_call_start: if "text" in gen: # non streaming, all generations will have the text they generated - pre_tool_prompt = await format_prompt_with_template(data, gen["text"]) + pre_tool_prompt = await apply_chat_template(data, gen["text"]) elif current_generations is not None: # streaming, we wont have text in the generation, # we'll have to use the current_generations - pre_tool_prompt = await format_prompt_with_template( - data, current_generations - ) + pre_tool_prompt = await apply_chat_template(data, current_generations) gen_tasks.append( asyncio.create_task( diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 597930b..ccb26d9 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -1,6 +1,8 @@ import asyncio import pathlib from sys import maxsize +from typing import Optional +from common.multimodal import MultimodalEmbeddingWrapper from fastapi import APIRouter, Depends, HTTPException, Request, Response from sse_starlette import EventSourceResponse @@ -13,6 +15,7 @@ from common.templating import PromptTemplate, get_all_templates from common.utils import unwrap from common.health import HealthManager +from endpoints.OAI.utils.chat_completion import format_messages_with_template from endpoints.core.types.auth import AuthPermissionResponse from endpoints.core.types.download import DownloadRequest, DownloadResponse from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse @@ -359,22 +362,47 @@ async def unload_embedding_model(): async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse: """Encodes a string or chat completion messages into tokens.""" + mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None + if isinstance(data.text, str): text = data.text - else: - special_tokens_dict = model.container.get_special_tokens( - unwrap(data.add_bos_token, True) - ) + elif isinstance(data.text, list): + if "oai" not in config.network.api_servers: + error_message = handle_request_error( + "Enable the OAI server to handle chat completion messages.", + exc_info=False, + ).error.message + + raise HTTPException(422, error_message) + + if not model.container.prompt_template: + error_message = handle_request_error( + "Cannot tokenize chat completion message because " + + "a prompt template is not set.", + exc_info=False, + ).error.message + + raise HTTPException(422, error_message) template_vars = { - "messages": data.text, "add_generation_prompt": False, - **special_tokens_dict, } - text, _ = model.container.prompt_template.render(template_vars) + # Don't need template vars again + text, mm_embeddings, _ = await format_messages_with_template( + data.text, template_vars, data.add_bos_token + ) + else: + error_message = handle_request_error( + "Unable to tokenize the provided text. Check your formatting?", + exc_info=False, + ).error.message - raw_tokens = model.container.encode_tokens(text, **data.get_params()) + raise HTTPException(422, error_message) + + raw_tokens = model.container.encode_tokens( + text, embeddings=mm_embeddings, **data.get_params() + ) tokens = unwrap(raw_tokens, []) response = TokenEncodeResponse(tokens=tokens, length=len(tokens)) diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index f2817f0..ddf1cc2 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -21,6 +21,7 @@ class ModelCardParameters(BaseModel): chunk_size: Optional[int] = 2048 prompt_template: Optional[str] = None num_experts_per_token: Optional[int] = None + use_vision: Optional[bool] = False # Draft is another model, so include it in the card params draft: Optional["ModelCard"] = None @@ -107,6 +108,7 @@ class ModelLoadRequest(BaseModel): cache_mode: Optional[str] = None chunk_size: Optional[int] = None prompt_template: Optional[str] = None + vision: Optional[bool] = None num_experts_per_token: Optional[int] = None # Non-config arguments diff --git a/endpoints/core/types/token.py b/endpoints/core/types/token.py index 945adbf..d43e65e 100644 --- a/endpoints/core/types/token.py +++ b/endpoints/core/types/token.py @@ -1,7 +1,9 @@ """Tokenization types""" from pydantic import BaseModel -from typing import Dict, List, Union +from typing import List, Union + +from endpoints.OAI.types.chat_completion import ChatCompletionMessage class CommonTokenRequest(BaseModel): @@ -23,7 +25,7 @@ def get_params(self): class TokenEncodeRequest(CommonTokenRequest): """Represents a tokenization request.""" - text: Union[str, List[Dict[str, str]]] + text: Union[str, List[ChatCompletionMessage]] class TokenEncodeResponse(BaseModel): diff --git a/pyproject.toml b/pyproject.toml index dc54ebe..d09129b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "lm-format-enforcer >= 0.9.6", "aiofiles", "aiohttp", + "async_lru", "huggingface_hub", "psutil", "httptools>=0.5.0",