Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradio Web Server for Multimodal Models #2960

Merged
merged 16 commits into from
Feb 10, 2024
44 changes: 44 additions & 0 deletions fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class Conversation:
# The names of two roles
roles: Tuple[str] = ("USER", "ASSISTANT")
# All messages. Each item is (role, message).
# Each message is either a string or a tuple of (string, List[image_url]).
messages: List[List[str]] = ()
# The number of few shot examples
offset: int = 0
Expand Down Expand Up @@ -289,11 +290,54 @@ def update_last_message(self, message: str):
"""
self.messages[-1][1] = message

def convert_image_to_base64(self, image):
"""Given an image, return the base64 encoded image string."""
import base64
from io import BytesIO
from PIL import Image
import requests

# Load image if it has not been loaded in yet
if type(image) == str:
if image.startswith("http://") or image.startswith("https://"):
response = requests.get(image)
image = Image.open(BytesIO(response.content)).convert("RGB")
elif "base64" in image:
# OpenAI format is: data:image/jpeg;base64,{base64_encoded_image_str}
return image.split(",")[1]
else:
image = Image.open(image).convert("RGB")

max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if longest_edge != max(image.size):
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))

buffered = BytesIO()
image.save(buffered, format="PNG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()

return img_b64_str

def to_gradio_chatbot(self):
"""Convert the conversation to gradio chatbot format."""
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
if type(msg) is tuple:
msg, image = msg
img_b64_str = image[0] # Only one image on gradio at one time
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
msg = img_str + msg.replace("<image>\n", "").strip()

ret.append([msg, None])
else:
ret[-1][-1] = msg
Expand Down
3 changes: 3 additions & 0 deletions fastchat/serve/base_model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
model_names: List[str],
limit_worker_concurrency: int,
conv_template: str = None,
multimodal: bool = False,
):
global logger, worker

Expand All @@ -46,6 +47,7 @@ def __init__(
self.limit_worker_concurrency = limit_worker_concurrency
self.conv = self.make_conv_template(conv_template, model_path)
self.conv.sep_style = int(self.conv.sep_style)
self.multimodal = multimodal
self.tokenizer = None
self.context_len = None
self.call_ct = 0
Expand Down Expand Up @@ -92,6 +94,7 @@ def register_to_controller(self):
"worker_name": self.worker_addr,
"check_heart_beat": True,
"worker_status": self.get_status(),
"multimodal": self.multimodal,
}
r = requests.post(url, json=data)
assert r.status_code == 200
Expand Down
47 changes: 44 additions & 3 deletions fastchat/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class WorkerInfo:
queue_length: int
check_heart_beat: bool
last_heart_beat: str
multimodal: bool


def heart_beat_controller(controller):
Expand All @@ -72,7 +73,11 @@ def __init__(self, dispatch_method: str):
self.heart_beat_thread.start()

def register_worker(
self, worker_name: str, check_heart_beat: bool, worker_status: dict
self,
worker_name: str,
check_heart_beat: bool,
worker_status: dict,
multimodal: bool,
):
if worker_name not in self.worker_info:
logger.info(f"Register a new worker: {worker_name}")
Expand All @@ -88,6 +93,7 @@ def register_worker(
worker_status["model_names"],
worker_status["speed"],
worker_status["queue_length"],
multimodal,
check_heart_beat,
time.time(),
)
Expand Down Expand Up @@ -116,7 +122,9 @@ def refresh_all_workers(self):
self.worker_info = {}

for w_name, w_info in old_info.items():
if not self.register_worker(w_name, w_info.check_heart_beat, None):
if not self.register_worker(
w_name, w_info.check_heart_beat, None, w_info.multimodal
):
logger.info(f"Remove stale worker: {w_name}")

def list_models(self):
Expand All @@ -127,6 +135,24 @@ def list_models(self):

return list(model_names)

def list_multimodal_models(self):
model_names = set()

for w_name, w_info in self.worker_info.items():
if w_info.multimodal:
model_names.update(w_info.model_names)

return list(model_names)

def list_language_models(self):
model_names = set()

for w_name, w_info in self.worker_info.items():
if not w_info.multimodal:
model_names.update(w_info.model_names)

return list(model_names)

def get_worker_address(self, model_name: str):
if self.dispatch_method == DispatchMethod.LOTTERY:
worker_names = []
Expand Down Expand Up @@ -263,7 +289,10 @@ def worker_api_generate_stream(self, params):
async def register_worker(request: Request):
data = await request.json()
controller.register_worker(
data["worker_name"], data["check_heart_beat"], data.get("worker_status", None)
data["worker_name"],
data["check_heart_beat"],
data.get("worker_status", None),
data.get("multimodal", False),
)


Expand All @@ -278,6 +307,18 @@ async def list_models():
return {"models": models}


@app.post("/list_multimodal_models")
async def list_multimodal_models():
models = controller.list_multimodal_models()
return {"models": models}


@app.post("/list_language_models")
async def list_language_models():
models = controller.list_language_models()
return {"models": models}


@app.post("/get_worker_address")
async def get_worker_address(request: Request):
data = await request.json()
Expand Down
Binary file added fastchat/serve/example_images/city.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added fastchat/serve/example_images/fridge.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
187 changes: 187 additions & 0 deletions fastchat/serve/gradio_block_arena_vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
"""
The gradio demo server for chatting with a single model.
"""

import os

import gradio as gr

from fastchat.serve.gradio_web_server import (
upvote_last_response,
downvote_last_response,
flag_last_response,
get_model_description_md,
acknowledgment_md,
bot_response,
add_text,
clear_history,
regenerate,
)
from fastchat.utils import (
build_logger,
)

logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
enable_moderation = False


def build_single_vision_language_model_ui(models, add_promotion_links=False):
promotion = (
"""
- | [GitHub](https://github.com/lm-sys/FastChat) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
- Introducing Llama 2: The Next Generation Open Source Large Language Model. [[Website]](https://ai.meta.com/llama/)
- Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90% ChatGPT Quality. [[Blog]](https://lmsys.org/blog/2023-03-30-vicuna/)
"""
if add_promotion_links
else ""
)

notice_markdown = f"""
# 🏔️ Chat with Open Large Vision-Language Models
{promotion}
### Choose a model to chat with
"""

state = gr.State()

with gr.Box():
with gr.Row(elem_id="model_selector_row"):
model_selector = gr.Dropdown(
choices=models,
value=models[0] if len(models) > 0 else "",
interactive=True,
show_label=False,
container=False,
)

with gr.Accordion(
"🔍 Expand to see 20+ model descriptions",
open=False,
elem_id="model_description_accordion",
):
model_description_md = get_model_description_md(models)
gr.Markdown(model_description_md, elem_id="model_description_markdown")

with gr.Row():
with gr.Column(scale=3):
textbox = gr.Textbox(
show_label=False,
placeholder="Enter your prompt here and press ENTER",
container=False,
render=False,
elem_id="input_box",
)
imagebox = gr.Image(type="pil")

cur_dir = os.path.dirname(os.path.abspath(__file__))

with gr.Accordion("Parameters", open=False) as parameter_row:
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.2,
step=0.1,
interactive=True,
label="Temperature",
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.1,
interactive=True,
label="Top P",
)
max_output_tokens = gr.Slider(
minimum=0,
maximum=1024,
value=512,
step=64,
interactive=True,
label="Max output tokens",
)

gr.Examples(
examples=[
[
f"{cur_dir}/example_images/city.jpeg",
"Explain this photo.",
],
[
f"{cur_dir}/example_images/fridge.jpeg",
"What is in this fridge?",
],
],
inputs=[imagebox, textbox],
)

with gr.Column(scale=8):
chatbot = gr.Chatbot(
elem_id="chatbot", label="Scroll down and start chatting", height=550
)

with gr.Row():
with gr.Column(scale=8):
textbox.render()
with gr.Column(scale=1, min_width=50):
send_btn = gr.Button(value="Send", variant="primary")
with gr.Row(elem_id="buttons"):
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)

if add_promotion_links:
gr.Markdown(acknowledgment_md)

# Register listeners
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
upvote_btn.click(
upvote_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn],
)
downvote_btn.click(
downvote_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn],
)
flag_btn.click(
flag_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn],
)
regenerate_btn.click(
regenerate, state, [state, chatbot, textbox, imagebox] + btn_list
).then(
bot_response,
[state, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
)
clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list)

model_selector.change(
clear_history, None, [state, chatbot, textbox, imagebox] + btn_list
)

textbox.submit(
add_text,
[state, model_selector, textbox, imagebox],
[state, chatbot, textbox, imagebox] + btn_list,
).then(
bot_response,
[state, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
)
send_btn.click(
add_text,
[state, model_selector, textbox, imagebox],
[state, chatbot, textbox, imagebox] + btn_list,
).then(
bot_response,
[state, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
)

return [state, model_selector]
Loading