Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add av_odyssey bench #461

Merged
merged 4 commits into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 92 additions & 2 deletions lmms_eval/models/gemini_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import io
import json
import os
import pathlib
import re
import time
from typing import List, Tuple

Expand Down Expand Up @@ -39,8 +41,9 @@ def __init__(
model_version: str = "gemini-1.5-pro",
# modality: str = "image",
timeout: int = 120,
continual_mode: bool = False,
continual_mode: bool = True,
response_persistent_folder: str = "./logs/gemini_persistent_folder",
interleave: bool = False,
# We will cache the Gemini API response in this path and use it for future requests
**kwargs,
) -> None:
Expand All @@ -49,6 +52,7 @@ def __init__(
self.timeout = timeout
self.model = genai.GenerativeModel(model_version)
self.continual_mode = continual_mode
self.interleave = interleave
# if self.continual_mode and response_persistent_folder is None:
# raise ValueError("Continual mode requires a persistent path for the response. We will cache the Gemini API response in this path and use it for future requests. Please provide a valid path.")
if self.continual_mode:
Expand Down Expand Up @@ -132,6 +136,20 @@ def convert_modality(self, images):
eval_logger.error(f"Error converting video: {str(e)}")
return images

def construct_interleaved_input(self, content, media):
pattern = r"<media_(\d+)>"
parts = re.split(pattern, content)
result = []
for i, part in enumerate(parts):
if i % 2 == 0:
if part == "":
continue
result.append(part)
else:
result.append(media[int(part)])

return result

def generate_until(self, requests) -> List[str]:
res = []
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
Expand Down Expand Up @@ -163,7 +181,10 @@ def get_uuid(task, split, doc_id):
visuals = self.flatten(visuals)
visuals = self.convert_modality(visuals)

message = [contexts] + visuals
if self.interleave:
message = self.construct_interleaved_input(contexts, visuals)
else:
message = [contexts] + visuals

for attempt in range(5):
try:
Expand Down Expand Up @@ -213,3 +234,72 @@ def generate_until_multi_round(self, requests) -> List[str]:
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
# TODO
assert False, "Gemini API not support"

def get_image_audio_text_interleaved_messsage(self, image_path, audio_path, question):
# image_path for list of image path
# audio_path for list of audio path
# question for question

# fixed image token and no audio in text
for index in range(1, 1 + len(image_path)):
question = question.replace(f"[img{index}]", "<image>")
for index in range(1, 1 + len(audio_path)):
question = question.replace(f"[audio{index}]", "<audio>")

text = question

info_list = []
image_counter = 0
audio_counter = 0
for part in re.split(r"(<image>|<audio>)", text):
if part == "<image>":
info_list.append(Image.open(image_path[image_counter]))
image_counter += 1
elif part == "<audio>":
info_list.append({"mime_type": "audio/wav", "data": pathlib.Path(audio_path[audio_counter]).read_bytes()})
audio_counter += 1
else:
if part == " ":
continue
info_list.append(part)

return info_list

def get_video_audio_text_interleaved_message(self, video_path, audio_path, question):
# image_path for list of image path
# audio_path for list of audio path
# question for question

# fixed video token and no audio in text
for index in range(1, 1 + len(video_path)):
question = question.replace(f"[video{index}]", "<video>")
for index in range(1, 1 + len(audio_path)):
question = question.replace(f"[audio{index}]", "<audio>")

text = question

info_list = []
video_counter = 0
audio_counter = 0
for part in re.split(r"(<video>|<audio>)", text):
if part == "<video>":
current_video_file_name = video_path[video_counter]
current_video_file = genai.upload_file(path=current_video_file_name)
while current_video_file.state.name == "processing":
print("uploading file")
time.sleep(5)
current_video_file = genai.get_file(current_video_file.name)
if current_video_file.state.name == "FAILED":
print("uploading file failed, next question")
return 0
info_list.append(current_video_file)
video_counter += 1
elif part == "<audio>":
info_list.append({"mime_type": "audio/wav", "data": pathlib.Path(audio_path[audio_counter]).read_bytes()})
audio_counter += 1
else:
if part == " ":
continue
info_list.append(part)

return info_list
156 changes: 156 additions & 0 deletions lmms_eval/models/text.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import re\n",
"\n",
"def split_media_tags(content):\n",
" # 用正则表达式匹配 <media数字>\n",
" # (\\d+) 捕获组用来提取数字\n",
" pattern = r'<media(\\d+)>'\n",
" \n",
" # 用 split 方法分割文本,同时保留匹配到的数字\n",
" # re.split 会返回一个列表,包含分割后的文本和匹配到的捕获组\n",
" parts = re.split(pattern, content)\n",
" \n",
" # 处理结果列表,将数字转换为整型\n",
" result = []\n",
" for i, part in enumerate(parts):\n",
" if i % 2 == 0: # 偶数索引是文本\n",
" result.append(part)\n",
" else: # 奇数索引是数字\n",
" result.append(int(part))\n",
" \n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['', 1, 'world', 2, '!']"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"split_media_tags('<media1>world<media2>!') # ['hello', 1, 'world', 2, '!']"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input: Select the instrument represented in images that corresponds to the audio [audio1] from [img1] [img2] [img3] [img4].\n",
"Output: ['Select the instrument represented in images that corresponds to the audio ', ('audio', 1), ' from ', ('img', 1), ' ', ('img', 2), ' ', ('img', 3), ' ', ('img', 4), '.']\n",
"\n",
"Input: [video1] is a great video\n",
"Output: [('video', 1), ' is a great video']\n",
"\n",
"Input: Compare [img1] and [img2]\n",
"Output: ['Compare ', ('img', 1), ' and ', ('img', 2)]\n",
"\n",
"Input: Listen to [audio1] and watch [video1]\n",
"Output: ['Listen to ', ('audio', 1), ' and watch ', ('video', 1)]\n",
"\n",
"Input: [img1] at the beginning and [img2] at the end\n",
"Output: [('img', 1), ' at the beginning and ', ('img', 2), ' at the end']\n",
"\n"
]
}
],
"source": [
"import re\n",
"\n",
"def split_media_tags(content):\n",
" # 匹配 [类型数字] 格式的标签\n",
" # 捕获组 1 捕获类型 (audio|video|img)\n",
" # 捕获组 2 捕获数字\n",
" pattern = r'\\[(audio|video|img)(\\d+)\\]'\n",
" \n",
" # 用 finditer 找到所有匹配\n",
" matches = list(re.finditer(pattern, content))\n",
" if not matches:\n",
" return [content]\n",
" \n",
" result = []\n",
" last_end = 0\n",
" \n",
" for match in matches:\n",
" # 添加标签之前的文本(如果有)\n",
" if match.start() > last_end:\n",
" result.append(content[last_end:match.start()])\n",
" \n",
" # 添加标签信息为元组 (类型, 数字)\n",
" media_type = match.group(1) # audio, video 或 img\n",
" media_num = int(match.group(2)) # 数字\n",
" result.append((media_type, media_num))\n",
" \n",
" last_end = match.end()\n",
" \n",
" # 添加最后一个标签之后的文本(如果有)\n",
" if last_end < len(content):\n",
" result.append(content[last_end:])\n",
" \n",
" return result\n",
"\n",
"# 测试\n",
"test_cases = [\n",
" \"Select the instrument represented in images that corresponds to the audio [audio1] from [img1] [img2] [img3] [img4].\",\n",
" \"[video1] is a great video\",\n",
" \"Compare [img1] and [img2]\",\n",
" \"Listen to [audio1] and watch [video1]\",\n",
" \"[img1] at the beginning and [img2] at the end\",\n",
"]\n",
"\n",
"for test in test_cases:\n",
" print(f\"Input: {test}\")\n",
" print(f\"Output: {split_media_tags(test)}\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "av-odyssey",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
24 changes: 24 additions & 0 deletions lmms_eval/tasks/av_odyssey/av_odyssey.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
dataset_path: AV-Odyssey/AV_Odyssey_Bench_LMMs_Eval
dataset_kwargs:
token: True
cache_dir: AV_Odyssey
video: True
task: "av_odyssey"
test_split: test
output_type: generate_until
doc_to_visual: !function utils.av_odyssey_doc_to_visual
doc_to_text: !function utils.av_odyssey_doc_to_text
doc_to_target: "answer"
generation_kwargs:
max_new_tokens: 1024
temperature: 0
top_p: 1.0
num_beams: 1
do_sample: false
# The return value of process_results will be used by metrics
process_results: !function utils.av_odyssey_process_results
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
metric_list:
- metric: av_odyssey_score
aggregation: !function utils.av_odyssey_aggregate_results
higher_is_better: true
Loading
Loading