Skip to content

Commit

Permalink
modify the form of VCR
Browse files Browse the repository at this point in the history
  • Loading branch information
tianyu-z committed Jun 10, 2024
1 parent 96e8d98 commit e1f04db
Show file tree
Hide file tree
Showing 20 changed files with 307 additions and 199 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,11 @@ We also provide the raw data exported from Weights & Biases for the detailed res
- TextVQA (textvqa)
- TextVQA Validation (textvqa_val)
- TextVQA Test (textvqa_test)
- VCR-Wiki (vcr_wiki)
- vcr English easy mode (vcr_wiki_en_easy)
- vcr English hard mode (vcr_wiki_en_hard)
- vcr Chinese easy mode (vcr_wiki_zh_easy)
- vcr Chinese hard mode (vcr_wiki_zh_hard)
- VizWizVQA (vizwiz_vqa)
- VizWizVQA Validation (vizwiz_vqa_val)
- VizWizVQA Test (vizwiz_vqa_test)
Expand Down
31 changes: 0 additions & 31 deletions lmms_eval/tasks/vcr/_default_template_vcr_yaml

This file was deleted.

31 changes: 0 additions & 31 deletions lmms_eval/tasks/vcr/vcr_wiki_en_easy.yaml

This file was deleted.

31 changes: 0 additions & 31 deletions lmms_eval/tasks/vcr/vcr_wiki_en_hard.yaml

This file was deleted.

31 changes: 0 additions & 31 deletions lmms_eval/tasks/vcr/vcr_wiki_zh_easy.yaml

This file was deleted.

31 changes: 0 additions & 31 deletions lmms_eval/tasks/vcr/vcr_wiki_zh_hard.yaml

This file was deleted.

17 changes: 17 additions & 0 deletions lmms_eval/tasks/vcr_wiki/_default_template_vcr_yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

dataset_kwargs:
token: True
output_type: generate_until
doc_to_visual: !function utils.vcr_doc_to_visual
doc_to_text: !function utils.vcr_doc_to_text
doc_to_target: "answer"
generation_kwargs:
max_new_tokens: 120
temperature: 0
top_p: 0
num_beams: 1
do_sample: false
# The return value of process_results will be used by metrics
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
metadata:
- version: 0.0.1
137 changes: 93 additions & 44 deletions lmms_eval/tasks/vcr/utils.py → lmms_eval/tasks/vcr_wiki/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
from collections import defaultdict
import os
from difflib import SequenceMatcher as SM
import datetime
import json
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
import evaluate
import logging
import spacy
Expand Down Expand Up @@ -34,6 +30,21 @@
}


def fast_filter(answer_text):
if "I can't" in answer_text:
return True
elif "I cannot" in answer_text:
return True
elif "sorry" in answer_text.lower():
return True
if "无法" in answer_text:
return True
elif "抱歉" in answer_text:
return True
else:
return False


def vcr_doc_to_visual(doc):
return [doc["stacked_image"].convert("RGB"), doc["only_it_image"].convert("RGB")]

Expand Down Expand Up @@ -63,16 +74,29 @@ def tokenize(text, language):
return [token.text for token in processed_text]


def vcr_process_results_single(doc, result, language):
def vcr_process_results_single(crossed_text, result, language):
"""
Args:
doc: a instance of the eval dataset
results: [pred]
Returns:
a dictionary with key: metric name (in this case mme score), value: metric value
"""

assert language in ["en", "zh"], f"Language {language} is not supported."
crossed_text = doc["crossed_text"]

if fast_filter(result):
return {
"crossed_text": crossed_text,
"max_sim_val": 0,
"max_sim_string": "",
"precision": 0,
"recall": 0,
"f1": 0,
"jaccard": 0,
"rouge1": 0,
"exact_match": 0,
}
tokens_result = tokenize(result, language)
tokens_crossed_text = tokenize(crossed_text, language)

