Skip to content

Commit

Permalink
Pufanyi/flickr30k refractor (EvolvingLMMs-Lab#56)
Browse files Browse the repository at this point in the history
* refactor vizwizvqa task

* Delete vqav2_test and vqav2_val YAML files

* Refactor vqav2_process_results functions

* Add a pack for vqav2

* refactor okvqa

* roll back vizwiz_vqa

* Fix exact_match calculation in ok_vqa_process_results

* Update OKVQA dataset name in readme

* add model_specific_prompt_kwargs

* add model_specific_prompt_kwargs to vizwiz_vqa

* add model_specific_prompt_kwargs for vqav2

* lint

* fix a small bug for eval_logger

* Refactor make_table function to display points as "  -  " if value is None

* Merge commit 'c5e52a785d3cc87a866be9b880deb477d9f73fb7'

* Refactor ok_vqa_aggreate_submissions function

* Merge commit 'e5aa0a9601d6d8ce727315e4b0a8f13f06f26bff'

* Refactor VQA submission file saving

* Update file utils

* Merge commit '560deca9f72483ca091795d6dc2537d4c54b32b0'

* Refactor file path handling and submission generation

* OKVQA path

* vizwizvqa file

* pack cmmmu

* fix a small metric bug for cmmmu

* Add higher_is_better flag to submission metric

* Add CMMMU dataset to README.md

* Add logging and refactor submission file generation in docvqa utils.py

* pack docvqa

* add traceback to print detailed error

* Refactor docvqa_test_aggregate_results to accept additional arguments

* Add metric check in evaluator.py and update test.yaml and val.yaml

* add common `EvalAIAnswerProcessor` for okvqa, textvqa, vizwizvqa and vqav2

* merge textvqa

* textvqa

* Modify submission file generation for COCO test results

* Update test result storage path

* update coco cap file name

* Update COCO 2017 Caption dataset name

* ferret

* Add Ferret dataset

* Refactor hb_doc_to_text function to include model-specific prompts

* Add IconQA and its subtasks

* Refactor image list creation in doc_to_visual function

* Add process_results function to default template

* Update process_results function in iconqa utils.py

* refactor flickr30k

* change aggregation function

* Fix formatting issues and update logging message

* Fix llava can not handle only text question (no visuals)

* Fix qwen can not handle no image question (no visuals)

* Add fuyu prepare accelerator scripts

* refactor mme

* naming consistency

* aggregation_submissions consistency

* flickr30k naming consistency

* remove submissions for mme

* remove unused submission function

* Refactor infovqa_test.yaml and infovqa_val.yaml

* Refactor code for improved readability and maintainability

* stvqa

* remane sqa

* Update lmms_eval textcaps files and utils.py

* Update default prompt for text captions

* Refactor textcaps_aggregation_result function

* Add generate_submission_file function and update mathvista_aggregate_results signature

* Update nocaps_test.yaml and nocaps_val.yaml

* refractor internal_eval

* Add internal evaluation datasets

* pack multidocvqa

* mmvet

* Fix gpt eval timeout issue for hallubench, restore load from gpt to avoid re evaluating

* Refractor llava wild

* Refractor llava-bench-coco

* Add JSON file generation for gpt evaluation details

* mmmu

* Remove MMBench English and Chinese tasks

* Remove unnecessary return statement in mmbench_aggregate_test_results function

* Fix distributed process group initialization

* Update dataset paths and group names in mmbench test configs

* Update import statements in cc_utils.py, cn_utils.py, and en_utils.py

* Add torch module import

* lint

* Remove IconQA dataset from README.md

* Add Multi-DocVQA and its submodules

* Add new datasets and update task names

* Refactor flickr_aggregation_result function to accept additional arguments

* Add timeout kwargs in Accelerator constructor

* Add encoding to be utf-8 for cmmmu

* Fix llava try and catch, remove torch.distributed.init in main

* Ds prepare script for llava

---------

Co-authored-by: JvThunder <joshuaadrianc@gmail.com>
Co-authored-by: kcz358 <kaichenzhang358@outlook.com>
  • Loading branch information
3 people authored Feb 28, 2024
1 parent cbe3e52 commit 6dbf2a9
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 49 deletions.
7 changes: 3 additions & 4 deletions lmms_eval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
warnings.simplefilter("ignore", category=DeprecationWarning)

from accelerate import Accelerator
from accelerate.utils import InitProcessGroupKwargs
from pathlib import Path
from typing import Union
import hashlib
Expand Down Expand Up @@ -149,9 +150,6 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if not args:
args = parse_eval_args()

if args.tasks != "list":
torch.distributed.init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=600000))

# Check if no arguments were passed after parsing
if len(sys.argv) == 1:
print("┌───────────────────────────────────────────────────────────────────────────────┐")
Expand Down Expand Up @@ -186,7 +184,8 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
args_list.append(args)

# initialize Accelerator
accelerator = Accelerator()
kwargs_handler = InitProcessGroupKwargs(timeout=datetime.timedelta(seconds=60000))
accelerator = Accelerator(kwargs_handlers=[kwargs_handler])
if accelerator.is_main_process:
is_main_process = True
else:
Expand Down
47 changes: 25 additions & 22 deletions lmms_eval/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from lmms_eval.utils import stop_sequences_criteria

from accelerate import Accelerator, DistributedType
from accelerate.state import AcceleratorState
from typing import List, Optional, Union, Tuple
import warnings

