Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1 runtimeerror flashattention only supports ampere gpus or newer #62

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions models/pllava/configuration_pllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.models.auto import CONFIG_MAPPING
from utils.basic_utils import is_gpu_ampere_or_later


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -141,6 +142,8 @@ def __init__(
elif text_config is None:
tmp_config = {"_attn_implementation":"flash_attention_2",
"gradient_checkpointing": self.gradient_checkpointing}
if not is_gpu_ampere_or_later():
del tmp_config['_attn_implementation']
self.text_config = CONFIG_MAPPING["llama"](**tmp_config)
self.text_config.gradient_checkpointing = self.gradient_checkpointing
# self.text_config["_attn_implementation"]="flash_attention_2" # xl: temporal hard code
Expand Down
9 changes: 7 additions & 2 deletions models/pllava/modeling_pllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

from .configuration_pllava import PllavaConfig
import pickle
from utils.basic_utils import is_gpu_ampere_or_later


logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -175,7 +177,7 @@ class PllavaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["LlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_flash_attn_2 = is_gpu_ampere_or_later()

def _init_weights(self, module):
# important: this ported version of Llava isn't meant for training from scratch - only
Expand Down Expand Up @@ -291,7 +293,10 @@ def __init__(self, config: PllavaConfig):
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = PllavaMultiModalProjector(config)
self.vocab_size = config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config, torch_dtype=config.torch_dtype, attn_implementation="flash_attention_2")
if is_gpu_ampere_or_later():
self.language_model = AutoModelForCausalLM.from_config(config.text_config, torch_dtype=config.torch_dtype, attn_implementation="flash_attention_2")
else:
self.language_model = AutoModelForCausalLM.from_config(config.text_config, torch_dtype=config.torch_dtype)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else self.config.text_config.pad_token_id
assert self.pad_token_id is not None, 'provide the model with pad_token_id, this would be used to arranging new embedings'
self.post_init()
Expand Down
10 changes: 7 additions & 3 deletions tasks/eval/demo/pllava_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
)
from tasks.eval.demo import pllava_theme

SYSTEM="""You are Pllava, a large vision-language assistant.
You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language.
Follow the instructions carefully and explain your answers in detail based on the provided video.
SYSTEM="""You are a powerful Video Magic ChatBot, a large vision-language assistant.
You are able to understand the video content that the user provides and assist the user in a video-language related task.
The user might provide you with the video and maybe some extra noisy information to help you out or ask you a question. Make use of the information in a proper way to be competent for the job.
### INSTRUCTIONS:
1. Follow the user's instruction.
2. Be critical yet believe in yourself.
"""

INIT_CONVERSATION: Conversation = conv_plain_v1.copy()


Expand Down
3 changes: 3 additions & 0 deletions tasks/eval/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from accelerate.utils import get_balanced_memory

from transformers import StoppingCriteria
from utils.basic_utils import is_gpu_ampere_or_later

class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
Expand Down Expand Up @@ -45,6 +47,7 @@ def load_pllava(repo_id, num_frames, use_lora=False, weight_dir=None, lora_alpha
kwargs.update(pooling_shape=(0,12,12)) # produce a bug if ever usen the pooling projector
config = PllavaConfig.from_pretrained(
repo_id if not use_lora else weight_dir,
use_flash_attention_2=is_gpu_ampere_or_later(),
pooling_shape=pooling_shape,
**kwargs,
)
Expand Down
5 changes: 5 additions & 0 deletions utils/basic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,17 @@

import torch
import torch.distributed as dist
from torch.cuda import get_device_properties
from .distributed import is_dist_avail_and_initialized


logger = logging.getLogger(__name__)


def is_gpu_ampere_or_later():
return get_device_properties(0).major >= 8


class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
Expand Down