Skip to content

Commit

Permalink
switch logic
Browse files Browse the repository at this point in the history
  • Loading branch information
tianyu-z committed Jun 10, 2024
1 parent e1f04db commit 043b483
Show file tree
Hide file tree
Showing 23 changed files with 141 additions and 129 deletions.
2 changes: 1 addition & 1 deletion lmms_eval/tasks/ok_vqa/_default_template_vqa_yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ metric_list:
ignore_case: true
ignore_punctuation: true
- metric: submission
aggregation: !function utils.ok_vqa_aggreate_submissions
aggregation: !function utils.ok_vqa_aggregate_submissions
higher_is_better: true
process_results: !function utils.ok_vqa_process_results
model_specific_prompt_kwargs:
Expand Down
10 changes: 7 additions & 3 deletions lmms_eval/tasks/ok_vqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def ok_vqa_doc_to_visual(doc):

def ok_vqa_process_results(doc, result):
eval_ai_processor = EvalAIAnswerProcessor()
assert len(result) == 1, f"The result should be a list of length 1, but got {len(result)}."
assert (
len(result) == 1
), f"The result should be a list of length 1, but got {len(result)}."
resAns = eval_ai_processor(result[0])
accuracy = 0

Expand All @@ -30,7 +32,9 @@ def ok_vqa_process_results(doc, result):
doc["answers"][i] = eval_ai_processor(doc["answers"][i])

for i in range(len(doc["answers"])):
otherGTAns = [doc["answers"][j] for j in range(len(doc["answers"])) if i != j]
otherGTAns = [
doc["answers"][j] for j in range(len(doc["answers"])) if i != j
]
matchingAns = [item for item in otherGTAns if item == resAns]
acc = min(1, float(len(matchingAns)) / 3)
gtAcc.append(acc)
Expand Down Expand Up @@ -61,7 +65,7 @@ def ok_vqa_doc_to_text(doc, model_specific_prompt_kwargs=None):
return f"{pre_prompt}{question}{post_prompt}"


def ok_vqa_aggreate_submissions(results, args):
def ok_vqa_aggregate_submissions(results, args):
now_date_time = datetime.datetime.now().strftime("%Y-%m%d-%H%M-%S")
file = f"ok_vqa-test-submission-{now_date_time}.json"
path = generate_submission_file(file, args)
Expand Down
2 changes: 1 addition & 1 deletion lmms_eval/tasks/textvqa/textvqa_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ task: textvqa_test
test_split: test
metric_list:
- metric: submission
aggregation: !function utils.textvqa_aggreate_submissions
aggregation: !function utils.textvqa_aggregate_submissions
higher_is_better: true
include: _default_template_textvqa_yaml
2 changes: 1 addition & 1 deletion lmms_eval/tasks/textvqa/textvqa_val.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ metric_list:
ignore_case: true
ignore_punctuation: true
- metric: submission
aggregation: !function utils.textvqa_aggreate_submissions
aggregation: !function utils.textvqa_aggregate_submissions
higher_is_better: true
include: _default_template_textvqa_yaml
15 changes: 11 additions & 4 deletions lmms_eval/tasks/textvqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def textvqa_doc_to_visual(doc):

def textvqa_process_results(doc, result):
eval_ai_processor = EvalAIAnswerProcessor()
assert len(result) == 1, f"The result should be a list of length 1, but got {len(result)}."
assert (
len(result) == 1
), f"The result should be a list of length 1, but got {len(result)}."
resAns = eval_ai_processor(result[0])
accuracy = 0

Expand All @@ -30,7 +32,9 @@ def textvqa_process_results(doc, result):
doc["answers"][i] = eval_ai_processor(doc["answers"][i])

