Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Benchmark] Support MATH-Vision #292

Merged
merged 6 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def main():
else:
if dataset.TYPE in ['MCQ', 'Y/N']:
judge_kwargs['model'] = 'chatgpt-0125'
elif listinstr(['MMVet', 'MathVista', 'LLaVABench', 'MMBench-Video'], dataset_name):
elif listinstr(['MMVet', 'MathVista', 'LLaVABench', 'MMBench-Video', 'MathVision'], dataset_name):
judge_kwargs['model'] = 'gpt-4-turbo'
elif listinstr(['MMLongBench'], dataset_name):
judge_kwargs['model'] = 'gpt-4o'
Expand Down
4 changes: 2 additions & 2 deletions vlmeval/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
from .image_caption import ImageCaptionDataset
from .image_yorn import ImageYORNDataset
from .image_mcq import ImageMCQDataset, MMMUDataset, CustomMCQDataset
from .image_vqa import ImageVQADataset, OCRBench, MathVista, LLaVABench, MMVet, CustomVQADataset
from .image_vqa import ImageVQADataset, MathVision, OCRBench, MathVista, LLaVABench, MMVet, CustomVQADataset
from .mmbench_video import MMBenchVideo
from .utils import *
from ..smp import *


# Add new supported dataset class here
IMAGE_DATASET = [
ImageCaptionDataset, ImageYORNDataset, ImageMCQDataset, ImageVQADataset,
ImageCaptionDataset, ImageYORNDataset, ImageMCQDataset, ImageVQADataset, MathVision,
MMMUDataset, OCRBench, MathVista, LLaVABench, MMVet,
MMLongBench, VCRDataset
]
Expand Down
64 changes: 64 additions & 0 deletions vlmeval/dataset/image_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,70 @@ def evaluate(self, eval_file, **judge_kwargs):
return score


class MathVision(ImageBaseDataset):
TYPE = 'VQA'
DATASET_URL = {
'MathVision': 'https://opencompass.openxlab.space/utils/VLMEval/MathVision.tsv',
'MathVision_MINI': 'https://opencompass.openxlab.space/utils/VLMEval/MathVision_MINI.tsv'
}
DATASET_MD5 = {
'MathVision': '93f6de14f7916e598aa1b7165589831e',
'MathVision_MINI': '060fe4fa5d868987ce179307bd5f8a33'
}

# It returns a DataFrame
@classmethod
def evaluate(self, eval_file, **judge_kwargs):
from .utils.mathv import MATH_V_auxeval, MATH_V_acc

if 'model' in judge_kwargs:
model = judge_kwargs['model']
else:
model = os.path.basename(os.environ.get('LOCAL_LLM'))
suffix = eval_file.split('.')[-1]
storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
nproc = judge_kwargs.pop('nproc', 4)

