forked from EvolvingLMMs-Lab/lmms-eval
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* vqav2 * Add vqav2_process_results function and update vqav2_doc_to_text function * Implement vqav2_process_results function to return exact match score * Refactor fewshot_docs() to use config.fewshot_config * Refactor Task class to handle fewshot_docs when training and validation docs are not available * Add answer processing logic in vqav2_process_results function * Refactor vqav2_process_results function and add submission aggregation * Add vqav2_aggreate_submissions function to utils.py * textvqa * Refactor answer processing in textvqa_process_results() function * textvqa eval * Update dataset path and modify textvqa_doc_to_text function * Capitalize the question in textvqa_doc_to_text function * Update textvqa.yaml and utils.py * Fix formatting issues in lmms_eval/api/task.py, lmms_eval/tasks/gqa/utils.py, lmms_eval/tasks/textvqa/utils.py, and lmms_eval/tasks/vqav2/utils.py --------- Co-authored-by: Li Bo <drluodian@gmail.com>
- Loading branch information
Showing
6 changed files
with
599 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,6 @@ | |
GQA_ID2IMAGE = None | ||
|
||
|
||
|
||
def gqa_doc_to_visual(doc): | ||
global GQA_RAW_IMAGE_DATASET | ||
global GQA_ID2IMAGE | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
task: textvqa | ||
dataset_path: lmms-lab/textvqa | ||
test_split: test | ||
output_type: generate_until | ||
doc_to_visual: !function utils.textvqa_doc_to_visual | ||
doc_to_text: !function utils.textvqa_doc_to_text | ||
doc_to_target: "answer" | ||
generation_kwargs: | ||
until: | ||
- "ASSISTANT:" | ||
metric_list: | ||
- metric: exact_match | ||
aggregation: mean | ||
higher_is_better: true | ||
ignore_case: true | ||
ignore_punctuation: true | ||
- metric: submission | ||
aggregation: !function utils.textvqa_aggreate_submissions | ||
higher_is_better: true | ||
metadata: | ||
- version: 0.0 | ||
- have_ocr_reference: true | ||
process_results: !function utils.textvqa_process_results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,279 @@ | ||
import re | ||
import os | ||
import json | ||
import yaml | ||
import pathlib | ||
import logging | ||
import datetime | ||
import statistics | ||
|
||
eval_logger = logging.getLogger("lmms-eval") | ||
|
||
with open(pathlib.Path(__file__).parent / "textvqa.yaml", "r") as f: | ||
raw_data = f.readlines() | ||
for i in range(len(raw_data)): | ||
raw_data[i] = raw_data[i].replace("!function", "function") | ||
|
||
config = yaml.safe_load("".join(raw_data)) | ||
|
||
|
||
class EvalAIAnswerProcessor: | ||
""" | ||
Processes an answer similar to Eval AI | ||
copied from | ||
https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897 | ||
""" | ||
|
||
CONTRACTIONS = { | ||
"aint": "ain't", | ||
"arent": "aren't", | ||
"cant": "can't", | ||
"couldve": "could've", | ||
"couldnt": "couldn't", | ||
"couldn'tve": "couldn't've", | ||
"couldnt've": "couldn't've", | ||
"didnt": "didn't", | ||
"doesnt": "doesn't", | ||
"dont": "don't", | ||
"hadnt": "hadn't", | ||
"hadnt've": "hadn't've", | ||
"hadn'tve": "hadn't've", | ||
"hasnt": "hasn't", | ||
"havent": "haven't", | ||
"hed": "he'd", | ||
"hed've": "he'd've", | ||
"he'dve": "he'd've", | ||
"hes": "he's", | ||
"howd": "how'd", | ||
"howll": "how'll", | ||
"hows": "how's", | ||
"Id've": "I'd've", | ||
"I'dve": "I'd've", | ||
"Im": "I'm", | ||
"Ive": "I've", | ||
"isnt": "isn't", | ||
"itd": "it'd", | ||
"itd've": "it'd've", | ||
"it'dve": "it'd've", | ||
"itll": "it'll", | ||
"let's": "let's", | ||
"maam": "ma'am", | ||
"mightnt": "mightn't", | ||
"mightnt've": "mightn't've", | ||
"mightn'tve": "mightn't've", | ||
"mightve": "might've", | ||
"mustnt": "mustn't", | ||
"mustve": "must've", | ||
"neednt": "needn't", | ||
"notve": "not've", | ||
"oclock": "o'clock", | ||
"oughtnt": "oughtn't", | ||
"ow's'at": "'ow's'at", | ||
"'ows'at": "'ow's'at", | ||
"'ow'sat": "'ow's'at", | ||
"shant": "shan't", | ||
"shed've": "she'd've", | ||
"she'dve": "she'd've", | ||
"she's": "she's", | ||
"shouldve": "should've", | ||
"shouldnt": "shouldn't", | ||
"shouldnt've": "shouldn't've", | ||
"shouldn'tve": "shouldn't've", | ||
"somebody'd": "somebodyd", | ||
"somebodyd've": "somebody'd've", | ||
"somebody'dve": "somebody'd've", | ||
"somebodyll": "somebody'll", | ||
"somebodys": "somebody's", | ||
"someoned": "someone'd", | ||
"someoned've": "someone'd've", | ||
"someone'dve": "someone'd've", | ||
"someonell": "someone'll", | ||
"someones": "someone's", | ||
"somethingd": "something'd", | ||
"somethingd've": "something'd've", | ||
"something'dve": "something'd've", | ||
"somethingll": "something'll", | ||
"thats": "that's", | ||
"thered": "there'd", | ||
"thered've": "there'd've", | ||
"there'dve": "there'd've", | ||
"therere": "there're", | ||
"theres": "there's", | ||
"theyd": "they'd", | ||
"theyd've": "they'd've", | ||
"they'dve": "they'd've", | ||
"theyll": "they'll", | ||
"theyre": "they're", | ||
"theyve": "they've", | ||
"twas": "'twas", | ||
"wasnt": "wasn't", | ||
"wed've": "we'd've", | ||
"we'dve": "we'd've", | ||
"weve": "we've", | ||
"werent": "weren't", | ||
"whatll": "what'll", | ||
"whatre": "what're", | ||
"whats": "what's", | ||
"whatve": "what've", | ||
"whens": "when's", | ||
"whered": "where'd", | ||
"wheres": "where's", | ||
"whereve": "where've", | ||
"whod": "who'd", | ||
"whod've": "who'd've", | ||
"who'dve": "who'd've", | ||
"wholl": "who'll", | ||
"whos": "who's", | ||
"whove": "who've", | ||
"whyll": "why'll", | ||
"whyre": "why're", | ||
"whys": "why's", | ||
"wont": "won't", | ||
"wouldve": "would've", | ||
"wouldnt": "wouldn't", | ||
"wouldnt've": "wouldn't've", | ||
"wouldn'tve": "wouldn't've", | ||
"yall": "y'all", | ||
"yall'll": "y'all'll", | ||
"y'allll": "y'all'll", | ||
"yall'd've": "y'all'd've", | ||
"y'alld've": "y'all'd've", | ||
"y'all'dve": "y'all'd've", | ||
"youd": "you'd", | ||
"youd've": "you'd've", | ||
"you'dve": "you'd've", | ||
"youll": "you'll", | ||
"youre": "you're", | ||
"youve": "you've", | ||
} | ||
|
||
NUMBER_MAP = { | ||
"none": "0", | ||
"zero": "0", | ||
"one": "1", | ||
"two": "2", | ||
"three": "3", | ||
"four": "4", | ||
"five": "5", | ||
"six": "6", | ||
"seven": "7", | ||
"eight": "8", | ||
"nine": "9", | ||
"ten": "10", | ||
} | ||
ARTICLES = ["a", "an", "the"] | ||
PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)") | ||
COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)") | ||
PUNCTUATIONS = [ | ||
";", | ||
r"/", | ||
"[", | ||
"]", | ||
'"', | ||
"{", | ||
"}", | ||
"(", | ||
")", | ||
"=", | ||
"+", | ||
"\\", | ||
"_", | ||
"-", | ||
">", | ||
"<", | ||
"@", | ||
"`", | ||
",", | ||
"?", | ||
"!", | ||
] | ||
|
||
def __init__(self, *args, **kwargs): | ||
pass | ||
|
||
def word_tokenize(self, word): | ||
word = word.lower() | ||
word = word.replace(",", "").replace("?", "").replace("'s", " 's") | ||
return word.strip() | ||
|
||
def process_punctuation(self, in_text): | ||
out_text = in_text | ||
for p in self.PUNCTUATIONS: | ||
if (p + " " in in_text or " " + p in in_text) or (re.search(self.COMMA_STRIP, in_text) is not None): | ||
out_text = out_text.replace(p, "") | ||
else: | ||
out_text = out_text.replace(p, " ") | ||
out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE) | ||
return out_text | ||
|
||
def process_digit_article(self, in_text): | ||
out_text = [] | ||
temp_text = in_text.lower().split() | ||
for word in temp_text: | ||
word = self.NUMBER_MAP.setdefault(word, word) | ||
if word not in self.ARTICLES: | ||
out_text.append(word) | ||
else: | ||
pass | ||
for word_id, word in enumerate(out_text): | ||
if word in self.CONTRACTIONS: | ||
out_text[word_id] = self.CONTRACTIONS[word] | ||
out_text = " ".join(out_text) | ||
return out_text | ||
|
||
def __call__(self, item): | ||
item = self.word_tokenize(item) | ||
item = item.replace("\n", " ").replace("\t", " ").strip() | ||
item = self.process_punctuation(item) | ||
item = self.process_digit_article(item) | ||
return item | ||
|
||
|
||
def textvqa_doc_to_visual(doc): | ||
return [doc["image"].convert("RGB")] | ||
|
||
|
||
def textvqa_process_results(doc, result): | ||
eval_ai_processor = EvalAIAnswerProcessor() | ||
assert len(result) == 1, f"The result should be a list of length 1, but got {len(result)}." | ||
resAns = eval_ai_processor(result[0]) | ||
accuracy = 0 | ||
|
||
if "answers" in doc and doc["answers"] is not None: | ||
gtAcc = [] | ||
|
||
for i in range(len(doc["answers"])): | ||
doc["answers"][i] = eval_ai_processor(doc["answers"][i]) | ||
|
||
for i in range(len(doc["answers"])): | ||
otherGTAns = [doc["answers"][j] for j in range(len(doc["answers"])) if i != j] | ||
matchingAns = [item for item in otherGTAns if item == resAns] | ||
acc = min(1, float(len(matchingAns)) / 3) | ||
gtAcc.append(acc) | ||
accuracy = statistics.mean(gtAcc) | ||
|
||
return { | ||
"exact_match": accuracy, | ||
"submission": { | ||
"question_id": doc["question_id"], | ||
"answer": resAns, | ||
}, | ||
} | ||
|
||
|
||
def textvqa_doc_to_text(doc): | ||
ocr_ref = "" | ||
if "have_ocr_reference" in config["metadata"] and config["metadata"]["have_ocr_prompt"] and doc["ocr_tokens"]: | ||
ocr_ref = f"Reference OCR token: {', '.join(doc['ocr_tokens'])}\n" | ||
text = f"{doc['question'].capitalize()}\n{ocr_ref}Answer the question using a single word or phrase." | ||
return text | ||
|
||
|
||
def textvqa_aggreate_submissions(results): | ||
now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") | ||
submission_file_name = f"textvqa-submission-{now_date_time}.json" | ||
path = os.path.abspath(submission_file_name) | ||
with open(path, "w") as f: | ||
json.dump(results, f) | ||
print(f"Submission file saved to {path}") | ||
return 0 |
Oops, something went wrong.