Skip to content

Commit

Permalink
FEAT: Chat vl web UI (#882)
Browse files Browse the repository at this point in the history
  • Loading branch information
codingl2k1 authored Jan 11, 2024
1 parent 1ba2ef7 commit 6bfd80f
Show file tree
Hide file tree
Showing 13 changed files with 944 additions and 29 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ install_requires =
xoscar>=0.2.1
torch
gradio>=3.39.0
pillow
click
tqdm>=4.27
tabulate
Expand Down Expand Up @@ -67,7 +68,6 @@ dev =
flake8>=3.8.0
black
openai>1
pillow
opencv-python
langchain
orjson
Expand Down
7 changes: 4 additions & 3 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ async def build_gradio_interface(
but calling API in async function does not return
"""
assert self._app is not None
assert body.model_type == "LLM"
assert body.model_type in ["LLM", "multimodal"]

# asyncio.Lock() behaves differently in 3.9 than 3.10+
# A event loop is required in 3.9 but not 3.10+
Expand All @@ -629,16 +629,17 @@ async def build_gradio_interface(
)
asyncio.set_event_loop(asyncio.new_event_loop())

from ..core.chat_interface import LLMInterface
from ..core.chat_interface import GradioInterface

try:
access_token = request.headers.get("Authorization")
internal_host = "localhost" if self._host == "0.0.0.0" else self._host
interface = LLMInterface(
interface = GradioInterface(
endpoint=f"http://{internal_host}:{self._port}",
model_uid=model_uid,
model_name=body.model_name,
model_size_in_billions=body.model_size_in_billions,
model_type=body.model_type,
model_format=body.model_format,
quantization=body.quantization,
context_length=body.context_length,
Expand Down
137 changes: 135 additions & 2 deletions xinference/core/chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import logging
import os
from io import BytesIO
from typing import Generator, List, Optional

import gradio as gr
import PIL.Image
from gradio.components import Markdown, Textbox
from gradio.layouts import Accordion, Column, Row

from ..client.restful.restful_client import (
RESTfulChatglmCppChatModelHandle,
RESTfulChatModelHandle,
RESTfulGenerateModelHandle,
RESTfulMultimodalModelHandle,
)
from ..types import ChatCompletionMessage

logger = logging.getLogger(__name__)


class LLMInterface:
class GradioInterface:
def __init__(
self,
endpoint: str,
model_uid: str,
model_name: str,
model_size_in_billions: int,
model_type: str,
model_format: str,
quantization: str,
context_length: int,
Expand All @@ -49,6 +54,7 @@ def __init__(
self.model_uid = model_uid
self.model_name = model_name
self.model_size_in_billions = model_size_in_billions
self.model_type = model_type
self.model_format = model_format
self.quantization = quantization
self.context_length = context_length
Expand All @@ -60,7 +66,9 @@ def __init__(
)

def build(self) -> "gr.Blocks":
if "chat" in self.model_ability:
if self.model_type == "multimodal":
interface = self.build_chat_vl_interface()
elif "chat" in self.model_ability:
interface = self.build_chat_interface()
else:
interface = self.build_generate_interface()
Expand Down Expand Up @@ -173,6 +181,131 @@ def generate_wrapper(
analytics_enabled=False,
)

def build_chat_vl_interface(
self,
) -> "gr.Blocks":
def predict(history, bot):
logger.debug("Predict model: %s, history: %s", self.model_uid, history)
from ..client import RESTfulClient

client = RESTfulClient(self.endpoint)
client._set_token(self._access_token)
model = client.get_model(self.model_uid)
assert isinstance(model, RESTfulMultimodalModelHandle)

prompt = history[-1]
assert prompt["role"] == "user"
prompt = prompt["content"]
# multimodal chat does not support stream.
response = model.chat(prompt=prompt, chat_history=history[:-1])
history.append(response["choices"][0]["message"])
bot[-1][1] = history[-1]["content"]
return history, bot

def add_text(history, bot, text, image):
logger.debug("Add text, text: %s, image: %s", text, image)
if image:
buffered = BytesIO()
with PIL.Image.open(image) as img:
img.thumbnail((500, 500))
img.save(buffered, format="JPEG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
display_content = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />\n{text}'
message = {
"role": "user",
"content": [
{"type": "text", "text": text},
{"type": "image_url", "image_url": {"url": image}},
],
}
else:
display_content = text
message = {"role": "user", "content": text}
history = history + [message]
bot = bot + [(display_content, None)]
return history, bot, "", None

def clear_history():
logger.debug("Clear history.")
return [], None, "", None

def update_button(text):
return gr.update(interactive=bool(text))

with gr.Blocks(
title=f"🚀 Xinference Chat Bot : {self.model_name} 🚀",
css="""
.center{
display: flex;
justify-content: center;
align-items: center;
padding: 0px;
color: #9ea4b0 !important;
}
""",
analytics_enabled=False,
) as chat_vl_interface:
Markdown(
f"""
<h1 style='text-align: center; margin-bottom: 1rem'>🚀 Xinference Chat Bot : {self.model_name} 🚀</h1>
"""
)
Markdown(
f"""
<div class="center">
Model ID: {self.model_uid}
</div>
<div class="center">
Model Size: {self.model_size_in_billions} Billion Parameters
</div>
<div class="center">
Model Format: {self.model_format}
</div>
<div class="center">
Model Quantization: {self.quantization}
</div>
"""
)

state = gr.State([])
with gr.Row():
chatbot = gr.Chatbot(
elem_id="chatbot", label=self.model_name, height=550, scale=7
)
with gr.Column(scale=3):
imagebox = gr.Image(type="filepath")
textbox = gr.Textbox(
show_label=False,
placeholder="Enter text and press ENTER",
container=False,
)
submit_btn = gr.Button(
value="Send", variant="primary", interactive=False
)
clear_btn = gr.Button(value="Clear")

textbox.change(update_button, [textbox], [submit_btn], queue=False)

textbox.submit(
add_text,
[state, chatbot, textbox, imagebox],
[state, chatbot, textbox, imagebox],
queue=False,
).then(predict, [state, chatbot], [state, chatbot])

submit_btn.click(
add_text,
[state, chatbot, textbox, imagebox],
[state, chatbot, textbox, imagebox],
queue=False,
).then(predict, [state, chatbot], [state, chatbot])

clear_btn.click(
clear_history, None, [state, chatbot, textbox, imagebox], queue=False
)

return chat_vl_interface

def build_generate_interface(
self,
):
Expand Down
9 changes: 8 additions & 1 deletion xinference/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,14 @@ def create_model_instance(
elif model_type == "multimodal":
kwargs.pop("trust_remote_code", None)
return create_multimodal_model_instance(
subpool_addr, devices, model_uid, model_name, **kwargs
subpool_addr,
devices,
model_uid,
model_name,
model_format,
model_size_in_billions,
quantization,
**kwargs,
)
else:
raise ValueError(f"Unsupported model type: {model_type}.")
9 changes: 8 additions & 1 deletion xinference/model/multimodal/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def _apply_format_to_model_id(spec: LVLMSpecV1, q: str) -> LVLMSpecV1:
and matched_quantization is None
):
continue
# Copy spec to avoid _apply_format_to_model_id modify the original spec.
spec = spec.copy()
if quantization:
return (
family,
Expand Down Expand Up @@ -328,6 +330,11 @@ def _skip_download(
logger.warning(f"Cache {cache_dir} exists, but it was from {hub}")
return True
return False
elif model_format in ["ggmlv3", "ggufv2", "gptq"]:
assert quantization is not None
return os.path.exists(
_get_meta_path(cache_dir, model_format, model_hub, quantization)
)
else:
raise ValueError(f"Unsupported format: {model_format}")

Expand Down Expand Up @@ -414,7 +421,7 @@ def cache_from_huggingface(
):
return cache_dir

if model_spec.model_format in ["pytorch"]:
if model_spec.model_format in ["pytorch", "gptq"]:
assert isinstance(model_spec, LVLMSpecV1)
retry_download(
huggingface_hub.snapshot_download,
Expand Down
9 changes: 9 additions & 0 deletions xinference/model/multimodal/model_spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@
],
"model_id": "Qwen/Qwen-VL-Chat",
"model_revision": "6665c780ade5ff3f08853b4262dcb9c8f9598d42"
},
{
"model_format": "gptq",
"model_size_in_billions": 7,
"quantizations": [
"Int4"
],
"model_id": "Qwen/Qwen-VL-Chat-{quantization}",
"model_revision": "5d3a5aa033ed2c502300d426c81cc5b13bcd1409"
}
],
"prompt_style": {
Expand Down
14 changes: 5 additions & 9 deletions xinference/model/multimodal/qwen_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import time
import uuid
from typing import Dict, Iterator, List, Optional, Union
from urllib.parse import urlparse

from ...types import (
ChatCompletion,
Expand Down Expand Up @@ -73,14 +72,7 @@ def load(self):

def _message_content_to_qwen(self, content) -> str:
def _ensure_url(_url):
try:
if _url.startswith("data:"):
raise "Not a valid url."
parsed = urlparse(_url)
if not parsed.scheme:
raise "Not a valid url."
return _url
except Exception:
if _url.startswith("data:"):
logging.info("Parse url by base64 decoder.")
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
# e.g. f"data:image/jpeg;base64,{base64_image}"
Expand All @@ -93,6 +85,10 @@ def _ensure_url(_url):
f.write(data)
logging.info("Dump base64 data to %s", f.name)
return f.name
else:
if len(_url) > 2048:
raise Exception(f"Image url is too long, {len(_url)} > 2048.")
return _url

if not isinstance(content, str):
# TODO(codingl2k1): Optimize _ensure_url
Expand Down
7 changes: 6 additions & 1 deletion xinference/model/multimodal/tests/test_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@


@pytest.mark.skip(reason="Cost too many resources.")
def test_restful_api_for_qwen_vl(setup):
@pytest.mark.parametrize(
"model_format, quantization", [("pytorch", None), ("gptq", "Int4")]
)
def test_restful_api_for_qwen_vl(setup, model_format, quantization):
endpoint, _ = setup
from ....client import Client

Expand All @@ -28,6 +31,8 @@ def test_restful_api_for_qwen_vl(setup):
model_uid="my_controlnet",
model_name="qwen-vl-chat",
model_type="multimodal",
model_format=model_format,
quantization=quantization,
)
model = client.get_model(model_uid)
prompt = [
Expand Down
15 changes: 10 additions & 5 deletions xinference/web/ui/src/scenes/launch_model/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Title from '../../components/Title'
import LaunchCustom from './launchCustom'
import LaunchEmbedding from './launchEmbedding'
import LaunchLLM from './launchLLM'
import LaunchMultimodal from './launchMultimodal'
import LaunchRerank from './launchRerank'

const LaunchModel = () => {
Expand Down Expand Up @@ -67,21 +68,25 @@ const LaunchModel = () => {
<Box sx={{ borderBottom: 1, borderColor: 'divider' }}>
<TabList value={value} onChange={handleTabChange} aria-label="tabs">
<Tab label="Language Models" value="1" />
<Tab label="Embedding Models" value="2" />
<Tab label="Rerank Models" value="3" />
<Tab label="Custom Models" value="4" />
<Tab label="Multimodal Models" value="2" />
<Tab label="Embedding Models" value="3" />
<Tab label="Rerank Models" value="4" />
<Tab label="Custom Models" value="5" />
</TabList>
</Box>
<TabPanel value="1" sx={{ padding: 0 }}>
<LaunchLLM gpuAvailable={gpuAvailable} />
</TabPanel>
<TabPanel value="2" sx={{ padding: 0 }}>
<LaunchEmbedding />
<LaunchMultimodal gpuAvailable={gpuAvailable} />
</TabPanel>
<TabPanel value="3" sx={{ padding: 0 }}>
<LaunchRerank />
<LaunchEmbedding />
</TabPanel>
<TabPanel value="4" sx={{ padding: 0 }}>
<LaunchRerank />
</TabPanel>
<TabPanel value="5" sx={{ padding: 0 }}>
<LaunchCustom gpuAvailable={gpuAvailable} />
</TabPanel>
</TabContext>
Expand Down
Loading

0 comments on commit 6bfd80f

Please sign in to comment.