diff --git a/models/pllava/configuration_pllava.py b/models/pllava/configuration_pllava.py index 6c429ce..53b58ab 100644 --- a/models/pllava/configuration_pllava.py +++ b/models/pllava/configuration_pllava.py @@ -16,6 +16,7 @@ from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging from transformers.models.auto import CONFIG_MAPPING +from utils.basic_utils import is_gpu_ampere_or_later logger = logging.get_logger(__name__) @@ -141,6 +142,8 @@ def __init__( elif text_config is None: tmp_config = {"_attn_implementation":"flash_attention_2", "gradient_checkpointing": self.gradient_checkpointing} + if not is_gpu_ampere_or_later(): + del tmp_config['_attn_implementation'] self.text_config = CONFIG_MAPPING["llama"](**tmp_config) self.text_config.gradient_checkpointing = self.gradient_checkpointing # self.text_config["_attn_implementation"]="flash_attention_2" # xl: temporal hard code diff --git a/models/pllava/modeling_pllava.py b/models/pllava/modeling_pllava.py index 04d64cf..da01f01 100644 --- a/models/pllava/modeling_pllava.py +++ b/models/pllava/modeling_pllava.py @@ -36,6 +36,8 @@ from .configuration_pllava import PllavaConfig import pickle +from utils.basic_utils import is_gpu_ampere_or_later + logger = logging.get_logger(__name__) @@ -175,7 +177,7 @@ class PllavaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True + _supports_flash_attn_2 = is_gpu_ampere_or_later() def _init_weights(self, module): # important: this ported version of Llava isn't meant for training from scratch - only @@ -291,7 +293,10 @@ def __init__(self, config: PllavaConfig): self.vision_tower = AutoModel.from_config(config.vision_config) self.multi_modal_projector = PllavaMultiModalProjector(config) self.vocab_size = config.vocab_size - self.language_model = AutoModelForCausalLM.from_config(config.text_config, torch_dtype=config.torch_dtype, attn_implementation="flash_attention_2") + if is_gpu_ampere_or_later(): + self.language_model = AutoModelForCausalLM.from_config(config.text_config, torch_dtype=config.torch_dtype, attn_implementation="flash_attention_2") + else: + self.language_model = AutoModelForCausalLM.from_config(config.text_config, torch_dtype=config.torch_dtype) self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else self.config.text_config.pad_token_id assert self.pad_token_id is not None, 'provide the model with pad_token_id, this would be used to arranging new embedings' self.post_init() diff --git a/tasks/eval/demo/pllava_demo.py b/tasks/eval/demo/pllava_demo.py index 935734b..f3f78b8 100644 --- a/tasks/eval/demo/pllava_demo.py +++ b/tasks/eval/demo/pllava_demo.py @@ -13,10 +13,14 @@ ) from tasks.eval.demo import pllava_theme -SYSTEM="""You are Pllava, a large vision-language assistant. -You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language. -Follow the instructions carefully and explain your answers in detail based on the provided video. +SYSTEM="""You are a powerful Video Magic ChatBot, a large vision-language assistant. +You are able to understand the video content that the user provides and assist the user in a video-language related task. +The user might provide you with the video and maybe some extra noisy information to help you out or ask you a question. Make use of the information in a proper way to be competent for the job. +### INSTRUCTIONS: +1. Follow the user's instruction. +2. Be critical yet believe in yourself. """ + INIT_CONVERSATION: Conversation = conv_plain_v1.copy() diff --git a/tasks/eval/model_utils.py b/tasks/eval/model_utils.py index f1a700f..4651801 100644 --- a/tasks/eval/model_utils.py +++ b/tasks/eval/model_utils.py @@ -10,6 +10,8 @@ from accelerate.utils import get_balanced_memory from transformers import StoppingCriteria +from utils.basic_utils import is_gpu_ampere_or_later + class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords @@ -45,6 +47,7 @@ def load_pllava(repo_id, num_frames, use_lora=False, weight_dir=None, lora_alpha kwargs.update(pooling_shape=(0,12,12)) # produce a bug if ever usen the pooling projector config = PllavaConfig.from_pretrained( repo_id if not use_lora else weight_dir, + use_flash_attention_2=is_gpu_ampere_or_later(), pooling_shape=pooling_shape, **kwargs, ) diff --git a/utils/basic_utils.py b/utils/basic_utils.py index fb453d3..ba6bb66 100644 --- a/utils/basic_utils.py +++ b/utils/basic_utils.py @@ -12,12 +12,17 @@ import torch import torch.distributed as dist +from torch.cuda import get_device_properties from .distributed import is_dist_avail_and_initialized logger = logging.getLogger(__name__) +def is_gpu_ampere_or_later(): + return get_device_properties(0).major >= 8 + + class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a window or the global series average.