Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Update Aria from extraction #466

Merged
merged 2 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions lmms_eval/models/aria.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import re
import warnings
from typing import List, Optional, Tuple, Union

import numpy as np
import PIL
import requests
import torch
from accelerate import Accelerator, DistributedType
from accelerate.state import AcceleratorState
Expand All @@ -16,6 +16,7 @@
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model
from lmms_eval.models.model_utils.load_video import read_video_pyav_pil

warnings.filterwarnings("ignore")

Expand Down Expand Up @@ -179,16 +180,19 @@ def flatten(self, input):
return new_list

def load_video(self, video_path, max_frames_num):
if type(video_path) == str:
vr = VideoReader(video_path, ctx=cpu(0))
else:
vr = VideoReader(video_path[0], ctx=cpu(0))
total_frame_num = len(vr)
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
spare_frames = vr.get_batch(frame_idx).asnumpy()
spare_frames = [Image.fromarray(x) for x in spare_frames]
return spare_frames # (frames, height, width, channels)
if isinstance(video_path, list):
video_path = video_path[0]
return read_video_pyav_pil(video_path, num_frm=max_frames_num)
# if type(video_path) == str:
# vr = VideoReader(video_path, ctx=cpu(0))
# else:
# vr = VideoReader(video_path[0], ctx=cpu(0))
# total_frame_num = len(vr)
# uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int)
# frame_idx = uniform_sampled_frames.tolist()
# spare_frames = vr.get_batch(frame_idx).asnumpy()
# spare_frames = [Image.fromarray(x) for x in spare_frames]
# return spare_frames # (frames, height, width, channels)

def generate_until(self, requests: List[Instance]) -> List[str]:
res = []
Expand Down
75 changes: 47 additions & 28 deletions lmms_eval/tasks/videofinal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ def videoperception_doc_to_visual(doc):

return [video_path]


def videoperception_doc_to_visual_question_only(doc):
video_path = doc["id"] + "_image" + ".mp4"
question_only_cache_dir = os.path.join(cache_dir, "question_only")
question_only_cache_dir = os.path.join(cache_dir, "question_only")
video_path = os.path.join(question_only_cache_dir, video_path)

if os.path.exists(video_path):
Expand All @@ -101,7 +102,7 @@ def videoperception_doc_to_text_adaptation(doc, lmms_eval_specific_kwargs=None):
question += "\n" + parsed_options
else:
pre_prompt += lmms_eval_specific_kwargs["open_ended_prompt"]

return f"{pre_prompt}{question}"


