diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 97316ae30dc9..c3d3db2781f8 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -34,6 +34,14 @@ if TORCH_HAS_FP8E5: torch_dtype:tl.constexpr = torch.float8_e5m2fnuz +# Helper function, but not always usable due to compiler bugs (esp. used with tl.trans) +@triton.jit +def dot(BLOCK_M : tl.constexpr, QDIM : tl.constexpr, KDIM : tl.constexpr, q, k): + if BLOCK_M == 1: + return tl.sum(tl.view(q, [QDIM]) * tl.view(k, [KDIM])) + else: + return tl.dot(q, k) + class MetaData(): cu_seqlens_q = None cu_seqlens_k = None @@ -452,7 +460,7 @@ def attn_fwd( tl.debug_barrier() # Remaining blocks, if any, are full / not masked. - if (masked_blocks > 0): + if (masked_blocks > 0): if IS_CAUSAL: offs_n_causal = offs_n + (seqlen_q - seqlen_k) else: @@ -462,7 +470,7 @@ def attn_fwd( if bias_ptr is not None: bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks*BLOCK_N)) if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, n_full_blocks)) acc, l_i, m_i = _attn_fwd_inner( acc, l_i, m_i, q, K_block_ptr, V_block_ptr, @@ -515,25 +523,64 @@ def attn_fwd( block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0) ) - # Need boundary check on this to make sure the padding from the + # Need boundary check on this to make sure the padding from the # Q and KV tensors in both dims are not part of what we store back. - # TODO: Do the boundary check optionally. + # TODO: Do the boundary check optionally. tl.store(O_block_ptr, acc, boundary_check=(0,1)) @triton.jit -def _attn_bwd_preprocess(O, DO, # - NewDO, Delta, # - BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, # - ): - off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) - off_n = tl.arange(0, D_HEAD) +def _attn_bwd_preprocess( + Out, DO, + Delta, + stride_oz, stride_oh, stride_om, stride_on, + stride_doz, stride_doh, stride_dom, stride_don, + seqlen_q, + head_dim, + BLOCK_M: tl.constexpr, + D_HEAD: tl.constexpr, +): + # off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + # off_n = tl.arange(0, D_HEAD) + off_m = tl.program_id(0) * BLOCK_M + off_h = tl.program_id(1) # head index + off_z = tl.program_id(2) # batch index + num_h = tl.num_programs(1) + o_offset = off_h * stride_oh + off_z * stride_oz + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, head_dim), + strides=(stride_om, stride_on), + offsets=(off_m, 0), + block_shape=(BLOCK_M, D_HEAD), + order=(1, 0) + ) + do_offset = off_h * stride_doh + off_z * stride_doz + DO_block_ptr = tl.make_block_ptr( + base=DO + do_offset, + shape=(seqlen_q, head_dim), + strides=(stride_dom, stride_don), + offsets=(off_m, 0), + block_shape=(BLOCK_M, D_HEAD), + order=(1, 0) + ) # load - o = tl.load(O + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) - do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + # o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + # do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + o = tl.load(O_block_ptr, boundary_check=(0,1), padding_option="zero").to(tl.float32) + do = tl.load(DO_block_ptr, boundary_check=(0,1), padding_option="zero").to(tl.float32) + # compute delta = tl.sum(o * do, axis=1) - # write-back - tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) - tl.store(Delta + off_m, delta) + # write-back, shape (q.shape[0] * q.shape[1], q.shape[2]) + off_zh = off_z * num_h + off_h * 1 + # Check for OOB accesses + delta_ptrs = Delta + off_zh * seqlen_q + off_m + tl.arange(0, BLOCK_M) + overflow = off_m + BLOCK_M - seqlen_q + if overflow > 0: + boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow, dtype=tl.int32) + mask = boundary > tl.arange(0, BLOCK_M) + tl.store(delta_ptrs, delta, mask=mask) + else: + tl.store(delta_ptrs, delta) @triton.jit def _bwd_kernel_dk_dv( @@ -544,132 +591,198 @@ def _bwd_kernel_dk_dv( stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_ok, seqlen_q, seqlen_k, + head_dim, dropout_p, philox_seed, philox_offset_base, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, CAUSAL: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr + ENABLE_DROPOUT: tl.constexpr, + PADDED_HEAD: tl.constexpr, ): start_m = tl.program_id(0) * BLOCK_N - off_hz = tl.program_id(1) - # Q is consumed depending on block ID. Every block uses - # previous block offset by BLOCK_M x D_HEAD. - qvk_offset = off_hz * stride_qh + off_h = tl.program_id(1) # head index + off_z = tl.program_id(2) # batch index + num_h = tl.num_programs(1) + num_z = tl.num_programs(2) # initialize offsets - offs_m = start_m + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m + tl.arange(0, BLOCK_N) + offs_n = tl.arange(0, BLOCK_M) # Initialize pointers to Q, K, V - q_offset = off_hz * stride_qh + # Q is consumed depending on block ID. Every block uses + # previous block offset by BLOCK_M x D_HEAD. + q_offset = off_h * stride_qh + off_z * stride_qz Q_block_ptr = tl.make_block_ptr( base=Q + q_offset, - shape=(seqlen_q, BLOCK_DMODEL), + shape=(seqlen_q, head_dim), strides=(stride_qm, stride_qk), offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0) ) - k_offset = off_hz * stride_kh - K_block_ptr = tl.make_block_ptr( + k_offset = off_h * stride_kh + off_z * stride_kz + KT_block_ptr = tl.make_block_ptr( base=K + k_offset, - shape=(BLOCK_DMODEL, seqlen_k), + shape=(head_dim, seqlen_k), strides=(stride_kk, stride_kn), offsets=(0, start_m), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1) ) - v_offset = off_hz * stride_vh + v_offset = off_h * stride_vh + off_z * stride_vz VT_block_ptr = tl.make_block_ptr( base=V + v_offset, - shape=(BLOCK_DMODEL, seqlen_k), + shape=(head_dim, seqlen_k), strides=(stride_vn, stride_vk), offsets=(0, start_m), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1) ) - do_offset = q_offset + do_offset = off_h * stride_oh + off_z * stride_oz DO_block_ptr = tl.make_block_ptr( base=DO + do_offset, - shape=(seqlen_q, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), + shape=(seqlen_q, head_dim), + strides=(stride_om, stride_ok), offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0) ) + off_zh = off_z * num_h + off_h * 1 # pointer to row-wise quantities in value-like data - D_ptrs = D + off_hz * seqlen_q - l_ptrs = L + off_hz * seqlen_q - qk_scale = sm_scale * 1.44269504 + D_ptrs = D + off_zh * seqlen_q + l_ptrs = L + off_zh * seqlen_q + qk_scale = sm_scale * 1.44269504089 # load k and v: they will stay in SRAM throughout - k = tl.load(K_block_ptr) - k = (k * qk_scale).to(K_block_ptr.type.element_ty) - vt = tl.load(VT_block_ptr) + # (BLOCK_DMODEL, BLOCK_N) + if PADDED_HEAD: + kt = tl.load(KT_block_ptr, boundary_check=(1,0), padding_option="zero") + else: + kt = tl.load(KT_block_ptr, boundary_check=(1,), padding_option="zero") + kt = (kt * qk_scale).to(KT_block_ptr.type.element_ty) + # (BLOCK_DMODEL, BLOCK_N) + if PADDED_HEAD: + vt = tl.load(VT_block_ptr, boundary_check=(1,0), padding_option="zero") + else: + vt = tl.load(VT_block_ptr, boundary_check=(1,), padding_option="zero") dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) + # This lower loop bound is because of the causal mask. We create a lower triangular # result. The upper triangular is -inf (becomes 0 when we do e^x). As such, it can # be ignored in the GEMM. - lo = start_m if CAUSAL else 0 + lo = (start_m // BLOCK_M) * BLOCK_M if CAUSAL else 0 hi = seqlen_q Q_block_ptr = tl.advance(Q_block_ptr, (lo, 0)) DO_block_ptr = tl.advance(DO_block_ptr, (lo, 0)) - batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k + batch_philox_offset = philox_offset_base + off_zh * seqlen_q * seqlen_k + ''' + K1 K2 (d)V dO + Q1 qk11 qk12 (d)v1 dO1 + Q2 qk21 qk22 (d)v2 dO2 + + QK: (seqlen_q, seqlen_k) + dO: (seqlen_q, hdim) + dV: (seqlen_k, hdim) + + dV = (QK)^T dO + + dV1 = qk11 dO1 + qk21 dO2 = q1 k1 dO1 + q2 k1 dO2 + dV2 = qk12 dO1 + qk22 dO2 = q1 k2 dO1 + q2 k2 dO2 + ~~~~~ = 0 + start_m: select k and dV + start_n: select q and dO + ''' # loop over q (seqlen_q, dhead), do (seqlen_q, d_head) for start_n in range(lo, hi, BLOCK_M): - offs_m_curr = offs_n[:, None] + start_n + offs_m_curr = offs_n[:, None] + start_n # (BLOCK_M, 1) # -- load q, do -- - q = tl.load(Q_block_ptr) - do = tl.load(DO_block_ptr) + # TODO: It is more optimal to do OOB check only in the last iter. + # (BLOCK_M, BLOCK_DMODEL), offs = (BLOCK_M * iter, 0) = (start_n, 0) + # do is (BLOCK_M, BLOCK_DMODEL) + if PADDED_HEAD: + q = tl.load(Q_block_ptr, boundary_check=(0,1), padding_option="zero") + do = tl.load(DO_block_ptr, boundary_check=(0,1), padding_option="zero") + else: + q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="zero") + do = tl.load(DO_block_ptr, boundary_check=(0,), padding_option="zero") # -- compute qk ---- - qk = tl.dot(q, k) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # TODO: These two checks can be optimized to occur on the last iter. + overflow_size = start_n + BLOCK_M - seqlen_q + if overflow_size > 0: + boundary_n = tl.full((BLOCK_N, ), seqlen_q, dtype=tl.int32) + mask = offs_m_curr < boundary_n[None, :] + qk = tl.where(mask, qk, float("-inf")) if CAUSAL: qk = tl.where(offs_m_curr >= offs_m[None, :], qk, float("-inf")) - l_i = tl.load(l_ptrs + offs_m_curr) - p = tl.math.exp2(qk - l_i) + # q.offs = (start_n, 0), k.offs = (0, start_m) + qk += dot(BLOCK_M, BLOCK_DMODEL, BLOCK_DMODEL, q, kt) # (BLOCK_M, BLOCK_N) + # Check for OOB accesses on D and LSE + boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) + d_lse_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + d_lse_padding = tl.full((BLOCK_M, ), 0, dtype=tl.float32) + Di = tl.load(D_ptrs + offs_m_curr, + mask=d_lse_ptrs_mask[:, None], + other=d_lse_padding[:, None]) + l_i = tl.load(l_ptrs + offs_m_curr, + mask=d_lse_ptrs_mask[:,None], + other=d_lse_padding[:, None]) + p = tl.math.exp2(qk - l_i) # (BLOCK_M, BLOCK_N) # -- compute dv ---- if ENABLE_DROPOUT: philox_offset = batch_philox_offset + start_n * seqlen_k + start_m keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, seqlen_k) # CAVEAT: do NOT update p, ds needs the original p - dv += tl.dot(tl.where(tl.trans(keep), tl.trans(p) / (1 - dropout_p), 0.0).to(Q.dtype.element_ty), do) + if BLOCK_M == 1: + dv += tl.where(keep, p / (1 - dropout_p), 0.0).to(Q.dtype.element_ty) * do + else: + dv += tl.dot(tl.trans(tl.where(keep, p / (1 - dropout_p), 0.0)).to(Q.dtype.element_ty), do) else: - dv += tl.dot(tl.trans(p.to(do.dtype)), do) - # compute dp = dot(v, do) - Di = tl.load(D_ptrs + offs_m_curr)#NAN WHY - dp = tl.zeros([BLOCK_M, BLOCK_M], dtype=tl.float32) + if BLOCK_M == 1: + dv += p.to(Q.dtype.element_ty) * do + else: + # dv += tl.dot(tl.trans(p.to(do.dtype)), do) + dv += tl.dot(tl.trans(p).to(do.dtype), do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # compute dp = dot(do, vt) + # dp += dot(BLOCK_M, BLOCK_DMODEL, BLOCK_DMODEL, do, vt) + # do.shape = (BLOCK_M, BLOCK_DMODEL) vt.shape = (BLOCK_DMODEL, BLOCK_N) dp += tl.dot(do, vt) if ENABLE_DROPOUT: dp = tl.where(keep, dp / (1 - dropout_p), 0) # compute ds = p * (dp - delta[:, None]) - ds = p * (dp - Di) + ds = p * (dp - Di) # (BLOCK_M, BLOCK_N) # compute dk - dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q) + if BLOCK_M == 1: + dk += ds.to(Q.dtype.element_ty) * q + else: + # ds.shape = (BLOCK_M, BLOCK_N), q.shape = (BLOCK_M, BLOCK_DMODEL) + dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q) # (BLOCK_N, BLOCK_DMODEL) # update pointers Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M, 0)) - DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0)) + DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0)) # Debug DO accessing problems # initialize pointers to output DK_block_ptr = tl.make_block_ptr( base=DK + k_offset, - shape=(seqlen_k, BLOCK_DMODEL), + shape=(seqlen_k, head_dim), strides=(stride_kn, stride_kk), offsets=(start_m, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), + block_shape=(BLOCK_N, BLOCK_DMODEL), order=(1, 0) ) DV_block_ptr = tl.make_block_ptr( base=DV + v_offset, - shape=(seqlen_k, BLOCK_DMODEL), + shape=(seqlen_k, head_dim), strides=(stride_vk, stride_vn), offsets=(start_m, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), + block_shape=(BLOCK_N, BLOCK_DMODEL), order=(1, 0) ) - tl.store(DK_block_ptr, (dk * sm_scale).to(DK.dtype.element_ty)) - tl.store(DV_block_ptr, dv.to(DK.type.element_ty)) - + tl.store(DK_block_ptr, (dk * sm_scale).to(DK.type.element_ty), boundary_check=(0,1)) + tl.store(DV_block_ptr, dv.to(DV.type.element_ty), boundary_check=(0,1)) @triton.jit def _bwd_kernel_dq( @@ -680,82 +793,118 @@ def _bwd_kernel_dq( stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, - seqlen_q, seqlen_k, dropout_p, philox_seed, philox_offset_base, + stride_oz, stride_oh, stride_om, stride_ok, + seqlen_q, seqlen_k, head_dim, dropout_p, philox_seed, philox_offset_base, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, CAUSAL: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, + PADDED_HEAD: tl.constexpr, ): - start_m = tl.program_id(0) * BLOCK_N - off_hz = tl.program_id(1) - qvk_offset = off_hz * stride_qh + start_m = tl.program_id(0) * BLOCK_M + off_h = tl.program_id(1) # head index + off_z = tl.program_id(2) # batch index + num_h = tl.num_programs(1) + num_z = tl.num_programs(2) # initialize offsets offs_m = start_m + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) # Initialize pointers to Q, K, V - q_offset = off_hz * stride_qh + q_offset = off_h * stride_qh + off_z * stride_qz Q_block_ptr = tl.make_block_ptr( base=Q + q_offset, - shape=(seqlen_q, BLOCK_DMODEL), + shape=(seqlen_q, head_dim), strides=(stride_qm, stride_qk), offsets=(start_m, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0) ) - k_offset = off_hz * stride_kh + k_offset = off_h * stride_kh + off_z * stride_kz K_block_ptr = tl.make_block_ptr( base=K + k_offset, - shape=(BLOCK_DMODEL, seqlen_k), + shape=(head_dim, seqlen_k), strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1) ) - v_offset = off_hz * stride_vh + v_offset = off_h * stride_vh + off_z * stride_vz V_block_ptr = tl.make_block_ptr( base=V + v_offset, - shape=(BLOCK_DMODEL, seqlen_k), + shape=(head_dim, seqlen_k), strides=(stride_vn, stride_vk), offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1) ) + do_offset = off_h * stride_oh + off_z * stride_oz DO_block_ptr = tl.make_block_ptr( - base=DO + q_offset, - shape=(seqlen_q, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), + base=DO + do_offset, + shape=(seqlen_q, head_dim), + strides=(stride_om, stride_ok), offsets=(start_m, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0) ) + off_zh = off_z * num_h + off_h * 1 # pointer to row-wise quantities in value-like data - D_ptrs = D + off_hz * seqlen_q - l_ptrs = L + off_hz * seqlen_q - qk_scale = sm_scale * 1.44269504 + D_ptrs = D + off_zh * seqlen_q + l_ptrs = L + off_zh * seqlen_q + qk_scale = sm_scale * 1.44269504089 # load q and do: they will stay in SRAM throughout - q = tl.load(Q_block_ptr) + if PADDED_HEAD: + q = tl.load(Q_block_ptr, boundary_check=(0,1), padding_option="zero") + else: + q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="zero") q = (q * qk_scale).to(Q_block_ptr.type.element_ty) - do = tl.load(DO_block_ptr) - Di = tl.load(D_ptrs + offs_m) - l_i = tl.load(l_ptrs + offs_m) + if PADDED_HEAD: + do = tl.load(DO_block_ptr, boundary_check=(0,1), padding_option="zero") + else: + do = tl.load(DO_block_ptr, boundary_check=(0,), padding_option="zero") + # Check for OOB accesses on D and LSE + overflow_size_q = start_m + BLOCK_M - seqlen_q + boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size_q, dtype=tl.int32) + d_lse_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + d_lse_padding = tl.full((BLOCK_M, ), 0, dtype=tl.float32) + Di = tl.load(D_ptrs + offs_m, mask=d_lse_ptrs_mask, other=d_lse_padding) + l_i = tl.load(l_ptrs + offs_m, mask=d_lse_ptrs_mask, other=d_lse_padding) dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # loop over k, v lo = 0 hi = min(start_m + BLOCK_M, seqlen_k) if CAUSAL else seqlen_k - batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k + batch_philox_offset = philox_offset_base + off_zh * seqlen_q * seqlen_k + ''' + K1 K2 (d)V dO + Q1 qk11 qk12 (d)v1 dO1 + Q2 qk21 qk22 (d)v2 dO2 + + QK: (seqlen_q, seqlen_k) + dO: (seqlen_q, hdim) + dV: (seqlen_k, hdim) + ''' for start_n in range(lo, hi, BLOCK_N): # -- load k, v -- - k = tl.load(K_block_ptr) - v = tl.load(V_block_ptr) + # shape = (BLOCK_DMODEL, BLOCK_N), offs = (0, BLOCK_N * iter) = (0, start_n) + if PADDED_HEAD: + kt = tl.load(K_block_ptr, boundary_check=(1,0), padding_option="zero") + vt = tl.load(V_block_ptr, boundary_check=(1,0), padding_option="zero") + else: + kt = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero") + vt = tl.load(V_block_ptr, boundary_check=(1,), padding_option="zero") # -- compute qk ---- - qk = tl.dot(q, k) + # q.offs = (start_m, 0), k.offs = (0, start_n) + qk = dot(BLOCK_M, BLOCK_DMODEL, BLOCK_DMODEL, q, kt) if CAUSAL: qk = tl.where(offs_m[:, None] >= (offs_n[None, :] + start_n), qk, float("-inf")) + overflow_size_k = start_n + BLOCK_N - seqlen_k + boundary_n = tl.full((BLOCK_M, ), seqlen_k, dtype=tl.int32) + size_n = start_n + tl.arange(0, BLOCK_N) + mask = size_n[None, :] < boundary_n[:, None] + qk = tl.where(mask, qk, float("-inf")) p = tl.math.exp2(qk - l_i[:, None]) # compute dp = dot(v, do) dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - dp += tl.dot(do, v) + dp += dot(BLOCK_M, BLOCK_DMODEL, BLOCK_DMODEL, do, vt) if ENABLE_DROPOUT: philox_offset = batch_philox_offset + start_m * seqlen_k + start_n keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, seqlen_k) @@ -764,20 +913,24 @@ def _bwd_kernel_dq( ds = p * (dp - Di[:, None]) # compute dq. Unfortunately we cannot avoid transpose here as this loop # uses k both normal and transpose. - dq += tl.dot(ds.to(Q.dtype.element_ty), tl.trans(k)) + if BLOCK_M == 1: + dq += tl.view(kt, [BLOCK_DMODEL]) * ds.to(Q.type.element_ty) + else: + # ds.shape = (BLOCK_M, BLOCK_N), kt.shape = (BLOCK_DMODEL, BLOCK_N) + dq += tl.dot(ds.to(Q.type.element_ty), tl.trans(kt)) # (BLOCK_M, BLOCK_DMODEL) # update pointers K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) V_block_ptr = tl.advance(V_block_ptr, (0, BLOCK_N)) # initialize pointers to output DQ_block_ptr = tl.make_block_ptr( base=DQ + q_offset, - shape=(seqlen_q, BLOCK_DMODEL), + shape=(seqlen_q, head_dim), strides=(stride_qm, stride_qk), offsets=(start_m, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0) ) - tl.store(DQ_block_ptr, (dq * sm_scale).to(DQ_block_ptr.type.element_ty)) + tl.store(DQ_block_ptr, (dq * sm_scale).to(DQ_block_ptr.type.element_ty), boundary_check=(0,1)) empty = torch.empty(128, device="cuda") @@ -843,7 +996,7 @@ def forward(ctx, q, k, v, o, metadata): philox_offset = 0x1D4B42 if metadata.bias is not None: - bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1), + bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1), metadata.bias.stride(2), metadata.bias.stride(3)) else: bias_strides = (0,0,0,0) @@ -893,19 +1046,30 @@ def backward(ctx, do, _): seqlen_q = q.shape[2] seqlen_k = k.shape[2] do = do.contiguous() - dq = torch.zeros_like(q) + dq = torch.zeros_like(q, dtype=torch.float32) dk = torch.empty_like(k) dv = torch.empty_like(v) BATCH, N_HEAD, N_CTX = q.shape[:3] delta = torch.empty_like(L) do_scaled = torch.empty_like(do) - _attn_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( - o, do, - do_scaled, delta, + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + padded_head = (Lk != ctx.BLOCK_DMODEL) + grid_preprocess = (triton.cdiv(do.shape[2], BLOCK), do.shape[1], do.shape[0]) + _attn_bwd_preprocess[grid_preprocess]( + o, do, delta, + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + do.stride(0), do.stride(1), do.stride(2), do.stride(3), + seqlen_q, + head_dim=Lk, BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL, ) dq = torch.zeros_like(q) - _bwd_kernel_dk_dv[(triton.cdiv(q.shape[2], BLOCK), ctx.grid[1])]( + grid_dk_dv = lambda META: ( + triton.cdiv(seqlen_k, META['BLOCK_N']), + q.shape[1], + q.shape[0], + ) + _bwd_kernel_dk_dv[grid_dk_dv]( q, k, v, ctx.sm_scale, o, do_scaled, dk, dv, @@ -913,8 +1077,10 @@ def backward(ctx, do, _): q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), seqlen_q=seqlen_q, seqlen_k=seqlen_k, + head_dim=Lk, dropout_p=ctx.dropout_p, philox_seed=ctx.philox_seed, philox_offset_base=ctx.philox_offset, @@ -923,10 +1089,16 @@ def backward(ctx, do, _): BLOCK_N=BLOCK, CAUSAL=ctx.causal, ENABLE_DROPOUT=ctx.dropout_p > 0.0, + PADDED_HEAD=padded_head, num_warps=4,num_stages=1, ) DQ_BLOCK_M = min(seqlen_q, BLOCK) - _bwd_kernel_dq[ctx.grid]( + grid_dq = lambda META: ( + triton.cdiv(seqlen_q, DQ_BLOCK_M), + q.shape[1], + q.shape[0], + ) + _bwd_kernel_dq[grid_dq]( q, k, v, ctx.sm_scale, o, do_scaled, dq, @@ -934,8 +1106,10 @@ def backward(ctx, do, _): q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), seqlen_q=seqlen_q, seqlen_k=seqlen_k, + head_dim=Lk, dropout_p=ctx.dropout_p, philox_seed=ctx.philox_seed, philox_offset_base=ctx.philox_offset, @@ -944,6 +1118,7 @@ def backward(ctx, do, _): BLOCK_N=BLOCK, CAUSAL=ctx.causal, ENABLE_DROPOUT=ctx.dropout_p > 0.0, + PADDED_HEAD=padded_head, num_warps=4, waves_per_eu=1, num_stages=1, ) #print(h.asm["ttgir"]) @@ -1056,7 +1231,7 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * sm_scale if causal: - mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K-N_CTX_Q) scores[:, :, mask==0] = float("-inf") if use_bias: @@ -1186,45 +1361,88 @@ def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16 (4, 48, 2048, 64), (4, 48, 4096, 64), (1, 16, 8192, 64), + (1, 16, 128, 32), ]) -def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16): +@pytest.mark.parametrize('qseqlen_not_equal_kseqlen', [None]) +@pytest.mark.parametrize('torch_sdpa_test', [False]) +def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, torch_sdpa_test, dtype=torch.bfloat16): torch.manual_seed(20) causal = True - q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + dropout_p = 0 + q = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + o = torch.empty_like(q) + + if causal and ((N_CTX - 1) & N_CTX): + pytest.skip() + torch.manual_seed(20) + if qseqlen_not_equal_kseqlen is not None: + seqlen_q = qseqlen_not_equal_kseqlen + else: + seqlen_q = N_CTX + seqlen_k = N_CTX + + sm_scale = D_HEAD ** -0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.max_seqlens_q = seqlen_q + input_metadata.max_seqlens_k = seqlen_k + + if causal: + input_metadata.need_causal() - sm_scale = 0.5 - split_kernel = True dout = torch.randn_like(q) # reference implementation - M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) - p = torch.matmul(q, k.transpose(2, 3)) * sm_scale - if causal: - p[:, :, M == 0] = float("-inf") - p = torch.softmax(p.float(), dim=-1).half() - ref_out = torch.matmul(p, v) - ref_out.backward(dout) - ref_dv, v.grad = v.grad.clone(), None - ref_dk, k.grad = k.grad.clone(), None - ref_dq, q.grad = q.grad.clone(), None + if torch_sdpa_test: + ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v, + dropout_p=dropout_p, + is_causal=causal, + scale=sm_scale, + dropout_mask=None) + ref_out.backward(dout.to(device=ref_out.device, dtype=ref_out.dtype)) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + else: + M = torch.tril(torch.ones((seqlen_q, seqlen_k), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).type(dtype=p.dtype) + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + # # triton implementation - tri_out, _ = attention(q, k, v, causal, None, sm_scale, 0, False, True) - tri_out.backward(dout)#dout) + tri_out, _ = attention(q, k, v, o, input_metadata) + tri_out.backward(dout) tri_dv, v.grad = v.grad.clone(), None tri_dk, k.grad = k.grad.clone(), None tri_dq, q.grad = q.grad.clone(), None + # test + #print("reference") + #print(ref_dv) + #print("tri") + #print(tri_dv) # compare torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0) - if torch.version.hip is None: - torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=0) # The current block size for MI200 series is 64x64. This results in # larger differences in float results due to rounding. + + if dtype == torch.bfloat16: + ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + if dtype == torch.float32: + ATOL = 1e-3 * max(1.0, (seqlen_q + D_HEAD) / 64.0) else: - torch.testing.assert_close(ref_dv, tri_dv, atol=5e-2, rtol=0) - torch.testing.assert_close(ref_dk, tri_dk, atol=5e-2, rtol=1e-2) - torch.testing.assert_close(ref_dq, tri_dq, atol=5e-2, rtol=1e-2) + ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + RTOL = 0 + + torch.testing.assert_close(ref_dv, tri_dv, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(ref_dk, tri_dk, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(ref_dq, tri_dq, atol=ATOL, rtol=RTOL) def nonvarlen_benchmark_configs(): configs=[(16, 16, 16, 1024, 1024), @@ -1339,7 +1557,7 @@ def bench_flash_attention( o = torch.empty_like(q) fn = lambda: attention(q, k, v, o, input_metadata) if mode == 'bwd': - o = fn() + o, _ = fn() do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) @@ -1388,11 +1606,11 @@ def main(): "If custom config is specified, please provide \ all of batch, number of Q heads, Q sequence length \ and head size." - + assert args.dtype in arg_to_torch_dtype, \ "Only fp16, bf16 and f32 types currently supported." - run_benchmark(custom_config) + run_benchmark(custom_config) if __name__ == '__main__': - sys.exit(main()) \ No newline at end of file + sys.exit(main())