for i in range(len(doc["answers"])):
otherGTAns = [doc["answers"][j] for j in range(len(doc["answers"])) if i != j]
otherGTAns = [
doc["answers"][j] for j in range(len(doc["answers"])) if i != j
]
matchingAns = [item for item in otherGTAns if item == resAns]
acc = min(1, float(len(matchingAns)) / 3)
gtAcc.append(acc)
Expand All @@ -54,12 +58,15 @@ def textvqa_doc_to_text(doc, model_specific_prompt_kwargs=None):
pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
if "post_prompt" in model_specific_prompt_kwargs:
post_prompt = model_specific_prompt_kwargs["post_prompt"]
if "ocr" in model_specific_prompt_kwargs and model_specific_prompt_kwargs["ocr"]:
if (
"ocr" in model_specific_prompt_kwargs
and model_specific_prompt_kwargs["ocr"]
):
ocr_ref = f"\nReference OCR token: {', '.join(doc['ocr_tokens'])}"
return f"{pre_prompt}{doc['question'].capitalize()}{ocr_ref}{post_prompt}"


def textvqa_aggreate_submissions(results, args):
def textvqa_aggregate_submissions(results, args):
now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
path = generate_submission_file(f"textvqa_submission_{now_date_time}.json", args)
with open(path, "w") as f:
Expand Down
137 changes: 64 additions & 73 deletions lmms_eval/tasks/vcr_wiki/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from spacy.cli import download
from nltk.util import ngrams
from functools import partial
import datetime
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
import json

# Download the English and Chinese models
download("en_core_web_sm")
Expand Down Expand Up @@ -46,7 +49,7 @@ def fast_filter(answer_text):


def vcr_doc_to_visual(doc):
return [doc["stacked_image"].convert("RGB"), doc["only_it_image"].convert("RGB")]
return [doc["stacked_image"].convert("RGB")]


def vcr_doc_to_text(doc, model_specific_prompt_kwargs=None):
Expand Down Expand Up @@ -80,7 +83,7 @@ def vcr_process_results_single(crossed_text, result, language):
doc: a instance of the eval dataset
results: [pred]
Returns:
a dictionary with key: metric name (in this case mme score), value: metric value
a dictionary with key: metric name (in this case vcr score), value: metric value
"""

assert language in ["en", "zh"], f"Language {language} is not supported."
Expand Down Expand Up @@ -171,29 +174,28 @@ def vcr_en_process_results(doc, results):
doc: a instance of the eval dataset
results: [pred]
Returns:
a dictionary with key: metric name (in this case mme score), value: metric value
a dictionary with key: metric name (in this case vcr score), value: metric value
"""
assert len(results) == 2, f"Expected 2 results, got {len(results)}"
output = {}
for i in range(len(doc["crossed_text"])):
res_stacked_image_results = vcr_process_results_single(
doc["crossed_text"][i], results[0], "en"
)
res_only_image_results = vcr_process_results_single(
doc["crossed_text"][i], results[1], "en"
)
output.update(
{
f"res_stacked_image__{k}___{i}": v
for k, v in res_stacked_image_results.items()
}
)
output.update(
{
f"res_only_it_image__{k}___{i}": v
for k, v in res_only_image_results.items()
}
)
output = {
"max_sim_val": [],
"precision": [],
"recall": [],
"f1": [],
"jaccard": [],
"rouge1": [],
"exact_match": [],
}
crossed_text = doc["crossed_text"]
for i in range(len(crossed_text)):
tmp = vcr_process_results_single(crossed_text[i], results, "en")
for k in output.keys():
output[k].append(
{
"score": tmp[k],
"max_sim_string": tmp["max_sim_string"],
"caption": doc["caption"],
}
)
return output


