Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance the robustness of the flash attention check #20495

Merged
merged 7 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 14 additions & 18 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,29 +973,25 @@ def psnr(x1, x2, max_val):


def _can_use_flash_attention(query, key, value, bias, raise_error=False):
# Ref: https://github.com/jax-ml/jax/blob/main/jax/_src/cudnn/fused_attention_stablehlo.py
from jax._src.cudnn.fused_attention_stablehlo import _normalize_layout
from jax._src.cudnn.fused_attention_stablehlo import check_cudnn_version
from jax._src.cudnn.fused_attention_stablehlo import check_layout

"""Verify the availability of flash attention."""
try:
# The older version of jax doesn't have `check_compute_capability`
from jax._src.cudnn.fused_attention_stablehlo import _normalize_layout
from jax._src.cudnn.fused_attention_stablehlo import (
check_compute_capability,
)
from jax._src.cudnn.fused_attention_stablehlo import check_cudnn_version
from jax._src.cudnn.fused_attention_stablehlo import check_layout
from jax.nn import dot_product_attention as dot_product_attention
except ImportError:
if raise_error:
raise
raise ImportError(
"Flash attention is not supported in your current JAX version. "
"Please update it by following the official guide: "
"https://jax.readthedocs.io/en/latest/installation.html"
)
return False

try:
# `dot_product_attention` is only available in jax>=0.4.31
if not hasattr(jax.nn, "dot_product_attention"):
raise ValueError(
"Flash attention is not supported in your "
"current JAX version. Please update it "
"using `pip install -U jax jaxlib`."
)
# Check if cuDNN is installed and raise RuntimeError if cuDNN is not
# detected
check_cudnn_version()
Expand Down Expand Up @@ -1110,10 +1106,10 @@ def dot_product_attention(
)

if flash_attention:
raise ValueError(
"Flash attention is not supported in your "
"current JAX version. Please update it "
"using `pip install -U jax jaxlib`."
raise RuntimeError(
"Flash attention is not supported in your current JAX version. "
"Please update it by following the official guide: "
"https://jax.readthedocs.io/en/latest/installation.html"
)
# Ref: jax.nn.dot_product_attention
# https://github.com/jax-ml/jax/blob/jax-v0.4.33/jax/_src/nn/functions.py#L886
Expand Down
53 changes: 30 additions & 23 deletions keras/src/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,29 +890,47 @@ def _get_large_negative(dtype):


def _can_use_flash_attention(
query, key, value, mask=None, is_causal=False, debug=False
query, key, value, mask=None, is_causal=False, raise_error=False
):
"""Verify the availability of flash attention."""
try:
spda_params = torch.backends.cuda.SDPAParams(
from torch.backends.cuda import SDPAParams
from torch.backends.cuda import can_use_flash_attention
except ImportError:
if raise_error:
raise ImportError(
"Flash attention is not supported in your current PyTorch "
"version. Please update it by following the official guide: "
"https://pytorch.org/get-started/locally/"
)
return False

try:
spda_params = SDPAParams(
query,
key,
value,
mask,
0.0, # dropout_p
is_causal,
False, # enable_gqa
)
except TypeError:
# The signature changed in newer version of torch.
spda_params = torch.backends.cuda.SDPAParams(
# The old function signature for the older version of PyTorch
spda_params = SDPAParams(
query,
key,
value,
mask,
0.0, # dropout_p
is_causal,
False, # enable_gqa
)
return torch.backends.cuda.can_use_flash_attention(spda_params, debug)
if raise_error and can_use_flash_attention(spda_params, True) is False:
raise RuntimeError(
"Flash attention is not supported with the provided inputs. "
"Please check the warnings for more details."
)
return can_use_flash_attention(spda_params, False)


def dot_product_attention(
Expand All @@ -938,7 +956,6 @@ def dot_product_attention(
f"Received: query.shape={query.shape}, key.shape={key.shape}, "
f"value.shape={value.shape}."
)
bias = bias if bias is None else convert_to_tensor(bias)
mask = mask if mask is None else convert_to_tensor(mask, dtype="bool")
if mask is not None:
# Explicit set `is_causal` to `False` when `mask` is not `None`.
Expand All @@ -952,23 +969,13 @@ def dot_product_attention(

if flash_attention is None:
flash_attention = _can_use_flash_attention(
query=query, key=key, value=value, mask=mask, is_causal=is_causal
)
elif (
flash_attention is True
and _can_use_flash_attention(
query=query,
key=key,
value=value,
mask=mask,
is_causal=is_causal,
debug=True,
query, key, value, mask, is_causal
)
is False
):
raise ValueError(
"Flash attention is not supported with the provided inputs. "
"Please check the warnings for more details."
elif flash_attention is True:
# Use `raise_error=True` to provide more details if the inputs failed to
# use flash attention
_can_use_flash_attention(
query, key, value, mask, is_causal, raise_error=True
)
if flash_attention:
with torch.nn.attention.sdpa_kernel(
Expand Down
61 changes: 43 additions & 18 deletions keras/src/layers/attention/multi_head_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,27 @@ def test_basics_with_flash_attention(self):
run_training_check=False,
)
disable_flash_attention()
except ValueError as e:
self.assertStartsWith(
e.args[0],
"Flash attention is not supported with the provided inputs",
)
except ImportError as e:
if "Flash attention is not supported" in str(e.args[0]):
self.assertTrue(
(
"Flash attention is not supported in your current "
"PyTorch version."
)
in str(e.args[0])
)
except RuntimeError as e:
if (
"Flash attention is not supported with the provided inputs"
in str(e.args[0])
):
self.assertTrue(
(
"Flash attention is not supported with the "
"provided inputs"
)
in str(e.args[0])
)
elif backend.backend() == "jax":
try:
enable_flash_attention()
Expand All @@ -113,20 +129,29 @@ def test_basics_with_flash_attention(self):
run_training_check=False,
)
disable_flash_attention()
except ValueError as e:
self.assertStartsWith(
e.args[0],
(
"Flash attention is not supported in your current JAX "
"version."
),
)
except ImportError as e:
if "Flash attention is not supported" in str(e.args[0]):
self.assertTrue(
(
"Flash attention is not supported in your current "
"JAX version."
)
in str(e.args[0])
)
except RuntimeError as e:
if str(e.args[0]).startswith("cuDNN"):
self.assertStartsWith(e.args[0], "cuDNN is not detected.")
elif str(e.args[0]).startswith("Require at least"):
self.assertStartsWith(
e.args[0], "Require at least Ampere arch to run"
if "cuDNN" in str(e.args[0]):
self.assertTrue("cuDNN is not detected." in str(e.args[0]))
elif "Require at least" in str(e.args[0]):
self.assertTrue(
"Require at least Ampere arch to run" in str(e.args[0])
)
elif "Flash attention" in str(e.args[0]):
self.assertTrue(
(
"Flash attention is not supported in your current "
"JAX version."
)
in str(e.args[0])
)

@parameterized.named_parameters(
Expand Down
3 changes: 3 additions & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@ pytest-cov
packaging
# for tree_test.py
dm_tree
# TODO: Don't pin coverage version. Higher version causes issues:
# https://github.com/nedbat/coveragepy/issues/1891
coverage==7.6.1