Skip to content

Commit

Permalink
Enable flash attention 2 in Llava model
Browse files Browse the repository at this point in the history
  • Loading branch information
Luodian committed Mar 11, 2024
1 parent 27027e4 commit 476e6fd
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions lmms_eval/models/llava.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import torch

torch.backends.cuda.matmul.allow_tf32 = True

import logging
import copy
from tqdm import tqdm
Expand Down Expand Up @@ -47,7 +50,7 @@ def __init__(
batch_size: Optional[Union[int, str]] = 1,
trust_remote_code: Optional[bool] = False,
revision=None,
use_flash_attention_2=False,
use_flash_attention_2=True,
conv_template="vicuna_v1",
use_cache=True,
truncate_context=False, # whether to truncate the context in generation, set it False for LLaVA-1.6
Expand All @@ -67,7 +70,7 @@ def __init__(
self._model,
self._image_processor,
self._max_length,
) = load_pretrained_model(pretrained, None, get_model_name_from_path(pretrained), device_map=self._device)
) = load_pretrained_model(pretrained, None, get_model_name_from_path(pretrained), device_map=self._device, use_flash_attention_2=use_flash_attention_2, torch_dtype=dtype)
self._config = self._model.config
self.model.eval()
self.model.tie_weights()
Expand All @@ -78,18 +81,14 @@ def __init__(
self.truncate_context = truncate_context
# assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue."
if accelerator.num_processes > 1:
assert accelerator.distributed_type in [
DistributedType.FSDP,
DistributedType.MULTI_GPU,
DistributedType.DEEPSPEED
], "Unsupported distributed type provided. Only DDP and FSDP are supported."
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
# Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works
# I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.
if accelerator.distributed_type == DistributedType.DEEPSPEED:
kwargs = {
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
"train_batch_size" : self.batch_size_per_gpu * accelerator.num_processes,
"train_batch_size": self.batch_size_per_gpu * accelerator.num_processes,
}
AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)
eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0")
Expand Down

0 comments on commit 476e6fd

Please sign in to comment.