Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vision #249

Merged
merged 16 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import math
import pathlib
import traceback
from backends.exllamav2.vision import clear_image_embedding_cache
from common.multimodal import MultimodalEmbeddingWrapper
import torch
import uuid
from copy import deepcopy
Expand All @@ -20,6 +22,7 @@
ExLlamaV2Cache_TP,
ExLlamaV2Tokenizer,
ExLlamaV2Lora,
ExLlamaV2VisionTower,
)
from exllamav2.generator import (
ExLlamaV2Sampler,
Expand Down Expand Up @@ -91,6 +94,10 @@ class ExllamaV2Container:
autosplit_reserve: List[float] = [96 * 1024**2]
use_tp: bool = False

# Vision vars
use_vision: bool = False
vision_model: Optional[ExLlamaV2VisionTower] = None

# Load state
model_is_loading: bool = False
model_loaded: bool = False
Expand Down Expand Up @@ -144,6 +151,15 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
# Apply a model's config overrides while respecting user settings
kwargs = await self.set_model_overrides(**kwargs)

# Set vision state and error if vision isn't supported on the current model
self.use_vision = unwrap(kwargs.get("vision"), False)
if self.use_vision and not self.config.vision_model_type:
raise ValueError(
"The provided model does not have vision capabilities that are "
"supported by ExllamaV2. "
"Please reload with vision disabled."
)

# Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft_model"), {})
draft_model_name = draft_args.get("draft_model_name")
Expand Down Expand Up @@ -477,6 +493,7 @@ def get_model_parameters(self):
"prompt_template": self.prompt_template.name
if self.prompt_template
else None,
"use_vision": self.use_vision,
}

if self.draft_config:
Expand Down Expand Up @@ -620,6 +637,14 @@ def progress(loaded_modules: int, total_modules: int)
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True)

# Load vision tower if it exists
if self.use_vision:
self.vision_model = ExLlamaV2VisionTower(self.config)

for value in self.vision_model.load_gen(callback_gen=progress_callback):
if value:
yield value

self.model = ExLlamaV2(self.config)
if not self.quiet:
logger.info("Loading model: " + self.config.model_dir)
Expand Down Expand Up @@ -811,6 +836,9 @@ async def unload(self, loras_only: bool = False, **kwargs):
# Delete references held in the grammar module
clear_grammar_func_cache()

# Clear the image embedding cache
clear_image_embedding_cache()

# Unload LoRAs
if self.generator and self.generator.generator.current_loras:
for lora in self.generator.generator.current_loras:
Expand All @@ -824,6 +852,16 @@ async def unload(self, loras_only: bool = False, **kwargs):
self.model.unload()
self.model = None

if self.vision_model:
# TODO: Remove this with newer exl2 versions
# Required otherwise unload function won't finish
try:
self.vision_model.unload()
except AttributeError:
pass

self.vision_model = None

if self.draft_model:
self.draft_model.unload()
self.draft_model = None
Expand Down Expand Up @@ -855,11 +893,15 @@ async def unload(self, loras_only: bool = False, **kwargs):
def encode_tokens(self, text: str, **kwargs):
"""Wrapper to encode tokens from a text string."""

mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings")
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []

return (
self.tokenizer.encode(
text,
add_bos=unwrap(kwargs.get("add_bos_token"), True),
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
embeddings=mm_embeddings_content,
)
.flatten()
.tolist()
Expand Down Expand Up @@ -903,7 +945,11 @@ def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor):
return dict(zip_longest(top_tokens, cleaned_values))