Expand Down Expand Up @@ -150,10 +174,26 @@ def vcr_en_process_results(doc, results):
a dictionary with key: metric name (in this case mme score), value: metric value
"""
assert len(results) == 2, f"Expected 2 results, got {len(results)}"
output = {
"res_stacked_image": vcr_process_results_single(doc, results[0], "en"),
"res_only_it_image": vcr_process_results_single(doc, results[1], "en"),
}
output = {}
for i in range(len(doc["crossed_text"])):
res_stacked_image_results = vcr_process_results_single(
doc["crossed_text"][i], results[0], "en"
)
res_only_image_results = vcr_process_results_single(
doc["crossed_text"][i], results[1], "en"
)
output.update(
{
f"res_stacked_image__{k}___{i}": v
for k, v in res_stacked_image_results.items()
}
)
output.update(
{
f"res_only_it_image__{k}___{i}": v
for k, v in res_only_image_results.items()
}
)
return output


Expand All @@ -166,10 +206,26 @@ def vcr_zh_process_results(doc, results):
a dictionary with key: metric name (in this case mme score), value: metric value
"""
assert len(results) == 2, f"Expected 2 results, got {len(results)}"
output = {
"res_stacked_image": vcr_process_results_single(doc, results[0], "zh"),
"res_only_it_image": vcr_process_results_single(doc, results[1], "zh"),
}
output = {}
for i in range(len(doc["crossed_text"])):
res_stacked_image_results = vcr_process_results_single(
doc["crossed_text"][i], results[0], "zh"
)
res_only_image_results = vcr_process_results_single(
doc["crossed_text"][i], results[1], "zh"
)
output.update(
{
f"res_stacked_image__{k}___{i}": v
for k, v in res_stacked_image_results.items()
}
)
output.update(
{
f"res_only_it_image__{k}___{i}": v
for k, v in res_only_image_results.items()
}
)
return output


Expand All @@ -180,36 +236,29 @@ def vcr_aggregate_results(results):
Returns:
A dictionary of dictionary of float, where the outer dictionary has keys "res_stacked_image" and "res_only_it_image"
"""

output = {
"res_stacked_image": {
"max_sim_val": 0,
"precision": 0,
"recall": 0,
"f1": 0,
"jaccard": 0,
"rouge1": 0,
},
"res_only_it_image": {
"max_sim_val": 0,
"precision": 0,
"recall": 0,
"f1": 0,
"jaccard": 0,
"rouge1": 0,
},
"res_stacked_image__precision": 0,
"res_stacked_image__recall": 0,
"res_stacked_image__f1": 0,
"res_stacked_image__jaccard": 0,
"res_stacked_image__rouge1": 0,
"res_stacked_image__exact_match": 0,
"res_only_it_image__precision": 0,
"res_only_it_image__recall": 0,
"res_only_it_image__f1": 0,
"res_only_it_image__jaccard": 0,
"res_only_it_image__rouge1": 0,
"res_only_it_image__exact_match": 0,
}
for target_domain in output.keys():
for target_metric_name in output[target_domain].keys():
score = 0
count = 0
for inner_dict in results:
for inner_key, inner_value in inner_dict.items():
if inner_key == target_domain:
for blank_id, blank_metrics in inner_value.items():
for metric_name, metric_value in blank_metrics.items():
if metric_name == target_metric_name:
score += metric_value
count += 1
output[target_domain][target_metric_name] = score / count

for output_key in output.keys():
count = 0
query_domain, query_metric_name = output_key.split("__")
for inner_dict in results:
for inner_key, inner_value in inner_dict.items():
key_domain, key_metric_name, _ = inner_key.split("__")
if query_domain == key_domain and query_metric_name == key_metric_name:
output[output_key] += inner_value
count += 1
output[output_key] /= count
return output
16 changes: 16 additions & 0 deletions lmms_eval/tasks/vcr_wiki/vcr_wiki_en_easy_100.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"include": "_default_template_vcr_yaml"
dataset_path: vcr-org/VCR-wiki-en-easy-test
task: "vcr_wiki_en_easy"
test_split: train[:100]
process_results: !function utils.vcr_en_process_results
metric_list:
- metric: vcr_percetion_score
aggregation: !function utils.vcr_en_process_results
higher_is_better: true
- metric: vcr_cognition_score
aggregation: !function utils.vcr_en_process_results
higher_is_better: true
model_specific_prompt_kwargs:
default:
pre_prompt: ""
post_prompt: "What is the covered texts in the image? Please restore the covered texts without outputting the explanations."
Loading

0 comments on commit e1f04db

Please sign in to comment.