Skip to content

Commit

Permalink
Enhance the robustness of the flash attention check (#20495)
Browse files Browse the repository at this point in the history
* Enhance the robustness of the flash attention check.

* Fix CI

* Fix CI again

* Fix GPU CI again and again...

* No raise in tests

* Pin coverage==7.6.1

* Fix the comment
  • Loading branch information
james77777778 authored Nov 15, 2024
1 parent cb99d0e commit c014c5e
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 59 deletions.
32 changes: 14 additions & 18 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,29 +982,25 @@ def psnr(x1, x2, max_val):


def _can_use_flash_attention(query, key, value, bias, raise_error=False):
# Ref: https://github.com/jax-ml/jax/blob/main/jax/_src/cudnn/fused_attention_stablehlo.py
from jax._src.cudnn.fused_attention_stablehlo import _normalize_layout
from jax._src.cudnn.fused_attention_stablehlo import check_cudnn_version
from jax._src.cudnn.fused_attention_stablehlo import check_layout

"""Verify the availability of flash attention."""
try:
# The older version of jax doesn't have `check_compute_capability`
from jax._src.cudnn.fused_attention_stablehlo import _normalize_layout
from jax._src.cudnn.fused_attention_stablehlo import (
check_compute_capability,
)
from jax._src.cudnn.fused_attention_stablehlo import check_cudnn_version
from jax._src.cudnn.fused_attention_stablehlo import check_layout
from jax.nn import dot_product_attention as dot_product_attention
except ImportError:
if raise_error:
raise
raise ImportError(
"Flash attention is not supported in your current JAX version. "
"Please update it by following the official guide: "
"https://jax.readthedocs.io/en/latest/installation.html"
)
return False

try:
# `dot_product_attention` is only available in jax>=0.4.31
if not hasattr(jax.nn, "dot_product_attention"):
raise ValueError(
"Flash attention is not supported in your "
"current JAX version. Please update it "
"using `pip install -U jax jaxlib`."
)
# Check if cuDNN is installed and raise RuntimeError if cuDNN is not
# detected
check_cudnn_version()
Expand Down Expand Up @@ -1119,10 +1115,10 @@ def dot_product_attention(
)

if flash_attention:
raise ValueError(
"Flash attention is not supported in your "
"current JAX version. Please update it "
"using `pip install -U jax jaxlib`."
raise RuntimeError(
"Flash attention is not supported in your current JAX version. "
"Please update it by following the official guide: "
"https://jax.readthedocs.io/en/latest/installation.html"
)
# Ref: jax.nn.dot_product_attention
# https://github.com/jax-ml/jax/blob/jax-v0.4.33/jax/_src/nn/functions.py#L886
Expand Down
53 changes: 30 additions & 23 deletions keras/src/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,29 +895,47 @@ def _get_large_negative(dtype):


def _can_use_flash_attention(
query, key, value, mask=None, is_causal=False, debug=False
query, key, value, mask=None, is_causal=False, raise_error=False
):
"""Verify the availability of flash attention."""
try:
spda_params = torch.backends.cuda.SDPAParams(
from torch.backends.cuda import SDPAParams
from torch.backends.cuda import can_use_flash_attention
except ImportError:
if raise_error:
raise ImportError(
"Flash attention is not supported in your current PyTorch "
"version. Please update it by following the official guide: "
"https://pytorch.org/get-started/locally/"
)
return False

try:
spda_params = SDPAParams(
query,
key,
value,
mask,
0.0, # dropout_p
is_causal,
False, # enable_gqa
)
except TypeError:
# The signature changed in newer version of torch.
spda_params = torch.backends.cuda.SDPAParams(
# The old function signature for the older version of PyTorch
spda_params = SDPAParams(
query,
key,
value,
mask,
0.0, # dropout_p
is_causal,
False, # enable_gqa
)
return torch.backends.cuda.can_use_flash_attention(spda_params, debug)
if raise_error and can_use_flash_attention(spda_params, True) is False:
raise RuntimeError(
"Flash attention is not supported with the provided inputs. "
"Please check the warnings for more details."
)
return can_use_flash_attention(spda_params, False)


def dot_product_attention(
Expand All @@ -943,7 +961,6 @@ def dot_product_attention(
f"Received: query.shape={query.shape}, key.shape={key.shape}, "
f"value.shape={value.shape}."
)
bias = bias if bias is None else convert_to_tensor(bias)
mask = mask if mask is None else convert_to_tensor(mask, dtype="bool")
if mask is not None:
# Explicit set `is_causal` to `False` when `mask` is not `None`.
Expand All @@ -957,23 +974,13 @@ def dot_product_attention(

if flash_attention is None:
flash_attention = _can_use_flash_attention(
query=query, key=key, value=value, mask=mask, is_causal=is_causal
)
elif (
flash_attention is True
and _can_use_flash_attention(
query=query,
key=key,
value=value,
mask=mask,
is_causal=is_causal,
debug=True,
query, key, value, mask, is_causal
)
is False
):
raise ValueError(
"Flash attention is not supported with the provided inputs. "
"Please check the warnings for more details."
elif flash_attention is True:
# Use `raise_error=True` to provide more details if the inputs failed to
# use flash attention
_can_use_flash_attention(
query, key, value, mask, is_causal, raise_error=True
)
if flash_attention:
with torch.nn.attention.sdpa_kernel(
Expand Down
61 changes: 43 additions & 18 deletions keras/src/layers/attention/multi_head_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,27 @@ def test_basics_with_flash_attention(self):
run_training_check=False,
)
disable_flash_attention()
except ValueError as e:
self.assertStartsWith(
e.args[0],
"Flash attention is not supported with the provided inputs",
)
except ImportError as e:
if "Flash attention is not supported" in str(e.args[0]):
self.assertTrue(
(
"Flash attention is not supported in your current "
"PyTorch version."
)
in str(e.args[0])
)
except RuntimeError as e:
if (
"Flash attention is not supported with the provided inputs"
in str(e.args[0])
):
self.assertTrue(
(
"Flash attention is not supported with the "
"provided inputs"
)
in str(e.args[0])
)
elif backend.backend() == "jax":
try:
enable_flash_attention()
Expand All @@ -113,20 +129,29 @@ def test_basics_with_flash_attention(self):
run_training_check=False,
)
disable_flash_attention()
except ValueError as e:
self.assertStartsWith(
e.args[0],
(
"Flash attention is not supported in your current JAX "
"version."
),
)
except ImportError as e:
if "Flash attention is not supported" in str(e.args[0]):
self.assertTrue(
(
"Flash attention is not supported in your current "
"JAX version."
)
in str(e.args[0])
)
except RuntimeError as e:
if str(e.args[0]).startswith("cuDNN"):
self.assertStartsWith(e.args[0], "cuDNN is not detected.")
elif str(e.args[0]).startswith("Require at least"):
self.assertStartsWith(
e.args[0], "Require at least Ampere arch to run"
if "cuDNN" in str(e.args[0]):
self.assertTrue("cuDNN is not detected." in str(e.args[0]))
elif "Require at least" in str(e.args[0]):
self.assertTrue(
"Require at least Ampere arch to run" in str(e.args[0])
)
elif "Flash attention" in str(e.args[0]):
self.assertTrue(
(
"Flash attention is not supported in your current "
"JAX version."
)
in str(e.args[0])
)

@parameterized.named_parameters(
Expand Down
3 changes: 3 additions & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@ pytest-cov
packaging
# for tree_test.py
dm_tree
# TODO: Don't pin coverage version. Higher version causes issues:
# https://github.com/nedbat/coveragepy/issues/1891
coverage==7.6.1

0 comments on commit c014c5e

Please sign in to comment.