Skip to content

Commit

Permalink
fix compatibility issue of older version llava
Browse files Browse the repository at this point in the history
  • Loading branch information
Luodian committed May 27, 2024
1 parent 616edf4 commit 24dc435
Showing 1 changed file with 2 additions and 7 deletions.
9 changes: 2 additions & 7 deletions lmms_eval/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@
except ImportError:
eval_logger.error("LLaVA is not installed. Please install LLaVA to use this model.")

from transformers.integrations.deepspeed import (
is_deepspeed_zero3_enabled,
set_hf_deepspeed_config,
unset_hf_deepspeed_config,
)

if torch.__version__ > "2.1.2":
best_fit_attn_implementation = "sdpa"
else:
Expand Down Expand Up @@ -94,8 +88,9 @@ def __init__(
# Try to load the model with the multimodal argument
self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, model_name, device_map=self.device_map, **llava_model_args)
except TypeError:
# for older versions of LLaVA that don't have multimodal argument
# for older versions of LLaVA that don't have multimodal and attn_implementation arguments
llava_model_args.pop("multimodal", None)
llava_model_args.pop("attn_implementation", None)
self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, model_name, device_map=self.device_map, **llava_model_args)

self._config = self._model.config
Expand Down

0 comments on commit 24dc435

Please sign in to comment.