Skip to content

Commit

Permalink
Fix pinned triton version (NVIDIA#7925)
Browse files Browse the repository at this point in the history
* Fix pinned triton version

Signed-off-by: Cheng-Ping Hsieh <chsieh@nvidia.com>

* Remove comment

Signed-off-by: Cheng-Ping Hsieh <chsieh@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Change README

Signed-off-by: Cheng-Ping Hsieh <chsieh@nvidia.com>

* Remove flash-attn in Dockerfile

Signed-off-by: Cheng-Ping Hsieh <chsieh@nvidia.com>

* Revert

Signed-off-by: Cheng-Ping Hsieh <chsieh@nvidia.com>

---------

Signed-off-by: Cheng-Ping Hsieh <chsieh@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
hsiehjackson and pre-commit-ci[bot] authored Nov 22, 2023
1 parent 23ef428 commit 1a5ee38
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 8 deletions.
4 changes: 1 addition & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,8 @@ WORKDIR /tmp/nemo
COPY requirements .
RUN for f in $(ls requirements*.txt); do pip3 install --disable-pip-version-check --no-cache-dir -r $f; done

# install flash attention dependencies
# install flash attention
RUN pip install flash-attn
# pinned triton version for flash-attention https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L3
RUN pip install triton==2.0.0.dev20221202
# install numba for latest containers
RUN pip install numba>=0.57.1

Expand Down
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ Transformer Engine requires PyTorch to be built with CUDA 11.8.

Flash Attention
~~~~~~~~~~~~~~~~~~~~
Transformer Engine already supports Flash Attention for GPT models. If you want to use Flash Attention for non-causal models or use with attention bias (introduced from position encoding, e.g. Alibi), please install `flash-attn <https://github.com/HazyResearch/flash-attention>`_.
Transformer Engine already supports Flash Attention for GPT models. If you want to use Flash Attention for non-causal models, please install `flash-attn <https://github.com/HazyResearch/flash-attention>`_. If you want to use Flash Attention with attention bias (introduced from position encoding, e.g. Alibi), please also install triton pinned version following the `implementation <https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L3>`_.

.. code-block:: bash
Expand Down
17 changes: 14 additions & 3 deletions nemo/collections/nlp/modules/common/megatron/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,23 @@

HAVE_MEGATRON_CORE = False

try:
# Flash Attention Triton
import pkg_resources
from flash_attn.flash_attn_triton import flash_attn_func as flash_attn_func_triton

# pinned triton version for flash-attention triton https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L3
assert pkg_resources.get_distribution("triton").version == '2.0.0.dev20221202'

except (ImportError, ModuleNotFoundError, AssertionError):

flash_attn_func_triton = None


try:
# Flash Attention 1.X
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
from flash_attn.flash_attn_triton import flash_attn_func as flash_attn_func_triton

HAVE_FLASH_ATTENTION = True
flash_attn_func = None
Expand All @@ -85,8 +97,7 @@
except (ImportError, ModuleNotFoundError):

HAVE_FLASH_ATTENTION = False

flash_attn_unpadded_func, flash_attn_func_triton, flash_attn_func = None, None, None
flash_attn_unpadded_func, flash_attn_func = None, None
unpad_input, pad_input = None, None

try:
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ tensorboard
text-unidecode
torch
tqdm>=4.41.0
triton
wget
wrapt
6 changes: 5 additions & 1 deletion tests/collections/nlp/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,14 @@
HAVE_FA = False

try:
import pkg_resources
import triton

# pinned triton version for flash-attention triton https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L3
assert pkg_resources.get_distribution("triton").version == '2.0.0.dev20221202'

HAVE_TRITON = True
except (ImportError, ModuleNotFoundError):
except (ImportError, ModuleNotFoundError, AssertionError):
HAVE_TRITON = False

try:
Expand Down

0 comments on commit 1a5ee38

Please sign in to comment.