From 6b20902d94ef9120181cd26cdce1e139046dbdf4 Mon Sep 17 00:00:00 2001 From: Zhang Peiyuan Date: Thu, 29 Feb 2024 13:40:02 +0800 Subject: [PATCH] Dev/py add models (#57) * add instructblip * minicpm_v * remove from qwen-vl * speed up postprocessing * Optimize build context speed --------- Co-authored-by: Pu Fanyi Co-authored-by: kcz358 --- lmms_eval/api/task.py | 29 ++- lmms_eval/evaluator.py | 12 +- lmms_eval/models/__init__.py | 3 +- lmms_eval/models/instructblip.py | 230 +++++++++++++++++ lmms_eval/models/minicpm_v.py | 232 ++++++++++++++++++ lmms_eval/models/qwen_vl.py | 4 + lmms_eval/tasks/mme/mme.yaml | 3 + .../textvqa/_default_template_textvqa_yaml | 5 +- 8 files changed, 512 insertions(+), 6 deletions(-) create mode 100644 lmms_eval/models/instructblip.py create mode 100644 lmms_eval/models/minicpm_v.py diff --git a/lmms_eval/api/task.py b/lmms_eval/api/task.py index efcbfe95..77120924 100644 --- a/lmms_eval/api/task.py +++ b/lmms_eval/api/task.py @@ -239,6 +239,19 @@ def download(self, data_dir=None, cache_dir=None, download_mode=None) -> None: cache_dir=cache_dir, download_mode=download_mode, ) + self.dataset_no_image = datasets.load_dataset( + path=self.DATASET_PATH, + name=self.DATASET_NAME, + data_dir=data_dir, + cache_dir=cache_dir, + download_mode=download_mode, + ) + for doc_name in self.dataset_no_image: + column_names = self.dataset_no_image[doc_name].column_names + image_column = [col for col in column_names if "image" in col.lower()] + # remove image column from docs + if image_column: + self.dataset_no_image[doc_name] = self.dataset_no_image[doc_name].remove_columns(image_column) @property def config(self): @@ -455,7 +468,7 @@ def fewshot_context( assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`" description = description if description else "" - doc = self.dataset[split][doc_id] + doc = self.dataset_no_image[split][doc_id] if num_fewshot == 0: labeled_examples = "" @@ -674,6 +687,18 @@ def download(self, dataset_kwargs=None) -> None: download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS, **dataset_kwargs if dataset_kwargs is not None else {}, ) + self.dataset_no_image = datasets.load_dataset( + path=self.DATASET_PATH, + name=self.DATASET_NAME, + download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS, + **dataset_kwargs if dataset_kwargs is not None else {}, + ) + for doc_name in self.dataset_no_image: + column_names = self.dataset_no_image[doc_name].column_names + image_column = [col for col in column_names if "image" in col.lower()] + # remove image column from docs + if image_column: + self.dataset_no_image[doc_name] = self.dataset_no_image[doc_name].remove_columns(image_column) def has_training_docs(self) -> bool: if self.config.training_split is not None: @@ -731,7 +756,7 @@ def fewshot_context(self, doc_id, num_fewshot, split): :returns: str The fewshot context. """ - doc = self.dataset[split][doc_id] + doc = self.dataset_no_image[split][doc_id] if num_fewshot == 0: # always prepend the (possibly empty) task description labeled_examples = self.config.description diff --git a/lmms_eval/evaluator.py b/lmms_eval/evaluator.py index ffc8cf13..7d6174dc 100644 --- a/lmms_eval/evaluator.py +++ b/lmms_eval/evaluator.py @@ -314,7 +314,15 @@ def evaluate( # TODO: make it possible to use a different metric per filter # iterate over different filters used for key in task.instances[0].filtered_resps.keys(): - doc_iterator = itertools.islice(enumerate(task.test_docs()), lm.rank, limit, lm.world_size) if task.has_test_docs() else itertools.islice(enumerate(task.validation_docs()), lm.rank, limit, lm.world_size) + # hack: remove image columns to speed avoid loading images and speed up postprocessing + # reason: doc_iterator will actually load image if it's in the doc. + docs = task.test_docs() if task.has_test_docs() else task.validation_docs() + column_names = docs.column_names + image_column = [col for col in column_names if "image" in col.lower()] + # remove image column from docs + if image_column: + docs = docs.remove_columns(image_column) + doc_iterator = itertools.islice(enumerate(docs), lm.rank, limit, lm.world_size) # Instead of converting the iterator to a list, use `itertools.tee` to create a parallel iterator for counting # doc_iterator, doc_iterator_for_counting = itertools.tee(doc_iterator) # Don't use above one, this would crash if doc_iterator_for_counting contains too many objects and very slow @@ -330,8 +338,8 @@ def evaluate( target = task.doc_to_target(doc) example = { "doc_id": doc_id, - "doc": {k: v for k, v in doc.items() if "image" not in k}, # do not include image "target": target, + "doc": doc, "arguments": [tuple(a for a in req.args if isinstance(a, (int, str))) for req in requests], # do not include image "resps": [req.resps for req in requests], "filtered_resps": [req.filtered_resps[key] for req in requests], diff --git a/lmms_eval/models/__init__.py b/lmms_eval/models/__init__.py index 55dbf549..760269c7 100644 --- a/lmms_eval/models/__init__.py +++ b/lmms_eval/models/__init__.py @@ -3,7 +3,8 @@ from .qwen_vl import Qwen_VL from .fuyu import Fuyu from .gpt4v import GPT4V - +from .instructblip import InstructBLIP +from .minicpm_v import MiniCPM_V import os try: diff --git a/lmms_eval/models/instructblip.py b/lmms_eval/models/instructblip.py new file mode 100644 index 00000000..1ad56207 --- /dev/null +++ b/lmms_eval/models/instructblip.py @@ -0,0 +1,230 @@ +import torch +import logging +import copy +from tqdm import tqdm +from lmms_eval import utils +from lmms_eval.api.instance import Instance +from lmms_eval.api.model import lmms +from lmms_eval.api.registry import register_model +from accelerate import Accelerator, DistributedType +from accelerate.state import AcceleratorState +from typing import List, Optional, Union, Tuple +from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration + +from lmms_eval.utils import stop_sequences_criteria + + +import warnings + +warnings.filterwarnings("ignore") + +eval_logger = logging.getLogger("lmms-eval") + + +@register_model("instructblip") +class InstructBLIP(lmms): + """ + InstructBLIP Model + """ + + def __init__( + self, + pretrained: str = "Salesforce/instructblip-vicuna-7b", + device: Optional[str] = "cuda", + dtype: Optional[Union[str, torch.dtype]] = "auto", + batch_size: Optional[Union[int, str]] = 1, + **kwargs, + ) -> None: + super().__init__() + # Do not use kwargs for now + assert kwargs == {}, f"Unexpected kwargs: {kwargs}" + + accelerator = Accelerator() + if accelerator.num_processes > 1: + self._device = torch.device(f"cuda:{accelerator.local_process_index}") + else: + self._device = device + self._model = InstructBlipForConditionalGeneration.from_pretrained(pretrained,device_map=self._device) + self._image_processor = InstructBlipProcessor.from_pretrained(pretrained) + self._tokenizer = self._image_processor.tokenizer + self._config = self._model.config + self.model.eval() + self.model.tie_weights() + self.batch_size_per_gpu = int(batch_size) + if accelerator.num_processes > 1: + assert accelerator.distributed_type in [ + DistributedType.FSDP, + DistributedType.MULTI_GPU, + DistributedType.DEEPSPEED + ], "Unsupported distributed type provided. Only DDP and FSDP are supported." + # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model + # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works + # I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work. + if accelerator.distributed_type == DistributedType.DEEPSPEED: + kwargs = { + "train_micro_batch_size_per_gpu": self.batch_size_per_gpu, + "train_batch_size" : self.batch_size_per_gpu * accelerator.num_processes, + } + AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs) + eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0") + if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED: + self._model = accelerator.prepare(self.model) + else: + self._model = accelerator.prepare_model(self.model, evaluation_mode=True) + self.accelerator = accelerator + if self.accelerator.is_local_main_process: + eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + else: + self.model.to(self._device) + self._rank = 0 + self._word_size = 1 + + @property + def config(self): + # return the associated transformers.AutoConfig for the given pretrained model. + return self._config + + @property + def tokenizer(self): + return self._tokenizer + + @property + def model(self): + # returns the model, unwrapping it if using Accelerate + if hasattr(self, "accelerator"): + return self.accelerator.unwrap_model(self._model) + else: + return self._model + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def max_length(self): + return self._max_length + + + @property + def batch_size(self): + return self.batch_size_per_gpu + + @property + def device(self): + return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]: + """ """ + add_special_tokens = False if add_special_tokens is None else add_special_tokens + encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) + # left-truncate the encoded context to be at most `left_truncate_len` tokens long + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + return encoding + + def tok_decode(self, tokens): + return self.tokenizer.decode(tokens) + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + # TODO + assert False, "We have not implemented this function for InstructBLIP yet" + + def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: + # TODO + assert False, "We have not implemented this function for InstructBLIP yet" + + def flatten(self, input): + new_list = [] + for i in input: + for j in i: + new_list.append(j) + return new_list + + def generate_until(self, requests: List[Instance]) -> List[str]: + res = [] + + def _collate(x): + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + toks = self.tok_encode(x[0]) + return -len(toks), x[0] + + # we group requests by their generation_kwargs, + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling + # in the same batch. + re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True) + chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) + num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1 + pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding") + for chunk in chunks: + contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk) + task = task[0] + split = split[0] + visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] + visuals = self.flatten(visuals) + # we assume all gen kwargs in the batch are the same + # this is safe to assume because the `grouper` object ensures it. + gen_kwargs = all_gen_kwargs[0] + + # Set default values for until and max_new_tokens + until = [self.tok_decode(self.eot_token_id)] + + # Update values from gen_kwargs if present + if "until" in gen_kwargs: + until = gen_kwargs.pop("until") + if isinstance(until, str): + until = [until] + elif not isinstance(until, list): + raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}") + assert self.batch_size_per_gpu == 1, "Do not support batch_size_per_gpu > 1 for now" + context = contexts[0] + if "" in context: + # instruct blip does not expect the tag + context = context.replace("", "") + inputs = self._image_processor(images=visuals, text=context, return_tensors="pt").to(self.device) + + gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))] + if "max_new_tokens" not in gen_kwargs: + gen_kwargs["max_new_tokens"] = 1024 + if "temperature" not in gen_kwargs: + gen_kwargs["temperature"] = 0 + if "top_p" not in gen_kwargs: + gen_kwargs["top_p"] = None + if "num_beams" not in gen_kwargs: + gen_kwargs["num_beams"] = 1 + try: + cont = self.model.generate( + **inputs, + do_sample=True if gen_kwargs["temperature"] > 0 else False, + temperature=gen_kwargs["temperature"], + top_p=gen_kwargs["top_p"], + num_beams=gen_kwargs["num_beams"], + max_new_tokens=gen_kwargs["max_new_tokens"], + ) + except Exception as e: + eval_logger.error(f"Error {e} in generating") + cont = "" + text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0].strip() + res.append(text_outputs) + self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs) + pbar.update(1) + # reorder this group of results back to original unsorted form + res = re_ords.get_original(res) + + pbar.close() + return res diff --git a/lmms_eval/models/minicpm_v.py b/lmms_eval/models/minicpm_v.py new file mode 100644 index 00000000..482ef301 --- /dev/null +++ b/lmms_eval/models/minicpm_v.py @@ -0,0 +1,232 @@ +import torch +import logging +from tqdm import tqdm +from lmms_eval import utils +from lmms_eval.api.instance import Instance +from lmms_eval.api.model import lmms +from lmms_eval.api.registry import register_model +from accelerate import Accelerator, DistributedType +from accelerate.state import AcceleratorState +from typing import List, Optional, Union, Tuple +from transformers import AutoModel, AutoTokenizer + + + +import warnings + +warnings.filterwarnings("ignore") + +eval_logger = logging.getLogger("lmms-eval") + + +@register_model("minicpm_v") +class MiniCPM_V(lmms): + """ + MiniCPM_V Model + """ + + def __init__( + self, + pretrained: str = "openbmb/MiniCPM-V", + device: Optional[str] = "cuda", + dtype: Optional[Union[str, torch.dtype]] = torch.bfloat16, + batch_size: Optional[Union[int, str]] = 1, + trust_remote_code: Optional[bool] = True, + **kwargs, + ) -> None: + super().__init__() + # Do not use kwargs for now + assert kwargs == {}, f"Unexpected kwargs: {kwargs}" + + accelerator = Accelerator() + if accelerator.num_processes > 1: + self._device = torch.device(f"cuda:{accelerator.local_process_index}") + else: + self._device = device + self._model = AutoModel.from_pretrained(pretrained, trust_remote_code=trust_remote_code, torch_dtype=dtype, device_map=self._device).to(dtype) + self._tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=trust_remote_code) + self._config = self._model.config + self.model.eval() + self.model.tie_weights() + self.batch_size_per_gpu = int(batch_size) + if accelerator.num_processes > 1: + assert accelerator.distributed_type in [ + DistributedType.FSDP, + DistributedType.MULTI_GPU, + DistributedType.DEEPSPEED + ], "Unsupported distributed type provided. Only DDP and FSDP are supported." + # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model + # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works + # I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work. + if accelerator.distributed_type == DistributedType.DEEPSPEED: + kwargs = { + "train_micro_batch_size_per_gpu": self.batch_size_per_gpu, + "train_batch_size" : self.batch_size_per_gpu * accelerator.num_processes, + } + AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs) + eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0") + if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED: + self._model = accelerator.prepare(self.model) + else: + self._model = accelerator.prepare_model(self.model, evaluation_mode=True) + self.accelerator = accelerator + if self.accelerator.is_local_main_process: + eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + else: + self.model.to(self._device) + self._rank = 0 + self._word_size = 1 + + @property + def config(self): + # return the associated transformers.AutoConfig for the given pretrained model. + return self._config + + @property + def tokenizer(self): + return self._tokenizer + + @property + def model(self): + # returns the model, unwrapping it if using Accelerate + if hasattr(self, "accelerator"): + return self.accelerator.unwrap_model(self._model) + else: + return self._model + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def max_length(self): + return self._max_length + + + @property + def batch_size(self): + return self.batch_size_per_gpu + + @property + def device(self): + return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]: + """ """ + add_special_tokens = False if add_special_tokens is None else add_special_tokens + encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) + # left-truncate the encoded context to be at most `left_truncate_len` tokens long + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + return encoding + + def tok_decode(self, tokens): + return self.tokenizer.decode(tokens) + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + # TODO + assert False, "We have not implemented this function for MiniCPM_V yet" + + def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: + # TODO + assert False, "We have not implemented this function for MiniCPM_V yet" + + def flatten(self, input): + new_list = [] + for i in input: + for j in i: + new_list.append(j) + return new_list + + def generate_until(self, requests: List[Instance]) -> List[str]: + res = [] + + def _collate(x): + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + toks = self.tok_encode(x[0]) + return -len(toks), x[0] + + # we group requests by their generation_kwargs, + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling + # in the same batch. + re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True) + chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) + num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1 + pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding") + for chunk in chunks: + contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk) + task = task[0] + split = split[0] + visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] + visuals = self.flatten(visuals) + # we assume all gen kwargs in the batch are the same + # this is safe to assume because the `grouper` object ensures it. + gen_kwargs = all_gen_kwargs[0] + + # Set default values for until and max_new_tokens + until = [self.tok_decode(self.eot_token_id)] + + # Update values from gen_kwargs if present + if "until" in gen_kwargs: + until = gen_kwargs.pop("until") + if isinstance(until, str): + until = [until] + elif not isinstance(until, list): + raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}") + assert self.batch_size_per_gpu == 1, "Do not support batch_size_per_gpu > 1 for now" + assert len(visuals) == 1, "MiniCPM_V interface does not support bn_image > 1 for now" + context = contexts[0] + if "" in context: + # minicpm does not expect the tag + context = context.replace("", "") + msgs = [{'role': 'user', 'content': context}] + + gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))] + if "max_new_tokens" not in gen_kwargs: + gen_kwargs["max_new_tokens"] = 1024 + if "temperature" not in gen_kwargs: + gen_kwargs["temperature"] = 0 + if "top_p" not in gen_kwargs: + gen_kwargs["top_p"] = None + if "num_beams" not in gen_kwargs: + gen_kwargs["num_beams"] = 1 + try: + # ominicpm does not give much information on how they do eval so I just use the chat format. + response, context, _ = self.model.chat( + image=visuals[0], + msgs=msgs, + context=None, + tokenizer=self.tokenizer, + sampling=True if gen_kwargs["temperature"] > 0 else False, + temperature=gen_kwargs["temperature"], + top_p=gen_kwargs["top_p"], + num_beams=gen_kwargs["num_beams"], + max_new_tokens=gen_kwargs["max_new_tokens"], + ) + except Exception as e: + eval_logger.error(f"Error {e} in generating") + cont = "" + res.append(response) + self.cache_hook.add_partial("generate_until", (context, gen_kwargs), response) + pbar.update(1) + # reorder this group of results back to original unsorted form + res = re_ords.get_original(res) + + pbar.close() + return res diff --git a/lmms_eval/models/qwen_vl.py b/lmms_eval/models/qwen_vl.py index f6cfd336..5201f79f 100644 --- a/lmms_eval/models/qwen_vl.py +++ b/lmms_eval/models/qwen_vl.py @@ -234,6 +234,10 @@ def _collate(x): until = [until] elif not isinstance(until, list): raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}") + for i in range(len(contexts)): + if "" in contexts[i]: + context[i] = contexts[i].replace("", "") + questions = [self.prompt.format(visual_path, context) for visual_path, context in zip(visual_paths, contexts)] # Similar to llava, is visual paths has len 0 # Then nothing will be executed diff --git a/lmms_eval/tasks/mme/mme.yaml b/lmms_eval/tasks/mme/mme.yaml index 93f34f14..cf49702d 100644 --- a/lmms_eval/tasks/mme/mme.yaml +++ b/lmms_eval/tasks/mme/mme.yaml @@ -27,6 +27,9 @@ model_specific_prompt_kwargs: default: pre_prompt: "" post_prompt: "\nAnswer the question using a single word or phrase." + qwen_vl: + pre_prompt: "" + post_prompt: " Answer:" otterhd: pre_prompt: "" post_prompt: "" diff --git a/lmms_eval/tasks/textvqa/_default_template_textvqa_yaml b/lmms_eval/tasks/textvqa/_default_template_textvqa_yaml index 3b7a64f5..282b401b 100644 --- a/lmms_eval/tasks/textvqa/_default_template_textvqa_yaml +++ b/lmms_eval/tasks/textvqa/_default_template_textvqa_yaml @@ -11,4 +11,7 @@ model_specific_prompt_kwargs: default: pre_prompt: "" post_prompt: "\nAnswer the question using a single word or phrase." - ocr: false \ No newline at end of file + ocr: false + qwen_vl: + pre_prompt: "" + post_prompt: " Answer:"