Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add flash attention #235

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft

add flash attention #235

wants to merge 7 commits into from

Conversation

ehartford
Copy link

I try to add flash attention in the same way that fastchat and axolotl do.

However, I get this error message.

  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 340, in forward
    raise ValueError(
ValueError: Attention mask should be of size (1, 1, 513, 513), but is torch.Size([1, 513])

I wonder if you have any ideas?

@tmm1
Copy link

tmm1 commented Jul 30, 2023

transformers/models/llama/modeling_llama.py", line 340, in forward

this makes it seem like its still using the original forward method instead of the one that is patched in? wouldn't it say patch_flash_attn.py, in forward instead then?

@tmm1
Copy link

tmm1 commented Jul 30, 2023

@ehartford
Copy link
Author

here is full stack trace:

Traceback (most recent call last):
  File "/home/eric/git/qlora/qlora.py", line 845, in <module>
    train()
  File "/home/eric/git/qlora/qlora.py", line 807, in train
    train_result = trainer.train()
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/trainer.py", line 1539, in train
    return inner_training_loop(
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/trainer.py", line 1809, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/trainer.py", line 2654, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/trainer.py", line 2679, in compute_loss
    outputs = model(**inputs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/utils/operations.py", line 581, in forward
    return model_forward(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/utils/operations.py", line 569, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
    return func(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/peft/peft_model.py", line 922, in forward
    return self.base_model(
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 806, in forward
    outputs = self.model(
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 685, in forward
    layer_outputs = torch.utils.checkpoint.checkpoint(
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 249, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
Traceback (most recent call last):
  File "/home/eric/git/qlora/qlora.py", line 845, in <module>
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 107, in forward
        train()
outputs = run_function(*args)  File "/home/eric/git/qlora/qlora.py", line 807, in train

  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 681, in custom_forward
    train_result = trainer.train()
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/trainer.py", line 1539, in train
    return module(*inputs, output_attentions, None)    return inner_training_loop(

  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/trainer.py", line 1809, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/trainer.py", line 2654, in training_step
    return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    loss = self.compute_loss(model, inputs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/trainer.py", line 2679, in compute_loss
    output = old_forward(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 408, in forward
    outputs = model(**inputs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    hidden_states, self_attn_weights, present_key_value = self.self_attn(    return forward_call(*args, **kwargs)

  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
      File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/utils/operations.py", line 581, in forward
        return model_forward(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/utils/operations.py", line 569, in __call__
output = old_forward(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 340, in forward
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
    return func(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/peft/peft_model.py", line 922, in forward
    raise ValueError(    return self.base_model(

  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
ValueError: Attention mask should be of size (1, 1, 513, 513), but is torch.Size([1, 513])
    return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 806, in forward
    outputs = self.model(
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 685, in forward
    layer_outputs = torch.utils.checkpoint.checkpoint(
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 249, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 107, in forward
    outputs = run_function(*args)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 681, in custom_forward
    return module(*inputs, output_attentions, None)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 408, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 340, in forward
    raise ValueError(
ValueError: Attention mask should be of size (1, 1, 513, 513), but is torch.Size([1, 513])

@ehartford
Copy link
Author

could be possible that qlora patches it again after I patch it?

@tmm1
Copy link

tmm1 commented Jul 30, 2023

/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)

this seems suspicious w/ old_forward and new_forward. maybe its grabbing the original method before we patch it with flash-attn

@tmm1
Copy link

tmm1 commented Jul 30, 2023

ah i think you need to call replace... method to monkey patch before the model is instantiated, i.e. before AutoModelForCausalLM.from_pretrained

@tmm1
Copy link

tmm1 commented Jul 30, 2023

try this? 1b56419

@ehartford
Copy link
Author

ok trying

@ehartford
Copy link
Author

RuntimeError: FlashAttention only support fp16 and bf16 data type

@ehartford
Copy link
Author

seems maybe FlashAttention needs to be modified to support this

@MrigankRaman
Copy link

Does this still work? I still get the same error of "RuntimeError: FlashAttention only support fp16 and bf16 data type"

@Ltrack
Copy link

Ltrack commented Sep 8, 2023

try this: #221 (comment)

replace_attn_with_flash_attn() >load model> model = upcast_layer_for_flash_attention(model, torch.bfloat16)

should fix this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants