diff --git a/axlearn/common/flash_attention/layer.py b/axlearn/common/flash_attention/layer.py index 36da8d59..d2500304 100644 --- a/axlearn/common/flash_attention/layer.py +++ b/axlearn/common/flash_attention/layer.py @@ -20,6 +20,7 @@ make_segment_mask, ) from axlearn.common.config import config_class +from axlearn.common.flash_attention import tpu_attention from axlearn.common.flash_attention.utils import ( MultiHeadAttentionImpl, flash_attention_implementation, @@ -169,10 +170,6 @@ def _compute_attention( cfg = self.config backend = self._backend() - # Repeats key/value heads dim if necessary. - k_proj = self._repeat_kv_heads(k_proj) - v_proj = self._repeat_kv_heads(v_proj) - batch, target_len, num_heads, _ = q_proj.shape _, source_len, _, _ = k_proj.shape @@ -228,7 +225,18 @@ def _compute_attention( f"{k_proj.shape[1]} for correctly supported GPU flash attention usage." ) - if backend == "tpu": + if backend == "cpu" and not tpu_attention.check_tpu_splash_attention( + query=q_proj, + key=k_proj, + has_mask=bool(cfg.mask), + segment_ids=segment_ids, + has_bias=(attention_logit_biases is not None), + ): + backend = "xla" + + if backend in ("tpu", "cpu"): + # Splash attention needs to know sliding_window_size. + mask_fn = cfg.mask assert q_proj.shape[1] % cfg.tpu_block_size == 0, ( f"Target seq len {q_proj.shape[1]} must be " f"divisible by block size {cfg.tpu_block_size}." @@ -263,6 +271,12 @@ def _compute_attention( q_proj = self.scale_query(q_proj) k_proj = self.scale_key(k_proj) + # TODO(dhwang2): splash attention supports GQA natively, so don't repeat with proper shard. + # https://github.com/jax-ml/jax/blob/7b9914d711593dca8725d46aa1dadb2194284519/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py#L934 + # Repeats key/value heads dim if necessary. + k_proj = self._repeat_kv_heads(k_proj) + v_proj = self._repeat_kv_heads(v_proj) + # Constrain input to conform to partitioned MHA expectations. q_proj = with_sharding_constraint(q_proj, cfg.mha_dim_to_partition_spec["btnh"]) k_proj = with_sharding_constraint(k_proj, cfg.mha_dim_to_partition_spec["bsnh"]) diff --git a/axlearn/common/flash_attention/layer_test.py b/axlearn/common/flash_attention/layer_test.py index 52d1bc4e..5b9e4d2f 100644 --- a/axlearn/common/flash_attention/layer_test.py +++ b/axlearn/common/flash_attention/layer_test.py @@ -16,7 +16,7 @@ import jax import jax.numpy as jnp import pytest -from absl.testing import parameterized +from absl.testing import absltest, parameterized from jax.experimental import mesh_utils from jax.sharding import Mesh @@ -91,6 +91,7 @@ def _prepare_layers( sliding_window_size, inference=False, set_layer_bias_recursively=False, + tpu_block_size=512, ): hidden_dim = num_heads * per_head_dim kwargs = dict( @@ -110,6 +111,7 @@ def _prepare_layers( .set( mha_dim_to_partition_spec=default_mha_dim_to_partition_spec(mesh_axis_names), output_dim_to_partition_spec=default_output_dim_to_partition_spec(mesh_axis_names), + tpu_block_size=tpu_block_size, ) ) if inference: @@ -378,7 +380,9 @@ def test_forward( mesh_axis_names=mesh_axis_names, causal=causal, sliding_window_size=sliding_window_size, + tpu_block_size=128, ) + # pylint: disable-next=protected-access if test_layer._backend() == "gpu" and query_len_multiplier != 1: pytest.skip( @@ -734,3 +738,7 @@ def test_extend_step( atol=2e-2, ) jax.clear_backends() + + +if __name__ == "__main__": + absltest.main() diff --git a/axlearn/common/flash_attention/tpu_attention.py b/axlearn/common/flash_attention/tpu_attention.py index dff77795..e6ded5b9 100644 --- a/axlearn/common/flash_attention/tpu_attention.py +++ b/axlearn/common/flash_attention/tpu_attention.py @@ -31,6 +31,7 @@ splash_attention_mask, ) +from axlearn.common import attention, config from axlearn.common.attention import MaskFn, apply_attention_logit_biases, bool_to_bias, causal_mask from axlearn.common.utils import Tensor @@ -42,9 +43,10 @@ def tpu_flash_attention( bias: Tensor = None, # [batch_size, num_heads, source_len, target_len] segment_ids: Tensor = None, # [batch_size, source_len] *, - mask: Optional[MaskFn] = None, + mask: config.ConfigOr[Optional[MaskFn]] = None, softmax_scale: float = 1.0, block_size: int = 128, + interpret: bool = False, ): """Wraps JAX's TPU flash-attention, with reshapes and softmax-scaling outside kernel. @@ -66,6 +68,7 @@ def tpu_flash_attention( mask: The mask to apply. This is more compute efficient compared to setting bias = -inf. softmax_scale: A scaling factor applied to the query. block_size: The block size to use for chunking data in the kernel. + interpret: If True, interpret the kernel using the pallas interpreter. CPU needs it. Returns: The context tensor, of shape [batch_size, source_len, num_heads, head_dim]. @@ -97,6 +100,19 @@ def tpu_flash_attention( key = jnp.einsum("bsnh->bnsh", key) value = jnp.einsum("bsnh->bnsh", value) try: + if not check_tpu_splash_attention( + query=query, + key=key, + has_mask=bool(mask), + segment_ids=segment_ids, + has_bias=(bias is not None), + ): + raise SplashAttentionUnsupportedError( + f"SplashAttention does not support with arguments: {query.shape=}, {key.shape=}, " + f"{bool(mask)=}, {bool(segment_ids)=}, {bool(bias)=}." + ) + mask_shape = (query.shape[2], key.shape[2]) + mask_fn = to_splash_mask(mask, mask_shape=mask_shape) block_sizes = splash_attention_kernel.BlockSizes( block_q=block_size, block_kv=block_size, @@ -110,10 +126,18 @@ def tpu_flash_attention( use_fused_bwd_kernel=True, ) context = _tpu_splash_attention( - query, key, value, bias, segment_ids=segment_ids, mask=mask, block_sizes=block_sizes + query, + key, + value, + segment_ids=segment_ids, + mask=mask_fn, + block_sizes=block_sizes, + interpret=interpret, ) logging.info("Using SplashAttention.") except SplashAttentionUnsupportedError as e: + if isinstance(mask, config.InstantiableConfig): + mask = mask.instantiate() # TODO(tom_gunter): See if we can do better block-size tuning. block_sizes = LegacyBlockSizes( block_q=block_size, @@ -136,6 +160,7 @@ def tpu_flash_attention( segment_ids=segment_ids, mask=mask, block_sizes=block_sizes, + interpret=interpret, ) logging.warning( "Falling back to legacy flash attention because SplashAttention is not supported.\n" @@ -152,6 +177,7 @@ def tpu_flash_attention( static_argnames=[ "mask", # Mask objects don't actually contain jax arrays, so they are static. "block_sizes", + "interpret", ], ) def _legacy_tpu_flash_attention( @@ -163,6 +189,7 @@ def _legacy_tpu_flash_attention( *, mask: Optional[MaskFn] = None, block_sizes: Optional[LegacyBlockSizes] = None, + interpret: bool = False, ) -> Tensor: # [batch_size, num_heads, source_len, head_dim]. """Wraps JAX's legacy TPU flash-attention. @@ -177,6 +204,7 @@ def _legacy_tpu_flash_attention( mask: The mask to apply. This is more compute efficient compared to setting bias = -inf. block_sizes: An object containing values that can be used to tune the performance such as the block size to chunk things into. + interpret: If True, interpret the kernel using the pallas interpreter. CPU needs it. Returns: The context tensor, of shape [batch_size, num_heads, source_len, head_dim]. @@ -203,6 +231,7 @@ def _legacy_tpu_flash_attention( sm_scale=1.0, block_sizes=block_sizes, debug=False, + interpret=interpret, ) return context @@ -212,22 +241,59 @@ class SplashAttentionUnsupportedError(NotImplementedError): """An error indicating splash attention is not supported.""" +def check_tpu_splash_attention( + *, + query: Tensor, # [batch_size, num_heads, source_len, head_dim] + key: Tensor, # [batch_size, num_heads, target_len, head_dim] + has_mask: bool, + segment_ids: Tensor = None, # [batch_size, source_len] + has_bias: bool = False, +) -> bool: + """Checks if splash attention is supported for the given arguments. + + Args: + query: The query tensor, of shape [batch_size, num_heads, source_len, head_dim]. + key: The key tensor, of shape [batch_size, num_heads, target_len, head_dim]. + has_mask: bool, whether to have a mask. + segment_ids: The id of which segment each token belongs to. Attention is not computed + between tokens in different segments, [batch_size, source_len]. + has_bias: bool, whether to have a bias. + + Returns: + True if splash attention is supported, False otherwise. + """ + source_len = query.shape[2] + target_len = key.shape[2] + head_dim = query.shape[3] + if has_bias: + return False + if head_dim % splash_attention_kernel.NUM_LANES != 0: + return False + if segment_ids is not None: + return False + # TODO(dhwang2): it's not necessary. We can handle the case by mask offset. + if source_len != target_len and has_mask: + return False + return True + + @functools.partial( jax.jit, static_argnames=[ "mask", # Mask objects don't actually contain jax arrays, so they are static. "block_sizes", + "interpret", ], ) def _tpu_splash_attention( query: Tensor, # [batch_size, num_heads, source_len, head_dim] key: Tensor, # [batch_size, num_heads, target_len, head_dim] value: Tensor, # [batch_size, num_heads, target_len, head_dim] - bias: Tensor = None, # [batch_size, num_heads, source_len, target_len] - segment_ids: Tensor = None, # [batch_size, source_len] *, - mask: Optional[MaskFn] = None, + mask: splash_attention_mask.Mask, + segment_ids: Tensor = None, # [batch_size, source_len] block_sizes: Optional[splash_attention_kernel.BlockSizes] = None, + interpret: bool = False, ) -> Tensor: # [batch_size, num_heads, source_len, head_dim]. """Wraps JAX's sparse TPU flash-attention. @@ -235,13 +301,13 @@ def _tpu_splash_attention( query: The query tensor, of shape [batch_size, num_heads, source_len, head_dim]. key: The key tensor, of shape [batch_size, num_heads, target_len, head_dim]. value: The value tensor, of shape [batch_size, num_heads, source_len, head_dim]. - bias: The attention biases, of shape [batch_size, num_heads, source_len, target_len]. - segment_ids: The id of which segment each token belongs to. Attention is not computed - between tokens in different segments. - Shape: [batch_size, source_len]. + bias: must be None. mask: The mask to apply. This is more compute efficient compared to setting bias = -inf. + segment_ids: The id of which segment each token belongs to. Attention is not computed + between tokens in different segments, [batch_size, source_len]. block_sizes: An object containing values that can be used to tune the performance such as the block size to chunk things into. + interpret: If True, interpret the kernel using the pallas interpreter. CPU needs it. Returns: The context tensor, of shape [batch_size, num_heads, source_len, head_dim]. @@ -250,55 +316,55 @@ def _tpu_splash_attention( NotImplementedError: If a bias is also specified or the head_dim is not divisible by 128. """ - - source_len = query.shape[2] - target_len = key.shape[2] + # TODO(dhwang2): splash attention can support segment_ids. Support it when needed. + del segment_ids num_heads = query.shape[1] - head_dim = query.shape[3] - - if bias is not None: - raise SplashAttentionUnsupportedError("SplashAttention does not support specifying a bias.") - if head_dim % splash_attention_kernel.NUM_LANES != 0: - raise SplashAttentionUnsupportedError( - "SplashAttention requires " - f"head_dim=={splash_attention_kernel.NUM_LANES}, " - f"got {head_dim}." - ) - if segment_ids is not None: - raise SplashAttentionUnsupportedError( - "The public API for SplashAttention that we " - "currently use does not support segment ids." - ) - if source_len != target_len and mask is not None: - raise SplashAttentionUnsupportedError( - "Query and key/value must have same length when mask is used." - ) - - mask_shape = (source_len, target_len) - if mask is None: - mask = splash_attention_mask.FullMask(mask_shape) - else: - # Use fewer bytes for the mask. - rows = np.arange(source_len, dtype=np.int32) - cols = np.arange(target_len, dtype=np.int32) - with jax.ensure_compile_time_eval(): - mask_array = np.asarray(mask(rows[:, None], cols[None, :])) - - # NumpyMask is backed by a dense [source_len, target_len] numpy array. - # May consume a large amount of host memory for long sequences at compile time. - mask = splash_attention_mask.NumpyMask(array=mask_array) kernel = splash_attention_kernel.make_splash_mha( mask=splash_attention_mask.MultiHeadMask(masks=[mask] * num_heads), block_sizes=block_sizes, + # TODO(dhwang2): support "seq" and "model" shard. head_shards=1, q_seq_shards=1, + interpret=interpret, ) kernel = jax.vmap(kernel) context = kernel(q=query, k=key, v=value) return context +def to_splash_mask( + mask: config.ConfigOr[Optional[MaskFn]], *, mask_shape: tuple[int, int], q_seq_shards: int = 1 +) -> splash_attention_mask.Mask: + """Converts a mask to a splash mask.""" + if mask is None: + return splash_attention_mask.FullMask(mask_shape) + if mask is causal_mask: + return splash_attention_mask.CausalMask(shape=mask_shape, shard_count=q_seq_shards) + if ( + isinstance(mask, config.InstantiableConfig) + and mask.fn is attention.sliding_window_causal_mask + ): + window_size = (mask.sliding_window_size, 0) + return splash_attention_mask.LocalMask( + shape=mask_shape, window_size=window_size, offset=0, shard_count=q_seq_shards + ) + + # NumpyMask fallback. Note: It consumes O(T^2) memory. + cols = np.arange(mask_shape[1], dtype=np.int32) + + # Use fewer bytes for the mask. + source_len, target_len = mask_shape + rows = np.arange(source_len, dtype=np.int32) + cols = np.arange(target_len, dtype=np.int32) + with jax.ensure_compile_time_eval(): + mask_array = np.asarray(mask(rows[:, None], cols[None, :])) + + # NumpyMask is backed by a dense [source_len, target_len] numpy array. + # May consume a large amount of host memory for long sequences at compile time. + return splash_attention_mask.NumpyMask(array=mask_array) + + # The following code is adapted from jax-ml/jax: # Copyright 2023 The JAX Authors. # Licensed under the Apache License, Version 2.0 (the "License"). @@ -311,6 +377,7 @@ def _tpu_splash_attention( "sm_scale", "block_sizes", "debug", + "interpret", ], ) def pallas_tpu_flash_attention( @@ -324,6 +391,7 @@ def pallas_tpu_flash_attention( sm_scale: float = 1.0, block_sizes: Optional[LegacyBlockSizes] = None, debug: bool = False, + interpret: bool = False, ): batch_size, num_heads, q_seq_len, d_model = q.shape batch_size_k, num_heads_k, kv_seq_len, d_model_k = k.shape @@ -372,10 +440,12 @@ def pallas_tpu_flash_attention( block_sizes = LegacyBlockSizes.get_default( batch_size, num_heads, q_seq_len, kv_seq_len, d_model ) - return _flash_attention(q, k, v, ab, segment_ids, False, causal, sm_scale, block_sizes, debug) + return _flash_attention( + q, k, v, ab, segment_ids, False, causal, sm_scale, block_sizes, debug, interpret + ) -@functools.partial(jax.custom_vjp, nondiff_argnums=range(5, 10)) +@functools.partial(jax.custom_vjp, nondiff_argnums=range(5, 11)) def _flash_attention( q, k, @@ -387,6 +457,7 @@ def _flash_attention( sm_scale, block_sizes, debug, + interpret, ): return _flash_attention_impl( q, @@ -402,6 +473,7 @@ def _flash_attention( block_sizes.block_k_major, block_sizes.block_k, debug, + interpret, ) @@ -416,10 +488,13 @@ def _flash_attention_fwd( sm_scale, block_sizes, debug, + interpret, ): if save_residuals: raise NotImplementedError("Higher-order AD not supported") - o, l, m = _flash_attention(q, k, v, ab, segment_ids, True, causal, sm_scale, block_sizes, debug) + o, l, m = _flash_attention( + q, k, v, ab, segment_ids, True, causal, sm_scale, block_sizes, debug, interpret + ) return o, (q, k, v, ab, segment_ids, o, l, m) @@ -429,6 +504,7 @@ def _flash_attention_bwd( sm_scale: float, block_sizes: LegacyBlockSizes, debug: bool, + interpret: bool, residuals, do, ): @@ -463,6 +539,7 @@ def _flash_attention_bwd( causal=causal, mask_value=DEFAULT_MASK_VALUE, debug=debug, + interpret=interpret, ) dq, ds = _flash_attention_bwd_dq( @@ -482,6 +559,7 @@ def _flash_attention_bwd( causal=causal, mask_value=DEFAULT_MASK_VALUE, debug=debug, + interpret=interpret, ) return dq, dk, dv, ds, None @@ -503,6 +581,7 @@ def _flash_attention_impl( block_k_major, block_k, debug, + interpret, ): batch_size, num_heads, q_seq_len, head_dim = q.shape _, _, kv_seq_len, _ = k.shape @@ -665,6 +744,7 @@ def kv_segment_ids_index_map(batch_index, head_index, q_seq_index, kv_seq_index) ), out_shape=out_shape, debug=debug, + interpret=interpret, compiler_params=dict( mosaic=dict( dimension_semantics=( @@ -702,6 +782,7 @@ def _flash_attention_bwd_dkv( causal: bool = False, mask_value: float = DEFAULT_MASK_VALUE, debug: bool = False, + interpret: bool = False, ): batch_size, num_heads, q_seq_len, head_dim = q.shape _, _, kv_seq_len, _ = k.shape @@ -868,6 +949,7 @@ def dkv_index_map(batch_index, head_index, kv_seq_index, _): ), out_shape=out_shapes, debug=debug, + interpret=interpret, compiler_params=dict( mosaic=dict( dimension_semantics=( @@ -902,6 +984,7 @@ def _flash_attention_bwd_dq( causal: bool, mask_value: float, debug: bool, + interpret: bool, ): batch_size, num_heads, q_seq_len, head_dim = q.shape _, _, kv_seq_len, _ = k.shape @@ -1059,6 +1142,7 @@ def kv_segment_ids_index_map(batch_index, head_index, q_seq_index, kv_seq_index) ), out_shape=out_shapes, debug=debug, + interpret=interpret, compiler_params=dict( mosaic=dict( dimension_semantics=( diff --git a/axlearn/common/flash_attention/tpu_attention_benchmark.py b/axlearn/common/flash_attention/tpu_attention_benchmark.py index e203d27b..0522f4a6 100644 --- a/axlearn/common/flash_attention/tpu_attention_benchmark.py +++ b/axlearn/common/flash_attention/tpu_attention_benchmark.py @@ -38,6 +38,7 @@ import jax import jax.numpy as jnp +from axlearn.common import config from axlearn.common.attention import causal_mask, sliding_window_causal_mask from axlearn.common.flash_attention.utils import flash_attention_implementation, mha_reference @@ -131,7 +132,9 @@ def _benchmark( if causal and sliding_window_size is None: mask = causal_mask elif causal: - mask = sliding_window_causal_mask(sliding_window_size) + mask = config.config_for_function(sliding_window_causal_mask).set( + sliding_window_size=sliding_window_size + ) # Get fwd & bwd timing information when softmax scaling applied before calling the kernel. mha_impl = flash_attention_implementation( diff --git a/axlearn/common/flash_attention/tpu_attention_test.py b/axlearn/common/flash_attention/tpu_attention_test.py index 1ba38d7a..1936bf88 100644 --- a/axlearn/common/flash_attention/tpu_attention_test.py +++ b/axlearn/common/flash_attention/tpu_attention_test.py @@ -5,11 +5,12 @@ import unittest +import chex import jax import jax.numpy as jnp import numpy as np import pytest -from absl.testing import parameterized +from absl.testing import absltest, parameterized from jax.experimental import mesh_utils from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask from jax.experimental.shard_map import shard_map @@ -17,13 +18,18 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec from axlearn.common.attention import causal_mask, sliding_window_causal_mask +from axlearn.common.config import config_for_function from axlearn.common.flash_attention import tpu_attention from axlearn.common.flash_attention.utils import mha_reference from axlearn.common.test_utils import TestCase, is_supported_mesh_shape from axlearn.common.utils import Tensor -if jax.default_backend() != "tpu": - pytest.skip(reason="Incompatible hardware", allow_module_level=True) + +def setUpModule(): + chex.set_n_cpu_devices(4) + # Comment out to test on CPU. + if jax.default_backend() != "tpu": + pytest.skip(reason="Incompatible hardware", allow_module_level=True) def jax_fn_mask(query_position: Tensor, key_position: Tensor) -> Tensor: @@ -71,19 +77,21 @@ def test_sliding_window_mask_equivalence(self, seq_len, sliding_window_size): @parameterized.product( batch_size=[4], - seq_len=[32768], + seq_len=[1024, 32768], + mask_fn=[None, "causal", "sliding", "sliding_fn"], sliding_window_size=[1024], num_heads=[4], per_head_dim=[256], mesh=[(4, 1)], mesh_axis_names=[("data", "model")], ) - def test_sliding_window_mask( + def test_forward( self, batch_size, seq_len, num_heads, per_head_dim, + mask_fn, sliding_window_size, mesh, mesh_axis_names, @@ -117,10 +125,23 @@ def fn(q, k, v): ) softmax_scale = q.shape[-1] ** -0.5 - mask = sliding_window_causal_mask(sliding_window_size) - + if mask_fn is None: + mask = None + elif mask_fn == "causal": + mask = causal_mask + elif mask_fn.startswith("sliding"): + mask = config_for_function(sliding_window_causal_mask).set( + sliding_window_size=sliding_window_size + ) + if mask_fn == "sliding_fn": + mask = mask.instantiate() attn = lambda q, k, v: tpu_attention.tpu_flash_attention( - q, k, v, mask=mask, softmax_scale=softmax_scale + q, + k, + v, + mask=mask, + softmax_scale=softmax_scale, + interpret=(jax.default_backend() == "cpu"), ) partitioned_mha = shard_map( @@ -209,9 +230,17 @@ def fn(q, k, v, bias, ids): ) with record_legacy_call: return tpu_attention.tpu_flash_attention( - q, k, v, bias, ids, mask=mask, softmax_scale=softmax_scale + q, + k, + v, + bias, + ids, + mask=mask, + softmax_scale=softmax_scale, + interpret=(jax.default_backend() == "cpu"), ) + # TODO(dhwang2): this has been broken for a while on CPU. # Compare outputs. out = fn(q, k, v, attention_bias, segment_ids) ref_out = ref_fn(q, k, v, attention_bias, segment_ids) @@ -231,3 +260,7 @@ def fn(q, k, v, bias, ids): legacy_flash_wrapper.assert_called() else: legacy_flash_wrapper.assert_not_called() + + +if __name__ == "__main__": + absltest.main() diff --git a/axlearn/common/flash_attention/utils.py b/axlearn/common/flash_attention/utils.py index 7da0543f..5c03ade4 100644 --- a/axlearn/common/flash_attention/utils.py +++ b/axlearn/common/flash_attention/utils.py @@ -8,6 +8,7 @@ import jax.numpy as jnp from absl import logging +from axlearn.common import config from axlearn.common.attention import NEG_INF, MaskFn, causal_mask, softmax_with_biases from axlearn.common.flash_attention.gpu_attention import cudnn_dot_product_attention from axlearn.common.flash_attention.gpu_attention import flash_attention as gpu_flash_attention @@ -77,7 +78,7 @@ def mha_reference( def flash_attention_implementation( backend: Literal["cpu", "tpu", "gpu", "xla"], *, - mask: Optional[MaskFn] = None, + mask: config.ConfigOr[Optional[MaskFn]] = None, softmax_scale: float, block_size: int = 128, ) -> MultiHeadAttentionImpl: @@ -98,14 +99,16 @@ def flash_attention_implementation( Raises: NotImplementedError: If implementation for the backend is not available. """ - causal = mask is causal_mask - if mask is not None and not causal and backend != "tpu": - raise NotImplementedError( - "Custom (non-causal, non-full) mask only supported on TPU.\n" - "You can use NEG_INF biases instead, but it won't " - "have the sparsity optimizations." - ) if backend == "gpu": + mask: Optional[MaskFn] = mask + causal = mask is causal_mask + if mask is not None and not causal: + raise NotImplementedError( + "Custom (non-causal, non-full) mask is not supported on GPU.\n" + "You can use NEG_INF biases instead, but it won't " + "have the sparsity optimizations." + ) + # shard_map-decorated function needs to be jitted. @jax.jit def jit_attn(query, key, value, bias, segment_ids): @@ -141,7 +144,10 @@ def jit_attn(query, key, value, bias, segment_ids): return jit_attn - elif backend == "tpu": + elif backend in ("tpu", "cpu"): + if backend == "cpu": + logging.warning("Flash attention CPU backend is for testing only.") + # shard_map-decorated function needs to be jitted. @jax.jit def jit_attn(query, key, value, bias, segment_ids): @@ -154,14 +160,15 @@ def jit_attn(query, key, value, bias, segment_ids): mask=mask, softmax_scale=softmax_scale, block_size=block_size, + interpret=(backend == "cpu"), ) return context return jit_attn - elif backend in ("cpu", "xla"): - if backend == "cpu": - logging.warning("Flash attention CPU backend is for testing only.") + elif backend == "xla": + mask: Optional[MaskFn] = mask + causal = mask is causal_mask logging.warning("Flash attention falling back using plain MHA implementation") # shard_map-decorated function needs to be jitted.