Skip to content

Commit

Permalink
Dev/onevision (#148)
Browse files Browse the repository at this point in the history
* feat: Add multi-choice parsing and processing functions

This commit adds two new functions to the `egoschema/utils.py` file: `get_multi_choice_info` and `parse_multi_choice_response`. These functions are used to parse and process multi-choice responses in the Egoschema task.

The `get_multi_choice_info` function extracts information about the available choices from the input document, while the `parse_multi_choice_response` function parses the generated response and returns the predicted index.

These functions are essential for accurately processing multi-choice answers in the Egoschema task.

* feat: Add regex-based parsing for multi-choice predictions

This commit enhances the `perceptiontest_val_process_results_mc` function in the `utils.py` file. It introduces regex-based parsing to extract the predicted choice from the raw text prediction. If a match is found for A, B, C, or D, the matched letter is used as the prediction. Otherwise, an empty string is set as the prediction.

This improvement ensures accurate processing of multi-choice predictions in the perception test validation.

Co-authored-by: [Co-author Name] <[coauthor@example.com]>

* refactor: Improve accuracy calculation in perception test validation

This commit refactors the `perceptiontest_val_aggregate_accuracy` function in the `utils.py` file. Instead of comparing the string representations of `answer_id` and `pred_id`, it now directly checks the `correct` field in the `accuracy` dictionary. This change ensures more accurate calculation of the overall accuracy in the perception test validation.

Co-authored-by: [Co-author Name] <[coauthor@example.com]>

* Refactor accuracy calculation in perception test validation

* feat: Add SRT_API model to available models

This commit adds the SRT_API model to the list of available models in the `__init__.py` file. This model can now be used for evaluation and testing purposes.

Co-authored-by: [Co-author Name] <[coauthor@example.com]>

---------

Co-authored-by: [Co-author Name] <[coauthor@example.com]>
  • Loading branch information
Luodian and [Co-author Name] authored Jul 27, 2024
1 parent 9dd1672 commit b5ba906
Show file tree
Hide file tree
Showing 7 changed files with 351 additions and 73 deletions.
25 changes: 25 additions & 0 deletions docs/commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,28 @@ This mode supports a number of command-line arguments, the details of which can

* `--limit` : Accepts an integer, or a float between 0.0 and 1.0 . If passed, will limit the number of documents to evaluate to the first X documents (if an integer) per task or first X% of documents per task. Useful for debugging, especially on costly API models.

## Usage with SRT API

> install sglang
```bash
git clone https://github.com/EvolvingLMMs-Lab/sglang.git
cd sglang
pip install -e "python[srt]"
```

> run sglang backend service with the following command
```bash
python -m sglang.launch_server --model-path "\path\to\onevision" --tokenizer-path lmms-lab/llavanext-qwen-siglip-tokenizer --port=30000 --host=127.0.0.1 --tp-size=8 --chat-template=chatml-llava
```

You may need to install some dependencies for the above command to work (if you encounter some errors).

```bash
pip install httpx==0.23.3
pip install protobuf==3.20
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/
```


1 change: 1 addition & 0 deletions lmms_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"tinyllava": "TinyLlava",
"llava_onevision": "LlavaOneVision",
"llava_hf": "LlavaHf",
"srt_api": "SRT_API",
"longva": "LongVA",
"vila": "VILA",
"xcomposer2d5": "XComposer2D5",
Expand Down
189 changes: 189 additions & 0 deletions lmms_eval/models/srt_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
from accelerate import Accelerator, DistributedType
import base64
from io import BytesIO
from copy import deepcopy
from decord import VideoReader, cpu
import numpy as np
from openai import OpenAI
from PIL import Image
import os
import json
from typing import List, Tuple
from tqdm import tqdm
import time

from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model


@register_model("srt_api")
class SRT_API(lmms):
def __init__(
self,
api_key: str = "EMPTY",
model_version: str = "default",
modality: str = "video",
host: str = "127.0.0.1",
port: int = 30000,
max_frames_num: int = 10,
timeout: int = 120,
continual_mode: bool = False,
response_persistent_folder: str = None,
**kwargs,
) -> None:
super().__init__()
# Manually set a image token for GPT4V so that we can search for it
# and split the text and image
# Here we just use the same token as llava for convenient
self.model_version = model_version
self.modality = modality
self.max_frames_num = max_frames_num
self.image_token = "<image>"
self.timeout = timeout
self.continual_mode = continual_mode
if self.continual_mode:
if response_persistent_folder is None:
raise ValueError("Continual mode requires a persistent path for the response. Please provide a valid path.")

os.makedirs(response_persistent_folder, exist_ok=True)
self.response_persistent_folder = response_persistent_folder
self.response_persistent_file = os.path.join(self.response_persistent_folder, f"{self.model_version}_response.json")

if os.path.exists(self.response_persistent_file):
with open(self.response_persistent_file, "r") as f:
self.response_cache = json.load(f)
self.cache_mode = "resume"
else:
self.response_cache = {}
self.cache_mode = "start"

accelerator = Accelerator()
self.client = OpenAI(api_key="EMPTY", base_url="http://127.0.0.1:30000/v1")
# assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue."
if accelerator.num_processes > 1:
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
else:
self.accelerator = accelerator
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes

self.device = self.accelerator.device

# Function to encode the image
def encode_image(self, image: Image):
output_buffer = BytesIO()
image.save(output_buffer, format="PNG")
byte_data = output_buffer.getvalue()
base64_str = base64.b64encode(byte_data).decode("utf-8")
return base64_str

# Function to encode the video
def encode_video(self, video_path, for_get_frames_num):
vr = VideoReader(video_path, ctx=cpu(0))
total_frame_num = len(vr)
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, for_get_frames_num, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
frames = vr.get_batch(frame_idx).asnumpy()

base64_frames = []
for frame in frames:
img = Image.fromarray(frame)
output_buffer = BytesIO()
img.save(output_buffer, format="PNG")
byte_data = output_buffer.getvalue()
base64_str = base64.b64encode(byte_data).decode("utf-8")
base64_frames.append(base64_str)

return base64_frames

def flatten(self, input):
new_list = []
for i in input:
for j in i:
new_list.append(j)
return new_list

def generate_until(self, requests) -> List[str]:
res = []
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")

for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
if self.continual_mode is True and self.cache_mode == "resume":
doc_uuid = f"{task}___{split}___{doc_id}"
if doc_uuid in self.response_cache:
response_text = self.response_cache[doc_uuid]
if response_text:
res.append(response_text)
pbar.update(1)
continue

visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
visuals = self.flatten(visuals)
imgs = [] # multiple images or frames for video
for visual in visuals:
if self.modality == "image":
img = self.encode_image(visual)
imgs.append(img)
elif self.modality == "video":
frames = self.encode_video(visual, self.max_frames_num)
imgs.extend(frames)

messages = []
if self.image_token not in contexts: # single image format
content = []
for img in imgs:
content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}})

