Skip to content

Commit

Permalink
feat: Add tie_weights parameter to Llava model initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
Luodian committed Jul 9, 2024
1 parent 2037a86 commit 672d7e5
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions lmms_eval/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
device_map="cuda:0",
conv_template="vicuna_v1",
use_cache=True,
tie_weights: bool = True,
truncate_context=False, # whether to truncate the context in generation, set it False for LLaVA-1.6
customized_config=None, # ends in json
**kwargs,
Expand Down Expand Up @@ -97,6 +98,8 @@ def __init__(
self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, model_name, device_map=self.device_map, **llava_model_args)
self._config = self._model.config
self.model.eval()
if tie_weights:
self.model.tie_weights()

self.truncation = truncation
self.batch_size_per_gpu = int(batch_size)
Expand Down

0 comments on commit 672d7e5

Please sign in to comment.