From a95ed3bce8fa506e54f47e973c1f84a8b565a363 Mon Sep 17 00:00:00 2001 From: Espere-1119-Song Date: Sat, 26 Oct 2024 21:14:45 +0800 Subject: [PATCH 1/5] Add AuroraCap, MovieChat, LLaVA-OneVision-MovieChat --- lmms_eval/models/auroracap.py | 541 ++++++++++++++++++ lmms_eval/models/llava_onevision_moviechat.py | 525 +++++++++++++++++ lmms_eval/models/moviechat.py | 452 +++++++++++++++ 3 files changed, 1518 insertions(+) create mode 100644 lmms_eval/models/auroracap.py create mode 100644 lmms_eval/models/llava_onevision_moviechat.py create mode 100644 lmms_eval/models/moviechat.py diff --git a/lmms_eval/models/auroracap.py b/lmms_eval/models/auroracap.py new file mode 100644 index 00000000..48f8028b --- /dev/null +++ b/lmms_eval/models/auroracap.py @@ -0,0 +1,541 @@ +import copy +import json +import logging +import os +import os.path as osp +from typing import List, Optional, Tuple, Union + +import av +import numpy as np +import torch +from accelerate import Accelerator, DistributedType +from accelerate.state import AcceleratorState +from huggingface_hub import snapshot_download +from peft import PeftModel +from PIL import Image +from tqdm import tqdm +from transformers import ( + AutoModel, + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + CLIPImageProcessor, +) + +from lmms_eval import utils +from lmms_eval.api.instance import Instance +from lmms_eval.api.model import lmms +from lmms_eval.api.registry import register_model +from lmms_eval.models.model_utils.load_video import read_video_pyav +from lmms_eval.utils import stop_sequences_criteria + +try: + from lmms_eval.models.aurora_xtuner.model.aurora import ( + AuroraEncoder, + AuroraModel, + AuroraSigEncoder, + ) + from lmms_eval.models.aurora_xtuner.utils import PROMPT_TEMPLATE +except ImportError: + eval_logger.error("AuroraCap is not installed. Please install AuroraCap to use this model by `git clone https://github.com/rese1f/aurora.git` and link `src/xtuner/xtuner` to `lmms_eval/models/aurora_xtuner`") +import warnings + +warnings.filterwarnings("ignore") + +eval_logger = logging.getLogger("lmms-eval") + +try: + from llava.constants import ( + DEFAULT_IM_END_TOKEN, + DEFAULT_IM_START_TOKEN, + DEFAULT_IMAGE_TOKEN, + IGNORE_INDEX, + IMAGE_TOKEN_INDEX, + ) + from llava.conversation import SeparatorStyle, conv_templates + from llava.mm_utils import get_model_name_from_path, tokenizer_image_token +except ImportError: + eval_logger.error("LLaVA is not installed. Please install LLaVA to use this model.") + + +@register_model("auroracap") +class AuroraCap(lmms): + """ + auroracap Model + """ + + def __init__( + self, + pretrained_llm: str = "meta-llama/Meta-Llama-3-8B-Instruct", + pretrained_vit: str = "google/siglip-so400m-patch14-384", + pretrained: str = "model/PATH", + resolution: int = 378, + token_merge_ratio: float = 0.4, + device: Optional[str] = "cuda", + dtype: Optional[Union[str, torch.dtype]] = "auto", + batch_size: Optional[Union[int, str]] = 1, + conv_template="vicuna_v1", # vicuna_v1", + video_decode_backend: str = "pyav", + max_frames_num: int = 16, + slowfast: bool = False, + **kwargs, + ) -> None: + super().__init__() + # Do not use kwargs for now + assert kwargs == {}, f"Unexpected kwargs: {kwargs}" + + accelerator = Accelerator() + if accelerator.num_processes > 1: + self._device = torch.device(f"cuda:{accelerator.local_process_index}") + else: + self._device = device + + pretrained_pth = snapshot_download(repo_id=pretrained) if not osp.isdir(pretrained) else pretrained + pretrained_llm = pretrained_pth + pretrained_vit = osp.join(pretrained_pth, "visual_encoder") + + self._model = AuroraModel( + slowfast=slowfast, + llm=AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=pretrained_llm, + trust_remote_code=True, + torch_dtype=torch.float16, + ), + visual_encoder=AuroraEncoder.from_pretrained( + pretrained_model_name_or_path=pretrained_vit, + torch_dtype=torch.float16, + ), + ) + + projector_path = osp.join(pretrained_pth, "projector") + self.model.projector = AutoModel.from_pretrained(projector_path, torch_dtype=torch.float16, trust_remote_code=True) + + self._image_processor = CLIPImageProcessor.from_pretrained( + pretrained_model_name_or_path="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", # use standard CLIP processor + trust_remote_code=True, + size=resolution, + crop_size=resolution, + ) + self._tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=pretrained_llm, + trust_remote_code=True, + padding_side="right", + ) + # compute token merge ratio settings + self.patch_size = self._model.visual_encoder.config.patch_size + self.num_layers = self._model.visual_encoder.config.num_hidden_layers + self.token_merge_ratio = token_merge_ratio + + self._config = self._model.config + self.model.eval() + self.model.tie_weights() + self.batch_size_per_gpu = int(batch_size) + self.conv_template = conv_template + if accelerator.num_processes > 1: + assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." + # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model + # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works + # I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work. + if accelerator.distributed_type == DistributedType.DEEPSPEED: + kwargs = { + "train_micro_batch_size_per_gpu": self.batch_size_per_gpu, + "train_batch_size": self.batch_size_per_gpu * accelerator.num_processes, + } + AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs) + eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0") + if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED: + self._model = accelerator.prepare(self.model) + self._model.visual_encoder = accelerator.prepare(self.model.visual_encoder) + self._model.projector = accelerator.prepare(self.model.projector) + else: # DistributedType.MULTI_GPU + self._model = accelerator.prepare_model(self.model, evaluation_mode=True) + self._model.visual_encoder = accelerator.prepare_model(self.model.visual_encoder, evaluation_mode=True) + self._model.projector = accelerator.prepare_model(self.model.projector, evaluation_mode=True) + self.accelerator = accelerator + if self.accelerator.is_local_main_process: + eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + else: + self.model.to(self._device) + self._rank = 0 + self._word_size = 1 + + # For Video Caption + self.video_decode_backend = video_decode_backend + self.max_frames_num = int(max_frames_num) + + @property + def config(self): + # return the associated transformers.AutoConfig for the given pretrained model. + return self._config + + @property + def tokenizer(self): + return self._tokenizer + + @property + def model(self): + # returns the model, unwrapping it if using Accelerate + if hasattr(self, "accelerator"): + return self.accelerator.unwrap_model(self._model) + else: + return self._model + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def max_length(self): + return self._max_length + + def pad_sequence(self, input_ids, batch_first, padding_value): + if self.tokenizer.padding_side == "left": + input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids] + input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value) + if self.tokenizer.padding_side == "left": + input_ids = torch.flip(input_ids, [1]) + return input_ids + + @property + def batch_size(self): + return self.batch_size_per_gpu + + @property + def device(self): + return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + def process_images(self, images, image_processor, model_cfg): + image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) + new_images = [] + if image_aspect_ratio == "pad": + for image in images: + image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean)) + image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0] + new_images.append(image) + elif image_aspect_ratio == "anyres": + for image in images: + image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints) + new_images.append(image) + else: + return image_processor(images, return_tensors="pt")["pixel_values"] + if all(x.shape == new_images[0].shape for x in new_images): + new_images = torch.stack(new_images, dim=0) + return new_images + + def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]: + """ """ + add_special_tokens = False if add_special_tokens is None else add_special_tokens + encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) + # left-truncate the encoded context to be at most `left_truncate_len` tokens long + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + return encoding + + def tok_decode(self, tokens): + return self.tokenizer.decode(tokens) + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + res = [] + pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") + + for contexts, doc_to_target, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: + # encode, pad, and truncate contexts for this batch + if type(doc_to_target) == str: + continuation = doc_to_target + else: + continuation = doc_to_target(self.task_dict[task][split][doc_id]) + visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] + visuals = self.flatten(visuals) + if visuals: + image = self.process_images(visuals, self._image_processor, self._config) + if type(image) is list: + image = [_image.to(dtype=torch.float16, device=self.device) for _image in image] + else: + image = image.to(dtype=torch.float16, device=self.device) + else: + image = None + + prompts_input = contexts[0] + + if image is not None and len(image) != 0 and DEFAULT_IMAGE_TOKEN not in prompts_input: + """ + Three senarios: + 1. No image, and there for, no image token should be added. + 2. image token is already specified in the context, so we don't need to add it. + 3. image token is not specified in the context and there is image inputs, so we need to add it. In this case, we add the image token at the beginning of the context and add a new line. + """ + image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visuals) + image_tokens = " ".join(image_tokens) + prompts_input = image_tokens + "\n" + contexts[0] + conv = conv_templates[self.conv_template].copy() + conv.append_message(conv.roles[0], prompts_input) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id + contxt_id = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device) + # Add the answer of the second role + conv.messages[1][1] = continuation + + prompt = conv.get_prompt() + input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device) + labels = input_ids.clone() + # Context part no need to calculate for loss + labels[0, : contxt_id.shape[1]] = -100 + with torch.inference_mode(): + data = dict() + data["pixel_values"] = image_tensor + data["input_ids"] = input_ids + data["attention_mask"] = attention_masks + self.model.visual_encoder.reset_tome_r(self.token_merge_ratio) + output = self.model(data, mode="tensor") + + loss = outputs["loss"] + # loss = torch.exp(loss) + logits = outputs["logits"] + greedy_tokens = logits.argmax(dim=-1) + cont_toks = input_ids[:, contxt_id.shape[1] :] # [1, seq] + greedy_tokens = greedy_tokens[:, contxt_id.shape[1] : input_ids.shape[1]] # [1, seq] + max_equal = (greedy_tokens == cont_toks).all() + res.append((float(loss.item()), bool(max_equal))) + pbar.update(1) + pbar.close() + return res + + def flatten(self, input): + new_list = [] + for i in input: + for j in i: + new_list.append(j) + return new_list + + def load_video(self, video_path, max_frames_num): + vr = VideoReader(video_path, ctx=cpu(0)) + total_frame_num = len(vr) + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int) + frame_idx = uniform_sampled_frames.tolist() + spare_frames = vr.get_batch(frame_idx).asnumpy() + return spare_frames # (frames, height, width, channels) + + def extract_keyframes(self, video_path, keyframes): + container = av.open(video_path) + video_stream = container.streams.video[0] + fps = video_stream.average_rate + time_base = video_stream.time_base + frames = [] + + for keyframe in keyframes: + keyframe_time = float(keyframe) + frame_number = int(keyframe_time * fps) + container.seek(int(keyframe_time / time_base)) + found = False + for packet in container.demux(video=0): + for frame in packet.decode(): + if frame.index >= frame_number: + frames.append(frame) + found = True + break + if found: + break + + if not found: + container.seek(-1, any_frame=False) + for packet in container.demux(video=0): + for frame in packet.decode(): + pass + frames.append(frame) + + return np.stack([x.to_ndarray(format="rgb24") for x in frames]) + + def generate_until(self, requests: List[Instance]) -> List[str]: + res = [] + + def _collate(x): + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + toks = self.tok_encode(x[0]) + return -len(toks), x[0] + + # we group requests by their generation_kwargs, + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling + # in the same batch. + re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True) + chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) + num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1 + pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding") + + for chunk in chunks: + contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk) + task = task[0] + split = split[0] + visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] # the length of visuals is 1, equal to batchsize + visuals = self.flatten(visuals) + # we assume all gen kwargs in the batch are the same + # this is safe to assume because the `grouper` object ensures it. + gen_kwargs = all_gen_kwargs[0] + + # Set default values for until and max_new_tokens + until = [self.tok_decode(self.eot_token_id)] + + # Update values from gen_kwargs if present + if "until" in gen_kwargs: + until = gen_kwargs.pop("until") + if isinstance(until, str): + until = [until] + elif not isinstance(until, list): + raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}") + + if "image_aspect_ratio" in gen_kwargs.keys() and "image_aspect_ratio" not in self._config.__dict__: + # here we should pop it out of gen_kwargs so that it doesn't get passed to the model for next step of generation + self._config.image_aspect_ratio = gen_kwargs.pop("image_aspect_ratio") + eval_logger.info(f"Setting image aspect ratio: {self._config.image_aspect_ratio}") + # encode, pad, and truncate contexts for this batch + if visuals: + if isinstance(visuals[0], dict): + video_path = visuals[0]["video_path"] + keyframe = visuals[0]["keyframe"] + video = self.extract_keyframes(video_path, keyframe) + image_tensor = self.process_images(video, self._image_processor, self._config).cuda() + elif isinstance(visuals, list): + print(visuals[0]) + if isinstance(visuals[0], Image.Image): + image_tensor = self.process_images(visuals, self._image_processor, self._config) + else: + if visuals[0].endswith("mp4"): + if self.video_decode_backend == "decord": + video = self.load_video(visuals[0], self.max_frames_num) + elif self.video_decode_backend == "pyav": + video = read_video_pyav(visuals[0], num_frm=self.max_frames_num) + image_tensor = self.process_images(video, self._image_processor, self._config).cuda() + elif visuals[0].endswith("mkv"): + assert self.video_decode_backend == "pyav", "we only tested this case, decord may not work" + video = read_video_pyav(visuals[0], num_frm=self.max_frames_num) + image_tensor = self.process_images(video, self._image_processor, self._config).cuda() + + if type(image_tensor) is list: + image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor] + else: + image_tensor = image_tensor.to(dtype=torch.float16, device=self.device) + + else: + image_tensor = None + + question_input = [] + + for visual, context in zip(visuals, contexts): + if image_tensor is not None and len(image_tensor) != 0 and DEFAULT_IMAGE_TOKEN not in context: + """ + Three senarios: + 1. No image, and there for, no image token should be added. + 2. image token is already specified in the context, so we don't need to add it. + 3. image token is not specified in the context and there is image inputs, so we need to add it. In this case, we add the image token at the beginning of the context and add a new line. + """ + if isinstance(visuals[0], dict): + image_tokens = [DEFAULT_IMAGE_TOKEN] * len(video) + elif isinstance(visuals, list): + if isinstance(visuals[0], Image.Image): + image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visual) if isinstance(visual, list) else [DEFAULT_IMAGE_TOKEN] + else: + if visual.endswith("mp4") or visual.endswith("mkv"): + image_tokens = [DEFAULT_IMAGE_TOKEN] * len(video) + + image_tokens = " ".join(image_tokens) + question = image_tokens + "\n" + context + + else: + question = context + conv = conv_templates[self.conv_template].copy() + conv.append_message(conv.roles[0], question) + conv.append_message(conv.roles[1], None) + prompt_question = conv.get_prompt() + question_input.append(prompt_question) + + # The above for loop has bugs. When there is no visuals, e.g. pure text, + # there will be no for loop execute resulting in an empty question_input (because no visuals) + # Scenario 1 won't even be execute + if len(visuals) == 0: + for context in contexts: + question = context + conv = conv_templates[self.conv_template].copy() + conv.append_message(conv.roles[0], question) + conv.append_message(conv.roles[1], None) + prompt_question = conv.get_prompt() + question_input.append(prompt_question) + + # preconfigure gen_kwargs with defaults + if isinstance(visuals[0], dict): + gen_kwargs["image_sizes"] = [video[idx].size for idx in range(len(video))] + elif isinstance(visuals, list): + if isinstance(visuals[0], Image.Image): + gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))] + else: + if visuals[0].endswith("mp4"): + gen_kwargs["image_sizes"] = [video[idx].size for idx in range(len(video))] + + if "max_new_tokens" not in gen_kwargs: + gen_kwargs["max_new_tokens"] = 1024 + if "temperature" not in gen_kwargs: + gen_kwargs["temperature"] = 0 + if "top_p" not in gen_kwargs: + gen_kwargs["top_p"] = None + if "num_beams" not in gen_kwargs: + gen_kwargs["num_beams"] = 1 + + input_ids_list = [tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") for prompt in question_input] + pad_token_ids = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id + input_ids = self.pad_sequence(input_ids_list, batch_first=True, padding_value=pad_token_ids).to(self.device) + attention_masks = input_ids.ne(pad_token_ids).to(self.device) + # These steps are not in LLaVA's original code, but are necessary for generation to work + try: + data = dict() + if isinstance(visuals[0], dict): + data["pixel_values"] = image_tensor.unsqueeze(0) + elif isinstance(visuals, list): + if isinstance(visuals[0], Image.Image): + data["pixel_values"] = image_tensor + else: + if visuals[0].endswith("mp4") or visuals[0].endswith("mkv"): + data["pixel_values"] = image_tensor.unsqueeze(0) + + data["input_ids"] = input_ids + data["attention_mask"] = attention_masks + self.model.visual_encoder.reset_tome_r(self.token_merge_ratio) + output = self.model(data, mode="inference") + cont = self.model.llm.generate( + **output, + do_sample=True if gen_kwargs["temperature"] > 0 else False, + temperature=gen_kwargs["temperature"], + top_p=gen_kwargs["top_p"], + num_beams=gen_kwargs["num_beams"], + max_new_tokens=gen_kwargs["max_new_tokens"], + ) + text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True) + except Exception as e: + eval_logger.error(f"Error {e} in generating") + cont = "" + text_outputs = [""] + + print(text_outputs) + + res.extend(text_outputs) + self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs) + pbar.update(1) + # reorder this group of results back to original unsorted form + res = re_ords.get_original(res) + + pbar.close() + return res diff --git a/lmms_eval/models/llava_onevision_moviechat.py b/lmms_eval/models/llava_onevision_moviechat.py new file mode 100644 index 00000000..39ccd802 --- /dev/null +++ b/lmms_eval/models/llava_onevision_moviechat.py @@ -0,0 +1,525 @@ +import copy +import json +import logging +import math +import os +import re +import warnings +from datetime import timedelta +from typing import List, Optional, Tuple, Union + +import av +import numpy as np +import PIL +import torch +import transformers +from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs +from accelerate.state import AcceleratorState +from decord import VideoReader, cpu +from moviepy.video.io.VideoFileClip import VideoFileClip +from packaging import version +from PIL import Image +from tqdm import tqdm +from transformers import AutoConfig + +from lmms_eval import utils +from lmms_eval.api.instance import Instance +from lmms_eval.api.model import lmms +from lmms_eval.api.registry import register_model + +# Suppress warnings +warnings.filterwarnings("ignore") + +# Configure logging +eval_logger = logging.getLogger("lmms-eval") + +# Enable TF32 for CUDA +torch.backends.cuda.matmul.allow_tf32 = True + +# Import LLaVA modules +try: + from llava.constants import ( + DEFAULT_IM_END_TOKEN, + DEFAULT_IM_START_TOKEN, + DEFAULT_IMAGE_TOKEN, + IGNORE_INDEX, + IMAGE_TOKEN_INDEX, + ) + from llava.conversation import SeparatorStyle, conv_templates + from llava.mm_utils import ( + KeywordsStoppingCriteria, + get_model_name_from_path, + process_images, + tokenizer_image_token, + ) + from llava.model.builder import load_pretrained_model +except ImportError as e: + eval_logger.debug(f"LLaVA_NeXT is not installed. Please install llava from `https://github.com/rese1f/MovieChat.git` to use this model.\nError: {e}") + + +# Determine best attention implementation +if version.parse(torch.__version__) >= version.parse("2.1.2"): + best_fit_attn_implementation = "sdpa" +else: + best_fit_attn_implementation = "eager" + + +@register_model("llava_onevision_moviechat") +class Llava_OneVision_MovieChat(lmms): + """ + Llava Model + """ + + def __init__( + self, + pretrained: str = "lmms-lab/llava-onevision-qwen2-7b-ov", + truncation: Optional[bool] = True, + device: Optional[str] = "cuda:0", + batch_size: Optional[Union[int, str]] = 1, + model_name: str = "llava_qwen", + attn_implementation: Optional[str] = best_fit_attn_implementation, + device_map: Optional[str] = "cuda:0", + conv_template: Optional[str] = "qwen_1_5", + use_cache: Optional[bool] = True, + truncate_context: Optional[bool] = False, # whether to truncate the context in generation, set it False for LLaVA-1.6 + customized_config: Optional[str] = None, # ends in json + short_memory_length: Optional[int] = 18, + long_memory_length: Optional[int] = 64, + sliding_window_length: Optional[int] = 8, + merge_frame_length: Optional[int] = 2, + tmp_folder: Optional[str] = "tmp/", + mm_spatial_pool_stride: Optional[int] = 2, + mm_spatial_pool_mode: Optional[str] = "bilinear", + token_strategy: Optional[str] = "single", # could be "single" or "multiple", "multiple" denotes adding multiple tokens for each frame + video_decode_backend: str = "decord", + **kwargs, + ) -> None: + super().__init__() + # Do not use kwargs for now + assert kwargs == {}, f"Unexpected kwargs: {kwargs}" + + accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) + accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) + if accelerator.num_processes > 1: + self._device = torch.device(f"cuda:{accelerator.local_process_index}") + self.device_map = f"cuda:{accelerator.local_process_index}" + elif accelerator.num_processes == 1 and device_map == "auto": + self._device = torch.device(device) + self.device_map = device_map + else: + self._device = torch.device(f"cuda:{accelerator.local_process_index}") + self.device_map = f"cuda:{accelerator.local_process_index}" + + llava_model_args = { + "multimodal": True, + } + if customized_config is not None: + llava_model_args["customized_config"] = customized_config + if attn_implementation is not None: + llava_model_args["attn_implementation"] = attn_implementation + if "use_flash_attention_2" in kwargs: + llava_model_args["use_flash_attention_2"] = kwargs["use_flash_attention_2"] + model_name = model_name if model_name is not None else get_model_name_from_path(pretrained) + + self.pretrained = pretrained + self.token_strategy = token_strategy + self.mm_spatial_pool_stride = mm_spatial_pool_stride + self.mm_spatial_pool_mode = mm_spatial_pool_mode + self.video_decode_backend = video_decode_backend + + self.short_memory_length = short_memory_length + self.long_memory_length = long_memory_length + self.merge_frame_length = merge_frame_length + self.sliding_window_length = sliding_window_length + self.num_clips = (self.long_memory_length // self.merge_frame_length) * ((self.short_memory_length - self.merge_frame_length) // self.sliding_window_length) + self.tmp_folder = tmp_folder + + overwrite_config = {} + overwrite_config["mm_spatial_pool_stride"] = self.mm_spatial_pool_stride + overwrite_config["mm_spatial_pool_mode"] = self.mm_spatial_pool_mode + cfg_pretrained = AutoConfig.from_pretrained(self.pretrained) + + if cfg_pretrained.architectures[0] == "LlavaLlamaForCausalLM": # Ugly code, only used in vicuna that needs ROPE + if "224" in cfg_pretrained.mm_vision_tower: + least_token_number = self.max_frames_num * (16 // self.mm_spatial_pool_stride) ** 2 + 1000 + else: + least_token_number = self.max_frames_num * (24 // self.mm_spatial_pool_stride) ** 2 + 1000 + + scaling_factor = math.ceil(least_token_number / 4096) + if scaling_factor >= 2: + overwrite_config["rope_scaling"] = {"factor": float(scaling_factor), "type": "linear"} + overwrite_config["max_sequence_length"] = 4096 * scaling_factor + overwrite_config["tokenizer_model_max_length"] = 4096 * scaling_factor + + llava_model_args["overwrite_config"] = overwrite_config + from LLaVA_NeXT.llava.model.builder import load_pretrained_model + + try: + # Try to load the model with the multimodal argument + self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, model_name, device_map=self.device_map, **llava_model_args) + except TypeError: + # for older versions of LLaVA that don't have multimodal argument + llava_model_args.pop("multimodal", None) + self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, model_name, device_map=self.device_map, **llava_model_args) + + self._config = self._model.config + self.model.eval() + self.truncation = truncation + self.batch_size_per_gpu = int(batch_size) + self.conv_template = conv_template + self.use_cache = use_cache + self.truncate_context = truncate_context + assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue." + + if accelerator.num_processes > 1: + assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." + # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model + # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works + # I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work. + if accelerator.distributed_type == DistributedType.DEEPSPEED: + kwargs = { + "train_micro_batch_size_per_gpu": self.batch_size_per_gpu, + "train_batch_size": self.batch_size_per_gpu * accelerator.num_processes, + } + AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs) + eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0") + + if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED: + self._model = accelerator.prepare(self.model) + else: + self._model = accelerator.prepare_model(self.model, evaluation_mode=True) + self.accelerator = accelerator + if self.accelerator.is_local_main_process: + eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + + elif accelerator.num_processes == 1 and device_map == "auto": + eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism") + self._rank = 0 + self._world_size = 1 + + else: + eval_logger.info(f"Using single device: {self._device}") + self.model.to(self._device) + self._rank = 0 + self._world_size = 1 + + @property + def config(self): + # return the associated transformers.AutoConfig for the given pretrained model. + return self._config + + @property + def tokenizer(self): + return self._tokenizer + + @property + def model(self): + # returns the model, unwrapping it if using Accelerate + if hasattr(self, "accelerator"): + return self.accelerator.unwrap_model(self._model) + else: + return self._model + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def max_length(self): + return self._max_length + + def pad_sequence(self, input_ids, batch_first, padding_value): + if self.tokenizer.padding_side == "left": + input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids] + input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value) + if self.tokenizer.padding_side == "left": + input_ids = torch.flip(input_ids, [1]) + return input_ids + + @property + def batch_size(self): + return self.batch_size_per_gpu + + @property + def device(self): + return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]: + """ """ + add_special_tokens = False if add_special_tokens is None else add_special_tokens + encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) + # left-truncate the encoded context to be at most `left_truncate_len` tokens long + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + return encoding + + def tok_decode(self, tokens): + try: + return self.tokenizer.decode(tokens) + except: + return self.tokenizer.decode([tokens]) + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + # TODO + raise NotImplementedError("MovieChat only supports generation.") + + def flatten(self, input): + new_list = [] + for i in input: + for j in i: + new_list.append(j) + return new_list + + def load_video(self, video_path, max_frames_num): + if type(video_path) == str: + vr = VideoReader(video_path, ctx=cpu(0)) + else: + vr = VideoReader(video_path[0], ctx=cpu(0)) + total_frame_num = len(vr) + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int) + frame_idx = uniform_sampled_frames.tolist() + spare_frames = vr.get_batch(frame_idx).numpy() + return spare_frames # (frames, height, width, channels) + + def extract_keyframes(self, video_path, keyframes): + container = av.open(video_path) + video_stream = container.streams.video[0] + fps = video_stream.average_rate + time_base = video_stream.time_base + frames = [] + + for keyframe in keyframes: + keyframe_time = float(keyframe) + frame_number = int(keyframe_time * fps) + container.seek(int(keyframe_time / time_base)) + found = False + for packet in container.demux(video=0): + for frame in packet.decode(): + if frame.index >= frame_number: + frames.append(frame) + found = True + break + if found: + break + + if not found: + container.seek(-1, any_frame=False) + for packet in container.demux(video=0): + for frame in packet.decode(): + pass + frames.append(frame) + + video = [x.to_ndarray(format="rgb24") for x in frames] + video_frames = [Image.fromarray(x) for x in video] + return video_frames + + def generate_until(self, requests: List[Instance]) -> List[str]: + res = [] + + def _collate(x): + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + toks = self.tok_encode(x[0]) + return -len(toks), x[0] + + # we group requests by their generation_kwargs, + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling + # in the same batch. + metadata = requests[0].metadata + re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True) + chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) + num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1 + pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding") + for chunk in chunks: + batched_contexts, all_gen_kwargs, batched_doc_to_visual, batched_doc_id, batched_task, batched_split = zip(*chunk) + task = batched_task[0] + split = batched_split[0] + batched_visuals = [batched_doc_to_visual[0](self.task_dict[task][split][ids]) for ids in batched_doc_id] # [B, N] + assert len(batched_visuals) == 1 + + # we assume all gen kwargs in the batch are the same + # this is safe to assume because the `grouper` object ensures it. + gen_kwargs = all_gen_kwargs[0] + if "until" in gen_kwargs: + gen_kwargs.pop("until") + + question_input = [] + + for visual, context in zip(batched_visuals, batched_contexts): + if len(visual) > 1 or "image_aspect_ratio" not in self._config.__dict__: # for multi image case, we treat per image aspect ratio as "pad" by default. + self._config.image_aspect_ratio = getattr(gen_kwargs, "image_aspect_ratio", "pad") + eval_logger.info(f"Setting image aspect ratio: {self._config.image_aspect_ratio}") + if type(visual[0]) == PIL.Image.Image and "task_type" not in metadata and "sample_frames" not in metadata: # For image task + raise NotImplementedError("MovieChat only supports video inputs.") + + elif "task_type" in metadata and metadata["task_type"] == "video" and "sample_frames" in metadata: + raise NotImplementedError("MovieChat only supports video inputs.") + + elif type(visual[0]) == str: # For video task + try: + self.short_memory_buffer = [] + self.long_memory_buffer = [] + # try: + os.makedirs(self.tmp_folder, exist_ok=True) + + video = VideoFileClip(visual[0]) + clip_duration = video.duration / self.num_clips + + cur_frame = 0 + for i in range(self.num_clips): + start_time = i * clip_duration + end_time = start_time + clip_duration + # uniformly sample self.sliding_window_length frames from the video from start_time to end_time + frames = list(video.subclip(start_time, end_time).iter_frames(fps=self.sliding_window_length / clip_duration))[: self.sliding_window_length] + frames = [Image.fromarray(frame).convert("RGB") for frame in frames] + preprocess_frames = self._image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].half().cuda() + encoded_window = self.model.encode_images(preprocess_frames) # [frames, 729,3584] + + for frame in encoded_window: + if cur_frame < (self.short_memory_length - self.merge_frame_length): + if len(self.short_memory_buffer) == self.short_memory_length: + self.short_memory_buffer.pop(0) + self.short_memory_buffer.append(frame) + cur_frame += 1 + + if cur_frame == (self.short_memory_length - self.merge_frame_length): + cur_frame = 0 + + # merge short_memory_frames + similar_list = [] + for frame_i in range(len(self.short_memory_buffer) - 1): + scores = self.short_memory_buffer[frame_i] @ self.short_memory_buffer[frame_i + 1].transpose(-1, -2) + frame_silimar = torch.mean(scores) + similar_list.append(frame_silimar) + + while len(self.short_memory_buffer) > self.merge_frame_length: + max_value = max(similar_list) + max_index = similar_list.index(max_value) + new_frame_feature = (self.short_memory_buffer[max_index].cpu() + self.short_memory_buffer[max_index + 1].cpu()) / 2 + self.short_memory_buffer[max_index] = new_frame_feature.cuda() + del self.short_memory_buffer[max_index + 1] + similar_list = [] + for frame_i in range(len(self.short_memory_buffer) - 1): + scores = self.short_memory_buffer[frame_i] @ self.short_memory_buffer[frame_i + 1].transpose(-1, -2) + frame_silimar = torch.mean(scores) + similar_list.append(frame_silimar) + + for frame in self.short_memory_buffer: + self.long_memory_buffer.append(frame) + + image_features = torch.stack(self.long_memory_buffer) + except Exception as e: + print(e) + eval_logger.error(f"Error {e} in loading video") + image_features = None + + task_type = "video" + placeholder_count = len(frames) if self.token_strategy == "multiple" else 1 + + if image_features is not None and DEFAULT_IMAGE_TOKEN not in context: + """ + Three senarios: + 1. No image, and there for, no image token should be added. + 2. image token is already specified in the context, so we don't need to add it. + 3. image token is not specified in the context and there is image inputs, so we need to add it. In this case, we add the image token at the beginning of the context and add a new line. + 4. For video tasks, we could add a token or multiple tokens for each frame in the context. This depends on the training strategy and should balance in test to decide which is better + """ + # if task_type == "image": # indeed in multi-image case, not the video in frames. + # image_tokens = [DEFAULT_IMAGE_TOKEN] * placeholder_count if isinstance(visual, list) else [DEFAULT_IMAGE_TOKEN] + # elif task_type == "video": + # image_tokens = [DEFAULT_IMAGE_TOKEN] * placeholder_count if self.token_strategy == "multiple" else [DEFAULT_IMAGE_TOKEN] + image_tokens = [DEFAULT_IMAGE_TOKEN] * placeholder_count + image_tokens = " ".join(image_tokens) + question = image_tokens + "\n" + context + else: + question = context + + # This is much safer for llama3, as we now have some object type in it + if "llama_3" in self.conv_template: + conv = copy.deepcopy(conv_templates[self.conv_template]) + else: + conv = conv_templates[self.conv_template].copy() + + if utils.is_json(question): # conversational question input + question = json.loads(question) + for idx, item in enumerate(question): + role = conv.roles[idx % 2] + message = item["value"] + conv.append_message(role, message) + + assert len(conv.messages) % 2 == 1 + conv.append_message(conv.roles[1], None) + prompt_question = conv.get_prompt() + question_input.append(prompt_question) + else: # only simple string for question + conv.append_message(conv.roles[0], question) + conv.append_message(conv.roles[1], None) + prompt_question = conv.get_prompt() + question_input.append(prompt_question) + + # preconfigure gen_kwargs with defaults + if "max_new_tokens" not in gen_kwargs: + gen_kwargs["max_new_tokens"] = 1024 + if "temperature" not in gen_kwargs: + gen_kwargs["temperature"] = 0 + if "do_sample" not in gen_kwargs: + gen_kwargs["do_sample"] = False + if "top_p" not in gen_kwargs: + gen_kwargs["top_p"] = None + if "num_beams" not in gen_kwargs: + gen_kwargs["num_beams"] = 1 + + input_ids_list = [tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") for prompt in question_input] + pad_token_ids = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id + input_ids = self.pad_sequence(input_ids_list, batch_first=True, padding_value=pad_token_ids).to(self.device) + attention_masks = input_ids.ne(pad_token_ids).to(self.device) + + if task_type == "image": + gen_kwargs["image_sizes"] = [batched_visuals[0][idx].size for idx in range(len(batched_visuals[0]))] + elif task_type == "video": + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) + gen_kwargs["modalities"] = ["video"] + gen_kwargs["stopping_criteria"] = [stopping_criteria] + self._config.mm_spatial_pool_stride = self.mm_spatial_pool_stride + self._config.mm_spatial_pool_mode = self.mm_spatial_pool_mode + + # These steps are not in LLaVA's original code, but are necessary for generation to work + # TODO: attention to this major generation step... + if "image_aspect_ratio" in gen_kwargs.keys(): + gen_kwargs.pop("image_aspect_ratio") + try: + with torch.inference_mode(): + gen_kwargs.pop("modalities") + cont = self.model.generate_moviechat(input_ids, attention_mask=attention_masks, pad_token_id=pad_token_ids, image_features=image_features, use_cache=self.use_cache, **gen_kwargs) + text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True) + text_outputs = [response.strip() for response in text_outputs] + except Exception as e: + print(e) + text_outputs = "Can not infer the answer." + + res.extend(text_outputs) + print(text_outputs) + self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs) + pbar.update(1) + # reorder this group of results back to original unsorted form + res = re_ords.get_original(res) + + pbar.close() + return res diff --git a/lmms_eval/models/moviechat.py b/lmms_eval/models/moviechat.py new file mode 100644 index 00000000..cacfe152 --- /dev/null +++ b/lmms_eval/models/moviechat.py @@ -0,0 +1,452 @@ +import copy +import json +import logging +import math +import os +import os.path as osp +import queue +import re +import warnings +from datetime import timedelta +from typing import List, Optional, Tuple, Union + +import av +import einops +import numpy as np +import PIL +import torch +import transformers +from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs +from accelerate.state import AcceleratorState +from decord import VideoReader, cpu +from huggingface_hub import snapshot_download +from moviepy.video.io.VideoFileClip import VideoFileClip +from packaging import version +from PIL import Image +from scipy.spatial.distance import cosine +from skimage import transform +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode +from tqdm import tqdm +from transformers import StoppingCriteria, StoppingCriteriaList + +from lmms_eval import utils +from lmms_eval.api.instance import Instance +from lmms_eval.api.model import lmms +from lmms_eval.api.registry import register_model + +# Suppress warnings +warnings.filterwarnings("ignore") + +# Configure logging +eval_logger = logging.getLogger("lmms-eval") + +# Enable TF32 for CUDA +torch.backends.cuda.matmul.allow_tf32 = True + +# Import LLaVA modules +try: + from MovieChat.common.registry import registry +except ImportError as e: + eval_logger.debug( + f"MovieChat is not installed. First, install MovieChat by 'https://github.com/rese1f/MovieChat.git' and 'cd MovieChat'. Change the torch version with `python -m pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url https://download.pytorch.org/whl/cu118`" + ) + + +# Determine best attention implementation +if version.parse(torch.__version__) >= version.parse("2.1.2"): + best_fit_attn_implementation = "sdpa" +else: + best_fit_attn_implementation = "eager" + + +class StoppingCriteriaSub(StoppingCriteria): + def __init__(self, stops=[], encounters=1): + super().__init__() + self.stops = stops + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): + for stop in self.stops: + if torch.all((stop == input_ids[0][-len(stop) :])).item(): + return True + + return False + + +@register_model("moviechat") +class MovieChat(lmms): + """ + MovieChat Model + """ + + def __init__( + self, + truncation: Optional[bool] = True, + device: Optional[str] = "cuda:0", + batch_size: Optional[Union[int, str]] = 1, + pretrained_llama_model: str = "Enxin/MovieChat-vicuna", + pretrained_llama_proj_model: str = "Enxin/MovieChat-proj", + attn_implementation: Optional[str] = best_fit_attn_implementation, + device_map: Optional[str] = "cuda:0", + use_cache: Optional[bool] = True, + truncate_context: Optional[bool] = False, # whether to truncate the context in generation, set it False for LLaVA-1.6 + customized_config: Optional[str] = None, # ends in json + short_memory_length: Optional[int] = 18, + long_memory_length: Optional[int] = 256, + sliding_window_length: Optional[int] = 8, + merge_frame_length: Optional[int] = 2, + tmp_folder: Optional[str] = "tmp/", + **kwargs, + ) -> None: + super().__init__() + # Do not use kwargs for now + assert kwargs == {}, f"Unexpected kwargs: {kwargs}" + + accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) + accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) + if accelerator.num_processes > 1: + self._device = torch.device(f"cuda:{accelerator.local_process_index}") + self.device_map = f"cuda:{accelerator.local_process_index}" + elif accelerator.num_processes == 1 and device_map == "auto": + self._device = torch.device(device) + self.device_map = device_map + else: + self._device = torch.device(f"cuda:{accelerator.local_process_index}") + self.device_map = f"cuda:{accelerator.local_process_index}" + + llama_model = snapshot_download(repo_id=pretrained_llama_model) if not osp.isdir(pretrained_llama_model) else pretrained_llama_model + llama_proj_pth = snapshot_download(repo_id=pretrained_llama_proj_model) if not osp.isdir(pretrained_llama_proj_model) else pretrained_llama_proj_model + llama_proj = osp.join(llama_proj_pth, "finetune-vicuna7b-v2.pth") + model_config = { + "arch": "moviechat", + "model_type": "pretrain_vicuna", + "freeze_vit": True, + "freeze_qformer": True, + "max_txt_len": 256, + "end_sym": "###", + "low_resource": False, + "frozen_llama_proj": False, + "llama_model": llama_model, + "llama_proj_model": llama_proj, + } + + model_cls = registry.get_model_class(model_config["arch"]) + self._model = model_cls.from_config(model_config).to(self.device_map) + + vis_processor_cfg = { + "name": "alpro_video_eval", + "n_frms": 8, + } + self.transform = transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] # Resize to 224x224 # Convert PIL Image to Tensor with shape [C, H, W] # Normalize + ) + self._image_processor = registry.get_processor_class(vis_processor_cfg["name"]).from_config(vis_processor_cfg) + + self.model.short_memory_length = short_memory_length + self.model.long_memory_length = long_memory_length + self.merge_frame_length = merge_frame_length + self.sliding_window_length = sliding_window_length + self.num_clips = (self.model.long_memory_length // self.merge_frame_length) * ((self.model.short_memory_length - self.merge_frame_length) // self.sliding_window_length) + self.tmp_folder = tmp_folder + + self._tokenizer = self.model.llama_tokenizer + stop_words_ids = [torch.tensor([835]).to(self.device), torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways. + self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) + + self.model.eval() + self.truncation = truncation + self.batch_size_per_gpu = int(batch_size) + self.use_cache = use_cache + self.truncate_context = truncate_context + assert self.batch_size_per_gpu == 1, "MovieChat currently does not support batched generation." + + if accelerator.num_processes > 1: + assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." + # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model + # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works + # I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work. + if accelerator.distributed_type == DistributedType.DEEPSPEED: + kwargs = { + "train_micro_batch_size_per_gpu": self.batch_size_per_gpu, + "train_batch_size": self.batch_size_per_gpu * accelerator.num_processes, + } + AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs) + eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0") + + if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED: + self._model = accelerator.prepare(self.model) + else: + self._model = accelerator.prepare_model(self.model, evaluation_mode=True) + self.accelerator = accelerator + if self.accelerator.is_local_main_process: + eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + + elif accelerator.num_processes == 1 and device_map == "auto": + eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism") + self._rank = 0 + self._world_size = 1 + + else: + eval_logger.info(f"Using single device: {self._device}") + self.model.to(self._device) + self._rank = 0 + self._world_size = 1 + + @property + def config(self): + # return the associated transformers.AutoConfig for the given pretrained model. + return self._config + + @property + def tokenizer(self): + return self._tokenizer + + @property + def model(self): + # returns the model, unwrapping it if using Accelerate + if hasattr(self, "accelerator"): + return self.accelerator.unwrap_model(self._model) + else: + return self._model + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def max_length(self): + return self._max_length + + @property + def batch_size(self): + return self.batch_size_per_gpu + + @property + def device(self): + return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]: + """ """ + add_special_tokens = False if add_special_tokens is None else add_special_tokens + encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) + # left-truncate the encoded context to be at most `left_truncate_len` tokens long + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + return encoding + + def tok_decode(self, tokens): + try: + return self.tokenizer.decode(tokens) + except: + return self.tokenizer.decode([tokens]) + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + # TODO + raise NotImplementedError("MovieChat only supports generation.") + + def flatten(self, input): + new_list = [] + for i in input: + for j in i: + new_list.append(j) + return new_list + + def get_context_emb(self, input_text, img_list): + prompt_1 = "You are able to understand the visual content that the user provides.Follow the instructions carefully and explain your answers in details.###Human: " + prompt_2 = input_text + prompt_3 = "###Assistant:" + + prompt = prompt_1 + " " + prompt_2 + prompt_3 + + prompt_segs = prompt.split("") + assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." + seg_tokens = [ + self.model.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids + # only add bos to the first seg + for i, seg in enumerate(prompt_segs) + ] + seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens] + + mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] + mixed_embs = torch.cat(mixed_embs, dim=1) + return mixed_embs + + def answer(self, img_list, input_text, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000): + embs = self.get_context_emb(input_text, img_list) + + current_max_len = embs.shape[1] + max_new_tokens + if current_max_len - max_length > 0: + print("Warning: The number of tokens in current conversation exceeds the max length. " "The model will not see the contexts outside the range.") + begin_idx = max(0, current_max_len - max_length) + + embs = embs[:, begin_idx:] + + outputs = self.model.llama_model.generate( + inputs_embeds=embs, + max_new_tokens=max_new_tokens, + stopping_criteria=self.stopping_criteria, + num_beams=num_beams, + do_sample=True, + min_length=min_length, + top_p=top_p, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + temperature=temperature, + ) + + output_token = outputs[0] + if output_token[0] == 0: # the model might output a unknow token + output_token = output_token[1:] + if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it + output_token = output_token[1:] + output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False) + output_text = output_text.split("###")[0] # remove the stop sign '###' + output_text = output_text.split("Assistant:")[-1].strip() + return output_text, output_token.cpu().numpy() + + def generate_until(self, requests: List[Instance]) -> List[str]: + res = [] + + def _collate(x): + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + toks = self.tok_encode(x[0]) + return -len(toks), x[0] + + # we group requests by their generation_kwargs, + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling + # in the same batch. + metadata = requests[0].metadata + re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True) + chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) + num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1 + pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding") + for chunk in chunks: + batched_contexts, all_gen_kwargs, batched_doc_to_visual, batched_doc_id, batched_task, batched_split = zip(*chunk) + task = batched_task[0] + split = batched_split[0] + batched_visuals = [batched_doc_to_visual[0](self.task_dict[task][split][ids]) for ids in batched_doc_id] # [B, N] + assert len(batched_visuals) == 1 + + # we assume all gen kwargs in the batch are the same + # this is safe to assume because the `grouper` object ensures it. + gen_kwargs = all_gen_kwargs[0] + if "until" in gen_kwargs: + gen_kwargs.pop("until") + + text_outputs = [] + + for visual, context in zip(batched_visuals, batched_contexts): + if type(visual[0]) == PIL.Image.Image and "task_type" not in metadata and "sample_frames" not in metadata: # For image task + raise NotImplementedError("MovieChat only supports video inputs.") + + elif "task_type" in metadata and metadata["task_type"] == "video" and "sample_frames" in metadata: + raise NotImplementedError("MovieChat only supports video inputs.") + + elif type(visual[0]) == str: # For video task + image_tensor = [] + self.model.short_memory_buffer = [] + self.model.long_memory_buffer = [] + img_list = [] + # try: + os.makedirs(self.tmp_folder, exist_ok=True) + + video = VideoFileClip(visual[0]) + clip_duration = video.duration / self.num_clips + + cur_frame = 0 + for i in range(self.num_clips): + preprocess_frames = [] + start_time = i * clip_duration + end_time = start_time + clip_duration + # uniformly sample self.sliding_window_length frames from the video from start_time to end_time + frames = list(video.subclip(start_time, end_time).iter_frames(fps=self.sliding_window_length / clip_duration))[: self.sliding_window_length] + for frame in frames: + frame = Image.fromarray(frame) + frame_tensor = self.transform(frame) + frame_tensor = frame_tensor.permute(2, 0, 1) + frame_tensor = frame_tensor.unsqueeze(0) + frame_tensor = self._image_processor.transform(frame_tensor) + frame_tensor = frame_tensor.squeeze(-1).permute(1, 2, 0) + preprocess_frames.append(frame_tensor) + + frames_tensor = torch.stack(preprocess_frames, dim=0) + + image_embeds = self.model.ln_vision(self.model.visual_encoder(frames_tensor.half().to(self.device))) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device) + query_tokens = self.model.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_output = self.model.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + encoded_window = query_output.last_hidden_state + + for frame in encoded_window: + if cur_frame < (self.model.short_memory_length - self.merge_frame_length): + if len(self.model.short_memory_buffer) == self.model.short_memory_length: + self.model.short_memory_buffer.pop(0) + self.model.short_memory_buffer.append(frame) + cur_frame += 1 + + if cur_frame == (self.model.short_memory_length - self.merge_frame_length): + cur_frame = 0 + + # merge short_memory_frames + similar_list = [] + for frame_i in range(len(self.model.short_memory_buffer) - 1): + scores = self.model.short_memory_buffer[frame_i] @ self.model.short_memory_buffer[frame_i + 1].transpose(-1, -2) + frame_silimar = torch.mean(scores) + similar_list.append(frame_silimar) + + while len(self.model.short_memory_buffer) > self.merge_frame_length: + max_value = max(similar_list) + max_index = similar_list.index(max_value) + new_frame_feature = (self.model.short_memory_buffer[max_index].cpu() + self.model.short_memory_buffer[max_index + 1].cpu()) / 2 + self.model.short_memory_buffer[max_index] = new_frame_feature.cuda() + del self.model.short_memory_buffer[max_index + 1] + similar_list = [] + for frame_i in range(len(self.model.short_memory_buffer) - 1): + scores = self.model.short_memory_buffer[frame_i] @ self.model.short_memory_buffer[frame_i + 1].transpose(-1, -2) + frame_silimar = torch.mean(scores) + similar_list.append(frame_silimar) + + for frame in self.model.short_memory_buffer: + self.model.long_memory_buffer.append(frame) + + cur_image = self.model.encode_image(preprocess_frames[-1].unsqueeze(0).unsqueeze(2).half(), self.device) + video_emb, _ = self.model.encode_long_video(cur_image, device=self.device, middle_video=False) + img_list.append(video_emb) + llm_message = self.answer(img_list=img_list, input_text=context, num_beams=1, temperature=1.0, max_new_tokens=300, max_length=2000)[0] + text_outputs.append(llm_message) + + # except Exception as e: + # eval_logger.error(f"Error {e} in loading video") + # image_tensor = None + + text_outputs = [response.strip() for response in text_outputs] + res.extend(text_outputs) + print(text_outputs) + self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs) + pbar.update(1) + # reorder this group of results back to original unsorted form + res = re_ords.get_original(res) + + pbar.close() + return res From 38a42d6681a88a7476524006b2a5d58f7f1321d7 Mon Sep 17 00:00:00 2001 From: Espere-1119-Song Date: Sat, 26 Oct 2024 21:37:16 +0800 Subject: [PATCH 2/5] Add tasks to readme --- docs/current_tasks.md | 11 +++++++++++ docs/run_examples.md | 3 ++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/docs/current_tasks.md b/docs/current_tasks.md index 8c922aea..2cdc3bc8 100644 --- a/docs/current_tasks.md +++ b/docs/current_tasks.md @@ -292,6 +292,17 @@ - [YouCook2](http://youcook2.eecs.umich.edu/) (youcook2_val) +- [MovieChat](https://github.com/rese1f/MovieChat) (moviechat) + - MovieChat Global Model (moviechat_global) + - MovieChat Breakpoint Model (moviechat_breakpoint) + +- [VDC](https://github.com/rese1f/aurora) (vdc) + - VDC Camera Caption (camera_test) + - VDC Short Caption (short_test) + - VDC Background Caption (background_test) + - VDC Main Object Caption (main_object_test) + + ## 4. Text Tasks - [GSM8K](https://github.com/openai/grade-school-math) (gsm8k) diff --git a/docs/run_examples.md b/docs/run_examples.md index 15c5715b..dcfdebdf 100644 --- a/docs/run_examples.md +++ b/docs/run_examples.md @@ -433,4 +433,5 @@ accelerate launch --num_processes 8 --main_process_port 12345 -m lmms_eval \ --log_samples \ --log_samples_suffix $TASK_SUFFIX \ --output_path ./logs/ -``` \ No newline at end of file +``` + From a997f3f5667b989cc44f827d8a294920beca7e57 Mon Sep 17 00:00:00 2001 From: Espere-1119-Song Date: Sat, 26 Oct 2024 21:53:44 +0800 Subject: [PATCH 3/5] Add tasks to readme --- lmms_eval/models/llava_onevision_moviechat.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lmms_eval/models/llava_onevision_moviechat.py b/lmms_eval/models/llava_onevision_moviechat.py index 39ccd802..a59b176a 100644 --- a/lmms_eval/models/llava_onevision_moviechat.py +++ b/lmms_eval/models/llava_onevision_moviechat.py @@ -64,6 +64,7 @@ best_fit_attn_implementation = "eager" +# llava_onevision_moviechat uses the same memory consolidation mechanism with the original MovieChat, but changes the base model from VideoLLamA to LLava-OneVision @register_model("llava_onevision_moviechat") class Llava_OneVision_MovieChat(lmms): """ From e091cbc59f1e36c76c318ab6f9006ba77c68f546 Mon Sep 17 00:00:00 2001 From: Espere-1119-Song Date: Sat, 26 Oct 2024 22:22:55 +0800 Subject: [PATCH 4/5] Modify init --- lmms_eval/models/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lmms_eval/models/__init__.py b/lmms_eval/models/__init__.py index 70089b3d..c1871519 100644 --- a/lmms_eval/models/__init__.py +++ b/lmms_eval/models/__init__.py @@ -11,6 +11,7 @@ logger.add(sys.stdout, level="WARNING") AVAILABLE_MODELS = { + "auroracap": "AuroraCap", "batch_gpt4": "BatchGPT4", "claude": "Claude", "cogvlm2": "CogVLM2", @@ -26,12 +27,14 @@ "llava": "Llava", "llava_hf": "LlavaHf", "llava_onevision": "Llava_OneVision", + "llava_onevision_moviechat": "Llava_OneVision_MovieChat", "llava_sglang": "LlavaSglang", "llava_vid": "LlavaVid", "longva": "LongVA", "mantis": "Mantis", "minicpm_v": "MiniCPM_V", "minimonkey": "MiniMonkey", + "moviechat": "MovieChat", "mplug_owl_video": "mplug_Owl", "phi3v": "Phi3v", "qwen_vl": "Qwen_VL", From 9312a9e5897385ae558da71da1afb26c833086d5 Mon Sep 17 00:00:00 2001 From: Espere-1119-Song Date: Sat, 26 Oct 2024 22:31:26 +0800 Subject: [PATCH 5/5] Add instruction --- docs/run_examples.md | 70 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/docs/run_examples.md b/docs/run_examples.md index dcfdebdf..a63b22b6 100644 --- a/docs/run_examples.md +++ b/docs/run_examples.md @@ -435,3 +435,73 @@ accelerate launch --num_processes 8 --main_process_port 12345 -m lmms_eval \ --output_path ./logs/ ``` +### MovieChat + +```bash +cd /path/to/lmms-eval +python3 -m pip install -e .; + +python -m pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url https://download.pytorch.org/whl/cu118 + +git clone https://github.com/rese1f/MovieChat.git +mv /path/to/MovieChat /path/to/lmms-eval/lmms_eval/models/ + +TASK=$1 +echo $TASK +TASK_SUFFIX="${TASK//,/_}" +echo $TASK_SUFFIX + +accelerate launch --num_processes 8 --main_process_port 12345 -m lmms_eval \ + --model moviechat \ + --tasks $TASK \ + --batch_size 1 \ + --log_samples \ + --log_samples_suffix $TASK_SUFFIX \ + --output_path ./logs/ +``` + +### LLaVA-OneVision-MovieChat + +```bash +cd /path/to/lmms-eval +python3 -m pip install -e .; + +git clone https://github.com/rese1f/MovieChat.git +mv /path/to/MovieChat/MovieChat_OneVision/llava /path/to/lmms-eval/ + +TASK=$1 +echo $TASK +TASK_SUFFIX="${TASK//,/_}" +echo $TASK_SUFFIX + +accelerate launch --num_processes 8 --main_process_port 12345 -m lmms_eval \ + --model llava_onevision_moviechat \ + --tasks $TASK \ + --batch_size 1 \ + --log_samples \ + --log_samples_suffix $TASK_SUFFIX \ + --output_path ./logs/ +``` + +### LLaVA-OneVision-MovieChat + +```bash +cd /path/to/lmms-eval +python3 -m pip install -e .; + +git clone https://github.com/rese1f/aurora.git +mv /path/to/aurora/src/xtuner/xtuner /path/to/lmms-eval/xtuner-aurora + +TASK=$1 +echo $TASK +TASK_SUFFIX="${TASK//,/_}" +echo $TASK_SUFFIX + +accelerate launch --num_processes 8 --main_process_port 12345 -m lmms_eval \ + --model auroracap \ + --tasks $TASK \ + --batch_size 1 \ + --log_samples \ + --log_samples_suffix $TASK_SUFFIX \ + --output_path ./logs/ +```