Skip to content

Commit

Permalink
Add Muirbench (#143)
Browse files Browse the repository at this point in the history
* handle gen kwargs in internvl2

* Add muirbench
  • Loading branch information
kcz358 authored Jul 16, 2024
1 parent 4f8db1d commit 5fc5f2f
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 0 deletions.
8 changes: 8 additions & 0 deletions lmms_eval/models/internvl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,14 @@ def generate_until(self, requests) -> List[str]:
if k not in gen_kwargs:
gen_kwargs[k] = v

pop_keys = []
for k, v in gen_kwargs.items():
if k not in DEFAULT_GEN_KWARGS:
pop_keys.append(k)

for k in pop_keys:
gen_kwargs.pop(k)

visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
visuals = self.flatten(visuals)
if self.modality == "image":
Expand Down
41 changes: 41 additions & 0 deletions lmms_eval/tasks/muirbench/muirbench.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@

dataset_path: MUIRBENCH/MUIRBENCH
task: "muirbench"
dataset_kwargs:
token: True
test_split: test
output_type: generate_until
doc_to_visual: !function utils.muir_doc_to_visual
doc_to_text: !function utils.muir_doc_to_text
doc_to_target: !function utils.muir_doc_to_target
process_results: !function utils.muir_process_results

model_specific_prompt_kwargs:
default:
pre_prompt: ""
post_prompt: "\nAnswer with the option's letter from the given choices directly."


generation_kwargs:
max_new_tokens: 16
temperature: 0
do_sample: False

filter_list:
- name: "flexible-extract"
filter:
- function: !function utils.MultiChoiceRegexFilter
group_select: 0
ignore_case: true
ignore_punctuation: true
regex_pattern: "([A-Z])\\."

metric_list:
- metric: muirbench_score_overall
aggregation: !function utils.muir_aggregation
higher_is_better: true
ignore_case: true
ignore_punctuation: true

metadata:
- version: 0.0
117 changes: 117 additions & 0 deletions lmms_eval/tasks/muirbench/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@

from lmms_eval.filters.extraction import ExtendedRegexFilter
from lmms_eval.filters.transformation import MapFilter
import re
import pandas as pd


def muir_doc_to_text(doc, model_specific_prompt_kwargs=None):
question, choices = doc["question"], doc["options"]
len_choices = len(choices)
post_prompt = model_specific_prompt_kwargs["post_prompt"]
pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
options = [chr(ord("A") + i) for i in range(len_choices)]
choices_str = "\n".join([f"{option}. {choice}" for option, choice in zip(options, choices)])
return f"{pre_prompt}{question}\n{choices_str}{post_prompt}"


def muir_doc_to_visual(doc):
image_list = [image.convert("RGB") for image in doc["image_list"]]
return image_list


def muir_doc_to_target(doc):
return doc["answer"]


def muir_process_results(doc, result):
pred = result[0]
task = doc["task"]
idx = doc["idx"]
image_relation = doc["image_relation"]
answer = doc["answer"]
image_type = doc["image_type"]

data_dict = {
"pred" : pred,
"task" : task,
"idx" : idx,
"image_relation" : image_relation,
"answer" : answer,
"image_type" : image_type,
}

return {"muirbench_score_overall" : data_dict}


def muir_aggregation(results):
task_num = {}
score = 0
task_score = {}
for result in results:
if result["task"] not in task_score:
task_score[result["task"]] = 0

if result["task"] not in task_num:
task_num[result["task"]] = 0

if result["pred"].lower().strip() == result["answer"].lower().strip():
task_score[result["task"]] += 1
score += 1
task_num[result["task"]] += 1

score = score / len(results)

task_score = {k : v / task_num[k] for k,v in task_score.items()}

print("=" * 50)
for k, v in task_score.items():
print(f"{k} : {v:.2f}")
print("=" * 50)

return score




class MultiChoiceRegexFilter(ExtendedRegexFilter):
def __init__(self, *args, **kwargs):
"""
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
- step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
- step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
group_select: Selects the (group_select)th match from the findall result.
ignore_case: Ignores the case during step 1 matching
ignore_punctuation: Remove the punctuation during step 1 matching
regexes_to_ignore: Remove these regexes during step 1 matching
"""
super().__init__(*args, **kwargs)

def apply(self, resps, docs):
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
# independently (and keep them a list.)

filtered_resps = []

for r, doc in zip(resps, docs):
# Regex to directly extract the option letter from the model response
option_letter_regex = re.compile(r"^\s*([A-Z])\.")

# Process each response
filtered = []
for resp in r:
# Try to match the option letter at the start of the response
match = option_letter_regex.match(resp)
if match:
# If a match is found, append the matched letter
filtered.append(match.group(1))
else:
# If no match, return the original response
filtered.append(resp)

# Assuming we need the first response that matches or the original response
filtered_resps.append(filtered[0])

return filtered_resps

0 comments on commit 5fc5f2f

Please sign in to comment.