Skip to content

Commit

Permalink
[Benchmark] Support MATH-Vision (open-compass#292)
Browse files Browse the repository at this point in the history
* [Benchmark] Support MATH-Vision

* update url

* Fix download_file

* update MATH_V md5

* fix MathVision

* fix lint

---------

Co-authored-by: Ke Wang <wangk.gm@gmail.com>
Co-authored-by: kennymckormick <dhd@pku.edu.cn>
  • Loading branch information
3 people authored Jul 19, 2024
1 parent 19795d8 commit 57e7ac7
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 4 deletions.
2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def main():
else:
if dataset.TYPE in ['MCQ', 'Y/N']:
judge_kwargs['model'] = 'chatgpt-0125'
elif listinstr(['MMVet', 'MathVista', 'LLaVABench', 'MMBench-Video'], dataset_name):
elif listinstr(['MMVet', 'MathVista', 'LLaVABench', 'MMBench-Video', 'MathVision'], dataset_name):
judge_kwargs['model'] = 'gpt-4-turbo'
elif listinstr(['MMLongBench'], dataset_name):
judge_kwargs['model'] = 'gpt-4o'
Expand Down
4 changes: 2 additions & 2 deletions vlmeval/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
from .image_caption import ImageCaptionDataset
from .image_yorn import ImageYORNDataset
from .image_mcq import ImageMCQDataset, MMMUDataset, CustomMCQDataset
from .image_vqa import ImageVQADataset, OCRBench, MathVista, LLaVABench, MMVet, CustomVQADataset
from .image_vqa import ImageVQADataset, MathVision, OCRBench, MathVista, LLaVABench, MMVet, CustomVQADataset
from .mmbench_video import MMBenchVideo
from .utils import *
from ..smp import *


# Add new supported dataset class here
IMAGE_DATASET = [
ImageCaptionDataset, ImageYORNDataset, ImageMCQDataset, ImageVQADataset,
ImageCaptionDataset, ImageYORNDataset, ImageMCQDataset, ImageVQADataset, MathVision,
MMMUDataset, OCRBench, MathVista, LLaVABench, MMVet,
MMLongBench, VCRDataset
]
Expand Down
64 changes: 64 additions & 0 deletions vlmeval/dataset/image_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,70 @@ def evaluate(self, eval_file, **judge_kwargs):
return score


class MathVision(ImageBaseDataset):
TYPE = 'VQA'
DATASET_URL = {
'MathVision': 'https://opencompass.openxlab.space/utils/VLMEval/MathVision.tsv',
'MathVision_MINI': 'https://opencompass.openxlab.space/utils/VLMEval/MathVision_MINI.tsv'
}
DATASET_MD5 = {
'MathVision': '93f6de14f7916e598aa1b7165589831e',
'MathVision_MINI': '060fe4fa5d868987ce179307bd5f8a33'
}

# It returns a DataFrame
@classmethod
def evaluate(self, eval_file, **judge_kwargs):
from .utils.mathv import MATH_V_auxeval, MATH_V_acc

if 'model' in judge_kwargs:
model = judge_kwargs['model']
else:
model = os.path.basename(os.environ.get('LOCAL_LLM'))
suffix = eval_file.split('.')[-1]
storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
nproc = judge_kwargs.pop('nproc', 4)

if not osp.exists(storage):
data = load(eval_file)
model = build_judge(max_tokens=128, **judge_kwargs)
assert model.working(), ('MATH-Vision evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
lt = len(data)
lines = [data.iloc[i] for i in range(lt)]
tups = [(model, line) for line in lines]
indices = [line['index'] for line in lines]

ans = {}
if osp.exists(tmp_file):
ans = load(tmp_file)
tups = [x for x, i in zip(tups, indices) if i not in ans]
indices = [i for i in indices if i not in ans]

if len(indices):
new_results = track_progress_rich(
MATH_V_auxeval,
tups,
nproc=nproc,
chunksize=nproc,
keys=indices,
save=tmp_file,
)
ans = load(tmp_file)
for k, v in zip(indices, new_results):
assert k in ans
assert ans[k]['log'] == v['log'] and ans[k]['res'] == v['res']

data['res'] = [ans[idx]['res'] for idx in data['index']]
data['log'] = [ans[idx]['log'] for idx in data['index']]
dump(data, storage)

score = MATH_V_acc(storage)
score_pth = storage.replace('.xlsx', '_score.csv')
dump(score, score_pth)
return score


class LLaVABench(ImageBaseDataset):
TYPE = 'VQA'
DATASET_URL = {'LLaVABench': 'https://opencompass.openxlab.space/utils/VLMEval/LLaVABench.tsv'}
Expand Down
170 changes: 170 additions & 0 deletions vlmeval/dataset/utils/mathv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from ...smp import *
from ...utils import can_infer
try:
from latex2sympy2 import latex2sympy
except ImportError:
print('Please install latex2sympy2 by running "pip install latex2sympy2"')

FAIL_MSG = 'Failed to obtain answer via API.'


def is_equal(asw: str, gt_asw: str) -> bool:
if type(asw) != str or type(gt_asw) != str:
print('Warning: input is not string')
print(asw, gt_asw)
asw = str(asw).lower().strip()
gt_asw = str(gt_asw).lower().strip()
if gt_asw == asw:
return True
try:
a = eval(gt_asw)
b = eval(asw)
if abs(a - b) < 1e-6:
return True
except:
pass
try:
a = latex2sympy(gt_asw)
b = latex2sympy(asw)
if abs(eval(str(a)) - eval(str(b))) < 1e-6:
return True
if abs(a - b) < 1e-6:
return True
except:
pass
return False


def get_gpt4_ICE():
example_1 = """
Hint: Please answer the question and provide the final answer at the end.\n
Question: Which number is missing?\n
Model response: The number missing in the sequence is 14.\n
Extracted answer: 14
"""

example_2 = """
Hint: Please answer the question and provide the final answer at the end.\n
Question: What is the fraction of females facing the camera?\n
Model response: The fraction of females facing the camera is 0.6,
which means that six out of ten females in the group are facing the camera.\n
Extracted answer: 0.6
"""

example_3 = """
Hint: Please answer the question and provide the final answer at the end.\n
Question: How much money does Luca need to buy a sour apple candy and a butter-scotch candy? (Unit: $)\n
Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy.\n
Extracted answer: 1.45
"""

example_4 = """
Hint: Please answer the question and provide the final answer at the end.\n
Question: Between which two years does the line graph saw its maximum peak?\n
Model response: The line graph saw its maximum peak between 2007 and 2008.\n
Extracted answer: [2007, 2008]
"""

example_5 = """
Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.\n
Question: What fraction of the shape is blue?\n
Choices: (A) 3/11 (B) 8/11 (C) 6/11 (D) 3/5\n
Model response: The correct answer is (B) 8/11.\n
Extracted answer: B
"""

return [example_1, example_2, example_3, example_4, example_5]


def build_mathv_gpt4_prompt(line):
task_description = """
Please read the following example.
Then extract the answer from the model response and type it at the end of the prompt.\n
"""
question = line['question']
prediction = str(line['prediction'])
prompt = task_description
examples = get_gpt4_ICE()
for example in examples:
prompt += example + '\n'
prompt += question + '\n'
prompt += 'Model respone: ' + prediction
prompt += 'Extracted answer:'
return prompt


def list_to_dict(lst):
return {chr(65 + i): val for i, val in enumerate(lst)}


def post_check(line, prefetch=False):
res = None
ans = line['answer']
response = line['prediction'] if prefetch else line['res']
try:
if len(eval(line['choices'])) > 0:
ans = line['answer']
choices = list_to_dict(eval(line['choices']))
res = can_infer(response, choices)
if prefetch:
return res
else:
res = str(res)
ans = str(ans)
except ValueError:
pass

if is_equal(res, ans):
return res if prefetch else True
else:
return False


def MATH_V_auxeval(model, line):
prompt = build_mathv_gpt4_prompt(line)
log = ''
retry = 5
if post_check(line, prefetch=True):
res = post_check(line, prefetch=True)
return dict(log='Prefetch succeed', res=res)
for i in range(retry):
prediction = line['prediction']
res = model.generate(prompt, temperature=i * 0.5)

if FAIL_MSG in res:
log += f'Try {i}: output is {prediction}, failed to parse.\n'
else:
log += 'Succeed'
return dict(log=log, res=res)
log += 'All 5 retries failed.\n'
return dict(log=log, res='')


def MATH_V_acc(result_file):
data = load(result_file)
tot = defaultdict(lambda: 0)
fetch = defaultdict(lambda: 0)
hit = defaultdict(lambda: 0)
lt = len(data)
for i in range(lt):
item = data.iloc[i]
cate = item['category']
tot['Overall'] += 1
tot[cate] += 1
if item['log'] == 'Prefetch succeed':
fetch['Overall'] += 1
fetch[cate] += 1
if post_check(item, prefetch=False):
hit['Overall'] += 1
hit[cate] += 1

res = defaultdict(list)
for k in tot.keys():
res['Subject'].append(k)
res['tot'].append(tot[k])
res['prefetch'].append(fetch[k])
res['hit'].append(hit[k])
res['prefetch_rate'].append(fetch[k] / tot[k] * 100)
res['acc'].append(hit[k] / tot[k] * 100)
res = pd.DataFrame(res).sort_values('Subject', ignore_index=True)
return res
7 changes: 6 additions & 1 deletion vlmeval/smp/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,12 @@ def update_to(self, b=1, bsize=1, tsize=None):
# Handle Failed Downloads from huggingface.co
if 'huggingface.co' in url:
url_new = url.replace('huggingface.co', 'hf-mirror.com')
os.system(f'wget {url_new} -O {filename}')
try:
os.system(f'wget {url_new} -O {filename}')
except:
raise Exception(f'Failed to download {url}')
else:
raise Exception(f'Failed to download {url}')

return filename

Expand Down

0 comments on commit 57e7ac7

Please sign in to comment.