-
Notifications
You must be signed in to change notification settings - Fork 186
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #69 from hunterheiden/hsh/new_task/WebSRC
[New Task] WebSRC (multimodal Q&A on web screenshots)
- Loading branch information
Showing
6 changed files
with
244 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# WebSRC | ||
|
||
## Paper | ||
|
||
Title: WebSRC: A Dataset for Web-Based Structural Reading Comprehension | ||
|
||
Abstract: https://arxiv.org/abs/2101.09465 | ||
|
||
Homepage: https://x-lance.github.io/WebSRC/# | ||
|
||
WebSRC is a dataset for web-based structural reading comprehension. | ||
Its full train/dev/test split contains over 400k questions across 6.4k webpages. | ||
This version of the dataset does not contain OCR or original HTML, it simply treats WebSRC as a image-and-text-based multimodal Q&A benchmark on webpage screenshots. | ||
|
||
## Citation | ||
|
||
```bibtex | ||
@inproceedings{chen2021websrc, | ||
title={WebSRC: A Dataset for Web-Based Structural Reading Comprehension}, | ||
author={Chen, Xingyu and Zhao, Zihan and Chen, Lu and Ji, Jiabao and Zhang, Danyang and Luo, Ao and Xiong, Yuxuan and Yu, Kai}, | ||
booktitle={Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing}, | ||
pages={4173--4185}, | ||
year={2021} | ||
} | ||
``` | ||
|
||
## Groups & Tasks | ||
|
||
### Groups | ||
|
||
- `websrc`: Evaluates `websrc-val` and generates a submission file for `websrc-test`. | ||
|
||
### Tasks | ||
|
||
- `websrc-val`: Given a question and a web page, predict the answer. | ||
- `websrc-test`: Given a question and a web page, predict the answer. Ground truth is not provided for this task. | ||
|
||
## Metrics | ||
|
||
This task uses SQUAD-style evaluation metrics, of which F1 score over tokens is used. | ||
The orignal paper also uses Exact Match (EM) score, but this is not implemented here as that metric is more conducive for Encoder-only extraction models. | ||
|
||
### F1 Score | ||
|
||
F1 Score is the harmonic mean of precision and recall. | ||
We calculate precision and recall at the token level, then compute the F1 score as normal using these values. | ||
|
||
### Test Submission | ||
|
||
When evaluaing on the test split, a prediction JSON will be compiled instead of metrics computed. | ||
Instructions for submission are available on the [WebSRC homepage](https://x-lance.github.io/WebSRC/#) and in their [Original GitHub Repo](https://github.com/X-LANCE/WebSRC-Baseline#obtain-test-result). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
from collections import defaultdict | ||
import re | ||
import ast | ||
import base64 | ||
import io | ||
import random | ||
import numpy as np | ||
import os | ||
import json | ||
import logging | ||
from PIL import Image | ||
|
||
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file | ||
|
||
lmms_logger = logging.getLogger("lmms-eval") | ||
|
||
OPEN_ENDED_PROMPT = "Answer the question using a single word or phrase." | ||
|
||
|
||
def construct_prompt(doc): | ||
question = doc["question"] | ||
question = f"{OPEN_ENDED_PROMPT}\n{question}" | ||
return question | ||
|
||
|
||
def websrc_doc_to_text(doc): | ||
question = construct_prompt(doc) | ||
return question | ||
|
||
|
||
def websrc_doc_to_visual(doc): | ||
img_bs64 = doc["image"] | ||
img = Image.open(io.BytesIO(base64.b64decode(img_bs64))) | ||
del doc['image'] | ||
return [img] | ||
|
||
|
||
def websrc_process_results(doc, results): | ||
pred = results[0] | ||
parsed_pred = pred | ||
id = doc["page_id"] | ||
websrc_ans = {"id": id, "domain": doc['domain'], "parsed_pred": parsed_pred} | ||
if "answer" in doc: | ||
websrc_ans["answer"] = doc["answer"] | ||
|
||
if 'id' in doc: | ||
websrc_ans['question_id'] = doc['id'] | ||
|
||
return { | ||
"websrc_squad_f1": websrc_ans, | ||
"submission": { | ||
websrc_ans['question_id']: pred, | ||
}, | ||
} | ||
|
||
|
||
def websrc_test_aggregate_results_for_submission(results, args): | ||
path = generate_submission_file("websrc_test_for_submission.json", args) | ||
with open(path, "w") as f: | ||
out = {} | ||
for result in results: | ||
out.update(result) | ||
json.dump(out, f, indent=4) | ||
lmms_logger.info(f"Results saved to {path}.") | ||
|
||
|
||
def websrc_aggregate_results(results): | ||
evaluation_result = {} | ||
|
||
# Group results by domain | ||
subset_to_eval_samples = defaultdict(list) | ||
for result in results: | ||
subset_to_eval_samples[result["domain"]].append(result) | ||
|
||
# Evaluate each domain | ||
for subset, sub_eval_samples in subset_to_eval_samples.items(): | ||
judge_dict, metric_dict = evaluate_websrc(sub_eval_samples) | ||
metric_dict.update({"num_example": len(sub_eval_samples)}) | ||
evaluation_result[subset] = metric_dict | ||
|
||
# Aggregate results for all domains | ||
printable_results = {} | ||
for domain in DOMAINS: | ||
if domain not in evaluation_result: | ||
continue | ||
printable_results[domain] = { | ||
"num": int(evaluation_result[domain]["num_example"]), | ||
"f1": round(evaluation_result[domain]["f1"], 3), | ||
} | ||
all_ins_f1 = np.sum([cat_results["f1"] * cat_results["num_example"] for cat_results in evaluation_result.values()]) / sum( | ||
[cat_results["num_example"] for cat_results in evaluation_result.values()] | ||
) | ||
printable_results["Overall"] = { | ||
"num": sum([cat_results["num_example"] for cat_results in evaluation_result.values()]), | ||
"f1": round(all_ins_f1, 3), | ||
} | ||
print(printable_results) | ||
return printable_results["Overall"]["f1"] | ||
|
||
|
||
################## | ||
# Helper functions written by official MMMU repo. | ||
################## | ||
DOMAINS = [ | ||
'auto', | ||
'book', | ||
'camera', | ||
'game', | ||
'jobs', | ||
'movie', | ||
'phone', | ||
'restaurant', | ||
'sports', | ||
'university', | ||
'hotel', | ||
] | ||
|
||
|
||
def evaluate_websrc(samples): | ||
|
||
def _normalize_str(string): | ||
# lower it | ||
string = string.lower() | ||
|
||
# strip non-alphanumeric characters | ||
string = re.sub(r"[^a-zA-Z0-9]", "", string) | ||
|
||
# strip leading and trailing whitespaces | ||
string = string.strip() | ||
|
||
return string | ||
|
||
judge_list = [] | ||
for sample in samples: | ||
gold_i = set(_normalize_str(sample["answer"])) | ||
pred_i = set(_normalize_str( sample["parsed_pred"])) | ||
if len(pred_i) == 0: | ||
judge_list.append(0.0) | ||
continue | ||
|
||
comm_i = gold_i.intersection(pred_i) | ||
prec_i = len(comm_i) / len(pred_i) | ||
rec_i = len(comm_i) / len(gold_i) | ||
f1_i = 2 * prec_i * rec_i / (prec_i + rec_i) if prec_i + rec_i > 0 else 0 | ||
judge_list.append(f1_i) | ||
|
||
f1 = np.mean(judge_list) | ||
return judge_list, {"f1": f1} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
group: websrc | ||
task: | ||
- websrc_val | ||
- websrc_test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
dataset_path: rootsautomation/websrc-test | ||
task: "websrc_test" | ||
test_split: test | ||
output_type: generate_until | ||
doc_to_visual: !function utils.websrc_doc_to_visual | ||
doc_to_text: !function utils.websrc_doc_to_text | ||
doc_to_target: "answer" | ||
# The return value of process_results will be used by metrics | ||
process_results: !function utils.websrc_process_results | ||
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results | ||
generation_kwargs: | ||
max_new_tokens: 16 | ||
image_aspect_ratio: pad | ||
metric_list: | ||
- metric: submission | ||
aggregation: !function utils.websrc_test_aggregate_results_for_submission | ||
higher_is_better: true | ||
metadata: | ||
- version: 0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
dataset_path: rootsautomation/websrc | ||
task: "websrc_val" | ||
test_split: dev | ||
output_type: generate_until | ||
doc_to_visual: !function utils.websrc_doc_to_visual | ||
doc_to_text: !function utils.websrc_doc_to_text | ||
doc_to_target: "answer" | ||
# The return value of process_results will be used by metrics | ||
process_results: !function utils.websrc_process_results | ||
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results | ||
generation_kwargs: | ||
max_new_tokens: 16 | ||
image_aspect_ratio: pad | ||
metric_list: | ||
- metric: websrc_squad_f1 | ||
aggregation: !function utils.websrc_aggregate_results | ||
higher_is_better: true | ||
metadata: | ||
- version: 0.0 |