Expand All @@ -203,62 +205,51 @@ def vcr_zh_process_results(doc, results):
doc: a instance of the eval dataset
results: [pred]
Returns:
a dictionary with key: metric name (in this case mme score), value: metric value
a dictionary with key: metric name (in this case vcr score), value: metric value
"""
assert len(results) == 2, f"Expected 2 results, got {len(results)}"
output = {}
for i in range(len(doc["crossed_text"])):
res_stacked_image_results = vcr_process_results_single(
doc["crossed_text"][i], results[0], "zh"
)
res_only_image_results = vcr_process_results_single(
doc["crossed_text"][i], results[1], "zh"
)
output.update(
{
f"res_stacked_image__{k}___{i}": v
for k, v in res_stacked_image_results.items()
}
)
output.update(
{
f"res_only_it_image__{k}___{i}": v
for k, v in res_only_image_results.items()
}
)
output = {
"max_sim_val": [],
"precision": [],
"recall": [],
"f1": [],
"jaccard": [],
"rouge1": [],
"exact_match": [],
}
crossed_text = doc["crossed_text"]
for i in range(len(crossed_text)):
tmp = vcr_process_results_single(crossed_text[i], results, "zh")
for k in output.keys():
output[k].append(
{
"score": tmp[k],
"max_sim_string": tmp["max_sim_string"],
"caption": doc["caption"],
}
)
return output


def vcr_aggregate_results(results):
def vcr_aggregate_results(results, args):
"""
Args:
results: a list of values returned by process_results
Returns:
A dictionary of dictionary of float, where the outer dictionary has keys "res_stacked_image" and "res_only_it_image"
"""
output = {
"res_stacked_image__precision": 0,
"res_stacked_image__recall": 0,
"res_stacked_image__f1": 0,
"res_stacked_image__jaccard": 0,
"res_stacked_image__rouge1": 0,
"res_stacked_image__exact_match": 0,
"res_only_it_image__precision": 0,
"res_only_it_image__recall": 0,
"res_only_it_image__f1": 0,
"res_only_it_image__jaccard": 0,
"res_only_it_image__rouge1": 0,
"res_only_it_image__exact_match": 0,
}

for output_key in output.keys():
count = 0
query_domain, query_metric_name = output_key.split("__")
for inner_dict in results:
for inner_key, inner_value in inner_dict.items():
key_domain, key_metric_name, _ = inner_key.split("__")
if query_domain == key_domain and query_metric_name == key_metric_name:
output[output_key] += inner_value
count += 1
output[output_key] /= count
return output
scores = 0
count = 0
output_dict = {}
for i in range(len(results)):
for blank_id in range(len(results[i])):
scores += results[i][blank_id]["score"]
count += 1
output_dict[str(i)] = results[i]

now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
path = generate_submission_file(f"vcr_submission_{now_date_time}.json", args)
with open(path, "w") as f:
json.dump(output_dict, f)
# print(f"Submission file saved to {path}")
eval_logger.info(f"Submission file saved to {path}")
return scores / count
6 changes: 3 additions & 3 deletions lmms_eval/tasks/vcr_wiki/vcr_wiki_en_easy_100.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"include": "_default_template_vcr_yaml"
dataset_path: vcr-org/VCR-wiki-en-easy-test
task: "vcr_wiki_en_easy"
task: "vcr_wiki_en_easy_100"
test_split: train[:100]
process_results: !function utils.vcr_en_process_results
metric_list:
- metric: vcr_percetion_score
- metric: jaccard
aggregation: !function utils.vcr_en_process_results
higher_is_better: true
- metric: vcr_cognition_score
- metric: exact_match
aggregation: !function utils.vcr_en_process_results
higher_is_better: true
model_specific_prompt_kwargs:
Expand Down
6 changes: 3 additions & 3 deletions lmms_eval/tasks/vcr_wiki/vcr_wiki_en_easy_500.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"include": "_default_template_vcr_yaml"
dataset_path: vcr-org/VCR-wiki-en-easy-test
task: "vcr_wiki_en_easy"
task: "vcr_wiki_en_easy_500"
test_split: train[:500]
process_results: !function utils.vcr_en_process_results
metric_list:
- metric: vcr_percetion_score
- metric: jaccard
aggregation: !function utils.vcr_en_process_results
higher_is_better: true
- metric: vcr_cognition_score
- metric: exact_match
aggregation: !function utils.vcr_en_process_results
higher_is_better: true
model_specific_prompt_kwargs:
Expand Down
6 changes: 3 additions & 3 deletions lmms_eval/tasks/vcr_wiki/vcr_wiki_en_easy_5000.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"include": "_default_template_vcr_yaml"
dataset_path: vcr-org/VCR-wiki-en-easy-test
task: "vcr_wiki_en_easy"
task: "vcr_wiki_en_easy_5000"
test_split: train
process_results: !function utils.vcr_en_process_results
metric_list:
- metric: vcr_percetion_score
- metric: jaccard
aggregation: !function utils.vcr_en_process_results
higher_is_better: true
- metric: vcr_cognition_score
- metric: exact_match
aggregation: !function utils.vcr_en_process_results
higher_is_better: true
model_specific_prompt_kwargs:
Expand Down
6 changes: 3 additions & 3 deletions lmms_eval/tasks/vcr_wiki/vcr_wiki_en_hard_100.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"include": "_default_template_vcr_yaml"
dataset_path: vcr-org/VCR-wiki-en-hard-test
task: "vcr_wiki_en_hard"
task: "vcr_wiki_en_hard_100"
test_split: train[:100]
process_results: !function utils.vcr_en_process_results
metric_list:
- metric: vcr_percetion_score
- metric: jaccard
aggregation: !function utils.vcr_en_process_results
higher_is_better: true
- metric: vcr_cognition_score
- metric: exact_match
aggregation: !function utils.vcr_en_process_results
higher_is_better: true
model_specific_prompt_kwargs:
Expand Down
6 changes: 3 additions & 3 deletions lmms_eval/tasks/vcr_wiki/vcr_wiki_en_hard_500.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"include": "_default_template_vcr_yaml"
dataset_path: vcr-org/VCR-wiki-en-hard-test
task: "vcr_wiki_en_hard"
task: "vcr_wiki_en_hard_500"
test_split: train[:500]
process_results: !function utils.vcr_en_process_results
metric_list:
- metric: vcr_percetion_score
- metric: jaccard
aggregation: !function utils.vcr_en_process_results
higher_is_better: true
- metric: vcr_cognition_score
- metric: exact_match
aggregation: !function utils.vcr_en_process_results
higher_is_better: true
model_specific_prompt_kwargs:
Expand Down
6 changes: 3 additions & 3 deletions lmms_eval/tasks/vcr_wiki/vcr_wiki_en_hard_5000.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"include": "_default_template_vcr_yaml"
dataset_path: vcr-org/VCR-wiki-en-hard-test
task: "vcr_wiki_en_hard"
task: "vcr_wiki_en_hard_5000"
test_split: train
process_results: !function utils.vcr_en_process_results
metric_list:
- metric: vcr_percetion_score
- metric: jaccard
aggregation: !function utils.vcr_en_process_results
higher_is_better: true
- metric: vcr_cognition_score
- metric: exact_match
aggregation: !function utils.vcr_en_process_results
higher_is_better: true
model_specific_prompt_kwargs:
Expand Down
6 changes: 3 additions & 3 deletions lmms_eval/tasks/vcr_wiki/vcr_wiki_zh_easy_100.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"include": "_default_template_vcr_yaml"
dataset_path: vcr-org/VCR-wiki-zh-easy-test
task: "vcr_wiki_zh_easy"
task: "vcr_wiki_zh_easy_100"
test_split: train[:100]
process_results: !function utils.vcr_zh_process_results
metric_list:
- metric: vcr_percetion_score
- metric: jaccard
aggregation: !function utils.vcr_zh_process_results
higher_is_better: true
- metric: vcr_cognition_score
- metric: exact_match
aggregation: !function utils.vcr_zh_process_results
higher_is_better: true
model_specific_prompt_kwargs:
Expand Down
Loading

0 comments on commit 043b483

Please sign in to comment.