Skip to content

Commit

Permalink
VQAv2 eval (#4)
Browse files Browse the repository at this point in the history
* vqav2

* Add vqav2_process_results function and update vqav2_doc_to_text function

* Implement vqav2_process_results function to return exact match score

* Refactor fewshot_docs() to use config.fewshot_config

* Refactor Task class to handle fewshot_docs when training and validation docs are not available

* Add answer processing logic in vqav2_process_results function

* Refactor vqav2_process_results function and add submission aggregation

* Add vqav2_aggreate_submissions function to utils.py

* textvqa

* Refactor answer processing in textvqa_process_results() function

* textvqa eval

* Update dataset path and modify textvqa_doc_to_text function

* Capitalize the question in textvqa_doc_to_text function

* Update textvqa.yaml and utils.py

* Fix formatting issues in lmms_eval/api/task.py, lmms_eval/tasks/gqa/utils.py, lmms_eval/tasks/textvqa/utils.py, and lmms_eval/tasks/vqav2/utils.py

---------

Co-authored-by: Li Bo <drluodian@gmail.com>
  • Loading branch information
pufanyi and Luodian authored Jan 17, 2024
1 parent 209f390 commit 7ddb976
Show file tree
Hide file tree
Showing 6 changed files with 599 additions and 4 deletions.
6 changes: 3 additions & 3 deletions lmms_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,8 @@ def fewshot_docs(self):
elif self.has_validation_docs():
return self.validation_docs()
else:
eval_logger.warning("has_training_docs and has_validation_docs are False" ", using test_docs as fewshot_docs but this is not recommended.")
if self.config.num_fewshot is not None:
eval_logger.warning("has_training_docs and has_validation_docs are False" ", using test_docs as fewshot_docs but this is not recommended.")
return self.test_docs()

def _process_doc(self, doc):
Expand Down Expand Up @@ -517,8 +518,7 @@ def __init__(self) -> None: # TODO no super() call here
self._filters.append(filter_pipeline)
else:
self._filters = [build_filter_ensemble("none", [["take_first", None]])]

if self.fewshot_docs() is not None:
if self.config.fewshot_config is not None:
self.sampler = samplers.get_sampler(self.config.fewshot_config.get("sampler", "default") if self.config.fewshot_config else "default")(list(self.fewshot_docs()), self, rnd=random.Random(1234))

if self.has_test_docs():
Expand Down
1 change: 0 additions & 1 deletion lmms_eval/tasks/gqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
GQA_ID2IMAGE = None



def gqa_doc_to_visual(doc):
global GQA_RAW_IMAGE_DATASET
global GQA_ID2IMAGE
Expand Down
23 changes: 23 additions & 0 deletions lmms_eval/tasks/textvqa/textvqa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
task: textvqa
dataset_path: lmms-lab/textvqa
test_split: test
output_type: generate_until
doc_to_visual: !function utils.textvqa_doc_to_visual
doc_to_text: !function utils.textvqa_doc_to_text
doc_to_target: "answer"
generation_kwargs:
until:
- "ASSISTANT:"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
- metric: submission
aggregation: !function utils.textvqa_aggreate_submissions
higher_is_better: true
metadata:
- version: 0.0
- have_ocr_reference: true
process_results: !function utils.textvqa_process_results
279 changes: 279 additions & 0 deletions lmms_eval/tasks/textvqa/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
import re
import os
import json
import yaml
import pathlib
import logging
import datetime
import statistics

eval_logger = logging.getLogger("lmms-eval")

with open(pathlib.Path(__file__).parent / "textvqa.yaml", "r") as f:
raw_data = f.readlines()
for i in range(len(raw_data)):
raw_data[i] = raw_data[i].replace("!function", "function")

config = yaml.safe_load("".join(raw_data))


class EvalAIAnswerProcessor:
"""
Processes an answer similar to Eval AI
copied from
https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897
"""

CONTRACTIONS = {
"aint": "ain't",
"arent": "aren't",
"cant": "can't",
"couldve": "could've",
"couldnt": "couldn't",
"couldn'tve": "couldn't've",
"couldnt've": "couldn't've",
"didnt": "didn't",
"doesnt": "doesn't",
"dont": "don't",
"hadnt": "hadn't",
"hadnt've": "hadn't've",
"hadn'tve": "hadn't've",
"hasnt": "hasn't",
"havent": "haven't",
"hed": "he'd",
"hed've": "he'd've",
"he'dve": "he'd've",
"hes": "he's",
"howd": "how'd",
"howll": "how'll",
"hows": "how's",
"Id've": "I'd've",
"I'dve": "I'd've",
"Im": "I'm",
"Ive": "I've",
"isnt": "isn't",
"itd": "it'd",
"itd've": "it'd've",
"it'dve": "it'd've",
"itll": "it'll",
"let's": "let's",
"maam": "ma'am",
"mightnt": "mightn't",
"mightnt've": "mightn't've",
"mightn'tve": "mightn't've",
"mightve": "might've",
"mustnt": "mustn't",
"mustve": "must've",
"neednt": "needn't",
"notve": "not've",
"oclock": "o'clock",
"oughtnt": "oughtn't",
"ow's'at": "'ow's'at",
"'ows'at": "'ow's'at",
"'ow'sat": "'ow's'at",
"shant": "shan't",
"shed've": "she'd've",
"she'dve": "she'd've",
"she's": "she's",
"shouldve": "should've",
"shouldnt": "shouldn't",
"shouldnt've": "shouldn't've",
"shouldn'tve": "shouldn't've",
"somebody'd": "somebodyd",
"somebodyd've": "somebody'd've",
"somebody'dve": "somebody'd've",
"somebodyll": "somebody'll",
"somebodys": "somebody's",
"someoned": "someone'd",
"someoned've": "someone'd've",
"someone'dve": "someone'd've",
"someonell": "someone'll",
"someones": "someone's",
"somethingd": "something'd",
"somethingd've": "something'd've",
"something'dve": "something'd've",
"somethingll": "something'll",
"thats": "that's",
"thered": "there'd",
"thered've": "there'd've",
"there'dve": "there'd've",
"therere": "there're",
"theres": "there's",
"theyd": "they'd",
"theyd've": "they'd've",
"they'dve": "they'd've",
"theyll": "they'll",
"theyre": "they're",
"theyve": "they've",
"twas": "'twas",
"wasnt": "wasn't",
"wed've": "we'd've",
"we'dve": "we'd've",
"weve": "we've",
"werent": "weren't",
"whatll": "what'll",
"whatre": "what're",
"whats": "what's",
"whatve": "what've",
"whens": "when's",
"whered": "where'd",
"wheres": "where's",
"whereve": "where've",
"whod": "who'd",
"whod've": "who'd've",
"who'dve": "who'd've",
"wholl": "who'll",
"whos": "who's",
"whove": "who've",
"whyll": "why'll",
"whyre": "why're",
"whys": "why's",
"wont": "won't",
"wouldve": "would've",
"wouldnt": "wouldn't",
"wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've",
"yall": "y'all",
"yall'll": "y'all'll",
"y'allll": "y'all'll",
"yall'd've": "y'all'd've",
"y'alld've": "y'all'd've",
"y'all'dve": "y'all'd've",
"youd": "you'd",
"youd've": "you'd've",
"you'dve": "you'd've",
"youll": "you'll",
"youre": "you're",
"youve": "you've",
}

NUMBER_MAP = {
"none": "0",
"zero": "0",
"one": "1",
"two": "2",
"three": "3",
"four": "4",
"five": "5",
"six": "6",
"seven": "7",
"eight": "8",
"nine": "9",
"ten": "10",
}
ARTICLES = ["a", "an", "the"]
PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)")
PUNCTUATIONS = [
";",
r"/",
"[",
"]",
'"',
"{",
"}",
"(",
")",
"=",
"+",
"\\",
"_",
"-",
">",
"<",
"@",
"`",
",",
"?",
"!",
]

