Skip to content

Commit

Permalink
Refactor code into VideoLLaVAApp class
Browse files Browse the repository at this point in the history
  • Loading branch information
lcolok committed Nov 22, 2023
1 parent b2662c4 commit 0b9ed83
Showing 1 changed file with 59 additions and 55 deletions.
114 changes: 59 additions & 55 deletions llava/serve/gradio_web_server.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import shutil
import subprocess
import argparse

import torch
import gradio as gr
from fastapi import FastAPI
import tempfile
import os
from PIL import Image
import tempfile
from decord import VideoReader, cpu
from transformers import TextStreamer
from fastapi import FastAPI
import gradio as gr
import torch

from llava.constants import DEFAULT_X_TOKEN, X_TOKEN_INDEX
from llava.conversation import conv_templates, SeparatorStyle, Conversation
Expand All @@ -22,21 +19,42 @@
)


if __name__ == "__main__":
class VideoLLaVAHandler:
def __init__(self, args):
self.args = args
self.conv_mode = args.conv_mode
self.state = conv_templates[self.conv_mode].copy()
self.state_ = conv_templates[self.conv_mode].copy()
self.images_tensor = [[], []]
self.first_run = True
self.handler = None
self.dtype = torch.float32 if self.args.use_full_precision else torch.float16
self.setup_handler()

def setup_handler(self):
self.handler = Chat(
self.args.model_path,
conv_mode=self.conv_mode,
load_8bit=self.args.load_8bit,
load_4bit=self.args.load_4bit,
device=self.args.device,
)

def save_image_to_local(image):
def save_image_to_local(self, image):
filename = os.path.join("temp", next(tempfile._get_candidate_names()) + ".jpg")
image = Image.open(image)
image.save(filename)
# print(filename)
return filename

def save_video_to_local(video_path):
def save_video_to_local(self, video_path):
filename = os.path.join("temp", next(tempfile._get_candidate_names()) + ".mp4")
shutil.copyfile(video_path, filename)
return filename

def generate(image1, video, textbox_in, first_run, state, state_, images_tensor):
def generate(
self, image1, video, textbox_in, first_run, state, state_, images_tensor
):
flag = 1
if not textbox_in:
if len(state_.messages) > 0:
Expand All @@ -51,43 +69,43 @@ def generate(image1, video, textbox_in, first_run, state, state_, images_tensor)
# assert not (os.path.exists(image1) and os.path.exists(video))

if type(state) is not Conversation:
state = conv_templates[conv_mode].copy()
state_ = conv_templates[conv_mode].copy()
state = conv_templates[self.conv_mode].copy()
state_ = conv_templates[self.conv_mode].copy()
images_tensor = [[], []]

first_run = False if len(state.messages) > 0 else True

text_en_in = textbox_in.replace("picture", "image")

# images_tensor = [[], []]
image_processor = handler.image_processor
image_processor = self.handler.image_processor
if os.path.exists(image1) and not os.path.exists(video):
tensor = image_processor.preprocess(image1, return_tensors="pt")[
"pixel_values"
][0]
# print(tensor.shape)
tensor = tensor.to(handler.model.device, dtype=dtype)
tensor = tensor.to(self.handler.model.device, dtype=self.dtype)
images_tensor[0] = images_tensor[0] + [tensor]
images_tensor[1] = images_tensor[1] + ["image"]
video_processor = handler.video_processor
video_processor = self.handler.video_processor
if not os.path.exists(image1) and os.path.exists(video):
tensor = video_processor(video, return_tensors="pt")["pixel_values"][0]
# print(tensor.shape)
tensor = tensor.to(handler.model.device, dtype=dtype)
tensor = tensor.to(self.handler.model.device, dtype=self.dtype)
images_tensor[0] = images_tensor[0] + [tensor]
images_tensor[1] = images_tensor[1] + ["video"]
if os.path.exists(image1) and os.path.exists(video):
tensor = video_processor(video, return_tensors="pt")["pixel_values"][0]
# print(tensor.shape)
tensor = tensor.to(handler.model.device, dtype=dtype)
tensor = tensor.to(self.handler.model.device, dtype=self.dtype)
images_tensor[0] = images_tensor[0] + [tensor]
images_tensor[1] = images_tensor[1] + ["video"]

tensor = image_processor.preprocess(image1, return_tensors="pt")[
"pixel_values"
][0]
# print(tensor.shape)
tensor = tensor.to(handler.model.device, dtype=dtype)
tensor = tensor.to(self.handler.model.device, dtype=self.dtype)
images_tensor[0] = images_tensor[0] + [tensor]
images_tensor[1] = images_tensor[1] + ["image"]