if not osp.exists(storage):
data = load(eval_file)
model = build_judge(max_tokens=128, **judge_kwargs)
assert model.working(), ('MATH-Vision evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
lt = len(data)
lines = [data.iloc[i] for i in range(lt)]
tups = [(model, line) for line in lines]
indices = [line['index'] for line in lines]

ans = {}
if osp.exists(tmp_file):
ans = load(tmp_file)
tups = [x for x, i in zip(tups, indices) if i not in ans]
indices = [i for i in indices if i not in ans]

if len(indices):
new_results = track_progress_rich(
MATH_V_auxeval,
tups,
nproc=nproc,
chunksize=nproc,
keys=indices,
save=tmp_file,
)
ans = load(tmp_file)
for k, v in zip(indices, new_results):
assert k in ans
assert ans[k]['log'] == v['log'] and ans[k]['res'] == v['res']

data['res'] = [ans[idx]['res'] for idx in data['index']]
data['log'] = [ans[idx]['log'] for idx in data['index']]
dump(data, storage)

score = MATH_V_acc(storage)
score_pth = storage.replace('.xlsx', '_score.csv')
dump(score, score_pth)
return score


class LLaVABench(ImageBaseDataset):
TYPE = 'VQA'
DATASET_URL = {'LLaVABench': 'https://opencompass.openxlab.space/utils/VLMEval/LLaVABench.tsv'}
Expand Down
170 changes: 170 additions & 0 deletions vlmeval/dataset/utils/mathv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from ...smp import *
from ...utils import can_infer
try:
from latex2sympy2 import latex2sympy
except ImportError:
print('Please install latex2sympy2 by running "pip install latex2sympy2"')

FAIL_MSG = 'Failed to obtain answer via API.'


def is_equal(asw: str, gt_asw: str) -> bool:
if type(asw) != str or type(gt_asw) != str:
print('Warning: input is not string')
print(asw, gt_asw)
asw = str(asw).lower().strip()
gt_asw = str(gt_asw).lower().strip()
if gt_asw == asw:
return True
try:
a = eval(gt_asw)
b = eval(asw)
if abs(a - b) < 1e-6:
return True
except:
pass
try:
a = latex2sympy(gt_asw)
b = latex2sympy(asw)
if abs(eval(str(a)) - eval(str(b))) < 1e-6:
return True
if abs(a - b) < 1e-6:
return True
except:
pass
return False


def get_gpt4_ICE():
example_1 = """
Hint: Please answer the question and provide the final answer at the end.\n
Question: Which number is missing?\n
Model response: The number missing in the sequence is 14.\n
Extracted answer: 14
"""

example_2 = """
Hint: Please answer the question and provide the final answer at the end.\n
Question: What is the fraction of females facing the camera?\n
Model response: The fraction of females facing the camera is 0.6,
which means that six out of ten females in the group are facing the camera.\n
Extracted answer: 0.6
"""

example_3 = """
Hint: Please answer the question and provide the final answer at the end.\n
Question: How much money does Luca need to buy a sour apple candy and a butter-scotch candy? (Unit: $)\n
Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy.\n
Extracted answer: 1.45
"""

example_4 = """
Hint: Please answer the question and provide the final answer at the end.\n
Question: Between which two years does the line graph saw its maximum peak?\n
Model response: The line graph saw its maximum peak between 2007 and 2008.\n
Extracted answer: [2007, 2008]
"""

example_5 = """
Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.\n
Question: What fraction of the shape is blue?\n
Choices: (A) 3/11 (B) 8/11 (C) 6/11 (D) 3/5\n
Model response: The correct answer is (B) 8/11.\n
Extracted answer: B
"""

return [example_1, example_2, example_3, example_4, example_5]


def build_mathv_gpt4_prompt(line):
task_description = """
Please read the following example.
Then extract the answer from the model response and type it at the end of the prompt.\n
"""
question = line['question']
prediction = str(line['prediction'])
prompt = task_description
examples = get_gpt4_ICE()
for example in examples:
prompt += example + '\n'
prompt += question + '\n'
prompt += 'Model respone: ' + prediction
prompt += 'Extracted answer:'
return prompt


def list_to_dict(lst):
return {chr(65 + i): val for i, val in enumerate(lst)}


def post_check(line, prefetch=False):
res = None
ans = line['answer']
response = line['prediction'] if prefetch else line['res']
try:
if len(eval(line['choices'])) > 0:
ans = line['answer']
choices = list_to_dict(eval(line['choices']))
res = can_infer(response, choices)
if prefetch:
return res
else:
res = str(res)
ans = str(ans)
except ValueError:
pass

if is_equal(res, ans):
return res if prefetch else True
else:
return False


def MATH_V_auxeval(model, line):
prompt = build_mathv_gpt4_prompt(line)
log = ''
retry = 5
if post_check(line, prefetch=True):
res = post_check(line, prefetch=True)
return dict(log='Prefetch succeed', res=res)
for i in range(retry):
prediction = line['prediction']
res = model.generate(prompt, temperature=i * 0.5)

if FAIL_MSG in res:
log += f'Try {i}: output is {prediction}, failed to parse.\n'
else:
log += 'Succeed'
return dict(log=log, res=res)
log += 'All 5 retries failed.\n'
return dict(log=log, res='')


def MATH_V_acc(result_file):
data = load(result_file)
tot = defaultdict(lambda: 0)
fetch = defaultdict(lambda: 0)
hit = defaultdict(lambda: 0)
lt = len(data)
for i in range(lt):
item = data.iloc[i]
cate = item['category']
tot['Overall'] += 1
tot[cate] += 1
if item['log'] == 'Prefetch succeed':
fetch['Overall'] += 1
fetch[cate] += 1
if post_check(item, prefetch=False):
hit['Overall'] += 1
hit[cate] += 1

res = defaultdict(list)
for k in tot.keys():
res['Subject'].append(k)
res['tot'].append(tot[k])
res['prefetch'].append(fetch[k])
res['hit'].append(hit[k])
res['prefetch_rate'].append(fetch[k] / tot[k] * 100)
res['acc'].append(hit[k] / tot[k] * 100)
res = pd.DataFrame(res).sort_values('Subject', ignore_index=True)
return res
7 changes: 6 additions & 1 deletion vlmeval/smp/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,12 @@ def update_to(self, b=1, bsize=1, tsize=None):
# Handle Failed Downloads from huggingface.co
if 'huggingface.co' in url:
url_new = url.replace('huggingface.co', 'hf-mirror.com')
os.system(f'wget {url_new} -O {filename}')
try:
os.system(f'wget {url_new} -O {filename}')
except:
raise Exception(f'Failed to download {url}')
else:
raise Exception(f'Failed to download {url}')

return filename

Expand Down
Loading