async def generate(
self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs
self,
prompt: str,
request_id: str,
abort_event: asyncio.Event = None,
**kwargs,
):
"""Generate a response to a prompt."""
generations = []
Expand Down Expand Up @@ -1238,10 +1284,17 @@ async def generate_gen(
else:
stop_conditions += eos_tokens

# Get multimodal embeddings if present
mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings")
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []

# Encode both positive and negative prompts
input_ids = [
self.tokenizer.encode(
prompt, add_bos=add_bos_token, encode_special_tokens=True
prompt,
add_bos=add_bos_token,
encode_special_tokens=True,
embeddings=mm_embeddings_content,
)
for prompt in prompts
]
Expand Down Expand Up @@ -1292,6 +1345,7 @@ async def generate_gen(
banned_strings=banned_strings,
token_healing=token_healing,
identifier=job_id,
embeddings=mm_embeddings_content,
)

# Save generated tokens and full response
Expand Down
70 changes: 70 additions & 0 deletions backends/exllamav2/vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Vision utilities for ExLlamaV2."""

import io
import base64
import re
from PIL import Image
from common import model
import aiohttp
from common.networking import (
handle_request_error,
)
from common.tabby_config import config
from fastapi import HTTPException
from exllamav2.generator import ExLlamaV2MMEmbedding
from async_lru import alru_cache


async def get_image(url: str) -> Image:
if url.startswith("data:image"):
# Handle base64 image
match = re.match(r"^data:image\/[a-zA-Z0-9]+;base64,(.*)$", url)
if match:
base64_image = match.group(1)
bytes_image = base64.b64decode(base64_image)
else:
error_message = handle_request_error(
"Failed to read base64 image input.",
exc_info=False,
).error.message

raise HTTPException(400, error_message)

else:
# Handle image URL
if config.network.disable_fetch_requests:
error_message = handle_request_error(
f"Failed to fetch image from {url} as fetch requests are disabled.",
exc_info=False,
).error.message

raise HTTPException(400, error_message)

async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
bytes_image = await response.read()
else:
error_message = handle_request_error(
f"Failed to fetch image from {url}.",
exc_info=False,
).error.message

raise HTTPException(400, error_message)

return Image.open(io.BytesIO(bytes_image))


@alru_cache(20)
async def get_image_embedding(url: str) -> ExLlamaV2MMEmbedding:
image = await get_image(url)
return model.container.vision_model.get_image_embeddings(
model=model.container.model,
tokenizer=model.container.tokenizer,
image=image,
text_alias=None,
)


def clear_image_embedding_cache():
get_image_embedding.cache_clear()
13 changes: 13 additions & 0 deletions common/config_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ class NetworkConfig(BaseConfigModel):
"Turn on this option if you are ONLY connecting from localhost."
),
)
disable_fetch_requests: Optional[bool] = Field(
False,
description=(
"Disable fetching external content in response to requests,"
"such as images from URLs."
),
)
send_tracebacks: Optional[bool] = Field(
False,
description=(
Expand Down Expand Up @@ -281,6 +288,12 @@ class ModelConfig(BaseConfigModel):
"NOTE: Only works with chat completion message lists!"
),
)
vision: Optional[bool] = Field(
False,
description=(
"Enables vision support if the model supports it. (default: False)"
),
)
num_experts_per_token: Optional[int] = Field(
None,
description=(
Expand Down
23 changes: 17 additions & 6 deletions common/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class ModelType(Enum):
MODEL = "model"
DRAFT = "draft"
EMBEDDING = "embedding"
VISION = "vision"


def load_progress(module, modules):
Expand Down Expand Up @@ -70,29 +71,39 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
# Create a new container
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)

model_type = "draft" if container.draft_config else "model"
# Add possible types of models that can be loaded
model_type = [ModelType.MODEL]

if container.use_vision:
model_type.insert(0, ModelType.VISION)

if container.draft_config:
model_type.insert(0, ModelType.DRAFT)

load_status = container.load_gen(load_progress, **kwargs)

progress = get_loading_progress_bar()
progress.start()

try:
index = 0
async for module, modules in load_status:
current_model_type = model_type[index].value
if module == 0:
loading_task = progress.add_task(
f"[cyan]Loading {model_type} modules", total=modules
f"[cyan]Loading {current_model_type} modules", total=modules
)
else:
progress.advance(loading_task)

yield module, modules, model_type
yield module, modules, current_model_type

if module == modules:
# Switch to model progress if the draft model is loaded
if model_type == "draft":
model_type = "model"
else:
if index == len(model_type):
progress.stop()
else:
index += 1
finally:
progress.stop()

Expand Down
30 changes: 30 additions & 0 deletions common/multimodal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import List
from backends.exllamav2.vision import get_image_embedding
from common import model
from loguru import logger

from common.optional_dependencies import dependencies

if dependencies.exllamav2:
from exllamav2 import ExLlamaV2VisionTower


class MultimodalEmbeddingWrapper:
"""Common multimodal embedding wrapper"""

type: str = None
content: List = []
text_alias: List[str] = []

async def add(self, url: str):
# Determine the type of vision embedding to use
if not self.type:
if isinstance(model.container.vision_model, ExLlamaV2VisionTower):
self.type = "ExLlamaV2MMEmbedding"

if self.type == "ExLlamaV2MMEmbedding":
embedding = await get_image_embedding(url)
self.content.append(embedding)
self.text_alias.append(embedding.text_alias)
else:
logger.error("No valid vision model to create embedding")
6 changes: 6 additions & 0 deletions config_sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ network:
# Turn on this option if you are ONLY connecting from localhost.
disable_auth: false

# Disable fetching external content in response to requests, such as images from URLs.
disable_fetch_requests: false

# Send tracebacks over the API (default: False).
# NOTE: Only enable this for debug purposes.
send_tracebacks: false
Expand Down Expand Up @@ -130,6 +133,9 @@ model:
# NOTE: Only works with chat completion message lists!
prompt_template:

# Enables vision support if the model supports it. (default: False)
vision: false

# Number of experts to use per token.
# Fetched from the model's config.json if empty.
# NOTE: For MoE models only.
Expand Down
13 changes: 6 additions & 7 deletions endpoints/OAI/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
from endpoints.OAI.utils.chat_completion import (
format_prompt_with_template,
apply_chat_template,
generate_chat_completion,
stream_generate_chat_completion,
)
Expand Down Expand Up @@ -123,10 +123,7 @@ async def chat_completion_request(

model_path = model.container.model_dir

if isinstance(data.messages, str):
prompt = data.messages
else:
prompt = await format_prompt_with_template(data)
prompt, embeddings = await apply_chat_template(data)

# Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json":
Expand All @@ -136,12 +133,14 @@ async def chat_completion_request(

if data.stream and not disable_request_streaming:
return EventSourceResponse(
stream_generate_chat_completion(prompt, data, request, model_path),
stream_generate_chat_completion(
prompt, embeddings, data, request, model_path
),
ping=maxsize,
)
else:
generate_task = asyncio.create_task(
generate_chat_completion(prompt, data, request, model_path)
generate_chat_completion(prompt, embeddings, data, request, model_path)
)

response = await run_with_request_disconnect(
Expand Down
Loading
Loading