def __init__(self, *args, **kwargs):
pass

def word_tokenize(self, word):
word = word.lower()
word = word.replace(",", "").replace("?", "").replace("'s", " 's")
return word.strip()

def process_punctuation(self, in_text):
out_text = in_text
for p in self.PUNCTUATIONS:
if (p + " " in in_text or " " + p in in_text) or (re.search(self.COMMA_STRIP, in_text) is not None):
out_text = out_text.replace(p, "")
else:
out_text = out_text.replace(p, " ")
out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE)
return out_text

def process_digit_article(self, in_text):
out_text = []
temp_text = in_text.lower().split()
for word in temp_text:
word = self.NUMBER_MAP.setdefault(word, word)
if word not in self.ARTICLES:
out_text.append(word)
else:
pass
for word_id, word in enumerate(out_text):
if word in self.CONTRACTIONS:
out_text[word_id] = self.CONTRACTIONS[word]
out_text = " ".join(out_text)
return out_text

def __call__(self, item):
item = self.word_tokenize(item)
item = item.replace("\n", " ").replace("\t", " ").strip()
item = self.process_punctuation(item)
item = self.process_digit_article(item)
return item


def textvqa_doc_to_visual(doc):
return [doc["image"].convert("RGB")]


def textvqa_process_results(doc, result):
eval_ai_processor = EvalAIAnswerProcessor()
assert len(result) == 1, f"The result should be a list of length 1, but got {len(result)}."
resAns = eval_ai_processor(result[0])
accuracy = 0

if "answers" in doc and doc["answers"] is not None:
gtAcc = []

for i in range(len(doc["answers"])):
doc["answers"][i] = eval_ai_processor(doc["answers"][i])

for i in range(len(doc["answers"])):
otherGTAns = [doc["answers"][j] for j in range(len(doc["answers"])) if i != j]
matchingAns = [item for item in otherGTAns if item == resAns]
acc = min(1, float(len(matchingAns)) / 3)
gtAcc.append(acc)
accuracy = statistics.mean(gtAcc)

return {
"exact_match": accuracy,
"submission": {
"question_id": doc["question_id"],
"answer": resAns,
},
}


def textvqa_doc_to_text(doc):
ocr_ref = ""
if "have_ocr_reference" in config["metadata"] and config["metadata"]["have_ocr_prompt"] and doc["ocr_tokens"]:
ocr_ref = f"Reference OCR token: {', '.join(doc['ocr_tokens'])}\n"
text = f"{doc['question'].capitalize()}\n{ocr_ref}Answer the question using a single word or phrase."
return text


def textvqa_aggreate_submissions(results):
now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
submission_file_name = f"textvqa-submission-{now_date_time}.json"
path = os.path.abspath(submission_file_name)
with open(path, "w") as f:
json.dump(results, f)
print(f"Submission file saved to {path}")
return 0
Loading

0 comments on commit 7ddb976

Please sign in to comment.