Skip to content

Commit

Permalink
Dev/py add models (#57)
Browse files Browse the repository at this point in the history
* add instructblip

* minicpm_v

* remove <image> from qwen-vl

* speed up postprocessing

* Optimize build context speed

---------

Co-authored-by: Pu Fanyi <FPU001@e.ntu.edu.sg>
Co-authored-by: kcz358 <kaichenzhang358@outlook.com>
  • Loading branch information
3 people authored Feb 29, 2024
1 parent 21050ba commit 6b20902
Show file tree
Hide file tree
Showing 8 changed files with 512 additions and 6 deletions.
29 changes: 27 additions & 2 deletions lmms_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,19 @@ def download(self, data_dir=None, cache_dir=None, download_mode=None) -> None:
cache_dir=cache_dir,
download_mode=download_mode,
)
self.dataset_no_image = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
data_dir=data_dir,
cache_dir=cache_dir,
download_mode=download_mode,
)
for doc_name in self.dataset_no_image:
column_names = self.dataset_no_image[doc_name].column_names
image_column = [col for col in column_names if "image" in col.lower()]
# remove image column from docs
if image_column:
self.dataset_no_image[doc_name] = self.dataset_no_image[doc_name].remove_columns(image_column)

@property
def config(self):
Expand Down Expand Up @@ -455,7 +468,7 @@ def fewshot_context(
assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`"

description = description if description else ""
doc = self.dataset[split][doc_id]
doc = self.dataset_no_image[split][doc_id]

if num_fewshot == 0:
labeled_examples = ""
Expand Down Expand Up @@ -674,6 +687,18 @@ def download(self, dataset_kwargs=None) -> None:
download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS,
**dataset_kwargs if dataset_kwargs is not None else {},
)
self.dataset_no_image = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS,
**dataset_kwargs if dataset_kwargs is not None else {},
)
for doc_name in self.dataset_no_image:
column_names = self.dataset_no_image[doc_name].column_names
image_column = [col for col in column_names if "image" in col.lower()]
# remove image column from docs
if image_column:
self.dataset_no_image[doc_name] = self.dataset_no_image[doc_name].remove_columns(image_column)

def has_training_docs(self) -> bool:
if self.config.training_split is not None:
Expand Down Expand Up @@ -731,7 +756,7 @@ def fewshot_context(self, doc_id, num_fewshot, split):
:returns: str
The fewshot context.
"""
doc = self.dataset[split][doc_id]
doc = self.dataset_no_image[split][doc_id]
if num_fewshot == 0:
# always prepend the (possibly empty) task description
labeled_examples = self.config.description
Expand Down
12 changes: 10 additions & 2 deletions lmms_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,15 @@ def evaluate(
# TODO: make it possible to use a different metric per filter
# iterate over different filters used
for key in task.instances[0].filtered_resps.keys():
doc_iterator = itertools.islice(enumerate(task.test_docs()), lm.rank, limit, lm.world_size) if task.has_test_docs() else itertools.islice(enumerate(task.validation_docs()), lm.rank, limit, lm.world_size)
# hack: remove image columns to speed avoid loading images and speed up postprocessing
# reason: doc_iterator will actually load image if it's in the doc.
docs = task.test_docs() if task.has_test_docs() else task.validation_docs()
column_names = docs.column_names
image_column = [col for col in column_names if "image" in col.lower()]
# remove image column from docs
if image_column:
docs = docs.remove_columns(image_column)
doc_iterator = itertools.islice(enumerate(docs), lm.rank, limit, lm.world_size)
# Instead of converting the iterator to a list, use `itertools.tee` to create a parallel iterator for counting
# doc_iterator, doc_iterator_for_counting = itertools.tee(doc_iterator)
# Don't use above one, this would crash if doc_iterator_for_counting contains too many objects and very slow
Expand All @@ -330,8 +338,8 @@ def evaluate(
target = task.doc_to_target(doc)
example = {
"doc_id": doc_id,
"doc": {k: v for k, v in doc.items() if "image" not in k}, # do not include image
"target": target,
"doc": doc,
"arguments": [tuple(a for a in req.args if isinstance(a, (int, str))) for req in requests], # do not include image
"resps": [req.resps for req in requests],
"filtered_resps": [req.filtered_resps[key] for req in requests],
Expand Down
3 changes: 2 additions & 1 deletion lmms_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from .qwen_vl import Qwen_VL
from .fuyu import Fuyu
from .gpt4v import GPT4V

from .instructblip import InstructBLIP
from .minicpm_v import MiniCPM_V
import os

try:
Expand Down
230 changes: 230 additions & 0 deletions lmms_eval/models/instructblip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
import torch
import logging
import copy
from tqdm import tqdm
from lmms_eval import utils
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model
from accelerate import Accelerator, DistributedType
from accelerate.state import AcceleratorState
from typing import List, Optional, Union, Tuple
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration

from lmms_eval.utils import stop_sequences_criteria


import warnings

warnings.filterwarnings("ignore")

eval_logger = logging.getLogger("lmms-eval")


@register_model("instructblip")
class InstructBLIP(lmms):
"""
InstructBLIP Model
"""

def __init__(
self,
pretrained: str = "Salesforce/instructblip-vicuna-7b",
device: Optional[str] = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto",
batch_size: Optional[Union[int, str]] = 1,
**kwargs,
) -> None:
super().__init__()
# Do not use kwargs for now
assert kwargs == {}, f"Unexpected kwargs: {kwargs}"

accelerator = Accelerator()
if accelerator.num_processes > 1:
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
else:
self._device = device
self._model = InstructBlipForConditionalGeneration.from_pretrained(pretrained,device_map=self._device)
self._image_processor = InstructBlipProcessor.from_pretrained(pretrained)
self._tokenizer = self._image_processor.tokenizer
self._config = self._model.config
self.model.eval()
self.model.tie_weights()
self.batch_size_per_gpu = int(batch_size)
if accelerator.num_processes > 1:
assert accelerator.distributed_type in [
DistributedType.FSDP,
DistributedType.MULTI_GPU,
DistributedType.DEEPSPEED
], "Unsupported distributed type provided. Only DDP and FSDP are supported."
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
# Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works
# I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.
if accelerator.distributed_type == DistributedType.DEEPSPEED:
kwargs = {
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
"train_batch_size" : self.batch_size_per_gpu * accelerator.num_processes,
}
AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)
eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0")
if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED:
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
else:
self.model.to(self._device)
self._rank = 0
self._word_size = 1

@property
def config(self):
# return the associated transformers.AutoConfig for the given pretrained model.
return self._config

@property
def tokenizer(self):
return self._tokenizer

@property
def model(self):
# returns the model, unwrapping it if using Accelerate
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self._model)
else:
return self._model

@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id

@property
def max_length(self):
return self._max_length


@property
def batch_size(self):
return self.batch_size_per_gpu

@property
def device(self):
return self._device

@property
def rank(self):
return self._rank

@property
def world_size(self):
return self._world_size

def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]:
""" """
add_special_tokens = False if add_special_tokens is None else add_special_tokens
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len:
encoding = encoding[-left_truncate_len:]
return encoding

def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)

def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
# TODO
assert False, "We have not implemented this function for InstructBLIP yet"

def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
# TODO
assert False, "We have not implemented this function for InstructBLIP yet"

def flatten(self, input):
new_list = []
for i in input:
for j in i:
new_list.append(j)
return new_list

def generate_until(self, requests: List[Instance]) -> List[str]:
res = []

def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tok_encode(x[0])
return -len(toks), x[0]

# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")
for chunk in chunks:
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
task = task[0]
split = split[0]
visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
visuals = self.flatten(visuals)
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]

# Set default values for until and max_new_tokens
until = [self.tok_decode(self.eot_token_id)]

# Update values from gen_kwargs if present
if "until" in gen_kwargs:
until = gen_kwargs.pop("until")
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}")
assert self.batch_size_per_gpu == 1, "Do not support batch_size_per_gpu > 1 for now"
context = contexts[0]
if "<image>" in context:
# instruct blip does not expect the <image> tag
context = context.replace("<image>", "")
inputs = self._image_processor(images=visuals, text=context, return_tensors="pt").to(self.device)

gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))]
if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 1024
if "temperature" not in gen_kwargs:
gen_kwargs["temperature"] = 0
if "top_p" not in gen_kwargs:
gen_kwargs["top_p"] = None
if "num_beams" not in gen_kwargs:
gen_kwargs["num_beams"] = 1
try:
cont = self.model.generate(
**inputs,
do_sample=True if gen_kwargs["temperature"] > 0 else False,
temperature=gen_kwargs["temperature"],
top_p=gen_kwargs["top_p"],
num_beams=gen_kwargs["num_beams"],
max_new_tokens=gen_kwargs["max_new_tokens"],
)
except Exception as e:
eval_logger.error(f"Error {e} in generating")
cont = ""
text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0].strip()
res.append(text_outputs)
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs)
pbar.update(1)
# reorder this group of results back to original unsorted form
res = re_ords.get_original(res)

pbar.close()
return res
Loading

0 comments on commit 6b20902

Please sign in to comment.