Expand Down Expand Up @@ -147,12 +148,12 @@ def videoperception_doc_to_text_with_transcript_perception_comprehension(doc, lm
transcript = f.read().strip()
else:
transcript = "[Transcript not available]"

post_prompt = ""
post_prompt += lmms_eval_specific_kwargs["perception_and_comprehension_prompt"]

formatted_output = f"\nTranscript for the Video:\n{transcript}\n\nQuestion for the video:\n{question}{post_prompt}"

return formatted_output


Expand Down Expand Up @@ -216,9 +217,7 @@ def videoperception_process_results(doc, results):
parsed_pred = parse_open_response(pred)

mmmu_acc = {"id": doc["id"], "subdomain": extract_subset_name(doc["id"]), "question_type": question_type, "answer": doc["answer"], "parsed_pred": parsed_pred}
return {
"mmmu_acc": mmmu_acc
}
return {"mmmu_acc": mmmu_acc}


# return subset name
Expand Down Expand Up @@ -275,7 +274,6 @@ def videoperception_aggregate_results(results):
return printable_results["Overall"]["acc"]



##################
# Helper functions written by official MMMU repo.
##################
Expand Down Expand Up @@ -414,12 +412,12 @@ def parse_multi_choice_response(response, all_choices, index2ans):
"""
if response == "API Error" or response == "":
return "API Error"

# Step 1: Clean up punctuation from the response
for char in [",", ".", "!", "?", ";", ":", "'"]:
response = response.strip(char)
response = " " + response + " " # Add space to avoid partial match
#print(response)
# print(response)

index_ans = True
ans_with_brack = False
Expand All @@ -430,13 +428,13 @@ def parse_multi_choice_response(response, all_choices, index2ans):
# Step 2: If no candidates, look for choices with a period after (A. B. C. D.)
for choice in all_choices: # e.g., A. B. C. D.
if f"{choice}." in response:
#print(f"Found choice with period after: {choice}")
# print(f"Found choice with period after: {choice}")
candidates.append(choice)
ans_with_period = True
# Step 2.1: If no candidates, look for choices with a colon after (A: B: C: D:)
for choice in all_choices: # e.g., A: B: C: D:
if f"{choice}:" in response:
#print(f"Found choice with semicolon after: {choice}")
# print(f"Found choice with semicolon after: {choice}")
candidates.append(choice)
ans_with_colon = True

Expand All @@ -452,57 +450,57 @@ def parse_multi_choice_response(response, all_choices, index2ans):
if len(candidates) == 0:
for choice in all_choices: # e.g., A B C D
if f"{choice} " in response:
#print(f"Found choice without parentheses (space after): {choice}")
# print(f"Found choice without parentheses (space after): {choice}")
candidates.append(choice)

# Step 5: If no candidates and response has more than 5 tokens, try parsing based on content
if len(candidates) == 0 and len(response.split()) > 5:
for index, ans in index2ans.items():
if ans.lower() in response.lower():
#print(f"Found answer content match: {ans}")
# print(f"Found answer content match: {ans}")
candidates.append(index)
index_ans = False # It's content answer, not an index

# Step 6: If still no candidates, randomly choose one
if len(candidates) == 0:
pred_index = "No Answere Found"
#print(f"No candidates found.")
# print(f"No candidates found.")
# Step 7: If multiple candidates found, use the one appearing last
elif len(candidates) > 1:
start_indexes = []
if index_ans:
if ans_with_period:
for can in candidates:
index = response.rfind(f"{can}.")
#print(f"Checking position of choice: {can} at {index}")
# print(f"Checking position of choice: {can} at {index}")
start_indexes.append(index)
elif ans_with_colon:
for can in candidates:
index = response.rfind(f"{can}:")
#print(f"Checking position of choice: {can} at {index}")
# print(f"Checking position of choice: {can} at {index}")
start_indexes.append(index)
elif ans_with_brack:
for can in candidates:
index = response.rfind(f"({can})")
#print(f"Checking position of choice with parentheses: {can} at {index}")
# print(f"Checking position of choice with parentheses: {can} at {index}")
start_indexes.append(index)
else:
for can in candidates:
index = response.rfind(f" {can} ")
#print(f"Checking position of choice: {can} at {index}")
# print(f"Checking position of choice: {can} at {index}")
start_indexes.append(index)
else:
for can in candidates:
index = response.lower().rfind(index2ans[can].lower())
#print(f"Checking position of content match: {can} at {index}")
# print(f"Checking position of content match: {can} at {index}")
start_indexes.append(index)
# Get the last one (max index)
pred_index = candidates[np.argmax(start_indexes)]
#print(f"Multiple candidates, selected based on last occurrence: {pred_index}")
# print(f"Multiple candidates, selected based on last occurrence: {pred_index}")
else:
# If only one candidate, use it
pred_index = candidates[0]
#print(f"Only one candidate found, selected: {pred_index}")
# print(f"Only one candidate found, selected: {pred_index}")

return pred_index

Expand Down Expand Up @@ -586,11 +584,32 @@ def get_key_subresponses(response):
sub_responses = re.split(r"\.\s(?=[A-Z])|\n", response)
indicators_of_keys = [
# Common explanation or conclusion phrases
"could be ", "so ", "is ", "thus ", "therefore ", "final ", "answer ",
"result ", "are ", "in total ", "total ", "identify ", "recognize ",
"calculated as ", "counted as ", "measured as ", "observed as ",
"concluded as ", "found to be ", "equals ", "determined to be ",
"number of ", "value is ", "adds up to ", "have ", "has "
"could be ",
"so ",
"is ",
"thus ",
"therefore ",
"final ",
"answer ",
"result ",
"are ",
"in total ",
"total ",
"identify ",
"recognize ",
"calculated as ",
"counted as ",
"measured as ",
"observed as ",
"concluded as ",
"found to be ",
"equals ",
"determined to be ",
"number of ",
"value is ",
"adds up to ",
"have ",
"has ",
]

key_responses = []
Expand Down
41 changes: 22 additions & 19 deletions lmms_eval/tasks/videommmu_engineering_1206/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ def videosearch_doc_to_visual_business(doc):
elif os.path.exists(video_path.replace("mp4", "MP4")):
video_path = video_path.replace("mp4", "MP4")
else:
video_path = os.path.join(cache_dir_business, "validation_Accounting_15.mp4")
#sys.exit(f"video path:{video_path} does not exist, please check")
video_path = os.path.join(cache_dir_business, "validation_Accounting_15.mp4")
# sys.exit(f"video path:{video_path} does not exist, please check")
return [video_path]


def videosearch_doc_to_visual_science(doc):
video_path = doc["id"] + ".mp4"
video_path = os.path.join(cache_dir_science, video_path)
Expand All @@ -60,12 +61,13 @@ def videosearch_doc_to_visual_science(doc):
elif os.path.exists(video_path.replace("mp4", "MP4")):
video_path = video_path.replace("mp4", "MP4")
else:
video_path = os.path.join(cache_dir_science, "validation_Math_14.mp4")
video_path = os.path.join(cache_dir_science, "validation_Math_14.mp4")
print(video_path)
print("Not found")
#sys.exit(f"video path:{video_path} does not exist, please check")
# sys.exit(f"video path:{video_path} does not exist, please check")
return [video_path]


def videosearch_doc_to_visual_engineering(doc):
video_path = doc["id"] + ".mp4"
video_path = os.path.join(cache_dir_engineering, video_path)
Expand All @@ -74,8 +76,8 @@ def videosearch_doc_to_visual_engineering(doc):
elif os.path.exists(video_path.replace("mp4", "MP4")):
video_path = video_path.replace("mp4", "MP4")
else:
video_path = os.path.join(cache_dir_engineering, "validation_Agriculture_1.mp4")
#sys.exit(f"video path:{video_path} does not exist, please check")
video_path = os.path.join(cache_dir_engineering, "validation_Agriculture_1.mp4")
# sys.exit(f"video path:{video_path} does not exist, please check")
return [video_path]


Expand All @@ -93,7 +95,7 @@ def videosearch_doc_to_text(doc, lmms_eval_specific_kwargs=None):
parsed_options = parse_options(doc["options"])
question += "\n" + parsed_options

#print(f"{pre_prompt}{question}")
# print(f"{pre_prompt}{question}")
return f"{pre_prompt}{question}"


Expand Down Expand Up @@ -373,6 +375,7 @@ def evaluate_mmmu(samples):
return {"acc": 0}
return judge_dict, {"acc": pred_correct / len(samples)}


def parse_multi_choice_response(response, all_choices, index2ans):
"""
Parse the prediction from the generated response.
Expand All @@ -382,7 +385,7 @@ def parse_multi_choice_response(response, all_choices, index2ans):
for char in [",", ".", "!", "?", ";", ":", "'"]:
response = response.strip(char)
response = " " + response + " " # Add space to avoid partial match
#print(response)
# print(response)

index_ans = True
ans_with_brack = False
Expand All @@ -392,13 +395,13 @@ def parse_multi_choice_response(response, all_choices, index2ans):
# Step 2: If no candidates, look for choices with a period after (A. B. C. D.)
for choice in all_choices: # e.g., A. B. C. D.
if f"{choice}." in response:
#print(f"Found choice with period after: {choice}")
# print(f"Found choice with period after: {choice}")
candidates.append(choice)
ans_with_period = True
# Step 2.1: If no candidates, look for choices with a period after (A. B. C. D.)
for choice in all_choices: # e.g., A. B. C. D.
if f"{choice}:" in response:
#print(f"Found choice with semicolon after: {choice}")
# print(f"Found choice with semicolon after: {choice}")
candidates.append(choice)
ans_with_colon = True

Expand All @@ -414,14 +417,14 @@ def parse_multi_choice_response(response, all_choices, index2ans):
if len(candidates) == 0:
for choice in all_choices: # e.g., A B C D
if f"{choice} " in response:
#print(f"Found choice without parentheses (space after): {choice}")
# print(f"Found choice without parentheses (space after): {choice}")
candidates.append(choice)

# Step 5: If no candidates and response has more than 5 tokens, try parsing based on content
if len(candidates) == 0 and len(response.split()) > 5:
for index, ans in index2ans.items():
if ans.lower() in response.lower():
#print(f"Found answer content match: {ans}")
# print(f"Found answer content match: {ans}")
candidates.append(index)
index_ans = False # It's content answer, not an index

Expand All @@ -436,35 +439,35 @@ def parse_multi_choice_response(response, all_choices, index2ans):
if ans_with_period:
for can in candidates:
index = response.rfind(f"{can}.")
#print(f"Checking position of choice: {can} at {index}")
# print(f"Checking position of choice: {can} at {index}")
start_indexes.append(index)
elif ans_with_colon:
for can in candidates:
index = response.rfind(f"{can}:")
#print(f"Checking position of choice: {can} at {index}")
# print(f"Checking position of choice: {can} at {index}")
start_indexes.append(index)
elif ans_with_brack:
for can in candidates:
index = response.rfind(f"({can})")
#print(f"Checking position of choice with parentheses: {can} at {index}")
# print(f"Checking position of choice with parentheses: {can} at {index}")
start_indexes.append(index)
else:
for can in candidates:
index = response.rfind(f" {can} ")
#print(f"Checking position of choice: {can} at {index}")
# print(f"Checking position of choice: {can} at {index}")
start_indexes.append(index)
else:
for can in candidates:
index = response.lower().rfind(index2ans[can].lower())
#print(f"Checking position of content match: {can} at {index}")
# print(f"Checking position of content match: {can} at {index}")
start_indexes.append(index)
# Get the last one (max index)
pred_index = candidates[np.argmax(start_indexes)]
#print(f"Multiple candidates, selected based on last occurrence: {pred_index}")
# print(f"Multiple candidates, selected based on last occurrence: {pred_index}")
else:
# If only one candidate, use it
pred_index = candidates[0]
#print(f"Only one candidate found, selected: {pred_index}")
# print(f"Only one candidate found, selected: {pred_index}")
# pred_index = "Z"
# print(pred_index)
return pred_index
Expand Down
2 changes: 1 addition & 1 deletion lmms_eval/tasks/videosearch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def videosearch_process_results(doc, results):
# Handle the case where 'question_type' might be missing for perception and understanding
question_type = doc.get("question_type", "perception")
if question_type == "multiple-choice":
#index2ans, all_choices = get_multi_choice_info(ast.literal_eval(doc["options"]))
# index2ans, all_choices = get_multi_choice_info(ast.literal_eval(doc["options"]))
index2ans, all_choices = get_multi_choice_info(doc["options"])
parsed_pred = parse_multi_choice_response(pred, all_choices, index2ans)
elif question_type == "open":
Expand Down
Loading