Skip to content

Commit

Permalink
Adding deprecation warning for Flash Attention 1 and user warning aga…
Browse files Browse the repository at this point in the history
…inst using Triton attention. (#921)
  • Loading branch information
ShashankMosaicML authored Jan 31, 2024
1 parent 7dc51e9 commit 169b653
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 16 deletions.
16 changes: 10 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ You can select a specific commit hash such as `mosaicml/llm-foundry:1.13.1_cu117
| Docker Image | Torch Version | Cuda Version | LLM Foundry dependencies installed? |
| ------------------------------------------------------ | ------------- | ----------------- | ----------------------------------- |
| `mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04` | 2.1.0 | 12.1 (Infiniband) | No |
| `mosaicml/llm-foundry:2.1.0_cu121-latest` | 2.1.0 | 12.1 (Infiniband) | Yes (flash attention v1) |
| `mosaicml/llm-foundry:2.1.0_cu121_flash2-latest` | 2.1.0 | 12.1 (Infiniband) | Yes (flash attention v2) |
| `mosaicml/llm-foundry:2.1.0_cu121_aws-latest` | 2.1.0 | 12.1 (EFA) | Yes (flash attention v1) |
| `mosaicml/llm-foundry:2.1.0_cu121_flash2_aws-latest` | 2.1.0 | 12.1 (EFA) | Yes (flash attention v2) |
| `mosaicml/llm-foundry:2.1.0_cu121-latest` | 2.1.0 | 12.1 (Infiniband) | Yes (flash attention v1. Warning: Support for flash attention v1 has been deprecated.) |
| `mosaicml/llm-foundry:2.1.0_cu121_flash2-latest` | 2.1.0 | 12.1 (Infiniband) | Yes (flash attention v2. Note: We recommend using flash attention v2.) |
| `mosaicml/llm-foundry:2.1.0_cu121_aws-latest` | 2.1.0 | 12.1 (EFA) | Yes (flash attention v1. Warning: Support for flash attention v1 has been deprecated.) |
| `mosaicml/llm-foundry:2.1.0_cu121_flash2_aws-latest` | 2.1.0 | 12.1 (EFA) | Yes (flash attention v2. Note: We recommend using flash attention v2.) |


# Installation
Expand All @@ -134,7 +134,9 @@ We *strongly* recommend working with LLM Foundry inside a Docker container (see
```bash
git clone https://github.com/mosaicml/llm-foundry.git
cd llm-foundry
pip install -e ".[gpu]" # or pip install -e . if no NVIDIA GPU
pip install -e ".[gpu-flash2]" # or `pip install -e .` if no NVIDIA GPU.
# Note: Currently, `pip install -e ".[gpu-flash2]"` installs Flash Attention v2, and `pip install -e ".[gpu]"` installs Flash Attention v1.
# However, once the support for Flash Attention v1 is removed, both of these commands will install Flash Attention v2.
```

### Without Docker (not recommended)
Expand All @@ -152,7 +154,9 @@ source llmfoundry-venv/bin/activate

pip install cmake packaging torch # setup.py requires these be installed

pip install -e ".[gpu]" # or pip install -e . if no NVIDIA GPU
pip install -e ".[gpu-flash2]" # or `pip install -e .` if no NVIDIA GPU.
# Note: Currently, `pip install -e ".[gpu-flash2]"` installs Flash Attention v2, and `pip install -e ".[gpu]"` installs Flash Attention v1.
# However, once the support for Flash Attention v1 is removed, both of these commands will install Flash Attention v2.
```

### TransformerEngine and amp_fp8 support
Expand Down
16 changes: 8 additions & 8 deletions TUTORIAL.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ name = 'mosaicml/mpt-7b'
# Download config
config = AutoConfig.from_pretrained(name, trust_remote_code=True)
# (Optional) Use `triton` backend for fast attention. Defaults to `torch`.
# config.attn_config['attn_impl'] = 'triton'
# (Optional) Use `flash` (preferred) or `triton` backend for fast attention. Defaults to `torch`.
# config.attn_config['attn_impl'] = 'flash'
# (Optional) Change the `max_seq_len` allowed for inference
# config.max_seq_len = 4096

Expand Down Expand Up @@ -291,7 +291,7 @@ The purpose of this section is probably pretty self-evident. You’ve got questi
- If OOMs persist with `device_train_microbatch_size: 1` and `device_eval_batch_size: 1`, you may need to use activation checkpointing `fsdp_config.activation_checkpointing: true` (if you are not already) and, as a last resort, activation CPU offloading `fsdp_config.activation_cpu_offload: true`.

### What hardware can I train on?
- In general, this repo should work on any system with NVIDIA GPUs. Checkout the `scripts/train/README.md` for more [details on GPU memory requirements]([https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#how-many-gpus-do-i-need-to-train-a-llm](https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#how-many-gpus-do-i-need-to-train-a-llm)). Keep in mind you may run into issues with `Triton` support on some GPU types. In that situation, you can fall back to `attn_impl: torch` or raise an issue in the [Triton github repo](https://github.com/openai/triton).
- In general, this repo should work on any system with NVIDIA GPUs. Checkout the `scripts/train/README.md` for more [details on GPU memory requirements]([https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#how-many-gpus-do-i-need-to-train-a-llm](https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#how-many-gpus-do-i-need-to-train-a-llm)). We recommend using `Flash` attention instead of `Triton` attention, unless you're training Prefix Language Models (in which case use `Triton`). Keep in mind you may run into issues with `Flash` or `Triton` support on some GPU types. In that situation, you can fall back to `attn_impl: torch`, or raise an issue in the [Flash Attention github repo](https://github.com/Dao-AILab/flash-attention).

### What hardware can I run eval on?
- Similar to above…
Expand All @@ -305,15 +305,15 @@ The purpose of this section is probably pretty self-evident. You’ve got questi
### What are the different attention options `torch` / `flash` / `triton` for MPT and which one should I use?
- **Short answer:** `torch` is the native pytorch attention implementation, and `flash` and `triton` are different implementations of the much more optimized [Flash Attention](https://arxiv.org/abs/2205.14135) method. `triton` and `flash` will be faster (and use less GPU memory) than `torch`, but they might not work with all hardware and environment setups.

Our training setups typically use `triton`.
Our training setups typically use `flash`.

- **Long answer:** In NLP, Softmax Attention operates on a sequence. It is an all to all graph operation where, during training, the memory complexity is quadratic with respect to the length of the sequence. Furthermore, on GPUs, naive implementations of Softmax Attention are bandwidth (BW) limited.
[Rabe et al. (2021)](https://arxiv.org/abs/2112.05682) and [Dao et al. (2022)](https://arxiv.org/abs/2205.14135) showed that fusing all operations in Softmax Attention can make the operation much less BW limited.
Furthermore, integrating a recomputation schema decreases the sequence length memory complexity from *quadratic* to *linear*, thereby supporting much longer sequence lengths.

- Setting `attn_config.attn_impl=torch` enables a naive Softmax Attention written using base torch operations.
- Setting `attn_config.attn_impl=flash` enables Flash Attention [implemented by Dao et al in the HazyResearch repo using CUDA](https://github.com/HazyResearch/flash-attention). This will have linear memory complexity (enabling larger batch sizes) and will run much faster.
- Setting `attn_config.attn_impl=triton` enables a Flash Attention [implemented using Triton](https://github.com/mosaicml/llm-foundry/blob/main/llmfoundry/models/layers/flash_attn_triton.py). In our experience, `triton` is slightly faster than `flash`.
- Setting `attn_config.attn_impl=flash` enables Flash Attention [implemented by Dao et al in the Dao-AILab repo using CUDA](https://github.com/Dao-AILab/flash-attention). This will have linear memory complexity (enabling larger batch sizes) and will run much faster.
- Setting `attn_config.attn_impl=triton` enables a Flash Attention [implemented using Triton](https://github.com/mosaicml/llm-foundry/blob/main/llmfoundry/models/layers/flash_attn_triton.py). We recommend using `flash` attention instead of `triton` attention, unless you're training Prefix Language Models (in which case use `Triton`).

<!-- In NLP, Softmax Attention operates on a sequence. It is an all to all graph operation where, during training, the memory complexity is quadratic with respect to the length of the sequence. Furthermore, on GPUs, naive implementations of Softmax Attention are BW limited.
[Rabe et al. (2021)](https://arxiv.org/abs/2112.05682) and [Dao et al. (2022)](https://arxiv.org/abs/2205.14135) noted that fusing all operations in Softmax Attention can make the operation much less BW limited.
Expand All @@ -327,7 +327,7 @@ The majority of our training setups use `triton`. -->
#### Limitations
- For training, `torch` uses a lot of memory and is slow.
- `flash` and `triton` cannot return attention weights and therefore cannot be used with methods that require it.
- `flash` cannot accept an attention bias and therefore cannot be used with methods that require it such as ALiBi.
- `flash` cannot accept an attention bias. However, it still allows the use of ALiBi positional bias.

#### What is `triton-pre-mlir`?
- Torch2 installs and requires a specific version of [Triton](https://openai.com/research/triton).
Expand All @@ -352,7 +352,7 @@ Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706.
| Name | YAML Config | Training MFU on MPT-7B trained on 8 A100 80GB GPUs | Notes |
|:-----------------------------------|:------------------------------------------------------------------|:---------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Learned Positional Embeddings | <pre>model:<br> learned_pos_emb:&nbsp;True</pre>| 65.7 | |
| ALiBi | <pre>model:<br> attn_config:<br> alibi:&nbsp;True</pre>| 64.5 | Requires Triton or Torch attention. |
| ALiBi | <pre>model:<br> attn_config:<br> alibi:&nbsp;True</pre>| 64.5 | Requires Flash (v2.4.2 or higher) or Triton or Torch attention. |
| RoPE (Dao-AILab Implementation) | <pre>model:<br> attn_config:<br> rope:&nbsp;True<br> rope_impl:&nbsp;dail</pre>| 64.5 | Requires a CUDA GPU and the [flash-attn library](https://github.com/Dao-AILab/flash-attention) v2.0.1 or higher to be installed. Please see the instructions in the [paragraph above](#support-for-flashattention-2) on how to install flash-attn v2. Note that the attention implementation can still be `torch`, `triton`, or `flash`. |
| RoPE (Hugging<code>&nbsp;</code>Face Implementation) | <pre>model:<br> attn_config:<br> rope:&nbsp;True<br> rope_impl:&nbsp;hf</pre>| 62.3 | |

Expand Down
15 changes: 15 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from transformers import PretrainedConfig

from llmfoundry.models.layers.attention import (check_alibi_support,
is_flash_v1_installed,
is_flash_v2_installed)
from llmfoundry.models.layers.blocks import attn_config_defaults

Expand Down Expand Up @@ -222,6 +223,20 @@ def _validate_config(self) -> None:
'attn_impl'] not in ['torch', 'triton']:
raise NotImplementedError(
'prefix_lm only implemented with torch and triton attention.')

if self.attn_config['attn_impl'] == 'flash' and is_flash_v1_installed():
warnings.warn(
DeprecationWarning(
'Support for Flash Attention v1 is deprecated. Please upgrade to Flash Attention v2.4.2. To install Flash Attention v2.4.2, please run `pip install -e ".[gpu-flash2]"` from the root directory of the llm-foundry repository.'
))

if self.attn_config[
'attn_impl'] == 'triton' and not self.attn_config['prefix_lm']:
warnings.warn(
UserWarning(
'If not using a Prefix Language Model, we recommend setting "attn_impl" to "flash" instead of "triton".'
))

if self.attn_config['alibi'] and not check_alibi_support(
self.attn_config['attn_impl']):
raise NotImplementedError(
Expand Down
2 changes: 1 addition & 1 deletion scripts/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ The script will use HuggingFace's `device_map=auto` feature to automatically loa
You can also directly specify `--device_map auto` or `--device_map balanced`, etc.
You can also target a specific **single** device using `--device cuda:0` or `--device cpu`, etc.
For MPT models specifically, you can pass args like `--attn_impl triton`, and `--max_seq_len 4096` to speed up generation time or alter the max generation length at inference time (thanks to ALiBi).
For MPT models specifically, you can pass args like `--attn_impl flash`, and `--max_seq_len 4096` to speed up generation time or alter the max generation length at inference time (thanks to ALiBi).
## Interactive Chat with HF models
Expand Down
2 changes: 1 addition & 1 deletion scripts/train/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ train_loader:
# Using Flash Attention <a name="flashattention"></a>

Flash Attention is an optimized implementation of the attention mechanism, first introduced by [Dao et al.](https://github.com/Dao-AILab/flash-attention). There are three versions of Flash Attention that can be used with LLM Foundry: Flash Attention V1, Flash Attention V2, and a Triton implementation of Flash Attention. To start, we recommend using one of our [provided Docker images](../../README.md#mosaicml-docker-images) corresponding to the Flash Attention version you would like to use. The Triton implementation can be used with either Flash Attention V1 or V2. Next, how you specify to use Flash Attention depends on which model you are using.
Flash Attention is an optimized implementation of the attention mechanism, first introduced by [Dao et al.](https://github.com/Dao-AILab/flash-attention). There are three versions of Flash Attention that can be used with LLM Foundry: Flash Attention V1, Flash Attention V2, and a Triton implementation of Flash Attention. The support for Flash Attention V1 has been deprecated, and we recommend using Flash Attention V2. We also recommend using `Flash` attention instead of `Triton` attention, unless you're training Prefix Language Models (in which case we recommend using `Triton`). To start, we recommend using one of our [provided Docker images](../../README.md#mosaicml-docker-images) corresponding to the Flash Attention version you would like to use. The Triton implementation can be used with either Flash Attention V1 or V2. Next, how you specify to use Flash Attention depends on which model you are using.

For MPT, you can specify Flash Attention in your YAML like so:
```yaml
Expand Down

0 comments on commit 169b653

Please sign in to comment.