Expand All @@ -104,7 +122,7 @@ def generate(image1, video, textbox_in, first_run, state, state_, images_tensor)
+ DEFAULT_X_TOKEN["IMAGE"]
)

text_en_out, state_ = handler.generate(
text_en_out, state_ = self.handler.generate(
images_tensor, text_en_in, first_run=first_run, state=state_
)
state_.messages[-1] = (state_.roles[1], text_en_out)
Expand All @@ -114,10 +132,10 @@ def generate(image1, video, textbox_in, first_run, state, state_, images_tensor)

show_images = ""
if os.path.exists(image1):
filename = save_image_to_local(image1)
filename = self.save_image_to_local(image1)
show_images += f'<img src="./file={filename}" style="display: inline-block;width: 250px;max-height: 400px;">'
if os.path.exists(video):
filename = save_video_to_local(video)
filename = self.save_video_to_local(video)
show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={filename}"></video>'

if flag:
Expand All @@ -137,16 +155,16 @@ def generate(image1, video, textbox_in, first_run, state, state_, images_tensor)
gr.update(value=video if os.path.exists(video) else None, interactive=True),
)

def regenerate(state, state_):
def regenerate(self, state, state_):
state.messages.pop(-1)
state_.messages.pop(-1)
if len(state.messages) > 0:
return state, state_, state.to_gradio_chatbot(), False
return (state, state_, state.to_gradio_chatbot(), True)

def clear_history(state, state_):
state = conv_templates[conv_mode].copy()
state_ = conv_templates[conv_mode].copy()
def clear_history(self, state, state_):
state = conv_templates[self.conv_mode].copy()
state_ = conv_templates[self.conv_mode].copy()
return (
gr.update(value=None, interactive=True),
gr.update(value=None, interactive=True),
Expand All @@ -158,7 +176,7 @@ def clear_history(state, state_):
[[], []],
)

def create_gradio_ui(handler, args):
def create_gradio_ui(self):
textbox = gr.Textbox(
show_label=False, placeholder="Enter text and press ENTER", container=False
)
Expand Down Expand Up @@ -212,9 +230,13 @@ def create_gradio_ui(handler, args):
flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
# stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
regenerate_btn = gr.Button(
value="🔄 Regenerate", interactive=True)
value="🔄 Regenerate",
interactive=True,
)
clear_btn = gr.Button(
value="🗑️ Clear history", interactive=True)
value="🗑️ Clear history",
interactive=True,
)

with gr.Row():
gr.Examples(
Expand Down Expand Up @@ -262,7 +284,7 @@ def create_gradio_ui(handler, args):
gr.Markdown(learn_more_markdown)

submit_btn.click(
generate,
self.generate,
[
image1,
video,
Expand All @@ -285,9 +307,9 @@ def create_gradio_ui(handler, args):
)

regenerate_btn.click(
regenerate, [state, state_], [state, state_, chatbot, first_run]
self.regenerate, [state, state_], [state, state_, chatbot, first_run]
).then(
generate,
self.generate,
[
image1,
video,
Expand All @@ -310,7 +332,7 @@ def create_gradio_ui(handler, args):
)

clear_btn.click(
clear_history,
self.clear_history,
[state, state_],
[
image1,
Expand All @@ -326,6 +348,8 @@ def create_gradio_ui(handler, args):

return demo


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run Gradio app with custom settings")
parser.add_argument(
"--server_name", type=str, default="0.0.0.0", help="Server name"
Expand Down Expand Up @@ -376,34 +400,14 @@ def create_gradio_ui(handler, args):

args = parser.parse_args()

conv_mode = args.conv_mode
model_path = args.model_path
device = args.device
load_8bit = args.load_8bit
load_4bit = args.load_4bit
dtype = torch.float32 if args.use_full_precision else torch.float16

handler = Chat(
model_path,
conv_mode=conv_mode,
load_8bit=load_8bit,
load_4bit=load_8bit,
device=device,
)
# handler.model.to(dtype=dtype)

if not os.path.exists("temp"):
os.makedirs("temp")

video_llava_handler = VideoLLaVAHandler(args)
app = FastAPI()

demo = create_gradio_ui(handler, args)

# app = gr.mount_gradio_app(app, demo, path="/")
demo = video_llava_handler.create_gradio_ui()
demo.launch(
server_name=args.server_name,
server_port=args.server_port,
share=args.share,
)

# uvicorn llava.serve.gradio_web_server:app

0 comments on commit 0b9ed83

Please sign in to comment.