Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pinned triton version #7925

Merged
merged 7 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,8 @@ WORKDIR /tmp/nemo
COPY requirements .
RUN for f in $(ls requirements*.txt); do pip3 install --disable-pip-version-check --no-cache-dir -r $f; done

# install flash attention dependencies
# install flash attention
RUN pip install flash-attn
# pinned triton version for flash-attention https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L3
RUN pip install triton==2.0.0.dev20221202
# install numba for latest containers
RUN pip install numba>=0.57.1

Expand Down
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ Transformer Engine requires PyTorch to be built with CUDA 11.8.

Flash Attention
~~~~~~~~~~~~~~~~~~~~
Transformer Engine already supports Flash Attention for GPT models. If you want to use Flash Attention for non-causal models or use with attention bias (introduced from position encoding, e.g. Alibi), please install `flash-attn <https://github.com/HazyResearch/flash-attention>`_.
Transformer Engine already supports Flash Attention for GPT models. If you want to use Flash Attention for non-causal models, please install `flash-attn <https://github.com/HazyResearch/flash-attention>`_. If you want to use Flash Attention with attention bias (introduced from position encoding, e.g. Alibi), please also install triton pinned version following the `implementation <https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L3>`_.

.. code-block:: bash

Expand Down
17 changes: 14 additions & 3 deletions nemo/collections/nlp/modules/common/megatron/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,23 @@

HAVE_MEGATRON_CORE = False

try:
# Flash Attention Triton
import pkg_resources
from flash_attn.flash_attn_triton import flash_attn_func as flash_attn_func_triton

# pinned triton version for flash-attention triton https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L3
assert pkg_resources.get_distribution("triton").version == '2.0.0.dev20221202'

except (ImportError, ModuleNotFoundError, AssertionError):

flash_attn_func_triton = None


try:
# Flash Attention 1.X
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
from flash_attn.flash_attn_triton import flash_attn_func as flash_attn_func_triton

HAVE_FLASH_ATTENTION = True
flash_attn_func = None
Expand All @@ -85,8 +97,7 @@
except (ImportError, ModuleNotFoundError):

HAVE_FLASH_ATTENTION = False

flash_attn_unpadded_func, flash_attn_func_triton, flash_attn_func = None, None, None
flash_attn_unpadded_func, flash_attn_func = None, None
unpad_input, pad_input = None, None

try:
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ tensorboard
text-unidecode
torch
tqdm>=4.41.0
triton
wget
wrapt
6 changes: 5 additions & 1 deletion tests/collections/nlp/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,14 @@
HAVE_FA = False

try:
import pkg_resources
import triton

# pinned triton version for flash-attention triton https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L3
assert pkg_resources.get_distribution("triton").version == '2.0.0.dev20221202'

HAVE_TRITON = True
except (ImportError, ModuleNotFoundError):
except (ImportError, ModuleNotFoundError, AssertionError):
HAVE_TRITON = False

try:
Expand Down