Skip to content

Commit

Permalink
adapt qwen to sqa, gqa, ai2d, docvqa (EvolvingLMMs-Lab#36)
Browse files Browse the repository at this point in the history
* adapt qwen to sqa, gqa, ai2d, docvqa

* black
  • Loading branch information
jzhang38 authored Feb 3, 2024
1 parent b94afc7 commit a91b591
Show file tree
Hide file tree
Showing 15 changed files with 90 additions and 52 deletions.
15 changes: 11 additions & 4 deletions lmms_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class TaskConfig(dict):

model_specific_prompt_kwargs: dict = None
model_specific_generation_kwargs: dict = None
model_specific_target_kwargs: dict = None

def __post_init__(self) -> None:
if self.dataset_path and os.path.exists(os.path.dirname(self.dataset_path)):
Expand Down Expand Up @@ -347,7 +348,7 @@ def build_all_requests(self, limit=None, rank=None, world_size=None) -> None:
doc_id_iterator = utils.create_iterator([i for i in range(len(docs))], rank, world_size, limit)
doc_id_iterator, doc_id_iterator_counting = itertools.tee(doc_id_iterator)
total_docs = sum(1 for _ in doc_id_iterator_counting)
pbar = tqdm(total=total_docs, desc=f"Building context {rank}", position=rank)
pbar = tqdm(total=total_docs, desc=f"Building context", disable=(rank != 0))
for doc_id in doc_id_iterator:
# sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx = self.fewshot_context(doc_id, 0 if self.config.num_fewshot is None else self.config.num_fewshot, self.config.training_split if self.has_training_docs() else split)
Expand Down Expand Up @@ -594,14 +595,20 @@ def _prepare_model_specific_config(self):
if self.model_name in self.model_specific_prompt_kwargs:
self.model_specific_prompt_kwargs = self.model_specific_prompt_kwargs[self.model_name]
else:
self.model_specific_prompt_kwargs = self.model_specific_prompt_kwargs["default"]
self.model_specific_prompt_kwargs = self.model_specific_prompt_kwargs.get("default", None)

self.model_specific_target_kwargs = self.config.model_specific_target_kwargs
if self.model_specific_target_kwargs is not None:
if self.model_name in self.model_specific_target_kwargs:
self.model_specific_target_kwargs = self.model_specific_target_kwargs[self.model_name]
else:
self.model_specific_target_kwargs = self.model_specific_target_kwargs["default"].get("default", None)
self.model_specific_generation_kwargs = self.config.model_specific_generation_kwargs
if self.model_specific_generation_kwargs is not None:
if self.model_name in self.model_specific_generation_kwargs:
self.model_specific_generation_kwargs = self.model_specific_generation_kwargs[self.model_name]
else:
self.model_specific_generation_kwargs = self.model_specific_generation_kwargs["default"]
self.model_specific_generation_kwargs = self.model_specific_generation_kwargs.get("default", {})

self.config.generation_kwargs.update(self.model_specific_generation_kwargs)

Expand Down Expand Up @@ -839,7 +846,7 @@ def doc_to_target(self, doc: dict) -> Union[int, str, list]:
elif type(doc_to_target) == list:
return doc_to_target
elif callable(doc_to_target):
return doc_to_target(doc)
return doc_to_target(doc, self.model_specific_target_kwargs) if self.model_specific_target_kwargs is not None else doc_to_target(doc)
# Used when applying a Promptsource template
elif hasattr(doc_to_target, "apply"):
applied_prompt = doc_to_target.apply(doc)
Expand Down
12 changes: 7 additions & 5 deletions lmms_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def evaluate(
# Don't use above one, this would crash if doc_iterator_for_counting contains too many objects and very slow
doc_iterator_for_counting = itertools.islice(range(len(task.test_docs())), lm.rank, limit, lm.world_size) if task.has_test_docs() else itertools.islice(range(len(task.validation_docs())), lm.rank, limit, lm.world_size)
total_docs = sum(1 for _ in doc_iterator_for_counting)
pbar = tqdm(total=total_docs, desc=f"Postprocessing {lm.rank}", position=lm.rank)
pbar = tqdm(total=total_docs, desc=f"Postprocessing", disable=(lm.rank != 0))
for doc_id, doc in doc_iterator:
# subset instances to only this document id ; sort by idx
requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
Expand Down Expand Up @@ -427,12 +427,14 @@ def evaluate(
else:
group_name = None
agg_fn = task.aggregation()[metric]
# Bo: for models only need agg items
if inspect.getfullargspec(agg_fn).args == ["results"]:
results[task_name][metric_key] = agg_fn(items)

# Bo: for models that need to know the args to save to correct path
elif inspect.getfullargspec(agg_fn).args == ["results", "args"]:
if inspect.getfullargspec(agg_fn).args == ["results", "args"]:
results[task_name][metric_key] = agg_fn(items, cli_args)
else:
# Bo: for models only need agg items
results[task_name][metric_key] = agg_fn(items)

results[task_name]["samples"] = len(items)

# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
Expand Down
2 changes: 1 addition & 1 deletion lmms_eval/models/qwen_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
self._tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=trust_remote_code)
self.tokenizer.padding_side = "left"
self.tokenizer.pad_token_id = self.tokenizer.eod_id
self.prompt = "<img>{}</img>{} Answer:"
self.prompt = "<img>{}</img>{}"
self._config = self._model.config
self.model.tie_weights()
self.batch_size_per_gpu = int(batch_size)
Expand Down
12 changes: 10 additions & 2 deletions lmms_eval/tasks/ai2d/ai2d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,20 @@ metric_list:
higher_is_better: true
ignore_case: true
ignore_punctuation: true
process_results: !function utils.ai2d_process_results
metadata:
- version: 0.0

model_specific_prompt_kwargs:
default:
prompt_format: mcq
pre_prompt: ""
post_prompt: "\nAnswer with the option's letter from the given choices directly."

# qwen formulate ai2d as question answering instead of mcq
qwen_vl:
prompt_format: qa
pre_prompt: ""
post_prompt: " Answer:"

model_specific_target_kwargs:
default: "mcq"
qwen_vl: "qa"
37 changes: 16 additions & 21 deletions lmms_eval/tasks/ai2d/utils.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,27 @@
def ai2d_doc_to_text(doc, model_specific_prompt_kwargs=None):
question, choices = doc["question"], doc["options"]
len_choices = len(choices)
options = [chr(ord("A") + i) for i in range(len_choices)]
choices_str = "\n".join([f"{option}. {choice}" for option, choice in zip(options, choices)])

post_prompt = model_specific_prompt_kwargs["post_prompt"]
pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
return f"{pre_prompt}{question}\n{choices_str}{post_prompt}"
if model_specific_prompt_kwargs["prompt_format"] == "mcq":
options = [chr(ord("A") + i) for i in range(len_choices)]
choices_str = "\n".join([f"{option}. {choice}" for option, choice in zip(options, choices)])
return f"{pre_prompt}{question}\n{choices_str}{post_prompt}"
elif model_specific_prompt_kwargs["prompt_format"] == "qa":
options = "\n".join(choices)
return f"{pre_prompt}{question}{options}{post_prompt}"
else:
raise ValueError(f"Unknown prompt format: {model_specific_prompt_kwargs['prompt_format']}")


def ai2d_doc_to_visual(doc):
return [doc["image"].convert("RGB")]


def ai2d_doc_to_target(doc):
len_choices = len(doc["options"])
options = [chr(ord("A") + i) for i in range(len_choices)]
return options[int(doc["answer"])]


def ai2d_process_results(doc, results):
# I know this is weird, but it's how llava parse it.
target = ai2d_doc_to_target(doc)
pred = results[0]
if pred == target:
return {"exact_match": 1.0}
# pattern: ^[A-Z]\. .*
if len(pred) >= 2 and pred[0].isupper() and pred[1] == ".":
result = 1.0 if pred[0] == target else 0.0
return {"exact_match": result}
return {"exact_match": 0.0}
def ai2d_doc_to_target(doc, model_specific_target_kwargs):
if model_specific_target_kwargs == "mcq":
len_choices = len(doc["options"])
options = [chr(ord("A") + i) for i in range(len_choices)]
return options[int(doc["answer"])]
elif model_specific_target_kwargs == "qa":
return doc["options"][int(doc["answer"])]
2 changes: 1 addition & 1 deletion lmms_eval/tasks/chartqa/chartqa.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ model_specific_prompt_kwargs:
post_prompt: "\nAnswer the question with a single word."
qwen_vl:
pre_prompt: ""
post_prompt: ""
post_prompt: " Answer:"

6 changes: 3 additions & 3 deletions lmms_eval/tasks/chartqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ def chartqa_doc_to_visual(doc):
return [doc["image"].convert("RGB")]


def chartqa_doc_to_text(doc, mdoel_specific_prompt_kwargs):
def chartqa_doc_to_text(doc, model_specific_prompt_kwargs):
question = doc["question"]
pre_prompt = mdoel_specific_prompt_kwargs["pre_prompt"]
post_prompt = mdoel_specific_prompt_kwargs["post_prompt"]
pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
post_prompt = model_specific_prompt_kwargs["post_prompt"]
return f"{pre_prompt}{question}{post_prompt}"


Expand Down
4 changes: 3 additions & 1 deletion lmms_eval/tasks/docvqa/docvqa_val.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ model_specific_prompt_kwargs:
default:
pre_prompt: ""
post_prompt: "\nAnswer the question using a single word or phrase."

qwen_vl:
pre_prompt: ""
post_prompt: " Answer:"
6 changes: 3 additions & 3 deletions lmms_eval/tasks/docvqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ def docvqa_doc_to_visual(doc):
return [doc["image"].convert("RGB")]


def docvqa_doc_to_text(doc, mdoel_specific_prompt_kwargs):
def docvqa_doc_to_text(doc, model_specific_prompt_kwargs):
question = doc["question"]
pre_prompt = mdoel_specific_prompt_kwargs["pre_prompt"]
post_prompt = mdoel_specific_prompt_kwargs["post_prompt"]
pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
post_prompt = model_specific_prompt_kwargs["post_prompt"]
return f"{pre_prompt}{question}{post_prompt}"


Expand Down
8 changes: 8 additions & 0 deletions lmms_eval/tasks/gqa/gqa.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,11 @@ metric_list:
ignore_punctuation: true
metadata:
- version: 0.0

model_specific_prompt_kwargs:
default:
pre_prompt: ""
post_prompt: "\nAnswer the question using a single word or phrase."
qwen_vl:
pre_prompt: ""
post_prompt: " Answer:"
7 changes: 4 additions & 3 deletions lmms_eval/tasks/gqa/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from datasets import load_dataset

prompt = "\nAnswer the question using a single word or phrase."
GQA_RAW_IMAGE_DATASET = None
GQA_ID2IMAGE = None

Expand All @@ -17,6 +16,8 @@ def gqa_doc_to_visual(doc):
return [image]


def gqa_doc_to_text(doc):
def gqa_doc_to_text(doc, model_specific_prompt_kwargs):
question = doc["question"]
return f"{question}{prompt}"
pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
post_prompt = model_specific_prompt_kwargs["post_prompt"]
return f"{pre_prompt}{question}{post_prompt}"
6 changes: 3 additions & 3 deletions lmms_eval/tasks/infovqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ def infovqa_doc_to_visual(doc):
return [doc["image"].convert("RGB")]


def infovqa_doc_to_text(doc, mdoel_specific_prompt_kwargs):
def infovqa_doc_to_text(doc, model_specific_prompt_kwargs):
question = doc["question"]
pre_prompt = mdoel_specific_prompt_kwargs["pre_prompt"]
post_prompt = mdoel_specific_prompt_kwargs["post_prompt"]
pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
post_prompt = model_specific_prompt_kwargs["post_prompt"]
return f"{pre_prompt}{question}{post_prompt}"


Expand Down
4 changes: 4 additions & 0 deletions lmms_eval/tasks/scienceqa_img/scienceqa.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@ metadata:

model_specific_prompt_kwargs:
default:
format: default
pre_prompt: ""
post_prompt: "\nAnswer with the option's letter from the given choices directly."
qwen_vl:
format: qwen_vl

model_specific_generation_kwargs:
llava:
image_aspect_ratio: original
Expand Down
3 changes: 3 additions & 0 deletions lmms_eval/tasks/scienceqa_img/scienceqa_img.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ metadata:

model_specific_prompt_kwargs:
default:
format: default
pre_prompt: ""
post_prompt: "\nAnswer with the option's letter from the given choices directly."
qwen_vl:
format: qwen_vl
model_specific_generation_kwargs:
llava:
image_aspect_ratio: original
Expand Down
18 changes: 13 additions & 5 deletions lmms_eval/tasks/scienceqa_img/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,20 @@ def sqa_doc_to_text(doc, model_specific_prompt_kwargs=None):
len_choices = len(choices)
options = [chr(ord("A") + i) for i in range(len_choices)]
choices_str = "\n".join([f"{option}. {choice}" for option, choice in zip(options, choices)])
if context:
context = f"Context: {context}\n"
if model_specific_prompt_kwargs["format"] == "default":
if context:
context = f"Context: {context}\n"

post_prompt = model_specific_prompt_kwargs["post_prompt"]
pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
return f"{pre_prompt}{context}{question}\n{choices_str}{post_prompt}"
post_prompt = model_specific_prompt_kwargs["post_prompt"]
pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
return f"{pre_prompt}{context}{question}\n{choices_str}{post_prompt}"
elif model_specific_prompt_kwargs["format"] == "qwen_vl":
prompt = "Context: {}\nQuestion: {}\nOptions: {}\nAnswer:"
context = context if context else "N/A"
prompt = prompt.format(context, question, choices_str)
return prompt
else:
raise ValueError(f"Unknown prompt format: {model_specific_prompt_kwargs}")


def sqa_doc_to_visual(doc):
Expand Down

0 comments on commit a91b591

Please sign in to comment.