content.append({"type": "text", "text": contexts})
messages.append({"role": "user", "content": content})
else: # interleaved format
contexts = contexts.split(self.image_token)
for idx, img in enumerate(imgs):
content = [
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}},
{"type": "text", "text": contexts[idx]},
]
messages.append({"role": "user", "content": content})
messages.append({"role": "user", "content": [{"type": "text", "text": contexts[-1]}]})

if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 1024

if "temperature" not in gen_kwargs:
gen_kwargs["temperature"] = 0

for attempt in range(5):
try:
response = self.client.chat.completions.create(model=self.model_version, messages=messages, temperature=gen_kwargs["temperature"], max_tokens=gen_kwargs["max_new_tokens"], timeout=self.timeout)
response_text = response.choices[0].message.content.strip()
break # If successful, break out of the loop

except Exception as e:
eval_logger.info(f"Attempt {attempt + 1} failed with error: {str(e)}.")
if attempt < 4:
time.sleep(NUM_SECONDS_TO_SLEEP)
else: # If this was the last attempt, log and return empty string
eval_logger.error(f"All 5 attempts failed. Last error message: {str(e)}.")
response_text = ""

res.append(response_text)
pbar.update(1)

if self.continual_mode is True: # Cache the response
doc_uuid = f"{task}___{split}___{doc_id}"
self.response_cache[doc_uuid] = response_text
with open(self.response_persistent_file, "w") as f:
json.dump(self.response_cache, f)

pbar.close()
return res

def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
# TODO
assert False, "GPT4V not support"
127 changes: 96 additions & 31 deletions lmms_eval/tasks/egoschema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,41 +84,106 @@ def egoschema_process_results(doc, result):
return {"submission": {doc["video_idx"]: min_index}, "score": {"pred": min_index, "ground_truth": doc["answer"]}}


def get_multi_choice_info(doc):
all_choices = []
index2ans = {}
OPTIONS = ["A", "B", "C", "D", "E"]
for i in range(5):
# import pdb;pdb.set_trace()
index2ans[OPTIONS[i]] = doc["option"][i].strip()
all_choices.append(OPTIONS[i])

return index2ans, all_choices


def parse_multi_choice_response(response, all_choices, index2ans):
"""
Parse the prediction from the generated response.
Return the predicted index e.g., A, B, C, D.
https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L10
"""
for char in [",", ".", "!", "?", ";", ":", "'"]:
response = response.strip(char)
response = " " + response + " " # add space to avoid partial match

index_ans = True
ans_with_brack = False
ans_with_space = False
ans_with_dot = False
candidates = []
# import pdb; pdb.set_trace()
for choice in all_choices: # e.g., (A) (B) (C) (D)
if f"({choice})" in response:
candidates.append(f"({choice})")
ans_with_brack = True

# if len(candidates) == 0:
for choice in all_choices: # e.g., A B C D
if f"{choice} " in response:
candidates.append(f"{choice} ")
ans_with_space = True

# if len(candidates) == 0:
for choice in all_choices: # e.g., A. B. C. D.
if f"{choice}." in response:
candidates.append(f"{choice}.")
ans_with_dot = True

# if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
if len(candidates) == 0 and len(response.split()) > 5:
for index, ans in index2ans.items():
if ans.lower() in response.lower():
candidates.append(index)
index_ans = False # it's content ans.

if len(candidates) == 0: # still not get answer, randomly choose one.
# import pdb; pdb.set_trace()
pred_index = random.choice(all_choices)
elif len(candidates) > 1:
# candidates = list(set(candidates))
start_indexes = []
if index_ans:
# if ans_with_brack:
for can in candidates:
index = response.rfind(can)
start_indexes.append(index) # -1 will be ignored anyway
# start_indexes = [generated_response.index(f'({can})') for can in candidates]
# if ans_with_space:
# for can in candidates:
# index = response.rfind(f"{can} ")
# start_indexes.append(index)
# if ans_with_dot:
# for can in candidates:
# index = response.rfind(f"{can}.")
# start_indexes.append(index)
# if not ans_with_brack and not ans_with_space and not ans_with_dot:
# for can in candidates:
# index = response.rfind(f" {can} ")
# start_indexes.append(index)
else:
for can in candidates:
index = response.lower().rfind(index2ans[can].lower())
start_indexes.append(index)
# get the first one
pred_index = candidates[np.argmin(start_indexes)]
pred_index = pred_index.replace("(", "").replace(")", "").replace(".", "").strip()
else: # if only one candidate, use it.
pred_index = candidates[0]
pred_index = pred_index.replace("(", "").replace(")", "").replace(".", "").strip()

return pred_index, len(candidates) > 0


# Process result for mcq answer generation
def egoschema_process_results_generation(doc, result):
# import pdb;pdb.set_trace()
pred = result[0]

# Determine whether the video LLM output is correct, based on word matching rules
# Ensure each option string ends with a period
option_strs = [opt if opt.endswith(".") else opt + "." for opt in doc["option"]] # Complete option strings
option_sents = [opt.split(". ")[1] if ". " in opt else opt for opt in option_strs] # Option sentence
option_inds = [opt.split(". ")[0] if ". " in opt else opt for opt in option_strs] # Option letter, e.g., A, B, C, D, E

video_llm_pred = None
index = -1

# Check if the prediction matches any of the complete option strings
for idx, option_str in enumerate(option_strs):
if pred == option_str:
video_llm_pred = option_str
index = idx
break

# Check if the prediction matches any of the option sentences
if not video_llm_pred:
for idx, option_sent in enumerate(option_sents):
if pred == option_sent:
video_llm_pred = option_sent
index = idx
break

# Check if the prediction matches any of the option letters
if not video_llm_pred:
for idx, option_ind in enumerate(option_inds):
if pred == option_ind or pred == option_ind.replace(".", ""):
video_llm_pred = option_ind
index = idx
break
index2ans, all_choices = get_multi_choice_info(doc)
parsed_pred, matched_tag = parse_multi_choice_response(pred, all_choices, index2ans)

pred_to_index = {"A": 0, "B": 1, "C": 2, "D": 3, "E": 4}
index = pred_to_index.get(parsed_pred, -1) # Default to -1 if the prediction is not found

return {"submission": {doc["video_idx"]: index}, "score": {"pred": index, "ground_truth": doc["answer"]}}

Expand Down
3 changes: 3 additions & 0 deletions lmms_eval/tasks/mmstar/mmstar.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ metric_list:
- metric: math
aggregation: !function utils.mmstar_aggregate_results
higher_is_better: true
- metric: average
aggregation: !function utils.mmstar_aggregate_results
higher_is_better: true
model_specific_prompt_kwargs:
default:
pre_prompt: ""
Expand Down
Loading

0 comments on commit b5ba906

Please sign in to comment.