Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add idefics2 #59

Merged
merged 1 commit into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lmms_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"gpt4v": "GPT4V",
"instructblip": "InstructBLIP",
"minicpm_v": "MiniCPM_V",
"idefics2": "Idefics2",
}

for model_name, model_class in AVAILABLE_MODELS.items():
Expand Down
223 changes: 223 additions & 0 deletions lmms_eval/models/idefics2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import torch
import logging
from tqdm import tqdm
from lmms_eval import utils
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model
from accelerate import Accelerator, DistributedType
from accelerate.state import AcceleratorState
from typing import List, Optional, Union, Tuple
from transformers import Idefics2ForConditionalGeneration, AutoProcessor

import warnings

warnings.filterwarnings("ignore")

eval_logger = logging.getLogger("lmms-eval")

DEFAULT_IMAGE_TOKEN = "<image>"
try:
import flash_attn
best_fit_attn_implementation = "flash_attention_2"
except ImportError:
best_fit_attn_implementation = "eager"

@register_model("idefics2")
class Idefics2(lmms):
"""
Idefics2 Model for Hugging Face Transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/modeling_idefics2.py

Example usage:

accelerate launch --num_processes=8 -m lmms_eval \
--model idefics2 \
--model_args pretrained=HuggingFaceM4/idefics2-8b \
--tasks mme \
--batch_size 1 \
--output_path ./logs/ \
--log_samples
"""

def __init__(
self,
pretrained: str = "HuggingFaceM4/idefics2-8b",
revision: str = "main",
device: str = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "float16",
batch_size: int = 1,
trust_remote_code: Optional[bool] = False,
attn_implementation: Optional[str] = best_fit_attn_implementation,
device_map: str = "",
use_cache: bool = True,
do_image_splitting: bool =False,
**kwargs,
) -> None:
super().__init__()
# Do not use kwargs for now
assert kwargs == {}, f"Unexpected kwargs: {kwargs}"

accelerator = Accelerator()
if accelerator.num_processes > 1 and device_map == "":
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"
else:
self._device = torch.device(device)
self.device_map = device_map
if isinstance(dtype, str) and dtype != "auto":
dtype = getattr(torch, dtype)
self._model = Idefics2ForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)
self._processor = AutoProcessor.from_pretrained(pretrained, do_image_splitting=do_image_splitting, revision=revision, trust_remote_code=trust_remote_code)

self._tokenizer = self._processor.tokenizer
self._config = self._model.config
self.batch_size_per_gpu = int(batch_size)
self.use_cache = use_cache
if accelerator.num_processes > 1 and device_map == "":
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
# Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works
# I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.
if accelerator.distributed_type == DistributedType.DEEPSPEED:
kwargs = {
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
"train_batch_size": self.batch_size_per_gpu * accelerator.num_processes,
}
AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)
eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0")
if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED:
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
elif accelerator.num_processes == 1 and device_map == "auto":
eval_logger.info(f"Using {accelerator.num_processes} devices with pipeline parallelism")
self._rank = 0
self._word_size = 1
else:
eval_logger.info(f"Using single device: {self._device}")
self.model.to(self._device)
self._rank = 0
self._word_size = 1

@property
def config(self):
# return the associated transformers.AutoConfig for the given pretrained model.
return self._config

@property
def tokenizer(self):
return self._tokenizer

@property
def model(self):
# returns the model, unwrapping it if using Accelerate
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self._model)
else:
return self._model

@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id

@property
def max_length(self):
return self._max_length

@property
def batch_size(self):
return self.batch_size_per_gpu

@property
def device(self):
return self._device

@property
def rank(self):
return self._rank

@property
def world_size(self):
return self._world_size

def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]:
""" """
add_special_tokens = False if add_special_tokens is None else add_special_tokens
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len:
encoding = encoding[-left_truncate_len:]
return encoding

def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)

def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
raise NotImplementedError("Loglikelihood is not implemented for Idefics2 model")

def flatten(self, input):
new_list = []
for i in input:
for j in i:
new_list.append(j)
return new_list

def generate_until(self, requests: List[Instance]) -> List[str]:
res = []

def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tok_encode(x[0])
return -len(toks), x[0]

# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")
for chunk in chunks:
contexts, all_gen_kwargs, doc_to_visuals, doc_id, tasks, splits = zip(*chunk)
visuals = [doc_to_visual(self.task_dict[task][split][ids]) for ids, task, split, doc_to_visual in zip(doc_id, tasks, splits, doc_to_visuals)]
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
#
until = gen_kwargs.pop("until", None)
prompts = []
for context, visual in zip(contexts, visuals):
content = []
if DEFAULT_IMAGE_TOKEN not in context:
for image in visual:
content.append({"type": "image"})
content.append({"type": "text", "text": context})
message = [{"role": "user", "content": content}]
prompt = self._processor.apply_chat_template(message, add_generation_prompt=True)
prompts.append(prompt)
inputs = self._processor(text=prompts, images=visuals, padding=True, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
output_ids = self.model.generate(**inputs, **gen_kwargs)
# only retain the generated text
for output_id, input_id in zip(output_ids, inputs["input_ids"]):
generated_id = output_id[len(input_id):]
generated_text = self.tokenizer.decode(generated_id, skip_special_tokens=True)

res.append(generated_text)
pbar.update(1)
# reorder this group of results back to original unsorted form
res = re_ords.get_original(res)

pbar.close()
return res
4 changes: 3 additions & 1 deletion lmms_eval/tasks/mmmu/mmmu_val.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ process_results: !function utils.mmmu_process_results
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
generation_kwargs:
max_new_tokens: 16
image_aspect_ratio: original
model_specific_generation_kwargs:
llava:
image_aspect_ratio: original
metric_list:
- metric: mmmu_acc
aggregation: !function utils.mmmu_aggregate_results
Expand Down
4 changes: 4 additions & 0 deletions lmms_eval/tasks/scienceqa/scienceqa_img.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ model_specific_prompt_kwargs:
post_prompt: "\nAnswer with the option's letter from the given choices directly."
qwen_vl:
format: qwen_vl
idefics2:
format: default
pre_prompt: ""
post_prompt: "\nAnswer:"
model_specific_generation_kwargs:
llava:
image_aspect_ratio: original
Expand Down
Loading