Skip to content

Commit

Permalink
update load_video function to allow single frame loading and adjust n…
Browse files Browse the repository at this point in the history
…um_segments usage in InternVL2 class
  • Loading branch information
pufanyi committed Dec 20, 2024
1 parent d9b59bb commit 06faf07
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions lmms_eval/models/internvl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
return frame_indices


def load_video(video_path, bound=None, input_size=448, max_num=32, num_segments=32):
def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
max_frame = len(vr) - 1
fps = float(vr.get_avg_fps())
Expand Down Expand Up @@ -135,7 +135,6 @@ def __init__(
device_map: str = "cuda:0",
batch_size: str = "1",
num_frame: int = 32,
num_segments: int = 32,
**kwargs,
):
super().__init__()
Expand All @@ -144,7 +143,6 @@ def __init__(
self._model = AutoModel.from_pretrained(self.path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True, device_map=device_map).eval()
self._tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True, device_map=device_map)
self.num_frame = num_frame
self.num_segments = num_segments

batch_size = int(batch_size)
assert batch_size == 1, f"Batch size should be 1 for InternVL2, but got {batch_size}."
Expand Down Expand Up @@ -273,7 +271,7 @@ def generate_until(self, requests) -> List[str]:
elif self.modality == "video":
assert len(visuals) == 1, f"Only one video is supported, but got {len(visuals)} videos."
video_path = visuals[0]
pixel_values, num_patches_list = load_video(video_path, num_segments=self.num_segments, max_num=self.num_frame)
pixel_values, num_patches_list = load_video(video_path, num_segments=self.num_frame)
pixel_values = pixel_values.to(torch.bfloat16).cuda()
video_prefix = "".join([f"Frame{i+1}: <image>\n" for i in range(len(num_patches_list))])
question = video_prefix + contexts
Expand Down

0 comments on commit 06faf07

Please sign in to comment.