Skip to content

Commit

Permalink
Fix llava_hf generation for 1.6
Browse files Browse the repository at this point in the history
  • Loading branch information
kcz358 committed May 7, 2024
1 parent fa3ff92 commit 3e56b4f
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions lmms_eval/models/llava_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from accelerate import Accelerator, DistributedType
from accelerate.state import AcceleratorState
from typing import List, Optional, Union, Tuple
from transformers import LlavaForConditionalGeneration, AutoProcessor
from transformers import LlavaForConditionalGeneration, LlavaNextForConditionalGeneration, AutoProcessor

import warnings

Expand Down Expand Up @@ -67,7 +67,15 @@ def __init__(
self.device_map = device_map
if isinstance(dtype, str) and dtype != "auto":
dtype = getattr(torch, dtype)
self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)

if "1.5" in pretrained:
self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)
elif "1.6" in pretrained:
self._model = LlavaNextForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)
else:
eval_logger.info("Not sure whether you use 1.5 or 1.6. Use 1.5 by default. This might cause bugs if you are actually using 1.6")
self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)

self._image_processor = AutoProcessor.from_pretrained(pretrained, revision=revision, trust_remote_code=trust_remote_code)
# Pad from left for batched generation: https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava#usage-tips
self._image_processor.tokenizer.padding_side = "left"
Expand Down Expand Up @@ -106,6 +114,7 @@ def __init__(
self.model.to(self._device)
self._rank = 0
self._word_size = 1
self.accelerator = accelerator

@property
def config(self):
Expand Down

0 comments on commit 3e56b4f

Please sign in to comment.