Expand Down Expand Up @@ -61,25 +62,12 @@ def __init__(
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
else:
self._device = device
try:
(
self._tokenizer,
self._model,
self._image_processor,
self._max_length,
) = load_pretrained_model(pretrained, None, get_model_name_from_path(pretrained), device_map=self._device)
is_deepspeed = False
except Exception as e:
eval_logger.info(f"Encounter error : \n {e} \n Trying again with loading without low_mem_usage = True and device_map for deep speed")
if is_deepspeed_zero3_enabled():
unset_hf_deepspeed_config()
is_deepspeed = True
(
self._tokenizer,
self._model,
self._image_processor,
self._max_length,
) = load_pretrained_model(pretrained, None, get_model_name_from_path(pretrained), device_map=self._device)
(
self._tokenizer,
self._model,
self._image_processor,
self._max_length,
) = load_pretrained_model(pretrained, None, get_model_name_from_path(pretrained), device_map=self._device)
self._config = self._model.config
self.model.eval()
self.model.tie_weights()
Expand All @@ -89,8 +77,22 @@ def __init__(
self.use_cache = use_cache
self.truncate_context = truncate_context
# assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue."
if accelerator.num_processes > 1 and not is_deepspeed:
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
if accelerator.num_processes > 1:
assert accelerator.distributed_type in [
DistributedType.FSDP,
DistributedType.MULTI_GPU,
DistributedType.DEEPSPEED
], "Unsupported distributed type provided. Only DDP and FSDP are supported."
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
# Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works
# I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.
if accelerator.distributed_type == DistributedType.DEEPSPEED:
kwargs = {
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
"train_batch_size" : self.batch_size_per_gpu * accelerator.num_processes,
}
AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)
eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0")
if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED:
self._model = accelerator.prepare(self.model)
else:
Expand Down Expand Up @@ -364,16 +366,17 @@ def _collate(x):
max_new_tokens=gen_kwargs["max_new_tokens"],
use_cache=self.use_cache,
)
text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)
except Exception as e:
eval_logger.error(f"Error {e} in generating")
cont = ""
text_outputs = [""]

# cont_toks_list = cont.tolist()
# for cont_toks, context in zip(cont_toks_list, contexts):
# discard context + left-padding toks if using causal decoder-only LMM
# if self.truncate_context:
# cont_toks = cont_toks[input_ids.shape[1] :]
text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
# if self.truncate_context:
# for term in until:
Expand Down
2 changes: 1 addition & 1 deletion lmms_eval/tasks/cmmmu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def cmmmu_process_test_results_for_submission(doc, results):

def cmmmu_test_aggregate_results_for_submission(results, args):
file = generate_submission_file("cmmmu_test_for_submission.jsonl", args)
with open(file, "w") as f:
with open(file, "w", encoding='utf8') as f:
for result in results:
json.dump(result, f, ensure_ascii=False)
f.write("\n")
Expand Down
44 changes: 22 additions & 22 deletions lmms_eval/tasks/flickr30k/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def flickr_process_result(doc, result):
return {f"flickr_{metric}": data_dict for metric in FLICKR_METRICS}


def flickr_aggregation_result(results, metric):
def flickr_aggregation_result(results, metric, args):
scorers = [(Bleu(4), "Bleu_1"), (Bleu(4), "Bleu_2"), (Bleu(4), "Bleu_3"), (Bleu(4), "Bleu_4"), (Meteor(), "METEOR"), (Rouge(), "ROUGE_L"), (Cider(), "CIDEr"), (Spice(), "SPICE")]
scorers_dict = {s[1]: s for s in scorers}

Expand Down Expand Up @@ -87,45 +87,45 @@ def flickr_aggregation_result(results, metric):
n = int(metric.split("_")[-1])
score = score[n - 1]

os.makedirs("./submissions", exist_ok=True)
if not os.path.exists("./submissions/flickr30k_captions_val2014_alg_results.json"):
eval_logger.info("Storing prediction that can be submitted to the server ...")
with open("./submissions/flickr30k_captions_val2014_alg_results.json", "w") as f:
json.dump(stored_results, f, indent=4)
path = generate_submission_file(f"flickr30k_captions_val2014_alg_results_{metric}.json", args)

eval_logger.info("Storing prediction that can be submitted to the server ...")
with open(path, "w") as f:
json.dump(stored_results, f, indent=4)

return score


def flickr_bleu4(results):
return flickr_aggregation_result(results, "Bleu_4")
def flickr_bleu4(results, args):
return flickr_aggregation_result(results, "Bleu_4", args)


def flickr_bleu3(results):
return flickr_aggregation_result(results, "Bleu_3")
def flickr_bleu3(results, args):
return flickr_aggregation_result(results, "Bleu_3", args)


def flickr_bleu2(results):
return flickr_aggregation_result(results, "Bleu_2")
def flickr_bleu2(results, args):
return flickr_aggregation_result(results, "Bleu_2", args)


def flickr_bleu1(results):
return flickr_aggregation_result(results, "Bleu_1")
def flickr_bleu1(results, args):
return flickr_aggregation_result(results, "Bleu_1", args)


def flickr_meteor(results):
return flickr_aggregation_result(results, "METEOR")
def flickr_meteor(results, args):
return flickr_aggregation_result(results, "METEOR", args)


def flickr_rougel(results):
return flickr_aggregation_result(results, "ROUGE_L")
def flickr_rougel(results, args):
return flickr_aggregation_result(results, "ROUGE_L", args)


def flickr_cider(results):
return flickr_aggregation_result(results, "CIDEr")
def flickr_cider(results, args):
return flickr_aggregation_result(results, "CIDEr", args)


def flickr_spice(results):
return flickr_aggregation_result(results, "SPICE")
def flickr_spice(results, args):
return flickr_aggregation_result(results, "SPICE", args)


def flickr_test_process_result(doc, result):
Expand Down

0 comments on commit 6dbf2a9

Please sign in to comment.