Skip to content

Commit

Permalink
Merge pull request #105 from tianyu-z/main
Browse files Browse the repository at this point in the history
Include VCR
  • Loading branch information
Luodian authored Jun 12, 2024
2 parents 5ed0035 + 0ce46d0 commit 44a3379
Show file tree
Hide file tree
Showing 25 changed files with 553 additions and 21 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ During the `v0.1` to `v0.2`, we thank the community support from pull requests (
**Datasets:**

- VCR: Vision_Caption_Restoration (officially from the authors, MILA)
- VCR: Visual Caption Restoration (officially from the authors, MILA)
- ConBench (officially from the authors, PKU/Bytedance)
- MathVerse (officially from the authors, CUHK)
- MM-UPD (officially from the authors, University of Tokyo)
Expand Down
2 changes: 1 addition & 1 deletion lmms_eval/tasks/ok_vqa/_default_template_vqa_yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ metric_list:
ignore_case: true
ignore_punctuation: true
- metric: submission
aggregation: !function utils.ok_vqa_aggreate_submissions
aggregation: !function utils.ok_vqa_aggregate_submissions
higher_is_better: true
process_results: !function utils.ok_vqa_process_results
model_specific_prompt_kwargs:
Expand Down
10 changes: 7 additions & 3 deletions lmms_eval/tasks/ok_vqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def ok_vqa_doc_to_visual(doc):

def ok_vqa_process_results(doc, result):
eval_ai_processor = EvalAIAnswerProcessor()
assert len(result) == 1, f"The result should be a list of length 1, but got {len(result)}."
assert (
len(result) == 1
), f"The result should be a list of length 1, but got {len(result)}."
resAns = eval_ai_processor(result[0])
accuracy = 0

Expand All @@ -30,7 +32,9 @@ def ok_vqa_process_results(doc, result):
doc["answers"][i] = eval_ai_processor(doc["answers"][i])

for i in range(len(doc["answers"])):
otherGTAns = [doc["answers"][j] for j in range(len(doc["answers"])) if i != j]
otherGTAns = [
doc["answers"][j] for j in range(len(doc["answers"])) if i != j
]
matchingAns = [item for item in otherGTAns if item == resAns]
acc = min(1, float(len(matchingAns)) / 3)
gtAcc.append(acc)
Expand Down Expand Up @@ -61,7 +65,7 @@ def ok_vqa_doc_to_text(doc, model_specific_prompt_kwargs=None):
return f"{pre_prompt}{question}{post_prompt}"


def ok_vqa_aggreate_submissions(results, args):
def ok_vqa_aggregate_submissions(results, args):
now_date_time = datetime.datetime.now().strftime("%Y-%m%d-%H%M-%S")
file = f"ok_vqa-test-submission-{now_date_time}.json"
path = generate_submission_file(file, args)
Expand Down
2 changes: 1 addition & 1 deletion lmms_eval/tasks/textvqa/textvqa_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ task: textvqa_test
test_split: test
metric_list:
- metric: submission
aggregation: !function utils.textvqa_aggreate_submissions
aggregation: !function utils.textvqa_aggregate_submissions
higher_is_better: true
include: _default_template_textvqa_yaml
2 changes: 1 addition & 1 deletion lmms_eval/tasks/textvqa/textvqa_val.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ metric_list:
ignore_case: true
ignore_punctuation: true
- metric: submission
aggregation: !function utils.textvqa_aggreate_submissions
aggregation: !function utils.textvqa_aggregate_submissions
higher_is_better: true
include: _default_template_textvqa_yaml
15 changes: 11 additions & 4 deletions lmms_eval/tasks/textvqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def textvqa_doc_to_visual(doc):

def textvqa_process_results(doc, result):
eval_ai_processor = EvalAIAnswerProcessor()
assert len(result) == 1, f"The result should be a list of length 1, but got {len(result)}."
assert (
len(result) == 1
), f"The result should be a list of length 1, but got {len(result)}."
resAns = eval_ai_processor(result[0])
accuracy = 0

Expand All @@ -30,7 +32,9 @@ def textvqa_process_results(doc, result):
doc["answers"][i] = eval_ai_processor(doc["answers"][i])

for i in range(len(doc["answers"])):
otherGTAns = [doc["answers"][j] for j in range(len(doc["answers"])) if i != j]
otherGTAns = [
doc["answers"][j] for j in range(len(doc["answers"])) if i != j
]
matchingAns = [item for item in otherGTAns if item == resAns]
acc = min(1, float(len(matchingAns)) / 3)
gtAcc.append(acc)
Expand All @@ -54,12 +58,15 @@ def textvqa_doc_to_text(doc, model_specific_prompt_kwargs=None):
pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
if "post_prompt" in model_specific_prompt_kwargs:
post_prompt = model_specific_prompt_kwargs["post_prompt"]
if "ocr" in model_specific_prompt_kwargs and model_specific_prompt_kwargs["ocr"]:
if (
"ocr" in model_specific_prompt_kwargs
and model_specific_prompt_kwargs["ocr"]
):
ocr_ref = f"\nReference OCR token: {', '.join(doc['ocr_tokens'])}"
return f"{pre_prompt}{doc['question'].capitalize()}{ocr_ref}{post_prompt}"


def textvqa_aggreate_submissions(results, args):
def textvqa_aggregate_submissions(results, args):
now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
path = generate_submission_file(f"textvqa_submission_{now_date_time}.json", args)
with open(path, "w") as f:
Expand Down
17 changes: 17 additions & 0 deletions lmms_eval/tasks/vcr_wiki/_default_template_vcr_yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

dataset_kwargs:
token: True
output_type: generate_until
doc_to_visual: !function utils.vcr_doc_to_visual
doc_to_text: !function utils.vcr_doc_to_text
doc_to_target: "answer"
generation_kwargs:
max_new_tokens: 120
temperature: 0
top_p: 0
num_beams: 1
do_sample: false
# The return value of process_results will be used by metrics
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
metadata:
- version: 0.0.1
Loading

0 comments on commit 44a3379

Please sign in to comment.