Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash Attention vs Triton Flash Attention #180

Closed
germanjke opened this issue May 20, 2023 · 7 comments
Closed

Flash Attention vs Triton Flash Attention #180

germanjke opened this issue May 20, 2023 · 7 comments
Assignees

Comments

@germanjke
Copy link

germanjke commented May 20, 2023

Hi, i want to know about choice in your MPT model

Yes, Triton version supports alibi and fast forward, but it's have some disadvantages:

  • slow backward;
  • slower forward + bacward;
  • no dropout;
  • no different batch seqs;
  • implementation is really experimental: works with only some on num_heads, another bugs (i know we can using some custom implementation but i'm afraid it's have some bugs as well)

Do you think alibi choice is so important in this case?

It's looks like some trade off

@vchiley
Copy link
Contributor

vchiley commented May 20, 2023

Without delving into the implementation details, using ALiBi as the networks position embedding is beneficial regardless of how exactly it is supported.

Relative position embedding

The default positional embedding is learned positional embeddings. The issue with learned positional embeddings is that inference max seq len is limited to the training max seq len.
ALiBi, being a relative position embedding, allows the user to extend the seq len for inference.
Screenshot 2023-05-20 at 7 32 36 AM
Source: ALiBi paper

Convergence

In my experience, ALiBi also has faster convergence to other position embedding schemes we tried.

Implementation details ie Triton

This is the flash attn implementation we use if attn_config: attn_impl: triton (note: the custom implementation you cite is the bases of the version we use).
Although the comments are copied from here, within a network, I've found that the triton version has

  • "slower faster forward + backward"
  • it doesn't matter if it doesn't have dropout because we don't use it
  • no different batch seqs is fine since we concatenate sequences
  • the cited limitations on n_heads is a limitation on the CUDA version of FlashAttn. The triton implementation we use seems not to have any limitation (besides head_dim <= 128; they do note that it hasn't been thoroughly tested at all dim sizes)

Given that triton has faster forward + backward, I'd advocate for using the triton version. The fact that it supports ALiBi is a bonus (a very welcome bonus); I wouldn't necessarily call it a tradeoff.

@vchiley vchiley self-assigned this May 20, 2023
@germanjke
Copy link
Author

great stuff, if it's faster it's cool, but here they told about slower computation, I am not wrong?

@germanjke
Copy link
Author

by the way can you tell me about setpu.py please, why you use torch 1.13.1 there, if we want to use torch 2?

@germanjke
Copy link
Author

germanjke commented May 21, 2023

why you turn off this from setup.py 'triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python', it's was on your torch2 version, but now master without this: PyPI don't have direct packages for real
i mean this pull

but in main branch i guess its ok and its triton pre mlir

@germanjke
Copy link
Author

upd: everything is fine with current main version

here is versions (just to information when branch will be updates):

torch==1.13.1+cu117 triton-pre-mlir @ git+https://github.com/vchiley/triton.git@2dd3b957698a39bbca615c02a447a98482c144a3#subdirectory=python`
flash-attn==v1.0.3.post0

installed everything from setup.py(gpu setup) with this docker mosaicml/pytorch:latest

@vchiley
Copy link
Contributor

vchiley commented May 22, 2023

#181
undid #178
which we'll try to redo soon.

(the PR was fine, but Huggingface didn't like it so we're working on a workaround for HF stuff)

@vchiley
Copy link
Contributor

vchiley commented May 26, 2023

torch2 reintegrated

@vchiley vchiley closed this as completed May 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants