Skip to content

Commit

Permalink
fix llama vision (#447)
Browse files Browse the repository at this point in the history
  • Loading branch information
pufanyi authored Dec 7, 2024
1 parent 835da31 commit 7ee8d59
Showing 1 changed file with 4 additions and 11 deletions.
15 changes: 4 additions & 11 deletions lmms_eval/models/llama_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model
from lmms_eval.models.model_utils.load_video import read_video_pyav_pil

warnings.filterwarnings("ignore")

Expand All @@ -33,12 +32,10 @@ def __init__(
device: str = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto",
batch_size: int = 1,
trust_remote_code: Optional[bool] = True,
trust_remote_code: Optional[bool] = False,
attn_implementation: Optional[str] = None,
device_map: str = "",
max_frames_num: Optional[int] = 32,
fps: Optional[int] = None,
max_image_size: Optional[int] = None,
**kwargs,
) -> None:
super().__init__()
Expand All @@ -55,9 +52,7 @@ def __init__(
if isinstance(dtype, str) and dtype != "auto":
dtype = getattr(torch, dtype)

self.fps = fps
self.max_frames_num = max_frames_num
self.max_image_size = max_image_size
self._model = MllamaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)
self.model.eval()
self.processor = AutoProcessor.from_pretrained(pretrained)
Expand Down Expand Up @@ -182,11 +177,9 @@ def generate_until(self, requests: List[Instance]) -> List[str]:

for visual in visuals:
if isinstance(visual, str):
frames = read_video_pyav_pil(visual, num_frm=self.max_frames_num, fps=self.fps, max_image_size=self.max_image_size)
images.extend(frames)
# frames = self.load_video(visual, self.max_frames_num)
# frames = torch.from_numpy(frames).permute(0, 3, 1, 2)
# images.extend([to_pil_image(frame) for frame in frames])
frames = self.load_video(visual, self.max_frames_num)
frames = torch.from_numpy(frames).permute(0, 3, 1, 2)
images.extend([to_pil_image(frame) for frame in frames])
elif isinstance(visual, PIL.Image.Image):
images.append(visual)

Expand Down

0 comments on commit 7ee8d59

Please sign in to comment.