diff --git a/lmms_eval/models/model_utils/load_video.py b/lmms_eval/models/model_utils/load_video.py index 0c4ea23d..2d4879cc 100644 --- a/lmms_eval/models/model_utils/load_video.py +++ b/lmms_eval/models/model_utils/load_video.py @@ -1,6 +1,19 @@ import av import numpy as np from av.codec.context import CodecContext +from decord import VideoReader, cpu + + +def load_video_decord(video_path, max_frames_num): + if type(video_path) == str: + vr = VideoReader(video_path, ctx=cpu(0)) + else: + vr = VideoReader(video_path[0], ctx=cpu(0)) + total_frame_num = len(vr) + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int) + frame_idx = uniform_sampled_frames.tolist() + spare_frames = vr.get_batch(frame_idx).asnumpy() + return spare_frames # (frames, height, width, channels) # This one is faster diff --git a/lmms_eval/models/qwen2_vl.py b/lmms_eval/models/qwen2_vl.py index 9848f911..565b7888 100755 --- a/lmms_eval/models/qwen2_vl.py +++ b/lmms_eval/models/qwen2_vl.py @@ -3,6 +3,7 @@ from typing import List, Optional, Tuple, Union import decord +import numpy as np import torch from accelerate import Accelerator, DistributedType from loguru import logger as eval_logger @@ -14,6 +15,7 @@ from lmms_eval.api.instance import Instance from lmms_eval.api.model import lmms from lmms_eval.api.registry import register_model +from lmms_eval.models.model_utils.load_video import load_video_decord try: from qwen_vl_utils import process_vision_info @@ -35,7 +37,10 @@ def __init__( device_map: Optional[str] = "cuda", batch_size: Optional[Union[int, str]] = 1, use_cache=True, - use_flash_attention_2: Optional[bool] = True, + use_flash_attention_2: Optional[bool] = False, + max_pixels: int = 12845056, + min_pixels: int = 3136, + max_num_frames: int = 32, **kwargs, ) -> None: super().__init__() @@ -62,7 +67,10 @@ def __init__( ).eval() else: self._model = Qwen2VLForConditionalGeneration.from_pretrained(pretrained, torch_dtype="auto", device_map=self.device_map).eval() - self.processor = AutoProcessor.from_pretrained(pretrained) + self.processor = AutoProcessor.from_pretrained(pretrained, max_pixels=max_pixels, min_pixels=min_pixels) + self.max_pixels = max_pixels + self.min_pixels = min_pixels + self.max_num_frames = max_num_frames self._tokenizer = AutoTokenizer.from_pretrained(pretrained) self._config = self.model.config @@ -198,8 +206,8 @@ def _collate(x): vr = decord.VideoReader(visual) first_frame = vr[0].asnumpy() height, width = first_frame.shape[:2] - max_pixels = height * width - message.append({"role": "user", "content": [{"type": "video", "video": visual, "max_pixels": max_pixels}, {"type": "text", "text": context}]}) + # max_pixels = height * width + message.append({"role": "user", "content": [{"type": "video", "video": visual, "max_pixels": self.max_pixels}, {"type": "text", "text": context}]}) elif isinstance(visual, Image.Image): # Single image base64_image = visual.convert("RGB") buffer = BytesIO() @@ -226,6 +234,12 @@ def _collate(x): texts = [self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages] image_inputs, video_inputs = process_vision_info(messages) + total_frames = video_inputs[0].shape[0] + indices = np.linspace(0, total_frames - 1, self.max_num_frames, dtype=int) + # Append the last frame index if not already included + if total_frames - 1 not in indices: + indices = np.append(indices, total_frames - 1) + video_inputs[0] = video_inputs[0][indices] inputs = self.processor(text=texts, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt") if self.device_map == "auto":