From 9f82c0fff529c7aab25d9b2a133a32d2c4c46e9f Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Tue, 14 May 2024 19:45:39 +0000 Subject: [PATCH] Add support for layouts --- python/perf-kernels/flash-attention.py | 380 ++++++++++++------------- 1 file changed, 178 insertions(+), 202 deletions(-) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index d70a43ecd36c..fa3c3202c120 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -26,13 +26,6 @@ import triton import triton.language as tl -torch_dtype: tl.constexpr = torch.float16 - -TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz') -if TORCH_HAS_FP8E5: - torch_dtype: tl.constexpr = torch.float8_e5m2fnuz - - class MetaData(): cu_seqlens_q = None cu_seqlens_k = None @@ -43,6 +36,7 @@ class MetaData(): causal = False num_contexts = 0 varlen = False + layout = None dropout_p, return_encoded_softmax = 0.0, False def __init__(self, sm_scale=1.0): @@ -50,6 +44,7 @@ def __init__(self, sm_scale=1.0): def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k): self.varlen = True + self.layout = 'thd' self.cu_seqlens_q = cu_seqlens_q self.cu_seqlens_k = cu_seqlens_k # Without "varlen", there should still be one sequence. @@ -108,39 +103,35 @@ def check_args(self, q, k, v, o): assert head_size <= 256 assert o.shape == q.shape assert (nheads_q % nheads_k) == 0 - + assert self.layout is not None + assert self.layout == 'thd' or not self.varlen @triton.jit -def cdiv_fn(x, y): +def cdiv_fn(x,y): return (x + y - 1) // y - @triton.jit def max_fn(x, y): return tl.math.max(x, y) - @triton.jit def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): ms = tl.arange(0, m) ns = tl.arange(0, n) return philox_offset + ms[:, None] * stride + ns[None, :] - @triton.jit def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32) # TODO: use tl.randint for better performance return tl.rand(philox_seed, rng_offsets) - @triton.jit def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) rng_keep = rng_output > dropout_p return rng_keep - @triton.jit def load_fn(block_ptr, first, second, pad): if first and second: @@ -153,7 +144,6 @@ def load_fn(block_ptr, first, second, pad): tensor = tl.load(block_ptr) return tensor - @triton.jit def print_gpu(prefix, val=None): if (tl.program_id(0) == 0) and ((tl.program_id(1) == 0) and (tl.program_id(2) == 0)): @@ -162,7 +152,6 @@ def print_gpu(prefix, val=None): else: tl.device_print(prefix) - @triton.jit def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False): # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix @@ -195,7 +184,6 @@ def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpo else: return alibi_block - def compute_alibi_tensor(alibi_slopes, seqlen_q, seqlen_k): q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K) @@ -288,87 +276,46 @@ def _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, actual_ encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, BLOCK_N)) return acc, l_i, m_i - @triton.autotune( - configs=[ - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=8), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=8), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': True}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=8), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=8), - # TODO: This config fails with head_size not pow2 with data mismatches. Check why. - # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - ], - key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], - use_cuda_graph=True, + configs=[ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': True}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + # TODO: This config fails with head_size not pow2 with data mismatches. Check why. + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + ], + key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], + use_cuda_graph=True, ) @triton.jit def attn_fwd( - Q, - K, - V, - bias, - sm_scale, - L, - Out, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - stride_oz, - stride_oh, - stride_om, - stride_on, - stride_bz, - stride_bh, - stride_bm, - stride_bn, - stride_az, - stride_ah, - cu_seqlens_q, - cu_seqlens_k, - dropout_p, - philox_seed, - philox_offset_base, - encoded_softmax, + Q, K, V, bias, sm_scale, L, Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + stride_bz, stride_bh, stride_bm, stride_bn, + stride_az, stride_ah, + cu_seqlens_q, cu_seqlens_k, + dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, - HQ: tl.constexpr, - HK: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, - MAX_SEQLENS_Q: tl.constexpr, - MAX_SEQLENS_K: tl.constexpr, + HQ: tl.constexpr, HK:tl.constexpr, + ACTUAL_BLOCK_DMODEL:tl.constexpr, + MAX_SEQLENS_Q:tl.constexpr, MAX_SEQLENS_K:tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, BIAS_TYPE: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, - USE_ALIBI: tl.constexpr, - BATCH_SIZE: tl.constexpr, + USE_ALIBI: tl.constexpr ): start_m = tl.program_id(0) off_h_q = tl.program_id(1) @@ -644,7 +591,6 @@ def _attn_bwd_preprocess( else: tl.store(delta_ptrs, delta) - @triton.jit def _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, # shared by Q/K/V/DO. @@ -696,7 +642,6 @@ def _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0)) return dk, dv - @triton.jit def _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, # shared by Q/K/V/DO. @@ -744,7 +689,6 @@ def _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n)) return dq - @triton.jit def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, # shared by Q/K/V/DO. @@ -874,9 +818,34 @@ def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, empty = torch.empty(128, device="cuda") +# TODO: This can probably optimized to have fewer lines of code. +def get_strides_from_layout(metadata, q, k, v, o): + if metadata.layout == 'thd': + batch, nheads_q = metadata.num_contexts, q.shape[1] + nheads_k = k.shape[1] + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + elif metadata.layout == 'bhsd': + batch, nheads_q, _, head_size = q.shape + nheads_k = k.shape[1] + q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3)) + k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3)) + v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3)) + o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3)) + elif metadata.layout == 'bshd': + batch, _, nheads_q, head_size = q.shape + nheads_k = k.shape[2] + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + else: + assert False, 'Got unsupported layout.' + return batch, nheads_q, nheads_k, q_strides, k_strides, v_strides, o_strides class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, o, metadata): # NOTE: a large bias tensor leads to overflow during pointer arithmetic @@ -886,24 +855,15 @@ def forward(ctx, q, k, v, o, metadata): if o is None: o = torch.empty_like(q, dtype=v.dtype) metadata.check_args(q, k, v, o) - if metadata.varlen: - total_q, nheads_q, head_size = q.shape - total_k, nheads_k, _ = k.shape - batch = metadata.num_contexts - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) - v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) - o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) - else: - batch, nheads_q, seqlen_q, head_size = q.shape - _, nheads_k, seqlen_k, _ = k.shape - q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3)) - k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3)) - v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3)) - o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3)) + + batch, nheads_q, nheads_k, q_strides, k_strides, v_strides, o_strides = \ + get_strides_from_layout(metadata, q, k, v, o) + head_size = q.shape[-1] # Get closest power of 2 over or equal to 32. padded_d_model = 1 << (head_size - 1).bit_length() + # Smallest head_dim supported is 16. If smaller, the tile in the + # kernel is padded - there is no padding in memory for any dims. padded_d_model = max(padded_d_model, 16) grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META['BLOCK_M']), nheads_q, batch) @@ -942,8 +902,9 @@ def forward(ctx, q, k, v, o, metadata): ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=metadata.max_seqlens_q, MAX_SEQLENS_K=metadata.max_seqlens_k, IS_CAUSAL=metadata.causal, VARLEN=metadata.varlen, BLOCK_DMODEL=padded_d_model, BIAS_TYPE=0 if metadata.bias is None else 1, - USE_ALIBI=False if metadata.alibi_slopes is None else True, ENABLE_DROPOUT=metadata.dropout_p - > 0.0, RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax, BATCH_SIZE=q.shape[0]) + USE_ALIBI=False if metadata.alibi_slopes is None else True, + ENABLE_DROPOUT=metadata.dropout_p > 0.0, + RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax) ctx.save_for_backward(q, k, v, o, M) ctx.grid = grid @@ -1031,42 +992,48 @@ def backward(ctx, do, _): return dq, dk, dv, None, None - attention = _attention.apply - -def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype): +def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout): torch.manual_seed(20) # Initialize q, k, v - q = torch.randn((Z, HQ, N_CTX_Q, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((Z, HK, N_CTX_K, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((Z, HK, N_CTX_K, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + if layout == 'bhsd': + q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD) + k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD) + elif layout == 'bshd': + q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD) + k_tensor_shape = (Z, N_CTX_K, HQ, D_HEAD) + else: + assert False, 'Got unsupported tensor layout' + q = torch.randn(q_tensor_shape, dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=True) sm_scale = D_HEAD**-0.5 input_metadata = MetaData(sm_scale=sm_scale) input_metadata.max_seqlens_q = N_CTX_Q input_metadata.max_seqlens_k = N_CTX_K + input_metadata.layout = layout return q, k, v, input_metadata - -def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype): +def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlens=False): torch.manual_seed(20) # Random sequence lengths. Using N_CTX as kind of max of sum of individual seqs - max_seqlens_q = N_CTX_Q // Z - max_seqlens_k = N_CTX_K // Z - seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z, ), dtype=torch.int32) - seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z, ), dtype=torch.int32) - max_seqlens_q = torch.max(seqlens_q).item() - max_seqlens_k = torch.max(seqlens_k).item() + if not equal_seqlens: + max_seqlens_q = N_CTX_Q // Z + max_seqlens_k = N_CTX_K // Z + seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z,), dtype=torch.int32) + seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z,), dtype=torch.int32) + else: + seqlens_q = torch.full((Z,), N_CTX_Q // Z) + seqlens_k = torch.full((Z,), N_CTX_K // Z) # Calculate cumulative sequence lengths cu_seqlens_q = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_q.cumsum(dim=0, dtype=torch.int32)]) cu_seqlens_k = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_k.cumsum(dim=0, dtype=torch.int32)]) cu_seqlens_q = cu_seqlens_q.to(device="cuda") cu_seqlens_k = cu_seqlens_k.to(device="cuda") - # -1 because the last entry of cu_seqlens_q specifies the end of the last seq - # num_ctxs = len(cu_seqlens_q) - 1 # Initialize q, k, v with variable lengths total_q = cu_seqlens_q[-1].item() @@ -1079,7 +1046,6 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype): input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) return q, k, v, input_metadata - @pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ (4, 48, 24, 1024, 1024, 64), (1, 24, 6, 8192, 8192, 64), @@ -1102,7 +1068,9 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype): @pytest.mark.parametrize('use_alibi', [True, False]) def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=torch.float16): torch.manual_seed(20) - q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype) + # TODO: Adapt test for bshd + layout = 'bhsd' + q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout) if causal: input_metadata.need_causal() @@ -1114,9 +1082,6 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=to else: alibi_slopes = None - if TORCH_HAS_FP8E5: - q = q.to(torch_dtype) - k = k.to(torch_dtype) o = torch.empty_like(q) # triton implementation @@ -1147,7 +1112,6 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=to # compare torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) - @pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ (4, 48, 1024, 1024, 64), (4, 24, 8192, 8192, 64), @@ -1185,9 +1149,6 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor q = torch.randn((Z, H, N_CTX_Q, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() k = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() v = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - if TORCH_HAS_FP8E5: - q = q.to(torch_dtype) - k = k.to(torch_dtype) o = torch.empty_like(q) # triton implementation @@ -1218,9 +1179,8 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor (4, 16, 1024, 128), (4, 16, 8192, 128), (32, 48, 8192, 128)]) @pytest.mark.parametrize('causal', [True, False]) def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): - pytest.skip() - q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, D_HEAD, dtype) + q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, D_HEAD, dtype) tri_out = torch.empty_like(q) ref_out = torch.empty_like(q) @@ -1233,7 +1193,6 @@ def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): attention(q, k, v, tri_out, input_metadata) torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2) - @pytest.mark.parametrize('Z, HQ, HK, N_CTX, D_HEAD', [(2, 48, 24, 128, 64), (4, 48, 12, 256, 64), (4, 48, 4, 512, 64), (4, 48, 2, 1024, 64), (8, 48, 6, 4096, 64), (4, 48, 8, 16384, 64), (4, 64, 16, 128, 128), (4, 64, 4, 4096, 128), @@ -1261,7 +1220,6 @@ def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16 attention(q, k, v, tri_out, input_metadata) torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2) - @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ (4, 48, 1024, 64), (4, 48, 2048, 64), @@ -1277,7 +1235,6 @@ def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16 @pytest.mark.parametrize('use_alibi', [False, True]) def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sdpa_test, use_alibi, dtype=torch.float16): - pytest.skip() torch.manual_seed(20) if qseqlen_not_equal_kseqlen is not None: seqlen_q = qseqlen_not_equal_kseqlen @@ -1363,66 +1320,60 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sd torch.testing.assert_close(ref_dk, tri_dk, atol=ATOL, rtol=RTOL) torch.testing.assert_close(ref_dq, tri_dq, atol=ATOL, rtol=RTOL) - def nonvarlen_benchmark_configs(): - configs = [ - (16, 16, 16, 1024, 1024), - (8, 16, 16, 2048, 2048), - (4, 16, 16, 4096, 4096), - (2, 16, 16, 8192, 8192), - (1, 16, 16, 16384, 16384), - (2, 48, 48, 1024, 1024), - (2, 48, 48, 2048, 1024), - (2, 48, 48, 4096, 8192), - (2, 48, 48, 8192, 4096), - (2, 48, 48, 16384, 8192), - (8, 16, 16, 1989, 15344), - (4, 16, 16, 4097, 163), - (2, 16, 16, 8122, 2159), - (1, 16, 16, 16281, 7), - (2, 48, 48, 1021, 1020), - (2, 48, 48, 2001, 2048), - (2, 48, 48, 3996, 9639), - (2, 48, 48, 8181, 1021), - ] + configs=[(16, 16, 16, 1024, 1024), + (8, 16, 16, 2048, 2048), + (4, 16, 16, 4096, 4096), + (2, 16, 16, 8192, 8192), + (1, 16, 16, 16384, 16384), + (2, 48, 48, 1024, 1024), + (2, 48, 48, 2048, 1024), + (2, 48, 48, 4096, 8192), + (2, 48, 48, 8192, 4096), + (2, 48, 48, 16384, 8192), + (8, 16, 16, 1989, 15344), + (4, 16, 16, 4097, 163), + (2, 16, 16, 8122, 2159), + (1, 16, 16, 16281, 7), + (2, 48, 48, 1021, 1020), + (2, 48, 48, 2001, 2048), + (2, 48, 48, 3996, 9639), + (2, 48, 48, 8181, 1021), + ] return configs - def varlen_benchmark_configs(): - configs = [ - (2, 16, 4, 1024, 1024), - (8, 16, 2, 2048, 2048), - (4, 16, 8, 4096, 4096), - (2, 16, 4, 8192, 8192), - (2, 16, 8, 16384, 16384), - (2, 48, 12, 1024, 1024), - (2, 48, 24, 2048, 2048), - (2, 48, 8, 4096, 4096), - (2, 48, 4, 8192, 8192), - (2, 48, 2, 16384, 16384), - (2, 64, 32, 1024, 1024), - (4, 64, 16, 2048, 2048), - (4, 64, 8, 4096, 4096), - (4, 64, 32, 8192, 8192), - (4, 128, 16, 16384, 16384), - ] + configs=[(2, 16, 4, 1024, 1024), + (8, 16, 2, 2048, 2048), + (4, 16, 8, 4096, 4096), + (2, 16, 4, 8192, 8192), + (2, 16, 8, 16384, 16384), + (2, 48, 12, 1024, 1024), + (2, 48, 24, 2048, 2048), + (2, 48, 8, 4096, 4096), + (2, 48, 4, 8192, 8192), + (2, 48, 2, 16384, 16384), + (2, 64, 32, 1024, 1024), + (4, 64, 16, 2048, 2048), + (4, 64, 8, 4096, 4096), + (4, 64, 32, 8192, 8192), + (4, 128, 16, 16384, 16384), + ] return configs +def run_benchmark(custom, args): -def run_benchmark(custom): - - args = parse_args() dtype = arg_to_torch_dtype[args.dtype] - # hk = args.hq if not args.hk else args.hk - # sk = args.sq if not args.sk else args.sk + hk = args.hq if not args.hk else args.hk + sk = args.sq if not args.sk else args.sk head_size = 128 if not args.d else args.d mode = 'fwd' - x_names = ['BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K'] + x_names=['BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K'] causal = args.causal - varlen = args.varlen + varlen = args.layout == 'thd' configs = [] if custom: - x_vals_list = [(args.b, args.hq, args.hk, args.sq, args.sk)] + x_vals_list=[(args.b, args.hq, hk, args.sq, sk)] else: if varlen: x_vals_list = varlen_benchmark_configs() @@ -1430,14 +1381,26 @@ def run_benchmark(custom): x_vals_list = nonvarlen_benchmark_configs() print_time = args.return_time line_names = 'Time (ms)' if print_time else 'TFLOPS' - configs.append( - triton.testing.Benchmark(x_names=x_names, x_vals=x_vals_list, line_arg='provider', line_vals=['triton'], - line_names=[line_names], styles=[('red', '-')], ylabel='ms', - plot_name=f'fused-attention-{mode}-d{head_size}{"-varlen" if varlen else ""}', - args={'D_HEAD': head_size, 'dtype': dtype, 'causal': causal, 'mode': mode})) + configs.append(triton.testing.Benchmark( + x_names=x_names, + x_vals=x_vals_list, + line_arg='provider', + line_vals=['triton'], + line_names=[line_names], + styles=[('red', '-')], + ylabel='ms', + plot_name=f'fused-attention-{mode}-d{head_size}-layout{args.layout}', + args={ + 'D_HEAD': head_size, + 'dtype': dtype, + 'causal': causal, + 'mode': mode}) + ) @triton.testing.perf_report(configs) - def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal, mode, provider, device="cuda"): + def bench_flash_attention( + BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal, mode, provider, device="cuda" + ): assert mode in ["fwd", "bwd"] warmup = 25 rep = 100 @@ -1455,14 +1418,14 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal flops_per_matmul = 0 if varlen: - q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype) - for i in range(0, input_metadata.num_contexts): - seqlen_q = input_metadata.cu_seqlens_q[i + 1] - input_metadata.cu_seqlens_q[i] - seqlen_k = input_metadata.cu_seqlens_k[i + 1] - input_metadata.cu_seqlens_k[i] + q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, args.equal_seqlens) + for i in range (0, input_metadata.num_contexts): + seqlen_q = input_metadata.cu_seqlens_q[i+1] - input_metadata.cu_seqlens_q[i] + seqlen_k = input_metadata.cu_seqlens_k[i+1] - input_metadata.cu_seqlens_k[i] # x2 for 2 GEMMs flops_per_matmul += seqlen_q.item() * seqlen_k.item() * HQ * D_HEAD * 2 else: - q, k, v, input_metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype) + q, k, v, input_metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, args.layout) flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD if causal: input_metadata.need_causal() @@ -1486,6 +1449,13 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal bench_flash_attention.run(save_path=".", print_data=True) +def supported_layouts(): + layouts = \ + 'bhsd: Q, K, V are individual tensors of [batch, num_heads, seqlen_q/k, head_size]' \ + 'bshd: Q, K, V are individual tensors of [batch, seqlen_q/k, num_heads, head_size]' \ + 'thd: Q, K, V are individual tensors of [total_q/k, num_heads, head_size]' \ + 'This layout is sometimes called "varlen" or "grouped" layout.' + return layouts def parse_args(): parser = argparse.ArgumentParser( @@ -1497,20 +1467,27 @@ def parse_args(): parser.add_argument("-hk", type=int, default=0) parser.add_argument("-sq", type=int, default=0) parser.add_argument("-sk", type=int, default=0) + parser.add_argument("-equal_seqlens", action='store_true', default=False, + help='If specified, each context within the thd layout' \ + ' has same seqlen as sq and sk') parser.add_argument("-d", type=int, default=0) parser.add_argument("-causal", action='store_true', default=False) - parser.add_argument("-varlen", action='store_true', default=False) parser.add_argument("-dtype", default='fp16') parser.add_argument("-return_time", action='store_true', default=False) + parser.add_argument("-layout", type=str, default='bhsd', help=supported_layouts()) return parser.parse_args() - -arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32} - +arg_to_torch_dtype = { + 'fp16': torch.float16, + 'bf16': torch.bfloat16, + 'fp32': torch.float32 +} def main(): args = parse_args() custom_config = False + assert args.layout == 'thd' or not args.equal_seqlens, \ + "Equal sequence lengths arg must be used with the thd layout." if args.b or args.hq or args.hk or args.sq or args.sk or args.d: custom_config = True assert args.b and args.hq and args.sq and args.d, \ @@ -1521,8 +1498,7 @@ def main(): assert args.dtype in arg_to_torch_dtype, \ "Only fp16, bf16 and f32 types currently supported." - run_benchmark(custom_config) - + run_benchmark(custom_config, args) if __name__ == '__main__': sys.exit(main())