diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index e2ec36211b86..a0a71f18ed94 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -21,7 +21,7 @@ # Aria -def run_aria(question: str, modality: str): +def run_aria(questions: list[str], modality: str): assert modality == "image" model_name = "rhymes-ai/Aria" @@ -32,41 +32,42 @@ def run_aria(question: str, modality: str): dtype="bfloat16", disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) - prompt = (f"<|im_start|>user\n<|img|>{question}" - "<|im_end|>\n<|im_start|>assistant\n") + prompts = [(f"<|im_start|>user\n<|img|>{question}" + "<|im_end|>\n<|im_start|>assistant\n") + for question in questions] stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519] - return llm, prompt, stop_token_ids + return llm, prompts, stop_token_ids # BLIP-2 -def run_blip2(question: str, modality: str): +def run_blip2(questions: list[str], modality: str): assert modality == "image" # BLIP-2 prompt format is inaccurate on HuggingFace model repository. # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa - prompt = f"Question: {question} Answer:" + prompts = [f"Question: {question} Answer:" for question in questions] llm = LLM(model="Salesforce/blip2-opt-2.7b", disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) stop_token_ids = None - return llm, prompt, stop_token_ids + return llm, prompts, stop_token_ids # Chameleon -def run_chameleon(question: str, modality: str): +def run_chameleon(questions: list[str], modality: str): assert modality == "image" - prompt = f"{question}" + prompts = [f"{question}" for question in questions] llm = LLM(model="facebook/chameleon-7b", max_model_len=4096, max_num_seqs=2, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) stop_token_ids = None - return llm, prompt, stop_token_ids + return llm, prompts, stop_token_ids # Deepseek-VL2 -def run_deepseek_vl2(question: str, modality: str): +def run_deepseek_vl2(questions: list[str], modality: str): assert modality == "image" model_name = "deepseek-ai/deepseek-vl2-tiny" @@ -77,9 +78,12 @@ def run_deepseek_vl2(question: str, modality: str): disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}) - prompt = f"<|User|>: \n{question}\n\n<|Assistant|>:" + prompts = [ + f"<|User|>: \n{question}\n\n<|Assistant|>:" + for question in questions + ] stop_token_ids = None - return llm, prompt, stop_token_ids + return llm, prompts, stop_token_ids # Florence2 @@ -99,20 +103,20 @@ def run_florence2(question: str, modality: str): # Fuyu -def run_fuyu(question: str, modality: str): +def run_fuyu(questions: list[str], modality: str): assert modality == "image" - prompt = f"{question}\n" + prompts = [f"{question}\n" for question in questions] llm = LLM(model="adept/fuyu-8b", max_model_len=2048, max_num_seqs=2, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) stop_token_ids = None - return llm, prompt, stop_token_ids + return llm, prompts, stop_token_ids # GLM-4v -def run_glm4v(question: str, modality: str): +def run_glm4v(questions: list[str], modality: str): assert modality == "image" model_name = "THUDM/glm-4v-9b" @@ -124,15 +128,17 @@ def run_glm4v(question: str, modality: str): hf_overrides={"architectures": ["GLM4VForCausalLM"]}, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) - prompt = f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\ - {question}<|assistant|>" + prompts = [ + f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\ + {question}<|assistant|>" for question in questions + ] stop_token_ids = [151329, 151336, 151338] - return llm, prompt, stop_token_ids + return llm, prompts, stop_token_ids # H2OVL-Mississippi -def run_h2ovl(question: str, modality: str): +def run_h2ovl(questions: list[str], modality: str): assert modality == "image" model_name = "h2oai/h2ovl-mississippi-800m" @@ -146,19 +152,24 @@ def run_h2ovl(question: str, modality: str): tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - messages = [{'role': 'user', 'content': f"\n{question}"}] - prompt = tokenizer.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True) + prompts = [ + tokenizer.apply_chat_template([{ + 'role': 'user', + 'content': f"\n{question}" + }], + tokenize=False, + add_generation_prompt=True) + for question in questions + ] # Stop tokens for H2OVL-Mississippi # https://huggingface.co/h2oai/h2ovl-mississippi-800m stop_token_ids = [tokenizer.eos_token_id] - return llm, prompt, stop_token_ids + return llm, prompts, stop_token_ids # Idefics3-8B-Llama3 -def run_idefics3(question: str, modality: str): +def run_idefics3(questions: list[str], modality: str): assert modality == "image" model_name = "HuggingFaceM4/Idefics3-8B-Llama3" @@ -176,15 +187,15 @@ def run_idefics3(question: str, modality: str): }, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, ) - prompt = ( + prompts = [( f"<|begin_of_text|>User:{question}\nAssistant:" - ) + ) for question in questions] stop_token_ids = None - return llm, prompt, stop_token_ids + return llm, prompts, stop_token_ids # InternVL -def run_internvl(question: str, modality: str): +def run_internvl(questions: list[str], modality: str): assert modality == "image" model_name = "OpenGVLab/InternVL2-2B" @@ -198,10 +209,15 @@ def run_internvl(question: str, modality: str): tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - messages = [{'role': 'user', 'content': f"\n{question}"}] - prompt = tokenizer.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True) + prompts = [ + tokenizer.apply_chat_template([{ + 'role': 'user', + 'content': f"\n{question}" + }], + tokenize=False, + add_generation_prompt=True) + for question in questions + ] # Stop tokens for InternVL # models variants may have different stop tokens @@ -209,71 +225,82 @@ def run_internvl(question: str, modality: str): # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] - return llm, prompt, stop_token_ids + return llm, prompts, stop_token_ids # LLaVA-1.5 -def run_llava(question: str, modality: str): +def run_llava(questions: list[str], modality: str): assert modality == "image" - prompt = f"USER: \n{question}\nASSISTANT:" + prompts = [ + f"USER: \n{question}\nASSISTANT:" for question in questions + ] llm = LLM(model="llava-hf/llava-1.5-7b-hf", max_model_len=4096, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) stop_token_ids = None - return llm, prompt, stop_token_ids + return llm, prompts, stop_token_ids # LLaVA-1.6/LLaVA-NeXT -def run_llava_next(question: str, modality: str): +def run_llava_next(questions: list[str], modality: str): assert modality == "image" - prompt = f"[INST] \n{question} [/INST]" + prompts = [f"[INST] \n{question} [/INST]" for question in questions] llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) stop_token_ids = None - return llm, prompt, stop_token_ids + return llm, prompts, stop_token_ids # LlaVA-NeXT-Video # Currently only support for video input -def run_llava_next_video(question: str, modality: str): +def run_llava_next_video(questions: list[str], modality: str): assert modality == "video" - prompt = f"USER: