-
Notifications
You must be signed in to change notification settings - Fork 192
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: Add multi-choice parsing and processing functions This commit adds two new functions to the `egoschema/utils.py` file: `get_multi_choice_info` and `parse_multi_choice_response`. These functions are used to parse and process multi-choice responses in the Egoschema task. The `get_multi_choice_info` function extracts information about the available choices from the input document, while the `parse_multi_choice_response` function parses the generated response and returns the predicted index. These functions are essential for accurately processing multi-choice answers in the Egoschema task. * feat: Add regex-based parsing for multi-choice predictions This commit enhances the `perceptiontest_val_process_results_mc` function in the `utils.py` file. It introduces regex-based parsing to extract the predicted choice from the raw text prediction. If a match is found for A, B, C, or D, the matched letter is used as the prediction. Otherwise, an empty string is set as the prediction. This improvement ensures accurate processing of multi-choice predictions in the perception test validation. Co-authored-by: [Co-author Name] <[coauthor@example.com]> * refactor: Improve accuracy calculation in perception test validation This commit refactors the `perceptiontest_val_aggregate_accuracy` function in the `utils.py` file. Instead of comparing the string representations of `answer_id` and `pred_id`, it now directly checks the `correct` field in the `accuracy` dictionary. This change ensures more accurate calculation of the overall accuracy in the perception test validation. Co-authored-by: [Co-author Name] <[coauthor@example.com]> * Refactor accuracy calculation in perception test validation * feat: Add SRT_API model to available models This commit adds the SRT_API model to the list of available models in the `__init__.py` file. This model can now be used for evaluation and testing purposes. Co-authored-by: [Co-author Name] <[coauthor@example.com]> --------- Co-authored-by: [Co-author Name] <[coauthor@example.com]>
- Loading branch information
Showing
7 changed files
with
351 additions
and
73 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
from accelerate import Accelerator, DistributedType | ||
import base64 | ||
from io import BytesIO | ||
from copy import deepcopy | ||
from decord import VideoReader, cpu | ||
import numpy as np | ||
from openai import OpenAI | ||
from PIL import Image | ||
import os | ||
import json | ||
from typing import List, Tuple | ||
from tqdm import tqdm | ||
import time | ||
|
||
from lmms_eval.api.instance import Instance | ||
from lmms_eval.api.model import lmms | ||
from lmms_eval.api.registry import register_model | ||
|
||
|
||
@register_model("srt_api") | ||
class SRT_API(lmms): | ||
def __init__( | ||
self, | ||
api_key: str = "EMPTY", | ||
model_version: str = "default", | ||
modality: str = "video", | ||
host: str = "127.0.0.1", | ||
port: int = 30000, | ||
max_frames_num: int = 10, | ||
timeout: int = 120, | ||
continual_mode: bool = False, | ||
response_persistent_folder: str = None, | ||
**kwargs, | ||
) -> None: | ||
super().__init__() | ||
# Manually set a image token for GPT4V so that we can search for it | ||
# and split the text and image | ||
# Here we just use the same token as llava for convenient | ||
self.model_version = model_version | ||
self.modality = modality | ||
self.max_frames_num = max_frames_num | ||
self.image_token = "<image>" | ||
self.timeout = timeout | ||
self.continual_mode = continual_mode | ||
if self.continual_mode: | ||
if response_persistent_folder is None: | ||
raise ValueError("Continual mode requires a persistent path for the response. Please provide a valid path.") | ||
|
||
os.makedirs(response_persistent_folder, exist_ok=True) | ||
self.response_persistent_folder = response_persistent_folder | ||
self.response_persistent_file = os.path.join(self.response_persistent_folder, f"{self.model_version}_response.json") | ||
|
||
if os.path.exists(self.response_persistent_file): | ||
with open(self.response_persistent_file, "r") as f: | ||
self.response_cache = json.load(f) | ||
self.cache_mode = "resume" | ||
else: | ||
self.response_cache = {} | ||
self.cache_mode = "start" | ||
|
||
accelerator = Accelerator() | ||
self.client = OpenAI(api_key="EMPTY", base_url="http://127.0.0.1:30000/v1") | ||
# assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue." | ||
if accelerator.num_processes > 1: | ||
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." | ||
self.accelerator = accelerator | ||
if self.accelerator.is_local_main_process: | ||
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") | ||
self._rank = self.accelerator.local_process_index | ||
self._world_size = self.accelerator.num_processes | ||
else: | ||
self.accelerator = accelerator | ||
self._rank = self.accelerator.local_process_index | ||
self._world_size = self.accelerator.num_processes | ||
|
||
self.device = self.accelerator.device | ||
|
||
# Function to encode the image | ||
def encode_image(self, image: Image): | ||
output_buffer = BytesIO() | ||
image.save(output_buffer, format="PNG") | ||
byte_data = output_buffer.getvalue() | ||
base64_str = base64.b64encode(byte_data).decode("utf-8") | ||
return base64_str | ||
|
||
# Function to encode the video | ||
def encode_video(self, video_path, for_get_frames_num): | ||
vr = VideoReader(video_path, ctx=cpu(0)) | ||
total_frame_num = len(vr) | ||
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, for_get_frames_num, dtype=int) | ||
frame_idx = uniform_sampled_frames.tolist() | ||
frames = vr.get_batch(frame_idx).asnumpy() | ||
|
||
base64_frames = [] | ||
for frame in frames: | ||
img = Image.fromarray(frame) | ||
output_buffer = BytesIO() | ||
img.save(output_buffer, format="PNG") | ||
byte_data = output_buffer.getvalue() | ||
base64_str = base64.b64encode(byte_data).decode("utf-8") | ||
base64_frames.append(base64_str) | ||
|
||
return base64_frames | ||
|
||
def flatten(self, input): | ||
new_list = [] | ||
for i in input: | ||
for j in i: | ||
new_list.append(j) | ||
return new_list | ||
|
||
def generate_until(self, requests) -> List[str]: | ||
res = [] | ||
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") | ||
|
||
for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: | ||
if self.continual_mode is True and self.cache_mode == "resume": | ||
doc_uuid = f"{task}___{split}___{doc_id}" | ||
if doc_uuid in self.response_cache: | ||
response_text = self.response_cache[doc_uuid] | ||
if response_text: | ||
res.append(response_text) | ||
pbar.update(1) | ||
continue | ||
|
||
visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] | ||
visuals = self.flatten(visuals) | ||
imgs = [] # multiple images or frames for video | ||
for visual in visuals: | ||
if self.modality == "image": | ||
img = self.encode_image(visual) | ||
imgs.append(img) | ||
elif self.modality == "video": | ||
frames = self.encode_video(visual, self.max_frames_num) | ||
imgs.extend(frames) | ||
|
||
messages = [] | ||
if self.image_token not in contexts: # single image format | ||
content = [] | ||
for img in imgs: | ||
content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}) | ||
|
||
content.append({"type": "text", "text": contexts}) | ||
messages.append({"role": "user", "content": content}) | ||
else: # interleaved format | ||
contexts = contexts.split(self.image_token) | ||
for idx, img in enumerate(imgs): | ||
content = [ | ||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}, | ||
{"type": "text", "text": contexts[idx]}, | ||
] | ||
messages.append({"role": "user", "content": content}) | ||
messages.append({"role": "user", "content": [{"type": "text", "text": contexts[-1]}]}) | ||
|
||
if "max_new_tokens" not in gen_kwargs: | ||
gen_kwargs["max_new_tokens"] = 1024 | ||
|
||
if "temperature" not in gen_kwargs: | ||
gen_kwargs["temperature"] = 0 | ||
|
||
for attempt in range(5): | ||
try: | ||
response = self.client.chat.completions.create(model=self.model_version, messages=messages, temperature=gen_kwargs["temperature"], max_tokens=gen_kwargs["max_new_tokens"], timeout=self.timeout) | ||
response_text = response.choices[0].message.content.strip() | ||
break # If successful, break out of the loop | ||
|
||
except Exception as e: | ||
eval_logger.info(f"Attempt {attempt + 1} failed with error: {str(e)}.") | ||
if attempt < 4: | ||
time.sleep(NUM_SECONDS_TO_SLEEP) | ||
else: # If this was the last attempt, log and return empty string | ||
eval_logger.error(f"All 5 attempts failed. Last error message: {str(e)}.") | ||
response_text = "" | ||
|
||
res.append(response_text) | ||
pbar.update(1) | ||
|
||
if self.continual_mode is True: # Cache the response | ||
doc_uuid = f"{task}___{split}___{doc_id}" | ||
self.response_cache[doc_uuid] = response_text | ||
with open(self.response_persistent_file, "w") as f: | ||
json.dump(self.response_cache, f) | ||
|
||
pbar.close() | ||
return res | ||
|
||
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: | ||
# TODO | ||
assert False, "GPT4V not support" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.